# Simulate biphasic current pulses on CBCs

Use experimentally recorded currents in COMSOL to generate voltages.

## <font color='red'> Select mode: run_all  / load_only / test</font>

- *run_all*
    - Stimulate all cells, takes a while.
- *load_only*
    - Will simulate anything new.
- *test*
     - Stimulate all cells, takes a while.

In [None]:
#simulation_mode = 'test'
#simulation_mode = 'run_all'
simulation_mode = 'load_only'

# Imports

In [None]:
import importlib
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import warnings
from multiprocessing import Pool
import itertools

import os
import sys

In [None]:
pythoncodepath = os.path.abspath(os.path.join('..', '_pythoncode'))
sys.path = [pythoncodepath] + sys.path
import importhelper
importhelper.addfolders2path(pythoncodepath)

In [None]:
import data_utils
import comsol_utils

# Create Cell and stimulate

In [None]:
t_rng  = (0, 0.04)
predur = 20.
rec_predur = 0.02

## Select optimized cells

In [None]:
cbc_folder = os.path.join('..', 'step2a_optimize_cbc', )

cell2folder = {
    'ON': os.path.join(cbc_folder, 'optim_data', 'optimize_ON_submission2'),
    'OFF': os.path.join(cbc_folder, 'optim_data', 'optimize_OFF_submission2'),
}

## Create cells


In [None]:
import retsim_cells
importlib.reload(retsim_cells)

kwargs = dict(
    stim_type='Vext',
    make_cones=False,
    t_rng=t_rng,
    predur=predur,
    expt_base_file=os.path.join(
        cbc_folder, 'retsim_files', 'expt_CBC_base.cc'),
    nval_file='nval_optimize_CBCs.n',
    retsim_path=os.path.abspath(os.path.join(
        '..', 'NeuronC', 'models', 'retsim')) + '/'
)

# Create cells.
ON_cell = retsim_cells.CBC(
    bp_type='CBC5o',
    expt_file='Vext_thresholds_ON',
    bp_densfile='dens_CBC5o_optimize_ON.n',
    chanparams_file='chanparams_CBC5o_optimize_ON.n',
    compfile='rd_ON.csv',
    comsol_compfile='rd_ON.csv',
    retsim_stim_file_base='Vext_thresholds_ON_rd',
    **kwargs
)

OFF_cell = retsim_cells.CBC(
    bp_type='CBC3a',
    expt_file='Vext_thresholds_OFF',
    bp_densfile='dens_CBC3a_optimize_OFF.n',
    chanparams_file='chanparams_CBC3a_optimize_OFF.n',
    compfile='rd_OFF.csv',
    comsol_compfile='rd_OFF.csv',
    retsim_stim_file_base='Vext_thresholds_OFF_rd',
    **kwargs
)

cells = [ON_cell, OFF_cell]

In [None]:
def reset_cells():
    for cell in cells:
        cell.rec_type = 'optimize'
        cell.set_n_cones = 0

        cell.sim_dt  = 1e-6
        cell.syn_dt  = 1e-6
        cell.rec_dt  = 1e-5
        cell.stim_dt = 1e-5
        
        cell.update_t_rng(t_rng)
        cell.predur = predur
        
reset_cells()

## Set parameters

### Defaults and units

In [None]:
for cell, cell_name in zip([ON_cell, OFF_cell], ['ON', 'OFF']):

    cell.params_default = data_utils.load_var(os.path.join(cell2folder[cell_name], 'cell_params_default.pkl'))
    cell.params_unit = data_utils.load_var(os.path.join(cell2folder[cell_name], 'cell_params_unit.pkl'))
    
    cell.params_default.update(data_utils.load_var(os.path.join(cell2folder[cell_name], 'final_cpl_dict.pkl')))
    cell.params_default['set_tempcel'] = 33.5

### Optimized parameters

In [None]:
N_param_sets = 5

cell2params_list = []

for cell in cells:
    optim_folder = cell2folder['OFF' if cell.is_OFF_bp else 'ON']
    samples = data_utils.load_var(os.path.join(optim_folder, 'post_data', 'post_model_output_list.pkl'))
    d_sort_idxs = np.argsort([sample['loss']['total'] for sample in samples])
    cell2params_list.append([samples[idx]['params'] for idx in d_sort_idxs[:N_param_sets]])

### Prepare cells

In [None]:
# Create c++ files.
ON_cell.create_retsim_expt_file(verbose=False, on2cone_nodes=[])
OFF_cell.create_retsim_expt_file(verbose=False, off2cone_nodes=[])
# Compile c++ files.
!(cd {cell.retsim_path} && make)

In [None]:
ON_cell.set_rot(mxrot=-90, myrot=0)
ON_im = ON_cell.init_retsim(verbose=False, print_comps=True, update=True)

OFF_cell.set_rot(mxrot=-90, myrot=60)
OFF_im = OFF_cell.init_retsim(verbose=False, print_comps=True, update=True)

# Plot.
fig, axs = plt.subplots(1,2,figsize=(14, 8))
axs[0].imshow(ON_im)
axs[1].imshow(OFF_im)
for ax in axs: ax.axis('off')
plt.show()

# EQ files

In [None]:
def set_zero_stim(cell):
    time = np.unique(np.array([0, t_rng[0], t_rng[1]]))
    zero_stim = {'Time': time}
    zero_stim.update({'C'+str(i): np.zeros(time.size) for i in range(cell.n_bc_comps)})
    cell.set_stim(pd.DataFrame(zero_stim))
    cell.create_retsim_stim_file(stim_idx=0)

In [None]:
def run_cell_with_params(cell, sim_params, verbose=False):
    return cell.run(sim_params=sim_params, reset_retsim_stim=False, stim_idx=0, plot=False, verbose=verbose)[0]

## Set EQ filenames

In [None]:
for cell, cell_params in zip(cells, cell2params_list):
    for i, cell_params_i in enumerate(cell_params):
        cell_params_i['eqfile'] = cell.bp_type + '_thresh_params_idx_'+ str(i) + '.eq'

## Create EQ files

In [None]:
allow_skip_eqs = True # Skip if they exist.

all_eqs_exist = True
for cell, cell_params in zip(cells, cell2params_list):
    for i, cell_params_i in enumerate(cell_params):
        if cell_params_i['eqfile'] not in os.listdir(cell.retsim_path):
            all_eqs_exist = False
            break

if (not allow_skip_eqs) or (not all_eqs_exist):
    reset_cells()

    for cell in cells:
        cell.params_default['run_predur_only'] = 0
        cell.params_default['load_eq'] = 0
        cell.params_default['save_eq'] = 1
        cell.params_default['rec_predur'] = rec_predur

        set_zero_stim(cell)

    parallel_params_list = []
    for cell, cell_params in zip(cells, cell2params_list):
        for cell_params_i in cell_params:
            parallel_params_list.append((cell, cell_params_i, True))

    with Pool(processes=20) as pool:
        eq_rec_data_list = pool.starmap(run_cell_with_params, parallel_params_list);
else:
    eq_rec_data_list is None

## Set to load EQ and test

Never skip this, it is relatively fast and a good sanity check.

In [None]:
reset_cells()

for cell in cells:
    cell.rec_type = 'optimize'
    cell.params_default['run_predur_only'] = 0
    cell.params_default['save_eq'] = 0
    cell.params_default['load_eq'] = 1
    cell.params_default['rec_predur'] = rec_predur
    
    set_zero_stim(cell)
    
parallel_params_list = []
for cell, cell_params in zip(cells, cell2params_list):
    for cell_params_i in cell_params:
        parallel_params_list.append((cell, cell_params_i, True))
        
with Pool(processes=20) as pool:
    load_eq_rec_data_list = pool.starmap(run_cell_with_params, parallel_params_list);

In [None]:
import plot_eq_samples
importlib.reload(plot_eq_samples);

if eq_rec_data_list: plot_eq_samples.plot_eq_rec_data(eq_rec_data_list, parallel_params_list)
plot_eq_samples.plot_eq_rec_data(load_eq_rec_data_list, parallel_params_list)

# Select simulation parameters

In [None]:
#simulation_mode = 'test'
#simulation_mode = 'run_all'
simulation_mode = 'load_only'

In [None]:
if simulation_mode == 'test':
    AxA_list = ['2x2'] # Electrode configurations
    prefixs = ['ON', 'OFF'] # Prefixes for ON and OFF
    j_list = np.array([0,1,2,3,5,8,10]) # Current indices.
else:
    AxA_list = ['1x1', '2x2', '4x4', '10x10'] # Electrode configurations
    prefixs = ['ON', 'OFF'] # Prefixes for ON and OFF
    j_list = np.arange(0,17) # Current indices.

## Create xy-positions to simulate

In [None]:
comsol_zo = 30

dx_list = [0, 70, 140, 210, 280, 500]
dy_list = np.zeros(len(dx_list))

N_cells = len(dx_list)

dxdy_list = np.stack([dx_list, dy_list]).T

# Plot cell positions.
plt.figure(1,(3,3))
plt.plot(dxdy_list[:,0], dxdy_list[:,1], '*')
plt.show()

# COMSOL

In [None]:
# Prepare output folders.
for prefix in prefixs:
    for AxA in AxA_list:
        data_utils.make_dir(os.path.join(AxA, 'comsol_Vext', prefix))

## Create compartment files for COMSOL

In [None]:
data_utils.clean_folder('Neurons', force=True)

In [None]:
import comsol_comp_utils
importlib.reload(comsol_comp_utils)

for cell in cells:

    print(cell.bp_type)
    prefix = 'OFF' if cell.is_OFF_bp else 'ON'
    for dxdy in dxdy_list:
        
        cell.comsol_compfile = prefix +'_dx' + str(dxdy[0]) + '_dy' + str(dxdy[1]) + '.n';
        
        print('\t', cell.comsol_compfile)
        
        comsol_comp_utils.center_xy_region(cell=cell, region='R1')
        comsol_comp_utils.create_comsol_comp_file(
            cell=cell, z_soma=30, verbose=False, x0=dxdy[0], y0=dxdy[1],
            copy_to_comsol=False
        )

## Get COMSOL output

- The notebook creates a single file containing all compartment positions of a single cell.
- The user then has to run COMSOL manually.  $\Rightarrow$ **COMSOL is required.**
- The notebook reads the COMSOL output and moves the data somewhere else.

If you don't have COMSOL or don't want to run it, skip this step.

Otherwise you will have to do this step twice, once for both cells. Do the following:

- Select a cell by setting cell to ON_cell or OFF_cell in this script. (see script below)
- Open and run the COMSOL files for all NxN configurations you want to have.
- Copy the output and save it for the cell with the script below.
- Repeat for other cell.

In [None]:
# Select a cell by setting cell to ON_cell or OFF_cell.

#cell = OFF_cell
cell = ON_cell

###### Create single morphology file for COMSOL

In [None]:
from shutil import move

prefix = 'OFF' if cell.is_OFF_bp else 'ON'
    
# Add all cells to one for COMSOL.
cells_morph = []

for dxdy in dxdy_list:
    cell_morph = pd.read_csv(
        'Neurons/' + prefix + '_dx' + str(dxdy[0]) + '_dy' + str(dxdy[1]) + '.n',\
        delim_whitespace=True, names=['x', 'y', 'z']
    )

    assert cell_morph.shape[0] == cell.n_bc_comps + cell.n_cone_comps
    cells_morph = cells_morph + [cell_morph]

n_cells = len(cells_morph) 

cells_morph = pd.concat(cells_morph, axis=0, ignore_index=True)

cells_morph.to_csv('_all_neurons.csv', index=None, header=None)

##### <font color='red'> Open COMSOL files and extract Vext. </font>

In [None]:
input('Confirm that COMSOL outputs are up to data.')

##### Move all files to cell specific folder

In [None]:
# Copy the output and save it for the cell.
for AxA in AxA_list:
    print(AxA)
    
    Vext_files = [Vext_file for Vext_file in os.listdir(AxA + '/comsol_Vext/') if Vext_file[0:4] == 'Vext']
    print(Vext_files)

    for Vext_file in Vext_files:
        src_file = AxA + '/comsol_Vext/' + Vext_file
        trg_file = AxA + '/comsol_Vext/'+ prefix + '/' + Vext_file
        move(src_file, trg_file)
        
    print()

##### <font color='red'>Repeat for other cell.</font>

In [None]:
input('Confirm Vext for both cells was extracted.')

# Make retsim stimuli

The [last step](#COMSOL) created extracellular voltages $V_{ex}$ for every compartment in COMSOL.
If you skipped it you can also use the precomputed $V_{ex}$ values.

This step, translates $V_{ex}$ into retsim stimuli.

In [None]:
# Prepare output folders and make clean.
for prefix, AxA, j in itertools.product(['OFF', 'ON'], AxA_list, j_list):
    outputfolder = os.path.join(AxA, 'comsol_Vext', prefix, 'j' + str(j))
    data_utils.make_dir(outputfolder)
    data_utils.clean_folder(outputfolder, verbose=False, force=True)

In [None]:
def make_stimuli_for_single_cells(AxA, j, prefix, n_comps): 
    outputfolder = os.path.join(AxA, 'comsol_Vext', prefix, 'j' + str(j))
    
    # Read COMSOL file.
    inputfile = os.path.join(AxA, 'comsol_Vext', prefix, 'Vext_k_' + str(j) + '.csv')
    Vext_cells_raw = comsol_utils.comsol2dataframe(inputfile)
    time = Vext_cells_raw['Time'].values
    n_cells = int(Vext_cells_raw.shape[1] / n_comps)
    
    # Sanity checks.
    assert n_cells == N_cells
    assert n_cells*n_comps == (Vext_cells_raw.shape[1]-1)
    
    # Split file for every cell.
    Vext_cells = []
    for cell_i in range(n_cells):
        Vext_cell = Vext_cells_raw.iloc[:,1+cell_i*n_comps:1+(cell_i+1)*n_comps]
        Vext_cell = pd.concat(
            [
                pd.DataFrame(np.zeros((1,n_comps)), columns=Vext_cell.columns),
                Vext_cell,
                pd.DataFrame(np.zeros((2,n_comps)), columns=Vext_cell.columns),
            ],
            ignore_index=True
        )
        
        assert time.max()+rec_predur < t_rng[1]
        
        Vext_cell.insert(0, 'Time', np.concatenate([np.append(0, time+rec_predur), [time.max()+1e-3+rec_predur, t_rng[1]]]))
        Vext_cells += [Vext_cell]
            
    # Save to file.
    for cell_i, Vext_cell in enumerate(Vext_cells):
        dxdy = dxdy_list[cell_i]
        filename = 'dx' + str(int(dxdy[0])) + '_dy' + str(int(dxdy[1])) + '.stim'
        Vext_cell.to_csv(os.path.join(outputfolder, filename), index=False, sep=',')

In [None]:
for cell, AxA, j in itertools.product(cells, AxA_list, j_list):
    prefix = 'OFF' if cell.is_OFF_bp else 'ON'
    print(prefix, AxA, ('j'+str(j)).ljust(3), end='\t')
    make_stimuli_for_single_cells(AxA, j, prefix, n_comps=cell.n_bc_comps)

# Run experiment

In [None]:
input('Confirm that retsim stimuli are up to data. Please press enter.')

In [None]:
if simulation_mode == 'test':
    run_AxA_list = ['2x2']
    run_j_list = np.array([0,1,2,3,5,8,10])
else:
    run_AxA_list = ['1x1', '2x2', '4x4', '10x10']
    run_j_list = [0,1,2,3,4,5,6,8,12]

## Helper functions

In [None]:
def set_cell_stims(cell, AxA, j):
    prefix = 'OFF' if cell.is_OFF_bp else 'ON' 
    stim_files = [f'{AxA}/comsol_Vext/{prefix}/j{j}/dx{int(dxdy[0])}_dy{int(dxdy[1])}.stim'
                  for dxdy in dxdy_list]    
    cell.set_stim([pd.read_csv(stim_file) for stim_file in stim_files]);

In [None]:
def run(cell, cell_params, AxA, j, rec_type='optimize'):
    print('\tRunning for ' + AxA + ' with j' + str(j))
    
    set_cell_stims(cell, AxA, j)
    
    sim_list_list = []
    for cell_params_i in cell_params:        
        print({k: "{:.1g}".format(v) if isinstance(v, float) else v for k, v in cell_params_i.items()})
        
        cell.params_default.update(cell_params_i)
        sim_list_list.append(cell.run_parallel_stimuli(n_parallel=N_cells+2))
    
    return sim_list_list

## Plot stimuli

In [None]:
def plot_stim(cell, AxA, j, ax=None):
    if ax is None: fig, ax = plt.subplots(1,1,figsize=(12,1))
    ax.set_title(AxA + '   j' + str(j))
    
    for stim_idx, stim in enumerate(cell.stim):
        ax.plot(1e3*stim['Time'], 1e3*stim.iloc[:,1::10].mean(axis=1))
        
    ax.set_xlabel('Time [ms]')
    ax.set_ylabel('Vext [mV]')
    plt.show()

In [None]:
for AxA, j in itertools.product(run_AxA_list, run_j_list):
    set_cell_stims(cell=ON_cell, AxA=AxA, j=j)
    plot_stim(cell=ON_cell, AxA=AxA, j=j, ax=None)

## Run

In [None]:
overwrite_bc_data = (simulation_mode != 'load_only')
bc_folder = 'bc_data_submission2' if (simulation_mode == 'load_only') else 'bc_data'
data_utils.make_dir(bc_folder)
print('Folder:', bc_folder, '  --> Gen new data:', allow_overwriting_bc_data)

In [None]:
dxdy_list_filename = f'{bc_folder}/dxdy_list.pkl'
if os.path.isfile(dxdy_list_filename) and not overwrite_bc_data:
    np.all(data_utils.load_var(dxdy_list_filename) == dxdy_list)
else:
    data_utils.save_var(dxdy_list, dxdy_list_filename)

In [None]:
run_j_list_filename = f'{bc_folder}/run_j_list.pkl'
if os.path.isfile(run_j_list_filename) and not overwrite_bc_data:
    assert np.all(data_utils.load_var(run_j_list_filename) == run_j_list)
else:
    data_utils.save_var(run_j_list, run_j_list_filename)

In [None]:
for cell, cell_params in zip(cells, cell2params_list):
    prefix = 'OFF' if cell.is_OFF_bp else 'ON'    
    
    for AxA in run_AxA_list:
        filename = f'{bc_folder}/sim_{prefix}_{AxA}.pkl'
        print(filename)
        
        if os.path.isfile(filename) and not(overwrite_bc_data):
            print('File already exists. Will not be overwritten.')
            sim_list = data_utils.load_var(filename)
            assert len(sim_list) == len(run_j_list)
            assert np.all([len(sim_list_i) == N_param_sets for sim_list_i in sim_list])
            assert np.all([[len(sim_list_ii) == len(dxdy_list) for sim_list_ii in sim_list_i]
                           for sim_list_i in sim_list])
        else:
            sim_list = [run(cell=cell, cell_params=cell_params, AxA=AxA, j=j) for j in run_j_list]
            data_utils.save_var(sim_list, filename)
            
        del sim_list
            
        print()

# Export meta data

In [None]:
ON_cell_params = cell2params_list[np.argwhere(np.asarray(cells)==ON_cell).flat[0]]
OFF_cell_params = cell2params_list[np.argwhere(np.asarray(cells)==OFF_cell).flat[0]]

ON_rrps = np.asarray([ON_params_i['b_rrp'] for ON_params_i in ON_cell_params ])
OFF_rrps = np.asarray([OFF_params_i['b_rrp'] for OFF_params_i in OFF_cell_params])

data_utils.save_var(ON_rrps, f'{bc_folder}/ON_rrps.pkl')
data_utils.save_var(OFF_rrps, f'{bc_folder}/OFF_rrps.pkl')