# Batch effect evaluation on JUMP Target

In [1]:
%load_ext autoreload
%autoreload 2

## Dataset creation

In [2]:
import logging
import os
import os.path as osp
import shutil
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import yaml
from lightning import LightningDataModule
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger
from omegaconf import DictConfig, OmegaConf
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy

from src.eval.moa.datamodule import JumpMOADataModule
from src.eval.moa.module import JumpMOAImageModule
from src.modules.collate_fn import image_graph_label_collate_function, label_graph_collate_function
from src.modules.compound_transforms import DGLPretrainedFromSmiles
from src.modules.images.timm_pretrained import CNNEncoder
from src.modules.molecules.dgllife_gin import GINPretrainedWithLinearHead
from src.modules.transforms import SimpleTransform
from src.splitters import StratifiedSplitter
from src.utils.io import download_and_extract_zip, load_image_paths_to_array



In [3]:
for i in range(1, 4):
    if not Path(f"../cpjump{i}/jump/").exists():
        print(f"Mounting cpjump{i}...")
        os.system(f"sshfs bioclust:/projects/cpjump{i}/ ../cpjump{i}")
    else:
        print(f"cpjump{i} already mounted.")

cpjump1 already mounted.
cpjump2 already mounted.
cpjump3 already mounted.


In [4]:
metadata_path = "../cpjump1/jump/metadata"
load_data_path = "../cpjump1/jump/load_data"

In [9]:
os.listdir(load_data_path)

['load_data_with_metadata', 'load_data_with_samples', 'final']

In [5]:
os.listdir(metadata_path)

['compound.csv.gz',
 'crispr.csv.gz',
 'microscope_config.csv',
 'microscope_filter.csv',
 'orf.csv.gz',
 'plate.csv.gz',
 'README.md',
 'well.csv.gz',
 'compound.csv',
 'crispr.csv',
 'orf.csv',
 'plate.csv',
 'well.csv',
 'complete_metadata.csv',
 'resolution.csv',
 'JUMP-Target-1_compound_metadata.tsv',
 'JUMP-Target-1_compound_platemap.tsv',
 'JUMP-Target-1_crispr_metadata.tsv',
 'JUMP-Target-1_crispr_platemap.tsv',
 'JUMP-Target-1_orf_metadata.tsv',
 'JUMP-Target-1_orf_platemap.tsv',
 'JUMP-Target-2_compound_metadata.tsv',
 'JUMP-Target-2_compound_platemap.tsv',
 'JUMP-MOA_compound_metadata.tsv',
 'local_metadata.csv']

In [14]:
load_df = pd.read_parquet(osp.join(load_data_path, "final"))

In [15]:
load_df

Unnamed: 0,Metadata_Source,Metadata_Batch,Metadata_Plate,Metadata_Well,Metadata_Site,FileName_OrigAGP,FileName_OrigDNA,FileName_OrigER,FileName_OrigMito,FileName_OrigRNA
0,source_10,2021_05_31_U2OS_48_hr_run1,Dest210531-152149,A01,1,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...
1,source_10,2021_05_31_U2OS_48_hr_run1,Dest210531-152149,A01,3,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...
2,source_10,2021_05_31_U2OS_48_hr_run1,Dest210531-152149,A01,6,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...
3,source_10,2021_05_31_U2OS_48_hr_run1,Dest210531-152149,A02,1,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...
4,source_10,2021_05_31_U2OS_48_hr_run1,Dest210531-152149,A02,2,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...
...,...,...,...,...,...,...,...,...,...,...
5121195,source_9,20211103-Run16,GR00004421,Z47,3,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...
5121196,source_9,20211103-Run16,GR00004421,Z47,4,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...
5121197,source_9,20211103-Run16,GR00004421,Z48,1,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...
5121198,source_9,20211103-Run16,GR00004421,Z48,2,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...


In [6]:
meta = pd.read_csv(osp.join(metadata_path, "local_metadata.csv"))

In [7]:
meta = pd.read_csv(osp.join(metadata_path, "local_metadata.csv"))
target_local = meta.query("Metadata_PlateType == 'TARGET2'")

In [8]:
target_local

Unnamed: 0,Metadata_Source,Metadata_Batch,Metadata_Plate,Metadata_PlateType,Metadata_Well,Metadata_JCP2022,Metadata_InChIKey,Metadata_InChI,Metadata_Sites_Per_Well,trt
114388,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,TARGET2,A01,JCP2022_043547,KBPLFHHGFOOTCA-UHFFFAOYSA-N,"InChI=1S/C8H18O/c1-2-3-4-5-6-7-8-9/h9H,2-8H2,1H3",6,target_trt
114389,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,TARGET2,A02,JCP2022_050797,LOUPRKONTZGTKE-UHFFFAOYSA-N,InChI=1S/C20H24N2O2/c1-3-13-12-22-9-7-14(13)10...,6,target_trt
114390,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,TARGET2,A03,JCP2022_050997,LPYXWGMUVRGUOY-UHFFFAOYSA-N,InChI=1S/C6H8O6/c7-1-2(8)5-3(9)4(10)6(11)12-5/...,6,target_trt
114391,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,TARGET2,A04,JCP2022_108326,YGSDEFSMJLZEOE-UHFFFAOYSA-N,"InChI=1S/C7H6O3/c8-6-4-2-1-3-5(6)7(9)10/h1-4,8...",6,target_trt
114392,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,TARGET2,A05,JCP2022_033924,IAZDPXIOMUYVGZ-UHFFFAOYSA-N,InChI=1S/C2H6OS/c1-4(2)3/h1-2H3,6,target_neg
...,...,...,...,...,...,...,...,...,...,...
842725,source_9,20211103-Run16,GR00004409,TARGET2,Z44,JCP2022_060040,NMUSYJAQQFHJEW-UHFFFAOYSA-N,InChI=1S/C8H12N4O5/c9-7-10-2-12(8(16)11-7)6-5(...,4,target_trt
842726,source_9,20211103-Run16,GR00004409,TARGET2,Z45,JCP2022_019314,FABUFPQFXZVHFB-UHFFFAOYSA-N,InChI=1S/C27H42N2O5S/c1-15-9-8-10-27(7)22(34-2...,4,target_trt
842727,source_9,20211103-Run16,GR00004409,TARGET2,Z46,JCP2022_018899,DXZRBHUCOHBAHP-UHFFFAOYSA-N,InChI=1S/C15H13Cl2N3O2/c1-3-22-15(21)14-11(8(6...,4,target_trt
842728,source_9,20211103-Run16,GR00004409,TARGET2,Z47,JCP2022_033924,IAZDPXIOMUYVGZ-UHFFFAOYSA-N,InChI=1S/C2H6OS/c1-4(2)3/h1-2H3,4,target_neg


In [10]:
target_meta = pd.read_csv(osp.join(metadata_path, "JUMP-Target-2_compound_metadata.tsv"), sep="\t")

In [11]:
target_meta

Unnamed: 0,broad_sample,InChIKey,pert_iname,pubchem_cid,target,pert_type,control_type,smiles
0,BRD-K09338665-001-07-1,KBPLFHHGFOOTCA-UHFFFAOYSA-N,1-octanol,957.0,GJB4,trt,,CCCCCCCCO
1,BRD-K48278478-001-01-2,LOUPRKONTZGTKE-AFHBHXEDSA-N,quinine,94175.0,KCNN4,trt,,COc1ccc2nccc([C@@H](O)[C@H]3C[C@@H]4CC[N@]3C[C...
2,BRD-A85242401-001-12-3,KRGQEOSDQHTZMX-IGCYCDGOSA-N,ascorbic-acid,9888239.0,P3H1,trt,,OC[C@H](O)[C@H]1OC(=O)C(=O)C1O
3,BRD-K93632104-001-17-2,YGSDEFSMJLZEOE-UHFFFAOYSA-N,salicylic-acid,118212070.0,AKR1C1,trt,,OC(=O)c1ccccc1O
4,BRD-K57313110-001-06-8,ODHCTXKNWHHXJC-VKHMYHEASA-N,pidolic-acid,7405.0,VEGFA,trt,,OC(=O)[C@@H]1CCC(=O)N1
...,...,...,...,...,...,...,...,...
302,BRD-A69636825-003-04-7,HSUGRBWQSSZJOP-UHFFFAOYSA-N,diltiazem,3076.0,CACNG1,trt,,COc1ccc(cc1)C1Sc2ccccc2N(CCN(C)C)C(=O)C1OC(C)=O
303,BRD-K87782578-001-03-9,KXBDTLQSDKGAEB-UHFFFAOYSA-N,AVL-292,59174488.0,BTK,trt,,COCCOc1ccc(Nc2ncc(F)c(Nc3cccc(NC(=O)C=C)c3)n2)cc1
304,BRD-K98763141-001-30-8,JZFPYUNJRRFVQU-UHFFFAOYSA-N,niflumic-acid,4488.0,UGT1A9,trt,,OC(=O)c1cccnc1Nc1cccc(c1)C(F)(F)F
305,BRD-K19975102-001-02-0,YYDUWLSETXNJJT-MTJSOVHGSA-N,GNF-5837,59397065.0,NTRK1,trt,,Cc1ccc(NC(=O)Nc2cc(ccc2F)C(F)(F)F)cc1Nc1ccc2c(...


In [12]:
merged = pd.merge(
    target_meta.dropna(subset=["target"]),
    target_local,
    left_on=["InChIKey"],
    right_on=["Metadata_InChIKey"],
    how="inner",
)[["target", "smiles", "Metadata_Source", "Metadata_Batch", "Metadata_Plate", "Metadata_Well", "pert_type"]]

In [13]:
merged

Unnamed: 0,target,smiles,Metadata_Source,Metadata_Batch,Metadata_Plate,Metadata_Well,pert_type
0,GJB4,CCCCCCCCO,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,A01,trt
1,GJB4,CCCCCCCCO,source_10,2021_08_09_U2OS_48_hr_run13,Dest210727-153003,A01,trt
2,GJB4,CCCCCCCCO,source_10,2021_08_12_U2OS_48_hr_run15,Dest210803-153958,A01,trt
3,GJB4,CCCCCCCCO,source_10,2021_08_17_U2OS_48_hr_run16,Dest210809-134534,A01,trt
4,GJB4,CCCCCCCCO,source_10,2021_08_20_U2OS_48_hr_run17,Dest210810-173723,A01,trt
...,...,...,...,...,...,...,...
26266,UGT1A9,OC(=O)c1cccnc1Nc1cccc(c1)C(F)(F)F,source_9,20211102-Run15,GR00004395,O30,trt
26267,UGT1A9,OC(=O)c1cccnc1Nc1cccc(c1)C(F)(F)F,source_9,20211103-Run16,GR00004409,AE06,trt
26268,UGT1A9,OC(=O)c1cccnc1Nc1cccc(c1)C(F)(F)F,source_9,20211103-Run16,GR00004409,AE30,trt
26269,UGT1A9,OC(=O)c1cccnc1Nc1cccc(c1)C(F)(F)F,source_9,20211103-Run16,GR00004409,O06,trt


In [16]:
target_load_df = pd.merge(
    load_df,
    merged,
    on=["Metadata_Source", "Metadata_Batch", "Metadata_Plate", "Metadata_Well"],
    how="inner",
)

In [17]:
target_load_df

Unnamed: 0,Metadata_Source,Metadata_Batch,Metadata_Plate,Metadata_Well,Metadata_Site,FileName_OrigAGP,FileName_OrigDNA,FileName_OrigER,FileName_OrigMito,FileName_OrigRNA,target,smiles,pert_type
0,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,A01,1,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,GJB4,CCCCCCCCO,trt
1,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,A01,2,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,GJB4,CCCCCCCCO,trt
2,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,A01,3,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,GJB4,CCCCCCCCO,trt
3,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,A01,4,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,GJB4,CCCCCCCCO,trt
4,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,A01,5,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,GJB4,CCCCCCCCO,trt
...,...,...,...,...,...,...,...,...,...,...,...,...,...
144081,source_9,20211103-Run16,GR00004409,Z40,4,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,ATM,O=c1cc(oc(c1)-c1cccc2Sc3ccccc3Sc12)N1CCOCC1,trt
144082,source_9,20211103-Run16,GR00004409,Z42,1,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,PTK2B,Cc1oncc1C(=O)Nc1ccc(cc1)C(F)(F)F,trt
144083,source_9,20211103-Run16,GR00004409,Z42,2,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,PTK2B,Cc1oncc1C(=O)Nc1ccc(cc1)C(F)(F)F,trt
144084,source_9,20211103-Run16,GR00004409,Z42,3,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,PTK2B,Cc1oncc1C(=O)Nc1ccc(cc1)C(F)(F)F,trt


In [52]:
class BatchEffectDataModule(LightningDataModule):
    def __init__(
        self,
        target_load_df_path: str,
        split_path: str,
        label_col: str = "target",
        smiles_col: str = "smiles",
        val_size: float = 0.1,
        test_size: float = 0.2,
        collate_fn: Optional[Callable] = default_collate,
        batch_size: int = 256,
        num_workers: int = 16,
        pin_memory: bool = False,
        prefetch_factor: int = 3,
        drop_last: bool = False,
        metadata_path: str = "../cpjump1/jump/metadata",
        load_data_path: str = "../cpjump1/jump/load_data",
        random_state: int = 42,
    ):
        super().__init__()

        # paths
        self.target_load_df_path = target_load_df_path
        self.split_path = split_path

        # for prepare_data
        self.metadata_path = metadata_path
        self.load_data_path = load_data_path
        self.val_size = val_size
        self.test_size = test_size
        self.random_state = random_state

        # dataset args
        self.smiles_col = smiles_col
        self.label_col = label_col

        # dataloader args
        self.collate_fn = collate_fn
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.prefetch_factor = prefetch_factor
        self.drop_last = drop_last

        # needed attributes
        self.target_load_df: Optional[pd.DataFrame] = None

        self.datasets: Dict[str, Dict[str, Dataset]] = {
            "random": {
                "train": None,
                "val": None,
                "test": None,
            },
            "plate_aware": {
                "train": None,
                "val": None,
                "test": None,
            },
            "source_aware": {
                "train": None,
                "val": None,
                "test": None,
            },
        }

        self.train_dataset: Optional[Dataset] = None
        self.val_dataset: Optional[Dataset] = None
        self.test_dataset: Optional[Dataset] = None

    @staticmethod
    def write_list_to_file(file_path: str, a_list: List[str]) -> None:
        with open(file_path, "w") as f:
            for item in a_list:
                f.write(f"{item}\n")

    @staticmethod
    def aware_split(df, col, test_size=0.2):
        unique_values = df[col].unique()
        train_values = np.random.choice(unique_values, size=int(test_size * len(unique_values)), replace=False)

        train_idx = df[df[col].isin(train_values)].index.tolist()
        train_idx = np.random.permutation(train_idx).tolist()
        test_ids = df[~df[col].isin(train_values)].index.tolist()
        test_ids = np.random.permutation(test_ids).tolist()

        return train_idx, test_ids

    def prepare_data(self) -> None:
        if not Path(self.target_load_df_path).exists():
            load_df = pd.read_parquet(osp.join(self.load_data_path, "final"))
            meta = pd.read_csv(osp.join(self.metadata_path, "local_metadata.csv"))
            target_local = meta.query("Metadata_PlateType == 'TARGET2'")
            target_meta = pd.read_csv(osp.join(self.metadata_path, "JUMP-Target-2_compound_metadata.tsv"), sep="\t")

            merged = pd.merge(
                target_meta.dropna(subset=["target"]),
                target_local,
                left_on=["InChIKey"],
                right_on=["Metadata_InChIKey"],
                how="inner",
            )[["target", "smiles", "Metadata_Source", "Metadata_Batch", "Metadata_Plate", "Metadata_Well", "pert_type"]]

            target_load_df = pd.merge(
                load_df,
                merged,
                on=["Metadata_Source", "Metadata_Batch", "Metadata_Plate", "Metadata_Well"],
                how="inner",
            )

            Path(self.target_load_df_path).parent.mkdir(parents=True, exist_ok=True)

            target_load_df.to_csv(self.target_load_df_path, index=False)

        if not Path(self.split_path).exists():
            if "target_load_df" not in locals():
                target_load_df = pd.read_csv(self.target_load_df_path)

            # Create split dir
            Path(self.split_path).mkdir(parents=True, exist_ok=True)

            # Create Random Split from the target_load_df
            index = np.arange(len(target_load_df))
            random_dir = Path(self.split_path) / "random"
            random_dir.mkdir(parents=True, exist_ok=True)

            train_val_index, test_index = train_test_split(
                index, test_size=self.test_size, random_state=self.random_state, stratify=target_load_df[self.label_col]
            )
            train_index, val_index = train_test_split(
                train_val_index,
                test_size=self.val_size / (1 - self.test_size),
                random_state=self.random_state,
                stratify=target_load_df[self.label_col].iloc[train_val_index],
            )

            self.write_list_to_file(osp.join(random_dir, "train.csv"), train_index)
            self.write_list_to_file(osp.join(random_dir, "val.csv"), val_index)
            self.write_list_to_file(osp.join(random_dir, "test.csv"), test_index)

            # Create Plate aware split from the target_load_df
            plate_dir = Path(self.split_path) / "plate_aware"
            plate_dir.mkdir(parents=True, exist_ok=True)

            train_val_index, test_index = self.aware_split(target_load_df, "Metadata_Plate", test_size=self.test_size)
            train_index, val_index = train_test_split(
                train_val_index,
                test_size=self.val_size / (1 - self.test_size),
                random_state=self.random_state,
                stratify=target_load_df[self.label_col].iloc[train_val_index],
            )

            self.write_list_to_file(osp.join(plate_dir, "train.csv"), train_index)
            self.write_list_to_file(osp.join(plate_dir, "val.csv"), val_index)
            self.write_list_to_file(osp.join(plate_dir, "test.csv"), test_index)

            # Create Source aware split from the target_load_df
            source_dir = Path(self.split_path) / "source_aware"
            source_dir.mkdir(parents=True, exist_ok=True)

            train_val_index, test_index = self.aware_split(target_load_df, "Metadata_Source", test_size=self.test_size)
            train_index, val_index = train_test_split(
                train_val_index,
                test_size=self.val_size / (1 - self.test_size),
                random_state=self.random_state,
                stratify=target_load_df[self.label_col].iloc[train_val_index],
            )

            self.write_list_to_file(osp.join(source_dir, "train.csv"), train_index)
            self.write_list_to_file(osp.join(source_dir, "val.csv"), val_index)
            self.write_list_to_file(osp.join(source_dir, "test.csv"), test_index)

    def setup(self, stage: Optional[str] = None) -> None:
        if self.target_load_df is None:
            self.target_load_df = pd.read_csv(self.target_load_df_path)
            self.labels = self.target_load_df[self.label_col].unique().tolist()
            self.labels.sort()
            self.label_to_idx = {label: idx for idx, label in enumerate(self.labels)}
            self.num_to_labels = dict(enumerate(self.labels))

        if stage == "fit" or stage is None or stage == "validate":
            pass

        if stage == "test" or stage is None:
            pass

    def get_split_datamodule(self, split_dir: str) -> LightningDataModule:
        LightningDataModule.from_datasets(train_dataset)

    # def phase_dataloader(self, phase: str) -> DataLoader:
    #     return DataLoader(
    #         dataset=self.train_dataset,
    #         collate_fn=self.collate_fn,
    #         batch_size=self.batch_size,
    #         num_workers=self.num_workers,
    #         pin_memory=self.pin_memory,
    #         prefetch_factor=self.prefetch_factor,
    #         drop_last=self.drop_last,
    #         shuffle=(phase == "train"),
    #     )

    # def train_dataloader(self) -> DataLoader:
    #     return self.phase_dataloader("train")

    # def val_dataloader(self) -> DataLoader:
    #     return self.phase_dataloader("val")

    # def test_dataloader(self) -> DataLoader:
    #     return self.phase_dataloader("test")

    def teardown(self, stage: Optional[str] = None):
        """Clean up after fit or test."""
        pass

    def state_dict(self):
        """Extra things to save to checkpoint."""
        return {}

    def load_state_dict(self, state_dict: Dict[str, Any]):
        """Things to do when loading checkpoint."""
        pass

SyntaxError: expected ':' (3864810952.py, line 156)

In [19]:
target_load_df.Metadata_Batch.nunique()

85

In [20]:
target_load_df.Metadata_Source.nunique()

10

In [21]:
target_load_df.Metadata_Plate.nunique()

113

In [39]:
target_load_df.groupby(["Metadata_Batch", "target"]).size()

Metadata_Batch                  target 
20210823_Batch_10               ABL1       12
                                ADH1C      12
                                ADORA2A    12
                                ADRA2B     12
                                AGER       12
                                           ..
p211123CPU2OS48hw384exp036JUMP  TNNC1      12
                                TUBB4B     12
                                UGT1A9     24
                                USP1       12
                                VEGFA      12
Length: 10029, dtype: int64

In [40]:
def split_by_source(df, source_col="Metadata_Source", frac=0.8):
    sources = df[source_col].unique()
    train_sources = np.random.choice(sources, size=int(frac * len(sources)), replace=False)
    train_df = df[df[source_col].isin(train_sources)]
    val_df = df[~df[source_col].isin(train_sources)]
    return train_df, val_df

In [46]:
train_df, val_df = split_by_source(target_load_df, source_col="Metadata_Batch")

In [41]:
train_df, val_df = split_by_source(target_load_df)

In [51]:
train_df

Unnamed: 0,Metadata_Source,Metadata_Batch,Metadata_Plate,Metadata_Well,Metadata_Site,FileName_OrigAGP,FileName_OrigDNA,FileName_OrigER,FileName_OrigMito,FileName_OrigRNA,target,smiles,pert_type
0,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,A01,1,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,GJB4,CCCCCCCCO,trt
1,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,A01,2,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,GJB4,CCCCCCCCO,trt
2,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,A01,3,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,GJB4,CCCCCCCCO,trt
3,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,A01,4,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,GJB4,CCCCCCCCO,trt
4,source_10,2021_08_03_U2OS_48_hr_run12,Dest210726-160150,A01,5,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,/projects/cpjump3/jump/images/source_10/2021_0...,GJB4,CCCCCCCCO,trt
...,...,...,...,...,...,...,...,...,...,...,...,...,...
144081,source_9,20211103-Run16,GR00004409,Z40,4,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,ATM,O=c1cc(oc(c1)-c1cccc2Sc3ccccc3Sc12)N1CCOCC1,trt
144082,source_9,20211103-Run16,GR00004409,Z42,1,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,PTK2B,Cc1oncc1C(=O)Nc1ccc(cc1)C(F)(F)F,trt
144083,source_9,20211103-Run16,GR00004409,Z42,2,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,PTK2B,Cc1oncc1C(=O)Nc1ccc(cc1)C(F)(F)F,trt
144084,source_9,20211103-Run16,GR00004409,Z42,3,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,/projects/cpjump1/jump/images/source_9/2021110...,PTK2B,Cc1oncc1C(=O)Nc1ccc(cc1)C(F)(F)F,trt


In [45]:
val_df

Unnamed: 0,Metadata_Source,Metadata_Batch,Metadata_Plate,Metadata_Well,Metadata_Site,FileName_OrigAGP,FileName_OrigDNA,FileName_OrigER,FileName_OrigMito,FileName_OrigRNA,target,smiles,pert_type
14656,source_13,20220914_Run1,CP-CC9-R1-29,A01,0,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,GJB4,CCCCCCCCO,trt
14657,source_13,20220914_Run1,CP-CC9-R1-29,A01,2,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,GJB4,CCCCCCCCO,trt
14658,source_13,20220914_Run1,CP-CC9-R1-29,A01,3,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,GJB4,CCCCCCCCO,trt
14659,source_13,20220914_Run1,CP-CC9-R1-29,A01,6,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,GJB4,CCCCCCCCO,trt
14660,source_13,20220914_Run1,CP-CC9-R1-29,A01,7,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,/projects/cpjump1/jump/images/source_13/202209...,GJB4,CCCCCCCCO,trt
...,...,...,...,...,...,...,...,...,...,...,...,...,...
103619,source_5,JUMPCPE-20211014-Run36_20211014_223431,ACPJUM192,P22,2,/projects/cpjump2/jump/images/source_5/JUMPCPE...,/projects/cpjump2/jump/images/source_5/JUMPCPE...,/projects/cpjump2/jump/images/source_5/JUMPCPE...,/projects/cpjump2/jump/images/source_5/JUMPCPE...,/projects/cpjump2/jump/images/source_5/JUMPCPE...,UGT1A9,OC(=O)c1cccnc1Nc1cccc(c1)C(F)(F)F,trt
103620,source_5,JUMPCPE-20211014-Run36_20211014_223431,ACPJUM192,P22,3,/projects/cpjump2/jump/images/source_5/JUMPCPE...,/projects/cpjump2/jump/images/source_5/JUMPCPE...,/projects/cpjump2/jump/images/source_5/JUMPCPE...,/projects/cpjump2/jump/images/source_5/JUMPCPE...,/projects/cpjump2/jump/images/source_5/JUMPCPE...,UGT1A9,OC(=O)c1cccnc1Nc1cccc(c1)C(F)(F)F,trt
103621,source_5,JUMPCPE-20211014-Run36_20211014_223431,ACPJUM192,P22,4,/projects/cpjump2/jump/images/source_5/JUMPCPE...,/projects/cpjump2/jump/images/source_5/JUMPCPE...,/projects/cpjump2/jump/images/source_5/JUMPCPE...,/projects/cpjump2/jump/images/source_5/JUMPCPE...,/projects/cpjump2/jump/images/source_5/JUMPCPE...,UGT1A9,OC(=O)c1cccnc1Nc1cccc(c1)C(F)(F)F,trt
103622,source_5,JUMPCPE-20211014-Run36_20211014_223431,ACPJUM192,P22,7,/projects/cpjump2/jump/images/source_5/JUMPCPE...,/projects/cpjump2/jump/images/source_5/JUMPCPE...,/projects/cpjump2/jump/images/source_5/JUMPCPE...,/projects/cpjump2/jump/images/source_5/JUMPCPE...,/projects/cpjump2/jump/images/source_5/JUMPCPE...,UGT1A9,OC(=O)c1cccnc1Nc1cccc(c1)C(F)(F)F,trt


In [26]:
target_load_df.target.value_counts()

target
GJB4       1536
P2RY12     1536
GUCY1B1    1536
EZH2       1536
HTR2C      1536
           ... 
SLCO2B1     762
CHRM3       762
HPGDS       762
CACNB4      762
BTK         762
Name: count, Length: 118, dtype: int64

## Datamodule

## Module

## Evaluator