In [1]:
import umap
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np
from __future__ import annotations

In [120]:
from __future__ import annotations
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import umap
import umap.plot
from bokeh.plotting import figure, show, output_notebook, save, output_file
from bokeh.models import HoverTool, ColumnDataSource, CategoricalColorMapper
from bokeh.palettes import Spectral4
from rdkit import Chem
from rdkit import Chem, RDLogger
from rdkit.Chem.Draw import rdMolDraw2D, MolDraw2DSVG, rdDepictor
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

output_notebook()

class Projector(object):
    def __init__(self, gen: list, train_ds: list, test_ds=None):
        self.gen = gen
        self.train_ds = train_ds
        self.test_ds = test_ds
        
        assert len(self.gen) > 0 and len(self.train_ds) > 0
        
        self.__remove_invalid()
        self.__remove_duplicate()
        self.__len__()
        
    def __remove_invalid(self):
        self.gen = [smi for smi in self.gen if Chem.MolFromSmiles(str(smi))]
        self.train_ds = [smi for smi in self.train_ds if Chem.MolFromSmiles(smi)]
        
        if self.test_ds is not None and len(self.test_ds):
            self.test_ds = [smi for smi in self.test_ds if Chem.MolFromSmiles(smi)]
        
    def __remove_duplicate(self):
        # remove own duplicates
        self.gen = list(set(self.gen))
        self.gen_df = pd.DataFrame(self.gen, columns=['gen'], dtype='object')
        
        self.train_ds = list(set(self.train_ds))
        self.train_df = pd.DataFrame(self.train_ds, columns=['train'], dtype='object')
        
        if self.test_ds is not None and len(self.test_ds):
            self.test_ds = list(set(self.test_ds))
            self.test_df = pd.DataFrame(self.test_ds, columns=['test'], dtype='object')
            
        gen_train_df = self.gen_df.merge(self.train_df, how='cross')
        # unique_gen_train_df = gen_train_df[gen_train_df['gen'] != gen_train_df['train']]
        unique_gen_train_df = gen_train_df.loc[~gen_train_df['gen'].isin(gen_train_df['train'].tolist())]
        self.gen_df = unique_gen_train_df.drop(['train'], axis=1).drop_duplicates().reset_index(drop=True)
        
        if self.test_ds is not None and len(self.test_ds):
            gen_test_df = self.gen_df.merge(self.test_df, how='cross')
            unique_gen_test_df = gen_test_df.loc[~gen_test_df['gen'].isin(gen_test_df['test'].tolist())]
            self.gen_df = unique_gen_test_df.drop(['test'], axis=1).drop_duplicates().reset_index(drop=True)
        
    def __len__(self):
        assert self.gen_df.shape[0] > 0 and self.train_df.shape[0] > 0
        print(f"Generated data: #{self.gen_df.shape[0]:,}. \nTraining data: #{self.train_df.shape[0]:,}.")
        if self.test_ds is not None and len(self.test_ds): print(f"Test data: #{self.test_df.shape[0]:,}")
        
    def get_fp(self, sms):
        ms = [Chem.MolFromSmiles(smi) for smi in sms]
        return [
            list(
                np.array(Chem.AllChem.GetMorganFingerprintAsBitVect(m, radius=3, nBits=2048))
            )
            for m in ms
        ]
    
    def init_map(self):
        ds_plot_figure = figure(
            title='UMAP',
            plot_width=1000,
            plot_height=600,
            tools=('pan, wheel_zoom, reset, box_select, save')
        )

        TOOLTIPS = [
            ("Class", "@label"),
            ("SMILES", "@SMILES"),
            ("", "@image{safe}")
        ]

        ds_plot_figure.add_tools(HoverTool(tooltips=TOOLTIPS))
        return ds_plot_figure
    
    def get_mol_graph(self, smi):
        return self._mol2svg(Chem.MolFromSmiles(str(smi)))

    def _mol2svg(self, mol):
        d2d = rdMolDraw2D.MolDraw2DSVG(200, 100)
        d2d.DrawMolecule(mol)
        d2d.FinishDrawing()
        return d2d.GetDrawingText()
    
    def save_map(self, figure):
        now = datetime.now().strftime('%Y%m%d_%H%m')
        output_file(f"./projection/irak4_spe-20221212_1512/prj-epoch-5.html")
        save(figure)
    
    def plot_map(self, figure, data, label):
        
        if len(label) == 2:
            color =  ['green', 'blue']
        else:
            color =  ['green', 'blue', 'pink']
        
        for data, name, color in zip(data, label, color):
    
            source = ColumnDataSource(data=dict(
                x=data['x'].values,
                y=data['y'].values,
                label=data['label'].values,
                image=data['image'].values,
                SMILES=data['smiles'].values
            ))

            figure.circle(
                x='x',
                y='y',
                source=source,
                color=color,
                line_alpha=0.6,
                fill_alpha=0.6,
                size=6,
                muted_alpha=0.1,
                legend_label=name
            )

        figure.legend.location = "top_right"
        figure.legend.click_policy="mute"

        # show(figure)
        self.save_map(figure)
        
    def project(self, view=False):
        # Get finger print
        self.train_fp = self.get_fp(self.train_df['train'].to_list())
        self.gen_fp = self.get_fp(self.gen_df['gen'].to_list())
        if self.test_ds is not None and self.test_df.shape[0] > 0:
            self.test_fp = self.get_fp(self.test_df['test'].to_list())
            
        # Init umap
        ds_reducer = umap.UMAP(
            random_state=42,
            # n_neighbors=20,
            # min_dist=0.3,
        )
        
        # Fit dataset first
        train_embed = ds_reducer.fit_transform(
            X=self.train_fp
        )
        train_embed_df = pd.DataFrame(train_embed, columns=['x', 'y'])
        self.train_df = pd.concat([self.train_df, train_embed_df], axis=1)
        
        # print(train_embed)
        # print(self.gen_fp[0])
        
        # Fit test dataset
        if self.test_ds is not None:
            test_embed = ds_reducer.transform(
                X=self.test_fp
            )
            test_embed_df = pd.DataFrame(test_embed, columns=['x', 'y'])
            self.test_df = pd.concat([self.test_df, test_embed_df], axis=1)
        
        # Fit generated data
        gen_embed = ds_reducer.transform(
            X=self.gen_fp
        )
        
        gen_embed_df = pd.DataFrame(gen_embed, columns=['x', 'y'])
        self.gen_df = pd.concat([self.gen_df, gen_embed_df], axis=1)
        # print(self.gen_df)
        
        # init Umap
        if view:
            # Init bokeh plot
            ds_plot_figure = self.init_map()
            
            # Plot figure
            self.gen_df.rename(columns={'gen': 'smiles'}, inplace=True)
            self.gen_df['label'] = 'gen'
            self.gen_df['image'] = self.gen_df['smiles'].apply(self.get_mol_graph)
            self.train_df.rename(columns={'train': 'smiles'}, inplace=True)
            self.train_df['label'] = 'train'
            self.train_df['image'] = self.train_df['smiles'].apply(self.get_mol_graph)
            # print(self.gen_df)
            
            if self.test_ds is not None:
                self.test_df.rename(columns={'test': 'smiles'}, inplace=True)
                self.test_df['label'] = 'test'
                self.test_df['image'] = self.test_df['smiles'].apply(self.get_mol_graph)
                
                data = [
                    self.gen_df,
                    self.train_df,
                    self.test_df
                ]
                label = ['Gen', 'Train', 'Test']
            else:
                data = [
                    self.gen_df,
                    self.train_df,
                ]
                label = ['Gen', 'Train']
                
            self.plot_map(ds_plot_figure, data, label)
            

In [121]:
# gen = pd.read_csv('./generation/gptneo_irak4_spe_20221103-1811-checkpoint-500/10_000_gen_sample.csv')
gen = pd.read_csv('./generation/irak4_spe/checkpoint-70-30_000_gen_sample.csv')

train = pd.read_csv('../GPT2/dataset/IRAK4/irak4_train.txt', header=None).iloc[:,0].to_list()
test = pd.read_csv('../GPT2/dataset/IRAK4/irak4_test.txt', header=None).iloc[:,0].to_list()
# test_scaffolds = get_scaffold(test)

print(f"Training set: {len(train):,} counts")
print(f"Test set: {len(test):,} counts")
# print(f"Test Scaffold: {len(test_scaffolds):,} counts")

Training set: 840 counts
Test set: 10 counts


In [122]:
projector = Projector(gen.iloc[:, 0].to_list(), train, test)
projector.project(view=True)

Generated data: #1,024. 
Training data: #840.
Test data: #10
