# Imports

In [None]:
import importlib
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import warnings

import os
import sys

In [None]:
pythoncodepath = os.path.abspath(os.path.join('..', 'pythoncode'))
sys.path = [pythoncodepath] + sys.path
import importhelper
importhelper.addfolders2path(pythoncodepath)

In [None]:
import data_utils

# Create Cell and stimulate

In [None]:
stim_t_rng = (1, 32)
predur = 10

In [None]:
# Load experimental data
data_folder = os.path.join('..', 'step0_preprocess_iGluSnFR_data', 'data_preprocessed')
stimulus    = pd.read_csv(os.path.join(data_folder, 'Franke2017_stimulus_time_and_amp_corrected.csv'))

In [None]:
fig, ax = plt.subplots(1,1,figsize=(12,2),subplot_kw=dict(xlim=stim_t_rng))
stimulus.plot(x='Time', ax=ax)
plt.show()

## Select optimized cells

In [None]:
cbc_folder = os.path.join('..', 'step2a_optimize_cbc', )

cell2folder = {
    'OFF': os.path.join(cbc_folder, 'optim_data', 'optimize_OFF_submission2'),
    'ON': os.path.join(cbc_folder, 'optim_data', 'optimize_ON_submission2'),
}

## Create cells


In [None]:
import retsim_cells
importlib.reload(retsim_cells)

kwargs = dict(
    stimulus=stimulus,
    stim_type='Light',
    t_rng=stim_t_rng,
    expt_base_file=os.path.join(cbc_folder, 'retsim_files', 'expt_CBC_base.cc'),
    cone_densfile='dens_cone_optimized_submission2.n',
    nval_file='nval_optimize_CBCs.n',
    retsim_path=os.path.abspath(os.path.join('..', 'neuronc', 'models', 'retsim')) + '/'
)

# Create cells.
ON_cell = retsim_cells.CBC(
    bp_type='CBC5o',
    predur=predur,
    expt_file='test_ON',
    bp_densfile         = 'dens_CBC5o_optimize_ON.n',
    chanparams_file     = 'chanparams_CBC5o_optimize_ON.n',
    **kwargs
)

OFF_cell = retsim_cells.CBC(
    bp_type='CBC3a',
    predur=predur,
    expt_file='test_OFF',
    bp_densfile         = 'dens_CBC3a_optimize_OFF.n',
    chanparams_file     = 'chanparams_CBC3a_optimize_OFF.n',
    **kwargs
)

cells = [ON_cell, OFF_cell]

## Set parameters

### Defaults and units

In [None]:
for cell, cell_name in zip([ON_cell, OFF_cell], ['ON', 'OFF']):

    cell.params_default = data_utils.load_var(os.path.join(cell2folder[cell_name], 'cell_params_default.pkl'))
    cell.params_unit = data_utils.load_var(os.path.join(cell2folder[cell_name], 'cell_params_unit.pkl'))
    
    cell.params_default.update(data_utils.load_var(os.path.join(cell2folder[cell_name], 'final_cpl_dict.pkl')))

### Optimized parameters

In [None]:
N_param_sets = 5

cell2params_list = {}

for cell in cells:
    optim_folder = cell2folder['OFF' if cell.is_OFF_bp else 'ON']
    samples = data_utils.load_var(os.path.join(optim_folder, 'post_data', 'post_model_output_list.pkl'))
    d_sort_idxs = np.argsort([sample['loss']['total'] for sample in samples])
    cell2params_list[cell.bp_type] = [samples[idx]['params'] for idx in d_sort_idxs[:N_param_sets]]

In [None]:
# List params not in optimized params.
for cell in cells:
    print(cell.bp_type, ':')
    for p_name, p_value in cell.params_default.items():
        if p_name not in cell2params_list[cell.bp_type][0].keys():
            print(p_name, end=', ')
    print()

### Cell loss

In [None]:
data_utils.make_dir('data')

cell2loss = {}
for cell in cells:
    optim_folder = cell2folder['OFF' if cell.is_OFF_bp else 'ON']
    cell2loss[cell.bp_type] = data_utils.load_var(os.path.join(optim_folder, 'loss.pkl'))
    
data_utils.save_var(cell2loss, os.path.join('data', 'cell2loss.pkl'))

### Prepare cells

In [None]:
# Create c++ files.
ON_cell.create_retsim_expt_file(verbose=False, on2cone_nodes=[1077, 980, 1190])
OFF_cell.create_retsim_expt_file(verbose=False, off2cone_nodes=[686, 1037, 828, 950, 879])
# Compile c++ files.
!(cd {cell.retsim_path} && make)

In [None]:
ON_cell.set_rot(mxrot=-90, myrot=0)
ON_im = ON_cell.init_retsim(verbose=False, print_comps=True, update=True)

OFF_cell.set_rot(mxrot=-90, myrot=60)
OFF_im = OFF_cell.init_retsim(verbose=False, print_comps=True, update=True)

# Plot.
fig, axs = plt.subplots(1,2,figsize=(14, 8))
axs[0].imshow(ON_im)
axs[1].imshow(OFF_im)
for ax in axs: ax.axis('off')
plt.show()

# Get cell response with different channels removed.

## Remove channel functions

In [None]:
def make_params_passive(full_params):
    new_params = full_params.copy()
    
    for p_name, p_value in new_params.items():
        if ('cd_' in p_name) and (p_name not in ['cd_L_at','cd_T_at','cd_P_at']):
            new_params[p_name] = 0.0            
    
    return new_params

In [None]:
def remove_HCN_channels(full_params):    
    new_params = full_params.copy()
    
    for p_name, p_value in new_params.items():
        if ('cd_H' in p_name):
            new_params[p_name] = 0.0
    return new_params

In [None]:
def remove_Kv_channels(full_params):    
    new_params = full_params.copy()
    
    for p_name, p_value in new_params.items():
        if ('cd_Kv' in p_name):
            new_params[p_name] = 0.0
    return new_params

In [None]:
def remove_Kir_channels(full_params):
    new_params = full_params.copy()
    
    for p_name, p_value in new_params.items():
        if ('cd_Kir' in p_name):
            new_params[p_name] = 0.0
    return new_params

In [None]:
def remove_Na_channels(full_params):
    new_params = full_params.copy()
    
    for p_name, p_value in new_params.items():
        if ('cd_N' in p_name):
            new_params[p_name] = 0.0
    return new_params

In [None]:
def remove_somaCa_channels(full_params):
    new_params = full_params.copy()
    
    for p_name, p_value in new_params.items():
        if p_name in ['cd_L_s','cd_T_s','cd_P_s']:
            new_params[p_name] = 0.0
    return new_params

In [None]:
def remove_L_at_channels(full_params):
    new_params = full_params.copy()
    
    for p_name, p_value in new_params.items():
        if ('cd_L_at' in p_name):
            new_params[p_name] = 0.0
    return new_params

In [None]:
def remove_T_at_channels(full_params):
    new_params = full_params.copy()
    
    for p_name, p_value in new_params.items():
        if ('cd_T_at' in p_name):
            new_params[p_name] = 0.0
    return new_params

In [None]:
def get_mode2params_dict(full_params, isOFF=False):
    
    full_params = full_params.copy()
    
    params_dict = {
        'all': full_params,
        'no_HCN': remove_HCN_channels(full_params),
        'no_Kv': remove_Kv_channels(full_params),
        'no_Kir': remove_Kir_channels(full_params),
        'no_Na': remove_Na_channels(full_params),
        'no_somaCa': remove_somaCa_channels(full_params),
        'passive': make_params_passive(full_params),
    }
    
    if isOFF:
        params_dict['no_L_at'] = remove_L_at_channels(full_params)
        params_dict['no_T_at'] = remove_T_at_channels(full_params)
    
    return params_dict

## Prepare cells

In [None]:
data_utils.make_dir('data')

In [None]:
rec_type = 'test'

In [None]:
def prepare_cell(cell):
    cell.rec_type = rec_type
    cell.create_retsim_stim_file()

In [None]:
for cell in cells: prepare_cell(cell)

### Plot functions

In [None]:
def plot_Vm(rec_data_sorted, rec_time):
    fig, axs = plt.subplots(len(rec_data_sorted),1,figsize=(12,10), sharex=True)
    for ax, (mode, rec_data_list) in zip(axs, rec_data_sorted.items()):
        ax.set_title(mode)
        for rec_data in rec_data_list:
            ax.plot(rec_time, rec_data['BC Vm Soma'])
    plt.tight_layout()

In [None]:
def plot_rate(rec_data_sorted, rec_time):
    fig, axs = plt.subplots(len(rec_data_sorted),1,figsize=(12,10), sharex=True)
    for ax, (mode, rec_data_list) in zip(axs, rec_data_sorted.items()):
        ax.set_title(mode)
        for rec_data in rec_data_list:
            ax.plot(rec_time, rec_data['rate BC'].mean(axis=1))
    plt.tight_layout()

In [None]:
def compare_to_full_model(rec_data_sorted, rec_time):
    fig, axs = plt.subplots(len(rec_data_sorted),4,figsize=(12,10), sharex='col', gridspec_kw=dict(width_ratios=[5,1,5,1]))

    for axs_row, (mode, rec_data_list) in zip(axs, rec_data_sorted.items()):
        axs_row[0].set_title(mode)
        for idx, (rec_data, rec_data_all) in enumerate(zip(rec_data_list, rec_data_sorted['all'])):
            Vm_diff = 1e3*(rec_data['BC Vm Soma']-rec_data_all['BC Vm Soma'])
            rate_diff = (rec_data['rate BC'].mean(axis=1)-rec_data_all['rate BC'].mean(axis=1))

            axs_row[0].plot(rec_time, Vm_diff, lw=1)
            axs_row[1].bar(idx, np.mean(Vm_diff)**2)
            axs_row[2].plot(rec_time, rate_diff, lw=1)
            axs_row[3].bar(idx, np.mean(rate_diff)**2)

    for ax in np.append(axs[:,1], axs[:,3]):
        ax.set_ylim((0, np.max([1, ax.get_ylim()[1]])))

    plt.tight_layout(w_pad=0)
    plt.show()

### Loss functions

In [None]:
def compute_losses(loss, rec_data_sorted):
    l_names = ['total'] + list(loss.loss_params.keys())
    modes =rec_data_sorted.keys()

    losses = {mode: {l_name: [] for l_name in l_names} for mode in modes}

    for mode, rec_data_list in rec_data_sorted.items():
        for rec_data in rec_data_list:
            rec_data_loss = loss.calc_loss(rec_data={
                'rate': rec_data['rate BC'].mean(axis=1).values, 'Vm': rec_data['BC Vm Soma'].values
            })

            for l_name in l_names:
                losses[mode][l_name].append(rec_data_loss[l_name])
    return losses

In [None]:
def loss2text(loss):
    if loss < 0:
        return '<0'
    elif loss == 0:
        return '0'
    else:
        return f"{loss:.3f}"
        

def plot_values(ax, idx, values):
    ax.plot(idx, np.mean(values), marker='_', markersize=10, c='r', markeredgewidth=2)
    ax.plot(np.full(len(values), idx), values, marker='_', markersize=5, c='orange', alpha=0.5, markeredgewidth=2)
    ax.plot([idx, idx], [np.min(values), np.max(values)], c='k')

    ax.text(idx, np.mean(values),\
            loss2text(np.max(values)) + '\n' + loss2text(np.mean(values)) + '\n' + loss2text(np.min(values)))

In [None]:
def plot_losses(losses):
    fig, axs = plt.subplots(len(losses),2,figsize=(12,10), sharey='col')

    for idx, (mode, loss_dict) in enumerate(losses.items()):
        for ax_row, (l_name, l_values) in zip(axs, loss_dict.items()):
            ax_row[0].set_ylabel(l_name, rotation=0, ha='right')
            plot_values(ax=ax_row[0], idx=idx, values=l_values)
            ax_row[0].axhline(0, color='gray')

            if mode == 'all': ax_row[0].axhline(np.mean(l_values), color='red', alpha=0.3, ls='--')

            add_err = np.abs(l_values) - np.abs(losses['all'][l_name])
            plot_values(ax=ax_row[1], idx=idx, values=add_err)
            ax_row[1].axhline(0, color='gray')

    for ax in axs.flatten():       
        ax.set_xticks(np.arange(len(losses.keys())))
        ax.set_xticklabels(list(losses.keys()))

    plt.tight_layout()

In [None]:
def get_modes_to_remove(losses):
    remove_modes = []
    for mode, loss_dict in losses.items():
        if (np.max(loss_dict['total']) - np.max(losses['all']['total']) < 1e-3) or\
           (np.mean(loss_dict['total']) - np.mean(losses['all']['total']) < 0.0) :
            remove_modes.append(mode)
    print(remove_modes)
    return remove_modes

## OFF cell

### Generate data

In [None]:
all_OFF_params = []
for params in cell2params_list[OFF_cell.bp_type]:
    all_OFF_params.append(get_mode2params_dict(params, isOFF=True))
len(all_OFF_params)

In [None]:
modes = list(all_OFF_params[0].keys())
modes

In [None]:
all_OFF_params_list = [params[mode] for params in all_OFF_params for mode in modes]
len(all_OFF_params_list)

In [None]:
OFF_cell.timeout = 250000

In [None]:
load = True

if not load:
    OFF_cell_rec_data_list = OFF_cell.run_parallel(sim_params_list=all_OFF_params_list, n_parallel=25)
    data_utils.save_var((OFF_cell_rec_data_list, all_OFF_params_list), os.path.join('data', 'OFF_cell_rec_data.pkl'))
else:
    OFF_cell_rec_data_list, all_OFF_params_list = data_utils.load_var(os.path.join('data', 'OFF_cell_rec_data.pkl'))

### Plot data

In [None]:
rec_time = OFF_cell_rec_data_list[0][1].copy()

In [None]:
OFF_rec_data_sorted = {mode: [] for mode in modes}

for i, (rec_data_i, params_i) in enumerate(zip(OFF_cell_rec_data_list, all_OFF_params_list)):
    mode = modes[i%len(modes)]
    
    print('### Mode: ', mode)
    print('\tZero channels:', end='\t')
    for k, v in params_i.items():
        if 'cd_' in k and v == 0:
            print(k, end=',')
    
    OFF_rec_data_sorted[mode].append(rec_data_i[0])
    print()

In [None]:
plot_Vm(rec_data_sorted=OFF_rec_data_sorted, rec_time=rec_time)

In [None]:
plot_rate(rec_data_sorted=OFF_rec_data_sorted, rec_time=rec_time)

In [None]:
compare_to_full_model(rec_data_sorted=OFF_rec_data_sorted, rec_time=rec_time)

### Show loss

In [None]:
OFF_losses = compute_losses(loss=cell2loss[OFF_cell.bp_type], rec_data_sorted=OFF_rec_data_sorted)
plot_losses(OFF_losses)

In [None]:
OFF_remove_modes = get_modes_to_remove(losses=OFF_losses)

## ON cell

### Generate data

In [None]:
all_ON_params = []
for params in cell2params_list[ON_cell.bp_type]:
    all_ON_params.append(get_mode2params_dict(params, isOFF=False))

In [None]:
all_ON_params_list = [params_i for params in all_ON_params for params_i in params.values()]
len(all_ON_params_list)

In [None]:
modes = list(all_ON_params[0].keys())

In [None]:
load = True

if not load:
    ON_cell_rec_data_list = ON_cell.run_parallel(sim_params_list=all_ON_params_list, n_parallel=20)
    data_utils.save_var((ON_cell_rec_data_list, all_ON_params_list), os.path.join('data', 'ON_cell_rec_data.pkl'))
else:
    ON_cell_rec_data_list, all_ON_params_list = data_utils.load_var(os.path.join('data', 'ON_cell_rec_data.pkl'))

In [None]:
rec_time = ON_cell_rec_data_list[0][1].copy() + 1

In [None]:
ON_rec_data_sorted = {mode: [] for mode in modes}

for i, (rec_data_i, params_i) in enumerate(zip(ON_cell_rec_data_list, all_ON_params_list)):
    mode = modes[i%len(modes)]
    
    print('Mode: ', mode.ljust(12), ' Zero:', end='\t')
    for k, v in params_i.items():
        if 'cd_' in k and v == 0:
            print(k, end=',')
    
    ON_rec_data_sorted[mode].append(rec_data_i[0])
    print()

### Plot data.

In [None]:
plot_Vm(rec_data_sorted=ON_rec_data_sorted, rec_time=rec_time)

In [None]:
plot_rate(rec_data_sorted=ON_rec_data_sorted, rec_time=rec_time)

In [None]:
compare_to_full_model(rec_data_sorted=ON_rec_data_sorted, rec_time=rec_time)

### Show loss

In [None]:
ON_losses = compute_losses(loss=cell2loss[ON_cell.bp_type], rec_data_sorted=ON_rec_data_sorted)
plot_losses(ON_losses)

In [None]:
ON_remove_modes = get_modes_to_remove(losses=ON_losses)

# Simulate reduced cells

In [None]:
def get_rm_multiple_channels_params_list(params_list, remove_modes):
    redcued_params_list = []
    for params in params_list:
        reduced_params = params.copy()
        
        if 'no_Na' in remove_modes:
            reduced_params = remove_Na_channels(reduced_params)
        if 'no_somaCa' in remove_modes:
            reduced_params = remove_somaCa_channels(reduced_params)
        if 'no_Kv' in remove_modes:
            reduced_params = remove_Kv_channels(reduced_params)
        if 'no_HCN' in remove_modes:
            reduced_params = remove_HCN_channels(reduced_params)
        if 'no_Kir' in remove_modes:
            reduced_params = remove_Kir_channels(reduced_params)
        if 'no_L_at' in remove_modes:
            reduced_params = remove_L_at_channels(reduced_params)
        if 'no_T_at' in remove_modes:
            reduced_params = remove_T_at_channels(reduced_params)

        redcued_params_list.append(reduced_params)
        
    return redcued_params_list

In [None]:
reduced_OFF_params_list =\
    get_rm_multiple_channels_params_list(params_list=cell2params_list[OFF_cell.bp_type], remove_modes=OFF_remove_modes)
reduced_ON_params_list =\
    get_rm_multiple_channels_params_list(params_list=cell2params_list[ON_cell.bp_type], remove_modes=ON_remove_modes)

In [None]:
for params1, params2 in zip(reduced_ON_params_list, reduced_OFF_params_list):
    assert params1 != params2

In [None]:
load = True

if not load:
    OFF_reduced_rec_data_list = OFF_cell.run_parallel(sim_params_list=reduced_OFF_params_list, n_parallel=20)
    data_utils.save_var((OFF_reduced_rec_data_list, reduced_OFF_params_list),
                        os.path.join('data', 'OFF_reduced_rec_data_list.pkl'))
else:
    OFF_reduced_rec_data_list, redcued_OFF_params_list =\
        data_utils.load_var(os.path.join('data', 'OFF_reduced_rec_data_list.pkl'))

In [None]:
load = True

if not load:
    ON_reduced_rec_data_list = ON_cell.run_parallel(sim_params_list=reduced_ON_params_list, n_parallel=20)
    data_utils.save_var((ON_reduced_rec_data_list, reduced_ON_params_list),
                        os.path.join('data', 'ON_reduced_rec_data_list.pkl'))
else:
    ON_reduced_rec_data_list, redcued_ON_params_list =\
        data_utils.load_var(os.path.join('data', 'ON_reduced_rec_data_list.pkl'))

In [None]:
def print_loss_reduced(reduced_rec_data_list, rec_data_sorted, cell):

    for rec_data_reduced, rec_data_all in zip(reduced_rec_data_list, rec_data_sorted['all']):
        rec_data_loss_reduced = cell2loss[cell.bp_type].calc_loss(rec_data={
            'rate': rec_data_reduced[0]['rate BC'].mean(axis=1).values, 'Vm': rec_data_reduced[0]['BC Vm Soma'].values
        })

        rec_data_loss_all = cell2loss[cell.bp_type].calc_loss(rec_data={
            'rate': rec_data_all['rate BC'].mean(axis=1).values, 'Vm': rec_data_all['BC Vm Soma'].values
        })

        print(f"All params: {rec_data_loss_all['total']:.4f}" +
              f" vs. reduced: {rec_data_loss_reduced['total']:.4f}" +
              f" Difference {rec_data_loss_reduced['total']-rec_data_loss_all['total']:.6f}")

In [None]:
print_loss_reduced(reduced_rec_data_list=ON_reduced_rec_data_list, rec_data_sorted=ON_rec_data_sorted, cell=ON_cell)

In [None]:
print_loss_reduced(reduced_rec_data_list=OFF_reduced_rec_data_list, rec_data_sorted=OFF_rec_data_sorted, cell=OFF_cell)

# Save all data for figure

In [None]:
ON_rec_data_sorted['minimal'] = [rec_data[0] for rec_data in ON_reduced_rec_data_list]
OFF_rec_data_sorted['minimal'] = [rec_data[0] for rec_data in OFF_reduced_rec_data_list]

In [None]:
data_utils.save_var(ON_rec_data_sorted, os.path.join('data', 'ON_data_sorted.pkl'))
data_utils.save_var(OFF_rec_data_sorted, os.path.join('data', 'OFF_data_sorted.pkl'))

data_utils.save_var(rec_time, os.path.join('data', 'rec_time.pkl'))