# Run the LNN model

First, we have to create the PyTorch objects out of the NPZ files. NPZ files behave like dictionaries of arrays. In our case, they contain two keys:

- `X`: the featurized systems
- `y`: the associated measurements

We can pass those dict-like arrays to an adapter class for Torch Datasets, which will be ingested by the DataLoaders. We also need the corresponding observation models.

In [1]:
DATASET = "PKIS2"
LEARNING_RATE = 0.005
MAX_EPOCHS = 100

In [2]:
MEASUREMENT_TYPES = {
    "ChEMBL": ["pKiMeasurement", "pIC50Measurement", "pKdMeasurement"],
    "PKIS2": ["PercentageDisplacementMeasurement"]
}[DATASET]

ONE_KINASE = {
    "ChEMBL": "P35968",
    "PKIS2": "ABL2",
}[DATASET]

In [3]:
from pathlib import Path
from collections import defaultdict
import numpy as np
import shutil
import time

import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
import pytorch_lightning as pl

from kinoml.utils import seed_everything
from kinoml.core import measurements as measurement_types
from kinoml.datasets.torch_datasets import XyNpzTorchDataset

HERE = Path(_dh[-1])
_trial = 0
OUT = HERE / "_output" / DATASET / f"{time.time():.0f}"
OUT.mkdir(parents=True, exist_ok=True)
print("Reporting results at path:", OUT)
# Fix the seed for reproducible random splits -- otherwise we get mixed train/test groups every time, biasing the model evaluation
seed_everything()



Reporting results at path: /home/jaime/devel/py/openkinome/experiments-binding-affinity/ligand-based/MorganFingerprint/LNN/_output/PKIS2/1604508433


## Load featurized data and create observation models

In [4]:
datasets = defaultdict(dict)
for npz in HERE.glob(f"../_output/{DATASET}__*.npz"):
    _, kinase, measurement_type = str(npz.stem).split("__")
    datasets[kinase][measurement_type] = ds = XyNpzTorchDataset(npz)
    # Override indices for splitting here: [0, a], [a, b], [b, 1]
    # a, b = 0.6, 0.7
    # ds.indices = {
    #     "train": list(range(0, int(a * len(ds)))),
    #     "test": list(range(int(a * len(ds)), int(b * len(ds)))),
    #     "val": list(range(int(b * len(ds)), len(ds))),
    # }

In [5]:
obs_models = {k: getattr(measurement_types, k).observation_model(backend="pytorch") for k in MEASUREMENT_TYPES}
obs_models

{'PercentageDisplacementMeasurement': <function kinoml.core.measurements.PercentageDisplacementMeasurement._observation_model_pytorch(dG_over_KT, inhibitor_conc=1, standard_conc=1, **kwargs)>}

Now that we have all the data-dependent objects, we can start with the model-specific definitions.

## Train the model

In [6]:
from kinoml.ml.torch_models import NeuralNetworkRegression
from kinoml.ml.lightning_modules import ObservationModelModule, CrossValidateTrainer, MultiDataModule
from pytorch_lightning import callbacks as plcb

In [7]:
datamodule = MultiDataModule(
    datasets=[datasets[ONE_KINASE][mtype] for mtype in MEASUREMENT_TYPES],
    observation_models=[obs_models[mtype] for mtype in MEASUREMENT_TYPES],
    batch_size=128, num_workers=4,
)

# Configure callbacks
early_stopping = plcb.EarlyStopping(
    monitor="val_loss", 
    min_delta=0.000000, 
    patience=10, 
    mode="min",
)
checkpoints = plcb.ModelCheckpoint(
    filepath=OUT / "chk-{epoch}-{val_loss:.4f}",
    monitor="val_loss", 
    mode="min",
    save_top_k=5,
    save_last=True,
)

# Configure trainer
trainer = CrossValidateTrainer(
    nfolds=5,     
    max_epochs=100, 
    callbacks=[early_stopping],
    checkpoint_callback=checkpoints,
    logger=pl.loggers.TensorBoardLogger(OUT / "tensorboard_logs", name="")
)

# Set up the network
input_size = datasets[ONE_KINASE][MEASUREMENT_TYPES[0]].input_size()
nn_model = NeuralNetworkRegression(input_size=input_size, hidden_size=350)

# Configure Lightning adapter module
module = ObservationModelModule(
    nn_model=nn_model, 
    optimizer=torch.optim.Adam(nn_model.parameters(), lr=LEARNING_RATE), 
    loss_function=torch.nn.MSELoss(),
)

# Run loop: first over datamodules (measurement types), then over kfolds
# TODO: Assess strategy? We start with smallest datasets first!
trainer.fit(model=module, datamodule=datamodule)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name          | Type                    | Params
----------------------------------------------------------
0 | nn_model      | NeuralNetworkRegression | 179 K 
1 | loss_function | MSELoss                 | 0     
2 | metric_mae    | MeanAbsoluteError       | 0     
3 | metric_mse    | MeanSquaredError        | 0     


DS #0 PercentageDisplacementMeasurement, fold=0


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Saving latest checkpoint...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name          | Type                    | Params
----------------------------------------------------------
0 | nn_model      | NeuralNetworkRegression | 179 K 
1 | loss_function | MSELoss                 | 0     
2 | metric_mae    | MeanAbsoluteError       | 0     
3 | metric_mse    | MeanSquaredError        | 0     



DS #0 PercentageDisplacementMeasurement, fold=1


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Saving latest checkpoint...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name          | Type                    | Params
----------------------------------------------------------
0 | nn_model      | NeuralNetworkRegression | 179 K 
1 | loss_function | MSELoss                 | 0     
2 | metric_mae    | MeanAbsoluteError       | 0     
3 | metric_mse    | MeanSquaredError        | 0     



DS #0 PercentageDisplacementMeasurement, fold=2


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Saving latest checkpoint...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name          | Type                    | Params
----------------------------------------------------------
0 | nn_model      | NeuralNetworkRegression | 179 K 
1 | loss_function | MSELoss                 | 0     
2 | metric_mae    | MeanAbsoluteError       | 0     
3 | metric_mse    | MeanSquaredError        | 0     



DS #0 PercentageDisplacementMeasurement, fold=3


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Saving latest checkpoint...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name          | Type                    | Params
----------------------------------------------------------
0 | nn_model      | NeuralNetworkRegression | 179 K 
1 | loss_function | MSELoss                 | 0     
2 | metric_mae    | MeanAbsoluteError       | 0     
3 | metric_mse    | MeanSquaredError        | 0     



DS #0 PercentageDisplacementMeasurement, fold=4


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Saving latest checkpoint...





## Performance on the test set

In [8]:
import pandas as pd
# Wait on https://github.com/PyTorchLightning/pytorch-lightning/pull/4480 to use multiple dataloaders
for index in datamodule.dataset_indices_by_size(reverse=True):
    print(f"Performance for {datamodule.measurement_types[index]}")
    datamodule.active_dataset_index = index
    display(pd.DataFrame.from_dict(trainer.test(datamodule=datamodule, dataset_index=index)))
    print(f"^ Performance for {datamodule.measurement_types[index]}")
    print()
    print("*************************************************")
    print()

Performance for PercentageDisplacementMeasurement


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'MAE': tensor(9.4969),
 'MSE': tensor(246.8831),
 'R2': tensor(0.3900),
 'test_loss': tensor(246.8831),
 'train_loss': tensor(79.9119),
 'val_loss': tensor(1081.7161)}
--------------------------------------------------------------------------------



HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'MAE': tensor(7.6515),
 'MSE': tensor(198.7881),
 'R2': tensor(0.4483),
 'test_loss': tensor(198.7881),
 'train_loss': tensor(28.7918),
 'val_loss': tensor(567.1509)}
--------------------------------------------------------------------------------



HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'MAE': tensor(6.2960),
 'MSE': tensor(147.3171),
 'R2': tensor(0.6064),
 'test_loss': tensor(147.3171),
 'train_loss': tensor(21.6361),
 'val_loss': tensor(318.4663)}
--------------------------------------------------------------------------------



HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'MAE': tensor(6.7895),
 'MSE': tensor(135.6022),
 'R2': tensor(0.6837),
 'test_loss': tensor(135.6022),
 'train_loss': tensor(42.5975),
 'val_loss': tensor(546.5746)}
--------------------------------------------------------------------------------



HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'MAE': tensor(7.8563),
 'MSE': tensor(195.3854),
 'R2': tensor(0.5234),
 'test_loss': tensor(195.3854),
 'train_loss': tensor(62.5729),
 'val_loss': tensor(448.7876)}
--------------------------------------------------------------------------------



Unnamed: 0,val_loss,train_loss,test_loss,R2,MAE,MSE
mean,592.539099,47.102034,184.795203,0.530352,7.618028,184.795203
std,259.935302,21.540224,39.968525,0.10564,1.097487,39.968525


^ Performance for PercentageDisplacementMeasurement

*************************************************



In [9]:
%load_ext tensorboard
%tensorboard --logdir {best_run.logger.log_dir}

Reusing TensorBoard on port 6024 (pid 21716), started 0:56:12 ago. (Use '!kill 21716' to kill it.)

Save best run with an easy to remember path for the next section.

In [10]:
best_run = trainer.best_run()
shutil.copy(best_run.checkpoint_callback.best_model_path, OUT / "best.ckpt")

PosixPath('/home/jaime/devel/py/openkinome/experiments-binding-affinity/ligand-based/MorganFingerprint/LNN/_output/PKIS2/1604508433/best.ckpt')

## Analysis of the best model

Measure performance against all data

In [11]:
bestmodel = ObservationModelModule.load_from_checkpoint(
    str(OUT / "best.ckpt"),
    # We need to re-specify the additional arguments upon checkpoint; weights will be taken from ckpt
    # See why: https://github.com/PyTorchLightning/pytorch-lightning/pull/1896#issue-420336432
    nn_model=NeuralNetworkRegression(input_size=input_size, hidden_size=350),
    optimizer=torch.optim.Adam(nn_model.parameters(), lr=LEARNING_RATE), 
    loss_function=torch.nn.MSELoss(),
)

See here the performance on the whole dataset. Take into account that the model never saw the `test` group, which was held out before the KFold split.

In [12]:
from ipywidgets import HBox, VBox, Output
from kinoml.analysis.plots import predicted_vs_observed
plots = []
for index in datamodule.dataset_indices_by_size(reverse=True):
    for ttype in ["train", "val", "test"]:
        indices = getattr(datamodule, f"{ttype}_dataloader")(dataset_index=index).sampler.indices
        observed = datamodule.datasets[index].data_y[indices]
        model_input = datamodule.datasets[index].data_X[indices]
        
        prediction = bestmodel(model_input, observation_model=datamodule.observation_models[index]).detach().numpy()
        
        mtype = datamodule.measurement_types[index]
        mtype_class = getattr(measurement_types, mtype)
        
        output = Output()
        with output:
            title = f"{mtype} ({ttype}={observed.shape[0]})"
            print(title)
            print("-"*(len(title)))
            display(predicted_vs_observed(prediction, observed, mtype_class, n_boot=100, sample_ratio=0.75))
        plots.append(output)
# Plot in 3-column table
VBox([HBox(row) for row in np.reshape(plots, (-1, 3)).tolist()])

VBox(children=(HBox(children=(Output(), Output(), Output())),))

In [13]:
from kinoml.utils import watermark
watermark()

Watermark
---------
pandas            1.1.3
torch             1.6.0
numpy             1.19.2
pytorch_lightning 1.0.4
last updated: 2020-11-04 17:47:53 CET 2020-11-04T17:47:53+01:00

CPython 3.7.8
IPython 7.18.1

compiler   : GCC 7.5.0
system     : Linux
release    : 4.19.128-microsoft-standard
machine    : x86_64
processor  : x86_64
CPU cores  : 8
interpreter: 64bit
host name  : jrodriguez
Git hash   : 1b20dfc063bd9d929133c103af7bb40cf9394687
watermark 2.0.2

conda
-----
sys.version: 3.7.6 | packaged by conda-forge | (defau...
sys.prefix: /opt/miniconda
sys.executable: /opt/miniconda/bin/python
conda location: /opt/miniconda/lib/python3.7/site-packages/conda
conda-build: /opt/miniconda/bin/conda-build
conda-convert: /opt/miniconda/bin/conda-convert
conda-debug: /opt/miniconda/bin/conda-debug
conda-develop: /opt/miniconda/bin/conda-develop
conda-env: /opt/miniconda/bin/conda-env
conda-index: /opt/miniconda/bin/conda-index
conda-inspect: /opt/miniconda/bin/conda-inspect
conda-metapackage