# Create MOA datasets

In [1]:
%load_ext autoreload
%autoreload 2

## Tests

In [10]:
import logging
import os
import os.path as osp
import shutil
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import yaml
from lightning import LightningDataModule
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, Dataset
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy

from src.eval.moa.datamodule import JumpMOADataModule
from src.eval.moa.module import JumpMOAImageModule
from src.modules.collate_fn import image_graph_label_collate_function, label_graph_collate_function
from src.modules.compound_transforms import DGLPretrainedFromSmiles
from src.modules.images.timm_pretrained import CNNEncoder
from src.modules.molecules.dgllife_gin import GINPretrainedWithLinearHead

# from src.modules.transforms import DefaultJUMPTransform
from src.splitters import StratifiedSplitter
from src.utils.io import download_and_extract_zip, load_image_paths_to_array

In [11]:
for i in range(1, 4):
    if not Path(f"../cpjump{i}/jump/").exists():
        print(f"Mounting cpjump{i}...")
        os.system(f"sshfs bioclust:/projects/cpjump{i}/ ../cpjump{i}")
    else:
        print(f"cpjump{i} already mounted.")

Mounting cpjump1...
Mounting cpjump2...
Mounting cpjump3...


In [12]:
metadata_path = "../cpjump1/jump/metadata"
load_data_path = "../cpjump1/jump/load_data"

In [13]:
os.listdir(metadata_path)

['compound.csv.gz',
 'crispr.csv.gz',
 'microscope_config.csv',
 'microscope_filter.csv',
 'orf.csv.gz',
 'plate.csv.gz',
 'README.md',
 'well.csv.gz',
 'compound.csv',
 'crispr.csv',
 'orf.csv',
 'plate.csv',
 'well.csv',
 'complete_metadata.csv',
 'resolution.csv',
 'JUMP-Target-1_compound_metadata.tsv',
 'JUMP-Target-1_compound_platemap.tsv',
 'JUMP-Target-1_crispr_metadata.tsv',
 'JUMP-Target-1_crispr_platemap.tsv',
 'JUMP-Target-1_orf_metadata.tsv',
 'JUMP-Target-1_orf_platemap.tsv',
 'JUMP-Target-2_compound_metadata.tsv',
 'JUMP-Target-2_compound_platemap.tsv',
 'JUMP-MOA_compound_metadata.tsv',
 'local_metadata.csv']

In [14]:
load_df = pd.read_parquet(osp.join(load_data_path, "final"))

In [15]:
load_df

Unnamed: 0,Metadata_Source,Metadata_Batch,Metadata_Plate,Metadata_Well,Metadata_Site,FileName_OrigAGP,FileName_OrigDNA,FileName_OrigER,FileName_OrigMito,FileName_OrigRNA
0,source_10,2021_05_31_U2OS_48_hr_run1,Dest210531-152149,A01,1,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...
1,source_10,2021_05_31_U2OS_48_hr_run1,Dest210531-152149,A01,3,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...
2,source_10,2021_05_31_U2OS_48_hr_run1,Dest210531-152149,A01,6,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...
3,source_10,2021_05_31_U2OS_48_hr_run1,Dest210531-152149,A02,1,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...
4,source_10,2021_05_31_U2OS_48_hr_run1,Dest210531-152149,A02,2,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...
...,...,...,...,...,...,...,...,...,...,...
5121195,source_9,20211103-Run16,GR00004421,Z47,3,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...
5121196,source_9,20211103-Run16,GR00004421,Z47,4,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...
5121197,source_9,20211103-Run16,GR00004421,Z48,1,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...
5121198,source_9,20211103-Run16,GR00004421,Z48,2,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...


## Creating the dataset

In [16]:
moa = pd.read_csv(
    osp.join(metadata_path, "JUMP-MOA_compound_metadata.tsv"),
    sep="\t",
    usecols=["smiles", "InChIKey", "moa", "pert_type"],
)

In [17]:
meta = pd.read_csv(osp.join(metadata_path, "local_metadata.csv"))

In [18]:
compounds = pd.read_csv(osp.join(metadata_path, "compound.csv"))

In [19]:
meta

Unnamed: 0,Metadata_Source,Metadata_Batch,Metadata_Plate,Metadata_PlateType,Metadata_Well,Metadata_JCP2022,Metadata_InChIKey,Metadata_InChI,Metadata_Sites_Per_Well,trt
0,source_1,Batch1_20221004,UL001641,COMPOUND,A02,JCP2022_033924,IAZDPXIOMUYVGZ-UHFFFAOYSA-N,InChI=1S/C2H6OS/c1-4(2)3/h1-2H3,4,compound_neg
1,source_1,Batch1_20221004,UL001641,COMPOUND,A03,JCP2022_085227,SRVFFFJZQVENJC-UHFFFAOYSA-N,InChI=1S/C17H30N2O5/c1-6-23-17(22)14-13(24-14)...,4,compound_pos
2,source_1,Batch1_20221004,UL001641,COMPOUND,A04,JCP2022_033924,IAZDPXIOMUYVGZ-UHFFFAOYSA-N,InChI=1S/C2H6OS/c1-4(2)3/h1-2H3,4,compound_neg
3,source_1,Batch1_20221004,UL001641,COMPOUND,A05,JCP2022_036592,IPPYTNWGGOIMDZ-UHFFFAOYSA-N,InChI=1S/C17H12ClF3N4O2/c1-24(15-13(18)6-9(7-2...,4,compound_trt
4,source_1,Batch1_20221004,UL001641,COMPOUND,A06,JCP2022_071885,PYZMXVUWLLQNEP-UHFFFAOYSA-N,InChI=1S/C10H7ClN4/c11-9-3-1-2-8(4-9)6-15-7-13...,4,compound_trt
...,...,...,...,...,...,...,...,...,...,...
861157,source_9,20211103-Run16,GR00004421,COMPOUND,Z44,JCP2022_999999,,,4,compound_trt
861158,source_9,20211103-Run16,GR00004421,COMPOUND,Z45,JCP2022_999999,,,4,compound_trt
861159,source_9,20211103-Run16,GR00004421,COMPOUND,Z46,JCP2022_999999,,,4,compound_trt
861160,source_9,20211103-Run16,GR00004421,COMPOUND,Z47,JCP2022_033924,IAZDPXIOMUYVGZ-UHFFFAOYSA-N,InChI=1S/C2H6OS/c1-4(2)3/h1-2H3,4,compound_neg


In [20]:
jump_moa = moa.dropna(subset=["InChIKey", "moa"]).merge(
    compounds, left_on="InChIKey", right_on="Metadata_InChIKey", how="inner"
)
jump_moa

Unnamed: 0,InChIKey,moa,pert_type,smiles,Metadata_JCP2022,Metadata_InChIKey,Metadata_InChI
0,ZYVXTMKTGDARKR-UHFFFAOYSA-N,DYRK inhibitor,trt,COc1cc(ccc1Nc1nccc(n1)-c1cn(C)c2cnccc12)N1CCN(...,JCP2022_116560,ZYVXTMKTGDARKR-UHFFFAOYSA-N,InChI=1S/C24H27N7O/c1-29-10-12-31(13-11-29)17-...
1,QDBVSOZTVKXUES-UHFFFAOYSA-N,histone lysine demethylase inhibitor,trt,CN(C)CCCNC(=O)c1ccc(cc1)-c1cc(O)c2ncccc2c1,JCP2022_072637,QDBVSOZTVKXUES-UHFFFAOYSA-N,InChI=1S/C21H23N3O2/c1-24(2)12-4-11-23-21(26)1...
2,RIJLVEAXPNLDTC-UHFFFAOYSA-N,JAK inhibitor,trt,O=C(Nc1nc2cccc(-c3ccc(CN4CCS(=O)(=O)CC4)cc3)n2...,JCP2022_078581,RIJLVEAXPNLDTC-UHFFFAOYSA-N,InChI=1S/C21H23N5O3S/c27-20(17-8-9-17)23-21-22...
3,YGUFCDOEKKVKJK-UHFFFAOYSA-N,protein tyrosine kinase inhibitor,trt,CC1(N)CCN(CC1)c1cnc(c(N)n1)-c1cccc(Cl)c1Cl,JCP2022_108339,YGUFCDOEKKVKJK-UHFFFAOYSA-N,InChI=1S/C16H19Cl2N5/c1-16(20)5-7-23(8-6-16)12...
4,BCZUAADEACICHN-UHFFFAOYSA-N,hepatocyte growth factor receptor inhibitor,trt,Cn1cc(cn1)-c1ccc2nnc(Sc3ccc4ncccc4c3)n2n1,JCP2022_005529,BCZUAADEACICHN-UHFFFAOYSA-N,InChI=1S/C18H13N7S/c1-24-11-13(10-20-24)16-6-7...
5,XIXXNJFWPAVKFR-UHFFFAOYSA-N,phosphodiesterase inhibitor,trt,Cc1c(O)nc2Nc3cccc(Cl)c3Cn12,JCP2022_103938,XIXXNJFWPAVKFR-UHFFFAOYSA-N,InChI=1S/C11H10ClN3O/c1-6-10(16)14-11-13-9-4-2...
6,WBKCKEHGXNWYMO-UHFFFAOYSA-N,histone lysine demethylase inhibitor,trt,CCOC(=O)CCNc1cc(nc(n1)-c1ccccn1)N1CCc2ccccc2CC1,JCP2022_097676,WBKCKEHGXNWYMO-UHFFFAOYSA-N,InChI=1S/C24H27N5O2/c1-2-31-23(30)10-14-26-21-...
7,VUIRVWPJNKZOSS-UHFFFAOYSA-N,ubiquitin specific protease inhibitor,trt,CC(C)c1ccccc1-c1ncc(C)c(NCc2ccc(cc2)-n2ccnn2)n1,JCP2022_096342,VUIRVWPJNKZOSS-UHFFFAOYSA-N,InChI=1S/C23H24N6/c1-16(2)20-6-4-5-7-21(20)23-...
8,SDGJBAUIGHSMRI-UHFFFAOYSA-N,AMPK inhibitor,trt,CCC(=O)Nc1cccc(Oc2nc(Nc3ccc(cc3OC)N3CCN(C)CC3)...,JCP2022_082441,SDGJBAUIGHSMRI-UHFFFAOYSA-N,InChI=1S/C25H29ClN6O3/c1-4-23(33)28-17-6-5-7-1...
9,PHXJVRSECIGDHY-UHFFFAOYSA-N,Bcr-Abl kinase inhibitor,trt,CN1CCN(Cc2ccc(NC(=O)c3ccc(C)c(c3)C#Cc3cnc4cccn...,JCP2022_068713,PHXJVRSECIGDHY-UHFFFAOYSA-N,InChI=1S/C29H27F3N6O/c1-20-5-6-22(16-21(20)8-1...


In [21]:
moa

Unnamed: 0,InChIKey,moa,pert_type,smiles
0,ZYVXTMKTGDARKR-UHFFFAOYSA-N,DYRK inhibitor,trt,COc1cc(ccc1Nc1nccc(n1)-c1cn(C)c2cnccc12)N1CCN(...
1,ODADKLYLWWCHNB-LDYBVBFYSA-N,HMGCR inhibitor,trt,CC(C)=CCC\C(C)=C\CC\C(C)=C\CC[C@]1(C)CCc2cc(O)...
2,QDBVSOZTVKXUES-UHFFFAOYSA-N,histone lysine demethylase inhibitor,trt,CN(C)CCCNC(=O)c1ccc(cc1)-c1cc(O)c2ncccc2c1
3,CXJCGSPAPOTTSF-VURMDHGXSA-N,CDC inhibitor,trt,CCOC(=O)c1c(\C(=C/N)C#N)c2ccc(Cl)c(Cl)c2n1C
4,RFZQYGBLRIKROZ-PCLIKHOPSA-N,phosphoinositide dependent kinase inhibitor,trt,Cc1cccc(\C=N\Nc2cc(N3CCOCC3)n3nc(cc3n2)-c2ccnc...
...,...,...,...,...
86,FPYJSJDOHRDAMT-KQWNVCNZSA-N,hepatocyte growth factor receptor inhibitor,trt,CN(c1cccc(Cl)c1)S(=O)(=O)c1ccc2NC(=O)\C(=C/c3[...
87,DDLZLOKCJHBUHD-UJZLGWIISA-N,glycogen synthase kinase inhibitor,trt,O\N=C1/C(Nc2ccccc12)=C1/C(=O)Nc2cc(Br)ccc12
88,AQGNHMOJWBZFQQ-UHFFFAOYSA-N,glycogen synthase kinase inhibitor,trt,Cc1c[nH]c(n1)-c1cnc(NCCNc2ccc(cn2)C#N)nc1-c1cc...
89,LKZLGMAAKNEGCH-UHFFFAOYSA-N,ubiquitin specific protease inhibitor,trt,CC(=O)c1cc(c(Sc2cccc(Cl)c2Cl)s1)[N+]([O-])=O


In [22]:
moa_wells = moa.dropna(subset=["InChIKey", "moa"]).merge(
    meta, left_on="InChIKey", right_on="Metadata_InChIKey", how="inner"
)

In [23]:
moa_wells

Unnamed: 0,InChIKey,moa,pert_type,smiles,Metadata_Source,Metadata_Batch,Metadata_Plate,Metadata_PlateType,Metadata_Well,Metadata_JCP2022,Metadata_InChIKey,Metadata_InChI,Metadata_Sites_Per_Well,trt
0,ZYVXTMKTGDARKR-UHFFFAOYSA-N,DYRK inhibitor,trt,COc1cc(ccc1Nc1nccc(n1)-c1cn(C)c2cnccc12)N1CCN(...,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,TARGET2,J09,JCP2022_116560,ZYVXTMKTGDARKR-UHFFFAOYSA-N,InChI=1S/C24H27N7O/c1-29-10-12-31(13-11-29)17-...,6,target_trt
1,ZYVXTMKTGDARKR-UHFFFAOYSA-N,DYRK inhibitor,trt,COc1cc(ccc1Nc1nccc(n1)-c1cn(C)c2cnccc12)N1CCN(...,source_10,2021_08_09_U2OS_48_hr_run13,Dest210727-153003,TARGET2,J09,JCP2022_116560,ZYVXTMKTGDARKR-UHFFFAOYSA-N,InChI=1S/C24H27N7O/c1-29-10-12-31(13-11-29)17-...,6,target_trt
2,ZYVXTMKTGDARKR-UHFFFAOYSA-N,DYRK inhibitor,trt,COc1cc(ccc1Nc1nccc(n1)-c1cn(C)c2cnccc12)N1CCN(...,source_10,2021_08_12_U2OS_48_hr_run15,Dest210803-153958,TARGET2,J09,JCP2022_116560,ZYVXTMKTGDARKR-UHFFFAOYSA-N,InChI=1S/C24H27N7O/c1-29-10-12-31(13-11-29)17-...,6,target_trt
3,ZYVXTMKTGDARKR-UHFFFAOYSA-N,DYRK inhibitor,trt,COc1cc(ccc1Nc1nccc(n1)-c1cn(C)c2cnccc12)N1CCN(...,source_10,2021_08_17_U2OS_48_hr_run16,Dest210809-134534,TARGET2,J09,JCP2022_116560,ZYVXTMKTGDARKR-UHFFFAOYSA-N,InChI=1S/C24H27N7O/c1-29-10-12-31(13-11-29)17-...,6,target_trt
4,ZYVXTMKTGDARKR-UHFFFAOYSA-N,DYRK inhibitor,trt,COc1cc(ccc1Nc1nccc(n1)-c1cn(C)c2cnccc12)N1CCN(...,source_10,2021_08_20_U2OS_48_hr_run17,Dest210810-173723,TARGET2,J09,JCP2022_116560,ZYVXTMKTGDARKR-UHFFFAOYSA-N,InChI=1S/C24H27N7O/c1-29-10-12-31(13-11-29)17-...,6,target_trt
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8817,AQGNHMOJWBZFQQ-UHFFFAOYSA-N,glycogen synthase kinase inhibitor,trt,Cc1c[nH]c(n1)-c1cnc(NCCNc2ccc(cn2)C#N)nc1-c1cc...,source_8,J3,A1170543,COMPOUND,G17,JCP2022_003104,AQGNHMOJWBZFQQ-UHFFFAOYSA-N,InChI=1S/C22H18Cl2N8/c1-13-10-29-21(31-13)17-1...,9,compound_trt
8818,LKZLGMAAKNEGCH-UHFFFAOYSA-N,ubiquitin specific protease inhibitor,trt,CC(=O)c1cc(c(Sc2cccc(Cl)c2Cl)s1)[N+]([O-])=O,source_2,20210823_Batch_10,1086291979,COMPOUND,E18,JCP2022_050096,LKZLGMAAKNEGCH-UHFFFAOYSA-N,InChI=1S/C12H7Cl2NO3S2/c1-6(16)10-5-8(15(17)18...,6,compound_trt
8819,LKZLGMAAKNEGCH-UHFFFAOYSA-N,ubiquitin specific protease inhibitor,trt,CC(=O)c1cc(c(Sc2cccc(Cl)c2Cl)s1)[N+]([O-])=O,source_3,CP60,BR5874b3,COMPOUND,C20,JCP2022_050096,LKZLGMAAKNEGCH-UHFFFAOYSA-N,InChI=1S/C12H7Cl2NO3S2/c1-6(16)10-5-8(15(17)18...,9,compound_trt
8820,LKZLGMAAKNEGCH-UHFFFAOYSA-N,ubiquitin specific protease inhibitor,trt,CC(=O)c1cc(c(Sc2cccc(Cl)c2Cl)s1)[N+]([O-])=O,source_3,CP_36_all_Phenix1,BAY5874b,COMPOUND,C20,JCP2022_050096,LKZLGMAAKNEGCH-UHFFFAOYSA-N,InChI=1S/C12H7Cl2NO3S2/c1-6(16)10-5-8(15(17)18...,9,compound_trt


In [40]:
moa_wells2 = moa_wells.query("Metadata_PlateType != 'COMPOUND'")

In [41]:
moa_wells2.groupby("moa").Metadata_Well.nunique()

moa
Aurora kinase inhibitor                        14
Bcr-Abl kinase inhibitor                       10
DYRK inhibitor                                  5
JAK inhibitor                                   5
JNK inhibitor                                   5
bromodomain inhibitor                           5
hepatocyte growth factor receptor inhibitor     5
phospholipase inhibitor                         5
ubiquitin specific protease inhibitor           5
Name: Metadata_Well, dtype: int64

In [38]:
# .groupby("moa", as_index=False).apply(lambda x: x.sample(min(900, len(x)), replace=False))

Metadata_Well
J01     244
O23     217
B01     140
F24     140
N24     140
J09     117
I11     104
C06     104
J08     104
D15     104
L15     104
D12     104
N12     104
F16     103
F11      13
K10      13
I20      13
F13      13
I18      13
N09      13
C20      13
H07      13
B09      13
R09       9
B33       9
V35       9
V11       9
AD09      9
Y44       9
F35       9
N33       9
AD33      9
K34       9
Y20       9
I44       9
I42       9
Y18       9
Y42       9
H31       9
X07       9
X31       9
J33       9
Z09       9
Z33       9
F37       9
V13       9
V37       9
C44       9
S20       9
S44       9
AA10      9
AA34      9
R33       9
O24       1
K01       1
G24       1
C01       1
Name: count, dtype: int64

In [34]:
moa_wells.query("trt != 'compound_pos'").Metadata_PlateType.value_counts()

Metadata_PlateType
TARGET2     1399
CRISPR       564
ORF          217
COMPOUND     151
TARGET1       40
Name: count, dtype: int64

In [24]:
moa_load_df = load_df.merge(
    moa_wells, on=["Metadata_Source", "Metadata_Batch", "Metadata_Plate", "Metadata_Well"], how="right"
)

In [25]:
moa_load_df

Unnamed: 0,Metadata_Source,Metadata_Batch,Metadata_Plate,Metadata_Well,Metadata_Site,FileName_OrigAGP,FileName_OrigDNA,FileName_OrigER,FileName_OrigMito,FileName_OrigRNA,InChIKey,moa,pert_type,smiles,Metadata_PlateType,Metadata_JCP2022,Metadata_InChIKey,Metadata_InChI,Metadata_Sites_Per_Well,trt
0,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,J09,1,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,ZYVXTMKTGDARKR-UHFFFAOYSA-N,DYRK inhibitor,trt,COc1cc(ccc1Nc1nccc(n1)-c1cn(C)c2cnccc12)N1CCN(...,TARGET2,JCP2022_116560,ZYVXTMKTGDARKR-UHFFFAOYSA-N,InChI=1S/C24H27N7O/c1-29-10-12-31(13-11-29)17-...,6,target_trt
1,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,J09,2,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,ZYVXTMKTGDARKR-UHFFFAOYSA-N,DYRK inhibitor,trt,COc1cc(ccc1Nc1nccc(n1)-c1cn(C)c2cnccc12)N1CCN(...,TARGET2,JCP2022_116560,ZYVXTMKTGDARKR-UHFFFAOYSA-N,InChI=1S/C24H27N7O/c1-29-10-12-31(13-11-29)17-...,6,target_trt
2,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,J09,3,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,ZYVXTMKTGDARKR-UHFFFAOYSA-N,DYRK inhibitor,trt,COc1cc(ccc1Nc1nccc(n1)-c1cn(C)c2cnccc12)N1CCN(...,TARGET2,JCP2022_116560,ZYVXTMKTGDARKR-UHFFFAOYSA-N,InChI=1S/C24H27N7O/c1-29-10-12-31(13-11-29)17-...,6,target_trt
3,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,J09,4,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,ZYVXTMKTGDARKR-UHFFFAOYSA-N,DYRK inhibitor,trt,COc1cc(ccc1Nc1nccc(n1)-c1cn(C)c2cnccc12)N1CCN(...,TARGET2,JCP2022_116560,ZYVXTMKTGDARKR-UHFFFAOYSA-N,InChI=1S/C24H27N7O/c1-29-10-12-31(13-11-29)17-...,6,target_trt
4,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,J09,5,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,ZYVXTMKTGDARKR-UHFFFAOYSA-N,DYRK inhibitor,trt,COc1cc(ccc1Nc1nccc(n1)-c1cn(C)c2cnccc12)N1CCN(...,TARGET2,JCP2022_116560,ZYVXTMKTGDARKR-UHFFFAOYSA-N,InChI=1S/C24H27N7O/c1-29-10-12-31(13-11-29)17-...,6,target_trt
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
31273,source_8,J3,A1170542,C20,4,/projects/cpjump2/jump/images/source_8/J3/A117...,/projects/cpjump2/jump/images/source_8/J3/A117...,/projects/cpjump2/jump/images/source_8/J3/A117...,/projects/cpjump2/jump/images/source_8/J3/A117...,/projects/cpjump2/jump/images/source_8/J3/A117...,LKZLGMAAKNEGCH-UHFFFAOYSA-N,ubiquitin specific protease inhibitor,trt,CC(=O)c1cc(c(Sc2cccc(Cl)c2Cl)s1)[N+]([O-])=O,COMPOUND,JCP2022_050096,LKZLGMAAKNEGCH-UHFFFAOYSA-N,InChI=1S/C12H7Cl2NO3S2/c1-6(16)10-5-8(15(17)18...,9,compound_trt
31274,source_8,J3,A1170542,C20,5,/projects/cpjump2/jump/images/source_8/J3/A117...,/projects/cpjump2/jump/images/source_8/J3/A117...,/projects/cpjump2/jump/images/source_8/J3/A117...,/projects/cpjump2/jump/images/source_8/J3/A117...,/projects/cpjump2/jump/images/source_8/J3/A117...,LKZLGMAAKNEGCH-UHFFFAOYSA-N,ubiquitin specific protease inhibitor,trt,CC(=O)c1cc(c(Sc2cccc(Cl)c2Cl)s1)[N+]([O-])=O,COMPOUND,JCP2022_050096,LKZLGMAAKNEGCH-UHFFFAOYSA-N,InChI=1S/C12H7Cl2NO3S2/c1-6(16)10-5-8(15(17)18...,9,compound_trt
31275,source_8,J3,A1170542,C20,7,/projects/cpjump2/jump/images/source_8/J3/A117...,/projects/cpjump2/jump/images/source_8/J3/A117...,/projects/cpjump2/jump/images/source_8/J3/A117...,/projects/cpjump2/jump/images/source_8/J3/A117...,/projects/cpjump2/jump/images/source_8/J3/A117...,LKZLGMAAKNEGCH-UHFFFAOYSA-N,ubiquitin specific protease inhibitor,trt,CC(=O)c1cc(c(Sc2cccc(Cl)c2Cl)s1)[N+]([O-])=O,COMPOUND,JCP2022_050096,LKZLGMAAKNEGCH-UHFFFAOYSA-N,InChI=1S/C12H7Cl2NO3S2/c1-6(16)10-5-8(15(17)18...,9,compound_trt
31276,source_8,J3,A1170542,C20,8,/projects/cpjump2/jump/images/source_8/J3/A117...,/projects/cpjump2/jump/images/source_8/J3/A117...,/projects/cpjump2/jump/images/source_8/J3/A117...,/projects/cpjump2/jump/images/source_8/J3/A117...,/projects/cpjump2/jump/images/source_8/J3/A117...,LKZLGMAAKNEGCH-UHFFFAOYSA-N,ubiquitin specific protease inhibitor,trt,CC(=O)c1cc(c(Sc2cccc(Cl)c2Cl)s1)[N+]([O-])=O,COMPOUND,JCP2022_050096,LKZLGMAAKNEGCH-UHFFFAOYSA-N,InChI=1S/C12H7Cl2NO3S2/c1-6(16)10-5-8(15(17)18...,9,compound_trt


In [47]:
sampled_targets = moa_load_df.groupby("moa", as_index=False).apply(lambda x: x.sample(min(900, len(x)), replace=False))

In [27]:
moa_load_df.moa.value_counts()

moa
Aurora kinase inhibitor                           23268
Bcr-Abl kinase inhibitor                           1608
ubiquitin specific protease inhibitor               846
JAK inhibitor                                       816
bromodomain inhibitor                               816
DYRK inhibitor                                      816
hepatocyte growth factor receptor inhibitor         816
JNK inhibitor                                       810
phospholipase inhibitor                             808
glycogen synthase kinase inhibitor                   78
antihistamine                                        78
pyruvate dehydrogenase kinase inhibitor              64
histone lysine methyltransferase inhibitor           48
histone lysine demethylase inhibitor                 48
LXR agonist                                          48
acetylcholine receptor antagonist                    46
smoothened receptor agonist                          36
MAP kinase inhibitor                        

## Dataset

In [53]:
class JumpMOADataset(Dataset):
    def __init__(
        self,
        moa_load_df: pd.DataFrame,
        transform: Optional[Callable] = None,
        compound_transform: Optional[Callable] = None,
        return_image: bool = True,
        return_compound: bool = False,
        target_col: str = "moa",
        smiles_col: str = "smiles",
        use_cache: bool = True,
        channels: Optional[List[str]] = None,
        data_root_dir: Optional[str] = None,
    ):
        """Initializes the dataset.

        Args:
            moa_load_df (pd.DataFrame):
                The load dataframe with the metadata.
            transform (Optional[Callable], optional):
                The transform to apply to the images.
            compound_transform (Optional[Callable], optional):
                The compound transform to apply to the compounds.
                Defaults to None.
            use_cache (bool, optional):
                Whether to use a cache for the compounds.
                Defaults to True.
            smiles_col (str, optional):
                The name of the column with the smiles.
                Defaults to "smiles".
            channels (Optional[List[str]], optional):
                The channels to use.
                Defaults to None.
            return_image (bool, optional):
                Whether to return the image.
                Defaults to True.
            return_compound (bool, optional):
                Whether to return the compound.
                Defaults to False.
        """
        super().__init__()

        self.moa_load_df = moa_load_df
        self.target_col = target_col
        self.smiles_col = smiles_col
        self.channels = channels or ["DNA", "AGP", "ER", "Mito", "RNA"]
        self.return_image = return_image
        self.return_compound = return_compound

        if data_root_dir:
            for channel in self.channels:
                self.moa_load_df.loc[:, f"FileName_Orig{channel}"] = self.moa_load_df[
                    f"FileName_Orig{channel}"
                ].str.replace("/projects/", data_root_dir)

        self.targets = self.moa_load_df[self.target_col].unique()
        self.targets.sort()
        self.target_to_num = {target: i for i, target in enumerate(self.targets)}

        self.transform = transform
        self.compound_transform = compound_transform

        self.n_compounds = moa_load_df[smiles_col].nunique()
        self.n_images = len(self.moa_load_df)

        self.use_cache = use_cache
        self.compound_cache = {}

    def __len__(self):
        return self.n_images

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(n_compounds={self.n_compounds}, n_images={self.n_images})"

    def get_transformed_compound(self, compound):
        if self.use_cache and compound in self.compound_cache:
            return self.compound_cache[compound]
        else:
            transformed_compound = self.compound_transform(compound)
            if self.use_cache:
                self.compound_cache[compound] = transformed_compound
            return transformed_compound

    def get_default_collate_fn(self):
        """Return the default collate function that should be used for this dataset."""
        if self.return_compound and self.return_image:
            return image_graph_label_collate_function
        elif not self.return_compound:
            return None
        elif not self.return_image:
            return label_graph_collate_function

    def __getitem__(self, idx: int):
        """Returns the data at the given index.

        Args:
            idx (int):
                The index of the data to return.

        Returns:
            Tuple[str, torch.Tensor]:
                The smile and the classes.
        """
        row = self.moa_load_df.iloc[idx]
        output = {"label": self.target_to_num[row[self.target_col]]}

        if self.return_image:
            img_paths = [row[f"FileName_Orig{channel}"] for channel in self.channels]

            img_array = load_image_paths_to_array(img_paths)  # A numpy array: (5, 768, 768)
            img_array = torch.from_numpy(img_array)

            if self.transform:
                img_array = self.transform(img_array)

            output["image"] = img_array

        if self.return_compound:
            smile = self.moa_load_df[self.smiles_col].iloc[idx]

            if self.compound_transform:
                transformed_compound = self.get_transformed_compound(smile)
            else:
                transformed_compound = smile

            output["compound"] = transformed_compound

        return output

### Tests

In [4]:
for channel in ["DNA", "AGP", "ER", "Mito", "RNA"]:
    moa_load_df[f"FileName_Orig{channel}"] = moa_load_df[f"FileName_Orig{channel}"].str.replace("/projects/", "../")

NameError: name 'moa_load_df' is not defined

In [None]:
moa_load_df

Unnamed: 0,Metadata_Source,Metadata_Batch,Metadata_Plate,Metadata_Well,Metadata_Site,FileName_OrigAGP,FileName_OrigDNA,FileName_OrigER,FileName_OrigMito,FileName_OrigRNA,InChIKey,moa,pert_type,smiles,Metadata_PlateType,Metadata_JCP2022,Metadata_InChIKey,Metadata_InChI,Metadata_Sites_Per_Well,trt
0,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,J09,1,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,ZYVXTMKTGDARKR-UHFFFAOYSA-N,DYRK inhibitor,trt,COc1cc(ccc1Nc1nccc(n1)-c1cn(C)c2cnccc12)N1CCN(...,TARGET2,JCP2022_116560,ZYVXTMKTGDARKR-UHFFFAOYSA-N,InChI=1S/C24H27N7O/c1-29-10-12-31(13-11-29)17-...,6,target_trt
1,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,J09,2,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,ZYVXTMKTGDARKR-UHFFFAOYSA-N,DYRK inhibitor,trt,COc1cc(ccc1Nc1nccc(n1)-c1cn(C)c2cnccc12)N1CCN(...,TARGET2,JCP2022_116560,ZYVXTMKTGDARKR-UHFFFAOYSA-N,InChI=1S/C24H27N7O/c1-29-10-12-31(13-11-29)17-...,6,target_trt
2,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,J09,3,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,ZYVXTMKTGDARKR-UHFFFAOYSA-N,DYRK inhibitor,trt,COc1cc(ccc1Nc1nccc(n1)-c1cn(C)c2cnccc12)N1CCN(...,TARGET2,JCP2022_116560,ZYVXTMKTGDARKR-UHFFFAOYSA-N,InChI=1S/C24H27N7O/c1-29-10-12-31(13-11-29)17-...,6,target_trt
3,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,J09,4,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,ZYVXTMKTGDARKR-UHFFFAOYSA-N,DYRK inhibitor,trt,COc1cc(ccc1Nc1nccc(n1)-c1cn(C)c2cnccc12)N1CCN(...,TARGET2,JCP2022_116560,ZYVXTMKTGDARKR-UHFFFAOYSA-N,InChI=1S/C24H27N7O/c1-29-10-12-31(13-11-29)17-...,6,target_trt
4,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,J09,5,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,../cpjump3/jump/images/source_10/2021_08_03_U2...,ZYVXTMKTGDARKR-UHFFFAOYSA-N,DYRK inhibitor,trt,COc1cc(ccc1Nc1nccc(n1)-c1cn(C)c2cnccc12)N1CCN(...,TARGET2,JCP2022_116560,ZYVXTMKTGDARKR-UHFFFAOYSA-N,InChI=1S/C24H27N7O/c1-29-10-12-31(13-11-29)17-...,6,target_trt
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
31273,source_8,J3,A1170542,C20,4,../cpjump2/jump/images/source_8/J3/A1170542/so...,../cpjump2/jump/images/source_8/J3/A1170542/so...,../cpjump2/jump/images/source_8/J3/A1170542/so...,../cpjump2/jump/images/source_8/J3/A1170542/so...,../cpjump2/jump/images/source_8/J3/A1170542/so...,LKZLGMAAKNEGCH-UHFFFAOYSA-N,ubiquitin specific protease inhibitor,trt,CC(=O)c1cc(c(Sc2cccc(Cl)c2Cl)s1)[N+]([O-])=O,COMPOUND,JCP2022_050096,LKZLGMAAKNEGCH-UHFFFAOYSA-N,InChI=1S/C12H7Cl2NO3S2/c1-6(16)10-5-8(15(17)18...,9,compound_trt
31274,source_8,J3,A1170542,C20,5,../cpjump2/jump/images/source_8/J3/A1170542/so...,../cpjump2/jump/images/source_8/J3/A1170542/so...,../cpjump2/jump/images/source_8/J3/A1170542/so...,../cpjump2/jump/images/source_8/J3/A1170542/so...,../cpjump2/jump/images/source_8/J3/A1170542/so...,LKZLGMAAKNEGCH-UHFFFAOYSA-N,ubiquitin specific protease inhibitor,trt,CC(=O)c1cc(c(Sc2cccc(Cl)c2Cl)s1)[N+]([O-])=O,COMPOUND,JCP2022_050096,LKZLGMAAKNEGCH-UHFFFAOYSA-N,InChI=1S/C12H7Cl2NO3S2/c1-6(16)10-5-8(15(17)18...,9,compound_trt
31275,source_8,J3,A1170542,C20,7,../cpjump2/jump/images/source_8/J3/A1170542/so...,../cpjump2/jump/images/source_8/J3/A1170542/so...,../cpjump2/jump/images/source_8/J3/A1170542/so...,../cpjump2/jump/images/source_8/J3/A1170542/so...,../cpjump2/jump/images/source_8/J3/A1170542/so...,LKZLGMAAKNEGCH-UHFFFAOYSA-N,ubiquitin specific protease inhibitor,trt,CC(=O)c1cc(c(Sc2cccc(Cl)c2Cl)s1)[N+]([O-])=O,COMPOUND,JCP2022_050096,LKZLGMAAKNEGCH-UHFFFAOYSA-N,InChI=1S/C12H7Cl2NO3S2/c1-6(16)10-5-8(15(17)18...,9,compound_trt
31276,source_8,J3,A1170542,C20,8,../cpjump2/jump/images/source_8/J3/A1170542/so...,../cpjump2/jump/images/source_8/J3/A1170542/so...,../cpjump2/jump/images/source_8/J3/A1170542/so...,../cpjump2/jump/images/source_8/J3/A1170542/so...,../cpjump2/jump/images/source_8/J3/A1170542/so...,LKZLGMAAKNEGCH-UHFFFAOYSA-N,ubiquitin specific protease inhibitor,trt,CC(=O)c1cc(c(Sc2cccc(Cl)c2Cl)s1)[N+]([O-])=O,COMPOUND,JCP2022_050096,LKZLGMAAKNEGCH-UHFFFAOYSA-N,InChI=1S/C12H7Cl2NO3S2/c1-6(16)10-5-8(15(17)18...,9,compound_trt


In [30]:
dataset = JumpMOADataset(
    moa_load_df=moa_load_df,
    transform=DefaultJUMPTransform(size=256),
    compound_transform=DGLPretrainedFromSmiles(),
    return_image=True,
    return_compound=True,
)

In [31]:
dataset[0]

{'label': 3,
 'image': tensor([[[-0.2042, -0.2285, -0.2285,  ..., -0.2285, -0.2285, -0.2042],
          [-0.2285, -0.2528, -0.2528,  ..., -0.1556, -0.1799, -0.1556],
          [-0.2285, -0.1312, -0.1312,  ..., -0.1799, -0.2528, -0.2285],
          ...,
          [-0.1556, -0.2285, -0.2528,  ..., -0.2042, -0.2528, -0.2042],
          [-0.2042, -0.2285, -0.2042,  ..., -0.2042, -0.1069, -0.2528],
          [-0.2285, -0.2042, -0.2042,  ..., -0.2528, -0.1556, -0.2285]],
 
         [[-0.3619, -0.3619, -0.3619,  ...,  1.1231,  0.8756,  0.7519],
          [-0.3371, -0.3619, -0.3619,  ...,  0.4054,  0.6529,  0.5786],
          [-0.3371, -0.3619, -0.3619,  ..., -0.0154,  0.5539,  0.6776],
          ...,
          [-0.3866, -0.3866, -0.3866,  ...,  0.9004,  1.0241,  0.9004],
          [-0.3619, -0.3866, -0.3619,  ...,  0.8756,  1.1974,  1.4449],
          [-0.3619, -0.3619, -0.3866,  ...,  0.8509,  1.1974,  1.4944]],
 
         [[-0.3536, -0.3316, -0.3756,  ..., -0.0456,  0.3503,  1.0102],
      

In [35]:
dl = DataLoader(dataset=dataset, batch_size=32, shuffle=True, collate_fn=dataset.get_default_collate_fn())

In [36]:
b = next(iter(dl))

In [37]:
b

{'label': tensor([ 1.,  1., 20.,  1., 20., 15.,  1.,  1.,  1.,  1.,  1.,  3.,  1.,  2.,
          1.,  1.,  1.,  1.,  3.,  1.,  1.,  6.,  1.,  1.,  1.,  1.,  1.,  5.,
          1.,  1.,  1.,  1.]),
 'image': tensor([[[[-3.0593e-01, -3.0593e-01, -3.0593e-01,  ...,  5.4897e+00,
             5.7895e+00,  4.1407e+00],
           [-2.5597e-01, -3.3091e-01, -3.0593e-01,  ...,  5.2898e+00,
             4.4904e+00,  5.1400e+00],
           [-3.8087e-01, -3.5589e-01, -4.0585e-01,  ...,  4.5154e+00,
             3.8659e+00,  4.0158e+00],
           ...,
           [-3.0593e-01, -3.5589e-01, -3.3091e-01,  ..., -3.5589e-01,
            -3.3091e-01, -4.3083e-01],
           [-3.8087e-01, -3.0593e-01, -3.3091e-01,  ..., -3.3091e-01,
            -3.8087e-01, -3.3091e-01],
           [-3.8087e-01, -3.8087e-01, -4.0585e-01,  ..., -4.3083e-01,
            -3.8087e-01, -3.8087e-01]],
 
          [[-6.6179e-01, -3.8758e-01, -5.0510e-01,  ...,  2.0411e+00,
             3.0400e+00,  2.7854e+00],
           

In [39]:
b["image"].shape

torch.Size([32, 5, 256, 256])

## DataModule

In [6]:
dataloader_config = DictConfig(
    {
        "train": {
            "batch_size": 16,
            "num_workers": 4,
            "shuffle": True,
        },
        "val": {
            "batch_size": 16,
            "num_workers": 4,
            "shuffle": False,
        },
        "test": {
            "batch_size": 16,
            "num_workers": 4,
            "shuffle": False,
        },
    }
)

In [7]:
dm = JumpMOADataModule(
    moa_load_df_path="../cpjump1/jump/models/eval/test/moa_1000.csv",
    split_path="../cpjump1/jump/models/eval/test/",
    dataloader_config=dataloader_config,
    force_split=False,
    transform=DefaultJUMPTransform(size=256),
    compound_transform=DGLPretrainedFromSmiles(),
    return_image=True,
    return_compound=True,
    collate_fn=None,
    metadata_dir="../cpjump1/jump/metadata",
    load_data_dir="../cpjump1/jump/load_data",
    splitter=StratifiedSplitter(
        train=0.75,
        val=0.15,
        test=0.1,
    ),
    max_obs_per_class=1000,
)

In [71]:
logging.basicConfig(level=logging.INFO)

In [8]:
dm.prepare_data()

In [9]:
dm.setup("train", data_root_dir="../")

In [19]:
torch.save(b, "../cpjump1/jump/models/eval/test/example.pt")

In [75]:
image_encoder = CNNEncoder("resnet18", target_num=128)

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/resnet18.a1_in1k)
INFO:timm.models._hub:[timm/resnet18.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
INFO:src.modules.images.timm_pretrained:Using model resnet/resnest with projection head


In [76]:
molecule_encoder = GINPretrainedWithLinearHead("gin_supervised_infomax", out_dim=128)

INFO:src.modules.molecules.dgllife_gin:Using pretrained model: gin_supervised_infomax


Downloading gin_supervised_infomax_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_infomax.pth...
Pretrained model loaded


In [78]:
model = JumpMOAImageModule(
    image_encoder=image_encoder,
    optimizer=torch.optim.Adam,
    scheduler=None,
    criterion=None,
    cross_modal_module=None,
    example_input=b,
)

In [79]:
model

JumpMOAImageModule(
  (image_encoder): CNNEncoder(
    (backbone): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (drop_block): Identity()
          (act1): ReLU(inplace=True)
          (aa): Identity()
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act2): ReLU(inplace=True)
        )
        (1): BasicBlock(
   

In [82]:
logger = WandbLogger(project="jump_moa", log_model=True, group="debug")
trainer = Trainer(max_epochs=5, logger=logger)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [83]:
model(**b)

tensor([[-0.0015, -0.0026, -0.0322, -0.0981, -0.0607,  0.0043,  0.0200, -0.0356,
         -0.0269, -0.0480, -0.0120,  0.0151,  0.0308,  0.0494, -0.0364, -0.0587,
          0.0831,  0.0616, -0.0041, -0.0368, -0.0117,  0.0286,  0.0584, -0.0170,
          0.0418,  0.0571],
        [ 0.0431,  0.0387, -0.0616, -0.0470, -0.0454,  0.0400, -0.0117, -0.0495,
         -0.0074, -0.0596, -0.0026, -0.0100,  0.0266,  0.0082, -0.0329, -0.0936,
          0.0742,  0.0481, -0.0047, -0.0305,  0.0184,  0.0241,  0.0328, -0.0024,
          0.0200,  0.0087],
        [ 0.0217,  0.0231, -0.0398, -0.0545, -0.0935,  0.0154,  0.0294, -0.0766,
         -0.0098, -0.0666, -0.0524,  0.0256,  0.0063,  0.0221, -0.0084, -0.0642,
          0.0077,  0.0614, -0.0200, -0.0372, -0.0124,  0.0236,  0.0473,  0.0358,
          0.0296,  0.0135],
        [ 0.0176,  0.0184, -0.0262, -0.0352, -0.0913,  0.0009,  0.0535, -0.0424,
         -0.0165, -0.0665, -0.0252, -0.0154, -0.0270,  0.0104,  0.0493, -0.0633,
          0.0073,  0.0348

In [84]:
b["label"]

tensor([ 0, 14, 13, 15])

In [86]:
criterion = nn.CrossEntropyLoss()

In [87]:
x = model(**b)
y = b["label"]

In [88]:
criterion(x, y)

tensor(3.2844, grad_fn=<NllLossBackward0>)

In [89]:
trainer.fit(model, dm)

You are using a CUDA device ('NVIDIA GeForce RTX 3070 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



   | Name                | Type                        | Params | In sizes         | Out sizes
----------------------------------------------------------------------------------------------------
0  | image_encoder       | CNNEncoder                  | 11.2 M | [4, 5, 256, 256] | [4, 128] 
1  | head                | Sequential                  | 72.5 K | [4, 256]         | [4, 26]  
2  | criterion           | CrossEntropyLoss            | 0      | ?                | ?        
3  | train_loss          | MeanMetric                  | 0      | ?                | ?        
4  | val_loss            | MeanMetric                  | 0      | ?                | ?        
5  | test_loss           | MeanMetric                  | 0      | ?                | ?        
6  | train_other_metrics | MetricCollection            | 0      | ?                | ?        
7  | val_other_metrics   | MetricCollection            | 0      | ?                | ?        
8  | test_other_metrics  | MetricCollection

Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [14]:
b

{'label': tensor([ 0, 14, 13, 15]),
 'image': tensor([[[[-1.1375, -0.3726,  0.1374,  ..., -1.1375, -0.6275,  0.3923],
           [ 0.9023,  0.6473,  1.4122,  ...,  0.9023, -0.1176,  0.3923],
           [-0.3726, -1.6474,  0.6473,  ...,  0.3923,  2.4321, -0.1176],
           ...,
           [ 0.3923,  0.6473, -1.1375,  ...,  0.6473, -0.1176,  1.6672],
           [-0.6275, -0.3726, -0.6275,  ...,  0.6473,  1.6672, -0.1176],
           [-0.1176, -0.6275, -1.1375,  ..., -1.6474, -0.6275,  0.6473]],
 
          [[-0.3807, -0.5040, -0.1957,  ..., -0.1957, -0.0723, -0.1340],
           [-0.5040, -0.4424, -0.2573,  ..., -0.0723, -0.0723, -0.1340],
           [-0.2573, -0.4424, -0.3807,  ..., -0.1957, -0.1340, -0.1340],
           ...,
           [-0.1340, -0.1340, -0.0723,  ..., -0.1340,  0.1127, -0.0106],
           [-0.0106, -0.1957, -0.3190,  ..., -0.3190, -0.0723,  0.1127],
           [-0.1340, -0.0106, -0.1340,  ..., -0.1340, -0.0723, -0.0723]],
 
          [[-0.1535, -0.2926, -0.2926,  .

In [10]:
dl = dm.train_dataloader(batch_size=4)

In [13]:
b = next(iter(dl))

In [21]:
dm.train_dataset.moa_load_df.moa.nunique()

26

In [62]:
b["label"]

tensor([ 0, 14, 13, 15])