# Train the model

> This notebook is a simplified version of our workflow. It exposes the basic details of the traning and evaluation loop more explicitly, but does not offer advanced features like early stopping, mini-batches or validation. Use the `*-lightning` version for those.

First, we have to create the PyTorch objects out of the NPZ files. NPZ files behave like dictionaries of arrays. In our case, they contain two keys:

- `X`: the featurized systems
- `y`: the associated measurements

We can pass those dict-like arrays to an adapter class for Torch Datasets, which will be ingested by the DataLoaders. We also need the corresponding observation models.

## Define hyper parameters

Edit `UPPERCASE` variables in the following cells to configure behavior of the training.

In [None]:
DATASET = "ChEMBL"
WITH_OBSERVATION_MODEL = True
# Adam
LEARNING_RATE = 0.001
EPSILON = 1e-7
BETAS = 0.9, 0.999
# Trainer
MAX_EPOCHS = 50
N_SPLITS = 5
SHUFFLE_FOLDS = False
VALIDATION = False  # TODO: VALIDATION=True is not implemented yet!
MIN_ITEMS_PER_DATASET = 50  # skip datasets if len(data) < N
# Bootstrapping
N_BOOTSTRAPS = 1
BOOTSTRAP_SAMPLE_RATIO = 1
# Output
VERBOSE = False

In [None]:
from importlib import import_module
# Model -- specified with the full import path to the calss object
MODEL_CLS = "kinoml.ml.torch_models.NeuralNetworkRegression"
MODEL_KWARGS = {"hidden_size": 350}  # input_size is defined dynamically during trainin

âš  From here on, you should _not_ need to modify anything else ðŸ¤ž

In [3]:
model_module, model_class = MODEL_CLS.rsplit(".", 1)
ModelCls =  getattr(import_module(model_module), model_class)

In [4]:
# TODO: This should be specified along the tensor files as metadata, and should not depend on the dataset identity
MEASUREMENT_TYPES = {
    "ChEMBL": ["pKiMeasurement", "pIC50Measurement", "pKdMeasurement"],
    "PKIS2": ["PercentageDisplacementMeasurement"]
}[DATASET]

# TODO: Make all datasets use the same kinase identifiers
ONE_KINASE = {
    "ChEMBL": "P35968",
    "PKIS2": "ABL2",
}[DATASET]

In [5]:
# Nasty trick: save all-caps local variables (CONSTANTS working as hyperparametrs) so far in a dict to save it later
_hparams = {key: value for key, value in locals().items() if key.upper() == key and not key.startswith(("_", "OE_"))}

In [6]:
from pathlib import Path
from collections import defaultdict
import numpy as np
import shutil
import time

from IPython.display import Markdown
import pandas as pd
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
import pytorch_lightning as pl

from kinoml.utils import seed_everything
from kinoml.core import measurements as measurement_types
from kinoml.datasets.torch_datasets import XyNpzTorchDataset
from kinoml.core.measurements import null_observation_model

HERE = Path(_dh[-1])
_trial = 0
OUT = HERE / "_output" / DATASET / f"{time.time():.0f}"
OUT.mkdir(parents=True, exist_ok=True)
print("Reporting results at path:", OUT)
# Fix the seed for reproducible random splits -- otherwise we get mixed train/test groups every time, biasing the model evaluation
seed_everything()



Reporting results at path: /home/jaime/devel/py/openkinome/experiments-binding-affinity/001_ligand-based/MorganFingerprint/LNN/_output/ChEMBL/1606992849


## Load featurized data and create observation models

In [7]:
datasets = defaultdict(dict)
for npz in HERE.glob(f"../_output/{DATASET}__*.npz"):
    _, kinase, measurement_type = str(npz.stem).split("__")
    datasets[kinase][measurement_type] = ds = XyNpzTorchDataset(npz)
    if not VALIDATION:
        ds.indices["test"] = np.concatenate([ds.indices["test"], ds.indices["val"]])
        ds.indices["val"] = np.array([])

In [8]:
backend = "pytorch" if WITH_OBSERVATION_MODEL else "null"
obs_models = {k: getattr(measurement_types, k).observation_model(backend=backend) for k in MEASUREMENT_TYPES}
obs_models

{'pKiMeasurement': <function kinoml.core.measurements.pKiMeasurement._observation_model_pytorch(dG_over_KT, standard_conc=1, **kwargs)>,
 'pIC50Measurement': <function kinoml.core.measurements.pIC50Measurement._observation_model_pytorch(dG_over_KT, substrate_conc=1e-06, michaelis_constant=1, standard_conc=1, **kwargs)>,
 'pKdMeasurement': <function kinoml.core.measurements.pKdMeasurement._observation_model_pytorch(dG_over_KT, standard_conc=1, **kwargs)>}

## Check X duplication

There's a chance we have several measurements per ligand, or, depending on the featurization scheme, even hash collisions... Let's quantify the amount of input tensor duplication we are facing.

In [9]:
for mtype in MEASUREMENT_TYPES:
    display(Markdown(f"#### {mtype}"))
    unique = {}
    for kinase, dataset_by_mtype in datasets.items():
        if mtype in dataset_by_mtype:
            ds = dataset_by_mtype[mtype]
            all_ = ds.data_X.shape[0]
            unique_ = np.unique(ds.data_X, axis=0).shape[0]
            unique[kinase] = {"all": all_, "unique": unique_}
    df = pd.DataFrame.from_dict(unique).T
    df["uniqueness"] = df["unique"] / df["all"]
    # This is how you highlight rows in pandas!
    df = df.describe().style.apply(lambda x: ['font-weight: bold' for v in x], subset=pd.IndexSlice[["mean", "std"], :])
    display(df)

#### pKiMeasurement

Unnamed: 0,all,unique,uniqueness
count,177.0,177.0,177.0
mean,91.576271,73.011299,0.90814
std,213.260181,159.515352,0.145384
min,1.0,1.0,0.333333
25%,2.0,2.0,0.873303
50%,11.0,8.0,1.0
75%,69.0,64.0,1.0
max,1497.0,1125.0,1.0


#### pIC50Measurement

Unnamed: 0,all,unique,uniqueness
count,347.0,347.0,347.0
mean,490.256484,387.616715,0.880322
std,1062.253429,814.083477,0.126062
min,1.0,1.0,0.462025
25%,6.0,6.0,0.815516
50%,68.0,59.0,0.893617
75%,399.0,321.0,1.0
max,9626.0,7252.0,1.0


#### pKdMeasurement

Unnamed: 0,all,unique,uniqueness
count,385.0,385.0,385.0
mean,44.285714,26.490909,0.726013
std,70.687814,20.648107,0.155696
min,1.0,1.0,0.084726
25%,19.0,15.0,0.635897
50%,31.0,22.0,0.707692
75%,47.0,33.0,0.833333
max,838.0,155.0,1.0


Now that we have all the data-dependent objects, we can start with the model-specific definitions.

### Training loop

In [None]:
from kinoml.ml.lightning_modules import KFold3Way, KFold
from IPython.display import Markdown
from tqdm.auto import trange, tqdm
from kinoml.ml.torch_models import NeuralNetworkRegression
from ipywidgets import HBox, VBox, Output, HTML
from kinoml.analysis.plots import predicted_vs_observed, performance
from kinoml.utils import fill_until_next_multiple
import pandas as pd
import torch.nn as nn

if VALIDATION:
    kfold = KFold3Way(n_splits=N_SPLITS, shuffle=SHUFFLE_FOLDS)
    ttypes = ["train", "val", "test"]
else:
    kfold = KFold(n_splits=N_SPLITS, shuffle=SHUFFLE_FOLDS)
    ttypes = ["train", "test"]

kinase_metrics = defaultdict(dict)
for kinase in tqdm(datasets):
    for mtype in MEASUREMENT_TYPES:
        if mtype not in datasets[kinase]:
            continue
        if datasets[kinase][mtype].data_X.shape[0] < MIN_ITEMS_PER_DATASET:
            print("Ignoring", kinase, "because it has less than", MIN_ITEMS_PER_DATASET, "entries for type", mtype)
            continue
            
        if VERBOSE:
            display(Markdown(f"#### {mtype}"))
        dataset = datasets[kinase][mtype]
        obs_model = obs_models[mtype]
        mtype_class = getattr(measurement_types, mtype)
        metrics = defaultdict(list)

        for fold_index, splits in enumerate(kfold.split(dataset.data_X, dataset.data_y)):
            if VALIDATION:
                train_indices, val_indices, test_indices = splits
            else:
                train_indices, test_indices = splits
            
            if VERBOSE:
                display(Markdown(f"##### Fold {fold_index}"))

            ####
            # TRAIN
            ####
            x_train = dataset.data_X[train_indices].float()
            x_test = dataset.data_X[test_indices].float()
            y_train = dataset.data_y[train_indices]
            y_test = dataset.data_y[test_indices]
            
            if VALIDATION:
                x_val = dataset.data_X[val_indices].float()
                y_val = dataset.data_y[val_indices]
                
            nn_model = ModelCls(input_size=x_train.shape[1], **MODEL_KWARGS)
            nn_model.train(True)

            optimizer = torch.optim.Adam(nn_model.parameters(), lr=LEARNING_RATE, eps=EPSILON, betas=BETAS)
            loss_function = torch.nn.MSELoss()
            
            if VERBOSE:
                range_epochs = trange(MAX_EPOCHS, desc="Epochs (+ featurization...)")
            else:
                range_epochs = range(MAX_EPOCHS)
            for epoch in range_epochs:
                optimizer.zero_grad()

                prediction = nn_model(x_train)
                if WITH_OBSERVATION_MODEL:
                    prediction = obs_model(prediction)

                prediction = prediction.view_as(y_train)

                # prediction = delta_g
                loss = loss_function(prediction, y_train)
                if VERBOSE:
                    range_epochs.set_description(f"Epochs (loss={loss.item():.2e})")
                
                if VALIDATION:
                    raise NotImplementedError("Validation step not implemented yet")
                    
                
                # Gradients w.r.t. parameters
                loss.backward()

                # Optimizer
                optimizer.step()

            ####
            # EVAL
            ####
            nn_model.eval()
            outputs = []
            for ttype in ttypes:
                output = Output()
                with output:
                    title = f"fold={fold_index}, {ttype}={locals()[f'{ttype}_indices'].shape[0]}"
                    print(title)
                    print("-"*(len(title)))

                    observed = locals()[f"y_{ttype}"]

                    with torch.no_grad():
                        predicted = nn_model(locals()[f"x_{ttype}"])
                        if WITH_OBSERVATION_MODEL:
                            predicted = obs_model(predicted)

                    predicted = predicted.view_as(observed).detach().numpy()
                    observed = observed.detach().numpy()
                    these_metrics = performance(predicted, observed, n_boot=N_BOOTSTRAPS, sample_ratio=BOOTSTRAP_SAMPLE_RATIO)
                    metrics[ttype].append(these_metrics)
                    if VERBOSE:
                        display(predicted_vs_observed(predicted, observed, mtype_class, with_metrics=False))

                outputs.append(output)
            if VERBOSE:
                display(HBox(outputs))

        # Average performances
        
        average = defaultdict(dict)
        for key in metrics["test"][0]:
            for label in ttypes:
                # this zero here ---v is super important! we only want the mean of the means!
                values =  [fold[key][0] for fold in metrics[label]]
                average[label][key] = {
                    "mean": np.mean(values),
                    "std": np.std(values)
                }
        if VERBOSE:
            for label in ttypes:    
                display(HTML(f"Bootstrapped average across folds ({label}):"))
                display(pd.DataFrame.from_dict(average[label]))
        kinase_metrics[kinase][mtype] = average

### Summary

`kinase_metrics` is a nested dictionary with these dimensions:

- kinase name
- measurement type
- metric
- mean & standard deviation

In [12]:
display(Markdown(f"### {DATASET}, observation model = {WITH_OBSERVATION_MODEL}"))
for mtype in MEASUREMENT_TYPES:
    display(Markdown(f"#### {mtype}"))
    # This is going to be fun:
    df = pd.concat({kinase_name: 
                    pd.DataFrame.from_dict(
                        {f"{train_test}_{metric}_{stat}": (value,) 
                         for train_test, vv in v[mtype].items() 
                         for metric, vvv    in vv.items() 
                         for stat, value    in vvv.items()}
                    ).assign(zeros=(datasets[kinase_name][mtype].data_y == 0).sum().detach().numpy())        
                    for kinase_name, v in sorted(kinase_metrics.items(), key=lambda kv: kv[0].lower())
                    if mtype in v})
    
    df.index = [index[0] for index in df.index]
    with pd.option_context("display.float_format", "{:.3f}".format, "display.max_rows", len(df)):
        display(df.style.background_gradient(subset=["train_r2_mean", "test_r2_mean"], low=0, high=1, vmin=0, vmax=1))

### ChEMBL, observation model = True

#### pKiMeasurement

Unnamed: 0,train_mae_mean,train_mae_std,train_mse_mean,train_mse_std,train_r2_mean,train_r2_std,train_rmse_mean,train_rmse_std,test_mae_mean,test_mae_std,test_mse_mean,test_mse_std,test_r2_mean,test_r2_std,test_rmse_mean,test_rmse_std,zeros
O60674-O60674,0.848317,0.111574,1.163537,0.264387,0.206386,0.112729,1.070819,0.129935,1.070177,0.289177,1.780807,0.937961,-3.155895,2.940528,1.286721,0.353775,0
O60885,0.787926,0.038169,0.944708,0.083885,0.116081,0.180053,0.970986,0.043535,0.937757,0.309898,1.369301,0.850595,-0.905688,1.644557,1.127486,0.31317,0
O96013,0.587511,0.028782,0.56435,0.053243,0.280744,0.149938,0.750379,0.035791,0.963989,0.387796,1.619548,1.137345,-1.010642,0.832008,1.18464,0.464947,0
P05129,0.498605,0.043839,0.463047,0.063356,0.672762,0.120324,0.678855,0.04693,1.039772,0.717593,1.94937,2.327983,-0.645012,1.50319,1.213817,0.68994,0
P08581,0.7184,0.068819,0.822364,0.167521,0.345565,0.16572,0.902573,0.087893,0.859731,0.26479,1.255374,0.730607,0.021602,0.354125,1.081262,0.293677,0
P11309,1.084153,0.078327,1.894601,0.258997,-0.018313,0.342972,1.373185,0.094676,1.107049,0.214074,1.922686,0.612277,-0.491382,0.672424,1.369403,0.217767,0
P12931,0.626715,0.021793,0.63825,0.037155,0.101594,0.15114,0.798573,0.023054,0.765619,0.167537,1.032618,0.393489,-0.884084,0.820959,0.99539,0.204493,0
P23458-P23458,0.624222,0.079481,0.60916,0.12631,0.20803,0.222573,0.776345,0.080299,0.913848,0.338266,1.261563,0.756792,-7.816945,12.057183,1.066471,0.352423,0
P24941,0.705466,0.049239,0.813558,0.143371,-0.015543,0.176827,0.898727,0.076473,1.00863,0.218737,1.638707,0.657002,-1.685086,1.31516,1.251594,0.268737,0
P29597-P29597,0.923619,0.11423,1.300601,0.303199,-0.150611,0.307303,1.132919,0.130751,0.87349,0.238627,1.160435,0.63259,-0.450985,0.732086,1.042094,0.272901,0


#### pIC50Measurement

Unnamed: 0,train_mae_mean,train_mae_std,train_mse_mean,train_mse_std,train_r2_mean,train_r2_std,train_rmse_mean,train_rmse_std,test_mae_mean,test_mae_std,test_mse_mean,test_mse_std,test_r2_mean,test_r2_std,test_rmse_mean,test_rmse_std,zeros
O00141,0.888094,0.086214,1.217273,0.225995,0.074177,0.199977,1.09862,0.101524,1.089986,0.459511,1.8788,1.328766,-1.295925,2.091227,1.289192,0.465601,0
O14920,0.83534,0.038187,1.121864,0.086749,0.154791,0.072952,1.058376,0.041288,0.917643,0.088026,1.325822,0.202874,-0.100057,0.278514,1.14832,0.084758,0
O15164,0.409569,0.059604,0.26551,0.092915,0.709268,0.052057,0.507707,0.087998,0.538847,0.168192,0.433214,0.25234,-0.750648,1.873085,0.630773,0.187989,0
O15264,0.507872,0.033986,0.372814,0.019599,0.703249,0.045243,0.610373,0.016077,0.81762,0.191253,1.127667,0.486097,-0.842172,1.253956,1.033258,0.245042,0
O15530,0.746913,0.063845,0.888824,0.133165,0.507413,0.061714,0.939939,0.073068,0.932322,0.363821,1.390962,0.987989,0.099683,0.510074,1.117749,0.376296,0
O60674-O60674,1.001878,0.027194,1.610578,0.083926,-0.220867,0.064529,1.268661,0.03283,1.138046,0.087194,2.068046,0.333419,-0.673854,0.515627,1.433345,0.116486,0
O60885,0.721925,0.02404,0.821862,0.045904,0.284456,0.040426,0.906209,0.025424,0.842234,0.154502,1.123335,0.360351,-0.067658,0.272959,1.045802,0.172141,0
O75582-O75582,0.991416,0.075004,1.489091,0.168088,0.302779,0.167122,1.218298,0.069573,1.406135,0.299191,2.712909,1.148589,-0.65418,0.482448,1.611188,0.342025,0
O95819,1.000683,0.077335,1.448684,0.094672,-0.17854,0.159616,1.202958,0.03971,1.507529,0.618472,3.152107,2.246673,-6.185066,5.943751,1.689328,0.546146,0
O96013,0.474494,0.045258,0.377099,0.060996,0.588988,0.08778,0.612107,0.049231,0.97667,0.59032,1.964973,2.028685,-0.735769,0.854823,1.246933,0.640415,0


#### pKdMeasurement

Unnamed: 0,train_mae_mean,train_mae_std,train_mse_mean,train_mse_std,train_r2_mean,train_r2_std,train_rmse_mean,train_rmse_std,test_mae_mean,test_mae_std,test_mse_mean,test_mse_std,test_r2_mean,test_r2_std,test_rmse_mean,test_rmse_std,zeros
O14578,0.466705,0.043244,0.396919,0.088127,0.264714,0.097782,0.625471,0.075533,0.687982,0.219282,0.863181,0.669688,-0.808948,1.252754,0.867708,0.332059,0
O14976,0.760472,0.053048,0.910698,0.113902,-0.288738,0.147591,0.952347,0.061105,0.948493,0.218169,1.417533,0.481557,-0.931098,0.54112,1.174636,0.194329,0
O15530,0.571796,0.035392,0.562984,0.044382,0.687742,0.055136,0.749726,0.029913,0.903857,0.370185,1.309722,0.911092,-0.335074,0.934822,1.073507,0.396618,0
O43293,0.720963,0.05323,0.869052,0.114751,0.266029,0.062602,0.930244,0.060807,1.123206,0.314648,2.186873,1.253058,-1.639441,1.762953,1.423463,0.400783,0
O60674-O60674,0.870581,0.100028,1.198622,0.272509,0.259486,0.14276,1.08764,0.125145,1.254064,0.393964,2.225249,1.170766,-0.769527,0.677996,1.446751,0.363538,0
O60885,0.602051,0.031849,0.719103,0.102989,0.313361,0.151321,0.845623,0.063443,0.929867,0.268632,1.403233,0.605832,-0.613081,0.854184,1.153839,0.268121,0
O75582-O75582,0.550629,0.060732,0.539115,0.145101,0.02133,0.32608,0.727874,0.096512,0.948027,0.113205,1.65634,0.339295,-3.399867,2.255962,1.28072,0.126872,0
O75676-O75676,0.439735,0.029807,0.334434,0.037963,0.402172,0.072593,0.577393,0.032432,0.890506,0.349607,1.338886,0.844816,-1.768131,1.259028,1.102865,0.350106,0
O95819,0.498484,0.039584,0.450628,0.051161,0.715991,0.102934,0.670213,0.037988,0.950122,0.243576,1.456926,0.73766,-2.028974,2.572225,1.160042,0.333511,0
P04629,0.696511,0.071817,0.679923,0.087039,0.421379,0.105771,0.822926,0.052114,1.522675,0.714122,3.180903,2.604275,-20.449534,24.49651,1.634537,0.713576,0


#### Overall performance

In [13]:
df[["train_r2_mean", "train_r2_std", "test_r2_mean", "test_r2_std", "zeros"]].describe().style.apply(lambda x: ['font-weight: bold' for v in x], subset=pd.IndexSlice[["mean", "std"], :])

Unnamed: 0,train_r2_mean,train_r2_std,test_r2_mean,test_r2_std,zeros
count,42.0,42.0,42.0,42.0,42.0
mean,0.284702,0.1143,-1.669861,1.694699,0.0
std,0.274219,0.07536,3.111056,3.658957,0.0
min,-0.525417,0.022413,-20.449534,0.288837,0.0
25%,0.14313,0.063544,-1.746117,0.691537,0.0
50%,0.298371,0.103916,-0.925246,0.95857,0.0
75%,0.438349,0.133223,-0.597206,1.516401,0.0
max,0.715991,0.368774,0.193687,24.49651,0.0


### Save reports to disk

In [None]:
%%capture cap --no-stderr
from kinoml.utils import watermark
import json

df.to_csv(OUT / "performance.csv")

with open(OUT / "performance.json", "w") as f:
    json.dump(kinase_metrics, f)

In [None]:
w = watermark()
with open(OUT/ "watermark.txt", "w") as f:
    f.write(cap.stdout)

with open(OUT / "hparams.json", "w") as f:
    json.dump(_hparams, f, default=str)