# Plot Discrepancy function for the cell models

In [None]:
import importlib
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

import os
import sys

In [None]:
pythoncodepath = os.path.abspath(os.path.join('..', '..', '_pythoncode'))
sys.path = [pythoncodepath] + sys.path
import importhelper
importhelper.addfolders2path(pythoncodepath)

In [None]:
import data_utils
import plot_utils

In [None]:
plot_utils.set_rcParams()

In [None]:
fig_num = os.getcwd().split('/')[-1][3:5]
print(fig_num)

# Get loss functions.

In [None]:
cone_optim_folder = os.path.join('..', '..', 'step1a_optimize_cones', 'optim_data')
cbc_optim_folder = os.path.join('..', '..', 'step2a_optimize_cbc', 'optim_data')

In [None]:
os.listdir(cone_optim_folder)

In [None]:
os.listdir(cbc_optim_folder)

In [None]:
cell2folder = {
    'Cone': os.path.join(cone_optim_folder, 'optimize_cone_submission2'),
    'OFF':  os.path.join(cbc_optim_folder, 'optimize_OFF_submission2'),
    'ON':   os.path.join(cbc_optim_folder, 'optimize_ON_submission2'),
}

In [None]:
cell_losses = {}

for cell, folder in cell2folder.items():
    cell_losses[cell] = data_utils.load_var(os.path.join(folder, 'loss.pkl'))

## Sumarize bounds for all cells.

In [None]:
# Get bounds.
all_loss_rngs = {}

for cell, cell_loss in cell_losses.items():
    for loss_name, loss_dict in cell_loss.loss_params.items():
        if 'good' in loss_dict.keys():
            if loss_name not in all_loss_rngs.keys():
                all_loss_rngs[loss_name] = {'good': [], 'acceptable': []}
                
            good = list(loss_dict['good'])
            acceptable = list(loss_dict['acceptable'])
                
            good = [v if v is not None else np.nan for v in good]
            acceptable = [v if v is not None else np.nan for v in acceptable]
                
            all_loss_rngs[loss_name]['good'].append(good)
            all_loss_rngs[loss_name]['acceptable'].append(acceptable)
            
all_loss_rngs

### Derive bounds

In [None]:
# Summarize bounds.
loss_rngs = {}

for loss_name, loss_bounds in all_loss_rngs.items():

    lbs = np.asarray(loss_bounds['acceptable'])[:,0]
    ubs = np.asarray(loss_bounds['acceptable'])[:,1]
    
    if np.any(~np.isnan(lbs)):
        lb = np.nanmin(lbs)
    else:
        lb = None
    
    if np.any(~np.isnan(ubs)):
        ub = np.nanmax(ubs)
    else:
        ub = None
        
    loss_rngs[loss_name] = [lb, ub]

In [None]:
loss_rngs

In [None]:
ylims = {}
for loss_name, loss_rng in loss_rngs.items():
    lb = 0 if loss_rng[0] is None else -1 
    ub = 0 if loss_rng[1] is None else +1 
    ylims[loss_name] = (lb, ub)
    
ylims

### Manually correct

In [None]:
loss_rngs['rate_rest'][1] = 110
loss_rngs['rate_rel_range'][1] = 110
loss_rngs['Vm_low'][1] = -0.05
loss_rngs['Vm_low'][0] = -0.12
loss_rngs['Vm_rest'][1] = -0.01
loss_rngs['Vm_rest'][0] = -0.09
loss_rngs['Vm_high'][0] = -0.04
loss_rngs['Vm_high'][1] = 0.01
loss_rngs['Vm_rel_range'][1] = 0.05

In [None]:
loss_rngs

# Evaluate

In [None]:
loss_evals = {}

for cell, cell_loss in cell_losses.items():
    
    for loss_name, loss_dict in cell_loss.loss_params.items():
        
        if loss_name in loss_rngs:
            
            xvals = np.linspace(loss_rngs[loss_name][0], loss_rngs[loss_name][1], 1000)
            
            if loss_name not in loss_evals.keys():
                loss_evals[loss_name] = {}
                loss_evals[loss_name]['xvals'] = xvals.copy()
                if 'Vm' in loss_name: loss_evals[loss_name]['xvals'] *= 1e3 
            
            yvals = np.empty(xvals.size)
            
            for i, x in enumerate(xvals):
                yvals[i] = cell_loss.loss_value_in_range(
                    value=x, good=loss_dict['good'], acceptable=loss_dict['acceptable']
                )
                
            loss_evals[loss_name][cell] = yvals

## Make simple plot.

In [None]:
fig, axs = plt.subplots(len(loss_evals), 1,
    figsize=(5, 0.8*len(loss_evals)), sharey=True, subplot_kw=dict(ylim=(-1.1, 1.1)))
for ax, (loss_name, loss_eval) in zip(axs, loss_evals.items()):
    ax.set_title(loss_name)
    for cell in cell2folder.keys():
        ax.plot(loss_eval['xvals'], loss_eval[cell], label=cell)

axs[0].legend()
plt.tight_layout()

# Paper figure

In [None]:
titles = {
    'rate_rest': r'$\delta_{Rate}^{Rest}$',
    'rate_rel_range': r'$\delta_{Rate}^\Delta$',
    'Vm_rest': r'$\delta_{V}^{Rest}$',
    'Vm_rel_range': r'$\delta_{V}^\Delta$',
    'Vm_low': r'$\delta_{V}^{min}$',
    'Vm_high': r'$\delta_{V}^{max}$',
}

In [None]:
xlabels = {
    'rate_rest': r'Release rate (ves./s)',
    'rate_rel_range': r'Release rate (ves./s)',
    'Vm_rest': r'Membrane potential (mV)',
    'Vm_rel_range': r'Membrane potential (mV)',
    'Vm_low': r'Membrane potential (mV)',
    'Vm_high': r'Membrane potential (mV)',
}

In [None]:
colors = {'Cone': 'firebrick', 'OFF': 'steelblue', 'ON': 'darkgreen'}
linestyles = {'Cone': '--', 'OFF': '-', 'ON': ':'}
lws = {'Cone': 1.5, 'OFF': 1.5, 'ON': 2}

In [None]:
import seaborn as sns

fig, axs = plt.subplots(3, 2, figsize=(5.6, 2.7))

for ax, (loss_name, loss_eval) in zip(axs.flatten(), loss_evals.items()):
    ax.set_title(titles[loss_name])
    for cell in cell2folder.keys():
        ax.plot(loss_eval['xvals'], loss_eval[cell], label=cell,
                lw=lws[cell], color=colors[cell], ls=linestyles[cell],
                clip_on=False)
    ax.set_ylim(ylims[loss_name])
    ax.set_yticks(np.unique(list(ylims[loss_name]) + [0]))
    ax.set_ylabel('Discrepancy', verticalalignment='center', labelpad=10)
    ax.set_xlabel(xlabels[loss_name])

sns.despine()

plt.tight_layout(h_pad=0.5, rect=[0, -0.04, 1, 1.04])
axs[1, 0].legend(loc='upper left', bbox_to_anchor=(0.01, 1.7),
                 frameon=False, borderpad=0.0, labelspacing=0.1)
plt.savefig(f'../_figures/fig{fig_num}_discrepancy_function.pdf')
plt.show()