# Start to finish training and evaluation of models

### Load packages and set seed

In [None]:
import os, pickle, random, sys, torch

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim

from sklearn.metrics import auc,precision_recall_curve,roc_curve,confusion_matrix,accuracy_score,recall_score,f1_score
from sklearn.model_selection import KFold, train_test_split
from torch.autograd import Variable
from torch.nn import Linear, Conv2d, BatchNorm2d, MaxPool2d, Dropout2d, GRU, Dropout, BatchNorm1d
from torch.nn.functional import relu, elu, relu6, sigmoid, tanh, softmax
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader

np.random.seed(10)
random.seed(10)

### Define functions for loading data

In [None]:
def load_training_and_validataion_dataset(path_to_partitions,train_splits):
    import random
    ## Training
    # training_partions = random.sample(range(10),train_splits)
    training_partions = [9, 0, 6, 3, 4, 8, 1, 7]

    ## Validation
    validation_partions = [i for i in range(10) if i not in training_partions]
    partitions = []
    for file in os.listdir(path_to_partitions):
        path_to_file = os.path.join(path_to_partitions,file)
        data = pd.read_csv(path_to_file,sep="\t",names=["peptide","label","HLA_allele"])
        partitions.append(data)
    training_df = pd.concat([partitions[i] for i in training_partions])
    validation_df = pd.concat([partitions[i] for i in validation_partions])
    return training_df, validation_df,training_partions,validation_partions

def retrieve_information_from_df(data_split,entire_df):
    immunogenicity = []
    response = []
    tested_subjects = []
    positive_subjects = []
    binding_scores = []
    for i,row in data_split.iterrows():
        peptide, HLA = row["peptide"], row['HLA_allele']
        original_entry = entire_df[(entire_df['peptide']==peptide) & (entire_df['HLA_allele'] == HLA)]
        assert len(original_entry) == 1
        immunogenicity.append(float(original_entry['immunogenicity']))
        response.append(original_entry['Response'].values[0])
        tested_subjects.append(int(original_entry['tested_subjects']))
        positive_subjects.append(int(original_entry['positive_subjects']))
        binding_scores.append(float(original_entry['binding_score']))
    data_split['immunogenicity'] = immunogenicity
    data_split['response'] = response
    data_split['test'] = tested_subjects
    data_split['positive_subjects'] = positive_subjects
    data_split['binding_score'] = binding_scores
    return data_split  


def encode_peptide_aaindex(aa_seq,aaindex_PCA,row):
    aa_seq = list(aa_seq.upper())
    encoded_aa_seq = []
    PCs = aaindex_PCA.shape[1]
    for aa in aa_seq:
        if aa == "X" or aa == "-":
            encoded_aa_seq.append(np.array([0 for x in range(PCs)]))
        else:
            try:
                encoded_aa_seq.append(aaindex_PCA.loc[aa].to_numpy())
            except KeyError:
                print(row)
                sys.exit(1)
    return np.array(encoded_aa_seq)

def encode_dataset(df,aaindex_PCA,HLA_dict,peptide_len,padding="right"):
    encoded_peptides = []
    encoded_labels = []
    encoded_hlas = []
    encoded_binding_scores = []
    for i,row in df.iterrows():
        peptide = row["peptide"]
        HLA = HLA_dict[row["HLA_allele"].replace(":","")]
        encoded_peptide = encode_peptide_aaindex(peptide,aaindex_PCA,row)
        binding_score = row['binding_score']


        # Adding padding
        if len(encoded_peptide) < peptide_len:
            n_added = peptide_len-len(encoded_peptide)
            if padding == "right":
                encoded_peptide = np.pad(encoded_peptide, ((0, 1), (0, 0)), 'constant')
            elif padding == "left":
                encoded_peptide = np.pad(encoded_peptide, ((1, 0), (0, 0)), 'constant')
            elif padding == "random":
                top_pad = random.choice([0,1])
                bot_pad = 1-top_pad
                encoded_peptide = np.pad(encoded_peptide, ((top_pad, bot_pad), (0, 0)), 'constant')


        encoded_HLA = encode_peptide_aaindex(HLA,aaindex_PCA,row)
        encoded_label = min(1,row["positive_subjects"])
        encoded_peptides.append(encoded_peptide)
        encoded_hlas.append(encoded_HLA)
        encoded_labels.append(encoded_label)
        encoded_binding_scores.append(binding_score)
    
    encoded_peptides = np.array(encoded_peptides).astype('float32')
    encoded_hlas = np.array(encoded_hlas).astype('float32')
    encoded_labels = np.array(encoded_labels).astype('float32')
    encoded_binding_scores = np.array(encoded_binding_scores).astype('float32')
    return encoded_peptides, encoded_hlas, encoded_binding_scores, encoded_labels


### Load and encode data

In [None]:
# Loading the databases
aaindex_PCA = pd.read_csv('../data/PCA_repr_aa.csv',index_col=0)
hla_database = pd.read_csv('../data/formatted_hla2paratope_MHC_pseudo.dat', sep=' ',index_col=0)
hla_dic = hla_database.to_dict("dict")["pseudo"]

# Load dataset
# entire_df = pd.read_csv('../data/filtered_data_IEDB_4_tested_len_9_10_full_HLA_IFNg_assay.csv')
entire_df = pd.read_csv("../data/filtered_data_IEDB_4_tested_len_9_10_full_HLA_Multi_assay_w_binding.csv")
# entire_df = pd.read_csv('../data/deep_immuno_2.csv')


# Allocating the partitions of the trainign and validation data
training_df, validation_df, training_partions,validation_partions = load_training_and_validataion_dataset(path_to_partitions="../data/multi_assay_parts",train_splits=8)


# Creating the training dataframe (With correct information such as tested and positive subjects aswell as label)
training_df_entire = retrieve_information_from_df(training_df,entire_df)
# Shuffling the dataframe
training_df_entire = training_df_entire.sample(frac=1, random_state=1).reset_index(drop=True)

# Creating the validation dataframe (With correct information such as tested and positive subjects aswell as label)
validation_df_entire = retrieve_information_from_df(validation_df,entire_df)
# Shuffling the dataframe
validation_df_entire = validation_df_entire.sample(frac=1, random_state=1).reset_index(drop=True)

print("##Encoding Training data")
train_peptides_encoded,train_HLA_encoded,train_binding_scores_encoded,train_label_encoded = encode_dataset(training_df_entire,aaindex_PCA,hla_dic,peptide_len=10,padding="right")
print("##Encoding Validation data")
val_peptides_encoded,val_HLA_encoded,val_binding_scores_encoded ,val_label_encoded = encode_dataset(validation_df_entire,aaindex_PCA,hla_dic,peptide_len=10,padding="right")

peptide_train = train_peptides_encoded.reshape(-1,1,10,12)
HLA_train = train_HLA_encoded.reshape(-1,1,34,12)
binding_train = train_binding_scores_encoded
label_train = train_label_encoded

peptide_val = val_peptides_encoded.reshape(-1,1,10,12)
HLA_val = val_HLA_encoded.reshape(-1,1,34,12) # 46 aligned representataion and 34 if not aligned
binding_val = val_binding_scores_encoded
label_val = val_label_encoded


### Define functions for model initialization

In [None]:
def compute_conv_dim(dim_size,kernel_size,padding,stride):
    return int((dim_size - kernel_size + 2 * padding) / stride + 1)
compute_conv_dim(34,2,0,1)

def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
          nn.init.kaiming_uniform_(m.weight.data,nonlinearity='relu')
          if m.bias is not None:
                nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
          nn.init.constant_(m.weight.data, 1)
          nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
          nn.init.kaiming_uniform_(m.weight.data,nonlinearity='relu')
          if m.bias is not None:
                nn.init.constant_(m.bias.data, 0)
          

### Defining the network class?

### Create test and validation data sets

In [None]:
batch_size = 100
peptide_train_loader = list(DataLoader(peptide_train,batch_size=batch_size))
HLA_train_loader = list(DataLoader(HLA_train,batch_size=batch_size))
label_train_loader = list(DataLoader(label_train,batch_size=batch_size))
binding_score_train_loader = list(DataLoader(binding_train,batch_size=batch_size))


peptide_val_loader = list(DataLoader(peptide_val,batch_size=batch_size))
HLA_val_loader = list(DataLoader(HLA_val,batch_size=batch_size))
label_val_loader = list(DataLoader(label_val,batch_size=batch_size))
binding_score_val_loader = list(DataLoader(binding_val,batch_size=batch_size))

### Define functions for model performance evaluation

In [None]:
def plot_learning_curve(train_accuracies,val_accuracies):
    epoch = np.arange(len(train_accuracies))
    plt.figure()
    plt.plot(epoch, train_accuracies, 'r', epoch, val_accuracies, 'b')
    plt.legend(['Train Accucary','Validation Accuracy'])
    plt.xlabel('epochs'), plt.ylabel('Acc')



def validation(model,device,valid_loaders,train_loaders):
    peptide_val_loader,HLA_val_loader,label_val_loader,binding_score_val_loader = valid_loaders
    peptide_train_loader,HLA_train_loader,label_train_loader,binding_score_train_loader = train_loaders
    model.eval()
    with torch.no_grad():
        all_train_targets = []
        all_predicted_train_labels = []
        for i in range(len((peptide_train_loader))):
            train_peptides = peptide_train_loader[i]
            train_HLA = HLA_train_loader[i]
            train_labels = label_train_loader[i].long().reshape(-1)
            train_binding_scores = binding_score_train_loader[i].reshape(len(train_peptides),1)
            outputs = model(train_peptides,train_HLA,train_binding_scores)
            _,predicted_labels =  torch.max(outputs, 1)

            all_predicted_train_labels += predicted_labels.numpy().tolist()
            all_train_targets += train_labels.numpy().tolist()
        
        all_val_targets = []
        all_predicted_val_labels = []
        all_probabilities_val = []
        for j in range(len((peptide_val_loader))):
            val_peptides = peptide_val_loader[j]
            val_HLA = HLA_val_loader[j]
            val_labels = label_val_loader[j].long().reshape(-1)
            val_binding_scores = binding_score_val_loader[j].reshape(len(val_peptides),1)
            outputs = model(val_peptides,val_HLA,val_binding_scores)
            probability,predicted_labels =  torch.max(outputs, 1)
            all_predicted_val_labels += predicted_labels.numpy().tolist()
            all_val_targets += val_labels.numpy().tolist()
            all_probabilities_val += probability.numpy().tolist()

    return all_train_targets,all_predicted_train_labels,all_val_targets,all_predicted_val_labels,all_probabilities_val


def train(model, device, epochs, train_loaders, valid_loaders):

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001,weight_decay=1e-4)

    # Early stopping
    the_last_loss = 100
    patience = 2
    trigger_times = 0
    
    all_val_targets_pr_epoch = []
    all_val_predictions_pr_epoch = []
    all_val_probabilities_pr_epoch = []
    losses = []

    train_accuracies = []
    val_accuracies = []

    peptide_train_loader,HLA_train_loader,label_train_loader,binding_score_train_loader = train_loaders
        
    for epoch in range(epochs):
        model.train()
        current_loss = 0
        for train_batch_index in range(len((peptide_train_loader))):
            train_peptides = peptide_train_loader[train_batch_index]
            train_HLA = HLA_train_loader[train_batch_index]
            train_labels = label_train_loader[train_batch_index].long().reshape(-1)
            train_binding_scores = binding_score_train_loader[train_batch_index].reshape(len(train_peptides),1)
            # zero the parameter gradients
            optimizer.zero_grad()
            outputs = model(train_peptides,train_HLA,train_binding_scores)
            loss = criterion(outputs, train_labels)
            loss.backward()
            optimizer.step()
            current_loss += loss.item()
        losses.append(current_loss/len((peptide_train_loader)))

        all_train_targets,all_predicted_train_labels,all_val_targets,all_predicted_val_labels,all_probabilities_val = validation(model,device,valid_loaders,train_loaders)
    
        # Calculating the accuracies
        train_accuracies.append(accuracy_score(all_train_targets,all_predicted_train_labels))
        val_accuracies.append(accuracy_score(all_val_targets,all_predicted_val_labels))
        # Saving the predicitons for further validation
        all_val_targets_pr_epoch.append(all_val_targets)
        all_val_predictions_pr_epoch.append(all_predicted_val_labels)
        all_val_probabilities_pr_epoch.append(all_probabilities_val)

        if epoch % 10 == 0:
            print("Epoch %2i : Train Loss %f , Train acc %f, Valid acc %f" % (epoch+1, losses[-1], train_accuracies[-1], val_accuracies[-1]))
        

        # Early stopping
        the_current_loss = val_accuracies[-1]
        the_last_loss = 0 if len(val_accuracies) < 2 else val_accuracies[-2]

        # print('The current valdiation loss:', the_current_loss)

        if the_current_loss < the_last_loss:
            trigger_times += 1
            # print('trigger times:', trigger_times)

            if trigger_times >= patience:
                # print('Early stopping at epoch',epoch)
                return model,train_accuracies,val_accuracies,all_val_targets_pr_epoch,all_val_predictions_pr_epoch,all_val_probabilities_pr_epoch

        else:
            # print('trigger times: 0')
            trigger_times = 0

    return model,train_accuracies,val_accuracies,all_val_targets_pr_epoch,all_val_predictions_pr_epoch,all_val_probabilities_pr_epoch

### Train and evaluate model

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001,weight_decay=1e-4)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('Device state:', device)

train_loaders = (peptide_train_loader, HLA_train_loader, label_train_loader, binding_score_train_loader)
val_loaders = (peptide_val_loader, HLA_val_loader, label_val_loader, binding_score_val_loader)

net = Net()
net.apply(initialize_weights)

epochs = 100
train_accuracies = []
val_accuracies = []
losses = []
all_val_targets_pr_epoch = []
all_val_predictions_pr_epoch = []
all_val_probabilities_pr_epoch = []
for epoch in range(epochs):
    net.train()
    current_loss = 0
    for train_batch_index in range(len((peptide_train_loader))):
        train_peptides = peptide_train_loader[train_batch_index]
        train_HLA = HLA_train_loader[train_batch_index]
        train_labels = label_train_loader[train_batch_index].long().reshape(-1)
        train_binding_scores = binding_score_train_loader[train_batch_index].reshape(len(train_peptides),1)
        # zero the parameter gradients
        optimizer.zero_grad()
        outputs = net(train_peptides,train_HLA,train_binding_scores)
        loss = criterion(outputs, train_labels)
        loss.backward()
        optimizer.step()
        current_loss += loss.item()

    # print(epoch, current_loss/batch_size)
    losses.append(current_loss/len((peptide_train_loader)))

    net.eval()
    with torch.no_grad():
        all_train_targets = []
        all_predicted_train_labels = []
        for i in range(len((peptide_train_loader))):
            train_peptides = peptide_train_loader[i]
            train_HLA = HLA_train_loader[i]
            train_labels = label_train_loader[i].long().reshape(-1)
            train_binding_scores = binding_score_train_loader[i].reshape(len(train_peptides),1)
            outputs = net(train_peptides,train_HLA,train_binding_scores)
            _,predicted_labels =  torch.max(outputs, 1)

            all_predicted_train_labels += predicted_labels.numpy().tolist()
            all_train_targets += train_labels.numpy().tolist()
        
        all_val_targets = []
        all_predicted_val_labels = []
        all_probabilities_val = []
        for j in range(len((peptide_val_loader))):
            val_peptides = peptide_val_loader[j]
            val_HLA = HLA_val_loader[j]
            val_labels = label_val_loader[j].long().reshape(-1)
            val_binding_scores = binding_score_val_loader[j].reshape(len(val_peptides),1)
            outputs = net(val_peptides,val_HLA,val_binding_scores)
            probability,predicted_labels =  torch.max(outputs, 1)
            all_predicted_val_labels += predicted_labels.numpy().tolist()
            all_val_targets += val_labels.numpy().tolist()
            all_probabilities_val += probability.numpy().tolist()

    # Calculating the accuracies
    train_accuracies.append(accuracy_score(all_train_targets,all_predicted_train_labels))
    val_accuracies.append(accuracy_score(all_val_targets,all_predicted_val_labels))
    # Saving the predicitons for further validation
    all_val_targets_pr_epoch.append(all_val_targets)
    all_val_predictions_pr_epoch.append(all_predicted_val_labels)
    all_val_probabilities_pr_epoch.append(all_probabilities_val)
    if epoch % 10 == 0:
        print("Epoch %2i : Train Loss %f , Train acc %f, Valid acc %f" % (epoch+1, losses[-1], train_accuracies[-1], val_accuracies[-1]))
    
