In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import logging
import os

import torch
from hydra import compose, initialize
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

from src.models.jump_cl.datamodule import BasicJUMPDataModule
from src.modules.images.timm_pretrained import CNNEncoder

In [3]:
os.getcwd()

'/mnt/2547d4d7-6732-4154-b0e1-17b0c1e0c565/Document-2/Projet2/Stage/workspace/jump_models'

In [4]:
initialize(version_base=None, config_path="../configs")

hydra.initialize()

In [106]:
cfg = compose(
    config_name="train.yaml",
    overrides=[
        "experiment=frozen_med",
        "paths.data_root_dir=../cpjump1/",
        "model/image_encoder=vit_base",
        "model/molecule_encoder=attentive_fp",
        "data/compound_transform=attentive_fp",
        "data.transform.size=384",
    ],
)

In [33]:
print(OmegaConf.to_yaml(cfg.model))

image_encoder:
  _target_: src.modules.images.timm_pretrained.CNNEncoder
  instance_model_name: vit_base_r50_s16_384.orig_in21k_ft_in1k
  target_num: ${model.embedding_dim}
  n_channels: 5
  pretrained: true
molecule_encoder:
  _target_: src.modules.molecules.attentive_fp.AttentiveFPWithLinearHead
  node_feat_size: 256
  edge_feat_size: 128
  num_layers: 4
  num_timesteps: 2
  graph_feat_size: 256
  n_tasks: 256
  dropout: 0.2
  out_dim: 512
criterion:
  _target_: src.modules.losses.contrastive_loss_with_temperature.ContrastiveLossWithTemperature
  logit_scale: 0
  logit_scale_min: -1
  logit_scale_max: 4.605170185988092
  requires_grad: false
optimizer:
  _target_: torch.optim.Adam
  _partial_: true
  betas:
  - 0.9
  - 0.999
  eps: 1.0e-08
  weight_decay: 0.01
  amsgrad: false
  lr: ${model.lr}
scheduler:
  _target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
  _partial_: true
  T_0: 7
  T_mult: 2
  eta_min: 0
  last_epoch: -1
_target_: src.models.jump_cl.module.BasicJUMPMo

In [107]:
dm = instantiate(cfg.data)

In [108]:
dm.prepare_data()

In [109]:
dm.setup("fit")

INFO:src.models.jump_cl.datamodule:Loading image metadata df from ../cpjump1//jump/models/metadata/images_metadata.parquet
INFO:src.models.jump_cl.datamodule:Loading compound dictionary from ../cpjump1//jump/models/metadata/compound_dict.json
INFO:src.models.jump_cl.datamodule:Loading train ids from ../cpjump1//jump/models/splits/med_jump_cl/train_ids.csv
INFO:src.models.jump_cl.datamodule:Train, test, val lengths: 25000, 2000, 3000
INFO:src.models.jump_cl.datamodule:Preparing train dataset
INFO:src.models.jump_cl.datamodule:Preparing validation dataset


In [110]:
for chan in ["DNA", "AGP", "ER", "Mito", "RNA"]:
    dm.train_dataset.load_df[f"FileName_Orig{chan}"] = dm.train_dataset.load_df[f"FileName_Orig{chan}"].str.replace(
        "/projects/", "../"
    )

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dm.train_dataset.load_df[f"FileName_Orig{chan}"] = dm.train_dataset.load_df[f"FileName_Orig{chan}"].str.replace("/projects/", "../")


In [111]:
dm.train_dataset[0]

{'image': tensor([[[ 1.6790,  1.6244,  1.5152,  ..., -0.7096, -0.6959, -0.7096],
          [ 1.7472,  1.7336,  1.6244,  ..., -0.6959, -0.6959, -0.6959],
          [ 1.7472,  1.7199,  1.6517,  ..., -0.7096, -0.6823, -0.6686],
          ...,
          [ 0.5052,  1.0375,  1.3378,  ..., -0.7232, -0.7232, -0.7369],
          [ 0.3687,  0.9010,  1.3241,  ..., -0.7232, -0.7369, -0.7369],
          [ 0.1230,  0.6826,  1.1876,  ..., -0.7232, -0.7232, -0.7232]],
 
         [[-1.2070, -1.1039, -0.7433,  ...,  0.3128,  0.6992,  0.7249],
          [-1.0782, -1.0524, -0.8206,  ...,  0.3128,  0.4416,  0.4158],
          [-0.8979, -1.0524, -1.0009,  ...,  0.3643,  0.3643,  0.2098],
          ...,
          [ 0.4158,  0.1325, -0.0221,  ..., -0.7691, -0.8721, -0.8206],
          [ 0.5704,  0.2355, -0.0221,  ..., -0.7691, -0.9494, -0.8721],
          [ 0.3643,  0.2870,  0.0552,  ..., -0.7176, -0.8979, -0.8463]],
 
         [[-0.5458, -0.7003, -0.8548,  ..., -0.3913,  0.0465, -0.0823],
          [-0.5973,

In [12]:
callbacks = instantiate(cfg.callbacks.jump_cl_freezer)

In [67]:
model = instantiate(cfg.model)

In [69]:
model.image_encoder.backbone.patch_embed

HybridEmbed(
  (backbone): ResNetV2(
    (stem): Sequential(
      (conv): StdConv2dSame(3, 64, kernel_size=(7, 7), stride=(2, 2), bias=False)
      (norm): GroupNormAct(
        32, 64, eps=1e-05, affine=True
        (drop): Identity()
        (act): ReLU(inplace=True)
      )
      (pool): MaxPool2dSame(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), dilation=(1, 1), ceil_mode=False)
    )
    (stages): Sequential(
      (0): ResNetStage(
        (blocks): Sequential(
          (0): Bottleneck(
            (downsample): DownsampleConv(
              (conv): StdConv2dSame(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (norm): GroupNormAct(
                32, 256, eps=1e-05, affine=True
                (drop): Identity()
                (act): Identity()
              )
            )
            (conv1): StdConv2dSame(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (norm1): GroupNormAct(
              32, 64, eps=1e-05, affine=True
        

In [113]:
x = dm.train_dataset[0]

In [46]:
x.ndata["h"].shape, x.edata["e"].shape

(torch.Size([14, 74]), torch.Size([44, 13]))

In [115]:
x["image"].unsqueeze(0).shape

torch.Size([1, 5, 384, 384])

In [65]:
model.image_encoder.backbone.patch_embed.proj

Conv2d(1024, 768, kernel_size=(1, 1), stride=(1, 1))

In [84]:
logging.basicConfig(level=logging.INFO)

In [116]:
im = CNNEncoder(
    instance_model_name="vit_base_r50_s16_384.orig_in21k_ft_in1k",
    target_num=512,
    n_channels=5,
    pretrained=True,
)

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/vit_base_r50_s16_384.orig_in21k_ft_in1k)
INFO:timm.models._hub:[timm/vit_base_r50_s16_384.orig_in21k_ft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
INFO:src.modules.images.timm_pretrained:Using model vit_base_resnet50 with projection head


In [103]:
x["image"].shape

torch.Size([5, 256, 256])

In [118]:
im.forward(x["image"].unsqueeze(0)).shape

torch.Size([1, 512])

In [119]:
dl = dm.train_dataloader()



In [None]:
b = {
    "image": x["image"].unsqueeze(0),
    "compound": x["compound"],
}

{'image': tensor([[[-0.3262, -0.3435, -0.3262,  ..., -0.3435, -0.3262, -0.3089],
          [-0.3262, -0.3435, -0.3089,  ..., -0.2915, -0.3262, -0.3089],
          [-0.3089, -0.3262, -0.3262,  ..., -0.3435, -0.3262, -0.3262],
          ...,
          [-0.3089, -0.3262, -0.3262,  ..., -0.3262, -0.3262, -0.3262],
          [-0.3435, -0.3435, -0.3089,  ..., -0.3262, -0.3262, -0.3262],
          [-0.3262, -0.2915, -0.2915,  ..., -0.3089, -0.3435, -0.3262]],
 
         [[-0.7364, -0.7008, -0.7008,  ..., -0.7008, -0.7008, -0.7008],
          [-0.7364, -0.7364, -0.7364,  ..., -0.6830, -0.6830, -0.7008],
          [-0.7186, -0.7364, -0.7364,  ..., -0.7008, -0.6652, -0.6830],
          ...,
          [-0.6830, -0.6652, -0.7008,  ..., -0.3446, -0.4337, -0.5405],
          [-0.6652, -0.7008, -0.6652,  ..., -0.3090, -0.2734, -0.3624],
          [-0.7186, -0.6830, -0.7008,  ..., -0.3090, -0.2378, -0.2912]],
 
         [[-0.7122, -0.6796, -0.6959,  ..., -0.6633, -0.6308, -0.6796],
          [-0.6796,

: 

In [None]:
projects / cpjump1 / jump / models / example_batch / simple_jump_cl / batch.pth()

In [51]:
model.molecule_encoder.forward(x["compound"], get_node_weight=False)

tensor([[-0.0429, -0.0390,  0.0602,  0.0495, -0.0300,  0.0639, -0.0669,  0.0061,
         -0.0332, -0.0858,  0.0504,  0.0589, -0.0047,  0.0326, -0.0458, -0.0006,
         -0.1529, -0.0057, -0.0414, -0.0168, -0.0141, -0.0045,  0.0413, -0.0539,
         -0.0514, -0.0826,  0.0031,  0.0320,  0.0294,  0.0219, -0.0783,  0.0549,
          0.0621,  0.0181,  0.0320, -0.0391,  0.0571, -0.0513, -0.0249,  0.0668,
         -0.0379, -0.0061,  0.0385, -0.0471,  0.0023,  0.0157,  0.0326,  0.0374,
         -0.0731, -0.0163, -0.0617,  0.0898, -0.1262,  0.0227,  0.0920,  0.0767,
          0.0809, -0.0212,  0.0421,  0.0346,  0.0787, -0.0290,  0.1543,  0.0636,
         -0.0374,  0.0189,  0.0615,  0.0857, -0.0751, -0.0087, -0.0537, -0.0732,
          0.1246, -0.1340,  0.0598,  0.0149,  0.0045,  0.0066, -0.0073, -0.0267,
          0.0229, -0.1014, -0.0145,  0.0433,  0.0187,  0.0412,  0.0437, -0.0696,
          0.0050,  0.0152,  0.0164,  0.0696, -0.0305, -0.0265, -0.0670,  0.0506,
         -0.0842,  0.0778, -

In [26]:
callbacks.train_bn = False

In [27]:
callbacks.freeze_before_training(model)

In [28]:
params_groups = [
    {
        "params": list(model.image_head.parameters()),
        "name": "image_projection_head",
    },
    {
        "params": list(model.molecule_head.parameters()),
        "name": "molecule_projection_head",
    },
    {
        "params": list(model.criterion.parameters()),
        "name": "criterion",
    },
    {
        "params": list(model.image_backbone.parameters()),
        "name": "image_encoder",
    },
    {
        "params": list(model.molecule_backbone.parameters()),
        "name": "molecule_encoder",
    },
]
filtered_params_groups = [
    {
        "params": list(filter(lambda p: p.requires_grad, group["params"])),
        "name": group["name"],
    }
    for group in params_groups
]

params_len = {group["name"]: len(group["params"]) for group in params_groups}
group_lens = {group["name"]: len(group["params"]) for group in filtered_params_groups}

group_to_keep = [
    group["name"]
    for group in filtered_params_groups
    if group_lens[group["name"]] > 0 and group["name"] not in model.params_group_to_ignore
]

print(f"Number of params in each groups:\n{params_len}")
print(f"Number of require grad params in each groups:\n{group_lens}")
print(f"Params groups to keep:\n{group_to_keep}")

Number of params in each groups:
{'image_projection_head': 2, 'molecule_projection_head': 2, 'criterion': 1, 'image_encoder': 60, 'molecule_encoder': 42}
Number of require grad params in each groups:
{'image_projection_head': 2, 'molecule_projection_head': 2, 'criterion': 0, 'image_encoder': 0, 'molecule_encoder': 0}
Params groups to keep:
['image_projection_head', 'molecule_projection_head']


In [21]:
model.image_encoder

CNNEncoder(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (drop_block): Identity()
        (act1): ReLU(inplace=True)
        (aa): Identity()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (act2): ReLU(inplace=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), paddin