# HINT Top Benchmark evaluation

In [1]:
%load_ext autoreload
%autoreload 2

## Dataset creation

In [2]:
import os
from pathlib import Path

import torch
from dgl import batch
from dgllife.model.model_zoo.gin_predictor import GINPredictor
from lightning.pytorch import Trainer, loggers
from torch import nn

from src.eval.clinical_prediction.datamodule import HintClinicalDataModulePhaseI
from src.eval.clinical_prediction.module import HintClinicalModulePhaseI
from src.modules.compound_transforms.dgllife_transform import DGLPretrainedFromSmiles
from src.modules.compound_transforms.fp_transform import FPTransform
from src.modules.molecules.dgllife_gin import GINPretrainedWithLinearHead

In [3]:
for i in range(1, 4):
    if not Path(f"../cpjump{i}/jump/").exists():
        print(f"Mounting cpjump{i}...")
        os.system(f"sshfs bioclust:/projects/cpjump{i}/ ../cpjump{i}")
    else:
        print(f"cpjump{i} already mounted.")

cpjump1 already mounted.
cpjump2 already mounted.
cpjump3 already mounted.


In [4]:
metadata_path = "../cpjump1/jump/metadata"
load_data_path = "../cpjump1/jump/load_data"
hint_path = "../cpjump1/hint-clinical-trial-outcome-prediction/data"

In [5]:
os.listdir(metadata_path)

['compound.csv.gz',
 'crispr.csv.gz',
 'microscope_config.csv',
 'microscope_filter.csv',
 'orf.csv.gz',
 'plate.csv.gz',
 'README.md',
 'well.csv.gz',
 'compound.csv',
 'crispr.csv',
 'orf.csv',
 'plate.csv',
 'well.csv',
 'complete_metadata.csv',
 'resolution.csv',
 'JUMP-Target-1_compound_metadata.tsv',
 'JUMP-Target-1_compound_platemap.tsv',
 'JUMP-Target-1_crispr_metadata.tsv',
 'JUMP-Target-1_crispr_platemap.tsv',
 'JUMP-Target-1_orf_metadata.tsv',
 'JUMP-Target-1_orf_platemap.tsv',
 'JUMP-Target-2_compound_metadata.tsv',
 'JUMP-Target-2_compound_platemap.tsv',
 'JUMP-MOA_compound_metadata.tsv',
 'local_metadata.csv']

In [6]:
os.listdir(hint_path)

['ADMET',
 'NCT00000378.xml',
 'README.md',
 'drugbank_mini.csv',
 'phase_III_test.csv',
 'phase_III_train.csv',
 'phase_III_valid.csv',
 'phase_II_test.csv',
 'phase_II_train.csv',
 'phase_II_valid.csv',
 'phase_I_test.csv',
 'phase_I_train.csv',
 'phase_I_valid.csv',
 'raw_data.csv',
 'sentence2embedding.pkl',
 'sponsor2approvalrate.csv',
 'sponsor2count.csv',
 'toy_test.csv',
 'toy_train.csv',
 'toy_valid.csv']

## Load phase csvs

In [7]:
smiles_to_graph = DGLPretrainedFromSmiles(
    add_self_loop=True,
    canonical_atom_order=True,
    num_virtual_nodes=0,
    explicit_hydrogens=False,
)

smiles_to_fp = FPTransform(
    fps=["maccs", "ecfp"],
    compound_str_type="smiles",
    params={"ecfp": {"radius": 2}},
)

### Custom GIN

In [75]:
class GINPredictorWrapper(nn.Module):
    def __init__(
        self,
        num_layers=5,
        emb_dim=256,
        JK="last",
        n_tasks=256,
    ):
        super().__init__()
        self.atom_featurizer = CanonicalAtomFeaturizer()
        self.bond_featurizer = CanonicalBondFeaturizer()

        self.smiles_to_bigraph = SMILESToBigraph(
            node_featurizer=self.atom_featurizer,
            edge_featurizer=self.bond_featurizer,
            add_self_loop=True,
            canonical_atom_order=True,
            num_virtual_nodes=0,
            explicit_hydrogens=False,
        )

        self.gin = GINPredictor(
            num_node_emb_list=[self.atom_featurizer.feat_size("h")],
            num_edge_emb_list=[self.bond_featurizer.feat_size("e")],
            num_layers=num_layers,
            emb_dim=emb_dim,
            JK=JK,
            n_tasks=n_tasks,
        )

    def smiles_to_graph(self, smiles):
        return batch([self.smiles_to_bigraph(s) for s in smiles])

    def forward(self, smiles, **kwargs):
        graphs = self.smiles_to_graph(smiles)
        node_feats = graphs.ndata.get("h")
        edge_feats = graphs.edata.get("e")
        return self.gin(graphs, node_feats, edge_feats)

In [77]:
model = GINPredictorWrapper()

### Trainer

In [11]:
mol_model = GINPretrainedWithLinearHead(
    pretrained_name="gin_supervised_contextpred",
    out_dim=256,
    pooling="mean",
    preload=False,
)

Downloading gin_supervised_contextpred_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_contextpred.pth...
Pretrained model loaded


In [13]:
model = HintClinicalModulePhaseI(
    molecule_encoder=mol_model,
    compound_transform=smiles_to_graph,
    optimizer=torch.optim.Adam,
    lr=1e-3,
)

In [14]:
dm = HintClinicalDataModulePhaseI(
    hint_dir=hint_path,
    batch_size=128,
)

In [15]:
dm.prepare_data()

In [17]:
logger = [
    loggers.WandbLogger(
        project="clinical-trial-outcome-prediction",
        name="gin-supervised-contextpred",
        log_model=True,
    )
]

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mgabriel-watkinson-work[0m ([33mjump_models[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [34]:
trainer = Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=50,
    logger=logger,
    log_every_n_steps=1,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [35]:
trainer.fit(model, dm)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                | Type                        | Params
---------------------------------------------------------------------
0  | molecule_encoder    | GINPretrainedWithLinearHead | 2.0 M 
1  | head                | Sequential                  | 66.0 K
2  | criterion           | BCEWithLogitsLoss           | 0     
3  | train_loss          | MeanMetric                  | 0     
4  | val_loss            | MeanMetric                  | 0     
5  | test_loss           | MeanMetric                  | 0     
6  | train_other_metrics | MetricCollection            | 0     
7  | val_other_metrics   | MetricCollection            | 0     
8  | test_other_metrics  | MetricCollection            | 0     
9  | train_plot_metrics  | MetricCollection            | 0     
10 | val_plot_metrics    | MetricCollection            | 0     
11 | test_plot_metrics   | MetricCollectio

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.
