# Load models from checkpoints and evaluate them on the evaluation tasks

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from copy import deepcopy
from pathlib import Path

import molfeat
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from lightning.pytorch.loggers import WandbLogger
from omegaconf import DictConfig, OmegaConf, open_dict

from src import utils
from src.models.jump_cl import BasicJUMPModule
from src.modules.collate_fn import default_collate
from src.modules.collate_fn.default_collate import _default_collate_fn_map
from src.modules.losses.contrastive_losses import InfoNCE, NTXent, RegInfoNCE, RegNTXent
from src.utils import instantiate_evaluator_list

In [3]:
for i in range(1, 4):
    if not Path(f"../cpjump{i}/jump/").exists():
        print(f"Mounting cpjump{i}...")
        os.system(f"sshfs bioclust:/projects/cpjump{i}/ ../cpjump{i}")
    else:
        print(f"cpjump{i} already mounted.")

cpjump1 already mounted.
cpjump2 already mounted.
cpjump3 already mounted.


## Load config

In [4]:
ckpt_str = "../cpjump1/jump/logs/train/multiruns/{run}/checkpoints/epoch_{epoch:0>3}.ckpt"
single_run_ckpt_str = "../cpjump1/jump/logs/train/runs/{run}/checkpoints/epoch_{epoch:0>3}.ckpt"

run_dict = {
    "small1": (run := "2023-08-16_11-59-26/0", "small_jump_cl", epoch := 43, ckpt_str.format(run=run, epoch=epoch)),
    "small": (run := "2023-08-17_13-32-50/0", "small_jump_cl", epoch := 41, ckpt_str.format(run=run, epoch=epoch)),
    "med": (run := "2023-08-07_11-55-54", "med_jump_cl", epoch := 5, ckpt_str.format(run=run, epoch=epoch)),
    "big": (run := "2023-08-01_11-37-40", "big_jump_cl", epoch := 1, ckpt_str.format(run=run, epoch=epoch)),
    "new_small": (
        run := "2023-08-22_17-15-50",
        "fp_small",
        epoch := 43,
        single_run_ckpt_str.format(run=run, epoch=epoch),
    ),
    "new_big": (run := "2023-08-23_20-49-23", "fp_big", epoch := 1, single_run_ckpt_str.format(run=run, epoch=epoch)),
}

In [5]:
run, experiment, epoch, ckpt = run_dict["new_big"]

In [6]:
os.system(f"cat ../cpjump1/jump/logs/train/runs/{run}/.hydra/config.yaml");

task_name: train
tags:
- big_jump_cl
- fingerprints
- clip_like
- ${model.image_encoder.instance_model_name}
train: true
test: true
evaluate: true
compile: false
ckpt_path: null
seed: 2354
data:
  compound_transform:
    _target_: src.modules.compound_transforms.fp_transform.FPTransform
    fps:
    - maccs
    - ecfp
    compound_str_type: inchi
    params:
      ecfp:
        radius: 2
  _target_: src.models.jump_cl.datamodule.BasicJUMPDataModule
  batch_size: 1024
  num_workers: 16
  pin_memory: null
  prefetch_factor: 2
  drop_last: true
  transform:
    _target_: src.modules.transforms.DefaultJUMPTransform
    _convert_: object
    size: 128
    dim:
    - -2
    - -1
  force_split: false
  splitter:
    _target_: src.splitters.ScaffoldSplitter
    train: 90112
    test: 10240
    val: 5120
    retrieval: 3072
  use_compond_cache: false
  data_root_dir: ${paths.projects_dir}/
  split_path: ${paths.split_path}/fp_big4/
  dataloader_config:
    train:
      batch_size: ${data.batch_

In [7]:
os.listdir(f"../cpjump1/jump/logs/train/runs/{run}/.hydra")

['config.yaml', 'hydra.yaml', 'overrides.yaml']

## Load the config and instantiate the model, loggers and evaluators

In [42]:
GlobalHydra.instance().clear()

In [8]:
initialize(version_base=None, config_path=f"../configs")

hydra.initialize()

In [12]:
cfg = compose(
    config_name="train.yaml",
    overrides=[
        "evaluate=true",
        "eval=retrieval",
        "paths.projects_dir=..",
        f"paths.output_dir=../cpjump1/jump/logs/train/multiruns/{run}",
        "experiment=gin_context_pred/small",
        "data.batch_size=4",
        # "model/molecule_encoder=gin_masking.yaml",
        "trainer.devices=1",
        # "eval.moa_image_task.datamodule.data_root_dir=../",
    ],
)
print(OmegaConf.to_yaml(cfg))

task_name: train
tags:
- small_jump_cl
- pretrained_gin
- clip_like
- ${model.molecule_encoder.pretrained_name}
- ${model.image_encoder.instance_model_name}
train: true
test: true
evaluate: true
compile: false
ckpt_path: null
seed: 12345
data:
  compound_transform:
    _target_: src.modules.compound_transforms.dgllife_transform.DGLPretrainedFromInchi
    add_self_loop: true
    canonical_atom_order: true
    num_virtual_nodes: 0
    explicit_hydrogens: false
  _target_: src.models.jump_cl.datamodule.BasicJUMPDataModule
  batch_size: 4
  num_workers: 24
  pin_memory: null
  prefetch_factor: 3
  drop_last: true
  transform:
    _target_: src.modules.transforms.DefaultJUMPTransform
    _convert_: object
    size: 128
    dim:
    - -2
    - -1
  force_split: false
  splitter:
    _target_: src.splitters.ScaffoldSplitter
    train: 1024
    test: 256
    val: 128
    retrieval: 0
  use_compond_cache: false
  data_root_dir: ${paths.projects_dir}/
  split_path: ${paths.split_path}/fp_small3/

## Instantiate datamodule

In [13]:
dm = instantiate(cfg.data)



In [14]:
dm.prepare_data()

In [15]:
dm.setup("train")

In [16]:
dl = dm.train_dataloader()



## Instantiate model

In [17]:
device = torch.device("cuda:0")

In [19]:
cfg.model.scheduler.warmup_steps = [3]
cfg.model.monitor = "val/loss"

In [20]:
model = instantiate(cfg.model, map_location="cuda:0")

Downloading gin_supervised_contextpred_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_contextpred.pth...
Pretrained model loaded


In [33]:
trainer = instantiate(
    cfg.trainer,
    limit_train_batches=1,
    limit_val_batches=1,
    num_sanity_val_steps=1,
    logger=False,
    check_val_every_n_epoch=1,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.


In [95]:
trainer.callback_metrics

{'train/loss': tensor(1.3786),
 'train/loss_step': tensor(1.3786),
 'train/steps': tensor(3.),
 'train/loss_epoch': tensor(1.3786),
 'val/loss': tensor(1.3981)}

In [94]:
trainer.logged_metrics

{'train/loss_step': tensor(1.3786),
 'train/steps': tensor(3.),
 'train/loss_epoch': tensor(1.3786),
 'val/steps': tensor(3.),
 'val/loss': tensor(1.3981)}

## Test losses

In [36]:
batches = []
embs = []
for i, batch in enumerate(dm.train_dataloader()):
    batches.append({k: v.to(device) for k, v in batch.items()})
    embs.append(model(**batches[i]))

    if i == 3:
        break



In [38]:
embs[0]

{'image_emb': tensor([[ 0.3213, -0.1112, -0.3768,  ...,  0.2018, -0.3701, -0.4943],
         [-0.0396, -0.1335,  0.0044,  ...,  0.1353, -0.3113, -0.1665],
         [ 0.3029, -0.0417,  0.1074,  ...,  0.4303,  0.0726,  0.1281],
         [-0.1516, -0.0128, -0.0193,  ...,  0.1602, -0.1948, -0.0625]],
        device='cuda:0', grad_fn=<AddmmBackward0>),
 'compound_emb': tensor([[ 0.0813,  0.0202, -0.0115,  ...,  0.0848,  0.0620,  0.0570],
         [-0.0428, -0.0047,  0.0897,  ...,  0.1063,  0.0917,  0.1529],
         [-0.0306,  0.0627, -0.0716,  ...,  0.0193,  0.0201,  0.1033],
         [-0.1542,  0.0046, -0.0676,  ...,  0.0799,  0.0721,  0.0802]],
        device='cuda:0', grad_fn=<AddmmBackward0>)}

In [None]:
criterions = []

In [39]:
infonce = InfoNCE(temperature=0.5, norm=True, eps=1e-8)

In [42]:
infonce(embs[0]["image_emb"], embs[0]["image_emb"]) / 4

tensor(0.1633, device='cuda:0', grad_fn=<DivBackward0>)

In [66]:
z1 = embs[0]["image_emb"]
z2 = embs[0]["compound_emb"]

In [67]:
sim_matrix = torch.einsum("ik,jk->ij", z1, z2)

In [55]:
z1_abs = z1.norm(dim=1)
z2_abs = z2.norm(dim=1)
sim_matrix = sim_matrix / (torch.einsum("i,j->ij", z1_abs, z2_abs) + 1e-8)

In [69]:
sim_matrix = torch.exp(sim_matrix / 1)

In [70]:
sim_matrix

tensor([[1.0888, 0.7952, 0.8916, 1.0491],
        [0.6790, 0.7940, 0.6577, 1.1722],
        [0.8089, 1.0380, 0.7858, 1.1058],
        [0.9522, 0.8749, 0.8597, 1.1981]], device='cuda:0',
       grad_fn=<ExpBackward0>)

In [71]:
pos_sim = torch.diagonal(sim_matrix)
loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)  # This is the difference from InfoNCE

In [72]:
-torch.log(loss)

tensor([0.9214, 1.1506, 1.3238, 0.8076], device='cuda:0',
       grad_fn=<NegBackward0>)

In [61]:
(sim_matrix.sum(dim=1) - pos_sim)

tensor([4.4275, 4.8740, 4.4021, 4.3147], device='cuda:0',
       grad_fn=<SubBackward0>)

In [60]:
pos_sim / (sim_matrix.sum(dim=1) - pos_sim)

tensor([0.6140, 0.5577, 0.6175, 0.6300], device='cuda:0',
       grad_fn=<DivBackward0>)

In [70]:
infonce(model.molecule_encoder(nan_batch["compound"].to("cpu")), im_emb)

tensor(6.9285, grad_fn=<NegBackward0>)

In [210]:
res = {}
j = 1

for criterion in criterions:
    for i in range(3):
        key = f"{criterion.__class__.__name__}_{i}"
        if key in res:
            key += f"_{j}"
            j += 1
        res[key] = criterion(
            embs[i]["image_emb"],
            embs[i]["compound_emb"],
        )

In [211]:
res

{'NtXentLoss_0': tensor(2.4853, device='cuda:0', grad_fn=<NegBackward0>),
 'NtXentLoss_1': tensor(2.4590, device='cuda:0', grad_fn=<NegBackward0>),
 'NtXentLoss_2': tensor(2.4400, device='cuda:0', grad_fn=<NegBackward0>),
 'NTXent_0': tensor(1.1239, device='cuda:0', grad_fn=<NegBackward0>),
 'NTXent_1': tensor(1.0914, device='cuda:0', grad_fn=<NegBackward0>),
 'NTXent_2': tensor(1.0687, device='cuda:0', grad_fn=<NegBackward0>),
 'ContrastiveLossWithTemperature_0': tensor(1.4052, device='cuda:0', grad_fn=<DivBackward0>),
 'ContrastiveLossWithTemperature_1': tensor(1.3802, device='cuda:0', grad_fn=<DivBackward0>),
 'ContrastiveLossWithTemperature_2': tensor(1.3642, device='cuda:0', grad_fn=<DivBackward0>),
 'NTXent_0_1': tensor(-inf, device='cuda:0', grad_fn=<AddBackward0>),
 'NTXent_1_2': tensor(-inf, device='cuda:0', grad_fn=<AddBackward0>),
 'NTXent_2_3': tensor(-inf, device='cuda:0', grad_fn=<AddBackward0>),
 'InfoNCE_0': tensor(-inf, device='cuda:0', grad_fn=<AddBackward0>),
 'InfoN

In [175]:
l11 / l12, l21 / l22

(tensor(1.0298, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(1.0107, device='cuda:0', grad_fn=<DivBackward0>))

In [138]:
embeddings_a = embs[0]["compound_emb"]
embeddings_b = embs[0]["image_emb"]
temperature = 0.5

In [167]:
embeddings_a_abs = F.normalize(embeddings_a, dim=1)
embeddings_b_abs = F.normalize(embeddings_b, dim=1)

out = torch.cat([embeddings_a_abs, embeddings_b_abs], dim=0)
n_samples = out.shape[0]

# Calculate cosine similarity
sim = torch.mm(out, out.t().contiguous())
sim = torch.exp(sim / temperature)

# Negative similarity
mask = ~torch.eye(n_samples, device=sim.device).bool()
neg = sim.masked_select(mask).view(n_samples, -1).sum(dim=-1)

# Positive similarity
pos = torch.exp(torch.sum(embeddings_a * embeddings_b, dim=-1) / temperature)
pos = torch.cat([pos, pos], dim=0)

loss = -torch.log(pos / neg).mean()

In [168]:
loss

tensor(30.5000, device='cuda:0', grad_fn=<NegBackward0>)

In [165]:
embeddings_a_abs = F.normalize(embeddings_a, dim=1)

In [154]:
torch.mm(embs[0]["image_emb"], embs[0]["compound_emb"].t())

tensor([[-18.0215, -34.1079,  -9.9153,  -5.6235],
        [ -8.6863, -19.3624,  -5.2746,  -3.0330],
        [-19.6764, -35.8913, -10.7404,  -6.0404],
        [-21.9357, -40.2616, -12.6212,  -7.4393]], device='cuda:0',
       grad_fn=<MmBackward0>)

In [100]:
embs[0]["compound_emb"].shape

torch.Size([4, 256])

In [98]:
embs[0]["image_emb"].shape

torch.Size([4, 256])

In [122]:
trainer = instantiate(cfg.trainer, callbacks=utils.instantiate_callbacks(cfg.callbacks))

Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


## Fix collate functions

In [82]:
default_collate?

[0;31mSignature:[0m [0mdefault_collate[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mFile:[0m      /mnt/2547d4d7-6732-4154-b0e1-17b0c1e0c565/Document-2/Projet2/Stage/workspace/jump_models/src/modules/collate_fn/default_collate.py
[0;31mType:[0m      function