# Imports

In [None]:
import importlib
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns

import os
import sys

In [None]:
pythoncodepath = os.path.abspath(os.path.join('..', '..', '_pythoncode'))
sys.path = [pythoncodepath] + sys.path
import importhelper
importhelper.addfolders2path(pythoncodepath)

In [None]:
import data_utils
import plot_utils
plot_utils.set_rcParams()

# Load data

In [None]:
data_folder = os.path.join('..', '..', 'step2b_analyse_optimized_cbcs', 'data')

In [None]:
sorted(os.listdir(data_folder))

In [None]:
ON_rec_data_sorted = data_utils.load_var(os.path.join(data_folder, 'ON_data_sorted.pkl'))
OFF_rec_data_sorted = data_utils.load_var(os.path.join(data_folder, 'OFF_data_sorted.pkl'))
rec_time = data_utils.load_var(os.path.join(data_folder, 'rec_time.pkl'))

In [None]:
cell2loss = data_utils.load_var(os.path.join(data_folder, 'cell2loss.pkl'))

In [None]:
n_traces = len(ON_rec_data_sorted['all'])

# Plotting functions

## Data functions

In [None]:
def extract_data(rec_data_list, rec_data_full_list, cell):
    
    Vms = [rec_data['BC Vm Soma'].values for rec_data in rec_data_list]
    rates = [rec_data['rate BC'].mean(axis=1).values for rec_data in rec_data_list]
    iGlus = [cell2loss[cell].rate2best_iGluSnFR_trace(trace=rate)[0] for rate in rates]
    
    full_Vms = [rec_data['BC Vm Soma'].values for rec_data in rec_data_full_list]
    full_rates = [rec_data['rate BC'].mean(axis=1).values for rec_data in rec_data_full_list]
    full_iGlus = [cell2loss[cell].rate2best_iGluSnFR_trace(trace=rate)[0] for rate in full_rates]
    
    
    Vms_rel = [Vm - full_Vm for Vm, full_Vm in zip(Vms, full_Vms)]
    rate_rel = [rate - full_rate for rate, full_rate in zip(rates, full_rates)]
    iGlus_rel = [iGlu - full_iGlu for iGlu, full_iGlu in zip(iGlus, full_iGlus)]

    losses = [cell2loss[cell].calc_loss({'rate': rate, 'Vm': Vm}) for rate, Vm in zip(rates, Vms)]
    full_losses = [cell2loss[cell].calc_loss({'rate': rate, 'Vm': Vm}) for rate, Vm in zip(full_rates, full_Vms)]
    losses_rel = [{k: loss_i[k]-full_loss_i[k] for k in loss_i.keys()} for loss_i, full_loss_i in zip(losses, full_losses)]
    
    return Vms, rates, iGlus, Vms_rel, rate_rel, iGlus_rel, losses, losses_rel

## Plot functions.

In [None]:
from matplotlib import cm
trace_mapper = cm.get_cmap('viridis', n_traces+1)

In [None]:
fig, ax = plt.subplots(1,1,figsize=(3,1))
for i in range(n_traces):
    ax.plot(i, 0, marker='o', ms='10', c=trace_mapper(i))

In [None]:
single_trace_kw = dict(
    ls='-',
    lw=0.1,
    zorder=0,
    alpha=1,
)

mean_kw = dict(
    ls='-',
    lw=0.4,
    c='r',
    zorder=10,
    alpha=0.6,
)

area_kw= dict(
    color='darkgray',
    zorder=-10,
    lw=0.0
)

def plot_trace_mode(ax, rec_time, trace_list, pltidx=None):
    if pltidx is None:
        tidx = np.arange(0,rec_time.size,1)
    else:
        tidx = pltidx
    for ti, trace in enumerate(trace_list):
        ax.plot(rec_time[tidx], trace[tidx], **single_trace_kw, color=trace_mapper(ti))
        
    plt.tight_layout()

In [None]:
loss_lw = dict(marker='_', ls='None', markeredgewidth=1, zorder=-5, markersize=3, clip_on=False)
loss_line_kw = dict(solid_capstyle='butt', marker='None', c='k', zorder=-10, clip_on=False, lw=0.8)
loss_mean_kw = dict(marker='_', ls='-', alpha=0.8, zorder=0, markersize=6, c='red', clip_on=False)

def plot_loss_mode(ax, losses):
    tot_losses = np.array([loss_i['total'] for loss_i in losses])
    for i, loss_i in enumerate(tot_losses):
        ax.plot(0, loss_i, **loss_lw, color=trace_mapper(i))
    ax.plot([0,0], [np.min(tot_losses), np.max(tot_losses)], **loss_line_kw)
    ax.plot(0, np.mean(tot_losses), **loss_mean_kw)
    
    iGlu_losses = np.array([loss_i['iGluSnFR'] for loss_i in losses])
    for i, loss_i in enumerate(iGlu_losses):
        ax.plot(1, loss_i, **loss_lw, color=trace_mapper(i))
    ax.plot([1,1], [np.min(iGlu_losses), np.max(iGlu_losses)], **loss_line_kw)
    ax.plot(1, np.mean(iGlu_losses), **loss_mean_kw)
    
    ax.set_xlim(-0.3, 1.3)
    
    if np.all(tot_losses>0) and np.all(iGlu_losses > 0):
        ax.set_ylim(0, np.max([np.max(tot_losses), 0.001]))
        if np.max(np.abs(tot_losses)) <= 0.002:
            ax.set_yticks([0, 0.001])
    else:
        ax.tick_params(axis='x', length=0.0)
        absmax = np.max([np.max(np.abs(tot_losses)), 0.001])
        ax.set_ylim(-absmax, absmax)
        
    ax.spines['bottom'].set_position('zero')

## Labels

In [None]:
mode_renaming = {
    'no_HCN': r"w/o $HCN$",
    'no_Kv': r"w/o $K_v$",
    'no_Kir': r"w/o $K_{ir}$",
    'no_Na': r"w/o $Na_{V}$",
    'no_T_at': r"w/o $Ca_{T}\,@\,AT$",
    'no_L_at': r"w/o $Ca_{L}\,@\,AT$",
    'no_somaCa': r"w/o $Ca\,@\,S$",
    'passive': "passive",
}

In [None]:
ylabels = [
    r'V$_m$ (mV)',
    r'',
    r'iGluSnFR',
    r'',
]

In [None]:
cols = ['mode', 'Vm', 'dVm', 'pad', 'iGlu', 'diGlu', 'pad', 'dloss']

width_ratios = []
for col in cols:
    if col == 'mode': wr = 1
    elif col == 'pad': wr = 0.0
    elif 'loss' in col: wr = 2
    else: wr = 4
        
    width_ratios.append(wr)

In [None]:
col2title = {
    'Vm': r'V$_m$',
    'dVm': r'$\Delta$ V$_m$',
    'iGlu': r'iGluSnFR',
    'diGlu': r'$\Delta$ iGluSnFR',
    'loss': 'Discrepancy',
    'dloss': r'$\Delta$ Discrepancy',
}

In [None]:
col2ylabel = {
    'Vm': r'V$_m$ (mV)',
    'dVm': '',
    'iGlu': r'iGluSnFR',
    'diGlu': '',
    'dloss': 'Discrepancy'
}

## Plot

In [None]:
plot_ON = True

if plot_ON:
    cell = 'CBC5o'
    cell_type = 'ON'
    rec_data_sorted = ON_rec_data_sorted
    filename = 'ON_removed_channels_traces'
    
    plot_mode_order = ['no_HCN', 'no_Kv', 'no_Kir', 'no_Na', 'no_somaCa', 'passive']
    
else:
    cell = 'CBC3a'
    cell_type = 'OFF'
    rec_data_sorted = OFF_rec_data_sorted
    filename = 'OFF_removed_channels_traces'
    
    plot_mode_order = ['no_HCN', 'no_Kv', 'no_Kir', 'no_Na', 'no_somaCa', 'no_T_at', 'no_L_at', 'passive']

In [None]:
from matplotlib import ticker

sbnx = len(cols)
sbny = len(plot_mode_order)

fig, axs = plt.subplots(
    nrows=sbny, ncols=sbnx, figsize=(5.6, 0.8*sbny+0.4), sharex='col', sharey=False,
    gridspec_kw=dict(width_ratios=width_ratios), squeeze=False,
)

sns.despine()

iGlu_time = cell2loss['CBC5o'].target_time + 1

for ax_row, mode in zip(axs, plot_mode_order):
    
    rec_data_list = rec_data_sorted[mode]
    
    Vms, rates, iGlus, Vms_rel, rate_rel, iGlus_rel, losses, losses_rel =\
        extract_data(rec_data_list=rec_data_list, rec_data_full_list=rec_data_sorted['all'], cell=cell)

    if 'Vm' in cols:
        plot_trace_mode(ax_row[np.argmax(np.asarray(cols) == 'Vm')], rec_time, [1e3*Vm for Vm in Vms],
                        pltidx=np.arange(0,rec_time.size,10))
    if 'dVm' in cols:
        plot_trace_mode(ax_row[np.argmax(np.asarray(cols) == 'dVm')], rec_time, [1e3*Vm for Vm in Vms_rel],
                        pltidx=np.arange(0,rec_time.size,10))
    if 'iGlu' in cols:
        plot_trace_mode(ax_row[np.argmax(np.asarray(cols) == 'iGlu')], iGlu_time, iGlus)
        ax_row[np.argmax(np.asarray(cols) == 'iGlu')].set_yticks([0,1])
    if 'diGlu' in cols:
        plot_trace_mode(ax_row[np.argmax(np.asarray(cols) == 'diGlu')], iGlu_time, iGlus_rel)
    if 'loss' in cols:
        plot_loss_mode(ax_row[np.argmax(np.asarray(cols) == 'loss')], losses)
    if 'dloss' in cols:
        plot_loss_mode(ax_row[np.argmax(np.asarray(cols) == 'dloss')], losses_rel)
    
for ax, col in zip(axs[0,:], cols):
    if col in col2title.keys():
        ax.set_title(col2title[col])
    
for ax, col in zip(axs[-1,:], cols):
    if col not in ['pad', 'mode', 'loss', 'dloss']:
        ax.set_xlabel('Time (ms)')
        ax.set_xticks([0,30])
    elif 'loss' in col:
        ax.set_xticks([0,1])
        ax.set_xticklabels(['total', 'iGluSnFR'], rotation=60)

for ax_col, col in zip(axs.T, cols):
    if col in col2ylabel.keys():
        plot_utils.set_labs(axs=ax_col, ylabs=col2ylabel[col])
        fig.align_ylabels(ax_col)
    
    elif col == 'pad':
        for ax in ax_col:
            ax.axis('off')
    
    elif col == 'mode':
        for ax, mode in zip(ax_col, plot_mode_order):
            ax.set_xticks([])
            ax.set_yticks([])
            for spine in ax.spines.values(): spine.set_visible(False)
            ax.set_ylabel(mode_renaming[mode], rotation=0, ha='right', va='center')

plt.tight_layout(w_pad=-1.3, h_pad=0.3)
            
for ax in axs[:,[4,5]].flatten():
    box = np.array(ax.get_position().bounds)
    box[0] -= 0.02
    ax.set_position(box)

plt.savefig('../_figures_apx/figapx06_' + filename + '.pdf')

# Export 

In [None]:
rates_rm_ch_exdf = pd.DataFrame()
iGlus_rm_ch_exdf = pd.DataFrame()

tidx = np.arange(0,rec_time.size,10)
rates_rm_ch_exdf['Time/s'] = rec_time[tidx]
iGlus_rm_ch_exdf['Time/s'] = iGlu_time

for mode in plot_mode_order:
    
    rec_data_list = rec_data_sorted[mode]
    
    Vms, rates, iGlus, Vms_rel, rate_rel, iGlus_rel, losses, losses_rel =\
        extract_data(rec_data_list=rec_data_list, rec_data_full_list=rec_data_sorted['all'], cell=cell)
    
    for idx in range(len(Vms)):
        
        rates_rm_ch_exdf['Cell' +str(idx+1)+' Release rate/(ves./s): ' + mode] = rates[idx][tidx]
        rates_rm_ch_exdf['Cell' +str(idx+1)+' Somatic membrane potential/mV: ' + mode] = Vms[idx][tidx]*1e3
        iGlus_rm_ch_exdf['Cell' +str(idx+1)+' iGluSnFR: '+ mode] = iGlus[idx]
    
for idx in range(len(Vms)):
    rates_rm_ch_exdf['Cell' +str(idx+1)+' Release rate/(ves./s): ' + 'full model'] =\
        (rates[idx][tidx]-rate_rel[idx][tidx])
    rates_rm_ch_exdf['Cell' +str(idx+1)+' Somatic membrane potential/mV: ' + 'full model'] =\
        (Vms[idx][tidx]-Vms_rel[idx][tidx])*1e3
    iGlus_rm_ch_exdf['Cell' +str(idx+1)+' iGluSnFR: ' + 'full model'] = (iGlus[idx]-iGlus_rel[idx])
    
rates_rm_ch_exdf.to_csv('source_data/CBC/' + cell_type + '_rates_and_Vms_with_removed_ion_channels.csv')
iGlus_rm_ch_exdf.to_csv('source_data/CBC/' + cell_type + '_iGluSnFR_with_removed_ion_channels.csv')

In [None]:
rates_rm_ch_exdf = pd.read_csv('source_data/CBC/' + cell_type + '_rates_and_Vms_with_removed_ion_channels.csv', index_col=0)
rates_rm_ch_exdf.plot(x='Time/s', y=[col for col in rates_rm_ch_exdf.columns if 'rate' in col and 'Cell1' in col],
                      figsize=(12,2))

In [None]:
rates_rm_ch_exdf.columns

In [None]:
rates_rm_ch_exdf.plot(x='Time/s', y=[col for col in rates_rm_ch_exdf.columns if 'mV' in col and 'Cell4' in col],
                      figsize=(12,2))

In [None]:
iGlus_rm_ch_exdf = pd.read_csv('source_data/CBC/' + cell_type + '_iGluSnFR_with_removed_ion_channels.csv', index_col=0)
iGlus_rm_ch_exdf.plot(x='Time/s', y=[col for col in iGlus_rm_ch_exdf.columns if 'iGluSnFR' in col and 'Cell1' in col],
                      figsize=(12,2))

In [None]:
list(iGlus_rm_ch_exdf.columns)