# HINT Top Benchmark evaluation

In [1]:
%load_ext autoreload
%autoreload 2

## Dataset creation

In [2]:
import os
from pathlib import Path

import torch
from dgl import batch
from dgllife.model.model_zoo.gin_predictor import GINPredictor
from lightning.pytorch import Trainer, loggers
from torch import nn

from src.eval.clinical_prediction.datamodule import HintClinicalDataModulePhaseI
from src.eval.clinical_prediction.module import HintClinicalModulePhaseI
from src.modules.compound_transforms.dgllife_transform import DGLPretrainedFromSmiles
from src.modules.compound_transforms.fp_transform import FPTransform
from src.modules.compound_transforms.pna import PNATransform
from src.modules.molecules.dgllife_gin import GINPretrainedWithLinearHead
from src.modules.molecules.pna import PNA

In [3]:
for i in range(1, 4):
    if not Path(f"../cpjump{i}/jump/").exists():
        print(f"Mounting cpjump{i}...")
        os.system(f"sshfs bioclust:/projects/cpjump{i}/ ../cpjump{i}")
    else:
        print(f"cpjump{i} already mounted.")

cpjump1 already mounted.
cpjump2 already mounted.
cpjump3 already mounted.


In [4]:
metadata_path = "../cpjump1/jump/metadata"
load_data_path = "../cpjump1/jump/load_data"
hint_path = "../cpjump1/hint-clinical-trial-outcome-prediction/data"

In [5]:
os.listdir(metadata_path)

['compound.csv.gz',
 'crispr.csv.gz',
 'microscope_config.csv',
 'microscope_filter.csv',
 'orf.csv.gz',
 'plate.csv.gz',
 'README.md',
 'well.csv.gz',
 'compound.csv',
 'crispr.csv',
 'orf.csv',
 'plate.csv',
 'well.csv',
 'complete_metadata.csv',
 'resolution.csv',
 'JUMP-Target-1_compound_metadata.tsv',
 'JUMP-Target-1_compound_platemap.tsv',
 'JUMP-Target-1_crispr_metadata.tsv',
 'JUMP-Target-1_crispr_platemap.tsv',
 'JUMP-Target-1_orf_metadata.tsv',
 'JUMP-Target-1_orf_platemap.tsv',
 'JUMP-Target-2_compound_metadata.tsv',
 'JUMP-Target-2_compound_platemap.tsv',
 'JUMP-MOA_compound_metadata.tsv',
 'local_metadata.csv']

In [6]:
os.listdir(hint_path)

['ADMET',
 'NCT00000378.xml',
 'README.md',
 'drugbank_mini.csv',
 'phase_III_test.csv',
 'phase_III_train.csv',
 'phase_III_valid.csv',
 'phase_II_test.csv',
 'phase_II_train.csv',
 'phase_II_valid.csv',
 'phase_I_test.csv',
 'phase_I_train.csv',
 'phase_I_valid.csv',
 'raw_data.csv',
 'sentence2embedding.pkl',
 'sponsor2approvalrate.csv',
 'sponsor2count.csv',
 'toy_test.csv',
 'toy_train.csv',
 'toy_valid.csv']

## Load phase csvs

In [7]:
smiles_to_graph = DGLPretrainedFromSmiles(
    add_self_loop=True,
    canonical_atom_order=True,
    num_virtual_nodes=0,
    explicit_hydrogens=False,
)

smiles_to_fp = FPTransform(
    fps=["maccs", "ecfp"],
    compound_str_type="smiles",
    params={"ecfp": {"radius": 2}},
)

### Custom GIN

In [75]:
class GINPredictorWrapper(nn.Module):
    def __init__(
        self,
        num_layers=5,
        emb_dim=256,
        JK="last",
        n_tasks=256,
    ):
        super().__init__()
        self.atom_featurizer = CanonicalAtomFeaturizer()
        self.bond_featurizer = CanonicalBondFeaturizer()

        self.smiles_to_bigraph = SMILESToBigraph(
            node_featurizer=self.atom_featurizer,
            edge_featurizer=self.bond_featurizer,
            add_self_loop=True,
            canonical_atom_order=True,
            num_virtual_nodes=0,
            explicit_hydrogens=False,
        )

        self.gin = GINPredictor(
            num_node_emb_list=[self.atom_featurizer.feat_size("h")],
            num_edge_emb_list=[self.bond_featurizer.feat_size("e")],
            num_layers=num_layers,
            emb_dim=emb_dim,
            JK=JK,
            n_tasks=n_tasks,
        )

    def smiles_to_graph(self, smiles):
        return batch([self.smiles_to_bigraph(s) for s in smiles])

    def forward(self, smiles, **kwargs):
        graphs = self.smiles_to_graph(smiles)
        node_feats = graphs.ndata.get("h")
        edge_feats = graphs.edata.get("e")
        return self.gin(graphs, node_feats, edge_feats)

In [77]:
model = GINPredictorWrapper()

### Trainer

In [8]:
mol_model = GINPretrainedWithLinearHead(
    pretrained_name="gin_supervised_contextpred",
    out_dim=256,
    pooling="mean",
    preload=False,
)

Downloading gin_supervised_contextpred_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_contextpred.pth...
Pretrained model loaded


In [75]:
pna = PNA(
    hidden_dim=200,
    target_dim=256,
    out_dim=256,
    aggregators=["mean", "max", "min", "std"],
    scalers=["identity", "amplification", "attenuation"],
    readout_aggregators=["min", "max", "mean"],
    readout_batchnorm=True,
    readout_hidden_dim=200,
    readout_layers=2,
    residual=True,
    pairwise_distances=False,
    activation="relu",
    last_activation="none",
    mid_batch_norm=True,
    last_batch_norm=True,
    propagation_depth=7,
    dropout=0.0,
    posttrans_layers=1,
    pretrans_layers=2,
    batch_norm_momentum=0.93,
)

In [82]:
model = HintClinicalModulePhaseI(
    molecule_encoder=pna,
    compound_transform=PNATransform("inchi"),
    optimizer=torch.optim.Adam,
    lr=1e-3,
)

In [14]:
dm = HintClinicalDataModulePhaseI(
    hint_dir=hint_path,
    batch_size=128,
)

In [17]:
dm.prepare_data()

In [23]:
dm.setup("fit")

In [24]:
dm.train_dataset

<src.eval.clinical_prediction.datamodule.HintClinicalDataset at 0x7f1e05d7afb0>

In [25]:
dl = dm.train_dataloader()

In [26]:
b = next(iter(dl))

In [29]:
b

{'smiles_list': [['C#Cc1cccc(Nc2ncnc3cc(OCCOC)c(OCCOC)cc23)c1'],
  ['CNC(=O)c1cc(Oc2ccc(NC(=O)Nc3ccc(Cl)c(C(F)(F)F)c3)cc2)ccn1',
   'C#Cc1cccc(Nc2ncnc3cc(OCCOC)c(OCCOC)cc23)c1'],
  ['Cn1c(=O)cc(N2CCC[C@@H](N)C2)n(Cc2ccccc2C#N)c1=O',
   'Cn1c(=O)cc(N2CCC[C@@H](N)C2)n(Cc2ccccc2C#N)c1=O'],
  ['CCC1=C[C@@H]2CN(C1)Cc1c([nH]c3ccccc13)[C@@](C(=O)OC)(c1cc3c(cc1OC)N(C)[C@H]1[C@@](O)(C(=O)OC)[C@H](OC(C)=O)[C@]4(CC)C=CCN5CC[C@]31[C@@H]54)C2'],
  ['CO[C@H]1C[C@@H]2CC[C@@H](C)[C@@](O)(O2)C(=O)C(=O)N2CCCC[C@H]2C(=O)O[C@H]([C@H](C)C[C@@H]2CC[C@@H](O)[C@H](OC)C2)CC(=O)[C@H](C)/C=C(\\C)[C@@H](O)[C@@H](OC)C(=O)[C@H](C)C[C@H](C)/C=C/C=C/C=C/1C',
   'C[C@@H](O[C@H]1OCCN(Cc2n[nH]c(=O)[nH]2)[C@H]1c1ccc(F)cc1)c1cc(C(F)(F)F)cc(C(F)(F)F)c1'],
  ['Cn1c(=O)cc(N2CCC[C@@H](N)C2)n(Cc2ccccc2C#N)c1=O',
   'CC(=O)Oc1ccccc1C(=O)O',
   'CC(=O)Oc1ccccc1C(=O)O'],
  ['N=C(N)NCCC[C@H](N)C(=O)O',
   'CCN(CC)CCCC(C)Nc1ccnc2cc(Cl)ccc12',
   'CO[C@H]1C[C@@H]2CC[C@@H](C)[C@@](O)(O2)C(=O)C(=O)N2CCCC[C@H]2C(=O)O[C@H]([C@H](C)C[C@@

In [77]:
batched_graphs, ids = model.get_batched_graphs(b["smiles_list"])

Empty edges for [Cl-].[Na+]
Empty edges for [Cl-].[Na+]


In [78]:
batched_graphs

Graph(num_nodes=6664, num_edges=14318,
      ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}
      edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})

In [83]:
model.molecule_encoder

PNA(
  (node_gnn): PNAGNN(
    (mp_layers): ModuleList(
      (0-6): 7 x PNALayer(
        (pretrans): MLP(
          (fully_connected): ModuleList(
            (0): FCLayer(
              (linear): Linear(in_features=600, out_features=200, bias=True)
              (batch_norm): BatchNorm1d(200, eps=1e-05, momentum=0.93, affine=True, track_running_stats=True)
              (activation): ReLU()
            )
            (1): FCLayer(
              (linear): Linear(in_features=200, out_features=200, bias=True)
              (batch_norm): BatchNorm1d(200, eps=1e-05, momentum=0.93, affine=True, track_running_stats=True)
            )
          )
        )
        (posttrans): MLP(
          (fully_connected): ModuleList(
            (0): FCLayer(
              (linear): Linear(in_features=2600, out_features=200, bias=True)
              (batch_norm): BatchNorm1d(200, eps=1e-05, momentum=0.93, affine=True, track_running_stats=True)
            )
          )
        )
      )
    )
    (atom

In [84]:
b_emb = model.molecule_encoder(batched_graphs)

In [87]:
model.model_step(b, "train")

Empty edges for [Cl-].[Na+]
Empty edges for [Cl-].[Na+]


ValueError: Target size (torch.Size([128, 1])) must be the same as input size (torch.Size([127, 1]))

In [28]:
model(**b)

TypeError: HintClinicalModule.forward() missing 1 required positional argument: 'compound'

In [17]:
logger = [
    loggers.WandbLogger(
        project="clinical-trial-outcome-prediction",
        name="gin-supervised-contextpred",
        log_model=True,
    )
]

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mgabriel-watkinson-work[0m ([33mjump_models[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [34]:
trainer = Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=50,
    logger=logger,
    log_every_n_steps=1,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [35]:
trainer.fit(model, dm)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                | Type                        | Params
---------------------------------------------------------------------
0  | molecule_encoder    | GINPretrainedWithLinearHead | 2.0 M 
1  | head                | Sequential                  | 66.0 K
2  | criterion           | BCEWithLogitsLoss           | 0     
3  | train_loss          | MeanMetric                  | 0     
4  | val_loss            | MeanMetric                  | 0     
5  | test_loss           | MeanMetric                  | 0     
6  | train_other_metrics | MetricCollection            | 0     
7  | val_other_metrics   | MetricCollection            | 0     
8  | test_other_metrics  | MetricCollection            | 0     
9  | train_plot_metrics  | MetricCollection            | 0     
10 | val_plot_metrics    | MetricCollection            | 0     
11 | test_plot_metrics   | MetricCollectio

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.
