# Load models from checkpoints and evaluate them on the evaluation tasks

In [1]:
%load_ext autoreload
%autoreload 2

In [119]:
import os
from copy import deepcopy
from pathlib import Path

import molfeat
import pandas as pd
import torch
import torch.nn as nn
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from lightning.pytorch.loggers import WandbLogger
from omegaconf import DictConfig, OmegaConf, open_dict

from src import utils
from src.models.jump_cl import BasicJUMPModule
from src.modules.losses.contrastive_loss_with_temperature import ContrastiveLossWithTemperature
from src.modules.losses.multiview_losses import NTXent
from src.modules.losses.nt_xent import NtXentLoss
from src.utils import instantiate_evaluator_list

In [3]:
for i in range(1, 4):
    if not Path(f"../cpjump{i}/jump/").exists():
        print(f"Mounting cpjump{i}...")
        os.system(f"sshfs bioclust:/projects/cpjump{i}/ ../cpjump{i}")
    else:
        print(f"cpjump{i} already mounted.")

Mounting cpjump1...
Mounting cpjump2...
Mounting cpjump3...


## Load config

In [6]:
ckpt_str = "../cpjump1/jump/logs/train/multiruns/{run}/checkpoints/epoch_{epoch:0>3}.ckpt"
single_run_ckpt_str = "../cpjump1/jump/logs/train/runs/{run}/checkpoints/epoch_{epoch:0>3}.ckpt"

run_dict = {
    "small1": (run := "2023-08-16_11-59-26/0", "small_jump_cl", epoch := 43, ckpt_str.format(run=run, epoch=epoch)),
    "small": (run := "2023-08-17_13-32-50/0", "small_jump_cl", epoch := 41, ckpt_str.format(run=run, epoch=epoch)),
    "med": (run := "2023-08-07_11-55-54", "med_jump_cl", epoch := 5, ckpt_str.format(run=run, epoch=epoch)),
    "big": (run := "2023-08-01_11-37-40", "big_jump_cl", epoch := 1, ckpt_str.format(run=run, epoch=epoch)),
    "new_small": (
        run := "2023-08-22_17-15-50",
        "fp_small",
        epoch := 43,
        single_run_ckpt_str.format(run=run, epoch=epoch),
    ),
}

In [7]:
run, experiment, epoch, ckpt = run_dict["new_small"]

In [8]:
os.system(f"cat ../cpjump1/jump/logs/train/runs/{run}/.hydra/overrides.yaml");

- experiment=fp_small
- trainer=gpu
- trainer.devices=[1]
- data.num_workers=16
- callbacks=default


In [9]:
os.listdir(f"../cpjump1/jump/logs/train/runs/{run}/checkpoints")

['last.ckpt', 'epoch_043.ckpt']

## Load the config and instantiate the model, loggers and evaluators

In [10]:
initialize(version_base=None, config_path="../configs/")

hydra.initialize()

In [11]:
cfg = compose(
    config_name="train.yaml",
    overrides=[
        "evaluate=true",
        "eval=retrieval",
        "paths.projects_dir=..",
        f"paths.output_dir=../cpjump1/jump/logs/train/multiruns/{run}",
        "experiment=fp_small",
        "data.batch_size=4",
        # "model/molecule_encoder=gin_masking.yaml",
        "trainer.devices=1",
        # "eval.moa_image_task.datamodule.data_root_dir=../",
    ],
)
print(OmegaConf.to_yaml(cfg))

task_name: train
tags:
- small_jump_cl
- fingerprints
- clip_like
- ${model.image_encoder.instance_model_name}
train: true
test: true
evaluate: true
compile: false
ckpt_path: null
seed: 12345
data:
  compound_transform:
    _target_: src.modules.compound_transforms.fp_transform.FPTransform
    fps:
    - maccs
    - ecfp
    compound_str_type: inchi
    params:
      ecfp:
        radius: 2
  _target_: src.models.jump_cl.datamodule.BasicJUMPDataModule
  batch_size: 4
  num_workers: 24
  pin_memory: null
  prefetch_factor: 3
  drop_last: true
  transform:
    _target_: src.modules.transforms.DefaultJUMPTransform
    _convert_: object
    size: 128
    dim:
    - -2
    - -1
  force_split: false
  splitter:
    _target_: src.splitters.ScaffoldSplitter
    train: 1024
    test: 256
    val: 128
    retrieval: 0
  use_compond_cache: false
  data_root_dir: ${paths.projects_dir}/
  split_path: ${paths.split_path}/fp_small3/
  dataloader_config:
    train:
      batch_size: ${data.batch_size}

## Instantiate datamodule

In [12]:
dm = instantiate(cfg.data)



In [13]:
dm.prepare_data()

In [14]:
dm.setup("train")

In [15]:
dl = dm.train_dataloader()



In [18]:
df = dm.train_dataset.load_df

for col in df.columns:
    if col.startswith("FileName"):
        df[col] = df[col].str.replace("/projects/", "../")

In [20]:
dm.train_dataset.load_df

Unnamed: 0_level_0,Metadata_Source,Metadata_Batch,Metadata_Plate,Metadata_Well,FileName_OrigDNA,FileName_OrigAGP,FileName_OrigER,FileName_OrigMito,FileName_OrigRNA,Metadata_InChI,Metadata_PlateType,Metadata_Site
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
source_11__Batch2__EC000047__E18__1,source_11,Batch2,EC000047,E18,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,InChI=1S/C10H10ClN3OS/c11-7-3-1-2-4-8(7)15-6-5...,COMPOUND,1
source_11__Batch2__EC000047__E18__4,source_11,Batch2,EC000047,E18,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,InChI=1S/C10H10ClN3OS/c11-7-3-1-2-4-8(7)15-6-5...,COMPOUND,4
source_11__Batch2__EC000047__E18__5,source_11,Batch2,EC000047,E18,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,InChI=1S/C10H10ClN3OS/c11-7-3-1-2-4-8(7)15-6-5...,COMPOUND,5
source_11__Batch2__EC000047__E18__7,source_11,Batch2,EC000047,E18,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,InChI=1S/C10H10ClN3OS/c11-7-3-1-2-4-8(7)15-6-5...,COMPOUND,7
source_11__Batch2__EC000047__E18__8,source_11,Batch2,EC000047,E18,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,../cpjump2/jump/images/source_11/Batch2/EC0000...,InChI=1S/C10H10ClN3OS/c11-7-3-1-2-4-8(7)15-6-5...,COMPOUND,8
...,...,...,...,...,...,...,...,...,...,...,...,...
source_8__J3__A1170506__L06__3,source_8,J3,A1170506,L06,../cpjump2/jump/images/source_8/J3/A1170506/so...,../cpjump2/jump/images/source_8/J3/A1170506/so...,../cpjump2/jump/images/source_8/J3/A1170506/so...,../cpjump2/jump/images/source_8/J3/A1170506/so...,../cpjump2/jump/images/source_8/J3/A1170506/so...,"InChI=1S/C9H8N2O2S/c10-14(12,13)8-3-4-9-7(6-8)...",COMPOUND,3
source_8__J3__A1170506__L06__5,source_8,J3,A1170506,L06,../cpjump2/jump/images/source_8/J3/A1170506/so...,../cpjump2/jump/images/source_8/J3/A1170506/so...,../cpjump2/jump/images/source_8/J3/A1170506/so...,../cpjump2/jump/images/source_8/J3/A1170506/so...,../cpjump2/jump/images/source_8/J3/A1170506/so...,"InChI=1S/C9H8N2O2S/c10-14(12,13)8-3-4-9-7(6-8)...",COMPOUND,5
source_8__J3__A1170506__L06__6,source_8,J3,A1170506,L06,../cpjump2/jump/images/source_8/J3/A1170506/so...,../cpjump2/jump/images/source_8/J3/A1170506/so...,../cpjump2/jump/images/source_8/J3/A1170506/so...,../cpjump2/jump/images/source_8/J3/A1170506/so...,../cpjump2/jump/images/source_8/J3/A1170506/so...,"InChI=1S/C9H8N2O2S/c10-14(12,13)8-3-4-9-7(6-8)...",COMPOUND,6
source_8__J3__A1170506__L06__7,source_8,J3,A1170506,L06,../cpjump2/jump/images/source_8/J3/A1170506/so...,../cpjump2/jump/images/source_8/J3/A1170506/so...,../cpjump2/jump/images/source_8/J3/A1170506/so...,../cpjump2/jump/images/source_8/J3/A1170506/so...,../cpjump2/jump/images/source_8/J3/A1170506/so...,"InChI=1S/C9H8N2O2S/c10-14(12,13)8-3-4-9-7(6-8)...",COMPOUND,7


In [21]:
b = next(iter(dl))



In [23]:
batches = []
for i, batch in enumerate(dl):
    batches.append(batch)
    if i == 2:
        break



In [26]:
batches[1]

{'image': tensor([[[[-1.2334e-01, -1.9673e-01, -1.2334e-01,  ...,  1.5068e+01,
             1.8517e+01,  1.8517e+01],
           [-1.9673e-01, -1.2334e-01, -1.9673e-01,  ...,  1.4114e+01,
             1.8224e+01,  1.8517e+01],
           [-1.9673e-01, -1.9673e-01, -4.9953e-02,  ...,  1.3160e+01,
             1.6976e+01,  1.7710e+01],
           ...,
           [-1.2334e-01, -1.9673e-01, -1.2334e-01,  ..., -4.9953e-02,
            -1.2334e-01, -1.2334e-01],
           [-1.9673e-01, -1.9673e-01, -1.2334e-01,  ..., -4.9953e-02,
            -1.2334e-01, -4.9953e-02],
           [-1.2334e-01, -4.9953e-02, -1.2334e-01,  ..., -1.2334e-01,
            -1.2334e-01, -1.2334e-01]],
 
          [[-4.5383e-01, -3.8236e-01, -3.4662e-01,  ...,  5.7998e+00,
             7.0505e+00,  7.8367e+00],
           [-4.1809e-01, -5.2530e-01, -4.1809e-01,  ...,  5.6569e+00,
             7.3722e+00,  7.3007e+00],
           [-4.5383e-01, -4.1809e-01, -4.1809e-01,  ...,  5.9070e+00,
             7.0505e+00,  7.40

In [25]:
batches[0]["compound"].shape

torch.Size([4, 2167])

In [27]:
batches[2]["image"].shape

torch.Size([4, 5, 128, 128])

## Instantiate model

In [31]:
device = torch.device("cuda:0")

In [28]:
cfg.model["_target_"] += ".load_from_checkpoint"
with open_dict(cfg.model):
    cfg.model["checkpoint_path"] = ckpt

In [29]:
model = instantiate(cfg.model, map_location="cuda:0")

In [32]:
model.to(device);

In [36]:
embs = []
for i in range(len(batches)):
    batches[i] = {k: v.to(device) for k, v in batches[i].items()}
    embs.append(model(**batches[i]))

In [43]:
embs[0]["compound_emb"]

tensor([[ 0.0000,  1.1347, -1.5460,  ...,  2.8691,  2.5293, -3.3855],
        [-0.8441,  0.5627, -0.0000,  ...,  0.0000,  2.7277, -1.9750],
        [ 0.0000, -0.4231, -1.9781,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0000, -0.4080, -0.9231,  ...,  0.8095,  1.1871, -0.1908]],
       device='cuda:0', grad_fn=<NativeDropoutBackward0>)

## Test losses

In [91]:
ntxent = NtXentLoss(temperature=1)

In [132]:
ntxent2 = NTXent(norm=True, tau=1, uniformity_reg=0, variance_reg=0, covariance_reg=0)

In [93]:
cl = ContrastiveLossWithTemperature(logit_scale=0)

In [133]:
l11 = ntxent2(
    embs[0]["image_emb"],
    embs[0]["compound_emb"],
)

l12 = ntxent2(
    embs[1]["image_emb"],
    embs[1]["compound_emb"],
)

l13 = ntxent2(
    embs[2]["image_emb"],
    embs[2]["compound_emb"],
)

In [135]:
l21 = cl(
    embs[0]["image_emb"],
    embs[0]["compound_emb"],
)

l22 = cl(
    embs[1]["image_emb"],
    embs[1]["compound_emb"],
)

l23 = cl(
    embs[2]["image_emb"],
    embs[2]["compound_emb"],
)

In [136]:
l11, l12, l13, l21, l22, l23

(tensor(1.1239, device='cuda:0', grad_fn=<NegBackward0>),
 tensor(1.0914, device='cuda:0', grad_fn=<NegBackward0>),
 tensor(1.0687, device='cuda:0', grad_fn=<NegBackward0>),
 tensor(6.6137, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(37.2393, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(11.4284, device='cuda:0', grad_fn=<DivBackward0>))

In [115]:
l11 / l12, l21 / l22

(tensor(1.0572, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.1776, device='cuda:0', grad_fn=<DivBackward0>))

In [103]:
embeddings_a = embs[0]["compound_emb"]
embeddings_b = embs[0]["image_emb"]

In [None]:
out = torch.cat([embeddings_a, embeddings_b], dim=0)
n_samples = out.shape[0]

# Calculate cosine similarity
sim = torch.mm(out, out.t().contiguous())
sim = torch.exp(sim / temperature)

# Negative similarity
mask = ~torch.eye(n_samples, device=sim.device).bool()
neg = sim.masked_select(mask).view(n_samples, -1).sum(dim=-1)

# Positive similarity
pos = torch.exp(torch.sum(embeddings_a * embeddings_b, dim=-1) / temperature)
pos = torch.cat([pos, pos], dim=0)

loss = -torch.log(pos / neg).mean()

In [100]:
embs[0]["compound_emb"].shape

torch.Size([4, 256])

In [98]:
embs[0]["image_emb"].shape

torch.Size([4, 256])

In [122]:
trainer = instantiate(cfg.trainer, callbacks=utils.instantiate_callbacks(cfg.callbacks))

Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
