# Using the pretrained model

## Loading the checkpoint file

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import warnings
from functools import partial

import dgl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import torch
import torch.nn as nn
from hydra import compose, initialize, initialize_config_dir, initialize_config_module
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, RichModelSummary
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger, WandbLogger
from omegaconf import DictConfig, OmegaConf
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from umap import UMAP

import wandb
from src.callbacks.wandb import WandbLogCallback
from src.eval import Evaluator, EvaluatorList, FinetunableEvaluator
from src.eval.ogb.datamodule import BBBPDataModule, EsolDataModule, HIVDataModule, LipoDataModule, Tox21DataModule
from src.eval.ogb.module import BBBPModule, EsolModule, HIVModule, LipoModule, Tox21Module
from src.models.jump_cl.module import BasicJUMPModule
from src.modules.collate_fn.dgl_labels import label_graph_collate_function
from src.modules.compound_transforms.dgllife_transform import DGLPretrainedFromInchi, DGLPretrainedFromSmiles
from src.modules.molecules.dgllife_gat import GATPretrainedWithLinearHead
from src.modules.molecules.dgllife_gin import GINPretrainedWithLinearHead
from src.splitters import RandomSplitter, ScaffoldSplitter

  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


In [3]:
ckpt = "./models/runs/2023-07-25_16-12-29/checkpoints/last.ckpt"

In [4]:
initialize(version_base=None, config_path="../models/runs/2023-07-25_16-12-29/.hydra/")

hydra.initialize()

In [5]:
cfg = compose(config_name="config.yaml")
print(OmegaConf.to_yaml(cfg))

task_name: train
tags:
- med_jump_cl
- simple_contrastive_training
- pretrained_gin_infomax
- pretrained_resnet18
train: true
test: true
compile: false
ckpt_path: null
seed: 12345
data:
  compound_transform:
    _target_: src.modules.molecules.dgllife_gin.dgl_pretrained_featurizer
    _partial_: true
  _target_: src.models.jump_cl.datamodule.BasicJUMPDataModule
  batch_size: 128
  num_workers: 24
  pin_memory: false
  prefetch_factor: null
  collate_fn:
    _target_: src.modules.collate_fn.dgl_image.image_graph_collate_function
    _partial_: true
  transform:
    _target_: src.modules.transforms.DefaultJUMPTransform
    _convert_: object
    size: 128
    dim:
    - -2
    - -1
  splitter:
    _target_: src.splitters.ScaffoldSplitter
    train: 25000
    test: 3000
    val: 2000
  split_path: ${paths.split_path}/med_jump_cl/
  dataloader_config:
    train:
      batch_size: ${data.batch_size}
      num_workers: ${data.num_workers}
      pin_memory: ${data.pin_memory}
      prefetch_fa

In [6]:
image_encoder = instantiate(cfg.model.image_encoder)
molecule_encoder = instantiate(cfg.model.molecule_encoder)
criterion = instantiate(cfg.model.criterion)

model = BasicJUMPModule.load_from_checkpoint(
    ckpt,
    image_encoder=image_encoder,
    molecule_encoder=molecule_encoder,
    criterion=criterion,
    example_input_path=None,
    map_location=torch.device("cpu"),
)

Downloading gin_supervised_masking_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_masking.pth...
Pretrained model loaded


## Plotting the molecule embeddings

In [23]:
split_path = "../cpjump1/jump/models/splits/bigger_jump_cl"

In [24]:
train_cpds = pd.read_csv(f"{split_path}/train_ids.csv").iloc[:, 0].tolist()
test_cpds = pd.read_csv(f"{split_path}/test_ids.csv").iloc[:, 0].tolist()
val_cpds = pd.read_csv(f"{split_path}/val_ids.csv").iloc[:, 0].tolist()

compound_list = train_cpds + test_cpds + val_cpds

In [25]:
splitter = ScaffoldSplitter(train=25000, test=3000, val=2000, compound_list=compound_list)

In [26]:
scaffolds = splitter.generate_scaffolds()

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [28]:
def get_emb_from_smiles(smiles, model, featurizer, device, verbose=False):
    model.to(device)
    model.eval()

    graphs = [featurizer(smile) for smile in smiles]
    dl = DataLoader(graphs, batch_size=64, shuffle=False, collate_fn=dgl.batch)
    feats = []

    pbar = tqdm(dl) if verbose else dl
    for b in pbar:
        f = model.molecule_encoder(b.to(device))
        feats.append(f.detach().cpu().numpy())

    return np.concatenate(feats)

In [29]:
n_ex = 100
idx = list(range(2, 2 + 3 * n_ex, 3))

In [30]:
scaffold_feats = []

for i in tqdm(idx, leave=False):
    scaffold_feats.append(
        get_emb_from_smiles(
            smiles=scaffolds[i], model=model, featurizer=DGLPretrainedFromInchi(), device=device, verbose=False
        )
    )

scaffold_index = np.concatenate([np.ones(scaf.shape[0]) * i for i, scaf in enumerate(scaffold_feats)])

X = np.concatenate(scaffold_feats)
y = scaffold_index

  0%|          | 0/100 [00:00<?, ?it/s]

In [31]:
tsne_proj = TSNE(n_components=2, perplexity=50, n_jobs=-1).fit_transform(X)
pca_proj = PCA(n_components=2).fit_transform(X)
umap_proj = UMAP(n_components=2, n_neighbors=50).fit_transform(X)

In [32]:
df = pd.DataFrame(tsne_proj, columns=["tsne-x", "tsne-y"])
df["pca-x"] = pca_proj[:, 0]
df["pca-y"] = pca_proj[:, 1]
df["umap-x"] = umap_proj[:, 0]
df["umap-y"] = umap_proj[:, 1]
df["scaffold"] = y.astype(int).astype(str)
df["inchi"] = np.concatenate([scaffolds[i] for i in idx])

In [33]:
proj_type = "tsne"
px.scatter(df, x=f"{proj_type}-x", y=f"{proj_type}-y", color="scaffold", height=800, width=800)

## Evaluation

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [9]:
featz = DGLPretrainedFromSmiles()

In [10]:
loader_config = DictConfig(
    {
        "train": {
            "batch_size": 64,
            "shuffle": True,
            "num_workers": 8,
        },
        "val": {
            "batch_size": 64,
            "shuffle": False,
            "num_workers": 8,
        },
        "test": {
            "batch_size": 64,
            "shuffle": False,
            "num_workers": 8,
        },
    }
)

In [25]:
dmc = LipoDataModule
mc = LipoModule
# criterion = nn.BCEWithLogitsLoss()
criterion = nn.MSELoss()
name = "lipo"

In [12]:
dm = dmc(
    root_dir="./data/",
    compound_transform=featz,
    collate_fn=label_graph_collate_function,
    dataloader_config=loader_config,
)

In [13]:
dm.prepare_data()
dm.setup()

In [14]:
dl = dm.train_dataloader()
b = next(iter(dl))

In [26]:
m = mc(
    cross_modal_module=model,
    optimizer=torch.optim.Adam,
    scheduler=partial(torch.optim.lr_scheduler.ExponentialLR, gamma=0.95),
    # criterion=criterion,
    molecule_encoder_attribute_name="molecule_encoder",
    example_input=b,
)

In [34]:
m = m.to(device)
bb = {k: v.to(device) for k, v in b.items()}
m(**bb);

In [28]:
os.environ["WANDB_NOTEBOOK_NAME"] = "./notebooks/7.0-gw-checkpoint.ipynb"

In [29]:
# default logger used by trainer (if tensorboard is installed)
logger = [
    TensorBoardLogger(save_dir=os.getcwd()),
    CSVLogger(save_dir=os.getcwd()),
    WandbLogger(save_dir=os.getcwd(), project="jump_models", tags=[name, "validation"], group=name),
]
# logger = CSVLogger(save_dir=os.getcwd(), name="lightning_logs")
callbacks = [
    RichModelSummary(),
    WandbLogCallback(
        watch=True,
        watch_log="all",
        log_freq=100,
    ),
]

trainer = Trainer(accelerator="gpu", max_epochs=50, logger=logger, callbacks=callbacks)


evaluator = FinetunableEvaluator(
    model=m,
    datamodule=dm,
    trainer=trainer,
    name=name,
)

Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [30]:
warnings.filterwarnings("ignore")

In [35]:
m(**bb)

tensor([[ 89.4749],
        [118.2019],
        [ 98.5668],
        [ 29.4051],
        [101.6879],
        [101.4642],
        [ 66.0896],
        [ 99.9193],
        [ -8.9107],
        [ 17.9166],
        [133.8878],
        [131.8886],
        [ 75.3815],
        [ 13.3442],
        [ 71.6203],
        [ 52.4193],
        [ 52.6021],
        [ 17.8721],
        [101.9396],
        [ 10.5105],
        [ 30.2667],
        [ 75.8556],
        [ 88.6641],
        [ 63.2152],
        [ 19.1693],
        [ 28.3296],
        [ -5.5612],
        [ 28.1050],
        [ 30.0842],
        [ 25.9844],
        [ 88.9908],
        [ 79.0325],
        [ 83.3029],
        [ 72.5576],
        [  3.1259],
        [ 74.6118],
        [167.8580],
        [ 95.0305],
        [122.1017],
        [106.2537],
        [ 30.3841],
        [100.0746],
        [ 10.3026],
        [  2.6383],
        [ 88.4504],
        [ 94.8880],
        [134.1628],
        [ 35.7002],
        [100.5984],
        [ 67.7810],


In [36]:
bb

{'compound': Graph(num_nodes=1767, num_edges=5599,
       ndata_schemes={'atomic_number': Scheme(shape=(), dtype=torch.int64), 'chirality_type': Scheme(shape=(), dtype=torch.int64)}
       edata_schemes={'bond_type': Scheme(shape=(), dtype=torch.int64), 'bond_direction_type': Scheme(shape=(), dtype=torch.int64)}),
 'label': tensor([[ 3.9900],
         [ 2.5400],
         [ 3.4000],
         [ 0.5100],
         [ 3.4400],
         [ 3.4600],
         [ 2.5100],
         [ 2.7400],
         [ 0.7700],
         [-1.0100],
         [ 2.0900],
         [ 3.3000],
         [ 3.2000],
         [ 2.3500],
         [ 2.1600],
         [ 1.9200],
         [ 3.3300],
         [ 1.3300],
         [ 1.3200],
         [ 1.9000],
         [ 2.3100],
         [ 2.7600],
         [ 3.5000],
         [ 1.9500],
         [ 0.4400],
         [ 0.9000],
         [ 0.2000],
         [ 1.5100],
         [ 0.6600],
         [ 1.0000],
         [ 2.8000],
         [ 1.0300],
         [ 4.1000],
         [ 4.00

In [31]:
evaluator.finetune()

You are using a CUDA device ('NVIDIA GeForce RTX 3070 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: 0it [00:00, ?it/s]

[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


Training: 0it [00:00, ?it/s]

RuntimeError: Found dtype Double but expected Float

In [24]:
evaluator.evaluate()

You are using a CUDA device ('NVIDIA GeForce RTX 3070 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

In [31]:
trainer.fit(m, dm)

You are using a CUDA device ('NVIDIA GeForce RTX 3070 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [22]:
trainer.test(m, dm)

You are using a CUDA device ('NVIDIA GeForce RTX 3070 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

[{'test/loss': 1.255344033241272,
  'test/AUC': 0.7004243731498718,
  'train/BinaryAccuracy': 0.5686274766921997,
  'train/BinaryRecall': 0.855766773223877,
  'train/BinaryPrecision': 0.5585212707519531,
  'train/BinaryF1Score': 0.670244574546814}]

In [25]:
wandb.finish()

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁
train/BinaryAccuracy,▁
train/BinaryF1Score,▁
train/BinaryPrecision,▁
train/BinaryRecall,▁
train/loss_step,█▄▁▆▃▇▁█▅▂
trainer/global_step,▁▂▃▃▄▅▆▆▇██████
val/AUC,▁
val/AUC_best,▁
val/loss,▁

0,1
epoch,0.0
train/BinaryAccuracy,0.98104
train/BinaryF1Score,0.05246
train/BinaryPrecision,0.06484
train/BinaryRecall,0.05239
train/loss_step,0.12273
trainer/global_step,514.0
val/AUC,0.73093
val/AUC_best,0.73093
val/loss,0.08415
