# Validating evaluation methods

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import copy
import os
from collections import defaultdict
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from omegaconf import DictConfig, OmegaConf, open_dict
from torchmetrics import MetricCollection
from torchmetrics.functional import pairwise_cosine_similarity, retrieval_hit_rate
from torchmetrics.retrieval import (
    RetrievalFallOut,
    RetrievalHitRate,
    RetrievalMAP,
    RetrievalMRR,
    RetrievalNormalizedDCG,
    RetrievalPrecision,
    RetrievalRPrecision,
)

from src import utils
from src.eval.retrieval import IDRRetrievalDataModule, IDRRetrievalEvaluator, IDRRetrievalModule
from src.modules.compound_transforms import DGLPretrainedFromSmiles
from src.modules.images import CNNEncoder
from src.modules.molecules import GINPretrainedWithLinearHead

In [3]:
for i in range(1, 4):
    if not Path(f"../cpjump{i}/jump/").exists():
        print(f"Mounting cpjump{i}...")
        os.system(f"sshfs bioclust:/projects/cpjump{i}/ ../cpjump{i}")
    else:
        print(f"cpjump{i} already mounted.")

cpjump1 already mounted.
cpjump2 already mounted.
cpjump3 already mounted.


## Loading config

In [4]:
GlobalHydra.instance().clear()

In [4]:
initialize(version_base=None, config_path="../configs")

hydra.initialize()

In [5]:
cfg = compose(
    config_name="train.yaml",
    overrides=[
        "evaluate=true",
        "eval=evaluators",
        "paths.projects_dir=..",
        "paths.output_dir=./tmp/21312FS12A",
        "experiment=final/dataset_experiments/small.yaml",
        "data.batch_size=4",
        "trainer.devices=1",
    ],
)
print(OmegaConf.to_yaml(cfg))

task_name: train
tags:
- final_experiments
- pretrained
- ntxent
- single_view
- resnet34
- pna
train: true
load_first_bacth: true
test: true
evaluate: true
compile: false
ckpt_path: null
seed: 12345
data:
  compound_transform:
    _target_: src.modules.compound_transforms.pna.PNATransform
    compound_str_type: inchi
  transform:
    _target_: src.modules.transforms.ComplexTransform
    _convert_: object
    size: 512
    flip_p: 0.5
    resize_p: 0.3
    color_p: 0.2
    resize_min_ratio: 0.9
    intensity: 0.2
    brightness: 0.1
    fill_nan: true
    use_flip: true
    use_blur: false
    use_color_jitter: true
    use_drop: false
    use_resized_crop: true
  _target_: src.models.jump_cl.datamodule.BasicJUMPDataModule
  batch_size: 4
  num_workers: 16
  pin_memory: null
  prefetch_factor: 2
  drop_last: true
  force_split: false
  splitter:
    _target_: src.splitters.RandomSplitter
    train: 4096
    test: 8192
    val: 4096
    retrieval: 4096
  use_compond_cache: false
  data_

In [7]:
cfg.eval.keys()

dict_keys(['simple_retrieval', 'idr_graph_retrieval', 'batch_effect', 'plate_normalized', 'lipo', 'esol', 'bbbp', 'hiv', 'phaseI', 'phaseII', 'phaseIII'])

## Simple retrieval

In [7]:
cfg.eval.simple_retrieval.trainer

{'_target_': 'lightning.pytorch.trainer.Trainer', 'default_root_dir': '${paths.output_dir}/eval/simple_retrieval/', 'min_epochs': 5, 'max_epochs': 20, 'accelerator': 'gpu', 'detect_anomaly': True, 'devices': '${trainer.devices}', 'check_val_every_n_epoch': 1, 'deterministic': False}

In [8]:
cfg.eval.simple_retrieval.datamodule.batch_size = 4
cfg.eval.simple_retrieval.trainer.devices = 1

with open_dict(cfg.eval.simple_retrieval.trainer):
    cfg.eval.simple_retrieval.trainer.limit_predict_batches = 3

In [9]:
simple = utils.instantiate_evaluator(cfg.eval.simple_retrieval, cfg.model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [10]:
simple.datamodule.setup()

[[36m2023-09-26 13:43:14,623[0m][[34msrc.eval.simple_retrieval.datamodule[0m][[32mINFO[0m] - Preparing retrieval dataset[0m


Preparing retrieval dataset


In [12]:
dl = simple.datamodule.predict_dataloader()

In [20]:
b = next(iter(dl))

In [21]:
simple.model.predict_step(b, 0)

{'dataloader_idx': 0,
 'batch_idx': 0,
 'compound_str': ['InChI=1S/C10H10ClN3O2/c11-7-2-1-3-8(4-7)14-10(6-16)9(5-15)12-13-14/h1-4,15-16H,5-6H2',
  'InChI=1S/C10H10ClN3O2S2/c11-9-1-2-10(17-9)18(15,16)14-4-3-7-8(5-14)13-6-12-7/h1-2,6H,3-5H2,(H,12,13)',
  'InChI=1S/C10H10N2O3S2/c1-2-5-3-6-8(15)11-10(12-9(6)17-5)16-4-7(13)14/h3H,2,4H2,1H3,(H,13,14)(H,11,12,15)',
  'InChI=1S/C10H10N4O2/c1-16-10(15)8-9(11)13-14(12-8)7-5-3-2-4-6-7/h2-6H,1H3,(H2,11,13)'],
 'image_id': ['source_6__p210928CPU2OS48hw384exp030JUMP__110000296383__M17__5',
  'source_5__JUMPCPE-20211001-Run33_20211001_152017__AEOJUM806__I07__7',
  'source_1__Batch1_20221004__UL001643__Z17__2',
  'source_9__20210915-Run10__GR00003307__Z43__1'],
 'compound_emb': tensor([[-0.0327,  0.0117,  0.0466,  ...,  0.0163,  0.0128,  0.0203],
         [ 0.0585,  0.0194,  0.0242,  ...,  0.0032,  0.0170, -0.0668],
         [ 0.0341, -0.0110,  0.0852,  ...,  0.1249,  0.0764, -0.0036],
         [ 0.0528,  0.0212,  0.0700,  ...,  0.0390,  0.0383,  0.00

In [22]:
predictions = simple.trainer.predict(simple.model, simple.datamodule)

You are using a CUDA device ('NVIDIA GeForce RTX 3070 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

In [25]:
predictions

[{'dataloader_idx': 0,
  'batch_idx': 0,
  'compound_str': ['InChI=1S/C10H10ClN3O2/c11-7-2-1-3-8(4-7)14-10(6-16)9(5-15)12-13-14/h1-4,15-16H,5-6H2',
   'InChI=1S/C10H10ClN3O2S2/c11-9-1-2-10(17-9)18(15,16)14-4-3-7-8(5-14)13-6-12-7/h1-2,6H,3-5H2,(H,12,13)',
   'InChI=1S/C10H10N2O3S2/c1-2-5-3-6-8(15)11-10(12-9(6)17-5)16-4-7(13)14/h3H,2,4H2,1H3,(H,13,14)(H,11,12,15)',
   'InChI=1S/C10H10N4O2/c1-16-10(15)8-9(11)13-14(12-8)7-5-3-2-4-6-7/h2-6H,1H3,(H2,11,13)'],
  'image_id': ['source_6__p210928CPU2OS48hw384exp030JUMP__110000296383__M17__9',
   'source_10__2021_08_09_U2OS_48_hr_run13__Dest210727-153138__P06__1',
   'source_8__J4__A1166164__M09__4',
   'source_11__Batch5__EC000080__N08__4'],
  'compound_emb': tensor([[ 0.0163,  0.0241,  0.0165,  ...,  0.0816,  0.0436, -0.0081],
          [ 0.0281,  0.0094,  0.0201,  ...,  0.0373,  0.0158,  0.0103],
          [-0.0392, -0.0008,  0.0729,  ...,  0.0593,  0.0104, -0.0258],
          [ 0.0009,  0.0181,  0.0571,  ...,  0.0711,  0.0079, -0.0134]]),
  '

In [28]:
result_dict = defaultdict(list)

for batch_res in predictions:
    # batch_res contains image_emb, compound_emb, compound_str, image_id
    image_emb = batch_res["image_emb"]
    compound_emb = batch_res["compound_emb"]

    dist = simple.distance_metric(image_emb, compound_emb)  # Similarity matrix between images and compounds: 100 x 100

    indexes_mol_to_img = torch.arange(dist.shape[1]).expand(
        dist.shape
    )  # 100 x 100 matrix with the indexes of the compounds
    indexes_img_to_mol = indexes_mol_to_img.transpose(0, 1)  # 100 x 100 matrix with the indexes of the images
    target = torch.eye(dist.shape[1])  # Identity matrix: 100 x 100

    res_mol_to_img = simple.retrieval_metrics(
        preds=dist, target=target, indexes=indexes_mol_to_img
    )  # Dictionary with the metrics for mol to img
    res_img_to_mol = simple.retrieval_metrics(
        preds=dist, target=target, indexes=indexes_img_to_mol
    )  # Dictionary with the metrics for img to mol

    for metric in simple.metric_keys:
        result_dict[f"retrieval/1:100/mol_to_img/{metric}"].append(res_mol_to_img[metric])
        result_dict[f"retrieval/1:100/img_to_mol/{metric}"].append(res_img_to_mol[metric])
        result_dict[f"retrieval/1:100/avg/{metric}"].append((res_mol_to_img[metric] + res_img_to_mol[metric]) / 2)

In [34]:
for metric in simple.metric_keys:
    result_dict[f"retrieval/1:100/mol_to_img/{metric}_avg"] = np.mean(
        result_dict[f"retrieval/1:100/mol_to_img/{metric}"]
    )
    result_dict[f"retrieval/1:100/img_to_mol/{metric}_avg"] = np.mean(
        result_dict[f"retrieval/1:100/img_to_mol/{metric}"]
    )
    result_dict[f"retrieval/1:100/avg/{metric}_avg"] = np.mean(result_dict[f"retrieval/1:100/avg/{metric}"])
    result_dict[f"retrieval/1:100/mol_to_img/{metric}_std"] = np.std(
        result_dict[f"retrieval/1:100/mol_to_img/{metric}"]
    )
    result_dict[f"retrieval/1:100/img_to_mol/{metric}_std"] = np.std(
        result_dict[f"retrieval/1:100/img_to_mol/{metric}"]
    )
    result_dict[f"retrieval/1:100/avg/{metric}_std"] = np.std(result_dict[f"retrieval/1:100/avg/{metric}"])

In [36]:
result_dict

defaultdict(list,
            {'retrieval/1:100/mol_to_img/RetrievalFallOut_top_01': [tensor(0.2500),
              tensor(0.2500),
              tensor(0.2500)],
             'retrieval/1:100/img_to_mol/RetrievalFallOut_top_01': [tensor(0.2500),
              tensor(0.2500),
              tensor(0.2500)],
             'retrieval/1:100/avg/RetrievalFallOut_top_01': [tensor(0.2500),
              tensor(0.2500),
              tensor(0.2500)],
             'retrieval/1:100/mol_to_img/RetrievalFallOut_top_05': [tensor(1.),
              tensor(1.),
              tensor(1.)],
             'retrieval/1:100/img_to_mol/RetrievalFallOut_top_05': [tensor(1.),
              tensor(1.),
              tensor(1.)],
             'retrieval/1:100/avg/RetrievalFallOut_top_05': [tensor(1.),
              tensor(1.),
              tensor(1.)],
             'retrieval/1:100/mol_to_img/RetrievalHitRate_top_01': [tensor(0.2500),
              tensor(0.2500),
              tensor(0.2500)],
             'ret

In [159]:
def concat_from_list_of_dict_to_list(res, key):
    out = np.concatenate([r[key] for r in res])
    if out.ndim == 2:
        return out.tolist()
    else:
        return out


def concat_from_list_of_dict_to_tensor(res, key):
    if isinstance(res[0][key], torch.Tensor):
        out = torch.cat([r[key] for r in res], dim=0)
    elif isinstance(res[0][key], (int, float)):
        out = [r[key] for r in res]
    else:
        out = concat_from_list_of_dict_to_list(res, key)
    return out

In [52]:
predictions[0][]

{'dataloader_idx': 0,
 'batch_idx': 0,
 'compound_str': ['InChI=1S/C10H10ClN3O2/c11-7-2-1-3-8(4-7)14-10(6-16)9(5-15)12-13-14/h1-4,15-16H,5-6H2',
  'InChI=1S/C10H10ClN3O2S2/c11-9-1-2-10(17-9)18(15,16)14-4-3-7-8(5-14)13-6-12-7/h1-2,6H,3-5H2,(H,12,13)',
  'InChI=1S/C10H10N2O3S2/c1-2-5-3-6-8(15)11-10(12-9(6)17-5)16-4-7(13)14/h3H,2,4H2,1H3,(H,13,14)(H,11,12,15)',
  'InChI=1S/C10H10N4O2/c1-16-10(15)8-9(11)13-14(12-8)7-5-3-2-4-6-7/h2-6H,1H3,(H2,11,13)'],
 'image_id': ['source_6__p210928CPU2OS48hw384exp030JUMP__110000296383__M17__9',
  'source_10__2021_08_09_U2OS_48_hr_run13__Dest210727-153138__P06__1',
  'source_8__J4__A1166164__M09__4',
  'source_11__Batch5__EC000080__N08__4'],
 'compound_emb': tensor([[ 0.0163,  0.0241,  0.0165,  ...,  0.0816,  0.0436, -0.0081],
         [ 0.0281,  0.0094,  0.0201,  ...,  0.0373,  0.0158,  0.0103],
         [-0.0392, -0.0008,  0.0729,  ...,  0.0593,  0.0104, -0.0258],
         [ 0.0009,  0.0181,  0.0571,  ...,  0.0711,  0.0079, -0.0134]]),
 'image_emb': ten

In [60]:
result_dict = defaultdict(list)
keys = predictions[0].keys()

n = len(predictions)
for i in range(0, n, 10):
    batch_res = {k: concat_from_list_of_dict_to_tensor(predictions[i : i + 10], k) for k in keys}

    image_emb = batch_res["image_emb"]
    compound_emb = batch_res["compound_emb"]

    dist = simple.distance_metric(image_emb, compound_emb)  # Similarity matrix between images and compounds: 100 x 100

    indexes_mol_to_img = torch.arange(dist.shape[1]).expand(
        dist.shape
    )  # 100 x 100 matrix with the indexes of the compounds
    indexes_img_to_mol = indexes_mol_to_img.transpose(0, 1)  # 100 x 100 matrix with the indexes of the images
    target = torch.eye(dist.shape[1])  # Identity matrix: 100 x 100

    res_mol_to_img = simple.retrieval_metrics(
        preds=dist, target=target, indexes=indexes_mol_to_img
    )  # Dictionary with the metrics for mol to img
    res_img_to_mol = simple.retrieval_metrics(
        preds=dist, target=target, indexes=indexes_img_to_mol
    )  # Dictionary with the metrics for img to mol

    for metric in simple.metric_keys:
        result_dict[f"retrieval/1:1000/mol_to_img/{metric}"].append(res_mol_to_img[metric])
        result_dict[f"retrieval/1:1000/img_to_mol/{metric}"].append(res_img_to_mol[metric])
        result_dict[f"retrieval/1:1000/avg/{metric}"].append((res_mol_to_img[metric] + res_img_to_mol[metric]) / 2)

for metric in simple.metric_keys:
    result_dict[f"retrieval/1:1000/mol_to_img/{metric}_avg"] = np.mean(
        result_dict[f"retrieval/1:1000/mol_to_img/{metric}"]
    )
    result_dict[f"retrieval/1:1000/img_to_mol/{metric}_avg"] = np.mean(
        result_dict[f"retrieval/1:1000/img_to_mol/{metric}"]
    )
    result_dict[f"retrieval/1:1000/avg/{metric}_avg"] = np.mean(result_dict[f"retrieval/1:1000/avg/{metric}"])
    result_dict[f"retrieval/1:1000/mol_to_img/{metric}_std"] = np.std(
        result_dict[f"retrieval/1:1000/mol_to_img/{metric}"]
    )
    result_dict[f"retrieval/1:1000/img_to_mol/{metric}_std"] = np.std(
        result_dict[f"retrieval/1:1000/img_to_mol/{metric}"]
    )
    result_dict[f"retrieval/1:1000/avg/{metric}_std"] = np.std(result_dict[f"retrieval/1:1000/avg/{metric}"])

## Hint evaluation

In [11]:
cfg.eval.phaseI.datamodule.batch_size = 64
cfg.eval.phaseI.trainer.devices = 1

In [37]:
hint = utils.instantiate_evaluator(cfg.eval.phaseI, cfg.model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [44]:
hint.datamodule.setup()

In [46]:
hint.datamodule.train_dataset.valid_df

Unnamed: 0,nctid,status,why_stop,label,phase,diseases,icdcodes,drugs,smiless,criteria,valid_smiles
1,NCT01046487,completed,,1,phase 1,['cancer'],"[""['C05.2', 'C10.0', 'C16.0', 'C16.4', 'C17.0'...","['imatinib mesylate, cyclophosphamide (dosing ...",['CC1=NC(NC2=NC=C(S2)C(=O)NC2=C(C)C=CC=C2Cl)=C...,\n Inclusion Criteria:\n\n - ...,[Cc1nc(Nc2ncc(C(=O)Nc3c(C)cccc3Cl)s2)cc(N2CCN(...
2,NCT01381887,completed,,1,phase 1,"['diabetes mellitus, type 2']","[""['E11.65', 'E11.9', 'E11.21', 'E11.36', 'E11...","['placebo', 'canagliflozin 300mg/placebo', 'ca...",['CN1C(=O)C=C(N2CCC[C@@H](N)C2)N(CC2=C(C=CC=C2...,\n Inclusion Criteria:\n\n - ...,[Cn1c(=O)cc(N2CCC[C@@H](N)C2)n(Cc2ccccc2C#N)c1...
3,NCT02015676,completed,,1,phase 1/phase 2,['breast cancer'],"[""['C79.81', 'D24.1', 'D24.2', 'D24.9', 'D49.3...","['trastuzumab', 'paclitaxel', 'myocet']",['[H][N]1([H])[C@@H]2CCCC[C@H]2[N]([H])([H])[P...,\n Inclusion Criteria:\n\n - ...,[CO[C@H]1C[C@@H]2CC[C@@H](C)[C@@](O)(O2)C(=O)C...
4,NCT01813955,terminated,\n patient recruitment insufficient\n,0,early phase 1,"['schizophrenia', 'cognitive deficits']","[""['F20.0', 'F20.1', 'F20.2', 'F20.3', 'F20.5'...",['papaverine or placebo'],['COC1=C(OC)C=C(CC2=NC=CC3=CC(OC)=C(OC)C=C23)C...,\n Inclusion Criteria:\n\n - ...,[COc1ccc(Cc2nccc3cc(OC)c(OC)cc23)cc1OC]
5,NCT01213160,completed,,0,phase 1,"['cancer', 'advanced solid malignancies']","[""['C05.2', 'C10.0', 'C16.0', 'C16.4', 'C17.0'...",['azd4547'],['COC1=CC(OC)=CC(CCC2=CC(NC(=O)C3=CC=C(C=C3)N3...,\n Inclusion Criteria:\n\n - Jap...,[COc1cc(CCc2cc(NC(=O)c3ccc(N4C[C@@H](C)N[C@@H]...
...,...,...,...,...,...,...,...,...,...,...,...
1039,NCT01434225,completed,,0,phase 1/phase 2,['neonatal seizures'],"[""['E71.511', 'P29.11', 'P29.12', 'P29.2', 'P3...",['bumetanide'],['CCCCNC1=C(OC2=CC=CC=C2)C(=CC(=C1)C(O)=O)S(N)...,\n Inclusion Criteria:-\n\n - ...,[CCCCNc1cc(C(=O)O)cc(S(N)(=O)=O)c1Oc1ccccc1]
1040,NCT01476137,completed,,0,phase 1,['cancer'],"[""['C05.2', 'C10.0', 'C16.0', 'C16.4', 'C17.0'...","['gsk1120212', 'gsk2110183']",['CN1C(=O)C(C)=C2N(C(=O)N(C3CC3)C(=O)C2=C1NC1=...,\n Inclusion Criteria for Part 1:\n\n ...,[CC(=O)Nc1cccc(-n2c(=O)n(C3CC3)c(=O)c3c(Nc4ccc...
1041,NCT01676233,completed,,1,phase 1,['type 1 diabetes mellitus'],"[""['E10.65', 'E10.9', 'E10.21', 'E10.36', 'E10...","['insulin glargine (hoe901)', 'insulin glargin...","['[Na+].[Na+].[O-]P([O-])(F)=O', '[Na+].[Na+]....",\n Inclusion criteria :\n\n - ...,"[O=P([O-])([O-])F.[Na+].[Na+], O=P([O-])([O-])..."
1042,NCT00331630,completed,,1,early phase 1,['breast cancer'],"[""['C79.81', 'D24.1', 'D24.2', 'D24.9', 'D49.3...","['lapatinib ditosylate', 'paclitaxel albumin-s...",['CS(=O)(=O)CCNCC1=CC=C(O1)C1=CC2=C(C=C1)N=CN=...,\n DISEASE CHARACTERISTICS:\n\n ...,[CS(=O)(=O)CCNCc1ccc(-c2ccc3ncnc(Nc4ccc(OCc5cc...


In [48]:
hint.run()

No WandbLogger found. WandbCallback will not log anything.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

Metric hint/phase_I/val/loss improved. New best score: 0.698


Metric hint/phase_I/val/loss improved by 0.003 >= min_delta = 0. New best score: 0.695


Metric hint/phase_I/val/loss improved by 0.006 >= min_delta = 0. New best score: 0.689


No WandbLogger found. WandbCallback will not log anything.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

In [73]:
hint.datamodule.setup()

In [76]:
dl = hint.datamodule.train_dataloader()
b = next(iter(dl))

In [79]:
b1 = copy.deepcopy(b)

In [98]:
smiles_list = b1["smiles_list"]
targets = b1["label"]

compound_embeddings = hint.model.forward_smiles_lst_lst(smiles_list)
logits = hint.model.head(compound_embeddings)

loss = hint.model.criterion(logits, targets)

In [100]:
logits, targets

(tensor([[ 0.1056,  0.0534],
         [ 0.0530, -0.0058],
         [ 0.0882, -0.0338],
         [ 0.0678,  0.0005]], grad_fn=<AddmmBackward0>),
 tensor([0, 0, 0, 1]))

In [96]:
logits.shape

torch.Size([4, 2])

## OGB evals

In [6]:
cfg.eval.hiv.datamodule.batch_size = 4
cfg.eval.hiv.trainer.devices = 1

with open_dict(cfg.eval.hiv.trainer):
    cfg.eval.hiv.trainer.limit_train_batches = 3

In [129]:
tox = utils.instantiate_evaluator(cfg.eval.hiv, cfg.model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [130]:
tox.datamodule.setup()

In [131]:
dl = tox.datamodule.train_dataloader()
b = next(iter(dl))

In [139]:
targets, logits

(tensor([1, 1, 0, 0], dtype=torch.int32),
 tensor([[ 0.0538,  0.0796],
         [-0.2209, -0.0526],
         [-0.1225, -0.1160],
         [ 0.0763, -0.0974]], grad_fn=<AddmmBackward0>))

In [144]:
loss

tensor(0.7409, grad_fn=<NllLossBackward0>)

In [118]:
logits.shape, targets

(torch.Size([4, 2]),
 tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]))

In [7]:
cfg.eval.esol.datamodule.batch_size = 4
cfg.eval.esol.trainer.devices = 1

with open_dict(cfg.eval.esol.trainer):
    cfg.eval.esol.trainer.limit_train_batches = 3

In [15]:
esol = utils.instantiate_evaluator(cfg.eval.esol, cfg.model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [16]:
esol.datamodule.setup()

In [17]:
dl = esol.datamodule.train_dataloader()
b = next(iter(dl))

In [18]:
b1 = copy.deepcopy(b)

compound = b1["compound"]
targets = b1["label"]

logits = esol.model.model(compound)

loss = esol.model.criterion(logits, targets)

In [20]:
logits, targets

(tensor([[-0.0000],
         [-0.2658],
         [-0.7437],
         [ 0.1805]], grad_fn=<MulBackward0>),
 tensor([[-1.3900],
         [-3.3900],
         [-1.6400],
         [ 0.5400]]))

In [21]:

esol.model.loss_dict["train"](loss)
esol.model.plot_metrics_dict["train"](logits, targets)
other_metrics = esol.model.other_metrics_dict["train"](logits, targets)

In [22]:

esol.model.loss_dict["train"](loss)

tensor(3.1564, grad_fn=<SqueezeBackward0>)

In [23]:
esol.model.plot_metrics_dict["train"](logits, targets)

{}

In [24]:
other_metrics = esol.model.other_metrics_dict["train"](logits, targets)

In [25]:
other_metrics

{'ogb/esol/train/MeanSquaredError': tensor(3.1564, grad_fn=<SqueezeBackward0>),
 'ogb/esol/train/MeanAbsoluteError': tensor(1.4425, grad_fn=<SqueezeBackward0>),
 'ogb/esol/train/R2Score': tensor(-0.6266, grad_fn=<SqueezeBackward0>)}