# Develop the IDR retrieval evaluator

In [1]:
%load_ext autoreload
%autoreload 2

In [24]:
import os
from collections import defaultdict
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from torchmetrics import MetricCollection
from torchmetrics.functional import pairwise_cosine_similarity, retrieval_hit_rate
from torchmetrics.retrieval import (
    RetrievalFallOut,
    RetrievalHitRate,
    RetrievalMAP,
    RetrievalMRR,
    RetrievalNormalizedDCG,
    RetrievalPrecision,
    RetrievalRPrecision,
)

from src.eval.retrieval import IDRRetrievalDataModule, IDRRetrievalEvaluator, IDRRetrievalModule
from src.modules.compound_transforms import DGLPretrainedFromSmiles
from src.modules.images import CNNEncoder
from src.modules.molecules import GINPretrainedWithLinearHead
from src.modules.transforms import DefaultJUMPTransform

In [3]:
for i in range(1, 4):
    if not Path(f"../cpjump{i}/jump/").exists():
        print(f"Mounting cpjump{i}...")
        os.system(f"sshfs bioclust:/projects/cpjump{i}/ ../cpjump{i}")
    else:
        print(f"cpjump{i} already mounted.")

cpjump1 already mounted.
cpjump2 already mounted.
cpjump3 already mounted.


In [4]:
datamodule = IDRRetrievalDataModule(
    selected_compounds_path="../cpjump1/excape-db/selected_compounds.csv",
    image_metadata_path="../cpjump1/idr0033-rohban-pathways/processed_metadata.csv",
    data_root_dir="../cpjump1/screen_1751",
    image_batch_size=8,
    compound_batch_size=8,
    num_workers=8,
    pin_memory=False,
    prefetch_factor=3,
    compound_transform=DGLPretrainedFromSmiles(),
    transform=DefaultJUMPTransform(size=256),
    compound_gene_col="Gene_Symbol",
    image_gene_col="Gene Symbol",
    col_fstring="FileName_{channel}",
    channels=None,
    target_col="Activity_Flag",
    smiles_col="SMILES",
    use_cache=False,
    mol_collate_fn=None,
    img_collate_fn=None,
)

image_encoder = CNNEncoder("resnet18", target_num=128)
molecule_encoder = GINPretrainedWithLinearHead("gin_supervised_infomax", out_dim=128)

idr_model = IDRRetrievalModule(
    image_encoder=image_encoder,
    molecule_encoder=molecule_encoder,
    example_input_path="../cpjump1/jump/models/eval/test/example.pt",
)

Downloading gin_supervised_infomax_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_infomax.pth...
Pretrained model loaded


In [51]:
idr_model = IDRRetrievalModule(
    image_encoder=image_encoder,
    molecule_encoder=molecule_encoder,
    example_input_path="../cpjump1/jump/models/eval/test/example.pt",
)

In [5]:
trainer = Trainer(accelerator="gpu", devices=1)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [35]:
res = trainer.predict(idr_model, datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

In [52]:
datamodule.prepare_data()
datamodule.setup("predict")

In [53]:
dl = datamodule.predict_dataloader()

In [54]:
def concat_from_list_of_dict(res, key):
    return torch.cat([r[key] for r in res], dim=0)

In [86]:
out_metrics = {}
for gene in dl:
    compound_emb = trainer.predict(idr_model, dl[gene]["molecule"])
    activities = concat_from_list_of_dict(compound_emb, "activity_flag")
    compound_emb = concat_from_list_of_dict(compound_emb, "compound")

    image_emb = trainer.predict(idr_model, dl[gene]["image"])
    image_emb = concat_from_list_of_dict(image_emb, "image")

    gene_metrics = idr_model.retrieval(
        image_embeddings=image_emb, compound_embeddings=compound_emb, activities=activities
    )

    out_metrics[gene] = gene_metrics

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

In [94]:
trainer.loggers[0].log_metrics(out_metrics)

In [109]:
wandb_logger = WandbLogger()

  rank_zero_warn(


In [110]:
wandb_logger.log_metrics

<lightning.pytorch.loggers.wandb.WandbLogger at 0x7f39eb9bf010>

In [104]:
mean_metrics = defaultdict(lambda: 0)
for gene in out_metrics:
    for metric in out_metrics[gene]:
        mean_metrics[metric] += out_metrics[gene][metric] / len(out_metrics)

mean_metrics = dict(mean_metrics)
for metric in mean_metrics:
    mean_metrics[metric] = float(mean_metrics[metric])

## Metric Test

In [181]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
idr_model.to(device)
idr_model.eval();

In [187]:
dist, molecule_activities = idr_model.get_distance_matrix(
    image_dataloader=img_dataloader, molecule_dataloader=mol_dataloader, max_num_batches=None
)

In [229]:
retrieval_metrics = MetricCollection(
    {
        "RetrievalMRR": RetrievalMRR(),
        "RetrievalHitRate_top_1": RetrievalHitRate(top_k=1),
        "RetrievalHitRate_top_3": RetrievalHitRate(top_k=3),
        "RetrievalHitRate_top_5": RetrievalHitRate(top_k=5),
        "RetrievalHitRate_top_10": RetrievalHitRate(top_k=10),
        "RetrievalFallOut_top_5": RetrievalFallOut(top_k=5),
        "RetrievalMAP_top_5": RetrievalMAP(top_k=5),
        "RetrievalPrecision_top_5": RetrievalPrecision(top_k=5),
        "RetrievalNormalizedDCG": RetrievalNormalizedDCG(),
        # "RetrievalRPrecision": RetrievalRPrecision()
    }
)

In [230]:
dist

tensor([[-0.0018,  0.0432,  0.0789,  ...,  0.0255,  0.0430,  0.0028],
        [-0.0261, -0.0429, -0.0069,  ..., -0.0523, -0.0571, -0.0432],
        [ 0.0019,  0.0239,  0.0101,  ...,  0.0945,  0.0225,  0.0465],
        ...,
        [ 0.1606,  0.0905,  0.1849,  ...,  0.1315,  0.1821,  0.1393],
        [ 0.1091,  0.0410,  0.0730,  ...,  0.1198,  0.1352,  0.0692],
        [ 0.1708,  0.1863,  0.2509,  ...,  0.1794,  0.2304,  0.2223]])

In [231]:
rh3 = RetrievalHitRate(top_k=5)

In [232]:
indexes = torch.arange(dist.shape[1])

In [233]:
retrieval_metrics(preds=dist, target=molecule_activities.expand(dist.T.shape).T, indexes=indexes.expand(dist.shape))

{'RetrievalFallOut_top_5': tensor(0.0495),
 'RetrievalHitRate_top_1': tensor(0.0270),
 'RetrievalHitRate_top_10': tensor(0.7838),
 'RetrievalHitRate_top_5': tensor(0.2703),
 'RetrievalMAP_top_5': tensor(0.1083),
 'RetrievalMRR': tensor(0.1887),
 'RetrievalNormalizedDCG': tensor(0.5436),
 'RetrievalPrecision_top_5': tensor(0.0595)}

In [212]:
rhs = []
for i in range(37):
    rhs.append(retrieval_hit_rate(preds=dist[:, i], target=molecule_activities, top_k=4).item())



In [207]:
sorted_id = sorted(range(dist.shape[0]), key=lambda x: dist[x, i], reverse=True)

In [209]:
dist[sorted_id, i]

tensor([ 0.2223,  0.2065,  0.1724,  0.1472,  0.1469,  0.1406,  0.1393,  0.1358,
         0.1271,  0.1198,  0.1185,  0.1169,  0.1123,  0.1114,  0.1094,  0.1046,
         0.1027,  0.0972,  0.0968,  0.0966,  0.0917,  0.0916,  0.0883,  0.0880,
         0.0878,  0.0871,  0.0792,  0.0785,  0.0779,  0.0757,  0.0723,  0.0703,
         0.0699,  0.0697,  0.0692,  0.0641,  0.0639,  0.0597,  0.0589,  0.0586,
         0.0567,  0.0559,  0.0546,  0.0520,  0.0486,  0.0480,  0.0476,  0.0468,
         0.0465,  0.0451,  0.0436,  0.0433,  0.0413,  0.0387,  0.0380,  0.0367,
         0.0353,  0.0352,  0.0343,  0.0328,  0.0326,  0.0323,  0.0263,  0.0253,
         0.0239,  0.0195,  0.0178,  0.0152,  0.0147,  0.0133,  0.0124,  0.0120,
         0.0082,  0.0028,  0.0026, -0.0012, -0.0025, -0.0037, -0.0069, -0.0101,
        -0.0131, -0.0139, -0.0144, -0.0178, -0.0211, -0.0214, -0.0223, -0.0233,
        -0.0235, -0.0309, -0.0333, -0.0344, -0.0385, -0.0393, -0.0414, -0.0432,
        -0.0481, -0.0482, -0.0544, -0.05

In [210]:
molecule_activities[sorted_id]

tensor([0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.,
        0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 1.,
        1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0.,
        1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.])

In [213]:
rhs[-1]

0.0

In [199]:
np.mean(rhs)

0.2702702702702703

In [69]:
idr_model.get_distance_matrix(image_dataloader=img_dataloader, molecule_dataloader=mol_dataloader, max_num_batches=5)

TypeError: conv2d() received an invalid combination of arguments - got (dict, Parameter, NoneType, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!dict!, !Parameter!, !NoneType!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!dict!, !Parameter!, !NoneType!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, int)


In [22]:
datamodule.test_image_datasets

{'BRCA1': IDRRetrievalImageDataset(n_images=37),
 'HIF1A': IDRRetrievalImageDataset(n_images=83),
 'HSPA5': IDRRetrievalImageDataset(n_images=45),
 'JUN': IDRRetrievalImageDataset(n_images=87),
 'STAT3': IDRRetrievalImageDataset(n_images=123),
 'TP53': IDRRetrievalImageDataset(n_images=88)}