In [1]:
import pickle
import pandas as pd
import re
from tqdm import tqdm
import numpy as np
import rdkit
from rdkit import Chem
from rdkit.Chem import Draw
from pprint import pprint as pp
import yaml
import plotly.graph_objects as go

BASE = '/Users/morgunov/batista/Summer/pipeline/'
REGEX_PATTERN = "(\[[^\]]+]|<|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|@@|\?|>|!|\*|\$|\%[0-9]{2}|[0-9])"
regex = re.compile(REGEX_PATTERN)

In [2]:
class Graph:
    def __init__(self):
        self.title_size = 20
        self.axis_title_size = 14
        self.tick_font_size = 12
        self.text_color="#333333"
        self.background = "white"
        self.grid_color = "#e2e2e2"
        self.line_color = "#000000"
        self.font_family = 'Helvetica'
        self.show_xgrid = False
        self.show_ygrid = False
        self.width = 600
        self.height = 400
        self.title = ''
        self.xaxis_title = ''
        self.yaxis_title = ''
    
    def update_parameters(self, params):
        for key, val in params.items():
            setattr(self, key, val)
        

    def style_figure(self, figure):
        figure.update_layout({
            'margin': {'t': 50, 'b': 50, 'l': 50, 'r': 50},
            'plot_bgcolor': self.background,
            'paper_bgcolor': self.background,
            'title': {
                'text': self.title,
                'font': {
                    'size': self.title_size,
                    'color': self.text_color,
                    'family': self.font_family
                },
            },
            'height': self.height,  # Set fixed size ratio 3:4
            'width': self.width, 
            'font': {
                'family': self.font_family,
                'size': self.tick_font_size,
                'color': self.text_color
            },
            'legend': {
                'font': {
                    'family': self.font_family,
                    'size': self.tick_font_size,
                    'color': self.text_color
                },
            },
        })

        # Setting the title size and color and grid for both x and y axes
        figure.update_xaxes(
            title=self.xaxis_title,
            title_font={'size': self.axis_title_size, 'color': self.text_color, 'family': self.font_family},
            tickfont={'size': self.tick_font_size, 'color': self.text_color, 'family': self.font_family},
            showgrid=self.show_xgrid,
            gridwidth=1,
            gridcolor=self.grid_color,
            linecolor=self.line_color,  # make x axis line visible
            linewidth=2
        )

        figure.update_yaxes(
            title=self.yaxis_title,
            title_standoff=0,
            title_font={'size': self.axis_title_size, 'color': self.text_color, 'family': self.font_family},
            tickfont={'size': self.tick_font_size, 'color': self.text_color, 'family': self.font_family},
            showgrid=self.show_ygrid,
            gridwidth=1,
            gridcolor=self.grid_color,
            linecolor=self.line_color,  # make y axis line visible
            linewidth=2
        )
        return figure

In [3]:
class Dataset:

    def __init__(self, base_path, data_path, remove_isotopes=False, remove_invalides=True, remove_non_bio_friendly=False, token_len_cutoff=None, token_freq_cutoff=None):
        self.base_path = base_path
        self.data_path = data_path
        self.vocab = set()
        self.max_block_size = 0
        self.block_sizes = []
        self.token_to_freq = {}
        self.remove_isotopes = remove_isotopes
        self.remove_invalides = remove_invalides
        self.remove_non_bio_friendly = remove_non_bio_friendly
        self.plot_suffix = ''
        self.token_len_cutoff = token_len_cutoff
        self.token_freq_cutoff = token_freq_cutoff
        if self.remove_isotopes:
            with open(self.base_path + 'processed_data/isotope_exceptions.yaml', 'r') as f:
                self.isotope_exceptions = set(yaml.safe_load(f))
            self.plot_suffix += '_no_isotopes'
        if self.remove_invalides:
            self.invalid_smiles = pickle.load(open(self.base_path + 'processed_data/invalid_smiles.pkl', 'rb'))
            self.plot_suffix += '_no_invalides'
        if self.remove_non_bio_friendly:
            with open(self.base_path + 'processed_data/non_bio_friendly_exceptions.yaml', 'r') as f:
                self.non_bio_friendly_exceptions = set(yaml.safe_load(f))
            self.plot_suffix += '_no_nonbio_friendly'

    def load_and_analyze(self):
        self.df = pd.read_csv(self.base_path + 'raw_data/' + self.data_path + '.csv')
        
        processed_smiles = set()
        smile_crumbles = []

        pbar_desc = 'Analyzing ' + self.data_path
        if self.remove_isotopes:
            pbar_desc += ' (no isotopes)'
            self.smiles_with_isotopes = set()
        if self.remove_invalides:
            pbar_desc += ' (no invalides)'
        if self.remove_non_bio_friendly:
            pbar_desc += ' (no non bio friendly)'
            self.smiles_with_non_bio_friendly = set()

        pbar = tqdm(self.df['smiles'].values, total=len(self.df['smiles'].values), desc=pbar_desc)
        for smile in pbar:
            tokens = regex.findall(smile.strip())

            if self.remove_isotopes and any([token in self.isotope_exceptions for token in tokens]):
                self.smiles_with_isotopes.add(smile)
                continue
            if self.remove_non_bio_friendly and any([token in self.non_bio_friendly_exceptions for token in tokens]):
                self.smiles_with_non_bio_friendly.add(smile)
                continue
            if self.remove_invalides and smile in self.invalid_smiles:
                continue
            if self.token_len_cutoff is not None and len(tokens) > self.token_len_cutoff:
                continue

            processed_smiles.add(smile)
            for token in tokens:
                self.vocab.add(token)
                if token not in self.token_to_freq:
                    self.token_to_freq[token] = 0
                self.token_to_freq[token] += 1

            self.max_block_size = max(self.max_block_size, len(tokens))
            self.block_sizes.append(len(tokens))
        
        if self.remove_isotopes:
            pickle.dump(self.smiles_with_isotopes, open(self.base_path + 'processed_data/smiles_with_isotopes.pkl', 'wb'))
        if self.remove_non_bio_friendly:
            pickle.dump(self.smiles_with_non_bio_friendly, open(self.base_path + 'processed_data/smiles_with_non_bio_friendly.pkl', 'wb'))
        pickle.dump(processed_smiles, open(self.base_path + 'processed_data/processed_smiles_' + self.data_path + '.pkl', 'wb'))
        self.processed_smiles = processed_smiles
        if self.token_freq_cutoff is not None:
            self._remove_rare_tokens()

    def _remove_rare_tokens(self):
        new_smiles = set()
        self.max_block_size = 0
        self.vocab = set()

        pbar = tqdm(self.processed_smiles, total=len(self.processed_smiles), desc=f"Removing rare tokens")
        for smile in pbar:
            tokens = regex.findall(smile.strip())
            if any([self.token_to_freq[token] < self.token_freq_cutoff for token in tokens]): continue
            new_smiles.add(smile)
            for token in tokens:
                self.vocab.add(token)
            self.max_block_size = max(self.max_block_size, len(tokens))
        
        pickle.dump(new_smiles, open(self.base_path + 'processed_data/processed_smiles_' + self.data_path + '_no_rare_tokens.pkl', 'wb'))

    def plot_distribution(self, block_sizes):
        graph = Graph()
        fig = go.Figure()
        fig.add_trace(go.Histogram(x=block_sizes, name=self.data_path, histnorm='probability density'))
        fig.update_layout(barmode='overlay', xaxis=dict(dtick=20))
        fig.update_traces(opacity=0.75)
        percentiles = [25, 50, 75, 95, 99, 99.9, 99.99]
        for i, percentile in enumerate(percentiles):
            val = np.percentile(block_sizes, percentile)
            y_coord = 0.5 if i < len(percentiles) / 2 else 0
            y_anno = -40 * (1 + i % 5) #-30 if i % 2 == 0 else -50 #-10 if i < len(percentiles) / 2 else -30
            fig.add_annotation(x=val, y=0, text=f'{percentile} percentile<br>{val:.0f} blocks',
                showarrow=True, arrowhead=1, ax=-10, ay=y_anno)
        stats_text = f"Total smiles: {len(block_sizes)}<br>Vocab Size: {len(self.vocab)}<br>Max Block Size: {self.max_block_size}"
        if self.remove_isotopes:
            isotope_fraction = len(self.smiles_with_isotopes) / len(self.df["smiles"])
            stats_text += f"<br># excluded (isotope-containing): {len(self.smiles_with_isotopes)} or {isotope_fraction:.2%}"
        if self.remove_non_bio_friendly:
            non_bio_friendly_fraction = len(self.smiles_with_non_bio_friendly) / len(self.df["smiles"])
            stats_text += f"<br># excluded (non bio friendly): {len(self.smiles_with_non_bio_friendly)} or {non_bio_friendly_fraction:.2%}"
        if self.remove_invalides:
            invalid_fraction = len(self.invalid_smiles) / len(self.df["smiles"])
            stats_text += f"<br># excluded (invalid): {len(self.invalid_smiles)} or {invalid_fraction:.2%}"
        fig.add_annotation(x=0.75, y=1, text=stats_text,
                            showarrow=False, xref='paper', yref='paper', xanchor='left', yanchor='top', font=dict(size=12))
        graph.update_parameters({'title': f'Block size distribution for {self.data_path} partition. Isotopes removed: {self.remove_isotopes}',
                                'xaxis_title': 'Block size', 'yaxis_title': 'Probability density',
                                'width': 1280, 'height': 720, 'show_xgrid': True})
        graph.style_figure(fig)
        fig.write_html(self.base_path + f'plots/block_size_distribution_{self.data_path}{self.plot_suffix}.html', include_plotlyjs='cdn')

    def plot_token_frequencies(self, top_n=20):
        sorted_token_freq = dict(sorted(self.token_to_freq.items(), key=lambda item: item[1], reverse=True))
        sorted_token_freq = dict(list(sorted_token_freq.items())[:top_n])
        
        graph = Graph()
        fig = go.Figure(data=[
            go.Bar(
                x=list(sorted_token_freq.keys()), 
                y=list(sorted_token_freq.values()), 
                name=self.data_path)
            ])
        graph.update_parameters({'title': f'Top {top_n} character frequencies for {self.data_path} partition. Isotopes removed: {self.remove_isotopes}',
                                'xaxis_title': 'Token', 'yaxis_title': 'Frequency',
                                'width': 1280, 'height': 720})

        graph.style_figure(fig)
        fig.write_html(self.base_path + f'plots/token_frequencies_{self.data_path}{self.plot_suffix}.html', include_plotlyjs='cdn')
    
    def export_vocab(self):
        with open(self.base_path + f'vocab_{self.data_path}.txt', 'w') as f:
            for token in self.vocab:
                f.write(token + '\n')


In [4]:
config = dict(remove_isotopes=True, remove_invalides=True, remove_non_bio_friendly=True, token_len_cutoff=None, token_freq_cutoff=None)
binding_db = Dataset(BASE, 'moses_and_binding', **config)
binding_db.load_and_analyze()
len(binding_db.vocab), binding_db.max_block_size

Analyzing moses_and_binding (no isotopes) (no invalides) (no non bio friendly): 100%|██████████| 2895566/2895566 [00:30<00:00, 93516.95it/s]


(44, 130)

In [5]:
# Sort the token_to_freq dictionary by frequency in reversed order
sorted_token_freq = dict(sorted(binding_db.token_to_freq.items(), key=lambda item: item[1], reverse=True))
sorted_token_freq

{'c': 31151990,
 'C': 21416624,
 '(': 12305744,
 ')': 12305744,
 '1': 9910896,
 'O': 6934241,
 '2': 5810532,
 'N': 4916586,
 '=': 4629659,
 'n': 4381610,
 '3': 1963472,
 'F': 1364125,
 '-': 898082,
 'S': 590858,
 'Cl': 511539,
 's': 437893,
 '[C@H]': 429608,
 '[nH]': 411651,
 'o': 397557,
 '[C@@H]': 395138,
 '4': 359984,
 '#': 261454,
 '\\': 132287,
 'Br': 100460,
 '5': 47774,
 '[C@]': 47130,
 '[C@@]': 43568,
 '[O-]': 33362,
 '/': 30583,
 '[N+]': 27785,
 'P': 17401,
 '.': 9663,
 '6': 6030,
 'I': 5902,
 '[n+]': 5363,
 'B': 2549,
 '[S+]': 1468,
 '[C-]': 1237,
 '[N-]': 1235,
 '[Si]': 1165,
 '[Na+]': 1164,
 '[NH3+]': 1080,
 '7': 684,
 '[H]': 251}

In [None]:
binding_set = pickle.load(open(BASE + 'processed_data/processed_smiles_bindingDB_07_11_no_rare_tokens.pkl', 'rb'))
moses_train_set = set(pd.read_csv(BASE + 'raw_data/train.csv.gz', compression='gzip')['SMILES'].tolist())
moses_test_set = set(pd.read_csv(BASE + 'raw_data/test.csv.gz', compression='gzip')['SMILES'].tolist())
moses_and_binding = binding_set | moses_train_set | moses_test_set
pd.DataFrame({'smiles': list(moses_and_binding)}).to_csv(BASE + 'processed_data/moses_and_binding.csv')

In [None]:
config = dict(remove_isotopes=True, remove_invalides=True, remove_non_bio_friendly=True, cutoff=130)
train_data = Dataset(BASE, 'train', **config)
train_data.load_and_analyze()
val_data = Dataset(BASE, 'val', **config)
val_data.load_and_analyze()

In [None]:
smile_crumbles_train = pickle.load(open(BASE + 'processed_data/smiles_crumbles_train.pkl', 'rb'))
smile_crumbles_val = pickle.load(open(BASE + 'processed_data/smiles_crumbles_val.pkl', 'rb'))
print(len(smile_crumbles_train), len(smile_crumbles_val))

In [None]:
def create_train_val_partition(fname, validation_fraction=0.05):
    all_mols = pickle.load(open(BASE + f'processed_data/{fname}.pkl', 'rb'))
    all_mols_list = list(all_mols)
    validation_index = int((1-validation_fraction)*len(all_mols))
    train_mols = all_mols_list[:validation_index]
    val_mols = all_mols_list[validation_index:]
    print(f"{validation_index=}")
    pd.DataFrame({'smiles': train_mols}).to_csv(BASE + f'processed_data/{fname}_train.csv.gz', compression='gzip')
    pd.DataFrame({'smiles': val_mols}).to_csv(BASE + f'processed_data/{fname}_val.csv.gz', compression='gzip')

create_train_val_partition('processed_smiles_moses_and_binding_no_rare_tokens')

In [None]:
configs = [
    dict(remove_isotopes=False, remove_invalides=False, remove_non_bio_friendly=False),
    dict(remove_isotopes=False, remove_invalides=True, remove_non_bio_friendly=False),
    dict(remove_isotopes=True, remove_invalides=True, remove_non_bio_friendly=False),
    dict(remove_isotopes=True, remove_invalides=True, remove_non_bio_friendly=True),
]
for config in configs:
    print(config)
    train_data = Dataset(BASE, 'train', **config)
    train_data.load_and_analyze()
    train_data.plot_distribution(train_data.block_sizes)
    train_data.plot_token_frequencies(top_n=None)
    val_data = Dataset(BASE, 'val', **config)
    val_data.load_and_analyze()
    val_data.plot_distribution(val_data.block_sizes)
    val_data.plot_token_frequencies(top_n=None)