# Evaluate single model

## Setup

#### Packages

In [None]:
# Install packages
!pip install ray[rllib]==0.8.1  # also recommended: ray[debug]
!pip uninstall -y pyarrow
!pip uninstall -y pickle5
!pip install torch===1.4.0 torchvision===0.5.0 -f https://download.pytorch.org/whl/torch_stable.html

Collecting ray[rllib]==0.8.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/37/0e9877a2729d31881d9bb2cad1f9aedd2f451602af67706df6faaef33e7f/ray-0.8.1-cp37-cp37m-manylinux1_x86_64.whl (74.3MB)
[K     |████████████████████████████████| 74.3MB 43kB/s 
Collecting redis>=3.3.2
[?25l  Downloading https://files.pythonhosted.org/packages/a7/7c/24fb0511df653cf1a5d938d8f5d19802a88cef255706fdda242ff97e91b7/redis-3.5.3-py2.py3-none-any.whl (72kB)
[K     |████████████████████████████████| 81kB 9.5MB/s 
Collecting funcsigs
  Downloading https://files.pythonhosted.org/packages/69/cb/f5be453359271714c01b9bd06126eaf2e368f1fddfff30818754b5ac2328/funcsigs-1.0.2-py2.py3-none-any.whl
Collecting colorama
  Downloading https://files.pythonhosted.org/packages/44/98/5b86278fbbf250d239ae0ecb724f8572af1c91f4a11edf4d36a206189440/colorama-0.4.4-py2.py3-none-any.whl
Collecting lz4; extra == "rllib"
[?25l  Downloading https://files.pythonhosted.org/packages/e3/52/151c815a486290608e4dc6699a0cfd741

In [None]:
# Mount Google Drive if needed
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


#### Imports

In [None]:
# Setup some constants
project_dir = "..."
global_response_metric = "LN_IC50"
# Standard imports
import ray
from ray import tune
from ray.tune import track

import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import time
import dill
import warnings

from scipy.stats import pearsonr

import torch
from torch.autograd import Variable
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch import nn

from sklearn import metrics

# Custom utilities imports
sys.path.append(project_dir + "Scripts/Modules")
from modeling import Dataset

print(ray.__version__)
print(torch.__version__)

0.8.1
1.4.0


## Network definitions and helper classes

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import time
import json
import os

from sklearn import metrics
from scipy.stats import pearsonr

import torch
from torch.autograd import Variable
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch import nn

# Network definitions
# Linear model
class LinearMatrixFactorizationWithFeatures(torch.nn.Module):
    def __init__(self, drug_input_dim, cell_line_input_dim, output_dim, 
                 out_activation_func=None,
                 drug_bias=True,
                 cell_line_bias=True):
        super(LinearMatrixFactorizationWithFeatures, self).__init__()
        self.drug_linear = torch.nn.Linear(drug_input_dim, output_dim, bias=drug_bias)
        self.cell_line_linear = torch.nn.Linear(cell_line_input_dim, output_dim, bias=cell_line_bias)
        self.out_activation = out_activation_func
        
    def forward(self, drug_features, cell_line_features):
        drug_outputs = self.drug_linear(drug_features)
        cell_line_outputs = self.cell_line_linear(cell_line_features)
        
        final_outputs = torch.sum(torch.mul(drug_outputs, cell_line_outputs), dim=1).view(-1, 1)
        if self.out_activation:
            return self.out_activation(final_outputs)
        return final_outputs
    
# Deep autoencoder with one hidden layer
class DeepAutoencoderOneHiddenLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, code_dim, activation_func=nn.ReLU, 
                 code_activation=True, dropout=False, dropout_rate=0.5):
        super(DeepAutoencoderOneHiddenLayer, self).__init__()
        # Establish encoder
        modules = []
        modules.append(nn.Linear(input_dim, hidden_dim))
        modules.append(activation_func())
        if dropout:
            modules.append(nn.Dropout(dropout_rate))
        modules.append(nn.Linear(hidden_dim, code_dim))
        if code_activation:
            modules.append(activation_func())
        self.encoder = nn.Sequential(*modules)
        # Establish decoder
        modules = []
        modules.append(nn.Linear(code_dim, hidden_dim))
        modules.append(activation_func())
        if dropout:
            modules.append(nn.Dropout(dropout_rate))
        modules.append(nn.Linear(hidden_dim, input_dim))
        self.decoder = nn.Sequential(*modules)
        
    def forward(self, x):
        x = self.encoder(x)
        code = x
        x = self.decoder(x)
        return code, x
    
# Deep autoencoder with two hidden layers
class DeepAutoencoderTwoHiddenLayers(nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, code_dim, activation_func=nn.ReLU,
                 code_activation=True, dropout=False, dropout_rate=0.5):
        super(DeepAutoencoderTwoHiddenLayers, self).__init__()
        # Establish encoder
        modules = []
        modules.append(nn.Linear(input_dim, hidden_dim1))
        modules.append(activation_func())
        if dropout:
            modules.append(nn.Dropout(dropout_rate))
        modules.append(nn.Linear(hidden_dim1, hidden_dim2))
        modules.append(activation_func())
        if dropout:
            modules.append(nn.Dropout(dropout_rate))
        modules.append(nn.Linear(hidden_dim2, code_dim))
        if code_activation:
            modules.append(activation_func())
        self.encoder = nn.Sequential(*modules)
        
        # Establish decoder
        modules = []
        modules.append(nn.Linear(code_dim, hidden_dim2))
        modules.append(activation_func())
        if dropout:
            modules.append(nn.Dropout(dropout_rate))
        modules.append(nn.Linear(hidden_dim2, hidden_dim1))
        modules.append(activation_func())
        if dropout:
            modules.append(nn.Dropout(dropout_rate))
        modules.append(nn.Linear(hidden_dim1, input_dim))
        self.decoder = nn.Sequential(*modules)
        
    def forward(self, x):
        x = self.encoder(x)
        code = x
        x = self.decoder(x)
        return code, x
    
# Deep autoencoder with three hidden layers
class DeepAutoencoderThreeHiddenLayers(nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, hidden_dim3, code_dim, activation_func=nn.ReLU,
                 code_activation=True, dropout=False, dropout_rate=0.5):
        super(DeepAutoencoderThreeHiddenLayers, self).__init__()
        # Establish encoder
        modules = []
        modules.append(nn.Linear(input_dim, hidden_dim1))
        modules.append(activation_func())
        if dropout:
            modules.append(nn.Dropout(dropout_rate))
        modules.append(nn.Linear(hidden_dim1, hidden_dim2))
        modules.append(activation_func())
        if dropout:
            modules.append(nn.Dropout(dropout_rate))
        modules.append(nn.Linear(hidden_dim2, hidden_dim3))
        modules.append(activation_func())
        if dropout:
            modules.append(nn.Dropout(dropout_rate))
        modules.append(nn.Linear(hidden_dim3, code_dim))
        if code_activation:
            modules.append(activation_func())
        self.encoder = nn.Sequential(*modules)
        
        # Establish decoder
        modules = []
        modules.append(nn.Linear(code_dim, hidden_dim3))
        modules.append(activation_func())
        if dropout:
            modules.append(nn.Dropout(dropout_rate))
        modules.append(nn.Linear(hidden_dim3, hidden_dim2))
        modules.append(activation_func())
        if dropout:
            modules.append(nn.Dropout(dropout_rate))
        modules.append(nn.Linear(hidden_dim2, hidden_dim1))
        modules.append(activation_func())
        if dropout:
            modules.append(nn.Dropout(dropout_rate))
        modules.append(nn.Linear(hidden_dim1, input_dim))
        self.decoder = nn.Sequential(*modules)
        
    def forward(self, x):
        x = self.encoder(x)
        code = x
        x = self.decoder(x)
        return code, x

class ForwardNetworkOneHiddenLayer(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim1, activation_func=nn.ReLU,
                out_activation=None):
        super(ForwardNetworkOneHiddenLayer, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            activation_func(),
            nn.Linear(hidden_dim1, 1))
        self.out_activation = out_activation
    
    def forward(self, x):
        if self.out_activation:
            return self.out_activation(self.layers(x))
        else:
            return self.layers(x)

# Rec system with incorporated autoencoders
class RecSystemWithAutoencoders(torch.nn.Module):
    def __init__(self, 
                 drug_autoencoder,
                 cell_line_autoencoder,
                 out_activation=None):
        
        super(RecSystemWithAutoencoders, self).__init__()
        self.drug_autoencoder = drug_autoencoder
        self.cell_line_autoencoder = cell_line_autoencoder
        self.out_activation = out_activation
        
    def forward(self, drug_features, cell_line_features):
        drug_code, drug_reconstruction = self.drug_autoencoder(drug_features)
        cell_line_code, cell_line_reconstruction = self.cell_line_autoencoder(cell_line_features)
        
        final_outputs = torch.sum(torch.mul(drug_code, cell_line_code), dim=1).view(-1, 1)
        if self.out_activation:
            return self.out_activation(final_outputs), drug_reconstruction, cell_line_reconstruction
        return final_outputs, drug_reconstruction, cell_line_reconstruction

class ForwardLinearRegression(torch.nn.Module):
    def __init__(self, input_dim, out_activation=None):
        super(ForwardLinearRegression, self).__init__()
        self.linear = nn.Linear(input_dim, 1)
        self.out_activation = out_activation
        
    def forward(self, x):
        if self.out_activation:
            return self.out_activation(self.linear(x))
        return self.linear(x)

class ForwardNetworkTwoHiddenLayers(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, activation_func=nn.ReLU,
                out_activation=None, dropout_rate=0.0):
        super(ForwardNetworkTwoHiddenLayers, self).__init__()
        
        self.layers = nn.Sequential(
             nn.Linear(input_dim, hidden_dim1),
             activation_func(),
             nn.Dropout(dropout_rate),
             nn.Linear(hidden_dim1, hidden_dim2),
             activation_func(),
             nn.Linear(hidden_dim2, 1))
        
        self.out_activation = out_activation
        
    
    def forward(self, x):
        if self.out_activation:
            return self.out_activation(self.layers(x))
        else:
            return self.layers(x)

class RecSystemCodeConcatenation(torch.nn.Module):
    def __init__(self, drug_autoencoder, cell_line_autoencoder, 
                 forward_network, 
                 code_interactions=False):
        super(RecSystemCodeConcatenation, self).__init__()
        self.drug_autoencoder = drug_autoencoder
        self.cell_line_autoencoder = cell_line_autoencoder
        self.forward_network = forward_network
        self.code_interactions = code_interactions
        
    def forward(self, drug_features, cell_line_features):
        drug_code, drug_reconstruction = self.drug_autoencoder(drug_features)
        cell_line_code, cell_line_reconstruction = self.cell_line_autoencoder(cell_line_features)
                
        if self.code_interactions:
            drug_code_t = drug_code.view(drug_code.shape[0], drug_code.shape[1], 1)
            cell_line_code_t = cell_line_code.view(cell_line_code.shape[0], 1, cell_line_code.shape[1])
            x = torch.bmm(drug_code_t, cell_line_code_t)
            x = x.view(cell_line_code.shape[0], x.shape[1] * x.shape[2])
            x = torch.cat((drug_code, cell_line_code, x), axis=1)
            return self.forward_network(x), drug_reconstruction, cell_line_reconstruction

        else:
            # Concatenate codes without interactions
            x = torch.cat((drug_code, cell_line_code), axis=1)
            return self.forward_network(x), drug_reconstruction, cell_line_reconstruction

### MODEL CLASSES
class Model:
    """Wrapper around PyTorch model.

    Contains helper functions for training and evaluating the underlying network.
    """
    def __init__(self, name, network):
        """Instance initializer.

        Args:
            name (str): Custom name of the model.
            network (PyTorch model): Underlying PyTorch model.
        """
        self.name = name
        self.network = network
        
    def train(self, train_samples, cell_line_features, drug_features,
             batch_size, optimizer, criterion, reg_lambda=0, log=True, response_metric="AUC"):
        """Perform one epoch of training of the underlying network.

        Args:
            train_samples (DataFrame): Table containing drug-cell line training pairs and corresponding response metric.
            cell_line_features (DataFrame): Cell line features data.
            drug_features (DataFrame): Drug features data.
            batch_size (int): Batch size.
            optimizer (PyTorch optimizer): Optimizer to use.
            criterion (PyTorch cost function): Cost function to optimize.
            reg_lambda (float): Weight of the L2 regularization, defaults to 0.
            log (bool): If to print some information during training, defaults to True.

        Returns:
            loss (float): Value of the loss drug response loss after one epoch of training.

        """
        no_batches = train_samples.shape[0] // batch_size + 1
        # Establish the device
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        if log:
            print(device)
        # Move the network into device
        self.network.to(device)
        # Training the model
        self.network.train()
        for batch in range(no_batches):
            # Separate response variable batch
            if batch != no_batches:
                samples_batch = train_samples.iloc[batch * batch_size:(batch + 1) * batch_size]
            else:
                samples_batch = train_samples.iloc[batch * batch_size:]

            # Extract output variable batch
            y_batch = torch.from_numpy(samples_batch[response_metric].values).view(-1, 1).to(device)

            # Extract cell lines IDs for which data shall be extracted
            cl_ids = samples_batch["COSMIC_ID"].values
            # Extract corresponding cell line data
            cell_line_input_batch = cell_line_features.loc[cl_ids].values
            cell_line_input_batch = torch.from_numpy(cell_line_input_batch).to(device)

            # Extract drug IDs for which data shall be extracted
            drug_ids = samples_batch["DRUG_ID"].values
            # Extract corresponding drug data
            drug_input_batch = drug_features.loc[drug_ids].values
            drug_input_batch = torch.from_numpy(drug_input_batch).to(device)

            # Clear gradient buffers because we don't want to accummulate gradients 
            optimizer.zero_grad()

            # Perform forward pass
            batch_output = self.network(drug_input_batch.float(), cell_line_input_batch.float())

            # L2 regularization
            reg_sum = 0
            for param in self.network.parameters():
                reg_sum += 0.5 * (param ** 2).sum()  # L2 norm

            # Compute the loss for this batch
            loss = criterion(batch_output, y_batch.float()) + reg_lambda * reg_sum
            # Get the gradients w.r.t. the parameters
            loss.backward()
            # Update the parameters
            optimizer.step()
        return loss
    
    def predict(self, samples, cell_line_features, drug_features, response_metric="AUC"):
        """Predict response for a given set of samples.

        Args:
            samples (DataFrame): Table containing drug-cell line pairs and corresponding response metric.
            cell_line_features (DataFrame): Cell line features data.
            drug_features (DataFrame): Drug features data.

        Returns:
            predicted (torch.Tensor): Model's predictions for provided samples.
            y_true (np.array): True response values for provided samples.
        """
        # Establish the device
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        # Extract true target values
        y_true = samples[response_metric].values

        cl_input = cell_line_features.loc[samples["COSMIC_ID"].values].values
        drug_input = drug_features.loc[samples["DRUG_ID"].values].values

        self.network.eval()
        with torch.no_grad():
            predicted = self.network(torch.from_numpy(drug_input).to(device).float(), 
                             torch.from_numpy(cl_input).to(device).float())
        return predicted, y_true
    
    @staticmethod
    def per_drug_performance_df(samples, predicted, mean_training_auc=None, response_metric="AUC"):
        """Compute evaluation metrics per drug and return them in a DataFrame.

        Args:
            samples (DataFrame): Table containing drug-cell line pairs and corresponding response metric.
            predicted (torch.Tensor): Model's predictions for considered samples.
            mean_training_auc (float): Mean of drug-response in training data for calculating dummy values,
                defaults to None. If None, mean of true AUC for a given drug is considered, resulting in
                dummy RMSE being the standard deviation of the AUC for a given drug.

        Returns:
            performance_per_drug (DataFrame): Table containing per-drug model and dummy performance metrics.
        """
        sample_with_predictions = samples.copy()
        sample_with_predictions["Predicted " + str(response_metric)] = predicted.numpy()

        drugs = []
        model_corrs = []
        model_rmses = []
        dummy_corrs = []
        dummy_rmses = []
        no_samples = []

        for drug in sample_with_predictions.DRUG_ID.unique():
            df = sample_with_predictions[sample_with_predictions.DRUG_ID == drug]
            if df.shape[0] < 2:
                continue
            if mean_training_auc:
                dummy_preds = [mean_training_auc] * df.shape[0]
            else:
                dummy_preds = [df[response_metric].mean()] * df.shape[0]
            dummy_rmse = metrics.mean_squared_error(df[response_metric], dummy_preds) ** 0.5
            dummy_corr = pearsonr(df[response_metric], dummy_preds)

            try:
                model_rmse = metrics.mean_squared_error(df[response_metric], df["Predicted " + str(response_metric)]) ** 0.5
                model_corr = pearsonr(df[response_metric], df["Predicted " + str(response_metric)])
            except ValueError:
                model_rmse, model_corr = np.nan, (np.nan, np.nan)

            drugs.append(drug)
            dummy_rmses.append(dummy_rmse)
            dummy_corrs.append(dummy_corr[0])

            model_rmses.append(model_rmse)
            model_corrs.append(model_corr[0])

            no_samples.append(df.COSMIC_ID.nunique())

        performance_per_drug = pd.DataFrame()
        performance_per_drug["Drug ID"] = drugs
        performance_per_drug["Model RMSE"] = model_rmses
        performance_per_drug["Model correlation"] = model_corrs

        performance_per_drug["Dummy RMSE"] = dummy_rmses
        performance_per_drug["Dummy correlation"] = dummy_corrs
        performance_per_drug["No. samples"] = no_samples

        return performance_per_drug

    @staticmethod
    def per_entity_performance_df(samples, predicted, entity_type="DRUG_ID", mean_training_auc=None,
                                 response_metric="AUC"):
        """Compute evaluation metrics per entity (drug or cell line) and return them in a DataFrame.

        Args:
            samples (DataFrame): Table containing drug-cell line pairs and corresponding response metric.
            predicted (torch.Tensor): Model's predictions for considered samples.
            mean_training_auc (float): Mean of drug-response in training data for calculating dummy values,
                defaults to None. If None, mean of true AUC for a given drug is considered, resulting in
                dummy RMSE being the standard deviation of the AUC for a given drug.

        Returns:
            performance_per_entity (DataFrame): Table containing per-entity model and dummy performance metrics.
        """
        sample_with_predictions = samples.copy()
        sample_with_predictions["Predicted " + str(response_metric)] = predicted.numpy()

        entities = []
        model_corrs = []
        model_rmses = []
        dummy_corrs = []
        dummy_rmses = []
        no_samples = []

        for entity in sample_with_predictions[entity_type].unique():
            df = sample_with_predictions[sample_with_predictions[entity_type] == entity]
            if df.shape[0] < 2:
                continue
            if mean_training_auc:
                dummy_preds = [mean_training_auc] * df.shape[0]
            else:
                dummy_preds = [df[response_metric].mean()] * df.shape[0]
            dummy_rmse = metrics.mean_squared_error(df[response_metric], dummy_preds) ** 0.5
            dummy_corr = pearsonr(df[response_metric], dummy_preds)

            try:
                model_rmse = metrics.mean_squared_error(df[response_metric], df["Predicted " + str(response_metric)]) ** 0.5
                model_corr = pearsonr(df[response_metric], df["Predicted " + str(response_metric)])
            except ValueError:
                model_rmse, model_corr = np.nan, (np.nan, np.nan)

            entities.append(entity)
            dummy_rmses.append(dummy_rmse)
            dummy_corrs.append(dummy_corr[0])

            model_rmses.append(model_rmse)
            model_corrs.append(model_corr[0])

            no_samples.append(df.shape[0])

        performance_per_entity = pd.DataFrame()
        performance_per_entity[entity_type] = entities
        performance_per_entity["Model RMSE"] = model_rmses
        performance_per_entity["Model correlation"] = model_corrs

        performance_per_entity["Dummy RMSE"] = dummy_rmses
        performance_per_entity["Dummy correlation"] = dummy_corrs
        performance_per_entity["No. samples"] = no_samples

        return performance_per_entity
        
    @staticmethod
    def evaluate_predictions(y_true, preds):
        """Compute RMSE and correlation with true values for model predictions."""
        return metrics.mean_squared_error(y_true, preds) ** 0.5, pearsonr(y_true, preds)
    
       

    
class ModelWithAutoencoders(Model):
    """ Wrapper around PyTorch model involving autoencoders.

    Inherits from Model class. Train and predict methods are adjusted for optimizing
    drug sensitivity predictions as well as drug and cell line reconstructions.

    """
    def train(self, train_samples, cell_line_features, drug_features,
             batch_size, optimizer, criterion, reconstruction_term_drug=0.0,
              reconstruction_term_cl=0.0, reg_lambda=0.0, log=True, response_metric="AUC"):
        """Perform one epoch of training of the underlying network with autoencoders.

        Rather than only drug-reponse prediction losss, also optimize for difference in drug and cell line
        input data and their corresponding reconstructions.

        Args:
            train_samples (DataFrame): Table containing drug-cell line training pairs and corresponding response metric.
            cell_line_features (DataFrame): Cell line features data.
            drug_features (DataFrame): Drug features data.
            batch_size (int): Batch size.
            optimizer (PyTorch optimizer): Optimizer to use.
            criterion (PyTorch cost function): Cost function to optimize.
            reconstruction_term_drug (float): Weight of reconstruction of input data in
                drug autoencoder, defaults to 0.
            reconstruction_term_cl (float): Weight of reconstruction of input data in
                cell line autoencoder, defaults to 0.
            reg_lambda (float): Weight of the L2 regularization, defaults to 0.
            log (bool): If to print some information during training, defaults to True.

        Returns:
            loss (float): Value of the loss drug response loss after one epoch of training.
            drug_recounstruction_loss (float): Loss between drug input and drug reconstruction.
            cl_reconstruction_loss (float): Loss between cell line input and cell line reconstruction.

        """
        no_batches = train_samples.shape[0] // batch_size + 1
        # Establish the device
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        if log:
          print(device)
        # Move the network into device
        self.network.to(device)
        # Training the model
        self.network.train()
        for batch in range(no_batches):
            # Separate response variable batch
            if batch != no_batches:
                samples_batch = train_samples.iloc[batch * batch_size:(batch + 1) * batch_size]
            else:
                samples_batch = train_samples.iloc[batch * batch_size:]

            # Extract output variable batch
            y_batch = torch.from_numpy(samples_batch[response_metric].values).view(-1, 1).to(device)

            # Extract cell lines IDs for which data shall be extracted
            cl_ids = samples_batch["COSMIC_ID"].values
            # Extract corresponding cell line data
            cell_line_input_batch = cell_line_features.loc[cl_ids].values
            cell_line_input_batch = torch.from_numpy(cell_line_input_batch).to(device)

            # Extract drug IDs for which data shall be extracted
            drug_ids = samples_batch["DRUG_ID"].values
            # Extract corresponding drug data
            drug_input_batch = drug_features.loc[drug_ids].values
            drug_input_batch = torch.from_numpy(drug_input_batch).to(device)

            # Clear gradient buffers because we don't want to accummulate gradients 
            optimizer.zero_grad()

            # Perform forward pass
            batch_output, batch_drug_reconstruction, batch_cl_reconstruction = self.network(
                drug_input_batch.float(), cell_line_input_batch.float())

            # L2 regularization
            reg_sum = 0
            for param in self.network.parameters():
                reg_sum += 0.5 * (param ** 2).sum()  # L2 norm

            # Compute the loss for this batch, including the drug and cell line reconstruction losses
            output_loss = criterion(batch_output, y_batch.float()) + reg_lambda * reg_sum
            drug_recounstruction_loss = criterion(batch_drug_reconstruction, drug_input_batch.float())
            cl_reconstruction_loss = criterion(batch_cl_reconstruction, cell_line_input_batch.float())

            # Combine the losses in the final cost function
            loss = output_loss + reconstruction_term_drug * drug_recounstruction_loss + reconstruction_term_cl * cl_reconstruction_loss
            # Get the gradients w.r.t. the parameters
            loss.backward()
            # Update the parameters
            optimizer.step()
            
        return loss, drug_recounstruction_loss, cl_reconstruction_loss
    
    def predict(self, samples, cell_line_features, drug_features, response_metric="AUC"):
        """Predict response along with drug anc cell line reconstructions for a given set of samples.

        Args:
            samples (DataFrame): Table containing drug-cell line pairs and corresponding response metric.
            cell_line_features (DataFrame): Cell line features data.
            drug_features (DataFrame): Drug features data.

        Returns:
            predicted (torch.Tensor): Model's predictions for provided samples.
            y_true (np.array): True response values for provided samples.
            drug_input (np.array): Drug input data for provided samples.
            cl_input (np.array): Cell line input data for provided samples.

        """
        # Establish the device
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        y_true = samples[response_metric].values

        cl_input = cell_line_features.loc[samples["COSMIC_ID"].values].values
        drug_input = drug_features.loc[samples["DRUG_ID"].values].values

        self.network.eval()
        with torch.no_grad():
            predicted = self.network(torch.from_numpy(drug_input).to(device).float(), 
                             torch.from_numpy(cl_input).to(device).float())
        return predicted, y_true, drug_input, cl_input
        
    def train_with_independence_penalty(self, train_samples, cell_line_features, drug_features,
             batch_size, optimizer, criterion, reconstruction_term_drug=0.0,
              reconstruction_term_cl=0.0, independence_term_drug=0.0, independence_term_cl=0.0, reg_lambda=0.0, log=True,
              response_metric="AUC"):
        """Perform one epoch of training of the underlying network with autoencoders.

        Rather than only drug-reponse prediction losss, also optimize for difference in drug and cell line
        input data and their corresponding reconstructions.

        Args:
            train_samples (DataFrame): Table containing drug-cell line training pairs and corresponding response metric.
            cell_line_features (DataFrame): Cell line features data.
            drug_features (DataFrame): Drug features data.
            batch_size (int): Batch size.
            optimizer (PyTorch optimizer): Optimizer to use.
            criterion (PyTorch cost function): Cost function to optimize.
            reconstruction_term_drug (float): Weight of reconstruction of input data in
                drug autoencoder, defaults to 0.
            reconstruction_term_cl (float): Weight of reconstruction of input data in
                cell line autoencoder, defaults to 0.
            reg_lambda (float): Weight of the L2 regularization, defaults to 0.
            log (bool): If to print some information during training, defaults to True.

        Returns:
            loss (float): Value of the loss drug response loss after one epoch of training.
            drug_recounstruction_loss (float): Loss between drug input and drug reconstruction.
            cl_reconstruction_loss (float): Loss between cell line input and cell line reconstruction.

        """
        no_batches = train_samples.shape[0] // batch_size + 1
        # Establish the device
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        if log:
          print(device)
        # Move the network into device
        self.network.to(device)
        # Training the model
        self.network.train()
        for batch in range(no_batches):
            # Separate response variable batch
            if batch != no_batches:
                samples_batch = train_samples.iloc[batch * batch_size:(batch + 1) * batch_size]
            else:
                samples_batch = train_samples.iloc[batch * batch_size:]

            # Extract output variable batch
            y_batch = torch.from_numpy(samples_batch[response_metric].values).view(-1, 1).to(device)

            # Extract cell lines IDs for which data shall be extracted
            cl_ids = samples_batch["COSMIC_ID"].values
            # Extract corresponding cell line data
            cell_line_input_batch = cell_line_features.loc[cl_ids].values
            cell_line_input_batch = torch.from_numpy(cell_line_input_batch).to(device)

            # Extract drug IDs for which data shall be extracted
            drug_ids = samples_batch["DRUG_ID"].values
            # Extract corresponding drug data
            drug_input_batch = drug_features.loc[drug_ids].values
            drug_input_batch = torch.from_numpy(drug_input_batch).to(device)

            # Clear gradient buffers because we don't want to accummulate gradients 
            optimizer.zero_grad()

            # Perform forward pass
            batch_output, batch_drug_reconstruction, batch_cl_reconstruction = self.network(
                drug_input_batch.float(), cell_line_input_batch.float())

            # L2 regularization
            reg_sum = 0
            for param in self.network.parameters():
                reg_sum += 0.5 * (param ** 2).sum()  # L2 norm

            # Compute the loss for this batch, including the drug and cell line reconstruction losses
            output_loss = criterion(batch_output, y_batch.float()) + reg_lambda * reg_sum
            drug_recounstruction_loss = criterion(batch_drug_reconstruction, drug_input_batch.float())
            cl_reconstruction_loss = criterion(batch_cl_reconstruction, cell_line_input_batch.float())

            # Compute independence loss
            # Covariance matrices
            t0 = time.time()
            drug_codes_batch = self.network.drug_autoencoder.encoder(drug_input_batch.float())
            cl_codes_batch = self.network.cell_line_autoencoder.encoder(cell_line_input_batch.float())
            
            drug_cov = self.__class__.covariance_matrix_torch(drug_codes_batch)
            cl_cov = self.__class__.covariance_matrix_torch(cl_codes_batch)
      
            drug_independence_loss = 0
            cl_independence_loss = 0
            drug_independence_loss = (drug_cov * drug_cov).sum() - torch.trace(drug_cov * drug_cov)
            cl_independence_loss = (cl_cov * cl_cov).sum() - torch.trace(cl_cov * cl_cov)

            # Combine the losses in the final cost function
            loss = output_loss + reconstruction_term_drug * drug_recounstruction_loss + \
                    reconstruction_term_cl * cl_reconstruction_loss + \
                    independence_term_drug * drug_independence_loss + independence_term_cl * cl_independence_loss
            # Get the gradients w.r.t. the parameters
            loss.backward()
            # Update the parameters
            optimizer.step()
            
        return loss, drug_recounstruction_loss, cl_reconstruction_loss, drug_independence_loss, cl_independence_loss

    @staticmethod
    def covariance_matrix_torch(m, rowvar=False):
        '''Estimate a covariance matrix given data.

        Covariance indicates the level to which two variables vary together.
        If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
        then the covariance matrix element `C_{ij}` is the covariance of
        `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.

        Args:
            m: A 1-D or 2-D array containing multiple variables and observations.
                Each row of `m` represents a variable, and each column a single
                observation of all those variables.
            rowvar: If `rowvar` is True, then each row represents a
                variable, with observations in the columns. Otherwise, the
                relationship is transposed: each column represents a variable,
                while the rows contain observations.

        Returns:
            The covariance matrix of the variables.
        '''
        if m.dim() > 2:
            raise ValueError('m has more than 2 dimensions')
        if m.dim() < 2:
            m = m.view(1, -1)
        if not rowvar and m.size(0) != 1:
            m = m.t()
        # m = m.type(torch.double)  # uncomment this line if desired
        fact = 1.0 / (m.size(1) - 1)
        m -= torch.mean(m, dim=1, keepdim=True)
        mt = m.t()  # if complex: mt = m.t().conj()
        return fact * m.matmul(mt).squeeze()

def min_max_series(s, minimum=None, maximum=None):
    """Perform min-max scaling on a one-dimensional Series or array."""
    if minimum and maximum:
        return (s - minimum) / (maximum - minimum)
    return (s - s.min()) / (s.max() - s.min())

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pth')
        self.val_loss_min = val_loss

def instantiate_system(specs, state_dict=None):
        """Create a recommender system in accordance with provided specs."""
        # Linear model case
        if specs["architecture_type"] == "linear":
            # Establish out activation
            network = LinearMatrixFactorizationWithFeatures(specs["drug_dim"],
                                        specs["cell_line_dim"], specs["code_dim"],
                                        out_activation_func=specs["out_activation"],
                                        drug_bias=specs["drug_bias"],
                                        cell_line_bias=specs["cell_line_bias"])
        # Autoencoders case
        elif "autoencoder" in specs["architecture_type"].lower():
            if specs["num_layers"] == 1:
                # Establish autoencoders
                drug_autoencoder = DeepAutoencoderOneHiddenLayer(specs["drug_dim"],
                                                    specs["drug_hidden_dim1"], specs["code_dim"],
                                                    activation_func=specs["activation_func"],
                                                    code_activation=specs["code_activation"],
                                                    dropout=specs["dropout"],
                                                    dropout_rate=specs["dropout_rate"])
                cell_line_autoencoder = DeepAutoencoderOneHiddenLayer(specs["cell_line_dim"],
                                                    specs["cell_line_hidden_dim1"], specs["code_dim"],
                                                    activation_func=specs["activation_func"],
                                                    code_activation=specs["code_activation"],
                                                    dropout=specs["dropout"],
                                                    dropout_rate=specs["dropout_rate"])
        
            elif specs["num_layers"] == 2:
                # Setup autoencoders
                drug_autoencoder = DeepAutoencoderTwoHiddenLayers(specs["drug_dim"],
                                                 specs["drug_hidden_dim1"],
                                                 specs["drug_hidden_dim2"],
                                                 specs["code_dim"],
                                                 activation_func=specs["activation_func"],
                                                 code_activation=specs["code_activation"],
                                                 dropout=specs["dropout"],
                                                 dropout_rate=specs["dropout_rate"])
                
                cell_line_autoencoder = DeepAutoencoderTwoHiddenLayers(specs["cell_line_dim"],
                                                 specs["cell_line_hidden_dim1"],
                                                 specs["cell_line_hidden_dim2"],
                                                 specs["code_dim"],
                                                 activation_func=specs["activation_func"],
                                                 code_activation=specs["code_activation"],
                                                 dropout=specs["dropout"],
                                                 dropout_rate=specs["dropout_rate"])
            elif specs["num_layers"] == 3:
                drug_autoencoder = DeepAutoencoderThreeHiddenLayers(specs["drug_dim"],
                                                  specs["drug_hidden_dim1"],
                                                  specs["drug_hidden_dim2"],
                                                  specs["drug_hidden_dim3"],
                                                  specs["code_dim"],
                                                  activation_func=specs["activation_func"],
                                                  code_activation=specs["code_activation"],
                                                  dropout=specs["dropout"],
                                                  dropout_rate=specs["dropout_rate"])
                
                cell_line_autoencoder = DeepAutoencoderThreeHiddenLayers(specs["cell_line_dim"],
                                                  specs["cell_line_hidden_dim1"],
                                                  specs["cell_line_hidden_dim2"],
                                                  specs["cell_line_hidden_dim3"],
                                                  specs["code_dim"],
                                                  activation_func=specs["activation_func"],
                                                  code_activation=specs["code_activation"],
                                                  dropout=specs["dropout"],
                                                  dropout_rate=specs["dropout_rate"])
            # Setup whole system
            network = RecSystemWithAutoencoders(drug_autoencoder,
                                                                  cell_line_autoencoder,
                                                                  specs["out_activation"])
        # If state dict is provided, load the weights
        if state_dict:
            network.load_state_dict(state_dict)
        return network

## Load and preprocess data

In [None]:
with open(project_dir + "Data/Preprocessed Datasets/GDSC-KINOMEscan_proteins_intersection_+_remaining_GDSC_target_genes_dataset_with_IC50.pkl", "rb") as f:
    full_dataset = dill.load(f)

## Data preprocessing
#### Establish response data for samples (drug-cell line pairs)

response_df = full_dataset.response_data.copy()

#### Establish cell line features data
cell_line_data_original_df = full_dataset.full_cell_lines_data.copy()

# Search for cell lines present in response data, but missing the genomic features
missing_cell_lines = []
for cosmic_id in response_df.COSMIC_ID.unique():
    if cosmic_id not in cell_line_data_original_df.cell_line_id.unique():
        missing_cell_lines.append(cosmic_id)
# Put cell line IDs into index and drop cell line IDs columns
cell_line_data_original_df.index = cell_line_data_original_df.cell_line_id
cell_line_data_original_df = cell_line_data_original_df.drop("cell_line_id", axis=1)

# Extract response only for cell lines for which features are present
response_df = response_df[~response_df.COSMIC_ID.isin(missing_cell_lines)]

#### Establish drug features data
drug_data_original_df = full_dataset.drugs_data.copy()

# Convert drug index from LINCS name to GDSC drug ID
drug_data_original_df.index = drug_data_original_df.index.map(full_dataset.kinomescan_name_to_gdsc_id_mapper)

print(drug_data_original_df.shape, cell_line_data_original_df.shape, response_df.shape)

# Establish input data dimensionalities
drug_dim = drug_data_original_df.shape[1]
cell_line_dim = cell_line_data_original_df.shape[1]

# Modify response data if needed
response_df = response_df[["DRUG_ID", "COSMIC_ID", global_response_metric]]
print(response_df.shape)
response_df.head()

## Trainable function for Ray

In [None]:
def trainable_autoencoders(config):
    warnings.filterwarnings("ignore")
    # Setup seed
    global samples_train
    torch.manual_seed(split_seeds[experiment - 1])
    print(samples_train.shape)
    # Initiate DEERS
    # Autoencoders
    drug_autoencoder = DeepAutoencoderOneHiddenLayer(fixed_model_specs["drug_dim"],
                                                        fixed_model_specs["drug_hidden_dim"], fixed_model_specs["code_dim"],
                                                        activation_func=fixed_model_specs["activation_func"],
                                                        code_activation=fixed_model_specs["code_activation"],
                                                        dropout=fixed_model_specs["auto_dropout"])

    cell_line_autoencoder = DeepAutoencoderOneHiddenLayer(fixed_model_specs["cell_line_dim"],
                                        fixed_model_specs["cell_line_hidden_dim"], fixed_model_specs["code_dim"],
                                        activation_func=fixed_model_specs["activation_func"],
                                        code_activation=fixed_model_specs["code_activation"],
                                        dropout=fixed_model_specs["auto_dropout"])
    # Forward network
    net = ForwardNetworkTwoHiddenLayers(2 * fixed_model_specs["code_dim"], 
                                            fixed_model_specs["forward_net_hidden_dim1"],
                                            fixed_model_specs["forward_net_hidden_dim2"],
                                            dropout_rate=config["forward_net_drop_rate"],
                                            out_activation=fixed_model_specs["forward_net_out_act"])
    # Make the model together
    network = RecSystemCodeConcatenation(drug_autoencoder, cell_line_autoencoder, forward_network=net,
                                              code_interactions=fixed_model_specs["code_interactions"])
    print(type(network))
    # Compute number of mini-batches per epoch
    no_batches = samples_train.shape[0] // fixed_training_specs["batch_size"] + 1

    # Establish the device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    # Move the network into device
    network.to(device)
    # Training the model
    network.train()

    # Establish early stopping
    early_stopping = EarlyStopping(patience=fixed_training_specs["early_stopping_patience"], delta=0.0)

    # Set optimizer and criterions
    optimizer = torch.optim.Adam(network.parameters(), lr=config["learning_rate"])
    if fixed_training_specs["weigh_by_true_auc"]:
        output_criterion = nn.MSELoss(reduction="none")
    else:
        output_criterion = nn.MSELoss()
    reconstruction_criterion = nn.MSELoss()

    # Training loop
    for epoch in range(1, fixed_training_specs["num_epochs"] + 1):
        # Initialize lists for storing training losses during batches
        epoch_main_losses = 0
        epoch_y_losses = 0
        epoch_drug_indep_losses = 0
        epoch_cl_indep_losses = 0
        t0 = time.time()
        # Shuffle the data if specified
        if fixed_training_specs["shuffle_train"]:
            samples_train = samples_train.sample(frac=1., random_state=11)
        for batch in range(no_batches):
            # Separate response variable batch
            if batch != no_batches:
                samples_batch = samples_train.iloc[batch * fixed_training_specs["batch_size"]:(batch + 1) * fixed_training_specs["batch_size"]]
            else:
                samples_batch = samples_train.iloc[batch * fixed_training_specs["batch_size"]:]

            # Extract output variable batch
            y_batch = samples_batch[global_response_metric].values.reshape(-1, 1)
            
            # Extract cell lines IDs for which data shall be extracted
            cl_ids = samples_batch["COSMIC_ID"].values
            # Extract corresponding cell line data
            cell_line_input_batch = cell_line_data_df.loc[cl_ids].values

            # Extract drug IDs for which data shall be extracted
            drug_ids = samples_batch["DRUG_ID"].values
            # Extract corresponding drug data
            drug_input_batch = drug_data_df.loc[drug_ids].values

            # Augment data if specified
            if fixed_training_specs["augment_data"]:
                original_drug_batch = drug_input_batch.copy()
                original_cell_line_batch = cell_line_input_batch.copy()
                original_y_batch = y_batch.copy()

                for i in range(fixed_training_specs["augm_data_factor"]):
                    augmented_cell_line_data = original_cell_line_batch[:, :cell_line_exp_idx]
                    cell_line_noise = np.random.normal(loc=fixed_training_specs["cl_noise_mean"],
                                                      scale=fixed_training_specs["cl_noise_std"],
                                                      size=augmented_cell_line_data.shape)
                    augmented_cell_line_data = augmented_cell_line_data + cell_line_noise
                    augmented_cell_line_data = np.concatenate((augmented_cell_line_data, 
                                                              original_cell_line_batch[:, cell_line_exp_idx:]),
                                                              axis=1)
                    augmented_response_data = original_y_batch + np.random.normal(loc=fixed_training_specs["auc_noise_mean"],
                                                                scale=fixed_training_specs["auc_noise_std"],
                                                                size=original_y_batch.shape)
                    
                    cell_line_input_batch = np.concatenate((cell_line_input_batch, augmented_cell_line_data),
                                                          axis=0)
                    
                    y_batch = np.concatenate((y_batch, augmented_response_data), axis=0)
                    drug_input_batch = np.concatenate((drug_input_batch, original_drug_batch), axis=0)

            # Clear gradient buffers because we don't want to accummulate gradients 
            optimizer.zero_grad()

            # Put data batches into torch and device
            cell_line_input_batch = torch.from_numpy(cell_line_input_batch).to(device)
            drug_input_batch = torch.from_numpy(drug_input_batch).to(device)
            y_batch = torch.from_numpy(y_batch).to(device)

            # Perform forward pass
            batch_output, batch_drug_reconstruction, batch_cl_reconstruction = network(
                drug_input_batch.float(), cell_line_input_batch.float())

            # L2 regularization if needed
            reg_sum = 0
            for param in network.parameters():
                reg_sum += 0.5 * (param ** 2).sum()  # L2 norm

            # Compute the loss for this batch, including the drug and cell line reconstruction losses
            # Variant wit weightning the samples
            if fixed_training_specs["weigh_by_true_auc"]:
                output_loss = output_criterion(batch_output, y_batch.float())
                output_loss = torch.mul(output_loss, y_batch.float())
                output_loss = output_loss.mean()
            # Variant with normal loss
            else:
                output_loss = output_criterion(batch_output, y_batch.float())
            # Reconstruction losses
            drug_reconstruction_loss = reconstruction_criterion(batch_drug_reconstruction, drug_input_batch.float())
            cl_reconstruction_loss = reconstruction_criterion(batch_cl_reconstruction, cell_line_input_batch.float())

            # Compute independence loss
            # Covariance matrices
            drug_codes_batch = network.drug_autoencoder.encoder(drug_input_batch.float())
            cl_codes_batch = network.cell_line_autoencoder.encoder(cell_line_input_batch.float())
            drug_cov = ModelWithAutoencoders.covariance_matrix_torch(drug_codes_batch)
            cl_cov = ModelWithAutoencoders.covariance_matrix_torch(cl_codes_batch)
            
            # Actual dependence losses
            drug_independence_loss = (drug_cov * drug_cov).sum() - torch.trace(drug_cov * drug_cov)
            cl_independence_loss = (cl_cov * cl_cov).sum() - torch.trace(cl_cov * cl_cov)

            # Combine the losses in the final cost function
            loss = fixed_training_specs["y_loss_weight"] * output_loss + \
                    fixed_training_specs["reconstruction_term_drug"] * drug_reconstruction_loss + \
                    fixed_training_specs["reconstruction_term_cl"] * cl_reconstruction_loss + \
                    fixed_training_specs["independence_term"] * drug_independence_loss + fixed_training_specs["independence_term"] * cl_independence_loss
            
            # Get the gradients w.r.t. the parameters
            loss.backward()
            # Update the parameters
            optimizer.step()
            
            epoch_main_losses += loss.item()
            epoch_y_losses += output_loss.item()
            epoch_drug_indep_losses += drug_independence_loss.item()
            epoch_cl_indep_losses += cl_independence_loss.item()
        
        print("Epoch training time:", time.time() - t0)
        f = lambda x: np.round(x, 3)
        print("Epoch: {}, main loss: {}, output loss: {} drug independence loss: {}, cell line independence loss: {}".format(
            epoch, f(epoch_main_losses / no_batches), f(epoch_y_losses / no_batches),
            f(epoch_drug_indep_losses / no_batches), f(epoch_cl_indep_losses / no_batches)))
        
        # Evaluate on training data
        model = ModelWithAutoencoders("Model", network)
        predicted, y_true, drug_input, cl_input = model.predict(samples_train, cell_line_data_df, drug_data_df,
                                                                response_metric=global_response_metric)
        preds, drug_reconstruction, cl_reconstruction = predicted
        try:
            train_rmse, train_corr = Model.evaluate_predictions(y_true, preds.cpu().numpy().reshape(-1))
        except ValueError:
            train_rmse, train_corr = np.nan, (np.nan, np.nan)
        # Drug reconstruction training error
        try:
            train_drug_rec_rmse, train_drug_rec_corr = Model.evaluate_predictions(drug_input.flatten(),
                                            drug_reconstruction.cpu().numpy().flatten())
        except ValueError:
            train_drug_rec_rmse, train_drug_rec_corr = np.nan, (np.nan, np.nan)

        # Cell line reconstruction training error
        try:
            train_cl_rec_rmse, train_cl_rec_corr = Model.evaluate_predictions(cl_input.flatten(),
                                            cl_reconstruction.cpu().numpy().flatten())
        except ValueError:
            train_cl_rec_rmse, train_cl_rec_corr = np.nan, (np.nan, np.nan)

        performance_df_train = Model.per_drug_performance_df(samples_train, preds.cpu(), 
                                                             response_metric=global_response_metric)
        train_median_rmse = performance_df_train["Model RMSE"].median()
        train_median_corr = performance_df_train["Model correlation"].median()
        
        # Evaluate on validation data
        predicted, y_true, drug_input, cl_input = model.predict(samples_val, cell_line_data_df, drug_data_df,
                                                                response_metric=global_response_metric)
        preds, drug_reconstruction, cl_reconstruction = predicted
        
        # Main validation error
        try:
            val_rmse, val_corr = Model.evaluate_predictions(y_true, preds.cpu().numpy().reshape(-1))
        except ValueError:
            val_rmse, val_corr = np.nan, (np.nan, np.nan)

        # Drug reconstruction validation error
        try:
            val_drug_rec_rmse, val_drug_rec_corr = Model.evaluate_predictions(drug_input.flatten(),
                                            drug_reconstruction.cpu().numpy().flatten())
        except ValueError:
            val_drug_rec_rmse, val_drug_rec_corr = np.nan, (np.nan, np.nan)

        # Cell line reconstruction validation error
        try:
            val_cl_rec_rmse, val_cl_rec_corr = Model.evaluate_predictions(cl_input.flatten(),
                                            cl_reconstruction.cpu().numpy().flatten())
        except ValueError:
            val_cl_rec_rmse, val_cl_rec_corr = np.nan, (np.nan, np.nan)
        # Per-drug validation evaluation
        performance_df_val = Model.per_drug_performance_df(samples_val, preds.cpu(),
                                                           response_metric=global_response_metric)
        val_median_rmse = performance_df_val["Model RMSE"].median()
        val_median_corr = performance_df_val["Model correlation"].median()
          
        early_stopping(val_rmse, network)
            
        # Save network's state dict at the end of the training
        if early_stopping.early_stop or epoch == fixed_training_specs["num_epochs"]:
            performance_df_train.to_csv("./performance_per_drug_train.csv", index=False)
            performance_df_val.to_csv("./performance_per_drug_val.csv", index=False)
            torch.save(model.network.state_dict(), "./network_end_state_dict.pth")

        track.log(train_loss=epoch_main_losses / no_batches,
                  train_y_loss=epoch_y_losses / no_batches,
                  drug_independence_train_loss=epoch_drug_indep_losses / no_batches,
                  cl_independence_train_loss=epoch_cl_indep_losses / no_batches,
                  train_rmse=train_rmse,
                  train_corr=train_corr[0],
                  train_median_rmse=train_median_rmse,
                  train_median_corr=train_median_corr,
                  train_drug_rec_corr=train_drug_rec_corr[0],
                  train_cl_rec_corr=train_cl_rec_corr[0],
                  val_rmse=val_rmse,
                  val_corr=val_corr[0],
                  val_drug_rec_corr=val_drug_rec_corr[0],
                  val_cl_rec_corr=val_cl_rec_corr[0],
                  val_median_rmse=val_median_rmse,
                  val_median_corr=val_median_corr,
                  early_stopping_stop=early_stopping.early_stop,
                  early_stopping_counter=early_stopping.counter)

## Specs for model and training, experiment setup

In [None]:
#### Experiment setup

# Establish how many cell lines go to val and test set
num_val_cell_lines = 100
num_test_cell_lines = 100

# Number of train/evaluation iterations
num_experimental_iterations = 5

# Number of parameter combinations to check in every iteration
num_tuning_samples = 20

# Establish data split seeds
split_seeds = [40, 65, 31, 9, 27]

# Metric for parameter search
validation_metric = "val_rmse"
validation_mode = "min"

########################################################
# SPECIFY RESULTS DIRECTORY
########################################################
global_experiment_name = "..."
global_experiment_dir = project_dir + "..." + global_experiment_name
print(global_experiment_dir)

if not os.path.exists(global_experiment_dir):
    os.makedirs(global_experiment_dir)

# Create and save JSON with experiment setup
setup = {"experiment name": global_experiment_name,
         "dataset name": full_dataset.name,
         "dataset description": full_dataset.description,
         "drug features dimension": drug_dim,
         "cell lines features dimension": cell_line_dim,
         "num experimental iterations": num_experimental_iterations,
         "num val cell lines": num_val_cell_lines,
         "num test cell lines": num_test_cell_lines,
         "data split seeds": str(split_seeds),
         "optimizer": "Adam",
         "num tuning samples": num_tuning_samples,
         "validation metric": validation_metric,
         "model": "RecSysNN"
  }

fixed_model_specs = {"drug_dim": drug_dim,
         "cell_line_dim": cell_line_dim,
         "code_dim": 10,
         "drug_hidden_dim": 128,
         "cell_line_hidden_dim": 128,
         "auto_dropout": False,
         "activation_func": nn.ReLU,
         "code_activation": False,
         "forward_net_hidden_dim1": 512,
         "forward_net_hidden_dim2": 256,
         "forward_net_out_act": None,
         "code_interactions": False}

fixed_training_specs = {"batch_size": 512, 
                        "num_epochs": 150,
                        "early_stopping_patience": 20,
                        "weigh_by_true_auc": False,
                        "augment_data": True,
                        "augm_data_factor": 2,
                        "cl_noise_mean": 0.0,
                        "cl_noise_std": 0.6,
                        "auc_noise_mean": 0.0, 
                        "auc_noise_std": 0.15,
                        "reconstruction_term_drug": 0.1,
                        "reconstruction_term_cl": 0.25,
                        "independence_term": 0.1,
                        "y_loss_weight": 1,
                        "l2_lambda": 0.0,
                        "shuffle_train": False,
                        "shuffle_train_and_val": True,
                        "scale_response_variable": True}

if fixed_training_specs["augment_data"]:
  cell_line_exp_idx = 202

# Hyperparameter search space
search_space = {
    "learning_rate": tune.loguniform(1e-5, 1e-3),
    "forward_net_drop_rate": tune.grid_search([0.0, 0.5])
  }

# Save the experiment parameters
import json
with open(global_experiment_dir + "/experiment description.json", "w") as f:
    json.dump(setup, f)

fixed_model_specs_json = fixed_model_specs.copy()
for k in fixed_model_specs_json:
  fixed_model_specs_json[k] = str(fixed_model_specs_json[k])
with open(global_experiment_dir + "/fixed_model_specs.json", "w") as f:
    json.dump(fixed_model_specs_json, f)

fixed_training_specs_json = fixed_training_specs.copy()
for k in fixed_training_specs_json:
  fixed_training_specs_json[k] = str(fixed_training_specs_json[k])
with open(global_experiment_dir + "/fixed_training_specs.json", "w") as f:
    json.dump(fixed_training_specs_json, f)

with open(global_experiment_dir + "/search_space.pkl", "wb") as f:
    dill.dump(search_space, f)
  
with open(global_experiment_dir + "/search_space.json", "w") as f:
    json.dump(str(search_space), f)

## Hyperparameter tuning loop

In [None]:
%%time
#### Hyperparameter tuning and evaluation iterations
import warnings
warnings.filterwarnings("ignore")
for experiment in range(1, num_experimental_iterations + 1):
    # Data preprocessing
    # Split data into train/val/test sets - unseen cell lines
    samples_train, samples_val, samples_test, cell_lines_test, cell_lines_test = Dataset.samples_train_test_split(
                                                                            response_df,
                                                                            num_val_cell_lines,
                                                                            num_test_cell_lines,
                                                                            split_seeds[experiment - 1],
                                                                            shuffle=True)

    # Scale the response variabe if needed
    if fixed_training_specs["scale_response_variable"]:
        minimum, maximum = samples_train[global_response_metric].min(), samples_train[global_response_metric].max()
        samples_train[global_response_metric] = Dataset.min_max_series(samples_train[global_response_metric], minimum, maximum)
        samples_val[global_response_metric] = Dataset.min_max_series(samples_val[global_response_metric], minimum, maximum)
        samples_test[global_response_metric] = Dataset.min_max_series(samples_test[global_response_metric], minimum, maximum)

    print("Train:", samples_train[global_response_metric].mean(), samples_train[global_response_metric].min(), 
      samples_train[global_response_metric].max())
    print("Val:", samples_val[global_response_metric].mean(), samples_val[global_response_metric].min(), 
          samples_val[global_response_metric].max())
    print("Test:", samples_test[global_response_metric].mean(), samples_test[global_response_metric].min(), 
          samples_test[global_response_metric].max())
                   
    # Normalize the data
    # Cell line data
    cols_subset = [col for col in list(cell_line_data_original_df) if col.endswith("_exp")]
    rows_subset = list(samples_train["COSMIC_ID"].unique())

    cell_line_data_df = Dataset.standardize_data(cell_line_data_original_df, cols_subset=cols_subset,
                                                rows_subset=rows_subset)
    # Drug data
    rows_subset = list(samples_train["DRUG_ID"].unique())
    drug_data_df = Dataset.standardize_data(drug_data_original_df, rows_subset=rows_subset)

    # Resources to request by Tune
    ray.shutdown()
    ray.init(num_cpus=2, num_gpus=1)

    time.sleep(5)

    ################################################
    # RUN THE MAIN TUNE ANALYSIS #
    ################################################
    experiment_name = "Experiment " + str(experiment)
    # Randomized sampling
    analysis = tune.run(trainable_autoencoders, 
                        name=experiment_name,
                        config=search_space, 
                        num_samples=num_tuning_samples,
                        local_dir=global_experiment_dir,
                        resources_per_trial={"cpu": 1, "gpu": 0.48},
                        max_failures=3,
                        verbose=1,
                        stop={"early_stopping_stop": True})
    
    # Save summary dataframe for this analysis
    full_df = analysis.dataframe()
    full_df.to_csv(analysis._experiment_dir + "/" + "analysis_tuning_results.csv", index=False)

    # Merge training and validation samples
    samples_train_and_val = pd.concat([samples_train, samples_val], axis=0)
    
    # Extract best parameter combination for this tuning analysis
    best_config = analysis.get_best_config(metric=validation_metric, mode=validation_mode)
    
    # Save best config
    with open(analysis._experiment_dir + "/best_config.txt", "w") as f:
        for d in best_config:
            line = str(d) + ": " + str(analysis.get_best_config(metric="val_rmse", mode="min")[d]) + "\n"
            f.write(line)

    # Setup seed
    torch.manual_seed(split_seeds[experiment - 1])

    # Setup the network
    drug_autoencoder = DeepAutoencoderOneHiddenLayer(fixed_model_specs["drug_dim"],
                                                        fixed_model_specs["drug_hidden_dim"], fixed_model_specs["code_dim"],
                                                        activation_func=fixed_model_specs["activation_func"],
                                                        code_activation=fixed_model_specs["code_activation"],
                                                        dropout=fixed_model_specs["auto_dropout"])

    cell_line_autoencoder = DeepAutoencoderOneHiddenLayer(fixed_model_specs["cell_line_dim"],
                                        fixed_model_specs["cell_line_hidden_dim"], fixed_model_specs["code_dim"],
                                        activation_func=fixed_model_specs["activation_func"],
                                        code_activation=fixed_model_specs["code_activation"],
                                        dropout=fixed_model_specs["auto_dropout"])
    # Forward network
    net = ForwardNetworkTwoHiddenLayers(2 * fixed_model_specs["code_dim"], 
                                            fixed_model_specs["forward_net_hidden_dim1"],
                                            fixed_model_specs["forward_net_hidden_dim2"],
                                            dropout_rate=best_config["forward_net_drop_rate"],
                                            out_activation=fixed_model_specs["forward_net_out_act"])
    # Make the model together
    network = RecSystemCodeConcatenation(drug_autoencoder, cell_line_autoencoder, forward_network=net,
                                              code_interactions=fixed_model_specs["code_interactions"])
    print(type(network))
    # Compute number of mini-batches per epoch
    no_batches = samples_train_and_val.shape[0] // fixed_training_specs["batch_size"] + 1
    print("No. batches:", no_batches)

    # Establish the device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    # Move the network into device
    network.to(device)
    # Training the model
    network.train()

    # Establish early stopping
    early_stopping = EarlyStopping(patience=fixed_training_specs["early_stopping_patience"], delta=0.0)

    # Set optimizer and criterions
    optimizer = torch.optim.Adam(network.parameters(), lr=best_config["learning_rate"])
    if fixed_training_specs["weigh_by_true_auc"]:
        output_criterion = nn.MSELoss(reduction="none")
    else:
        output_criterion = nn.MSELoss()
    reconstruction_criterion = nn.MSELoss()

    results = {"epoch": [],
               "train_loss": [],
               "train_y_loss": [],
               "drug_independence_train_loss": [],
               "cl_independence_train_loss": [],
               "train_rmse": [],
               "train_corr": [],
               "train_median_rmse": [],
               "train_median_corr": [],
               "train_drug_rec_corr": [],
               "train_cl_rec_corr": [],
               "test_rmse": [],
               "test_corr": [],
               "test_median_rmse": [],
               "test_median_corr": [],
               "test_drug_rec_corr": [],
               "test_cl_rec_corr": []}

    # Establish number of epochs
    best_trials_logdir = analysis.get_best_logdir(metric=validation_metric, mode=validation_mode)
    df = full_df[full_df.logdir == best_trials_logdir]
    train_iters = df["training_iteration"].iloc[0]
    counter = df["early_stopping_counter"].iloc[0]


    for epoch in range(1, train_iters - counter + 2):
        # Shuffle the data
        if fixed_training_specs["shuffle_train_and_val"]:
            samples_train_and_val = samples_train_and_val.sample(frac=1., random_state=11)
        # Initialize lists for storing training losses during batches
        epoch_main_losses = 0
        epoch_y_losses = 0
        epoch_drug_indep_losses = 0
        epoch_cl_indep_losses = 0
        t0 = time.time()
        for batch in range(no_batches):
            # Separate response variable batch
            if batch != no_batches:
                samples_batch = samples_train_and_val.iloc[batch * fixed_training_specs["batch_size"]:(batch + 1) * fixed_training_specs["batch_size"]]
            else:
                samples_batch = samples_train_and_val.iloc[batch * fixed_training_specs["batch_size"]:]

            # Extract output variable batch
            y_batch = samples_batch[global_response_metric].values.reshape(-1, 1)
            
            # Extract cell lines IDs for which data shall be extracted
            cl_ids = samples_batch["COSMIC_ID"].values
            # Extract corresponding cell line data
            cell_line_input_batch = cell_line_data_df.loc[cl_ids].values

            # Extract drug IDs for which data shall be extracted
            drug_ids = samples_batch["DRUG_ID"].values
            # Extract corresponding drug data
            drug_input_batch = drug_data_df.loc[drug_ids].values

            # Augment data if specified
            if fixed_training_specs["augment_data"]:
                original_drug_batch = drug_input_batch.copy()
                original_cell_line_batch = cell_line_input_batch.copy()
                original_y_batch = y_batch.copy()

                for i in range(fixed_training_specs["augm_data_factor"]):
                    augmented_cell_line_data = original_cell_line_batch[:, :cell_line_exp_idx]
                    cell_line_noise = np.random.normal(loc=fixed_training_specs["cl_noise_mean"],
                                                      scale=fixed_training_specs["cl_noise_std"],
                                                      size=augmented_cell_line_data.shape)
                    augmented_cell_line_data = augmented_cell_line_data + cell_line_noise
                    augmented_cell_line_data = np.concatenate((augmented_cell_line_data, 
                                                              original_cell_line_batch[:, cell_line_exp_idx:]),
                                                              axis=1)
                    augmented_response_data = original_y_batch + np.random.normal(loc=fixed_training_specs["auc_noise_mean"],
                                                                scale=fixed_training_specs["auc_noise_std"],
                                                                size=original_y_batch.shape)
                    
                    cell_line_input_batch = np.concatenate((cell_line_input_batch, augmented_cell_line_data),
                                                          axis=0)
                    
                    y_batch = np.concatenate((y_batch, augmented_response_data), axis=0)
                    drug_input_batch = np.concatenate((drug_input_batch, original_drug_batch), axis=0)

            # Clear gradient buffers because we don't want to accummulate gradients 
            optimizer.zero_grad()

            # Put data batches into torch and device
            cell_line_input_batch = torch.from_numpy(cell_line_input_batch).to(device)
            drug_input_batch = torch.from_numpy(drug_input_batch).to(device)
            y_batch = torch.from_numpy(y_batch).to(device)

            # Perform forward pass
            batch_output, batch_drug_reconstruction, batch_cl_reconstruction = network(
                drug_input_batch.float(), cell_line_input_batch.float())


            # Compute the loss for this batch, including the drug and cell line reconstruction losses
            # Variant with weightning the samples
            if fixed_training_specs["weigh_by_true_auc"]:
                output_loss = output_criterion(batch_output, y_batch.float())
                output_loss = torch.mul(output_loss, y_batch.float())
                output_loss = output_loss.mean()
            # Variant with normal loss
            else:
                output_loss = output_criterion(batch_output, y_batch.float())
            # Reconstruction losses
            drug_reconstruction_loss = reconstruction_criterion(batch_drug_reconstruction, drug_input_batch.float())
            cl_reconstruction_loss = reconstruction_criterion(batch_cl_reconstruction, cell_line_input_batch.float())

            # Compute independence loss
            # Covariance matrices
            drug_codes_batch = network.drug_autoencoder.encoder(drug_input_batch.float())
            cl_codes_batch = network.cell_line_autoencoder.encoder(cell_line_input_batch.float())
            drug_cov = ModelWithAutoencoders.covariance_matrix_torch(drug_codes_batch)
            cl_cov = ModelWithAutoencoders.covariance_matrix_torch(cl_codes_batch)
            
            # Actual dependence losses
            drug_independence_loss = (drug_cov * drug_cov).sum() - torch.trace(drug_cov * drug_cov)
            cl_independence_loss = (cl_cov * cl_cov).sum() - torch.trace(cl_cov * cl_cov)

            # Combine the losses in the final cost function
            # loss = output_loss
            # Combine the losses in the final cost function
            loss = fixed_training_specs["y_loss_weight"] * output_loss + \
                    fixed_training_specs["reconstruction_term_drug"] * drug_reconstruction_loss + \
                    fixed_training_specs["reconstruction_term_cl"] * cl_reconstruction_loss + \
                    fixed_training_specs["independence_term"] * drug_independence_loss + fixed_training_specs["independence_term"] * cl_independence_loss
            
            # Get the gradients w.r.t. the parameters
            loss.backward()
            # Update the parameters
            optimizer.step()
            
            epoch_main_losses += loss.item()
            epoch_y_losses += output_loss.item()
            epoch_drug_indep_losses += drug_independence_loss.item()
            epoch_cl_indep_losses += cl_independence_loss.item()
        
        print("Epoch training time:", time.time() - t0)
        f = lambda x: np.round(x, 3)
        print("Epoch: {}, main loss: {}, output loss: {} drug independence loss: {}, cell line independence loss: {}".format(
            epoch, f(epoch_main_losses / no_batches), f(epoch_y_losses / no_batches),
            f(epoch_drug_indep_losses / no_batches), f(epoch_cl_indep_losses / no_batches)))
        
        # Update metrics for learning curve
        results["train_loss"].append(epoch_main_losses / no_batches)
        results["train_y_loss"].append(epoch_y_losses / no_batches)
        results["drug_independence_train_loss"].append(epoch_drug_indep_losses / no_batches)
        results["cl_independence_train_loss"].append(epoch_cl_indep_losses / no_batches)
        
        # Evaluate on training data
        best_model = ModelWithAutoencoders("Best model", network)
        predicted, y_true, drug_input, cl_input = best_model.predict(samples_train_and_val, cell_line_data_df, drug_data_df,
                                                                     response_metric=global_response_metric)
        preds, drug_reconstruction, cl_reconstruction = predicted
        try:
            train_rmse, train_corr = Model.evaluate_predictions(y_true, preds.cpu().numpy().reshape(-1))
        except ValueError:
            train_rmse, train_corr = np.nan, (np.nan, np.nan)
        # Drug reconstruction training error
        try:
            train_drug_rec_rmse, train_drug_rec_corr = Model.evaluate_predictions(drug_input.flatten(),
                                            drug_reconstruction.cpu().numpy().flatten())
        except ValueError:
            train_drug_rec_rmse, train_drug_rec_corr = np.nan, (np.nan, np.nan)

        # Cell line reconstruction training error
        try:
            train_cl_rec_rmse, train_cl_rec_corr = Model.evaluate_predictions(cl_input.flatten(),
                                            cl_reconstruction.cpu().numpy().flatten())
        except ValueError:
            train_cl_rec_rmse, train_cl_rec_corr = np.nan, (np.nan, np.nan)

        train_performance_df = Model.per_drug_performance_df(samples_train_and_val, preds.cpu(),
                                                             response_metric=global_response_metric)
        train_median_rmse = train_performance_df["Model RMSE"].median()
        train_median_corr = train_performance_df["Model correlation"].median()
        
        results["epoch"].append(epoch)
        results["train_rmse"].append(train_rmse)
        results["train_corr"].append(train_corr[0])
        results["train_median_corr"].append(train_median_corr)
        results["train_median_rmse"].append(train_median_rmse)
        results["train_drug_rec_corr"].append(train_drug_rec_corr[0])
        results["train_cl_rec_corr"].append(train_cl_rec_corr[0])
        
        # Evaluate on test data
        predicted, y_true, drug_input, cl_input = best_model.predict(samples_test, cell_line_data_df, drug_data_df,
                                                                     response_metric=global_response_metric)
        preds, drug_reconstruction, cl_reconstruction = predicted
        
        # Main test error
        try:
            test_rmse, test_corr = Model.evaluate_predictions(y_true, preds.cpu().numpy().reshape(-1))
        except ValueError:
            test_rmse, test_corr = np.nan, (np.nan, np.nan)

        # Drug reconstruction validation error
        try:
            test_drug_rec_rmse, test_drug_rec_corr = Model.evaluate_predictions(drug_input.flatten(),
                                            drug_reconstruction.cpu().numpy().flatten())
        except ValueError:
            test_drug_rec_rmse, test_drug_rec_corr = np.nan, (np.nan, np.nan)

        # Cell line reconstruction validation error
        try:
            test_cl_rec_rmse, test_cl_rec_corr = Model.evaluate_predictions(cl_input.flatten(),
                                            cl_reconstruction.cpu().numpy().flatten())
        except ValueError:
            test_cl_rec_rmse, test_cl_rec_corr = np.nan, (np.nan, np.nan)
     
        # Per-drug test evaluation
        test_performance_df = Model.per_drug_performance_df(samples_test, preds.cpu(),
                                                            response_metric=global_response_metric)
        test_median_rmse = test_performance_df["Model RMSE"].median()
        test_median_corr = test_performance_df["Model correlation"].median()

        # Per-cell line test evaluation
        per_cl_test_performance_df = Model.per_entity_performance_df(
                        samples_test, preds.cpu(), entity_type="COSMIC_ID", response_metric=global_response_metric)
            
        results["test_rmse"].append(test_rmse)
        results["test_corr"].append(test_corr[0])
        results["test_drug_rec_corr"].append(test_drug_rec_corr[0])
        results["test_cl_rec_corr"].append(test_cl_rec_corr[0])
        results["test_median_corr"].append(test_median_corr)
        results["test_median_rmse"].append(test_median_rmse)

        best_model_results = pd.DataFrame(results)

    best_model_results.to_csv(analysis._experiment_dir + "/best_model_test_results.csv", index=False)
    train_performance_df.to_csv(analysis._experiment_dir + "/best_model_per_drug_train_results.csv", index=False)
    test_performance_df.to_csv(analysis._experiment_dir + "/best_model_per_drug_test_results.csv", index=False)
    
    # Save trained best model
    with open(analysis._experiment_dir + "/best_trained_model.pkl", "wb") as f:
        dill.dump(best_model, f)
    # Save trained best model's state dict
    best_network = network
    torch.save(best_network.state_dict(), analysis._experiment_dir + "/best_network_state_dict.pth")
        
    print("*" * 50)
    print("Experiment", experiment)
    print(best_config)
    print("*" * 50)