# Create Scaffold Splitter and visualize the projection vs random split

In [1]:
%load_ext autoreload
%autoreload 2

In [106]:
import os

import dgl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import torch
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from umap import UMAP

from src.modules.molecules.dgllife_gat import GATPretrainedWithLinearHead, dgl_canonical_featurizer
from src.modules.molecules.dgllife_gin import GINPretrainedWithLinearHead, dgl_pretrained_featurizer
from src.splitters import RandomSplitter, ScaffoldSplitter

OSError: /home/gwatk/miniconda3/envs/jump_models/lib/python3.10/site-packages/torch_sparse/_version_cuda.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE

In [21]:
split_path = "../cpjump1/jump/models/splits/bigger_jump_cl"

In [22]:
train_cpds = pd.read_csv(f"{split_path}/train_ids.csv").iloc[:, 0].tolist()
test_cpds = pd.read_csv(f"{split_path}/test_ids.csv").iloc[:, 0].tolist()
val_cpds = pd.read_csv(f"{split_path}/val_ids.csv").iloc[:, 0].tolist()

compound_list = train_cpds + test_cpds + val_cpds

In [97]:
splitter = ScaffoldSplitter(train=25000, test=3000, val=2000, compound_list=compound_list)

In [98]:
scaffolds = splitter.generate_scaffolds()

In [91]:
model = GATPretrainedWithLinearHead(
    "GAT_canonical_PCBA",
    out_dim=512,
    preload=True,
)

Downloading GAT_canonical_PCBA_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gat_canonical_pcba.pth...
Pretrained model loaded


In [92]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [94]:
def get_emb_from_smiles(smiles, model, featurizer, device, verbose=False):
    model.to(device)
    model.eval()

    graphs = [featurizer(smile) for smile in smiles]
    dl = DataLoader(graphs, batch_size=64, shuffle=False, collate_fn=dgl.batch)
    feats = []

    pbar = tqdm(dl) if verbose else dl
    for b in pbar:
        f = model.extract(b.to(device))
        feats.append(f.detach().cpu().numpy())

    return np.concatenate(feats)

In [95]:
n_ex = 100
idx = list(range(2, 2 + 3 * n_ex, 3))

In [100]:
scaffold_feats = []

for i in tqdm(idx, leave=False):
    scaffold_feats.append(
        get_emb_from_smiles(
            smiles=scaffolds[i], model=model, featurizer=dgl_canonical_featurizer, device=device, verbose=False
        )
    )

scaffold_index = np.concatenate([np.ones(scaf.shape[0]) * i for i, scaf in enumerate(scaffold_feats)])

X = np.concatenate(scaffold_feats)
y = scaffold_index

  0%|          | 0/100 [00:00<?, ?it/s]

In [101]:
tsne_proj = TSNE(n_components=2, perplexity=50, n_jobs=-1).fit_transform(X)
pca_proj = PCA(n_components=2).fit_transform(X)
umap_proj = UMAP(n_components=2, n_neighbors=50).fit_transform(X)

In [102]:
df = pd.DataFrame(tsne_proj, columns=["tsne-x", "tsne-y"])
df["pca-x"] = pca_proj[:, 0]
df["pca-y"] = pca_proj[:, 1]
df["umap-x"] = umap_proj[:, 0]
df["umap-y"] = umap_proj[:, 1]
df["scaffold"] = y.astype(int).astype(str)
df["inchi"] = np.concatenate([scaffolds[i] for i in idx])

In [105]:
proj_type = "tsne"
px.scatter(df, x=f"{proj_type}-x", y=f"{proj_type}-y", color="scaffold", height=800, width=800)

In [79]:
splitter = RandomSplitter(train=25000, test=3000, val=2000, compound_list=compound_list)
# splitter = ScaffoldSplitter(train=25000, test=3000, val=2000, compound_list=compound_list)

In [102]:
splitter = RandomSplitter(train=25000, test=3000, val=2000, compound_list=compound_list)
random_train_cpds, random_test_cpds, random_val_cpds = splitter.split()

In [104]:
splitter = ScaffoldSplitter(train=25000, test=3000, val=2000, compound_list=compound_list)
scaffold_train_cpds, scaffold_test_cpds, scaffold_val_cpds = splitter.split()

In [86]:
graphs = []

for comp in tqdm(compound_list, leave=False):
    graphs.append(dgl_pretrained_featurizer(comp))

  0%|          | 0/30000 [00:00<?, ?it/s]

In [95]:
dl = DataLoader(graphs, batch_size=128, shuffle=False, collate_fn=dgl.batch)
feats = []

for batch in tqdm(dl):
    batch_features = model.extract(batch.to(device))
    feats.append(batch_features.detach().cpu().numpy())

feats = np.concatenate(feats)

  0%|          | 0/235 [00:00<?, ?it/s]

In [98]:
tsne_proj = TSNE(n_components=2, perplexity=50, n_jobs=-1, random_state=42).fit_transform(feats)
pca_proj = PCA(n_components=2, random_state=42).fit_transform(feats)

In [105]:
df = pd.DataFrame(tsne_proj, columns=["tsne-1", "tsne-2"])
df["pca-1"] = pca_proj[:, 0]
df["pca-2"] = pca_proj[:, 1]

df["compound"] = compound_list
df["random_split"] = "train"
df.loc[df["compound"].isin(random_test_cpds), "random_split"] = "test"
df.loc[df["compound"].isin(random_val_cpds), "random_split"] = "val"
df["scaffold_split"] = "train"
df.loc[df["compound"].isin(scaffold_test_cpds), "scaffold_split"] = "test"
df.loc[df["compound"].isin(scaffold_val_cpds), "scaffold_split"] = "val"

In [106]:
df

Unnamed: 0,tsne-1,tsne-2,pca-1,pca-2,compound,random_split,scaffold_split
0,-95.997498,15.171036,-0.035411,-0.044989,InChI=1S/C15H20N4OS/c1-15(6-3-7-18-8-15)10-5-4...,train,train
1,-95.896378,15.168567,0.204280,0.126972,"InChI=1S/C17H22N4O2S/c1-10(22)21-8-4-7-17(2,9-...",test,train
2,-95.879669,15.153924,0.009899,0.105141,"InChI=1S/C18H24N4O2S/c1-4-13(23)22-9-5-8-18(2,...",train,train
3,-95.937012,15.237575,0.240354,0.045607,InChI=1S/C19H26N4O2S/c1-5-14(24)23-10-6-9-19(2...,train,train
4,-95.992752,15.190013,0.290126,-0.164314,InChI=1S/C19H27N5O2S/c1-19(8-5-9-24(11-19)14(2...,train,train
...,...,...,...,...,...,...,...
29995,18.902733,-60.993237,-0.457270,-0.317803,InChI=1S/C15H10N2O3/c18-15-12-8-4-5-9-13(12)16...,val,train
29996,18.786655,-60.954170,-0.729785,-0.063444,InChI=1S/C19H18N2O2/c1-13(22)21-17-7-5-4-6-16(...,train,train
29997,-50.927944,6.364256,0.566371,0.107061,InChI=1S/C18H15N5O2/c24-18-14-7-4-9-19-17(14)2...,test,train
29998,48.628563,17.562490,0.368874,-0.817282,InChI=1S/C20H25NO2/c1-16-13-23-15-20(19-9-4-3-...,train,train


In [113]:
proj_type = "tsne"
split_type = "random"
px.scatter(
    df,
    x=f"{proj_type}-1",
    y=f"{proj_type}-2",
    color=f"{split_type}_split",
    height=1000,
    width=1200,
    title=f"{proj_type.upper()} projection of GIN embeddings colored by {split_type} split",
)