In [1]:
%load_ext autoreload
%autoreload 2


import torch
from torch import nn
import numpy as np
import seaborn as sns


import pandas as pd
import numpy as np
from spacecutter.models import OrdinalLogisticModel
import torch
from torch import nn
import datamol as dm
import matplotlib.pyplot as plt

from skorch import NeuralNet
from skorch.dataset import Dataset
from skorch.helper import SkorchDoctor
from skorch.callbacks import EarlyStopping
from spacecutter.models import OrdinalLogisticModel, OrdinalLogisticMultiTaskModel
from spacecutter.losses import CumulativeLinkLoss, MultiTaskCumulativeLinkLoss
from spacecutter.callbacks import AscensionCallback
from spacecutter.losses import CumulativeLinkLoss
from sklearn.metrics import mean_absolute_error
from scipy.stats import kendalltau

from utils import train_data

In [2]:
# proj_dir =  'drive/MyDrive/Polaris_ASAP_competition/polaris_challenge/admet'
proj_dir = '/Users/robertarbon/Library/CloudStorage/GoogleDrive-robert.arbon@gmail.com/My Drive/Polaris_ASAP_competition/polaris_challenge/admet'

In [3]:
# Imputed training data
df_imp = pd.read_csv(f'{proj_dir}/dm_features/ordinal_data_split_2/train_admet_split2_log_pmm_imputed.csv')
# Non-imputed validation data
df_val = pd.read_csv(f'{proj_dir}/dm_features/ordinal_data_split_2/train_admet_split2_features.csv')
# change names
df_val.rename(columns={'Molecule Name': 'Molecule.Name', 'LogMDR1-MDCKII':'LogMDR1.MDCKII'}, inplace=True)
df_imp.rename(columns={'Molecule Name': 'Molecule.Name', 'LogMDR1-MDCKII':'LogMDR1.MDCKII'}, inplace=True)

# Smiles columns because they were removed (for some unknown reason)
df_smiles = pd.read_csv(f'{proj_dir}/data/train_admet_all.csv')
df_smiles.rename(columns={'Molecule Name': 'Molecule.Name', 'LogMDR1-MDCKII':'LogMDR1.MDCKII'}, inplace=True)


df_imp = df_imp.merge(df_smiles.loc[:, ['Molecule.Name', 'CXSMILES']], on='Molecule.Name', how='left')
df_val = df_val.merge(df_smiles.loc[:, ['Molecule.Name', 'CXSMILES']], on='Molecule.Name', how='left')

In [4]:
def to_model_format(train, val):
    """
    Puts the training and validation data into a convenient form. 
    """
    config = {}
    train_X = train[0]
    targets = list(train[1].keys())
    targets.sort()
    train_y = np.concatenate([train[1][target]['values'].reshape(-1, 1) for target in targets], axis=1)
    n_tasks = train_y.shape[1]
    n_classes_per_task = [np.unique(train_y[:, i]).shape[0] for i in range(n_tasks)]
    n_obs, n_features = train_X.shape
    print(f" {n_tasks} tasks\n classes/task: {n_classes_per_task}\n features: {n_features}, obs: {n_obs}")

    val_y = np.concatenate([val[1][target]['values'].reshape(-1, 1) for target in targets], axis=1) 
    val_X = val[0]
    train_val_X = np.vstack([train_X, val_X]).astype(np.float32)
    train_val_y = np.vstack([train_y, val_y]).astype(np.float32)
    train_ix = np.arange(train_X.shape[0])
    val_ix = np.arange(train_X.shape[0], train_val_X.shape[0])

    config['n_features'] = n_features
    config['n_tasks'] = n_tasks
    config['n_classes_per_task'] = n_classes_per_task
    
    return train_val_X, train_val_y, train_ix, val_ix, config

def ord_to_cont(train, y_ord):
    targets = list(train[1].keys())
    targets.sort() 
    y_cont = []
    for i, target in enumerate(targets):
        bins = targets[2][target]['bins']
        y_cont.apend(np.array([bins[x] if not np.isnan(x) else np.nan for x in y_ord[:, i]]).reshape(-1, 1))
    return np.concatenate(y_cont, axis=1)

def mtl_mae(train, y_pred, y_true_cont):
    y_pred_cont = ord_to_cont(train, y_pred)
    diff = np.abs(y_pred_cont - y_true_cont)
    return np.mean(diff, where=~np.isnan(diff))
    
def plot_results(train, val_y_pred, val_y_true_cont, train_y_pred, train_y_true_cont):
    val_y_pred_cont = ord_to_cont(train, val_y_pred)
    train_y_pred_cont = ord_to_cont(train, train_y_pred) 
    cols = sns.color_palette('colorblind')

    targets = list(train[1].keys())
    targets.sort()
    fig, axes = plt.subplots(len(targets), figsize=(6, 3*len(targets)))
    for i, ax in enumerate(axes):
        min_val = np.min((val_y_pred_cont.min(), val_y_true_cont.min(), train_y_pred_cont.min(), train_y_true_cont.min()))
        max_val = np.max((val_y_pred_cont.max(), val_y_true_cont.max(), train_y_pred_cont.max(), train_y_true_cont.max())) 
        ax.scatter(val_y_pred_cont, val_y_true_cont, label='validation', color=cols[0])
        ax.scatter(train_y_pred_cont, train_y_true_cont, label='train', color=cols[1])
        ax.plot([min_val, max_val], [min_val, max_val], label='y=x', color='black')

        ax.annotate(text=f"val MAE: {mtl_mae(train, val_y_pred, val_y_true_cont):4.2f}", xy=(0.1, 0.9))
        ax.set_title(targets[i])
        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
    return axes


def predict(skorch_model, X):
    mod = skorch_model.module_
    mod.eval()
    y_pred_list = mod.forward(torch.as_tensor(X))

    y_pred_list = [x.cpu().detach().numpy() for x in y_pred_list]
    y_preds_ord = [np.argmax(x, axis=1) for x in y_pred_list]
    return y_preds_ord



In [5]:
train, val = train_data(df_imp, imp_ix=1, df_val=df_val, n_cuts=None, features='chemberta', proj_dir=proj_dir, remove_nans=False)
X, y, train_ix, val_ix, config = to_model_format(train, val)

n_features = config['n_features']
# out_dim = max(n_features//10, 2)
backbone = nn.Sequential(
    nn.Linear(n_features, n_features),
    nn.ReLU(),
    nn.Linear(n_features, n_features),
    nn.ReLU()
)
head = nn.Sequential(
    nn.Linear(n_features, n_features),
    nn.ReLU(),
    nn.Linear(n_features, 1), 
    nn.ReLU())



training data
using chemberta
	creating new scaler
validation data
using chemberta
	using existing scaler
 5 tasks
 classes/task: [59, 165, 160, 274, 178]
 features: 384, obs: 354


In [8]:
model = NeuralNet(
    module=OrdinalLogisticMultiTaskModel,
    module__backbone=backbone,
    module__head=head,
    module__n_classes=config['n_classes_per_task'],
    criterion=MultiTaskCumulativeLinkLoss,
    criterion__n_tasks=config['n_tasks'],
    criterion__n_classes_per_task = config['n_classes_per_task'], 
    criterion__loss_reduction = 'inv_num_classes', 
    optimizer=torch.optim.Adam,
    optimizer__weight_decay = 1e-3,
    train_split=lambda ds, y: (torch.utils.data.Subset(ds, train_ix),
                                torch.utils.data.Subset(ds, val_ix)),
    callbacks=[
        ('ascension', AscensionCallback()),
        ('early_stopping', EarlyStopping(threshold=0.0001, load_best=True,
                                        patience=10))
    ],
    verbose=0,
    batch_size=X.shape[0],
    max_epochs=500,
)

model.fit(X, y)

