# Run the LNN model

First, we have to create the PyTorch objects out of the NPZ files. NPZ files behave like dictionaries of arrays. In our case, they contain two keys:

- `X`: the featurized systems
- `y`: the associated measurements

We can pass those dict-like arrays to an adapter class for Torch Datasets, which will be ingested by the DataLoaders. We also need the corresponding observation models.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
DATASET = "PKIS2"
LEARNING_RATE = 0.005
MAX_EPOCHS = 100
N_SPLITS = 5
ITEMS_PER_ROW = 3

In [3]:
MEASUREMENT_TYPES = {
    "ChEMBL": ["pKiMeasurement", "pIC50Measurement", "pKdMeasurement"],
    "PKIS2": ["PercentageDisplacementMeasurement"]
}[DATASET]

ONE_KINASE = {
    "ChEMBL": "P35968",
    "PKIS2": "ABL2",
}[DATASET]

In [4]:
from pathlib import Path
from collections import defaultdict
import numpy as np
import shutil
import time

import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
import pytorch_lightning as pl

from kinoml.utils import seed_everything
from kinoml.core import measurements as measurement_types
from kinoml.datasets.torch_datasets import XyNpzTorchDataset

HERE = Path(_dh[-1])
_trial = 0
OUT = HERE / "_output" / DATASET / f"{time.time():.0f}"
OUT.mkdir(parents=True, exist_ok=True)
print("Reporting results at path:", OUT)
# Fix the seed for reproducible random splits -- otherwise we get mixed train/test groups every time, biasing the model evaluation
seed_everything()



Reporting results at path: /home/jaime/devel/py/openkinome/experiments-binding-affinity/ligand-based/MorganFingerprint/LNN/_output/PKIS2/1605109764


## Load featurized data and create observation models

In [5]:
datasets = defaultdict(dict)
for npz in HERE.glob(f"../_output/{DATASET}__*.npz"):
    _, kinase, measurement_type = str(npz.stem).split("__")
    datasets[kinase][measurement_type] = ds = XyNpzTorchDataset(npz)
    # Override indices for splitting here: [0, a], [a, b], [b, 1]
    # a, b = 0.6, 0.7
    # ds.indices = {
    #     "train": list(range(0, int(a * len(ds)))),
    #     "test": list(range(int(a * len(ds)), int(b * len(ds)))),
    #     "val": list(range(int(b * len(ds)), len(ds))),
    # }

In [6]:
obs_models = {k: getattr(measurement_types, k).observation_model(backend="pytorch") for k in MEASUREMENT_TYPES}
obs_models

{'PercentageDisplacementMeasurement': <function kinoml.core.measurements.PercentageDisplacementMeasurement._observation_model_pytorch(dG_over_KT, inhibitor_conc=1, standard_conc=1, **kwargs)>}

Now that we have all the data-dependent objects, we can start with the model-specific definitions.

## Train the model

In [7]:
from kinoml.ml.torch_models import NeuralNetworkRegression
from kinoml.ml.lightning_modules import ObservationModelModule, CrossValidateTrainer, MultiDataModule
from pytorch_lightning import callbacks as plcb

In [8]:
datamodule = MultiDataModule(
    datasets=[datasets[ONE_KINASE][mtype] for mtype in MEASUREMENT_TYPES],
    observation_models=[obs_models[mtype] for mtype in MEASUREMENT_TYPES],
    batch_size=128, num_workers=1,
)

# Configure callbacks
early_stopping = plcb.EarlyStopping(
    monitor="val_loss", 
    min_delta=0.00001, 
    patience=10, 
    mode="min",
)
checkpoints = plcb.ModelCheckpoint(
    filepath=OUT / "chk-{epoch}-{val_loss:.4f}",
    monitor="val_loss", 
    mode="min",
    save_top_k=5,
    save_last=True,
)

# Configure trainer
trainer = CrossValidateTrainer(
    nfolds=5,     
    max_epochs=100, 
    callbacks=[early_stopping],
    checkpoint_callback=checkpoints,
    logger=pl.loggers.TensorBoardLogger(OUT / "tensorboard_logs", name="")
)

# Set up the network
input_size = datasets[ONE_KINASE][MEASUREMENT_TYPES[0]].input_size()
nn_model = NeuralNetworkRegression(input_size=input_size, hidden_size=350)

# Configure Lightning adapter module
module = ObservationModelModule(
    nn_model=nn_model, 
    optimizer=torch.optim.Adam(nn_model.parameters(), lr=LEARNING_RATE), 
    loss_function=torch.nn.MSELoss(),
)

# Run loop: first over datamodules (measurement types), then over kfolds
# TODO: Assess strategy? We start with smallest datasets first!
trainer.fit(model=module, datamodule=datamodule)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name          | Type                    | Params
----------------------------------------------------------
0 | nn_model      | NeuralNetworkRegression | 179 K 
1 | loss_function | MSELoss                 | 0     
2 | metric_mae    | MeanAbsoluteError       | 0     
3 | metric_mse    | MeanSquaredError        | 0     
4 | metric_rmse   | RootMeanSquaredError    | 0     


DS #0 PercentageDisplacementMeasurement, fold=0




HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…



HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Saving latest checkpoint...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name          | Type                    | Params
----------------------------------------------------------
0 | nn_model      | NeuralNetworkRegression | 179 K 
1 | loss_function | MSELoss                 | 0     
2 | metric_mae    | MeanAbsoluteError       | 0     
3 | metric_mse    | MeanSquaredError        | 0     
4 | metric_rmse   | RootMeanSquaredError    | 0     



DS #0 PercentageDisplacementMeasurement, fold=1




HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…



HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Saving latest checkpoint...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name          | Type                    | Params
----------------------------------------------------------
0 | nn_model      | NeuralNetworkRegression | 179 K 
1 | loss_function | MSELoss                 | 0     
2 | metric_mae    | MeanAbsoluteError       | 0     
3 | metric_mse    | MeanSquaredError        | 0     
4 | metric_rmse   | RootMeanSquaredError    | 0     



DS #0 PercentageDisplacementMeasurement, fold=2




HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…



HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Saving latest checkpoint...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name          | Type                    | Params
----------------------------------------------------------
0 | nn_model      | NeuralNetworkRegression | 179 K 
1 | loss_function | MSELoss                 | 0     
2 | metric_mae    | MeanAbsoluteError       | 0     
3 | metric_mse    | MeanSquaredError        | 0     
4 | metric_rmse   | RootMeanSquaredError    | 0     



DS #0 PercentageDisplacementMeasurement, fold=3




HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…



HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Saving latest checkpoint...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name          | Type                    | Params
----------------------------------------------------------
0 | nn_model      | NeuralNetworkRegression | 179 K 
1 | loss_function | MSELoss                 | 0     
2 | metric_mae    | MeanAbsoluteError       | 0     
3 | metric_mse    | MeanSquaredError        | 0     
4 | metric_rmse   | RootMeanSquaredError    | 0     



DS #0 PercentageDisplacementMeasurement, fold=4




HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…



HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Saving latest checkpoint...





## Performance on the test set

In [9]:
import pandas as pd
# Wait on https://github.com/PyTorchLightning/pytorch-lightning/pull/4480 to use multiple dataloaders
for index in datamodule.dataset_indices_by_size(reverse=True):
    print(f"Performance for {datamodule.measurement_types[index]}")
    display(pd.DataFrame.from_dict(trainer.test(datamodule=datamodule, dataset_index=index)))
    print(f"^ Performance for {datamodule.measurement_types[index]}")
    print()
    print("*************************************************")
    print()

Performance for PercentageDisplacementMeasurement
Test results for DS #0 for fold 0




HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_MAE': tensor(24.7645),
 'test_MSE': tensor(1119.5229),
 'test_R2': -2.009995422712897,
 'test_RMSE': tensor(33.4593),
 'test_loss': tensor(1119.5229),
 'train_loss': tensor(117.5616),
 'val_MAE': tensor(14.4190),
 'val_MSE': tensor(487.2249),
 'val_R2': 0.05994957596864359,
 'val_RMSE': tensor(22.0732),
 'val_loss': tensor(487.2249)}
--------------------------------------------------------------------------------

Test results for DS #0 for fold 1




HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_MAE': tensor(15.7533),
 'test_MSE': tensor(561.6644),
 'test_R2': -0.3694760485069408,
 'test_RMSE': tensor(23.6995),
 'test_loss': tensor(561.6644),
 'train_loss': tensor(96.5852),
 'val_MAE': tensor(14.8351),
 'val_MSE': tensor(447.6775),
 'val_R2': 0.1019282443445757,
 'val_RMSE': tensor(21.1584),
 'val_loss': tensor(447.6775)}
--------------------------------------------------------------------------------

Test results for DS #0 for fold 2




HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_MAE': tensor(15.7795),
 'test_MSE': tensor(558.8057),
 'test_R2': -0.08904228651826274,
 'test_RMSE': tensor(23.6391),
 'test_loss': tensor(558.8057),
 'train_loss': tensor(50.0401),
 'val_MAE': tensor(16.8695),
 'val_MSE': tensor(590.9111),
 'val_R2': 0.03618130896518823,
 'val_RMSE': tensor(24.3087),
 'val_loss': tensor(590.9111)}
--------------------------------------------------------------------------------

Test results for DS #0 for fold 3




HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_MAE': tensor(15.1360),
 'test_MSE': tensor(517.9116),
 'test_R2': 0.12734836221493873,
 'test_RMSE': tensor(22.7577),
 'test_loss': tensor(517.9116),
 'train_loss': tensor(48.9267),
 'val_MAE': tensor(14.2085),
 'val_MSE': tensor(499.5436),
 'val_R2': 0.08068097863242496,
 'val_RMSE': tensor(22.3505),
 'val_loss': tensor(499.5436)}
--------------------------------------------------------------------------------

Test results for DS #0 for fold 4




HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_MAE': tensor(14.0250),
 'test_MSE': tensor(447.6107),
 'test_R2': -0.06183302508871713,
 'test_RMSE': tensor(21.1568),
 'test_loss': tensor(447.6107),
 'train_loss': tensor(81.6823),
 'val_MAE': tensor(13.3604),
 'val_MSE': tensor(434.2366),
 'val_R2': 0.2002503164448982,
 'val_RMSE': tensor(20.8383),
 'val_loss': tensor(434.2366)}
--------------------------------------------------------------------------------



Unnamed: 0,val_R2,val_MAE,val_MSE,val_RMSE,val_loss,train_loss,test_R2,test_MAE,test_MSE,test_RMSE,test_loss
mean,0.095798,14.738495,491.918738,22.145808,491.918738,78.959171,-0.4806,17.091681,641.10307,24.942458,641.10307
std,0.056596,1.16898,55.071199,1.217339,55.071199,26.632692,0.780984,3.888828,242.727988,4.356247,242.727988


^ Performance for PercentageDisplacementMeasurement

*************************************************



In [10]:
%load_ext tensorboard
%tensorboard --logdir {best_run.logger.log_dir}

Reusing TensorBoard on port 6024 (pid 21716), started 6 days, 23:59:14 ago. (Use '!kill 21716' to kill it.)

Save best run with an easy to remember path for the next section.

In [11]:
best_run = trainer.best_run()
shutil.copy(best_run.checkpoint_callback.best_model_path, OUT / "best.ckpt")

PosixPath('/home/jaime/devel/py/openkinome/experiments-binding-affinity/ligand-based/MorganFingerprint/LNN/_output/PKIS2/1605109764/best.ckpt')

## Analysis of the best model

Measure performance against all data

In [12]:
bestmodel = ObservationModelModule.load_from_checkpoint(
    str(OUT / "best.ckpt"),
    # We need to re-specify the additional arguments upon checkpoint; weights will be taken from ckpt
    # See why: https://github.com/PyTorchLightning/pytorch-lightning/pull/1896#issue-420336432
    nn_model=NeuralNetworkRegression(input_size=input_size, hidden_size=350),
    optimizer=torch.optim.Adam(nn_model.parameters(), lr=LEARNING_RATE), 
    loss_function=torch.nn.MSELoss(),
)

See here the performance of the best run on the whole dataset. This is just for illustrative purposes.

In [24]:
from ipywidgets import HBox, VBox, Output, HTML
from kinoml.analysis.plots import predicted_vs_observed, performance

for index in datamodule.dataset_indices_by_size(reverse=True):
    plots, metrics = [], []
    mtype = datamodule.measurement_types[index]
    mtype_class = getattr(measurement_types, mtype)
    obs_model = datamodule.observation_models[index]
    display(HTML(f"<h3>{mtype}</h3>"))
    for fold, model in enumerate(trainer._models):
        for ttype in ["train", "val", "test"]:
            indices = trainer._dataloaders[ttype][fold].sampler.indices
            model_input = dataloader.dataset.data_X[indices]
            observed = dataloader.dataset.data_y[indices]
            prediction = model(model_input, observation_model=obs_model).detach().numpy()
            output = Output()
            with output:
                title = f"Fold {fold}, {ttype}={observed.shape[0]}"
                print(title)
                print("-"*(len(title)))
                these_metrics = performance(prediction, observed)
                display(predicted_vs_observed(prediction, observed, mtype_class, n_boot=100, sample_ratio=0.75, with_metrics=False))
            plots.append(output)
            
            if ttype == "test":
                metrics.append(these_metrics)
    # Fill with empty objects until the next multiple of ITEMS_PER_ROW
    for _ in range((ITEMS_PER_ROW - (len(plots) % ITEMS_PER_ROW)) % ITEMS_PER_ROW):
        plots.append(Output())
    # Plot in 3-column table
    display(VBox([HBox(row) for row in np.reshape(plots, (-1, ITEMS_PER_ROW)).tolist()]))
    
    # Average performances
    average = {}
    for key in metrics[0]:
        values = [fold[key] for fold in metrics]
        average[key] = {
            "mean": np.mean(values),
            "std": np.std(values)
        }
    display(HTML("Bootstrapped averages across folds (test):"))
    display(pd.DataFrame.from_dict(average))

HTML(value='<h3>PercentageDisplacementMeasurement</h3>')

VBox(children=(HBox(children=(Output(), Output(), Output())), HBox(children=(Output(), Output(), Output())), H…

HTML(value='Bootstrapped averages across folds (test):')

Unnamed: 0,mae,mse,r2,rmse
mean,13.278384,515.21776,-0.03003,19.188776
std,7.775405,346.67717,0.539869,10.794567


In [None]:
from kinoml.utils import watermark
watermark()