# Test the losses

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import shutil
from copy import deepcopy
from pathlib import Path

import molfeat
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from lightning.pytorch.loggers import WandbLogger
from omegaconf import DictConfig, OmegaConf, open_dict

from src import utils
from src.models.jump_cl import BasicJUMPModule
from src.modules.collate_fn import default_collate
from src.modules.losses.autoencoder_losses import GraphImageVariatonalEncoderLoss, ImageGraphVariatonalEncoderLoss
from src.modules.losses.base_losses import CombinationLoss, RegularizationLoss
from src.modules.losses.contrastive_losses import InfoNCE, NTXent, RegInfoNCE, RegNTXent
from src.modules.losses.matching_losses import GraphImageMatchingLoss
from src.utils import instantiate_evaluator_list

In [3]:
for i in range(1, 4):
    if not Path(f"../cpjump{i}/jump/").exists():
        print(f"Mounting cpjump{i}...")
        os.system(f"sshfs bioclust:/projects/cpjump{i}/ ../cpjump{i}")
    else:
        print(f"cpjump{i} already mounted.")

cpjump1 already mounted.
cpjump2 already mounted.
cpjump3 already mounted.


## Load the config and instantiate the model, loggers and evaluators

In [4]:
GlobalHydra.instance().clear()

In [5]:
run = "../cpjump1/jump/logs/train/runs/2023-09-08_13-41-04"
ckpt = f"{run}/checkpoints/epoch_097.ckpt"

In [6]:
initialize(version_base=None, config_path=f"../{run}/.hydra")

hydra.initialize()

In [7]:
os.listdir(f"{run}/.hydra")

['config.yaml', 'hydra.yaml', 'overrides.yaml']

In [8]:
cfg = compose(
    config_name="config.yaml",
    overrides=[
        "evaluate=true",
        "eval=retrieval",
        "paths.projects_dir=..",
        f"paths.output_dir=../cpjump1/jump/logs/train/runs/{run}",
        # "experiment=fp_big",
        "data.batch_size=4",
        # "model/molecule_encoder=gin_masking.yaml",
        "trainer.devices=1",
        # "eval.moa_image_task.datamodule.data_root_dir=../",
    ],
)
print(OmegaConf.to_yaml(cfg))

task_name: train
tags:
- big_images
- big_jump_cl
- pretrained
- clip_like
- pna
- resnet34
train: true
load_first_bacth: true
test: true
evaluate: true
compile: false
ckpt_path: null
seed: 12345
data:
  compound_transform:
    _target_: src.modules.compound_transforms.pna.PNATransform
    compound_str_type: inchi
  _target_: src.models.jump_cl.datamodule.BasicJUMPDataModule
  batch_size: 4
  num_workers: 8
  pin_memory: null
  prefetch_factor: 2
  drop_last: true
  transform:
    _target_: src.modules.transforms.SimpleTransform
    _convert_: object
    size: 512
  force_split: false
  splitter:
    _target_: src.splitters.ScaffoldSplitter
    train: -1
    test: 8192
    val: 4096
    retrieval: 4096
  use_compond_cache: false
  data_root_dir: ${paths.projects_dir}/
  split_path: ${paths.split_path}/scaffold_split/
  dataloader_config:
    train:
      batch_size: ${data.batch_size}
      num_workers: ${data.num_workers}
      pin_memory: ${data.pin_memory}
      prefetch_factor: ${d

In [9]:
dm = instantiate(cfg.data)
dm.prepare_data()
dm.setup("test")
dl = dm.test_dataloader()



## Instantiate model

In [10]:
# device = torch.device("cuda:0")
device = torch.device("cpu")
cfg.model["_target_"] += ".load_from_checkpoint"
with open_dict(cfg.model):
    cfg.model["checkpoint_path"] = ckpt

In [11]:
model = instantiate(cfg.model, map_location=device, strict=False)

  rank_zero_warn(


In [12]:
batches = []
embs = []
for i, batch in enumerate(dl):
    batches.append({k: v.to(device) for k, v in batch.items()})
    embs.append(model(**batches[i]))

    if i == 3:
        break

## Losses

In [13]:
criterion = instantiate(cfg.model.criterion)

In [43]:
criterion

RegNTXent(
  (losses): ModuleDict(
    (temp_loss): NTXent()
    (reg_loss): RegularizationLoss(
      (mse_loss): MSELoss()
      (l1_loss): L1Loss()
    )
  )
)

In [55]:
criterion.losses["temp_loss"].temperature_param

<ClampedParameter>: 0.5

In [53]:
criterion.temperature

Parameter containing:
tensor(0.5000)

In [21]:
img1, comp1 = embs[0]["image_emb"], embs[0]["compound_emb"]
img2, comp2 = embs[1]["image_emb"], embs[1]["compound_emb"]

In [64]:
criterion(img1, comp1)

{'RegNTXent/loss': tensor(0.4345, device='cuda:0', grad_fn=<NegBackward0>),
 'RegNTXent/x_to_y_top1': tensor(0.7500, device='cuda:0'),
 'RegNTXent/x_to_y_top5': tensor(1., device='cuda:0'),
 'RegNTXent/x_to_y_top10': tensor(1., device='cuda:0'),
 'RegNTXent/x_to_y_mean_pos': tensor(1.5000, device='cuda:0'),
 'RegNTXent/x_to_y_mean_pos_normed': tensor(0.3750, device='cuda:0'),
 'RegNTXent/y_to_x_top1': tensor(0.5000, device='cuda:0'),
 'RegNTXent/y_to_x_top5': tensor(1., device='cuda:0'),
 'RegNTXent/y_to_x_top10': tensor(1., device='cuda:0'),
 'RegNTXent/y_to_x_mean_pos': tensor(1.5000, device='cuda:0'),
 'RegNTXent/y_to_x_mean_pos_normed': tensor(0.3750, device='cuda:0'),
 'Regularization/mse_loss': tensor(0.0025, device='cuda:0', grad_fn=<MseLossBackward0>),
 'Regularization/std_loss': tensor(1.9144, device='cuda:0', grad_fn=<AddBackward0>),
 'Regularization/cov_loss': tensor(0.0022, device='cuda:0', grad_fn=<AddBackward0>),
 'Regularization/loss': tensor(1.9152, device='cuda:0', gra

In [70]:
criterion(img2, comp2)

{'RegNTXent/loss': tensor(0.4503, device='cuda:0', grad_fn=<NegBackward0>),
 'RegNTXent/x_to_y_top1': tensor(0.5000, device='cuda:0'),
 'RegNTXent/x_to_y_top5': tensor(1., device='cuda:0'),
 'RegNTXent/x_to_y_top10': tensor(1., device='cuda:0'),
 'RegNTXent/x_to_y_mean_pos': tensor(1.5000, device='cuda:0'),
 'RegNTXent/x_to_y_mean_pos_normed': tensor(0.3750, device='cuda:0'),
 'RegNTXent/y_to_x_top1': tensor(0.5000, device='cuda:0'),
 'RegNTXent/y_to_x_top5': tensor(1., device='cuda:0'),
 'RegNTXent/y_to_x_top10': tensor(1., device='cuda:0'),
 'RegNTXent/y_to_x_mean_pos': tensor(1.5000, device='cuda:0'),
 'RegNTXent/y_to_x_mean_pos_normed': tensor(0.3750, device='cuda:0'),
 'Regularization/mse_loss': tensor(0.0026, device='cuda:0', grad_fn=<MseLossBackward0>),
 'Regularization/std_loss': tensor(1.9184, device='cuda:0', grad_fn=<AddBackward0>),
 'Regularization/cov_loss': tensor(0.0024, device='cuda:0', grad_fn=<AddBackward0>),
 'Regularization/loss': tensor(1.9193, device='cuda:0', gra

In [73]:
img2.shape

torch.Size([4, 512])

In [81]:
g2i_vae_loss = GraphImageVariatonalEncoderLoss(emb_dim=512, similarity="cosine", beta=1.0, detach_target=False)
g2i_vae_loss.to(device)

GraphImageVariatonalEncoder(
  (criterion): CosineSimilarity()
  (fc_mu): Linear(in_features=512, out_features=128, bias=True)
  (fc_var): Linear(in_features=512, out_features=128, bias=True)
  (decoder): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=512, bias=True)
  )
)

In [82]:
i2g_vae_loss = ImageGraphVariatonalEncoderLoss(emb_dim=512, similarity="cosine", beta=1.0, detach_target=False)
i2g_vae_loss.to(device)

ImageGraphVariatonalEncoder(
  (criterion): CosineSimilarity()
  (fc_mu): Linear(in_features=512, out_features=128, bias=True)
  (fc_var): Linear(in_features=512, out_features=128, bias=True)
  (decoder): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=512, bias=True)
  )
)

In [88]:
i2g_vae_loss(img2, comp1)

{'reconstruction_loss': tensor(0.0184, grad_fn=<MeanBackward0>),
 'kl_loss': tensor(0.1196, grad_fn=<MeanBackward1>),
 'loss': tensor(0.1380, grad_fn=<AddBackward0>)}

In [89]:
g2i_vae_loss(img2, comp1)

{'reconstruction_loss': tensor(-0.0020, grad_fn=<MeanBackward0>),
 'kl_loss': tensor(0.1290, grad_fn=<MeanBackward1>),
 'loss': tensor(0.1269, grad_fn=<AddBackward0>)}

In [94]:
GraphImageMatchingLoss?

[0;31mInit signature:[0m
[0mGraphImageMatchingLoss[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0membedding_dim[0m[0;34m:[0m [0mint[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnorm[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mname[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'GraphImageMatchingLoss'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfusion_layer[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0;34m**[0m[0mkwargs[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in
a tree structure. You can assign the submodules as regular attributes::

    import torch.nn as nn
    import torch.nn.functional as F

    class Model(nn.Module):
        def __init__(self):
            super().__init__()

In [26]:
gim = GraphImageMatchingLoss(embedding_dim=512, fusion_layer="deepset")
gim.to(device)

GraphImageMatchingLoss(
  (fusion_layer): DeepSetFusion(
    (image_proj): Linear(in_features=512, out_features=128, bias=True)
    (graph_proj): Linear(in_features=512, out_features=128, bias=True)
    (fusion): DeepsetFusionWithTransformer(
      (projections): ModuleDict(
        (image): Identity()
        (graph): Identity()
      )
      (attention): Identity()
      (pooling_function): TransformerEncoder(
        (layers): ModuleList(
          (0): TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
            )
            (linear1): Linear(in_features=128, out_features=2048, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (linear2): Linear(in_features=2048, out_features=128, bias=True)
            (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=

In [68]:
ntxent = NTXent(
    temperature=0.1,
    return_rank=True,
)

In [91]:
combination = CombinationLoss(
    losses={
        "gim": gim,
        "i2g_ve": i2g_vae_loss,
        "g2i_ve": g2i_vae_loss,
        "ntxent": ntxent,
    },
    weights=[1.0, 1.0, 1.0, 1.0],
)

In [92]:
combination(img1, comp1)

{'gim/loss': tensor(0.7045, grad_fn=<NllLossBackward0>),
 'gim/auroc': tensor(0.4062),
 'gim/accuracy': tensor(0.5833),
 'gim/recall': tensor(0.5000),
 'gim/precision': tensor(0.4000),
 'gim/f1_score': tensor(0.4444),
 'i2g_ve/reconstruction_loss': tensor(0.0099, grad_fn=<MeanBackward0>),
 'i2g_ve/kl_loss': tensor(0.1223, grad_fn=<MeanBackward1>),
 'i2g_ve/loss': tensor(0.1322, grad_fn=<AddBackward0>),
 'g2i_ve/reconstruction_loss': tensor(-0.0067, grad_fn=<MeanBackward0>),
 'g2i_ve/kl_loss': tensor(0.1290, grad_fn=<MeanBackward1>),
 'g2i_ve/loss': tensor(0.1223, grad_fn=<AddBackward0>),
 'ntxent/loss': tensor(3.9879, grad_fn=<NegBackward0>),
 'ntxent/x_to_y_top1': tensor(0.),
 'ntxent/x_to_y_top5': tensor(1.),
 'ntxent/x_to_y_top10': tensor(1.),
 'ntxent/x_to_y_mean_pos': tensor(3.7500),
 'ntxent/x_to_y_mean_pos_normed': tensor(0.9375),
 'ntxent/y_to_x_top1': tensor(0.2500),
 'ntxent/y_to_x_top5': tensor(1.),
 'ntxent/y_to_x_top10': tensor(1.),
 'ntxent/y_to_x_mean_pos': tensor(3.),
 

In [93]:
combination(img2, comp2)

{'gim/loss': tensor(0.7706, grad_fn=<NllLossBackward0>),
 'gim/auroc': tensor(0.4375),
 'gim/accuracy': tensor(0.6667),
 'gim/recall': tensor(0.),
 'gim/precision': tensor(0.),
 'gim/f1_score': tensor(0.),
 'i2g_ve/reconstruction_loss': tensor(0.0179, grad_fn=<MeanBackward0>),
 'i2g_ve/kl_loss': tensor(0.1196, grad_fn=<MeanBackward1>),
 'i2g_ve/loss': tensor(0.1376, grad_fn=<AddBackward0>),
 'g2i_ve/reconstruction_loss': tensor(0.0304, grad_fn=<MeanBackward0>),
 'g2i_ve/kl_loss': tensor(0.1300, grad_fn=<MeanBackward1>),
 'g2i_ve/loss': tensor(0.1604, grad_fn=<AddBackward0>),
 'ntxent/loss': tensor(6.6429, grad_fn=<NegBackward0>),
 'ntxent/x_to_y_top1': tensor(0.),
 'ntxent/x_to_y_top5': tensor(1.),
 'ntxent/x_to_y_top10': tensor(1.),
 'ntxent/x_to_y_mean_pos': tensor(3.5000),
 'ntxent/x_to_y_mean_pos_normed': tensor(0.8750),
 'ntxent/y_to_x_top1': tensor(0.),
 'ntxent/y_to_x_top5': tensor(1.),
 'ntxent/y_to_x_top10': tensor(1.),
 'ntxent/y_to_x_mean_pos': tensor(3.2500),
 'ntxent/y_to_