# Load models from checkpoints and evaluate them on the evaluation tasks

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from copy import deepcopy
from pathlib import Path

import molfeat
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from lightning.pytorch.loggers import WandbLogger
from omegaconf import DictConfig, OmegaConf, open_dict

from src import utils
from src.models.jump_cl import BasicJUMPModule
from src.modules.collate_fn import default_collate
from src.modules.collate_fn.default_collate import _default_collate_fn_map
from src.modules.losses.contrastive_losses import InfoNCE, NTXent, RegInfoNCE, RegNTXent
from src.utils import instantiate_evaluator_list

In [3]:
for i in range(1, 4):
    if not Path(f"../cpjump{i}/jump/").exists():
        print(f"Mounting cpjump{i}...")
        os.system(f"sshfs bioclust:/projects/cpjump{i}/ ../cpjump{i}")
    else:
        print(f"cpjump{i} already mounted.")

Mounting cpjump1...
Mounting cpjump2...
Mounting cpjump3...


## Load config

In [4]:
ckpt_str = "../cpjump1/jump/logs/train/multiruns/{run}/checkpoints/epoch_{epoch:0>3}.ckpt"
single_run_ckpt_str = "../cpjump1/jump/logs/train/runs/{run}/checkpoints/epoch_{epoch:0>3}.ckpt"

run_dict = {
    "small1": (run := "2023-08-16_11-59-26/0", "small_jump_cl", epoch := 43, ckpt_str.format(run=run, epoch=epoch)),
    "small": (run := "2023-08-17_13-32-50/0", "small_jump_cl", epoch := 41, ckpt_str.format(run=run, epoch=epoch)),
    "med": (run := "2023-08-07_11-55-54", "med_jump_cl", epoch := 5, ckpt_str.format(run=run, epoch=epoch)),
    "big": (run := "2023-08-01_11-37-40", "big_jump_cl", epoch := 1, ckpt_str.format(run=run, epoch=epoch)),
    "new_small": (
        run := "2023-08-22_17-15-50",
        "fp_small",
        epoch := 43,
        single_run_ckpt_str.format(run=run, epoch=epoch),
    ),
    "new_big": (run := "2023-08-23_20-49-23", "fp_big", epoch := 1, single_run_ckpt_str.format(run=run, epoch=epoch)),
}

In [5]:
run, experiment, epoch, ckpt = run_dict["new_big"]

In [6]:
os.system(f"cat ../cpjump1/jump/logs/train/runs/{run}/.hydra/config.yaml");

task_name: train
tags:
- big_jump_cl
- fingerprints
- clip_like
- ${model.image_encoder.instance_model_name}
train: true
test: true
evaluate: true
compile: false
ckpt_path: null
seed: 2354
data:
  compound_transform:
    _target_: src.modules.compound_transforms.fp_transform.FPTransform
    fps:
    - maccs
    - ecfp
    compound_str_type: inchi
    params:
      ecfp:
        radius: 2
  _target_: src.models.jump_cl.datamodule.BasicJUMPDataModule
  batch_size: 1024
  num_workers: 16
  pin_memory: null
  prefetch_factor: 2
  drop_last: true
  transform:
    _target_: src.modules.transforms.DefaultJUMPTransform
    _convert_: object
    size: 128
    dim:
    - -2
    - -1
  force_split: false
  splitter:
    _target_: src.splitters.ScaffoldSplitter
    train: 90112
    test: 10240
    val: 5120
    retrieval: 3072
  use_compond_cache: false
  data_root_dir: ${paths.projects_dir}/
  split_path: ${paths.split_path}/fp_big4/
  dataloader_config:
    train:
      batch_size: ${data.batch_

In [7]:
os.listdir(f"../cpjump1/jump/logs/train/runs/{run}/")

['.hydra',
 'tags.log',
 'config_tree.log',
 'wandb',
 'csv',
 'tensorboard',
 'nan_batches',
 'checkpoints',
 'eval']

## Load the config and instantiate the model, loggers and evaluators

In [8]:
initialize(version_base=None, config_path="../configs/")

hydra.initialize()

In [9]:
cfg = compose(
    config_name="train.yaml",
    overrides=[
        "evaluate=true",
        "eval=retrieval",
        "paths.projects_dir=..",
        f"paths.output_dir=../cpjump1/jump/logs/train/multiruns/{run}",
        "experiment=fp_big",
        "data.batch_size=4",
        # "model/molecule_encoder=gin_masking.yaml",
        "trainer.devices=1",
        # "eval.moa_image_task.datamodule.data_root_dir=../",
    ],
)
print(OmegaConf.to_yaml(cfg))

task_name: train
tags:
- big_jump_cl
- fingerprints
- clip_like
- ${model.image_encoder.instance_model_name}
train: true
test: true
evaluate: true
compile: false
ckpt_path: null
seed: 12345
data:
  compound_transform:
    _target_: src.modules.compound_transforms.fp_transform.FPTransform
    fps:
    - maccs
    - ecfp
    compound_str_type: inchi
    params:
      ecfp:
        radius: 2
  _target_: src.models.jump_cl.datamodule.BasicJUMPDataModule
  batch_size: 4
  num_workers: 24
  pin_memory: null
  prefetch_factor: 2
  drop_last: true
  transform:
    _target_: src.modules.transforms.DefaultJUMPTransform
    _convert_: object
    size: 128
    dim:
    - -2
    - -1
  force_split: false
  splitter:
    _target_: src.splitters.ScaffoldSplitter
    train: 90112
    test: 10240
    val: 5120
    retrieval: 3072
  use_compond_cache: false
  data_root_dir: ${paths.projects_dir}/
  split_path: ${paths.split_path}/fp_big4/
  dataloader_config:
    train:
      batch_size: ${data.batch_si

## Instantiate datamodule

In [10]:
dm = instantiate(cfg.data)



In [11]:
dm.prepare_data()

In [12]:
dm.setup("train")

In [19]:
dl = dm.train_dataloader()



In [14]:
df = dm.train_dataset.load_df

for col in df.columns:
    if col.startswith("FileName"):
        df[col] = df[col].str.replace("/projects/", "../")

In [22]:
dm.val_dataset.load_df

Unnamed: 0_level_0,Metadata_Source,Metadata_Batch,Metadata_Plate,Metadata_Well,FileName_OrigDNA,FileName_OrigAGP,FileName_OrigER,FileName_OrigMito,FileName_OrigRNA,Metadata_InChI,Metadata_PlateType,Metadata_Site
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
source_1__Batch1_20221004__UL001651__AC35__1,source_1,Batch1_20221004,UL001651,AC35,/projects/cpjump1/jump/images/source_1/Batch1_...,/projects/cpjump1/jump/images/source_1/Batch1_...,/projects/cpjump1/jump/images/source_1/Batch1_...,/projects/cpjump1/jump/images/source_1/Batch1_...,/projects/cpjump1/jump/images/source_1/Batch1_...,"InChI=1S/C10H10F3N5OS/c11-10(12,13)7-14-15-8(2...",COMPOUND,1
source_1__Batch1_20221004__UL001651__AC35__2,source_1,Batch1_20221004,UL001651,AC35,/projects/cpjump1/jump/images/source_1/Batch1_...,/projects/cpjump1/jump/images/source_1/Batch1_...,/projects/cpjump1/jump/images/source_1/Batch1_...,/projects/cpjump1/jump/images/source_1/Batch1_...,/projects/cpjump1/jump/images/source_1/Batch1_...,"InChI=1S/C10H10F3N5OS/c11-10(12,13)7-14-15-8(2...",COMPOUND,2
source_1__Batch1_20221004__UL001651__AC35__3,source_1,Batch1_20221004,UL001651,AC35,/projects/cpjump1/jump/images/source_1/Batch1_...,/projects/cpjump1/jump/images/source_1/Batch1_...,/projects/cpjump1/jump/images/source_1/Batch1_...,/projects/cpjump1/jump/images/source_1/Batch1_...,/projects/cpjump1/jump/images/source_1/Batch1_...,"InChI=1S/C10H10F3N5OS/c11-10(12,13)7-14-15-8(2...",COMPOUND,3
source_1__Batch1_20221004__UL001651__AC35__4,source_1,Batch1_20221004,UL001651,AC35,/projects/cpjump1/jump/images/source_1/Batch1_...,/projects/cpjump1/jump/images/source_1/Batch1_...,/projects/cpjump1/jump/images/source_1/Batch1_...,/projects/cpjump1/jump/images/source_1/Batch1_...,/projects/cpjump1/jump/images/source_1/Batch1_...,"InChI=1S/C10H10F3N5OS/c11-10(12,13)7-14-15-8(2...",COMPOUND,4
source_2__20210808_Batch_4__1086292952__M13__1,source_2,20210808_Batch_4,1086292952,M13,/projects/cpjump3/jump/images/source_2/2021080...,/projects/cpjump3/jump/images/source_2/2021080...,/projects/cpjump3/jump/images/source_2/2021080...,/projects/cpjump3/jump/images/source_2/2021080...,/projects/cpjump3/jump/images/source_2/2021080...,"InChI=1S/C10H10F3N5OS/c11-10(12,13)7-14-15-8(2...",COMPOUND,1
...,...,...,...,...,...,...,...,...,...,...,...,...
source_9__20210901_Run8__GR00003315__AB19__4,source_9,20210901_Run8,GR00003315,AB19,/projects/cpjump1/jump/images/source_9/2021090...,/projects/cpjump1/jump/images/source_9/2021090...,/projects/cpjump1/jump/images/source_9/2021090...,/projects/cpjump1/jump/images/source_9/2021090...,/projects/cpjump1/jump/images/source_9/2021090...,InChI=1S/C9H9N5O2S/c10-8-11-9(13-12-8)17-5-6-2...,COMPOUND,4
source_9__20210901_Run8__GR00003339__L19__1,source_9,20210901_Run8,GR00003339,L19,/projects/cpjump1/jump/images/source_9/2021090...,/projects/cpjump1/jump/images/source_9/2021090...,/projects/cpjump1/jump/images/source_9/2021090...,/projects/cpjump1/jump/images/source_9/2021090...,/projects/cpjump1/jump/images/source_9/2021090...,InChI=1S/C9H9N5O2S/c10-8-11-9(13-12-8)17-5-6-2...,COMPOUND,1
source_9__20210901_Run8__GR00003339__L19__2,source_9,20210901_Run8,GR00003339,L19,/projects/cpjump1/jump/images/source_9/2021090...,/projects/cpjump1/jump/images/source_9/2021090...,/projects/cpjump1/jump/images/source_9/2021090...,/projects/cpjump1/jump/images/source_9/2021090...,/projects/cpjump1/jump/images/source_9/2021090...,InChI=1S/C9H9N5O2S/c10-8-11-9(13-12-8)17-5-6-2...,COMPOUND,2
source_9__20210901_Run8__GR00003339__L19__3,source_9,20210901_Run8,GR00003339,L19,/projects/cpjump1/jump/images/source_9/2021090...,/projects/cpjump1/jump/images/source_9/2021090...,/projects/cpjump1/jump/images/source_9/2021090...,/projects/cpjump1/jump/images/source_9/2021090...,/projects/cpjump1/jump/images/source_9/2021090...,InChI=1S/C9H9N5O2S/c10-8-11-9(13-12-8)17-5-6-2...,COMPOUND,3


In [16]:
batches[1]

NameError: name 'batches' is not defined

In [None]:
batches[0]["compound"].shape

torch.Size([4, 2167])

In [29]:
batches[2]["image"].shape

torch.Size([4, 5, 128, 128])

## Instantiate model

In [57]:
device = torch.device("cuda:0")

In [58]:
cfg.model["_target_"] += ".load_from_checkpoint"
with open_dict(cfg.model):
    cfg.model["checkpoint_path"] = ckpt

In [59]:
model = instantiate(cfg.model, map_location="cuda:0")

In [60]:
model.to(device);

In [36]:
embs = []
for i in range(len(batches)):
    batches[i] = {k: v.to(device) for k, v in batches[i].items()}
    embs.append(model(**batches[i]))

In [37]:
embs[0]["compound_emb"]

tensor([[ 0.0207, -0.0000,  0.0000,  ..., -0.0519,  0.4667, -0.2460],
        [ 0.0333, -0.0745,  0.0000,  ...,  0.0563,  0.5618, -0.1488],
        [ 0.0682, -0.0746,  0.4007,  ...,  0.0000,  0.3068, -0.0000],
        [ 0.0000, -0.0485,  0.2254,  ...,  0.0313,  0.0613, -0.0600]],
       device='cuda:0', grad_fn=<NativeDropoutBackward0>)

In [39]:
nan_loss_batch = "../cpjump1/jump/logs/train/runs/2023-08-23_20-49-23/nan_batches/running_epoch_0_batch_51.pt"

In [40]:
nan_batch = torch.load(nan_loss_batch)

In [41]:
nan_batch["compound"]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')

In [47]:
nan_batch_cpu = {k: v.to("cpu") for k, v in nan_batch.items()}
model.to("cpu")

BasicJUMPModule(
  (image_encoder): CNNEncoder(
    (backbone): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (drop_block): Identity()
          (act1): ReLU(inplace=True)
          (aa): Identity()
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act2): ReLU(inplace=True)
        )
        (1): BasicBlock(
      

In [49]:
nan_out = model.training_step(nan_batch_cpu, 0)

  rank_zero_warn(


In [53]:
image_emb = model.image_encoder(nan_batch_cpu["image"])

In [54]:
compound_emb = model.molecule_encoder(nan_batch_cpu["compound"])

In [55]:
loss = model.criterion(image_emb, compound_emb)

In [68]:
for i in range(1024):
    if nan_batch_cpu["image"][i].isnan().any():
        print(i)

537


In [69]:
nan_batch_cpu["image"][537]

tensor([[[    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         ...,
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan]],

        [[18.3959, 17.1319, 15.8679,  ..., -0.2032,  0.3385,  0.5191],
         [17.1319, 15.6873, 14.6039,  ...,  0.1579,  0.1579,  0.1579],
         [16.9513, 16.7707, 14.4233,  ...,  0.1579, -0.2032, -0.0226],
         ...,
         [-0.3838, -0.0226, -0.3838,  ...,  0.1579, -0.2032, -0.3838],
         [-0.2032, -0.0226, -0.3838,  ..., -0.2032, -0.0226,  0.3385],
         [-0.3838,  0.3385, -0.0226,  ...,  0.1579, -0.2032, -0.0226]],

        [[14.2149, 10.4271,  8.9120,  ...,  1.3365,  0.5790, -0.1786],
         [12.6998, 12.6998, 10.4271,  ..., -0

In [71]:
dm.transform

DefaultJUMPTransform(
  (transform): Sequential(
    (0): RandomHorizontalFlip(p=0.5)
    (1): RandomVerticalFlip(p=0.5)
    (2): RandomCrop(size=(128, 128), pad_if_needed=False, fill=0, padding_mode=constant)
    (3): ToImageTensor()
    (4): ConvertDtype()
    (5): ImageNormalization()
  )
)

In [24]:
dm.setup("test")

In [25]:
df = dm.test_dataset.load_df

for col in df.columns:
    if col.startswith("FileName"):
        df[col] = df[col].str.replace("/projects/", "../")

In [26]:
test_dl = dm.test_dataloader()

In [32]:
for b in test_dl:
    break

In [33]:
b["image"].isnan().any()

tensor(False)

In [35]:
nan_to_num?

Object `nan_to_num` not found.


In [45]:
transform = instantiate(cfg.data.transform)

In [61]:
nan_batch["image"]

tensor([[[[-3.9186e-01, -4.0917e-01, -4.0917e-01,  ..., -3.2263e-01,
           -3.2263e-01, -3.2263e-01],
          [-4.0917e-01, -3.9186e-01, -4.0917e-01,  ..., -3.3993e-01,
           -3.5724e-01, -3.3993e-01],
          [-3.5724e-01, -3.9186e-01, -4.0917e-01,  ..., -3.3993e-01,
           -3.9186e-01, -3.7455e-01],
          ...,
          [-4.6110e-01, -4.7841e-01, -4.6110e-01,  ...,  2.2738e+00,
            2.2565e+00,  2.1699e+00],
          [-4.4379e-01, -4.0917e-01, -4.4379e-01,  ...,  2.3603e+00,
            2.3950e+00,  2.2046e+00],
          [-4.6110e-01, -4.2648e-01, -4.4379e-01,  ...,  2.3430e+00,
            2.4296e+00,  2.3777e+00]],

         [[ 6.9876e-02,  6.9876e-02,  1.8677e-01,  ...,  5.6669e-01,
            7.4203e-01,  6.8358e-01],
          [ 6.2513e-01,  4.4979e-01,  4.2056e-01,  ...,  5.0824e-01,
            4.4979e-01,  3.6212e-01],
          [ 7.4203e-01,  5.9591e-01,  4.4979e-01,  ...,  3.3289e-01,
            1.8677e-01,  1.2832e-01],
          ...,
     

In [66]:
transformed_nan_b = transform(nan_batch["image"].to("cpu")).to("cpu")

In [68]:
model.to("cpu")
im_emb = model.image_encoder(transformed_nan_b)

In [69]:
im_emb

tensor([[-4.4974e-03, -7.0135e-03, -3.6734e-03,  ...,  3.0655e-03,
          1.7795e-02, -1.8031e-02],
        [-5.8477e-03, -9.2124e-03, -1.8504e-02,  ...,  2.3520e-02,
          1.8007e-02, -3.6282e-02],
        [-7.6105e-03, -8.5919e-03,  7.5200e-03,  ...,  7.9799e-03,
         -3.2365e-03, -1.0332e-02],
        ...,
        [-6.4327e-03, -3.1845e-04, -4.7621e-02,  ...,  1.3264e-02,
          2.8813e-02, -1.2199e-02],
        [-1.0584e-02, -8.1476e-04, -7.5834e-03,  ...,  2.3543e-02,
         -1.7742e-02, -2.0531e-03],
        [-5.5999e-03, -5.3685e-03,  2.9388e-02,  ..., -4.7939e-05,
          3.5367e-02, -6.8224e-03]], grad_fn=<AddmmBackward0>)

## Test losses

In [None]:
criterions = []

In [55]:
infonce = InfoNCE(temperature=0.5, norm=True, eps=1e-8)

In [62]:
model.molecule_encoder(nan_batch["compound"])

tensor([[ 0.0289, -0.0334,  0.1331,  ...,  0.0502,  0.1946, -0.0000],
        [ 0.0000, -0.0348,  0.0489,  ...,  0.0423,  0.0000, -0.0754],
        [ 0.0325, -0.0344,  0.1656,  ...,  0.0571,  0.1605, -0.0654],
        ...,
        [ 0.0329, -0.0433,  0.2460,  ...,  0.0621,  0.0000, -0.0000],
        [ 0.0223, -0.0403,  0.0000,  ...,  0.0320,  0.3211, -0.0940],
        [ 0.0353, -0.0580,  0.3490,  ...,  0.0697,  0.1115, -0.0423]],
       device='cuda:0', grad_fn=<NativeDropoutBackward0>)

In [70]:
infonce(model.molecule_encoder(nan_batch["compound"].to("cpu")), im_emb)

tensor(6.9285, grad_fn=<NegBackward0>)

In [210]:
res = {}
j = 1

for criterion in criterions:
    for i in range(3):
        key = f"{criterion.__class__.__name__}_{i}"
        if key in res:
            key += f"_{j}"
            j += 1
        res[key] = criterion(
            embs[i]["image_emb"],
            embs[i]["compound_emb"],
        )

In [211]:
res

{'NtXentLoss_0': tensor(2.4853, device='cuda:0', grad_fn=<NegBackward0>),
 'NtXentLoss_1': tensor(2.4590, device='cuda:0', grad_fn=<NegBackward0>),
 'NtXentLoss_2': tensor(2.4400, device='cuda:0', grad_fn=<NegBackward0>),
 'NTXent_0': tensor(1.1239, device='cuda:0', grad_fn=<NegBackward0>),
 'NTXent_1': tensor(1.0914, device='cuda:0', grad_fn=<NegBackward0>),
 'NTXent_2': tensor(1.0687, device='cuda:0', grad_fn=<NegBackward0>),
 'ContrastiveLossWithTemperature_0': tensor(1.4052, device='cuda:0', grad_fn=<DivBackward0>),
 'ContrastiveLossWithTemperature_1': tensor(1.3802, device='cuda:0', grad_fn=<DivBackward0>),
 'ContrastiveLossWithTemperature_2': tensor(1.3642, device='cuda:0', grad_fn=<DivBackward0>),
 'NTXent_0_1': tensor(-inf, device='cuda:0', grad_fn=<AddBackward0>),
 'NTXent_1_2': tensor(-inf, device='cuda:0', grad_fn=<AddBackward0>),
 'NTXent_2_3': tensor(-inf, device='cuda:0', grad_fn=<AddBackward0>),
 'InfoNCE_0': tensor(-inf, device='cuda:0', grad_fn=<AddBackward0>),
 'InfoN

In [175]:
l11 / l12, l21 / l22

(tensor(1.0298, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(1.0107, device='cuda:0', grad_fn=<DivBackward0>))

In [138]:
embeddings_a = embs[0]["compound_emb"]
embeddings_b = embs[0]["image_emb"]
temperature = 0.5

In [167]:
embeddings_a_abs = F.normalize(embeddings_a, dim=1)
embeddings_b_abs = F.normalize(embeddings_b, dim=1)

out = torch.cat([embeddings_a_abs, embeddings_b_abs], dim=0)
n_samples = out.shape[0]

# Calculate cosine similarity
sim = torch.mm(out, out.t().contiguous())
sim = torch.exp(sim / temperature)

# Negative similarity
mask = ~torch.eye(n_samples, device=sim.device).bool()
neg = sim.masked_select(mask).view(n_samples, -1).sum(dim=-1)

# Positive similarity
pos = torch.exp(torch.sum(embeddings_a * embeddings_b, dim=-1) / temperature)
pos = torch.cat([pos, pos], dim=0)

loss = -torch.log(pos / neg).mean()

In [168]:
loss

tensor(30.5000, device='cuda:0', grad_fn=<NegBackward0>)

In [165]:
embeddings_a_abs = F.normalize(embeddings_a, dim=1)

In [154]:
torch.mm(embs[0]["image_emb"], embs[0]["compound_emb"].t())

tensor([[-18.0215, -34.1079,  -9.9153,  -5.6235],
        [ -8.6863, -19.3624,  -5.2746,  -3.0330],
        [-19.6764, -35.8913, -10.7404,  -6.0404],
        [-21.9357, -40.2616, -12.6212,  -7.4393]], device='cuda:0',
       grad_fn=<MmBackward0>)

In [100]:
embs[0]["compound_emb"].shape

torch.Size([4, 256])

In [98]:
embs[0]["image_emb"].shape

torch.Size([4, 256])

In [122]:
trainer = instantiate(cfg.trainer, callbacks=utils.instantiate_callbacks(cfg.callbacks))

Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


## Fix collate functions

In [82]:
default_collate?

[0;31mSignature:[0m [0mdefault_collate[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mFile:[0m      /mnt/2547d4d7-6732-4154-b0e1-17b0c1e0c565/Document-2/Projet2/Stage/workspace/jump_models/src/modules/collate_fn/default_collate.py
[0;31mType:[0m      function