In [1]:
%load_ext autoreload
%autoreload 2

In [103]:
import os

import dgl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from src.modules.molecules.dgllife_pretrained import GINPretrainedWithLinearHead, dgl_mol_featurizer
from src.splitters import RandomSplitter, ScaffoldSplitter

In [63]:
split_path = "../cpjump1/jump/models/splits/bigger_jump_cl"

In [68]:
train_cpds = pd.read_csv(f"{split_path}/train_ids.csv").iloc[:, 0].tolist()
test_cpds = pd.read_csv(f"{split_path}/test_ids.csv").iloc[:, 0].tolist()
val_cpds = pd.read_csv(f"{split_path}/val_ids.csv").iloc[:, 0].tolist()

compound_list = train_cpds + test_cpds + val_cpds

In [110]:
splitter = ScaffoldSplitter(train=25000, test=3000, val=2000, compound_list=compound_list)

In [99]:
splitter = RandomSplitter(train=25000, test=3000, val=2000, compound_list=compound_list)

In [111]:
model = GINPretrainedWithLinearHead(
    "gin_supervised_infomax",
    out_dim=512,
    pooling="mean",
    preload=True,
)

Downloading gin_supervised_infomax_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_infomax.pth...
Pretrained model loaded


In [112]:
train_cpds, test_cpds, val_cpds = splitter.split()

In [113]:
train_cpds = pd.read_csv(f"{split_path}/train_ids.csv").iloc[:, 0].tolist()
train_graphs = [dgl_mol_featurizer(cpd) for cpd in train_cpds]
train_dl = DataLoader(train_graphs, batch_size=32, shuffle=False, collate_fn=dgl.batch)
train_feats = []

for batch in tqdm(train_dl):
    batch_features = model.extract(batch)
    train_feats.append(batch_features.detach().cpu().numpy())

train_feats = np.concatenate(train_feats)

  0%|          | 0/782 [00:00<?, ?it/s]

In [114]:
test_cpds = pd.read_csv(f"{split_path}/test_ids.csv").iloc[:, 0].tolist()
test_graphs = [dgl_mol_featurizer(cpd) for cpd in test_cpds]
test_dl = DataLoader(test_graphs, batch_size=32, shuffle=False, collate_fn=dgl.batch)
test_feats = []

for batch in tqdm(test_dl):
    batch_features = model.extract(batch)
    test_feats.append(batch_features.detach().cpu().numpy())

test_feats = np.concatenate(test_feats)

  0%|          | 0/63 [00:00<?, ?it/s]

In [115]:
val_cpds = pd.read_csv(f"{split_path}/val_ids.csv").iloc[:, 0].tolist()
val_graphs = [dgl_mol_featurizer(cpd) for cpd in val_cpds]
val_dl = DataLoader(val_graphs, batch_size=32, shuffle=False, collate_fn=dgl.batch)
val_feats = []

for batch in tqdm(val_dl):
    batch_features = model.extract(batch)
    val_feats.append(batch_features.detach().cpu().numpy())

val_feats = np.concatenate(val_feats)

  0%|          | 0/94 [00:00<?, ?it/s]

In [116]:
all_feats = np.vstack([train_feats, test_feats, val_feats])

In [120]:
proj = PCA(n_components=2)

projected = proj.fit_transform(np.vstack([train_feats, test_feats, val_feats]))

In [121]:
df = pd.DataFrame(projected, columns=["dim-1", "dim-2"])
splits = ["train"] * len(train_feats) + ["test"] * len(test_feats) + ["val"] * len(val_feats)

df["split"] = splits

In [122]:
px.scatter(df, x="dim-1", y="dim-2", color="split", height=1000, width=1200)