# HINT Top Benchmark evaluation

In [1]:
%load_ext autoreload
%autoreload 2

## Dataset creation

In [2]:
import csv
import logging
import os
import os.path as osp
import shutil
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import yaml
from lightning import LightningDataModule
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy

from src.eval.moa.datamodule import JumpMOADataModule
from src.eval.moa.module import JumpMOAImageGraphModule, JumpMOAImageModule
from src.modules.collate_fn import image_graph_label_collate_function, label_graph_collate_function
from src.modules.compound_transforms import DGLPretrainedFromSmiles
from src.modules.images.timm_pretrained import CNNEncoder
from src.modules.molecules.dgllife_gin import GINPretrainedWithLinearHead
from src.modules.transforms import DefaultJUMPTransform
from src.splitters import StratifiedSplitter
from src.utils.io import download_and_extract_zip, load_image_paths_to_array



In [3]:
for i in range(1, 4):
    if not Path(f"../cpjump{i}/jump/").exists():
        print(f"Mounting cpjump{i}...")
        os.system(f"sshfs bioclust:/projects/cpjump{i}/ ../cpjump{i}")
    else:
        print(f"cpjump{i} already mounted.")

cpjump1 already mounted.
cpjump2 already mounted.
cpjump3 already mounted.


In [7]:
metadata_path = "../cpjump1/jump/metadata"
load_data_path = "../cpjump1/jump/load_data"
hint_path = "../cpjump1/hint-clinical-trial-outcome-prediction/data"

In [5]:
os.listdir(metadata_path)

['compound.csv.gz',
 'crispr.csv.gz',
 'microscope_config.csv',
 'microscope_filter.csv',
 'orf.csv.gz',
 'plate.csv.gz',
 'README.md',
 'well.csv.gz',
 'compound.csv',
 'crispr.csv',
 'orf.csv',
 'plate.csv',
 'well.csv',
 'complete_metadata.csv',
 'resolution.csv',
 'JUMP-Target-1_compound_metadata.tsv',
 'JUMP-Target-1_compound_platemap.tsv',
 'JUMP-Target-1_crispr_metadata.tsv',
 'JUMP-Target-1_crispr_platemap.tsv',
 'JUMP-Target-1_orf_metadata.tsv',
 'JUMP-Target-1_orf_platemap.tsv',
 'JUMP-Target-2_compound_metadata.tsv',
 'JUMP-Target-2_compound_platemap.tsv',
 'JUMP-MOA_compound_metadata.tsv',
 'local_metadata.csv']

In [8]:
os.listdir(hint_path)

['ADMET',
 'NCT00000378.xml',
 'README.md',
 'drugbank_mini.csv',
 'phase_III_test.csv',
 'phase_III_train.csv',
 'phase_III_valid.csv',
 'phase_II_test.csv',
 'phase_II_train.csv',
 'phase_II_valid.csv',
 'phase_I_test.csv',
 'phase_I_train.csv',
 'phase_I_valid.csv',
 'raw_data.csv',
 'sentence2embedding.pkl',
 'sponsor2approvalrate.csv',
 'sponsor2count.csv',
 'toy_test.csv',
 'toy_train.csv',
 'toy_valid.csv']

## Load phase csvs

In [42]:
class Trial_Dataset(Dataset):
    def __init__(self, nctid_lst, label_lst, smiles_lst, icdcode_lst):  # , criteria_lst):
        self.nctid_lst = nctid_lst
        self.label_lst = label_lst
        self.smiles_lst = smiles_lst
        self.icdcode_lst = icdcode_lst
        # self.criteria_lst = criteria_lst

    def __len__(self):
        return len(self.nctid_lst)

    def __getitem__(self, index):
        return (
            self.nctid_lst[index],
            self.label_lst[index],
            self.smiles_lst[index],
            self.icdcode_lst[index],
        )  # , self.criteria_lst[index]

    #### smiles_lst[index] is list of smiles

In [38]:
def csv_three_feature_2_dataloader(csvfile, shuffle, batch_size):
    with open(csvfile) as csvfile:
        rows = list(csv.reader(csvfile, delimiter=","))[1:]
    ## nctid,status,why_stop,label,phase,diseases,icdcodes,drugs,smiless,criteria
    nctid_lst = [row[0] for row in rows]
    label_lst = [row[3] for row in rows]
    icdcode_lst = [row[6] for row in rows]
    drugs_lst = [row[7] for row in rows]
    smiles_lst = [row[8] for row in rows]
    # criteria_lst = [row[9] for row in rows]
    dataset = Trial_Dataset(nctid_lst, label_lst, smiles_lst, icdcode_lst)  # , criteria_lst)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=trial_collate_fn)
    return data_loader

In [35]:
def smiles_txt_to_2lst(smiles_txt_file):
    with open(smiles_txt_file) as fin:
        lines = fin.readlines()
    smiles_lst = [line.split()[0] for line in lines]
    label_lst = [int(line.split()[1]) for line in lines]
    return smiles_lst, label_lst


def smiles_txt_to_lst(text):
    """
    "['CN[C@H]1CC[C@@H](C2=CC(Cl)=C(Cl)C=C2)C2=CC=CC=C12', 'CNCCC=C1C2=CC=CC=C2CCC2=CC=CC=C12']"
    """
    text = text[1:-1]
    lst = [i.strip()[1:-1] for i in text.split(",")]
    return lst


def icdcode_text_2_lst_of_lst(text):
    text = text[2:-2]
    lst_lst = []
    for i in text.split('", "'):
        i = i[1:-1]
        lst_lst.append([j.strip()[1:-1] for j in i.split(",")])
    return lst_lst


def trial_collate_fn(x):
    nctid_lst = [i[0] for i in x]  ### ['NCT00604461', ..., 'NCT00788957']
    label_vec = default_collate([int(i[1]) for i in x])  ### shape n,
    smiles_lst = [smiles_txt_to_lst(i[2]) for i in x]
    icdcode_lst = [icdcode_text_2_lst_of_lst(i[3]) for i in x]
    # criteria_lst = [protocol2feature(i[4], sentence2vec) for i in x]
    return [nctid_lst, label_vec, smiles_lst, icdcode_lst]

In [36]:
phase_1_train = osp.join(hint_path, "phase_I_train.csv")

In [48]:
df = pd.read_csv(phase_1_train)

In [52]:
df["smiless"].values[0]

"['[H][N]1([H])[C@@H]2CCCC[C@H]2[N]([H])([H])[Pt]11OC(=O)C(=O)O1', '[H][N]1([H])[C@@H]2CCCC[C@H]2[N]([H])([H])[Pt]11OC(=O)C(=O)O1']"

In [43]:
dl = csv_three_feature_2_dataloader(phase_1_train, False, 4)

In [44]:
b = next(iter(dl))

In [45]:
b

[['NCT01187615', 'NCT01046487', 'NCT01381887', 'NCT02015676'],
 tensor([0, 1, 1, 1]),
 [['[H][N]1([H])[C@@H]2CCCC[C@H]2[N]([H])([H])[Pt]11OC(=O)C(=O)O1',
   '[H][N]1([H])[C@@H]2CCCC[C@H]2[N]([H])([H])[Pt]11OC(=O)C(=O)O1'],
  ['CC1=NC(NC2=NC=C(S2)C(=O)NC2=C(C)C=CC=C2Cl)=CC(=N1)N1CCN(CCO)CC1',
   'CC1=NC(NC2=NC=C(S2)C(=O)NC2=C(C)C=CC=C2Cl)=CC(=N1)N1CCN(CCO)CC1',
   'CC1=NC(NC2=NC=C(S2)C(=O)NC2=C(C)C=CC=C2Cl)=CC(=N1)N1CCN(CCO)CC1'],
  ['CN1C(=O)C=C(N2CCC[C@@H](N)C2)N(CC2=C(C=CC=C2)C#N)C1=O',
   '[H][C@]1(O[C@H](CO)[C@@H](O)[C@H](O)[C@H]1O)C1=CC=C(C)C(CC2=CC=C(S2)C2=CC=C(F)C=C2)=C1',
   '[H][C@]1(O[C@H](CO)[C@@H](O)[C@H](O)[C@H]1O)C1=CC=C(C)C(CC2=CC=C(S2)C2=CC=C(F)C=C2)=C1',
   '[H][C@]1(O[C@H](CO)[C@@H](O)[C@H](O)[C@H]1O)C1=CC=C(C)C(CC2=CC=C(S2)C2=CC=C(F)C=C2)=C1'],
  ['[H][N]1([H])[C@@H]2CCCC[C@H]2[N]([H])([H])[Pt]11OC(=O)C(=O)O1',
   '[H][C@@]1(C[C@@H](C)[C@]2([H])CC(=O)[C@H](C)\\\\C=C(C)\\\\[C@@H](O)[C@@H](OC)C(=O)[C@H](C)C[C@H](C)\\\\C=C\\\\C=C\\\\C=C(C)\\\\[C@H](C[C@]3([H])CC[C@@H](C