# Imports

In [None]:
import importlib
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns

In [None]:
import os
import sys

In [None]:
pythoncodepath = os.path.abspath(os.path.join('..', '..', '_pythoncode'))
sys.path = [pythoncodepath] + sys.path
import importhelper
importhelper.addfolders2path(pythoncodepath)

In [None]:
import data_utils
import interpolation_utils
import lin_trans_utils

# Define data to be plotted.

## Get final model output

In [None]:
cbc_optim_folder = os.path.join('..', '..', 'step2a_optimize_cbc', 'optim_data')

In [None]:
sorted(os.listdir(cbc_optim_folder))

In [None]:
cell2folder = {
    'OFF':  os.path.join(cbc_optim_folder, 'optimize_OFF_submission2'),
    'ON':   os.path.join(cbc_optim_folder, 'optimize_ON_submission2'),
}

In [None]:
final_model_outputs = {}
for cell, folder in cell2folder.items():
    final_model_outputs[cell] = data_utils.load_var(os.path.join(folder, 'post_data', 'final_model_output.pkl'))        

## Get other iGluSnFR traces.

In [None]:
experimental_data_folder = os.path.join('..', '..', 'ExperimentalData', 'PreprocessedData')

In [None]:
os.listdir(experimental_data_folder)

In [None]:
drug_traces_sorted = data_utils.load_var(os.path.join(
    experimental_data_folder, 'drug_traces_sorted.pkl'))
no_drug_traces_sorted = data_utils.load_var(os.path.join(
    experimental_data_folder, 'no_drug_traces_sorted.pkl'))
no_drug_traces_sorted.keys()

## Summarize data

In [None]:
all_iGluSnFR_traces = {'OFF': {}, 'ON': {}}

### CBC3a

In [None]:
all_iGluSnFR_traces['OFF']['model_output'] = pd.DataFrame({
    'Time': final_model_outputs['OFF']['Time-Target'],
    'mean': final_model_outputs['OFF']['iGlu'],
})

all_iGluSnFR_traces['OFF']['strychnine'] = pd.DataFrame({
    'Time': drug_traces_sorted['Strychnine']['BC3a']['Time'],
    'mean': drug_traces_sorted['Strychnine']['BC3a']['mean']
})

all_iGluSnFR_traces['OFF']['no_drug'] = pd.DataFrame({
    'Time': no_drug_traces_sorted['BC3a']['Time'],
    'mean': no_drug_traces_sorted['BC3a']['mean']
})

all_iGluSnFR_traces['OFF']['similar_strychnine'] = pd.DataFrame({
    'Time': drug_traces_sorted['Strychnine']['BC4']['Time'],
    'mean': drug_traces_sorted['Strychnine']['BC4']['mean']
})

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 5))
for name, trace in all_iGluSnFR_traces['OFF'].items():
    ax.plot(trace['Time'], trace['mean']-trace['mean'].iloc[0], label=name)
plt.legend();

### CBC5o

In [None]:
all_iGluSnFR_traces['ON']['model_output'] = pd.DataFrame({
    'Time': final_model_outputs['ON']['Time-Target'],
    'mean': final_model_outputs['ON']['iGlu'],
})

all_iGluSnFR_traces['ON']['strychnine'] = pd.DataFrame({
    'Time': drug_traces_sorted['Strychnine']['BC5o']['Time'],
    'mean': drug_traces_sorted['Strychnine']['BC5o']['mean']
})

all_iGluSnFR_traces['ON']['no_drug'] = pd.DataFrame({
    'Time': no_drug_traces_sorted['BC5o']['Time'],
    'mean': no_drug_traces_sorted['BC5o']['mean']
})

all_iGluSnFR_traces['ON']['similar_strychnine'] = pd.DataFrame({
    'Time': drug_traces_sorted['Strychnine']['BC7']['Time'],
    'mean': drug_traces_sorted['Strychnine']['BC7']['mean']
})

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 5))
for name, trace in all_iGluSnFR_traces['ON'].items():
    ax.plot(trace['Time'], trace['mean']-trace['mean'].iloc[0], label=name)
plt.legend();

## Compute loss between traces

In [None]:
def norm_target(target):
    target = target.copy()
    if target['Time'].iloc[0] < 1.0:
        target['mean'] -= np.mean(target['mean'][target['Time']<=1.0])
    target['mean'] /= target['mean'].max()
    
    return target

In [None]:
model_output_time = all_iGluSnFR_traces['OFF']['model_output']['Time']

In [None]:
import loss_funcs
importlib.reload(loss_funcs);

def compute_iGluSnFR_loss(trace, target, plot=False):
    
    target = norm_target(target)
    
    loss = loss_funcs.LossOptimizeCell(
        target=target, rec_time=model_output_time, t_drop=0.5, loss_params='iGlu only'
    )
    
    intpol_iGluSnFR_trace = interpolation_utils.in_ex_polate(
      x_old=trace['Time'], y_old=trace['mean'], x_new=loss.target_time
    )
    
    trans_iGluSnFR_trace, iGluSnFR_loss = lin_trans_utils.best_lin_trans(
      trace=intpol_iGluSnFR_trace, target=loss.target, loss_fun=loss.compute_iGluSnFR_trace_loss
    )
    
    _, f_norm_loss = loss.rate2best_iGluSnFR_trace(trace=np.zeros(loss.target_time.size))
    iGluSnFR_loss /= f_norm_loss
    
    if plot:
        plt.figure(1,(12,1))
        plt.title("{:.3g}".format(iGluSnFR_loss))
        plt.plot(trace['Time'], trace['mean'], label='original')
        plt.plot(loss.target_time, trans_iGluSnFR_trace, label='fit')
        plt.plot(loss.target_time, loss.target, alpha=0.8, lw=1, label='target')
        plt.legend()
        plt.show()
        
    return iGluSnFR_loss

In [None]:
np.random.seed(12)

all_iGluSnFR_losses = {}
for target_cell, target_dict in all_iGluSnFR_traces.items():
    for target_label, target in target_dict.items():
        if target_cell + ' ' + target_label not in all_iGluSnFR_losses.keys():
            all_iGluSnFR_losses[target_cell + ' ' + target_label] = {}
        
        for trace_cell, trace_dict in all_iGluSnFR_traces.items():
            for trace_label, trace in trace_dict.items():
                
                iGluSnFR_loss = compute_iGluSnFR_loss(trace=trace, target=target, plot=True)
                
                all_iGluSnFR_losses[target_cell + ' ' + target_label][trace_cell + ' ' + trace_label] = iGluSnFR_loss

In [None]:
all_iGluSnFR_losses = pd.DataFrame(all_iGluSnFR_losses)

In [None]:
fig, ax = plt.subplots(1,1,figsize=(9,9))
sns.heatmap(all_iGluSnFR_losses,annot=True, fmt='.3f', ax=ax)
ax.axis('equal');

In [None]:
fig, ax = plt.subplots(1,1,figsize=(9,9))
sns.heatmap(all_iGluSnFR_losses,annot=True, fmt='.2f', ax=ax)
ax.axis('equal');

In [None]:
all_iGluSnFR_losses_asym = all_iGluSnFR_losses.copy()

In [None]:
all_iGluSnFR_losses_asym.iloc[:,:] = np.tril(all_iGluSnFR_losses) - np.triu(all_iGluSnFR_losses).T \
                                    + np.triu(all_iGluSnFR_losses) - np.tril(all_iGluSnFR_losses).T

In [None]:
fig, ax = plt.subplots(1,1,figsize=(9,9))
sns.heatmap(all_iGluSnFR_losses_asym, annot=True, fmt='.3f', ax=ax)
ax.axis('equal');

# Export data

In [None]:
data_utils.make_dir('source_data')
data_utils.save_var(all_iGluSnFR_losses, 'source_data/all_iGluSnFR_losses.pkl')
all_iGluSnFR_losses.to_csv('source_data/all_iGluSnFR_losses.csv', float_format='%.6f')

In [None]:
cols = []
for target_cell, target_dict in all_iGluSnFR_traces.items():
    for target_label, target in target_dict.items():
        print(target_cell, target_label, target.shape)
        cols.append(target_cell + ' ' + target_label)
        assert np.all(target['Time'].values[-model_output_time.size:] == model_output_time.values)

In [None]:
trace_data_exdf = pd.DataFrame(np.full((1984, 9), np.nan), columns=['Time/s']+cols)

trace_data_exdf['Time/s'] = model_output_time

for target_cell, target_dict in all_iGluSnFR_traces.items():
    for target_label, target in target_dict.items():
        print(target_cell, target_label, target.shape)
        col = target_cell + ' ' + target_label
        
        trace_data_exdf[col] = norm_target(target)['mean'].values[-model_output_time.size:]
        
trace_data_exdf.to_csv('source_data/compared_iGluSnFR_traces.csv', float_format='%.6f', index=False)

## Show exported data

In [None]:
trace_data_exdf = pd.read_csv('source_data/compared_iGluSnFR_traces.csv')
trace_data_exdf.head()

In [None]:
fig, axs = plt.subplots(2,1,figsize=(12,5))
trace_data_exdf.plot(x='Time/s', y=[col for col in trace_data_exdf.columns if 'OFF' in col], ax=axs[0])
trace_data_exdf.plot(x='Time/s', y=[col for col in trace_data_exdf.columns if 'ON' in col], ax=axs[1])