# Using the pretrained model

## Loading the checkpoint file

In [6]:
import os

import dgl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import torch
from hydra import compose, initialize, initialize_config_dir, initialize_config_module
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from umap import UMAP

from src.models.jump_cl.module import BasicJUMPModule
from src.modules.molecules.dgllife_gat import GATPretrainedWithLinearHead, dgl_canonical_featurizer
from src.modules.molecules.dgllife_gin import GINPretrainedWithLinearHead, dgl_pretrained_featurizer
from src.splitters import RandomSplitter, ScaffoldSplitter

In [2]:
os.getcwd()

'/mnt/2547d4d7-6732-4154-b0e1-17b0c1e0c565/Document-2/Projet2/Stage/workspace/jump_models'

In [5]:
ckpt = "./models/runs/2023-07-25_16-12-29/checkpoints/last.ckpt"

In [3]:
initialize(version_base=None, config_path="../models/runs/2023-07-25_16-12-29/.hydra/")

hydra.initialize()

In [4]:
cfg = compose(config_name="config.yaml")

In [9]:
image_encoder = instantiate(cfg.model.image_encoder)
molecule_encoder = instantiate(cfg.model.molecule_encoder)
criterion = instantiate(cfg.model.criterion)

Downloading gin_supervised_masking_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_masking.pth...
Pretrained model loaded


In [14]:
model = BasicJUMPModule.load_from_checkpoint(
    ckpt,
    image_encoder=image_encoder,
    molecule_encoder=molecule_encoder,
    criterion=criterion,
    example_input_path=None,
    map_location=torch.device("cuda:0"),
)

In [15]:
model

BasicJUMPModule(
  (image_encoder): CNNEncoder(
    (model): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (drop_block): Identity()
          (act1): ReLU(inplace=True)
          (aa): Identity()
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act2): ReLU(inplace=True)
        )
        (1): BasicBlock(
         

## Plotting the molecule embeddings

In [17]:
split_path = "../cpjump1/jump/models/splits/bigger_jump_cl"

In [18]:
train_cpds = pd.read_csv(f"{split_path}/train_ids.csv").iloc[:, 0].tolist()
test_cpds = pd.read_csv(f"{split_path}/test_ids.csv").iloc[:, 0].tolist()
val_cpds = pd.read_csv(f"{split_path}/val_ids.csv").iloc[:, 0].tolist()

compound_list = train_cpds + test_cpds + val_cpds

In [19]:
splitter = ScaffoldSplitter(train=25000, test=3000, val=2000, compound_list=compound_list)

In [20]:
scaffolds = splitter.generate_scaffolds()

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [25]:
def get_emb_from_smiles(smiles, model, featurizer, device, verbose=False):
    model.to(device)
    model.eval()

    graphs = [featurizer(smile) for smile in smiles]
    dl = DataLoader(graphs, batch_size=64, shuffle=False, collate_fn=dgl.batch)
    feats = []

    pbar = tqdm(dl) if verbose else dl
    for b in pbar:
        f = model.molecule_encoder(b.to(device))
        feats.append(f.detach().cpu().numpy())

    return np.concatenate(feats)

In [23]:
n_ex = 100
idx = list(range(2, 2 + 3 * n_ex, 3))

In [26]:
scaffold_feats = []

for i in tqdm(idx, leave=False):
    scaffold_feats.append(
        get_emb_from_smiles(
            smiles=scaffolds[i], model=model, featurizer=dgl_pretrained_featurizer, device=device, verbose=False
        )
    )

scaffold_index = np.concatenate([np.ones(scaf.shape[0]) * i for i, scaf in enumerate(scaffold_feats)])

X = np.concatenate(scaffold_feats)
y = scaffold_index

  0%|          | 0/100 [00:00<?, ?it/s]

In [31]:
tsne_proj = TSNE(n_components=2, perplexity=50, n_jobs=-1).fit_transform(X)
pca_proj = PCA(n_components=2).fit_transform(X)
umap_proj = UMAP(n_components=2, n_neighbors=50).fit_transform(X)

In [32]:
df = pd.DataFrame(tsne_proj, columns=["tsne-x", "tsne-y"])
df["pca-x"] = pca_proj[:, 0]
df["pca-y"] = pca_proj[:, 1]
df["umap-x"] = umap_proj[:, 0]
df["umap-y"] = umap_proj[:, 1]
df["scaffold"] = y.astype(int).astype(str)
df["inchi"] = np.concatenate([scaffolds[i] for i in idx])

In [33]:
proj_type = "tsne"
px.scatter(df, x=f"{proj_type}-x", y=f"{proj_type}-y", color="scaffold", height=800, width=800)

## Evaluation

In [35]:
checkpoint = torch.load(ckpt, map_location=torch.device("cpu"))

In [39]:
checkpoint["state_dict"].keys()

odict_keys(['image_encoder.model.conv1.weight', 'image_encoder.model.bn1.weight', 'image_encoder.model.bn1.bias', 'image_encoder.model.bn1.running_mean', 'image_encoder.model.bn1.running_var', 'image_encoder.model.bn1.num_batches_tracked', 'image_encoder.model.layer1.0.conv1.weight', 'image_encoder.model.layer1.0.bn1.weight', 'image_encoder.model.layer1.0.bn1.bias', 'image_encoder.model.layer1.0.bn1.running_mean', 'image_encoder.model.layer1.0.bn1.running_var', 'image_encoder.model.layer1.0.bn1.num_batches_tracked', 'image_encoder.model.layer1.0.conv2.weight', 'image_encoder.model.layer1.0.bn2.weight', 'image_encoder.model.layer1.0.bn2.bias', 'image_encoder.model.layer1.0.bn2.running_mean', 'image_encoder.model.layer1.0.bn2.running_var', 'image_encoder.model.layer1.0.bn2.num_batches_tracked', 'image_encoder.model.layer1.1.conv1.weight', 'image_encoder.model.layer1.1.bn1.weight', 'image_encoder.model.layer1.1.bn1.bias', 'image_encoder.model.layer1.1.bn1.running_mean', 'image_encoder.mod

In [37]:
checkpoint.keys()

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters', 'datamodule_hparams_name', 'datamodule_hyper_parameters'])