In [1]:
import json
import logging
import os
import random as rd
from glob import glob
from pathlib import Path
from typing import List, Optional, Tuple

import datamol as dm
import hydra
import lightning as L
import numpy as np
import pandas as pd
import pyrootutils
import torch
import torch.nn as nn
import yaml
from hydra import compose, initialize, initialize_config_dir, initialize_config_module
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig, OmegaConf
from tqdm.auto import tqdm

from src import utils
from src.eval.evaluators import EvaluatorList
from src.pyidr.screenio import ScreenReader

In [2]:
Path("../cpjump1/jump/").exists()

True

In [3]:
for i in range(1, 4):
    if not Path(f"../cpjump{i}/jump/").exists():
        print(f"Mounting cpjump{i}...")
        os.system(f"sshfs bioclust:/projects/cpjump{i}/ ../cpjump{i}")
    else:
        print(f"cpjump{i} already mounted.")

cpjump1 already mounted.
cpjump2 already mounted.
cpjump3 already mounted.


In [4]:
metadata_dir = "../cpjump1/jump/metadata"
load_data_dir = "../cpjump1/jump/load_data"

In [172]:
idr_metadata = pd.read_csv("../cpjump1/idr0033-rohban-pathways/idr0033-screenA-annotation.csv")

In [173]:
idr_metadata.head()

Unnamed: 0,Plate,Well,Well Number,Characteristics [Organism],Term Source 1 REF,Term Source 1 Accession,Characteristics [Cell Line],Term Source 2 REF,Term Source 2 Accession,ORF Identifier,...,Phenotype 17,Phenotype 18,Phenotype 19,Phenotype 20,Phenotype 20 Term Name,Phenotype 20 Term Accession,Phenotype 21,Phenotype 21 Term Name,Phenotype 21 Term Accession,Phenotype 22
0,41744,A1,1,Homo sapiens,NCBITaxon,NCBITaxon_9606,U2OS,EFO,EFO_0002869,,...,,,,,,,,,,
1,41744,A2,2,Homo sapiens,NCBITaxon,NCBITaxon_9606,U2OS,EFO,EFO_0002869,,...,,,,,,,,,,
2,41744,A3,3,Homo sapiens,NCBITaxon,NCBITaxon_9606,U2OS,EFO,EFO_0002869,ccsbBroad304_00117,...,de-enriched for multi-nucleate,,,,,,low cell density,decreased cell numbers,CMPO_0000052,
3,41744,A4,4,Homo sapiens,NCBITaxon,NCBITaxon_9606,U2OS,EFO,EFO_0002869,ccsbBroad304_07101,...,,,,,,,,,,
4,41744,A5,5,Homo sapiens,NCBITaxon,NCBITaxon_9606,U2OS,EFO,EFO_0002869,ccsbBroad304_00150,...,,,,,,,,,,


In [174]:
positives_genes = ["BRCA1", "JUN", "HIF1A", "STAT3", "TP53", "HSPA5"]
negatives_genes = [
    "CASP7",
    "MAP4K2",
    "KCNH2",
    "MAPT",
    "APAF1",
    "KLK7",
    "KCNK3",
    "BCL2",
    "DDIT3",
    "ABCC1",
    "CBX1",
    "GALK1",
    "PTPN7",
    "MCOLN3",
    "GALR2",
]
genes = positives_genes + negatives_genes
cols_to_keep = ["Plate", "Well", "Gene Identifier", "Gene Symbol", "ORF Identifier", "ORF Sequence"]

In [175]:
idr_metadata_small = (
    idr_metadata.query("~Plate.str.endswith('_illum_corrected')")
    .query("`Gene Symbol`.isin(@positives_genes)")
    .query("`Quality Control Comments`.isna()")
    .loc[:, cols_to_keep]
)
idr_metadata_small

Unnamed: 0,Plate,Well,Gene Identifier,Gene Symbol,ORF Identifier,ORF Sequence
10,41744,A11,672.0,BRCA1,ccsbBroad304_00173,GGTCTATATAAGCAGAGCTCTCTGGCTAACTGTCGGGATCAACAAG...
36,41744,B13,3309.0,HSPA5,BRDN0000464901,GGTCTATATAAGCAGAGCTCTCTGGCTAACTGTCGGGATCAACAAG...
76,41744,D5,7157.0,TP53,BRDN0000464908,GGTCTATATAAGCAGAGCTCTCTGGCTAACTGTCGGGATCAACAAG...
80,41744,D9,3091.0,HIF1A,BRDN0000464910,GGTCTATATAAGCAGAGCTCTCTGGCTAACTGTCGGGATCAACAAG...
115,41744,E20,3725.0,JUN,ccsbBroad304_14682,GGTCTATATAAGCAGAGCTCTCTGGCTAACTGTCGGGATCAACAAG...
...,...,...,...,...,...,...
2076,41757,G13,6774.0,STAT3,ccsbBroad304_01609,GGTCTATATAAGCAGAGCTCTCTGGCTAACTGTCGGGATCAACAAG...
2182,41757,K23,3091.0,HIF1A,ccsbBroad304_06365,GGTCTATATAAGCAGAGCTCTCTGGCTAACTGTCGGGATCAACAAG...
2254,41757,N23,6774.0,STAT3,BRDN0000464968,GGTCTATATAAGCAGAGCTCTCTGGCTAACTGTCGGGATCAACAAG...
2268,41757,O13,7157.0,TP53,ccsbBroad304_07088,GGTCTATATAAGCAGAGCTCTCTGGCTAACTGTCGGGATCAACAAG...


In [176]:
idr_metadata_small.to_csv("../cpjump1/idr0033-rohban-pathways/filtered_metadata.csv", index=False)

In [132]:
# excape_db1 = pd.read_csv("../cpjump1/excape-db/excape_db_positive_with_negative.csv", nrows=700_000)
# excape_db2 = pd.read_csv("../cpjump1/excape-db/BRAC1-TP53.csv", nrows=700_000)

# excape_db = pd.concat([excape_db1, excape_db2])
excape_db = pd.read_csv("../cpjump1/excape-db/selected_perturbations.csv", nrows=700_000)
excape_db = excape_db.query("Gene_Symbol.isin(@positives_genes)")

In [133]:
excape_db.groupby("Gene_Symbol")["Activity_Flag"].value_counts()

Gene_Symbol  Activity_Flag
BRCA1        N                253028
             A                  7808
HIF1A        N                  8205
             A                  2096
HSPA5        N                  2498
             A                   656
JUN          N                  3305
             A                    82
STAT3        N                198467
             A                   245
TP53         N                216844
             A                  6766
Name: count, dtype: int64

In [157]:
def sample_k_compounds_with_scaffolds(
    excape_db, gene_symbol, k_pos=20, k_neg=100, max_neg=30_000, max_pos=5_000, seed=42
):
    """Sample k compounds with scaffolds from the same gene_symbol"""
    np.random.seed(seed)
    df = excape_db.query("Gene_Symbol == @gene_symbol")

    sdict = {
        "A": max_pos,
        "N": max_neg,
    }

    df = (
        df.groupby("Activity_Flag")
        .apply(lambda x: x.sample(np.min([sdict[x.name], len(x)]), random_state=seed))
        .reset_index(drop=True)
    )

    smiles_list = df["SMILES"].tolist()
    scaffolds = []
    for smiles in tqdm(smiles_list):
        mol = dm.to_mol(smiles)
        if mol is None:
            scaffolds.append(None)
            continue
        scaffold = dm.to_scaffold_murcko(mol)
        scaffold_smiles = dm.to_smiles(scaffold)
        scaffolds.append(scaffold_smiles)

    df["scaffold"] = scaffolds

    scaffold_count = (
        df.query('~scaffold.isnull() & scaffold != ""')
        .groupby("scaffold")
        .agg(
            has_pos=("Activity_Flag", lambda x: "A" in x.unique()),
            has_neg=("Activity_Flag", lambda x: "N" in x.unique()),
            n_scaffold=("scaffold", "count"),
        )
        .reset_index()
    )

    pos_scaffold_count = scaffold_count.query("has_pos")
    pos_p = np.sqrt(pos_scaffold_count["n_scaffold"])
    pos_p = pos_p / pos_p.sum()
    pos_scaffold_choice = np.random.choice(pos_scaffold_count["scaffold"], size=k_pos, replace=False, p=pos_p)
    pos_samples = (
        df.query("scaffold.isin(@pos_scaffold_choice)")
        .groupby("scaffold")
        .apply(lambda x: x.sample(1, random_state=seed))
    )

    neg_scaffold_count = scaffold_count.query("has_neg")
    neg_p = np.sqrt(neg_scaffold_count["n_scaffold"])
    neg_p = neg_p / neg_p.sum()
    neg_scaffold_choice = np.random.choice(neg_scaffold_count["scaffold"], size=k_neg, replace=False, p=neg_p)
    neg_samples = (
        df.query("scaffold.isin(@neg_scaffold_choice)")
        .groupby("scaffold")
        .apply(lambda x: x.sample(1, random_state=seed))
    )

    samples = pd.concat([pos_samples, neg_samples]).reset_index(drop=True)

    return samples

In [161]:
samples = {}

for gene in tqdm(positives_genes):
    samples[gene] = sample_k_compounds_with_scaffolds(
        excape_db=excape_db, gene_symbol=gene, k_pos=20, k_neg=100, max_neg=30_000, max_pos=5_000, seed=42
    )

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/35000 [00:00<?, ?it/s]

[14:15:50] Explicit valence for atom # 2 N, 4, is greater than permitted
[14:15:52] Explicit valence for atom # 7 N, 5, is greater than permitted
[14:15:53] Explicit valence for atom # 1 N, 5, is greater than permitted
[14:15:53] Explicit valence for atom # 2 N, 5, is greater than permitted
[14:15:54] Explicit valence for atom # 9 N, 5, is greater than permitted
[14:15:54] Explicit valence for atom # 1 N, 5, is greater than permitted
[14:15:55] Explicit valence for atom # 1 N, 5, is greater than permitted
[14:15:56] Explicit valence for atom # 9 N, 5, is greater than permitted


  0%|          | 0/3387 [00:00<?, ?it/s]

[14:15:58] Explicit valence for atom # 1 N, 5, is greater than permitted


  0%|          | 0/10301 [00:00<?, ?it/s]

[14:16:00] Explicit valence for atom # 1 N, 5, is greater than permitted
[14:16:00] Explicit valence for atom # 1 N, 5, is greater than permitted
[14:16:00] Explicit valence for atom # 1 N, 5, is greater than permitted
[14:16:00] Explicit valence for atom # 1 N, 5, is greater than permitted
[14:16:01] Explicit valence for atom # 1 N, 5, is greater than permitted
[14:16:01] Explicit valence for atom # 1 N, 5, is greater than permitted
[14:16:01] Explicit valence for atom # 4 N, 5, is greater than permitted
[14:16:01] Explicit valence for atom # 1 N, 5, is greater than permitted


  0%|          | 0/30245 [00:00<?, ?it/s]

[14:16:03] Explicit valence for atom # 6 N, 5, is greater than permitted
[14:16:03] Explicit valence for atom # 2 N, 5, is greater than permitted
[14:16:07] Explicit valence for atom # 2 N, 5, is greater than permitted
[14:16:08] Explicit valence for atom # 1 N, 5, is greater than permitted
[14:16:09] Explicit valence for atom # 1 N, 5, is greater than permitted


  0%|          | 0/35000 [00:00<?, ?it/s]

[14:16:14] Explicit valence for atom # 1 N, 5, is greater than permitted
[14:16:15] Explicit valence for atom # 2 N, 5, is greater than permitted
[14:16:15] Explicit valence for atom # 10 N, 5, is greater than permitted
[14:16:15] Explicit valence for atom # 2 N, 5, is greater than permitted
[14:16:17] Explicit valence for atom # 2 N, 5, is greater than permitted
[14:16:17] Explicit valence for atom # 1 N, 5, is greater than permitted


  0%|          | 0/3154 [00:00<?, ?it/s]

In [162]:
selected_compounds = pd.DataFrame()
for k, v in samples.items():
    selected_compounds = pd.concat([selected_compounds, v], axis=0)

In [166]:
list(selected_compounds.columns)

['Ambit_InchiKey',
 'Original_Entry_ID',
 'Entrez_ID',
 'Activity_Flag',
 'pXC50',
 'DB',
 'Original_Assay_ID',
 'Tax_ID',
 'Gene_Symbol',
 'Ortholog_Group',
 'SMILES',
 'scaffold']

In [167]:
selected_compounds = selected_compounds.loc[
    :,
    [
        "SMILES",
        "Gene_Symbol",
        "Activity_Flag",
        "pXC50",
        "Ambit_InchiKey",
        "Original_Assay_ID",
        "Tax_ID",
        "Original_Entry_ID",
        "Entrez_ID",
    ],
].sample(frac=1, random_state=42)

In [168]:
selected_compounds.Gene_Symbol.value_counts()

Gene_Symbol
HIF1A    120
BRCA1    120
JUN      120
STAT3    120
HSPA5    120
TP53     120
Name: count, dtype: int64

In [169]:
selected_compounds

Unnamed: 0,SMILES,Gene_Symbol,Activity_Flag,pXC50,Ambit_InchiKey,Original_Assay_ID,Tax_ID,Original_Entry_ID,Entrez_ID
100,O(C=1C(CC2=CC=CC=C2)=CC=CC1)CCN(C)C,HIF1A,N,,IZRPKIZLIFYYKR-UHFFFAOYNA-N,915,9606,298107,3091
50,O=C(N1CC(CC(C1)C)C)C2=C3C(=NC(=C2)C=4OC(=CC4)C...,HIF1A,N,,ZNYBBKBNQBRYLY-UHFFFAOYNA-N,651589,9606,3500511,3091
54,O=C(CN1N=C(N=C1)N(=O)=O)C=2C=C(O)C(O)=CC2,BRCA1,N,,NJYZXVXIVXSMQD-UHFFFAOYNA-N,624202,9606,674159,672
78,P1(O[C@@H]2[C@H](O[C@@H](N3C4=NC=NC(N)=C4N=C3)...,JUN,N,,IVOMOUWHDPKRLL-BJEHYBLCNA-N,1159526,9606,6076,3725
93,S(C=1N=C2N(C=C(C=C2)C)C(=O)C1C=O)CC3=CC=CC=C3,STAT3,N,,NPIFICATZHJQAS-UHFFFAOYNA-N,871,9606,803539,6774
...,...,...,...,...,...,...,...,...,...
71,S1C(C(OCC(=O)N2C(CCCC2)CC)=O)=CC=C1,BRCA1,N,,YZDFMRGYQOPYHN-UHFFFAOYNA-N,624202,9606,3901281,672
106,O(C(=O)C=1C=CC(CN2CCCCC2)=CC1)C,BRCA1,N,,VSUPXUIIYYFQDG-UHFFFAOYNA-N,624202,9606,782306,672
30,C1COC(=N1)NC(C2CC2)C3CC3,HIF1A,N,4.6,CQXADFVORZEARL-XWKXFZRBNA-N,688288,9606,CHEMBL289480,3091
75,O1CCN(CC(OC(=O)C2=CC(OC)=C(OC)C=C2)C)CC1,STAT3,N,,WYGBLDZQEBHKHE-UHFFFAOYNA-N,871,9606,2911045,6774


In [324]:
selected_compounds

Unnamed: 0,SMILES,Gene_Symbol,Activity_Flag,pXC50,Ambit_InchiKey,Original_Assay_ID,Tax_ID,Original_Entry_ID,Entrez_ID
100,O(C=1C(CC2=CC=CC=C2)=CC=CC1)CCN(C)C,HIF1A,N,,IZRPKIZLIFYYKR-UHFFFAOYNA-N,915,9606,298107,3091
50,O=C(N1CC(CC(C1)C)C)C2=C3C(=NC(=C2)C=4OC(=CC4)C...,HIF1A,N,,ZNYBBKBNQBRYLY-UHFFFAOYNA-N,651589,9606,3500511,3091
54,O=C(CN1N=C(N=C1)N(=O)=O)C=2C=C(O)C(O)=CC2,BRCA1,N,,NJYZXVXIVXSMQD-UHFFFAOYNA-N,624202,9606,674159,672
78,P1(O[C@@H]2[C@H](O[C@@H](N3C4=NC=NC(N)=C4N=C3)...,JUN,N,,IVOMOUWHDPKRLL-BJEHYBLCNA-N,1159526,9606,6076,3725
93,S(C=1N=C2N(C=C(C=C2)C)C(=O)C1C=O)CC3=CC=CC=C3,STAT3,N,,NPIFICATZHJQAS-UHFFFAOYNA-N,871,9606,803539,6774
...,...,...,...,...,...,...,...,...,...
71,S1C(C(OCC(=O)N2C(CCCC2)CC)=O)=CC=C1,BRCA1,N,,YZDFMRGYQOPYHN-UHFFFAOYNA-N,624202,9606,3901281,672
106,O(C(=O)C=1C=CC(CN2CCCCC2)=CC1)C,BRCA1,N,,VSUPXUIIYYFQDG-UHFFFAOYNA-N,624202,9606,782306,672
30,C1COC(=N1)NC(C2CC2)C3CC3,HIF1A,N,4.6,CQXADFVORZEARL-XWKXFZRBNA-N,688288,9606,CHEMBL289480,3091
75,O1CCN(CC(OC(=O)C2=CC(OC)=C(OC)C=C2)C)CC1,STAT3,N,,WYGBLDZQEBHKHE-UHFFFAOYNA-N,871,9606,2911045,6774


In [None]:
selected_compounds.to_csv("../cpjump1/excape-db/selected_compounds.csv", index=False)

In [186]:
with open("../cpjump1/idr0033-rohban-pathways/metadata_experiment_1751.json") as f:
    metadata = json.load(f)

In [189]:
metadata.keys()

dict_keys(['screens', 'projects', 'total_experiments', 'total_images'])

In [242]:
screen_meta = metadata["screens"]["0"]
plate_keys = screen_meta["plates"].keys()

In [260]:
images_meta = []

for plate_key in tqdm(plate_keys, leave=False):
    plate = screen_meta["plates"][plate_key]
    plate_name = plate["plate_name"]
    if "_illum_corrected" in plate_name:
        continue

    well_keys = plate["wells"].keys()
    for well_key in tqdm(well_keys, leave=False):
        well = plate["wells"][well_key]
        well_id = well["well_id"]
        field_keys = well["fields"].keys()
        for field_key in field_keys:
            field = well["fields"][field_key]
            image_id = field["image_id"]
            field_id = field["field"]
            meta = field["image_metadata"]
            gene = meta.get("Gene Symbol")
            channels = meta["Channels"]
            control_type = meta.get("Control Type", "trt")
            comment = meta.get("Control Comments", "")
            orf = meta.get("ORF Identifier")

            mdict = {
                "plate": plate_name,
                "well": well_id,
                "field": field_id,
                "image_id": image_id,
                "Gene Symbol": gene,
                "Channels": channels,
                "Control Type": control_type,
                "Control Comments": comment,
                "ORF Identifier": orf,
            }

            images_meta.append(mdict)

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/384 [00:00<?, ?it/s]

  0%|          | 0/384 [00:00<?, ?it/s]

  0%|          | 0/384 [00:00<?, ?it/s]

  0%|          | 0/384 [00:00<?, ?it/s]

  0%|          | 0/384 [00:00<?, ?it/s]

  0%|          | 0/384 [00:00<?, ?it/s]

In [270]:
image_id_to_gene_symbol = pd.DataFrame(images_meta)

In [271]:
image_id_to_gene_symbol

Unnamed: 0,plate,well,field,image_id,Gene Symbol,Channels,Control Type,Control Comments,ORF Identifier
0,41744,1312840,0,3191231,,Hoechst 33342:nucleus;concanavalin A/AlexaFluo...,negative control,untreated cells (EMPTY),
1,41744,1312840,1,3194679,,Hoechst 33342:nucleus;concanavalin A/AlexaFluo...,negative control,untreated cells (EMPTY),
2,41744,1312840,2,3194680,,Hoechst 33342:nucleus;concanavalin A/AlexaFluo...,negative control,untreated cells (EMPTY),
3,41744,1312840,3,3194681,,Hoechst 33342:nucleus;concanavalin A/AlexaFluo...,negative control,untreated cells (EMPTY),
4,41744,1312840,4,3194682,,Hoechst 33342:nucleus;concanavalin A/AlexaFluo...,negative control,untreated cells (EMPTY),
...,...,...,...,...,...,...,...,...,...
20731,41757,1315366,4,3230353,,Hoechst 33342:nucleus;concanavalin A/AlexaFluo...,negative control,untreated cells (EMPTY),
20732,41757,1315366,5,3230354,,Hoechst 33342:nucleus;concanavalin A/AlexaFluo...,negative control,untreated cells (EMPTY),
20733,41757,1315366,6,3230355,,Hoechst 33342:nucleus;concanavalin A/AlexaFluo...,negative control,untreated cells (EMPTY),
20734,41757,1315366,7,3230356,,Hoechst 33342:nucleus;concanavalin A/AlexaFluo...,negative control,untreated cells (EMPTY),


In [291]:
filtered_image_ids_to_symbol = image_id_to_gene_symbol.query("`Gene Symbol`.isin(@positives_genes)")

fs = os.listdir("../cpjump1/screen_1751")
unique_ids = filtered_image_ids_to_symbol.image_id.unique()
file_names = [f"screen_1751_{unique_id}_0.png" for unique_id in unique_ids]
downloaded_images = list(set(file_names).intersection(set(fs)))
downloaded_ids = [int(file_name.split("_")[2]) for file_name in downloaded_images]

print(f"Found {len(downloaded_images)}/{len(file_names)} images in the file system")

filtered_image_ids_to_symbol = filtered_image_ids_to_symbol.query("image_id.isin(@downloaded_ids)")

Found 463/594 images in the file system


In [302]:
filtered_image_ids_to_symbol

Unnamed: 0,plate,well,field,image_id,Gene Symbol,Channels,Control Type,Control Comments,ORF Identifier
90,41744,1312494,0,3191565,BRCA1,Hoechst 33342:nucleus;concanavalin A/AlexaFluo...,trt,,ccsbBroad304_00173
91,41744,1312494,1,3191566,BRCA1,Hoechst 33342:nucleus;concanavalin A/AlexaFluo...,trt,,ccsbBroad304_00173
92,41744,1312494,2,3191567,BRCA1,Hoechst 33342:nucleus;concanavalin A/AlexaFluo...,trt,,ccsbBroad304_00173
93,41744,1312494,3,3191568,BRCA1,Hoechst 33342:nucleus;concanavalin A/AlexaFluo...,trt,,ccsbBroad304_00173
94,41744,1312494,4,3191569,BRCA1,Hoechst 33342:nucleus;concanavalin A/AlexaFluo...,trt,,ccsbBroad304_00173
...,...,...,...,...,...,...,...,...,...
20521,41757,1315039,1,3227407,STAT3,Hoechst 33342:nucleus;concanavalin A/AlexaFluo...,trt,,BRDN0000464969
20522,41757,1315039,2,3227408,STAT3,Hoechst 33342:nucleus;concanavalin A/AlexaFluo...,trt,,BRDN0000464969
20524,41757,1315039,4,3227410,STAT3,Hoechst 33342:nucleus;concanavalin A/AlexaFluo...,trt,,BRDN0000464969
20526,41757,1315039,6,3227412,STAT3,Hoechst 33342:nucleus;concanavalin A/AlexaFluo...,trt,,BRDN0000464969


In [298]:
channels = filtered_image_ids_to_symbol.Channels.values[0]

In [315]:
metadata["screens"]["0"].keys()

dict_keys(['experiment_id', 'experiment_map_annotation', 'plates', 'total_plates', 'total_images_in_screen'])

In [316]:
channels

'Hoechst 33342:nucleus;concanavalin A/AlexaFluor488 conjugate:endoplasmic reticulum;SYTO14 green fluorescent nucleic acid stain:nucleoli and cytoplasmic RNA;wheat germ agglutinin/AlexaFluor594 conjugate (WGA):Golgi apparatus and plasma membrane;phalloidin/AlexaFluor594 conjugate:F_actin;MitoTracker Deep Red:mitochondria'

In [300]:
channels.split(";")

['Hoechst 33342:nucleus',
 'concanavalin A/AlexaFluor488 conjugate:endoplasmic reticulum',
 'SYTO14 green fluorescent nucleic acid stain:nucleoli and cytoplasmic RNA',
 'wheat germ agglutinin/AlexaFluor594 conjugate (WGA):Golgi apparatus and plasma membrane',
 'phalloidin/AlexaFluor594 conjugate:F_actin',
 'MitoTracker Deep Red:mitochondria']

In [307]:
simpler_channel = {
    "DNA": 0,
    "ER": 1,
    "RNA": 2,
    "AGP": 3,
    "Mito": 4,
}

In [308]:
def row_fn(row, simpler_channel=simpler_channel):
    res = {}
    for channel in simpler_channel:
        res[f"FileName_{channel}"] = f"screen_1751_{row['image_id']}_{simpler_channel[channel]}.png"

    return res

In [320]:
final_metadata = pd.concat(
    [filtered_image_ids_to_symbol, filtered_image_ids_to_symbol.apply(row_fn, axis=1).apply(pd.Series)], axis=1
).drop(columns=["Channels", "Control Type", "Control Comments"])

In [321]:
final_metadata

Unnamed: 0,plate,well,field,image_id,Gene Symbol,ORF Identifier,FileName_DNA,FileName_ER,FileName_RNA,FileName_AGP,FileName_Mito
90,41744,1312494,0,3191565,BRCA1,ccsbBroad304_00173,screen_1751_3191565_0.png,screen_1751_3191565_1.png,screen_1751_3191565_2.png,screen_1751_3191565_3.png,screen_1751_3191565_4.png
91,41744,1312494,1,3191566,BRCA1,ccsbBroad304_00173,screen_1751_3191566_0.png,screen_1751_3191566_1.png,screen_1751_3191566_2.png,screen_1751_3191566_3.png,screen_1751_3191566_4.png
92,41744,1312494,2,3191567,BRCA1,ccsbBroad304_00173,screen_1751_3191567_0.png,screen_1751_3191567_1.png,screen_1751_3191567_2.png,screen_1751_3191567_3.png,screen_1751_3191567_4.png
93,41744,1312494,3,3191568,BRCA1,ccsbBroad304_00173,screen_1751_3191568_0.png,screen_1751_3191568_1.png,screen_1751_3191568_2.png,screen_1751_3191568_3.png,screen_1751_3191568_4.png
94,41744,1312494,4,3191569,BRCA1,ccsbBroad304_00173,screen_1751_3191569_0.png,screen_1751_3191569_1.png,screen_1751_3191569_2.png,screen_1751_3191569_3.png,screen_1751_3191569_4.png
...,...,...,...,...,...,...,...,...,...,...,...
20521,41757,1315039,1,3227407,STAT3,BRDN0000464969,screen_1751_3227407_0.png,screen_1751_3227407_1.png,screen_1751_3227407_2.png,screen_1751_3227407_3.png,screen_1751_3227407_4.png
20522,41757,1315039,2,3227408,STAT3,BRDN0000464969,screen_1751_3227408_0.png,screen_1751_3227408_1.png,screen_1751_3227408_2.png,screen_1751_3227408_3.png,screen_1751_3227408_4.png
20524,41757,1315039,4,3227410,STAT3,BRDN0000464969,screen_1751_3227410_0.png,screen_1751_3227410_1.png,screen_1751_3227410_2.png,screen_1751_3227410_3.png,screen_1751_3227410_4.png
20526,41757,1315039,6,3227412,STAT3,BRDN0000464969,screen_1751_3227412_0.png,screen_1751_3227412_1.png,screen_1751_3227412_2.png,screen_1751_3227412_3.png,screen_1751_3227412_4.png


In [323]:
final_metadata["Gene Symbol"].value_counts()

Gene Symbol
STAT3    123
TP53      88
JUN       87
HIF1A     83
HSPA5     45
BRCA1     37
Name: count, dtype: int64

In [322]:
final_metadata.to_csv("../cpjump1/idr0033-rohban-pathways/processed_metadata.csv", index=False)

In [183]:
fs = os.listdir("../cpjump1/screen_1751")

In [303]:
[f for f in fs if "3191566" in f]

['screen_1751_3191566_0.png',
 'screen_1751_3191566_1.png',
 'screen_1751_3191566_2.png',
 'screen_1751_3191566_3.png',
 'screen_1751_3191566_4.png']

In [184]:
fs

['screen_1751_3230851_0.png',
 'screen_1751_3230851_1.png',
 'screen_1751_3230851_2.png',
 'screen_1751_3230851_3.png',
 'screen_1751_3230851_4.png',
 'screen_1751_3213328_0.png',
 'screen_1751_3213328_1.png',
 'screen_1751_3213328_2.png',
 'screen_1751_3213328_3.png',
 'screen_1751_3213328_4.png',
 'screen_1751_3254294_0.png',
 'screen_1751_3254294_1.png',
 'screen_1751_3254294_2.png',
 'screen_1751_3254294_3.png',
 'screen_1751_3254294_4.png',
 'screen_1751_3192023_0.png',
 'screen_1751_3192023_1.png',
 'screen_1751_3192023_2.png',
 'screen_1751_3192023_3.png',
 'screen_1751_3192023_4.png',
 'screen_1751_3230150_0.png',
 'screen_1751_3230150_1.png',
 'screen_1751_3230150_2.png',
 'screen_1751_3230150_3.png',
 'screen_1751_3230150_4.png',
 'screen_1751_3191498_0.png',
 'screen_1751_3191498_1.png',
 'screen_1751_3191498_2.png',
 'screen_1751_3191498_3.png',
 'screen_1751_3191498_4.png',
 'screen_1751_3244881_0.png',
 'screen_1751_3244881_1.png',
 'screen_1751_3244881_2.png',
 'screen_1

In [6]:
ScreenReader?

[0;31mInit signature:[0m [0mScreenReader[0m[0;34m([0m[0mf[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m      <no docstring>
[0;31mFile:[0m           /mnt/2547d4d7-6732-4154-b0e1-17b0c1e0c565/Document-2/Projet2/Stage/workspace/jump_models/src/pyidr/screenio.py
[0;31mType:[0m           type
[0;31mSubclasses:[0m     

In [14]:
bulk_annotation = pd.read_hdf("../cdna/bulk_annotations")

Object `pd.read_h5` not found.


In [7]:
with open("../cpjump1/idr0033-rohban-pathways/41744.screen") as f:
    screen = ScreenReader(f)

In [21]:
screen.wells[0]

{'Row': '0',
 'Column': '0',
 'ChannelNames': 'Hoechst,ERSyto,ERSytoBleed,PhGolgi,Mito',
 'Fields': ['/uod/idr/filesets/idr0033-rohban-pathways/20170214-original/images/41744/taoe005-u2os-72h-cp-a-au00044859_a01_s1_w<1506fb051-c2ad-45db-8a52-674278937a31,2e7e571af-8408-4d59-89ad-9141739170ac,396a49edf-a2af-4b25-88bf-cc5e82097e69,4aa25898c-cfcb-4930-9217-af6b79d8fd7d,5e47a4e68-3442-42e2-b756-ef7e68abc332>.tif',
  '/uod/idr/filesets/idr0033-rohban-pathways/20170214-original/images/41744/taoe005-u2os-72h-cp-a-au00044859_a01_s2_w<120ef055e-86ef-410f-9d7e-705f682cc370,2feeb40bb-9b62-4b2e-a72b-ef2b4f3ee99d,31db0aa69-6dd8-43e9-9160-a4997f7c7d59,4a48b50ba-f7e8-4ef9-bffe-aea0d9f7d5f0,58f4610c8-61a5-4717-b02d-e5d6a18a4e6f>.tif',
  '/uod/idr/filesets/idr0033-rohban-pathways/20170214-original/images/41744/taoe005-u2os-72h-cp-a-au00044859_a01_s3_w<10e6da511-41b5-4ab7-bf9d-c7981cc7e8ad,23db644df-02ee-429d-9559-09cf4625c62b,344efd3f4-13f9-4c50-ba11-9d8920ee3f6c,4df95942f-845c-4a19-8f98-ececb9330c44,5

In [23]:
screen_files = glob("../cpjump1/idr0033-rohban-pathways/*.screen")

In [25]:
screen_files

['../cpjump1/idr0033-rohban-pathways/41744.screen',
 '../cpjump1/idr0033-rohban-pathways/41749.screen',
 '../cpjump1/idr0033-rohban-pathways/41754.screen',
 '../cpjump1/idr0033-rohban-pathways/41755.screen',
 '../cpjump1/idr0033-rohban-pathways/41756.screen',
 '../cpjump1/idr0033-rohban-pathways/41757.screen']

In [26]:
load_data = []

for screen_f in screen_files:
    with open(screen_f) as f:
        screen = ScreenReader(f)
        load_data.append(screen.wells)

In [35]:
ex = load_data[0][0]
channels = ex["ChannelNames"]
field = ex["Fields"][0]

In [45]:
channel_files = field[field.index("<") + 1 : field.index(">")].split(",")
root = field[: field.index("<")]

In [179]:
idr_metadata_small

Unnamed: 0,Plate,Well,Gene Identifier,Gene Symbol,ORF Identifier,ORF Sequence
10,41744,A11,672.0,BRCA1,ccsbBroad304_00173,GGTCTATATAAGCAGAGCTCTCTGGCTAACTGTCGGGATCAACAAG...
36,41744,B13,3309.0,HSPA5,BRDN0000464901,GGTCTATATAAGCAGAGCTCTCTGGCTAACTGTCGGGATCAACAAG...
76,41744,D5,7157.0,TP53,BRDN0000464908,GGTCTATATAAGCAGAGCTCTCTGGCTAACTGTCGGGATCAACAAG...
80,41744,D9,3091.0,HIF1A,BRDN0000464910,GGTCTATATAAGCAGAGCTCTCTGGCTAACTGTCGGGATCAACAAG...
115,41744,E20,3725.0,JUN,ccsbBroad304_14682,GGTCTATATAAGCAGAGCTCTCTGGCTAACTGTCGGGATCAACAAG...
...,...,...,...,...,...,...
2076,41757,G13,6774.0,STAT3,ccsbBroad304_01609,GGTCTATATAAGCAGAGCTCTCTGGCTAACTGTCGGGATCAACAAG...
2182,41757,K23,3091.0,HIF1A,ccsbBroad304_06365,GGTCTATATAAGCAGAGCTCTCTGGCTAACTGTCGGGATCAACAAG...
2254,41757,N23,6774.0,STAT3,BRDN0000464968,GGTCTATATAAGCAGAGCTCTCTGGCTAACTGTCGGGATCAACAAG...
2268,41757,O13,7157.0,TP53,ccsbBroad304_07088,GGTCTATATAAGCAGAGCTCTCTGGCTAACTGTCGGGATCAACAAG...


In [50]:
fs

['screen_1751_3230851_0.png',
 'screen_1751_3230851_1.png',
 'screen_1751_3230851_2.png',
 'screen_1751_3230851_3.png',
 'screen_1751_3230851_4.png',
 'screen_1751_3213328_0.png',
 'screen_1751_3213328_1.png',
 'screen_1751_3213328_2.png',
 'screen_1751_3213328_3.png',
 'screen_1751_3213328_4.png',
 'screen_1751_3254294_0.png',
 'screen_1751_3254294_1.png',
 'screen_1751_3254294_2.png',
 'screen_1751_3254294_3.png',
 'screen_1751_3254294_4.png',
 'screen_1751_3192023_0.png',
 'screen_1751_3192023_1.png',
 'screen_1751_3192023_2.png',
 'screen_1751_3192023_3.png',
 'screen_1751_3192023_4.png',
 'screen_1751_3230150_0.png',
 'screen_1751_3230150_1.png',
 'screen_1751_3230150_2.png',
 'screen_1751_3230150_3.png',
 'screen_1751_3230150_4.png',
 'screen_1751_3191498_0.png',
 'screen_1751_3191498_1.png',
 'screen_1751_3191498_2.png',
 'screen_1751_3191498_3.png',
 'screen_1751_3191498_4.png',
 'screen_1751_3244881_0.png',
 'screen_1751_3244881_1.png',
 'screen_1751_3244881_2.png',
 'screen_1

In [17]:
pd.DataFrame(screen.wells)

Unnamed: 0,Row,Column,ChannelNames,Fields
0,0,0,"Hoechst,ERSyto,ERSytoBleed,PhGolgi,Mito",[/uod/idr/filesets/idr0033-rohban-pathways/201...
1,0,1,"Hoechst,ERSyto,ERSytoBleed,PhGolgi,Mito",[/uod/idr/filesets/idr0033-rohban-pathways/201...
2,0,2,"Hoechst,ERSyto,ERSytoBleed,PhGolgi,Mito",[/uod/idr/filesets/idr0033-rohban-pathways/201...
3,0,3,"Hoechst,ERSyto,ERSytoBleed,PhGolgi,Mito",[/uod/idr/filesets/idr0033-rohban-pathways/201...
4,0,4,"Hoechst,ERSyto,ERSytoBleed,PhGolgi,Mito",[/uod/idr/filesets/idr0033-rohban-pathways/201...
...,...,...,...,...
379,15,19,"Hoechst,ERSyto,ERSytoBleed,PhGolgi,Mito",[/uod/idr/filesets/idr0033-rohban-pathways/201...
380,15,20,"Hoechst,ERSyto,ERSytoBleed,PhGolgi,Mito",[/uod/idr/filesets/idr0033-rohban-pathways/201...
381,15,21,"Hoechst,ERSyto,ERSytoBleed,PhGolgi,Mito",[/uod/idr/filesets/idr0033-rohban-pathways/201...
382,15,22,"Hoechst,ERSyto,ERSytoBleed,PhGolgi,Mito",[/uod/idr/filesets/idr0033-rohban-pathways/201...


In [8]:
os.listdir(metadata_dir)

['compound.csv.gz',
 'crispr.csv.gz',
 'microscope_config.csv',
 'microscope_filter.csv',
 'orf.csv.gz',
 'plate.csv.gz',
 'README.md',
 'well.csv.gz',
 'compound.csv',
 'crispr.csv',
 'orf.csv',
 'plate.csv',
 'well.csv',
 'complete_metadata.csv',
 'resolution.csv',
 'JUMP-Target-1_compound_metadata.tsv',
 'JUMP-Target-1_compound_platemap.tsv',
 'JUMP-Target-1_crispr_metadata.tsv',
 'JUMP-Target-1_crispr_platemap.tsv',
 'JUMP-Target-1_orf_metadata.tsv',
 'JUMP-Target-1_orf_platemap.tsv',
 'JUMP-Target-2_compound_metadata.tsv',
 'JUMP-Target-2_compound_platemap.tsv',
 'JUMP-MOA_compound_metadata.tsv',
 'local_metadata.csv']

In [1]:
%load_ext autoreload
%autoreload 2

In [134]:
GlobalHydra.instance().clear()

In [135]:
initialize(version_base=None, config_path="../configs")

hydra.initialize()

In [171]:
cfg = compose(
    config_name="eval/evaluators.yaml",
    overrides=[
        "+paths.output_dir=/projects/cpjump1/jump/logs/train/runs/2023-07-20_11-52-43",
        "+paths.data_root_dir=../cpjump1",
        "+paths.metadata_path=${paths.data_root_dir}/jump/models/metadata",
        "+paths.raw_metadata_path=${paths.data_root_dir}/jump/metadata",
        "+paths.load_data_path=${paths.data_root_dir}/jump/load_data",
        "+data.transform.size=256",
        "+trainer.devices=[0]",
    ],
)

In [172]:
print(OmegaConf.to_yaml(cfg))

eval:
  moa_image_task:
    trainer:
      _target_: lightning.pytorch.trainer.Trainer
      default_root_dir: ${paths.output_dir}/eval/moa
      min_epochs: 0
      max_epochs: 100
      accelerator: gpu
      devices: ${trainer.devices}
      check_val_every_n_epoch: 2
      deterministic: false
      log_every_n_steps: 1
      gradient_clip_val: 0.5
      num_sanity_val_steps: 1
    callbacks:
      rich_progress_bar:
        _target_: lightning.pytorch.callbacks.RichProgressBar
      model_checkpoint: null
      early_stopping:
        _target_: lightning.pytorch.callbacks.EarlyStopping
        monitor: jump_moa/image/val/AUROC
        min_delta: 0
        patience: 10
        verbose: false
        mode: max
        strict: true
        check_finite: true
        stopping_threshold: null
        divergence_threshold: null
        check_on_train_epoch_end: null
      wandb_plotter:
        _target_: src.callbacks.wandb.WandbPlottingCallback
        watch: true
        watch_log: al

In [154]:
logging.basicConfig(level=logging.INFO)

In [176]:
ev_list = utils.instantiate_evaluator_list(cfg.eval, cross_modal_module=model)

INFO:src.utils.instantiators:Instantiating callback <lightning.pytorch.callbacks.RichProgressBar>
INFO:src.utils.instantiators:Instantiating callback <lightning.pytorch.callbacks.EarlyStopping>
INFO:src.utils.instantiators:Instantiating callback <src.callbacks.wandb.WandbPlottingCallback>
INFO:src.utils.instantiators:Instantiating callback <lightning.pytorch.callbacks.LearningRateMonitor>
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
INFO:src.utils.instantiators:Instantiating evaluator <src.eval.moa.module.JumpMOAImageModule>


In [177]:
ev_list

EvaluatorList(
    n_evaluators=1,
    evaluators=
        (JumpMOAImageModule) Evaluator(
            datamodule=JumpMOADataModule(../cpjump1/jump/models/eval/moa/image_task/moa_1024.csv),
            model=JumpMOAImageModule(CNNEncoder(512), num_classes=26),
        )
)

In [108]:
module = instantiate(cfg.eval.moa.model, cross_modal_module=model, example_input_path=None)

In [130]:
callbacks = utils.instantiate_callbacks(cfg.eval.moa.trainer.callbacks)

In [131]:
callbacks

[<lightning.pytorch.callbacks.progress.rich_progress.RichProgressBar at 0x7f589a8dee00>,
 <lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint at 0x7f589aaff490>,
 <lightning.pytorch.callbacks.early_stopping.EarlyStopping at 0x7f589aafe920>,
 <src.callbacks.wandb.WandbPlottingCallback at 0x7f589aafed70>,
 <lightning.pytorch.callbacks.lr_monitor.LearningRateMonitor at 0x7f589aafef50>]

In [125]:
tc = OmegaConf.to_container(cfg.eval.moa.trainer, resolve=True)
tc.pop("_target_")
tc.pop("callbacks")
tc

{'default_root_dir': '/projects/cpjump1/jump/logs/train/runs/2023-07-20_11-52-43/eval/moa',
 'min_epochs': 0,
 'max_epochs': 100,
 'accelerator': 'gpu',
 'devices': [0],
 'check_val_every_n_epoch': 2,
 'deterministic': False,
 'log_every_n_steps': 1,
 'gradient_clip_val': 0.5,
 'num_sanity_val_steps': 1}

In [None]:
callbacks

In [132]:
Trainer(
    callbacks=callbacks,
    **tc,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


<lightning.pytorch.trainer.trainer.Trainer at 0x7f589ab116f0>

In [112]:
instantiate(cfg.eval.moa.trainer)

InstantiationException: Error in call to target 'lightning.pytorch.trainer.trainer.Trainer':
ConfigAttributeError('Missing key append\n    full_key: append\n    object_type=dict')
full_key: eval.moa.trainer

In [110]:
evaluator = instantiate(cfg.eval.moa, model=module)

InstantiationException: Error in call to target 'lightning.pytorch.trainer.trainer.Trainer':
ConfigAttributeError('Missing key append\n    full_key: append\n    object_type=dict')
full_key: eval.moa.trainer

In [88]:
evaluator.model = evaluator.model(cross_modal_module=model, image_encoder=None, example_input_path=None)

In [89]:
evaluator

(JumpMOAImageModule) Evaluator(
            datamodule=JumpMOADataModule(../cpjump1/jump/models/eval/moa/image_task/moa_1024.csv),
            model=JumpMOAImageModule(
  (image_encoder): CNNEncoder(
    (backbone): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (drop_block): Identity()
          (act1): ReLU(inplace=True)
          (aa): Identity()
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchN

In [92]:
evaluator_list = EvaluatorList([evaluator])

In [94]:
evaluator.trainer

{'callbacks': {'rich_progress_bar': <lightning.pytorch.callbacks.progress.rich_progress.RichProgressBar object at 0x7f589b18b640>, 'model_checkpoint': <lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint object at 0x7f589b18b730>, 'early_stopping': <lightning.pytorch.callbacks.early_stopping.EarlyStopping object at 0x7f589b18a9b0>, 'wandb_plotter': <src.callbacks.wandb.WandbPlottingCallback object at 0x7f589b18b7c0>, 'lr_monitor': <lightning.pytorch.callbacks.lr_monitor.LearningRateMonitor object at 0x7f589b18b880>}, 'default_root_dir': '/projects/cpjump1/jump/logs/train/runs/2023-07-20_11-52-43/eval/moa', 'min_epochs': 0, 'max_epochs': 100, 'log_every_n_steps': 1, 'gradient_clip_val': 0.5, 'num_sanity_val_steps': 1}

In [93]:
evaluator_list.run()

ConfigAttributeError: Missing key fit
    full_key: fit
    object_type=dict

In [32]:
run_dir = Path("../cpjump1/jump/logs/train/runs/2023-08-07_14-45-24/")
ckpt_path = run_dir / "checkpoints"
best_ckpt = ckpt_path / "epoch_009.ckpt"
hparams_path = run_dir / "csv/version_0/hparams.yaml"

In [33]:
with open(hparams_path) as f:
    hparams = yaml.load(f, Loader=yaml.FullLoader)

In [35]:
cfg2 = DictConfig(hparams)

In [39]:
cfg2.model.example_input_path = "../cpjump1/jump/models/example_batch/simple_jump_cl/batch.pth"

In [40]:
model = instantiate(cfg2.model)

Downloading gin_supervised_infomax_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_infomax.pth...
Pretrained model loaded


In [41]:
model

BasicJUMPModule(
  (image_encoder): CNNEncoder(
    (backbone): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (drop_block): Identity()
          (act1): ReLU(inplace=True)
          (aa): Identity()
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act2): ReLU(inplace=True)
        )
        (1): BasicBlock(
      

In [49]:
evaluator.model(cross_modal_module=model)

FileNotFoundError: [Errno 2] No such file or directory: '../cpjump1/jump/models/eval/test/example.pt"'