# Optimize stimuli to target OFF and ON CBCs

- [Create cells](#Cells)
- [Create equilibrium files for the cells that can be loaded](#Equilibrium-(EQ)-files)
- [Prepare them for optimization](#Prepare-optimization)
- [Genereate prior samples for both cells as targets](#Pior-samples)
- [Optimize for both cell types](#Optimize-for-specific-BC)
- [Remove ion channels and test response](#Test-why-cells-respond-differently)

## Select mode: full_inference  / load_only / test

- *full_inference*
    - Runs the whole inference. $\Rightarrow$ **COMSOL is required.**
    - Takes a long time
- *load_only*
    - Will not generate new samples, but loads the data generated for the paper.
- *test*
    - Runs the whole inference, but with fewer samples. $\Rightarrow$ **COMSOL is required.**
    - Illustrates how the inference works, without spending to much CPU power and time.
    - However, it might lead to problems, because too few samples are generated leading to bad inference.
    - Don't use these results in subsequent steps.

In [None]:
#inference_mode = 'test'
#inference_mode = 'full_inference'
inference_mode = 'load_only'

# Imports

In [None]:
import importlib
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from multiprocessing import Pool

In [None]:
import os
import sys

In [None]:
pythoncodepath = os.path.abspath(os.path.join('..', '_pythoncode'))
sys.path = [pythoncodepath] + sys.path
import importhelper
importhelper.addfolders2path(pythoncodepath)

In [None]:
import data_utils

In [None]:
import optimize_stimulus_utils as utils
importlib.reload(utils);

# Cells

In [None]:
predur = 20. # Note that this is only run once during initialization.

predur_stim = 0.02
stimdur = 0.04
postdur_stim = 0.02

t_rng = (0, predur_stim+stimdur+postdur_stim)

rec_dur = t_rng[1] - predur_stim

In [None]:
import retsim_cells
importlib.reload(retsim_cells)

cbc_folder = os.path.join('..', 'step2a_optimize_cbc')

kwargs = dict(
    stim_type='Vext',
    make_cones=False,
    predur=predur,
    nval_file='nval_optimize_CBCs.n',
    expt_base_file=os.path.join(
        cbc_folder, 'retsim_files', 'expt_CBC_base.cc'),
    retsim_path=os.path.abspath(os.path.join(
        '..', 'NeuronC', 'models', 'retsim')) + '/'
)

# Create ON cell.
ON_cell = retsim_cells.CBC(
    bp_type='CBC5o',
    expt_file='optimize_stimulus_ON',
    comsol_compfile='ON_BC_optimize_stimulation.csv',
    retsim_stim_file_base='ON_BC_stimulus_optimize_stimulation',
    bp_densfile='dens_CBC5o_optimize_ON.n',
    chanparams_file='chanparams_CBC5o_optimize_ON.n',
    **kwargs,
)

# Create OFF cell.
OFF_cell = retsim_cells.CBC(
    bp_type='CBC3a',
    expt_file='optimize_stimulus_OFF',
    comsol_compfile='OFF_BC_optimize_stimulation.csv',
    retsim_stim_file_base='OFF_BC_stimulus_optimize_stimulation',
    bp_densfile='dens_CBC3a_optimize_OFF.n',
    chanparams_file='chanparams_CBC3a_optimize_OFF.n',
    **kwargs,
)

cells = [ON_cell, OFF_cell]

In [None]:
def reset_cells():
    for cell in cells:
        cell.rec_type = 'optimize'
        cell.set_n_cones = 0

        cell.sim_dt  = 1e-5
        cell.syn_dt  = 1e-5
        cell.rec_dt  = 1e-5
        cell.stim_dt = 1e-4
        
        cell.update_t_rng(t_rng)
        cell.predur = predur
        
reset_cells()

## Prepare cells

In [None]:
for cell in cells: cell.create_retsim_expt_file(verbose=True) # Create c++ files.

In [None]:
!(cd {cell.retsim_path} && make) # Compile c++ files.

## Parameters

In [None]:
cell2folder = {
    'ON': os.path.join(cbc_folder, 'optim_data', 'optimize_ON_submission2'),
    'OFF': os.path.join(cbc_folder, 'optim_data', 'optimize_OFF_submission2'),
}

### Defaults and units

In [None]:
for cell, cell_name in zip([ON_cell, OFF_cell], ['ON', 'OFF']):
    cell.params_default = data_utils.load_var(os.path.join(cell2folder[cell_name], 'cell_params_default.pkl'))
    cell.params_unit = data_utils.load_var(os.path.join(cell2folder[cell_name], 'cell_params_unit.pkl'))
    cell.params_default.update(data_utils.load_var(os.path.join(cell2folder[cell_name], 'final_cpl_dict.pkl')))

### Posterior parameters

In [None]:
N_param_sets = 5

cell2params_list = []

for cell in cells:
    optim_folder = cell2folder['OFF' if cell.is_OFF_bp else 'ON']
    samples = data_utils.load_var(os.path.join(optim_folder, 'post_data', 'post_model_output_list.pkl'))
    d_sort_idxs = np.argsort([sample['loss']['total'] for sample in samples])
    cell2params_list.append([samples[idx]['params'] for idx in d_sort_idxs[:N_param_sets]])

In [None]:
ON_cell_params = cell2params_list[np.argwhere(np.asarray(cells)==ON_cell).flat[0]]
OFF_cell_params = cell2params_list[np.argwhere(np.asarray(cells)==OFF_cell).flat[0]]

ON_rrps = np.asarray([ON_params_i['b_rrp'] for ON_params_i in ON_cell_params])
OFF_rrps = np.asarray([OFF_params_i['b_rrp'] for OFF_params_i in OFF_cell_params])

In [None]:
cell2maxrelease = {
    'CBC5o': ON_rrps + rec_dur*8,
    'CBC3a': OFF_rrps + rec_dur*8,
}

## Compartments

In [None]:
ON_cell.set_rot(mxrot=-90, myrot=0)
ON_im = ON_cell.init_retsim(verbose=False, print_comps=True, update=True)

OFF_cell.set_rot(mxrot=-90, myrot=60)
OFF_im = OFF_cell.init_retsim(verbose=False, print_comps=True, update=True)

# Plot.
fig, axs = plt.subplots(1,2,figsize=(14, 8))
axs[0].imshow(ON_im)
axs[1].imshow(OFF_im)
for ax in axs: ax.axis('off')
plt.show()

# Equilibrium (EQ) files

In [None]:
def set_zero_stim(cell):
    zero_stim = {'Time': np.array([0, 10])}
    zero_stim.update({'C'+str(i): np.array([0, 0]) for i in range(cell.n_bc_comps)})
    cell.set_stim(pd.DataFrame(zero_stim))
    cell.create_retsim_stim_file(stim_idx=0)

In [None]:
def run_cell_with_params(cell, sim_params, verbose=False):
    return cell.run(sim_params=sim_params, reset_retsim_stim=False, stim_idx=0, plot=False, verbose=verbose)[0]

## Set names

In [None]:
for cell, cell_params in zip(cells, cell2params_list):
    for i, cell_params_i in enumerate(cell_params):
        cell_params_i['eqfile'] = cell.bp_type + '_optimize_stim_pidx_'+ str(i) + '.eq'

## Create

In [None]:
def prepare_cells_create_eq(predur=None):
    reset_cells()
    
    if predur is not None:
        for cell in cells: cell.predur = predur
    
    for cell in cells:
        cell.params_default['run_predur_only'] = 0
        cell.params_default['load_eq'] = 0
        cell.params_default['save_eq'] = 1
        
        set_zero_stim(cell)

In [None]:
allow_skip_eqs = True # Skip if they exist.

all_eqs_exist = True
for cell, cell_params in zip(cells, cell2params_list):
    for i, cell_params_i in enumerate(cell_params):
        if cell_params_i['eqfile'] not in os.listdir(cell.retsim_path):
            all_eqs_exist = False
            print(cell_params_i['eqfile'], 'and maybe others dont exist')
            break

In [None]:
if (not allow_skip_eqs) or (not all_eqs_exist):
    prepare_cells_create_eq()

    parallel_params_list = []
    for cell, cell_params in zip(cells, cell2params_list):
        for cell_params_i in cell_params:
            parallel_params_list.append((cell, cell_params_i, True))

    with Pool(processes=20) as pool:
        eq_rec_data_list = pool.starmap(run_cell_with_params, parallel_params_list);
else:
    eq_rec_data_list is None

## Load

In [None]:
def prepare_cells_load_eq(post_load_predur=0.5):
    reset_cells()

    for cell in cells:
        cell.params_default['run_predur_only'] = 0
        cell.params_default['save_eq'] = 0
        cell.params_default['load_eq'] = 1
        cell.params_default['post_load_predur'] = post_load_predur
        cell.params_default['rec_predur'] = predur_stim
        cell.params_default['post_load_predur_timinc'] = 1e-4

        set_zero_stim(cell)

In [None]:
prepare_cells_load_eq()
    
parallel_params_list = []
for cell, cell_params in zip(cells, cell2params_list):
    for cell_params_i in cell_params:
        parallel_params_list.append((cell, cell_params_i, True))
        
with Pool(processes=20) as pool:
    load_eq_rec_data_list = pool.starmap(run_cell_with_params, parallel_params_list);

In [None]:
import plot_eq_samples
importlib.reload(plot_eq_samples);

if eq_rec_data_list: plot_eq_samples.plot_eq_rec_data(eq_rec_data_list, parallel_params_list)
plot_eq_samples.plot_eq_rec_data(load_eq_rec_data_list, parallel_params_list)

# Prepare optimization

This and the following steps require the software COMSOL.

Do the following if you want to run the code below:

- Copy the folder *COMSOL2retsim_COMSOL* to a machine that can run COMSOL.
- Open the COMSOL file 
- Run the notebook *comsol2retsim.ipynb* in this folder.
- Make sure that this notebook, and *comsol2retsim.ipynb*, and COMSOL have read and write permissions for the folder *COMSOL2retsim_interface*, since this folder is the inferace. They communicate by writing and reading files to this folder.

If you can run COMSOL and SNPE/retsim on the same machine, this might seem a little cumbersome, but it you should still work.

In [None]:
n_params_stim = 4
normalize_stim = True
spline_mode = 'cubic'

In [None]:
if inference_mode == 'load_only':
    output_folder = 'optimize_stimulus_submission2'
elif inference_mode in ['test', 'full_inference']:
    output_folder = 'optimize_stimulus'
else:
    raise NotImplementedError()

## Parameters and optimizer

In [None]:
from param_funcs import DummyParameters
params = DummyParameters(n_params_stim)

In [None]:
import optimize_stimulus_protocols
importlib.reload(optimize_stimulus_protocols)

optim = optimize_stimulus_protocols.OptimizerStimulusMultiParams(
    cells=cells, cell_params_list=cell2params_list,
    t_rng=t_rng, params=params,
    comsol_filename="single_flat_2D_rd.mph",
    comsol_batch_size=4,
    comsol_samples_prefix='stimulus',
    comsol_samples_suffix='.csv',
    comsol_global_parameters={'tmax': t_rng[1]},
    n_parallel=30, n_reps=1,
    predur_stim=predur_stim,
    output_folder=output_folder,
    set_comsol2retsim_folder=os.path.abspath('COMSOL2retsim_interface') # Ensure this is the right folder to the interface.
)

In [None]:
import comsol_comp_utils
importlib.reload(comsol_comp_utils)

for cell in optim.cells:
    comsol_comp_utils.center_xy_region(cell=cell, region='R1')
    comsol_comp_utils.create_comsol_comp_file(cell=cell, z_soma=30, verbose=True)

In [None]:
class dummy_loss():
    def calc_loss(self, rec_data_dict, verbose=False):
        return {'f0': 0.0, 'f1': 0.0, 'total': 0.0}

optim.set_loss(dummy_loss())

## Stimulus

In [None]:
import stim_funcs
importlib.reload(stim_funcs)

stim_generator = stim_funcs.StimulusGenerator(
    stim_time=optim.stim_time, n_params=n_params_stim,
    stim_mode='charge neutral',
    predur=predur_stim, postdur=postdur_stim,
    stim_mulitplier=0.5e-6, normalize_stim=normalize_stim, spline_mode=spline_mode,
)

optim.set_stim_generator(stim_generator)
data_utils.save_var(stim_generator, os.path.join('optim_data', optim.output_folder, 'stim_generator.pkl'))

In [None]:
optim.init_rec_data(allow_loading=(inference_mode=='load_only'),
                    force_loading=False, verbose=True, eq_exists=True)

In [None]:
optim.plot_init_rec_data()

## Prior

In [None]:
from delfi import distribution

mean = np.full(n_params_stim, 0)
std  = np.full(n_params_stim, 0.3)

prior = distribution.Gaussian(m=np.array(mean), S=np.diag(np.array(std)**2))

In [None]:
import plot_sampling_dists
importlib.reload(plot_sampling_dists);

PP = plot_sampling_dists.SamplingDistPlotter(
    params=params, prior=prior, posterior_list=[],
    lbs=np.full(params.p_N, -1.5), ubs=np.full(params.p_N, 1.5)
)
PP.plot_sampling_dists_1D(plot_peak_lines=False, figsize=(12,8), opt_x=False)

# Tests

In [None]:
test_params_list = [prior.gen(1).flatten() for _ in range(4)]

In [None]:
for params_i in test_params_list:
    stim_generator.create_stimulus(params=params_i, plot=True);

In [None]:
if inference_mode != 'load_only':
    model_output_list = optim.run_parallel(params_list=test_params_list, verbose=True)

In [None]:
if inference_mode != 'load_only':
    N_stims = len(model_output_list)
    N_cell_params = len(optim.cell_params_list[0])

    fig, axs = plt.subplots(N_stims, 3, figsize=(12,N_stims*1.4), sharex='col', sharey='col')

    for idx, model_output in enumerate(model_output_list):
        axs[idx,0].set_ylabel('I (uA)')
        for rrps in optim.cells:
            axs[idx,0].plot(
                optim.stim_time,
                optim.stim_generator.create_stimulus(params=model_output['params'])*1e6,
                c='darkred'
            )

        axs[idx,0].set_title('Loss = {:.2f}'.format(model_output['loss']['total']))
        for cell_idx, cell in enumerate(optim.cells):
            axs[0,1+cell_idx].set_title(cell.bp_type + ' rate')
            axs[idx,1+cell_idx].plot(
                optim.rec_data[cell.bp_type]['Time'], model_output[cell.bp_type]['rate'][0,:,:],
            )
            axs[idx,1+cell_idx].legend(
                [f"{(mrate*rec_dur)/maxrelease:.3f}" for mrate, maxrelease in\
                     zip(np.mean(model_output[cell.bp_type]['rate'][0,:,:], axis=0), cell2maxrelease[cell.bp_type])],
                fontsize=8, loc='upper right'
            )

    plt.tight_layout()

# Prior samples

In [None]:
from os import environ
environ["MKL_THREADING_LAYER"] = "GNU"

import gpu_test
assert gpu_test.run(verbose=False)

In [None]:
import delfi_funcs
importlib.reload(delfi_funcs); 

delfi_optim_prior = delfi_funcs.DELFI_Optimizer(
    optim=optim, prior=prior, n_parallel=20,
    gen_minibatch=4, scalar_loss=True,
    post_as_truncated_normal=False,
    samples_folder='prior_samples',
    backups_folder='prior_backups',
    snpe_folder='prior_snpe',
)

if inference_mode != 'load_only':

    delfi_optim_prior.init_SNPE(
        verbose                 = False,
        pseudo_obs_dim          = 0,
        pseudo_obs_n            = 1,
        kernel_bandwidth        = 0.25,
        kernel_bandwidth_perc   = 20,
        kernel_bandwidth_min    = 1,
        pseudo_obs_use_all_data = False,
        n_components            = 1,
    )    

    delfi_optim_prior.run_SNPE(
        max_duration_minutes  = 60*24,
        max_rounds            = 1,
        n_samples_per_round   = 400,
        continue_optimization = False,
        load_init_tds         = False,
    )

In [None]:
data_utils.save_var(ON_rrps, os.path.join(delfi_optim_prior.retsim_folder, 'ON_rrps.pkl'))
data_utils.save_var(OFF_rrps, os.path.join(delfi_optim_prior.retsim_folder, 'OFF_rrps.pkl'))

# Optimize for specific BC

- Select cell target
- [Define loss function](#Loss)
    - [Update prior samples loss with that loss function](#Update-loss-of-prior-samples)
- [Run inference](#Inference)
- [Show results](#Inference-results)
- [Export data for figure](#Export-data)
- **Repeat these steps for other cell**

In [None]:
# Select target

cell_target = 'CBC5o'
#cell_target = 'CBC3a'

## Loss

In [None]:
import loss_funcs_stimulus
importlib.reload(loss_funcs_stimulus);

loss = loss_funcs_stimulus.LossOptimizeStimulationMultiParams(
    init_rec_data=optim.rec_data, cell_target=cell_target, mode='maxrelbase',
    p=1, maxrel=cell2maxrelease
)
optim.set_loss(loss)

### Update loss of prior samples

In [None]:
initial_samples = data_utils.load_var(os.path.join(delfi_optim_prior.samples_folder, 'delfi_samples_r0.pkl'))
initial_samples['loss']['total'].size

In [None]:
def get_rec_data_i(i, samples):
    return {cell: {'rate': np.expand_dims(samples[cell]['rate'][i,:,:], 0),
                   'Vm':   np.expand_dims(samples[cell]['Vm'][i,:,:], 0)} for cell in loss.cells}

In [None]:
loss.calc_loss(get_rec_data_i(i=4, samples=initial_samples))

In [None]:
updated_initial_samples = initial_samples.copy()

for smp_idx in range(initial_samples['loss']['total'].size):
    updated_loss = loss.calc_loss(get_rec_data_i(i=smp_idx, samples=initial_samples))
    
    for key, value in updated_loss.items():
        updated_initial_samples['loss'][key][smp_idx] = value

## Inference

In [None]:
delfi_optim = delfi_funcs.DELFI_Optimizer(
    optim=optim, prior=prior, n_parallel=20,
    gen_minibatch=4, scalar_loss=True,
    post_as_truncated_normal=False,
    samples_folder=f'target_{cell_target}_samples',
    backups_folder=f'target_{cell_target}_backups',
    snpe_folder=f'target_{cell_target}_snpe',
)

In [None]:
data_utils.save_var(updated_initial_samples,
                    os.path.join(delfi_optim.samples_folder, 'delfi_samples_r0.pkl'))

### Plot prior samples

In [None]:
samples, n_samples, d_sort_index = delfi_optim.load_samples(
    files=[os.listdir(delfi_optim.samples_folder)[0]],
    concat_traces=True, list_traces=False, return_sort_idx=True,
    return_n_samples=True, verbose=True
)

d_min_idx = d_sort_index[0]
print('\nd_min = {:.5f}'.format(samples['loss']['total'][d_min_idx]))

In [None]:
utils.plot_samples(optim, samples, d_sort_index, cell_target, plot_best_n=10, plot_worst_n=10)

### SNPE

In [None]:
if inference_mode != 'load_only':

    if not(continue_optimization or only_load_data):
        delfi_optim.init_SNPE(
            verbose                 = False,
            pseudo_obs_dim          = 0,
            pseudo_obs_n            = 1,
            kernel_bandwidth        = 0.25,
            kernel_bandwidth_perc   = 20,
            pseudo_obs_use_all_data = False,
            n_components            = 2,
        )

    delfi_optim.nn_epochs    = 200
    delfi_optim.nn_minibatch = 20

    if not only_load_data:
        delfi_optim.run_SNPE(
            max_duration_minutes  = 60*24,
            max_rounds            = 2,
            n_samples_per_round   = 100,
            continue_optimization = continue_optimization,
            load_init_tds         = load_init_tds,
        )

## Inference results

In [None]:
# Load data.
inf_snpes            = data_utils.load_var(os.path.join(delfi_optim.snpe_folder, 'inf_snpes.pkl'))
sample_distributions = data_utils.load_var(os.path.join(delfi_optim.snpe_folder, 'sample_distributions.pkl'))
logs                 = data_utils.load_var(os.path.join(delfi_optim.snpe_folder, 'logs.pkl'))
pseudo_obs           = data_utils.load_var(os.path.join(delfi_optim.snpe_folder, 'pseudo_obs.pkl'))
kernel_bandwidths    = data_utils.load_var(os.path.join(delfi_optim.snpe_folder, 'kernel_bandwidths.pkl'))

# Split prior and posteriors.
prior, posteriors = sample_distributions[0], sample_distributions[1:]

In [None]:
import plot_obs_and_bw
importlib.reload(plot_obs_and_bw)

plot_obs_and_bw.plot(pseudo_obs, kernel_bandwidths)
plot_obs_and_bw.plot_logs(logs)

### Load samples

In [None]:
samples, n_samples, d_sort_index = delfi_optim.load_samples(
    concat_traces=True, list_traces=False, return_sort_idx=True,
    return_n_samples=True, verbose=True
)

d_min_idx = d_sort_index[0]
print('\nd_min = {:.5f}'.format(samples['loss']['total'][d_min_idx]))

In [None]:
utils.plot_samples(optim, samples, d_sort_index, cell_target, plot_best_n=6, plot_worst_n=2)

### Plot posterior

In [None]:
import plot_sampling_dists
importlib.reload(plot_sampling_dists);

PP = plot_sampling_dists.SamplingDistPlotter(
    params=optim.params, prior=delfi_optim.prior, posterior_list=posteriors,
    lbs=np.full(n_params_stim, -2), ubs=np.full(n_params_stim, 2)
)

PP.plot_sampling_dists_1D(
    params=None, plot_peak_lines=False, figsize=(12,8), opt_x=True
)

In [None]:
PP.plot_correlation()

## Export data

In [None]:
post_folder = os.path.join(delfi_optim.general_folder, f'post_data_{cell_target}')
data_utils.make_dir(post_folder)
post_folder

### Posterior data

##### Samples from posterior

In [None]:
post_params = posteriors[-1].gen(100)
post_stimuli = np.full((post_params.shape[0], optim.stim_time.size), np.nan)

for i, post_params_i in enumerate(post_params):
    if i %10 == 0: print(i, end=', ')        
    post_stimuli[i,:] = stim_generator.create_stimulus(params=post_params_i)

In [None]:
fig, axs = plt.subplots(1,2,figsize=(12,3))
axs[0].plot(optim.stim_time, np.mean(post_stimuli, axis=0))
axs[0].fill_between(
    optim.stim_time,
    np.mean(post_stimuli, axis=0)-np.std(post_stimuli, axis=0),
    np.mean(post_stimuli, axis=0)+np.std(post_stimuli, axis=0),
    alpha=0.3
)

axs[1].plot(optim.stim_time, post_stimuli[:20,:].T)

plt.show()

In [None]:
data_utils.save_var(post_stimuli, os.path.join(post_folder, 'post_sampled_stimuli.pkl'))

### General data

In [None]:
rec_time = optim.rec_data[optim.cells[0].bp_type]['Time']
rec_time -= rec_time[0]

In [None]:
data_utils.save_var(optim.stim_time, os.path.join(delfi_optim.general_folder, 'stim_time.pkl'))
data_utils.save_var(rec_time, os.path.join(delfi_optim.general_folder, 'rec_time.pkl'))
data_utils.save_var(predur_stim, os.path.join(delfi_optim.general_folder, 'predur_stim.pkl'))

In [None]:
print(delfi_optim.general_folder)

### Best samples

In [None]:
n_best_export = 100

best_stimuli_params = np.full((n_best_export, n_params_stim), np.nan)
best_stimuli = np.full((n_best_export, optim.stim_time.size), np.nan)

for i in range(n_best_export):
    best_stimuli_params[i,:] = samples['params'][d_sort_index[i]]
    best_stimuli[i,:] = stim_generator.create_stimulus(params=samples['params'][d_sort_index[i]])
    
data_utils.save_var(best_stimuli_params, os.path.join(post_folder, 'best_stimuli_params.pkl'))
data_utils.save_var(best_stimuli, os.path.join(post_folder, 'best_stimuli.pkl'))

 #####  **$\Rightarrow$ Repeat for other cell!**

# Test why cells respond differently

In [None]:
rec_time = optim.rec_data[optim.cells[0].bp_type]['Time']
rec_time -= rec_time[0]

In [None]:
cell2best_stimuli = {}
for cell in cells:
    post_folder = os.path.join('optim_data', optim.output_folder, 'post_data_' + cell.bp_type)
    cell2best_stimuli[cell.bp_type] = data_utils.load_var(os.path.join(post_folder, 'best_stimuli_params.pkl'))

In [None]:
stim_params_list  = [cell2best_stimuli['CBC5o'][0], cell2best_stimuli['CBC5o'][1],
                     cell2best_stimuli['CBC3a'][0], cell2best_stimuli['CBC3a'][1]]

In [None]:
for params in stim_params_list:
    plt.plot(optim.stim_time, stim_generator.create_stimulus(params)*1e6)
plt.ylabel(r'Current ($\mu A$)');

In [None]:
from copy import deepcopy

test_optim = deepcopy(optim)
cells = test_optim.cells

test_optim.DEBUG = True
test_optim.raw_data_labels = []

In [None]:
rm_ch_folder = os.path.join('optim_data', test_optim.output_folder, 'removed_ion_channels')
data_utils.make_dir(rm_ch_folder)

## Generate reference solutions

In [None]:
# Create stimuli
if inference_mode == 'load_only':
    cell2stims = data_utils.load_var(os.path.join(rm_ch_folder, 'cell2stims.pkl'))
else:
    test_optim.set_comsol_input(stim_params_list)
    test_optim.run_comsol(verbose=True)
    test_optim.get_comsol_output(n_stimuli=len(stim_params_list))
    
    cell2stims = {cell.bp_type: deepcopy(cell.stim) for cell in test_optim.cells}
    data_utils.save_var(cell2stims, os.path.join(rm_ch_folder, 'cell2stims.pkl'))
    
# Set stimuli
for cell in test_optim.cells:
    cell.set_stim(cell2stims[cell.bp_type])

In [None]:
if inference_mode != 'load_only':
    test_optim.cell_params_list = deepcopy(optim.cell_params_list)
    test_optim.rec_type = 'heatmap_vm'
    ref_rec_data_list_vm = test_optim.run_parallel(
        params_list=stim_params_list, verbose=True, skip_comsol=True)
    data_utils.save_var(ref_rec_data_list_vm, os.path.join(rm_ch_folder, 'ref_rec_data_list_vm.pkl'))
else:
    ref_rec_data_list_vm = data_utils.load_var(os.path.join(rm_ch_folder, 'ref_rec_data_list_vm.pkl'))

In [None]:
utils.plot_Vm_of_z(rec_time, ref_rec_data_list_vm, ON_cell, OFF_cell, sharey=True, rm_offset=True)

In [None]:
if inference_mode != 'load_only':
    test_optim.rec_type = 'synapses'
    ref_rec_data_list = test_optim.run_parallel(
        params_list=stim_params_list, verbose=True, skip_comsol=True)
    data_utils.save_var(ref_rec_data_list, os.path.join(rm_ch_folder, 'ref_rec_data_list.pkl'))
else:
    ref_rec_data_list = data_utils.load_var(os.path.join(rm_ch_folder, 'ref_rec_data_list.pkl'))

In [None]:
utils.plot_synapses(rec_time, ref_rec_data_list, rec_type='Vext', sharey='row', rm_offset=True)

In [None]:
utils.plot_synapses(rec_time, ref_rec_data_list, rec_type='Vm', sharey='row', rm_offset=True)

In [None]:
utils.plot_synapses(rec_time, ref_rec_data_list, rec_type='Ca', sharey=True, rm_offset=True)

## Remove ion channels

In [None]:
def set_eq_files(k, v):
    if k == 'eqfile': return 'modified_' + v    
    else: return v

test_optim.cell_params_list =\
    [[{k: set_eq_files(k, v) for k, v in params.items()} for params in p_list] for p_list in cell2params_list]

In [None]:
import plot_eq_samples
importlib.reload(plot_eq_samples);

def create_and_test_eq_fils(modified_cell_params, predur=20, post_load_predur=2):
    prepare_cells_create_eq(predur=predur)

    parallel_params_list = []
    for cell, cell_params in zip(test_optim.cells, modified_cell_params):
        for cell_params_i in cell_params:
            parallel_params_list.append((cell, cell_params_i, False))

    with Pool(processes=20) as pool:
        pool.starmap(run_cell_with_params, parallel_params_list);
        
    prepare_cells_load_eq(post_load_predur=post_load_predur)
    
    parallel_params_list = []
    for cell, cell_params in zip(test_optim.cells, modified_cell_params):
        for cell_params_i in cell_params:
            parallel_params_list.append((cell, cell_params_i, False))

    with Pool(processes=20) as pool:
        modified_load_eq_rec_data_list = pool.starmap(run_cell_with_params, parallel_params_list);
        
    plot_eq_samples.plot_eq_rec_data(load_eq_rec_data_list, parallel_params_list)
    plot_eq_samples.plot_eq_rec_data(modified_load_eq_rec_data_list, parallel_params_list)

In [None]:
def stimulate(modified_cell_params):
    test_optim.cell_params_list = modified_cell_params
    test_optim.rec_type = 'synapses'
    for cell in test_optim.cells:
        cell.set_stim(cell2stims[cell.bp_type])
    return test_optim.run_parallel(params_list=stim_params_list, verbose=True, skip_comsol=True)

### All channels

In [None]:
if inference_mode != 'load_only':
    modified_cell_params = deepcopy(cell2params_list)
    create_and_test_eq_fils(modified_cell_params, predur=20, post_load_predur=0.5)
    modified_rec_data_list = stimulate(modified_cell_params)

    data_utils.save_var((modified_cell_params, modified_rec_data_list),
                        os.path.join(rm_ch_folder, 'all_params.pkl'))
    
else:
    modified_cell_params, modified_rec_data_list =\
        data_utils.load_var(os.path.join(rm_ch_folder, 'all_params.pkl'))

In [None]:
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='Vm', sharey='row', rm_offset=True)
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='rate')
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='Ca', sharey='row', rm_offset=True)

### L-Type from OFF-AT

In [None]:
def get_removed_L_at_OFF_list():
    
    def modify_params(params_dict):
        
        params_dict = params_dict.copy()
        
        if 'cd_T_at' in params_dict.keys():
            params_dict['cd_L_at'] = 0.0
            
        return params_dict
    
    return [[modify_params(params) for params in p_list] for p_list in cell2params_list]

In [None]:
if inference_mode != 'load_only':
    modified_cell_params = get_removed_L_at_OFF_list()
    create_and_test_eq_fils(modified_cell_params, predur=20, post_load_predur=0.5)
    modified_rec_data_list = stimulate(modified_cell_params)

    data_utils.save_var((modified_cell_params, modified_rec_data_list),
                        os.path.join(rm_ch_folder, 'rm_L_at.pkl'))
    
else:
    modified_cell_params, modified_rec_data_list =\
        data_utils.load_var(os.path.join(rm_ch_folder, 'rm_L_at.pkl'))

In [None]:
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='Vm', sharey='row', rm_offset=True)
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='rate')
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='Ca', sharey='row', rm_offset=True)

### Ca from AT

In [None]:
def get_removed_Ca_at_list():
    
    def modify_param(k, v):
        if 'cd_T_at' in k:
            return 0.0
        if 'cd_L_at' in k:
            return 0.0
        if 'cd_P_at' in k:
            return 0.0
        return v
    
    return [[{k: modify_param(k, v) for k, v in params.items()} for params in p_list] for p_list in cell2params_list]

In [None]:
if inference_mode != 'load_only':
    modified_cell_params = get_removed_Ca_at_list()
    create_and_test_eq_fils(modified_cell_params, predur=20, post_load_predur=0.5)
    modified_rec_data_list = stimulate(modified_cell_params)
    data_utils.save_var((modified_cell_params, modified_rec_data_list),
                        os.path.join(rm_ch_folder, 'rm_Ca_at.pkl'))
    
else:
    modified_cell_params, modified_rec_data_list =\
        data_utils.load_var(os.path.join(rm_ch_folder, 'rm_Ca_at.pkl'))

In [None]:
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='Vm', sharey='row', rm_offset=True)

### T-Type from AT

In [None]:
def get_removed_T_at_list():
    
    def modify_param(k, v):
        if 'cd_T_at' in k:
            return 0.0
        return v
    
    return [[{k: modify_param(k, v) for k, v in params.items()} for params in p_list] for p_list in cell2params_list]

In [None]:
if inference_mode != 'load_only':
    modified_cell_params = get_removed_T_at_list()
    create_and_test_eq_fils(modified_cell_params, predur=20, post_load_predur=0.5)
    modified_rec_data_list = stimulate(modified_cell_params)

    data_utils.save_var((modified_cell_params, modified_rec_data_list),
                        os.path.join(rm_ch_folder, 'rm_T_at.pkl'))
    
else:
    modified_cell_params, modified_rec_data_list =\
        data_utils.load_var(os.path.join(rm_ch_folder, 'rm_T_at.pkl'))

In [None]:
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='Vm', sharey='row', rm_offset=True)
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='rate')
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='Ca', sharey='row', rm_offset=True)

### Channels not important for light response

In [None]:
def get_removed_CaS_list():
    
    def modify_param(k, v):
        if k in ['cd_L_s', 'cd_T_s', 'cd_P_s']:
            return 0.0
        return v
    
    return [[{k: modify_param(k, v) for k, v in params.items()} for params in p_list] for p_list in cell2params_list]

In [None]:
if inference_mode != 'load_only':
    modified_cell_params = get_removed_CaS_list()
    create_and_test_eq_fils(modified_cell_params, predur=20, post_load_predur=0.5)
    modified_rec_data_list = stimulate(modified_cell_params)

    data_utils.save_var((modified_cell_params, modified_rec_data_list),
                        os.path.join(rm_ch_folder, 'rm_CaS.pkl'))
    
else:
    modified_cell_params, modified_rec_data_list =\
        data_utils.load_var(os.path.join(rm_ch_folder, 'rm_CaS.pkl'))

In [None]:
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='Vm', sharey='row', rm_offset=True)
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='rate')
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='Ca', sharey='row', rm_offset=True)

In [None]:
def get_removed_CaS_and_N_list():
    
    def modify_param(k, v):
        if k in ['cd_L_s', 'cd_T_s', 'cd_P_s', 'cd_N']:
            return 0.0
        return v
    
    return [[{k: modify_param(k, v) for k, v in params.items()} for params in p_list] for p_list in cell2params_list]

In [None]:
if inference_mode != 'load_only':
    modified_cell_params = get_removed_CaS_and_N_list()
    create_and_test_eq_fils(modified_cell_params, predur=20, post_load_predur=0.5)
    modified_rec_data_list = stimulate(modified_cell_params)

    data_utils.save_var((modified_cell_params, modified_rec_data_list),
                        os.path.join(rm_ch_folder, 'rm_CaS_and_N.pkl'))
    
else:
    modified_cell_params, modified_rec_data_list =\
        data_utils.load_var(os.path.join(rm_ch_folder, 'rm_CaS_and_N.pkl'))

In [None]:
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='Vm', sharey='row', rm_offset=True)
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='rate')
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='Ca', sharey='row', rm_offset=True)

### Passive

In [None]:
def get_passive_list():
    
    def modify_param(k, v):
        if k == 'eqfile':
            return 'passive' + v    
        if ('cd_' in k) and (k not in ['cd_L_at', 'cd_T_at', 'cd_P_at']):
            return 0.0
        return v
    
    return [[{k: modify_param(k, v) for k, v in params.items()} for params in p_list] for p_list in cell2params_list]

In [None]:
if inference_mode != 'load_only':
    modified_cell_params = get_passive_list()
    create_and_test_eq_fils(modified_cell_params, predur=30, post_load_predur=1)
    modified_rec_data_list = stimulate(modified_cell_params)

    data_utils.save_var((modified_cell_params, modified_rec_data_list),
                        os.path.join(rm_ch_folder, 'passive.pkl'))
    
else:
    modified_cell_params, modified_rec_data_list =\
        data_utils.load_var(os.path.join(rm_ch_folder, 'passive.pkl'))

In [None]:
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='Vm', sharey='row', rm_offset=True)
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='rate')
utils.plot_synapses(rec_time, ref_rec_data_list, modified_rec_data_list, rec_type='Ca', sharey='row', rm_offset=True)