# Test ion channels and create retsim files for CBC optimization

In [None]:
import importlib
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import warnings

import os
import sys

In [None]:
pythoncodepath = os.path.abspath(os.path.join('..', '_pythoncode'))
sys.path = [pythoncodepath] + sys.path
import importhelper
importhelper.addfolders2path(pythoncodepath)

In [None]:
import data_utils

# Cell

In [None]:
stim_t_rng = (1, 32)
predur = 5.0

## Stimulus and target

In [None]:
# Load experimental data
data_folder = os.path.join('..', 'ExperimentalData', 'PreprocessedData')
target_dF_F = pd.read_csv(os.path.join(data_folder, 'ConeData_ReleaseMeanData.csv'))
stimulus    = pd.read_csv(os.path.join(data_folder, 'ConeData_stimulus_time_and_amp_corrected.csv'))

In [None]:
fig, axs = plt.subplots(2,1,figsize=(12,4), subplot_kw=dict(xlim=stim_t_rng))
stimulus.plot(x='Time', ax=axs[0])
target_dF_F.plot(x='Time', ax=axs[1])
plt.show()

## Select optimized cells

In [None]:
# Select folder of optimized cone.
cone_folder = os.path.join('..', 'step1a_optimize_cones')
opt_cone_folder = os.path.join(cone_folder, 'optim_data', 'optimize_cone_submission2')
assert os.path.isdir(opt_cone_folder)

## Create cell

In [None]:
import retsim_cells
importlib.reload(retsim_cells)

cell = retsim_cells.Cone(
    predur=predur, t_rng=stim_t_rng,
    stimulus=stimulus, stim_type='Light',
    cone_densfile       = 'dens_cone_optimize_cone.n',
    nval_file           = 'nval_cone_optimize_cone.n',
    chanparams_file     = 'chanparams_cone_optimize_cone.n',
    expt_file_list      = ['plot_cone'],
    expt_base_file_list = [os.path.join(cone_folder, 'retsim_files', 'expt_optimize_cones.cc')],
    retsim_path=os.path.abspath(os.path.join('..', 'NeuronC', 'models', 'retsim')) + '/'
)

## Set parameters

### Defaults and units

In [None]:
cell.params_default = data_utils.load_var(os.path.join(opt_cone_folder, 'cell_params_default.pkl'))
cell.params_unit = data_utils.load_var(os.path.join(opt_cone_folder, 'cell_params_unit.pkl'))

### Optimized parameters

In [None]:
final_model_output = data_utils.load_var(os.path.join(opt_cone_folder, 'post_data', 'final_model_output.pkl'))

In [None]:
final_model_output['params']

In [None]:
opt_params_list = [final_model_output['params']]

In [None]:
# List params not in optimized params.
for p_name, p_value in cell.params_default.items():
    if p_name not in opt_params_list[0].keys():
        print(p_name, end=', ')
print()

### Cell loss

In [None]:
data_utils.make_dir('data')

cell_loss = data_utils.load_var(os.path.join(opt_cone_folder, 'loss.pkl'))
data_utils.save_var(cell_loss, os.path.join('data', 'cell_loss.pkl'))

### Prepare cells

In [None]:
# Create c++ files.
cell.create_retsim_expt_file(verbose=True)
# Compile c++ files.
!(cd {cell.retsim_path} && make)

In [None]:
cell.init_retsim(verbose=False, print_comps=False, update=False)

# Get cell response with different channels removed

In [None]:
print([p_name for p_name in opt_params_list[0].keys() if 'cd' in p_name])

In [None]:
def make_params_passive(full_params):
    new_params = full_params.copy()
    
    for p_name, p_value in new_params.items():
        if ('cd_' in p_name) and (p_name not in ['cd_Ca_L', 'cd_Ca_P']):
            new_params[p_name] = 0.0            
    
    return new_params

In [None]:
def remove_HCN_channels(full_params):    
    new_params = full_params.copy()
    
    for p_name, p_value in new_params.items():
        if ('cd_H' in p_name):
            new_params[p_name] = 0.0
    return new_params

In [None]:
def remove_Kv_channels(full_params):    
    new_params = full_params.copy()
    
    for p_name, p_value in new_params.items():
        if ('cd_Kv' in p_name):
            new_params[p_name] = 0.0
    return new_params

In [None]:
def remove_ClCa_channels(full_params):
    new_params = full_params.copy()
    
    for p_name, p_value in new_params.items():
        if ('cd_ClCa' in p_name):
            new_params[p_name] = 0.0
    return new_params

In [None]:
def get_reduced_params_list(full_params):
    
    full_params = full_params.copy()
    
    return {
        'all': full_params,
        'passive': make_params_passive(full_params),
        'no_HCN': remove_HCN_channels(full_params),
        'no_Kv': remove_Kv_channels(full_params),
        'no_ClCa': remove_ClCa_channels(full_params),
    }

## Prepare cells

In [None]:
rec_type = 'test'

In [None]:
def prepare_cell(cell):
    cell.rec_type = rec_type
    cell.create_retsim_stim_file()

In [None]:
prepare_cell(cell)

In [None]:
_ = cell.run(plot=True)

### Plot functions

In [None]:
def plot_Vm(rec_data_sorted, rec_time):
    fig, axs = plt.subplots(len(rec_data_sorted),1,figsize=(12,10), sharex=True)
    for ax, (mode, rec_data_list) in zip(axs, rec_data_sorted.items()):
        ax.set_title(mode)
        for rec_data in rec_data_list:
            ax.plot(rec_time, rec_data['Vm 0'])
    plt.tight_layout()

In [None]:
def plot_rate(rec_data_sorted, rec_time):
    fig, axs = plt.subplots(len(rec_data_sorted),1,figsize=(12,10), sharex=True)
    for ax, (mode, rec_data_list) in zip(axs, rec_data_sorted.items()):
        ax.set_title(mode)
        for rec_data in rec_data_list:
            ax.plot(rec_time, rec_data['rate Cone'])
    plt.tight_layout()

In [None]:
def compare_to_full_model(rec_data_sorted, rec_time):
    fig, axs = plt.subplots(len(rec_data_sorted),4,figsize=(12,10), sharex='col', gridspec_kw=dict(width_ratios=[5,1,5,1]))

    for axs_row, (mode, rec_data_list) in zip(axs, rec_data_sorted.items()):
        axs_row[0].set_title(mode)
        for idx, (rec_data, rec_data_all) in enumerate(zip(rec_data_list, rec_data_sorted['all'])):
            Vm_diff = 1e3*(rec_data['Vm 0']-rec_data_all['Vm 0'])
            rate_diff = (rec_data['rate Cone']-rec_data_all['rate Cone'])

            axs_row[0].plot(rec_time, Vm_diff, lw=1)
            axs_row[1].bar(idx, np.mean(Vm_diff)**2)
            axs_row[2].plot(rec_time, rate_diff, lw=1)
            axs_row[3].bar(idx, np.mean(rate_diff)**2)

    for ax in np.append(axs[:,1], axs[:,3]):
        ax.set_ylim((0, np.max([1, ax.get_ylim()[1]])))

    plt.tight_layout(w_pad=0)
    plt.show()

### Loss functions

In [None]:
def compute_losses(loss, rec_data_sorted):
    l_names = ['total'] + list(loss.loss_params.keys())
    modes =rec_data_sorted.keys()

    losses = {mode: {l_name: [] for l_name in l_names} for mode in modes}

    for mode, rec_data_list in rec_data_sorted.items():
        for rec_data in rec_data_list:
            rec_data_loss = loss.calc_loss(rec_data={
                'rate': rec_data['rate Cone'].values, 'Vm': rec_data['Vm 0'].values
            })

            for l_name in l_names:
                losses[mode][l_name].append(rec_data_loss[l_name])
    return losses

In [None]:
def loss2text(loss):
    if loss < 0:
        return '<0'
    elif loss == 0:
        return '0'
    else:
        return f"{loss:.3f}"
        

def plot_values(ax, idx, values):
    ax.plot(idx, np.mean(values), marker='_', markersize=10, c='r', markeredgewidth=2)
    ax.plot(np.full(len(values), idx), values, marker='_', markersize=5, c='orange', alpha=0.5, markeredgewidth=2)
    ax.plot([idx, idx], [np.min(values), np.max(values)], c='k')

    ax.text(idx, np.mean(values),\
            loss2text(np.max(values)) + '\n' + loss2text(np.mean(values)) + '\n' + loss2text(np.min(values)))

In [None]:
def plot_losses(losses):
    fig, axs = plt.subplots(len(losses),2,figsize=(12,10), sharey='col')

    for idx, (mode, loss_dict) in enumerate(losses.items()):
        for ax_row, (l_name, l_values) in zip(axs, loss_dict.items()):
            ax_row[0].set_ylabel(l_name, rotation=0, ha='right')
            plot_values(ax=ax_row[0], idx=idx, values=l_values)
            ax_row[0].axhline(0, color='gray')

            if mode == 'all': ax_row[0].axhline(np.mean(l_values), color='red', alpha=0.3, ls='--')

            add_err = np.abs(l_values) - np.abs(losses['all'][l_name])
            plot_values(ax=ax_row[1], idx=idx, values=add_err)
            ax_row[1].axhline(0, color='gray')

    for ax in axs.flatten():       
        ax.set_xticks(np.arange(len(losses.keys())))
        ax.set_xticklabels(list(losses.keys()))

    plt.tight_layout()

In [None]:
def get_modes_to_remove(losses):
    remove_modes = []
    for mode, loss_dict in losses.items():
        if (np.max(loss_dict['total']) - np.max(losses['all']['total']) < 1e-3) or\
           (np.mean(loss_dict['total']) - np.mean(losses['all']['total']) < 0.0) :
            remove_modes.append(mode)
    print(remove_modes)
    return remove_modes

In [None]:
def reduce_params(params_list, remove_modes):
    redcued_params_list = []
    for params in params_list:
        reduced_params = params.copy()
        
        if 'no_ClCa' in remove_modes:
            reduced_params = remove_ClCa_channels(reduced_params)
        if 'no_Kv' in remove_modes:
            reduced_params = remove_Kv_channels(reduced_params)
        if 'no_HCN' in remove_modes:
            reduced_params = remove_HCN_channels(reduced_params)

        redcued_params_list.append(reduced_params)
        
    return redcued_params_list

## Generate data

In [None]:
all_cone_params = []
for params in opt_params_list:
    all_cone_params.append(get_reduced_params_list(params))

In [None]:
all_cone_params_list = [params_i for params in all_cone_params for params_i in params.values()]
len(all_cone_params_list)

In [None]:
load = False

if not load:
    cone_rec_data_list = cell.run_parallel(sim_params_list=all_cone_params_list, n_parallel=20)
    data_utils.save_var((cone_rec_data_list, all_cone_params_list), os.path.join('data', 'cone_rec_data.pkl'))
else:
    cone_rec_data_list, all_cone_params_list = data_utils.load_var(os.path.join('data', 'cone_rec_data.pkl'))

## Plot data

In [None]:
rec_time = cone_rec_data_list[0][1].copy()

In [None]:
modes = list(all_cone_params[0].keys())

In [None]:
cone_rec_data_sorted = {mode: [] for mode in modes}

for i, (rec_data_i, params_i) in enumerate(zip(cone_rec_data_list, all_cone_params_list)):
    mode = modes[i%len(modes)]
    
    print('### Mode: ', mode)
    print('\tZero channels:', end='\t')
    for k, v in params_i.items():
        if 'cd_' in k and v == 0:
            print(k, end=',')
    
    cone_rec_data_sorted[mode].append(rec_data_i[0])
    print()

### Plot traces

In [None]:
plot_Vm(rec_data_sorted=cone_rec_data_sorted, rec_time=rec_time)

In [None]:
plot_rate(rec_data_sorted=cone_rec_data_sorted, rec_time=rec_time)

In [None]:
compare_to_full_model(rec_data_sorted=cone_rec_data_sorted, rec_time=rec_time)

### Plot loss

In [None]:
cone_losses = compute_losses(loss=cell_loss, rec_data_sorted=cone_rec_data_sorted)
plot_losses(cone_losses)

# Remove unessesary channels

In [None]:
cone_remove_modes = get_modes_to_remove(losses=cone_losses)
reduced_cone_params_list = reduce_params(params_list=opt_params_list, remove_modes=cone_remove_modes)

In [None]:
load = False

if not load:
    reduced_cone_rec_data_list = cell.run_parallel(sim_params_list=reduced_cone_params_list, n_parallel=20)
    data_utils.save_var(
        (reduced_cone_rec_data_list, reduced_cone_params_list),
        os.path.join('data', 'reduced_cone_rec_data_list.pkl')
    )
else:
    reduced_cone_rec_data_list, redcued_cone_params_list =\
        data_utils.load_var(os.path.join('data', 'reduced_cone_rec_data_list.pkl'))

In [None]:
def print_loss_reduced(reduced_rec_data_list, rec_data_sorted):

    for rec_data_reduced, rec_data_all in zip(reduced_rec_data_list, rec_data_sorted['all']):
        rec_data_loss_reduced = cell_loss.calc_loss(rec_data={
            'rate': rec_data_reduced[0]['rate Cone'].values, 'Vm': rec_data_reduced[0]['Vm 0'].values
        })

        rec_data_loss_all = cell_loss.calc_loss(rec_data={
            'rate': rec_data_all['rate Cone'].values, 'Vm': rec_data_all['Vm 0'].values
        })

        print(f"All params: {rec_data_loss_all['total']:.4f}" +
              f" vs. reduced: {rec_data_loss_reduced['total']:.4f}" +
              f" Difference {rec_data_loss_reduced['total']-rec_data_loss_all['total']:.6f}")

In [None]:
print_loss_reduced(reduced_rec_data_list=reduced_cone_rec_data_list, rec_data_sorted=cone_rec_data_sorted)

## Simulate CBC light stimulus

In [None]:
cbc_stimulus = pd.read_csv(os.path.join(data_folder, 'Franke2017_stimulus_time_and_amp_corrected.csv'))
cbc_stimulus.plot(x='Time', figsize=(12,2));

In [None]:
cell.retsim_stim_file_base = 'Light_stimulus_optimize_CBCs'
cell.set_stim(cbc_stimulus)
cell.create_retsim_stim_file()

In [None]:
cell.rec_type = 'test'
cbc_stim_response = cell.run_parallel(sim_params_list=reduced_cone_params_list+opt_params_list, n_parallel=4)

In [None]:
fig, axs = plt.subplots(1,2,figsize=(12,3))
axs[0].plot(cbc_stim_response[0][0]['rate Cone'])
axs[0].plot(cbc_stim_response[1][0]['rate Cone'])
axs[1].set_title('Difference')
axs[1].plot(cbc_stim_response[0][0]['rate Cone'] - cbc_stim_response[1][0]['rate Cone'])
plt.show()

# Save all data for figure

In [None]:
cone_rec_data_sorted['minimal'] = [rec_data[0] for rec_data in reduced_cone_rec_data_list]

In [None]:
data_utils.save_var(cone_rec_data_sorted, os.path.join('data', 'cone_data_sorted.pkl'))
data_utils.save_var(rec_time, os.path.join('data', 'rec_time.pkl'))

# Create retsim files for CBC optimization

In [None]:
# Select which parameters to remove. e.g. cd_ClCa for the paper cone.
final_cone_params = final_model_output['params'].copy()
final_cone_params['cd_ClCa'] = 0.0

In [None]:
retsim_params = cell.params_default.copy()
retsim_params.update(final_cone_params)
retsim_params = cell.add_units_to_params(retsim_params)
retsim_params = cell.add_adaptive_cpl_params(retsim_params)
cell.print_params(retsim_params)

In [None]:
rec_data = cell.run(sim_params=final_cone_params, verbose=False)

## Save to file

Save the optimized parameters to retsim files. <br>
If you want to use them in other experiments, read user information below.

In [None]:
opt_cone_retsim_folder = os.path.join(opt_cone_folder, 'retsim')

In [None]:
os.listdir(opt_cone_retsim_folder)

In [None]:
inputfiles =[
    'chanparams_cone_optimize_cone.n',
    'nval_cone_optimize_cone.n',
    'dens_cone_optimize_cone.n',
]

for file in inputfiles: assert file in os.listdir(opt_cone_retsim_folder)

outputfiles =[
    'chanparams_cone_optimized.n',
    'nval_cone_optimized.n',
    'dens_cone_optimized.n',
]

In [None]:
import update_retsim_param_files

update_retsim_param_files.find_and_replace_in_files(
    inputfiles=inputfiles,
    input_folder=opt_cone_retsim_folder,
    outputfiles=outputfiles,
    output_folder=os.path.join(opt_cone_retsim_folder, '_optimized'),
    params=retsim_params,
)

## README: How to use optimized cone for CBC experiments

In order to use those file for the CBC experiments, you have to do the following manually.<br>
However, if you have not done the full inference, e.g. if you used the test mode, just don't do anything from the following, as the files are already prepared.

- Add the cone-density file to the retsim folder, and state the files name as a parameter in your CBC experiments
- Copy the information from the chanparams-file and add it to the CBC-chanparams-file you want to use. Make sure channels that are used but have no optimized tau or off have values of 1 and 0 (i.e. defaults) resepetively
- Copy the information from the nval-file and add it to the CBC-nval-file you want to use.
- Make sure you don't overwrite optimization variables of the CBC files.
- Make sure you add the information to the correct (or both) CBC types (ON=dbp1 and OFF=hbp1)
- Use the test_cone function in your CBC notebook to test if the cones produces the output that was optimized.