# Imports

In [None]:
import pandas as pd
import numpy as np
# import os

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

from numpy.random import normal, uniform, shuffle
# import random

# from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import TimeSeriesSplit, GridSearchCV, KFold
from sklearn.metrics import mean_squared_error, r2_score, make_scorer, auc
from sklearn.ensemble import RandomForestRegressor

import dill
# import datetime

import itertools

In [None]:
%matplotlib inline

In [None]:
# SEED = 73 # random seed

In [None]:
dill.load_session('cache/Mixed_model_session.db')

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# data_path = 'drive/My Drive/Colab Notebooks/NIR/data/'

In [None]:
data_path = 'data/prepared/'

In [None]:
results_path = 'results LD+SD/mixed_model/'

# Data reading

## Reading

### Synth data for model building

In [None]:
X_train = pd.read_csv(data_path + 'X_train.csv')
X_test = pd.read_csv(data_path + 'X_test.csv')

Y_train = pd.read_csv(data_path + 'y_train.csv')
Y_test = pd.read_csv(data_path + 'y_test.csv')

### Real data for analysis

In [None]:
conditions = ['LD', 'SD']

In [None]:
X_real = {}
Y_real = {}
for cond in conditions:
    X_real[cond] = pd.read_csv(data_path + 'real_X_' + cond + '.csv')
    Y_real[cond] = pd.read_csv(data_path + 'real_Y_' + cond + '.csv')

In [None]:
X_real['LD']

In [None]:
Y_real['LD']['days'] = X_real['LD'].days.tolist()[1:] + [40]
Y_real['SD']['days'] = X_real['SD'].days.tolist()[1:] + [34]

## Check the data

In [None]:
X_train.tail()

In [None]:
X_train.days.value_counts()

In [None]:
Y_train.tail()

In [None]:
days = list(set(X_test.days.values))
days.append(40)
days.sort()
days

In [None]:
X = X_train.drop('days', axis=1)
Y = Y_train.drop('days', axis=1)

In [None]:
X_test.tail()

# Random forest mixed model

##  Define regulators for each target

In [None]:
genes = X_test.drop('days', axis = 1).columns.tolist()
targets = [name for name in genes if not name.startswith('FT')]
FTs = sorted(list(set(genes) - set(targets)))

In [None]:
regulators = {
    'AP1': genes[1:],
    'FD': ['LFY'],
    'LFY': genes.copy(),
    'TFL1a': ['AP1'],
    'TFL1c': ['AP1']
}

regulators['LFY'].remove('LFY')

In [None]:
regulators

In [None]:
regulators["LFY"]

In [None]:
genes

## Model building

In [None]:
def GridS(model, grid, cv, **kwargs):
    return GridSearchCV(model, grid, 
                      n_jobs=-1, 
                      scoring=['neg_mean_squared_error', 'r2'], 
                      refit='neg_mean_squared_error', 
#                       scoring = MSE5,
#                       refit = MSE5,
                      cv=cv, 
                      verbose=10)

In [None]:
# cv = TimeSeriesSplit(n_splits=5) # number of splits must be divider of days-1
cv = KFold(n_splits=5, shuffle=True)
grid = {
    'n_estimators': [10, 20, 50, 100, 300],
    'min_samples_leaf': [10, 20, 50],
    'max_depth': [6, 7, 8, None]
}

In [None]:
model = {}
gs = {}

for t in targets:
    model[t] = RandomForestRegressor()
    gs[t] = GridS(model[t], grid, cv)

In [None]:
for t in targets:
    gs[t].fit(X[regulators[t]].values, Y[t].values)

In [None]:
for t in targets:
    print(t , gs[t].best_params_)

Get the model with best score on the cross-validation:

In [None]:
best_model = {}

for t in targets:
    best_model[t] = gs[t].best_estimator_

### Save the model

In [None]:
with open('cache/Mixed_model.md', 'wb') as ouf:
    dill.dump(best_model, ouf)

## Make predictions

In [None]:
def predict(data):
    predictions = pd.DataFrame()
    
    for t in targets:
        predictions[t] = best_model[t].predict(data[regulators[t]].values)
        
    return predictions

In [None]:
predictions_train = predict(X)
predictions_test = predict(X_test)

## Evaluating

In [None]:
def print_scores(true_values, predictions):
    MSEs = mean_squared_error(true_values.values, predictions.values, multioutput='raw_values').round(2)
    R2 = r2_score(true_values.values, predictions.values, multioutput='raw_values').round(2)
    r = [] 
    for i in range(len(targets)):
        r.append(np.corrcoef(true_values.values.transpose()[i], predictions.values.transpose()[i]).round(2)[0, 1])
    df = pd.DataFrame({'MSE': MSEs,'R2': R2, 'r': r}, index=targets, )
    return df

In [None]:
def comparsion_plot(true_values, predictions, data_type = 'train', current_target=targets[0]):
    plt.scatter(true_values, predictions)
    plt.plot([min(true_values), max(true_values)], [min(true_values), max(true_values)], 'r')
    plt.xlabel('True Values' + ' (' + current_target + ')')
    plt.ylabel('Predictions' + ' (' + current_target + ')')
    plt.axis('equal')
    plt.axis('square')
    plt.title('Predictions on ' + data_type)
    plt.savefig(results_path + current_target + '_' + data_type + '.png', bbox_inches='tight', dpi=300)
    plt.show()

Scores on train:  
(r - Pearson correlation)

In [None]:
print_scores(Y, predictions_train)

Scores on test:

In [None]:
print_scores(Y_test.drop('days', axis=1), predictions_test)

In [None]:
for i in range(5):
    comparsion_plot(Y.iloc[:, i], predictions_train.iloc[:, i], data_type='train', current_target=targets[i])

In [None]:
for i in range(5):
    comparsion_plot(Y_test.drop('days', axis=1).iloc[:, i], predictions_test.iloc[:, i], data_type='test', current_target=targets[i])

# Real data

In [None]:
results_path

In [None]:
def time_long_plot(true_values, predicted, condition = 'LD', predict_method = 'static'):
    for i in range(5):
        plt.plot(true_values['days'], true_values.iloc[:,i].values, 'o')
        plt.plot(true_values['days'], predicted.iloc[:, i].values, '-')
        plt.title(targets[i] + ' ' + condition + f' ({predict_method} prediction method)')
        plt.xlabel('days')
        plt.ylabel('rltv expr lvl')
        plt.legend(['true values', 'predictions'])
        plt.savefig(results_path + f'{predict_method}/{targets[i]}_{condition}_pred_on_real.png', 
                    bbox_inches='tight', dpi=300)
        plt.show()

## Static

In [None]:
predictions_real = {}
predictions_real['LD'] = predict(X_real['LD'])
predictions_real['SD'] = predict(X_real['SD'])

In [None]:
for cond in conditions:
    for i in range(5):
        comparsion_plot(Y_real[cond].iloc[:, i], predictions_real[cond].iloc[:, i], data_type='real ' + cond, current_target=targets[i])

In [None]:
for cond in conditions:
    print(cond)
    print(print_scores(Y_real[cond].drop('days', axis=1), predictions_real[cond]))

In [None]:
for cond in conditions:
    time_long_plot(Y_real[cond], predictions_real[cond], cond, 'static')

## Dynamic

In [None]:
for cond in conditions:

    X_temp = X_real[cond].drop('days', axis = 1).loc[[0]]

    predictions_real[cond] = pd.DataFrame(columns=targets)

    for i in range(len(X_real[cond])):
        current_predictions = predict(X_temp)
        predictions_real[cond] = predictions_real[cond].append(current_predictions, ignore_index = True)
        
        if i < len(X_real[cond]) - 1:
            X_temp = X_real[cond].drop('days', axis = 1).loc[[i+1]]
        
        X_temp[targets] = predictions_real[cond].loc[[i]].values

    time_long_plot(Y_real[cond], predictions_real[cond], cond, 'dynamic')

In [None]:
for cond in conditions:
    for i in range(5):
        comparsion_plot(Y_real[cond].iloc[:, i], predictions_real[cond].iloc[:, i], data_type='real_dynamic_predictions_' + cond, current_target=targets[i])

In [None]:
for cond in conditions:
    print(cond)
    print(print_scores(Y_real[cond].drop('days', axis=1), predictions_real[cond]))

# Knock out simulation

## Methods definitions

In [None]:
def time_long_plot_ko(true_values, predicted, pred_for_ko, KO_gene, condition, predict_method = 'dynamic'):
    x_axis = true_values['days']
    
    for i in range(5):        
        plt.plot(x_axis, true_values.iloc[:,i].values, 'o')
        plt.plot(x_axis, predicted.iloc[:, i], '--')
        plt.plot(x_axis, pred_for_ko.iloc[:, i], '-')
        
        if type(KO_gene) == type(''):
            KO_name = KO_gene
        else:            
            KO_name = '_'.join(KO_gene)

        plt.title(targets[i] + f' ({condition} with {KO_gene} KO)') # for correct title
        plt.xlabel('days')
        plt.ylabel('rltv expr lvl')
        plt.legend(['true values', 'predictions on WT', 'predictions with KO'])
        plt.savefig(results_path + f'{predict_method}/KOs/{targets[i]}/{targets[i]}_{condition}_with_{KO_gene}_KO.png', 
                    bbox_inches='tight', dpi=300)
        plt.show()

## KO Iterations set

In [None]:
targets_with_cond = [name + '_LD' for name in targets]
targets_with_cond = targets_with_cond + [name + '_SD' for name in targets]

In [None]:
targets_with_cond

In [None]:
genes

In [None]:
targets

In [None]:
# FTs = [name for name in genes if name.startswith('FT')]
FTs

In [None]:
FT_pairs = list(itertools.combinations(FTs, 2))
FT_pairs

In [None]:
FT_trios = list(itertools.combinations(FTs, 3))
FT_trios

In [None]:
FT_quads = list(itertools.combinations(FTs, 4))
FT_quads

In [None]:
KO_iterations = [targets, FTs, FT_pairs, FT_trios, FT_quads]

## KO simulations

In [None]:
AUC = pd.DataFrame(0, index=targets_with_cond, columns=['WT'])

In [None]:
X_ko

In [None]:
list(KO_iterations[4][2])

In [None]:
for ko_genes in KO_iterations:
    for cond in conditions:
        for ko_name in ko_genes:
                            
            if type(ko_name) == type((1,)):
                ko_name = list(ko_name)
            
            X_ko = X_real[cond].copy()
            X_ko = X_ko.drop('days', axis = 1)
            
            pred_ko = pd.DataFrame(columns=targets)
            
            for i in range(len(X_ko)):
                X_ko[ko_name] = 0
                current_predictions = predict(X_ko.loc[[i]])
                pred_ko = pred_ko.append(current_predictions, ignore_index = True)
                
                if i < len(X_ko) - 1:
                    X_ko.loc[[i + 1]][targets] = pred_ko.loc[[i]].values
                
            time_long_plot_ko(Y_real[cond], predictions_real[cond], pred_ko, KO_gene=ko_name, condition=cond)
            
            column = str(ko_name)
            
            # AUCs table:
            days_axis = Y_real[cond]['days']
            for i in range(5):
                current_target = targets[i] + '_' + cond
                AUC.loc[current_target, 'WT'] = auc(days_axis, predictions_real[cond].iloc[:, i])
                AUC.loc[current_target, column] = auc(days_axis, pred_ko.iloc[:, i])

In [None]:
AUC.T

In [None]:
regulators

In [None]:
AUC.index

In [None]:
AUCs = {}

for ind in AUC.index:
    if ind[:-3] != 'AP1' and ind[:-3] != 'LFY':
        AUCs[ind] = AUC.T[ind][['WT'] + [name for name in regulators[ind[:-3]]]]
    else:
        AUCs[ind] = AUC.T[ind].drop(ind[:-3])

# AUCs barplots

In [None]:
def AUC_barplots(data, gene):
    
    temp_data = data / data['WT']
    
    plt.figure(figsize=(15, 8))
    
    myplot = sns.barplot(y = temp_data.values, x = temp_data.keys())
    plt.xticks(rotation=90)
    plt.xlabel('KO genes')
    plt.ylabel('KO/WT transcription level')
    plt.title(gene.replace('_', ' '))
    myplot.axes.axhline(1, dashes=(5, 1))

    plt.savefig(results_path + f'dynamic/barplots/{gene}.png', bbox_inches='tight', dpi=300)
    plt.show()

In [None]:
for key in AUCs.keys():
    AUC_barplots(AUCs[key], key)

# Save and load the model

In [None]:
dill.dump_session('cache/Mixed_model_session.db')

In [None]:
dill.load_session('cache/Mixed_model_session.db')

## Some trash

In [None]:
# import torch

In [None]:
# model_save_name = 'RF_regressor.pt'
# # path = F"/content/gdrive/My Drive/{model_save_name}" 
# model_path = data_path + model_save_name
# torch.save(total_model.state_dict(), model_path)

In [None]:
def logging(description, out):
    string = str(description) + ': ' + str(out)
    with open('results/log.txt', 'a') as out_file:
        out_file.write(string + '\n')
    print(string)

In [None]:
# def clear_log():
#     with open('results/NN/NN_log.txt', 'w') as out_file:
#         out_file.write('log file have been cleared ' + str(datetime.datetime.now().strftime('%d-%m-%Y %H:%M:%S')))