# Create animations of the CBCs

Create a video of the [OFF](#OFF) and [ON](#ON) CBC.

Note that this script is not optimized for speed and it creates multiple GB of data.

# Imports

In [None]:
import importlib

In [None]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import os
import sys

In [None]:
pythoncodepath = os.path.abspath(os.path.join('..', 'pythoncode'))
sys.path = [pythoncodepath] + sys.path
import importhelper
importhelper.addfolders2path(pythoncodepath)

In [None]:
import data_utils
import plot_utils

plot_utils.set_rcParams()

# Helper functions

In [None]:
def merge_rec_data(df1, df2):

    df1 = df1.copy()
    df2 = df2.copy()
    
    for col in set(df1.columns).intersection(set(df2.columns)):
        assert df1[col].astype(float).equals(df2[col].astype(float)), 'Overlap must be equal: ' + col
        
    return df1.combine_first(df2)

# Load target and stimulus

In [None]:
cbc_optim_folder = os.path.join('..', 'step2a_optimize_cbc', 'optim_data')

In [None]:
sorted(os.listdir(cbc_optim_folder))

In [None]:
cell2folder = {
    'OFF':  os.path.join(cbc_optim_folder, 'optimize_OFF_submission2'),
    'ON':   os.path.join(cbc_optim_folder, 'optimize_ON_submission2'),
}

In [None]:
data_utils.make_dir('cbc_data')

# OFF

## Create cell

In [None]:
final_model_output = data_utils.load_var(
    os.path.join(cell2folder['OFF'], 'post_data', 'final_model_output.pkl'))

In [None]:
stim_t_rng = final_model_output['t_rng']
stimulus = final_model_output['Stimulus']

In [None]:
predur = 10

# Load parameters.
params_default = data_utils.load_var(os.path.join(cell2folder['OFF'], 'cell_params_default.pkl'))
params_default.update(final_model_output['params'])
params_default.update(data_utils.load_var(os.path.join(cell2folder['OFF'], 'final_cpl_dict.pkl')))

params_unit = data_utils.load_var(os.path.join(cell2folder['OFF'], 'cell_params_unit.pkl'))

In [None]:
import retsim_cells
importlib.reload(retsim_cells);

OFF_cell = retsim_cells.CBC(
    bp_type = 'CBC3a',
    predur=predur, t_rng=(1.9,2.3),
    params_default=params_default, params_unit=params_unit,
    stimulus=stimulus, stim_type='Light',
    cone_densfile       = 'dens_cone_optimized_submission2.n',
    bp_densfile         = 'dens_CBC3a_optimize_OFF.n',
    nval_file           = 'nval_optimize_CBCs.n',
    chanparams_file     = 'chanparams_CBC3a_optimize_OFF.n',
    expt_file_list      = ['plot_cell_OFF'],
    expt_base_file_list = ['../step2a_optimize_cbc/retsim_files/expt_CBC_base.cc'],
    retsim_path=os.path.abspath(os.path.join('..', 'neuronc', 'models', 'retsim')) + '/'
)

In [None]:
# Create c++ file.
OFF_cell.create_retsim_expt_file(verbose=False, off2cone_nodes=[686, 1037, 828, 950, 879])

In [None]:
# Compile c++ file.
!(cd {OFF_cell.retsim_path} && make)

In [None]:
OFF_cell.init_retsim(verbose=True)

## Tests

### Test model

In [None]:
OFF_cell.update_t_rng((1.9, 2.1))
OFF_cell.rec_type = 'optimize'
_ = OFF_cell.run(plot=True, verbose=False, update_cell_rec_data=True);

### Test cone output

In [None]:
import retsim_cell_tests
importlib.reload(retsim_cell_tests)

cone_post_data_folder = os.path.join('..', 'step1a_optimize_cones', 'optim_data',
                                     'optimize_cone_submission2', 'post_data')

retsim_cell_tests.test_cones(
    OFF_cell, os.path.join(cone_post_data_folder, 'final_model_output.pkl'), t_rng=(1.0,2.5)
);

### Test CBC output

In [None]:
import retsim_cell_tests
importlib.reload(retsim_cell_tests)

retsim_cell_tests.test_CBC(
    OFF_cell, os.path.join(os.path.join(cell2folder['OFF'], 'post_data', 'final_model_output.pkl')),
    t_rng=(1.0,2.5)
);

## Run simulation

In [None]:
OFF_cell.update_t_rng(stim_t_rng)

### Run

Record rates, membrane voltage and calcium for all compartments.

In [None]:
load_data = True

if load_data:
    OFF_cell.rec_data = data_utils.load_var('cbc_data/OFF_cell_rec_data.pkl')
else:
    OFF_cell.rec_type = 'heatmap_vm'
    _ = OFF_cell.run(plot=True, verbose=True, update_cell_rec_data=True)
    
    OFF_cell.rec_type = 'heatmap_ca'
    _ = OFF_cell.run(plot=True, verbose=True, update_cell_rec_data=True)
    
    OFF_cell.rec_data['heatmap_vm']['Data'] = \
        merge_rec_data(df1=OFF_cell.rec_data['heatmap_vm']['Data'], df2=OFF_cell.rec_data['heatmap_ca']['Data'])
    
    data_utils.save_var(OFF_cell.rec_data, 'cbc_data/OFF_cell_rec_data.pkl')

### Animate

In [None]:
t_list = np.arange(0, OFF_cell.get_t_rng()[1]-OFF_cell.get_t_rng()[0], 50e-3)
print('N:', t_list.size, '\t dt=', np.diff(t_list)[0])

In [None]:
import plot_cell_heatmap
importlib.reload(plot_cell_heatmap)

scale_n = 10
CP = plot_cell_heatmap.CellPlotter(cell=OFF_cell, rec_type='heatmap_vm')
CP.compute_draw_data(scale_n=scale_n, release_rad=scale_n, inc_cone=False,
                     inc_release=True, flipz=True)

In [None]:
Vm_rest = CP.rec_data['Vm'][OFF_cell.node2comp(0)][0]
Vm_max_diff = np.max(np.abs(CP.rec_data['Vm'].values - Vm_rest))

for plot_type in ['Vm', 'Ca', 'rate']:
    
    y_min = None
    y_max = None
    
    if plot_type == 'Vm':
        y_min = Vm_rest - Vm_max_diff
        y_max = Vm_rest + Vm_max_diff
    elif plot_type == 'rate':
        y_min = 0
        
    CP.set_colormapping(
        plot_type, data=None, symmetric=False, y_min=y_min, y_max=y_max, cmap=None
    )

In [None]:
load_data = True

if load_data:
    heatmaps_list = data_utils.load_var('cbc_data/heatmaps_list_OFF.pkl')
else:
    heatmaps_list = CP.get_image_sequences(
        plot_types_list=[['Vm'], ['Ca'], ['rate with cell']],
        t_list=t_list, extraspace=scale_n*6,
        nodes=None, to_array_stack=True
    );
    data_utils.save_var(heatmaps_list, 'cbc_data/heatmaps_list_OFF.pkl')

In [None]:
fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.imshow(heatmaps_list[0][0])

In [None]:
plot_kwargs = dict(
    data_list=heatmaps_list,
    colorbar_list=[CP.colormapping['Vm'], CP.colormapping['Ca'], CP.colormapping['rate']],
    titles=['Voltage', 'Calcium', 'Release'],
    cb_labels=['Membrane voltage (mV)', 'Calcium conc. (nM)', 'Release rate (ves./s)'],
    trace_df=pd.DataFrame({
        'Time': OFF_cell.rec_data['heatmap_vm']['Time'],
        'Stim': OFF_cell.rec_data['heatmap_vm']['Stim']
    }),
    set_sbny=5, sbnx=4, cb_width=0.15, figsize=(6.7,3.3), abc='ABCD',
    trace_kw=dict(color=(0.8, 0.1, 0.1)),
    tl_dict=dict(h_pad=-1, w_pad=-2, pad=0.2)
)

#### Snapshot

In [None]:
plt.rcParams['figure.dpi'] = 600
plt.rcParams["savefig.dpi"] = 600

fig, *_ = CP.plot_data(**plot_kwargs, data_list_idx=81)
plt.savefig('../create_figures/_figures_apx/OFF_CBC_chirp.pdf')

#### Video

This takes several minutes.

In [None]:
plt.rcParams['figure.dpi'] = 100
plt.rcParams["savefig.dpi"] = 300

data_utils.make_dir('_animations')
CP.animate(dt=np.diff(t_list)[0], filename='_animations/OFF.mp4', **plot_kwargs)

# ON

## Create cell

In [None]:
final_model_output = data_utils.load_var(os.path.join(cell2folder['ON'], 'post_data', 'final_model_output.pkl'))

In [None]:
stim_t_rng = final_model_output['t_rng']
stimulus = final_model_output['Stimulus']

In [None]:
predur = 10

# Load parameters.
params_default = data_utils.load_var(os.path.join(cell2folder['ON'], 'cell_params_default.pkl'))
params_default.update(final_model_output['params'])
params_default.update(data_utils.load_var(os.path.join(cell2folder['ON'], 'final_cpl_dict.pkl')))

params_unit = data_utils.load_var(os.path.join(cell2folder['ON'], 'cell_params_unit.pkl'))

In [None]:
import retsim_cells
importlib.reload(retsim_cells)

ON_cell = retsim_cells.CBC(
    bp_type = 'CBC5o',
    predur=predur, t_rng=(1.9,2.3),
    params_default=params_default, params_unit=params_unit,
    stimulus=stimulus, stim_type='Light',
    cone_densfile       = 'dens_optimized_cone_submission2.n',
    bp_densfile         = 'dens_strychnine_optimize_bc_v3.n',
    nval_file           = 'nval_strychnine_optimize_bc_v3.n',
    chanparams_file     = 'chanparams_strychnine_optimize_CBC5_v3.n',
    expt_file_list      = ['plot_cell_ON'],
    expt_base_file_list = ['../step2_optimize_cbc/retsim_files/expt_CBC_base.cc'],
)

In [None]:
# Create c++ file.
ON_cell.create_retsim_expt_file(verbose=False, on2cone_nodes=[1077, 980, 1190])
# Compile c++ file.
!(cd {ON_cell.retsim_path} && make)

In [None]:
ON_cell.init_retsim(verbose=True)

## Tests

### Test model

In [None]:
ON_cell.update_t_rng((1.9, 2.1))
ON_cell.rec_type = 'optimize'
_ = ON_cell.run(plot=True, verbose=False, update_cell_rec_data=True)

### Test cone output

In [None]:
import retsim_cell_tests
importlib.reload(retsim_cell_tests)

cone_post_data_folder = os.path.join(
    '..', 'step1_optimize_cones', 'optim_data', '_optimize_cone_submission2_truncated', 'post_data'
)

retsim_cell_tests.test_cones(
    ON_cell, os.path.join(cone_post_data_folder, 'final_model_output.pkl'), t_rng=(1,2.5)
);

### Test CBC output

In [None]:
import retsim_cell_tests
importlib.reload(retsim_cell_tests)

retsim_cell_tests.test_CBC(
    ON_cell, os.path.join(os.path.join(cell2folder['ON'], 'post_data', 'final_model_output.pkl')), t_rng=(1,2.5)
);

## Run simulation

In [None]:
ON_cell.update_t_rng(stim_t_rng)

### Run

In [None]:
load_data = True

if load_data:
    ON_cell.rec_data = data_utils.load_var('cbc_data/ON_cell_rec_data.pkl')
else:
    ON_cell.rec_type = 'heatmap_vm'
    _ = ON_cell.run(plot=True, verbose=True, update_cell_rec_data=True)
    
    ON_cell.rec_type = 'heatmap_ca'
    _ = ON_cell.run(plot=True, verbose=True, update_cell_rec_data=True)
    
    ON_cell.rec_data['heatmap_vm']['Data'] = \
    merge_rec_data(df1=ON_cell.rec_data['heatmap_vm']['Data'], df2=ON_cell.rec_data['heatmap_ca']['Data'])
    
    data_utils.save_var(ON_cell.rec_data, 'cbc_data/ON_cell_rec_data.pkl')

### Animate

In [None]:
t_list = np.arange(0, OFF_cell.get_t_rng()[1]-OFF_cell.get_t_rng()[0], 50e-3)
print('N:', t_list.size, '\t dt=', np.diff(t_list)[0])

In [None]:
import plot_cell_heatmap
importlib.reload(plot_cell_heatmap)

scale_n = 10
CP = plot_cell_heatmap.CellPlotter(cell=ON_cell, rec_type='heatmap_vm')
CP.compute_draw_data(scale_n=scale_n, release_rad=scale_n, inc_cone=False, inc_release=True, flipz=True)

In [None]:
Vm_rest = CP.rec_data['Vm'][OFF_cell.node2comp(0)][0]
Vm_max_diff = np.max(np.abs(CP.rec_data['Vm'].values - Vm_rest))

for plot_type in ['Vm', 'Ca', 'rate']:
    
    y_min = None
    y_max = None
    
    if plot_type == 'Vm':
        y_min = Vm_rest - Vm_max_diff
        y_max = Vm_rest + Vm_max_diff
    elif plot_type == 'rate':
        y_min = 0
        
    CP.set_colormapping(
        plot_type, data=None, symmetric=False, y_min=y_min, y_max=y_max, cmap=None
    )

In [None]:
load_data = True

if load_data:
    heatmaps_list = data_utils.load_var('cbc_data/heatmaps_list_ON.pkl')
else:
    heatmaps_list = CP.get_image_sequences(
        plot_types_list=[['Vm'], ['Ca'], ['rate with cell']],
        t_list=t_list, extraspace=scale_n*6,
        nodes=None, to_array_stack=True
    );
    data_utils.save_var(heatmaps_list, 'cbc_data/heatmaps_list_ON.pkl')

In [None]:
fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.imshow(heatmaps_list[-1][0])

In [None]:
plot_kwargs = dict(
    data_list=heatmaps_list,
    colorbar_list=[CP.colormapping['Vm'], CP.colormapping['Ca'], CP.colormapping['rate']],
    titles=['Voltage', 'Calcium', 'Release'],
    cb_labels=['Membrane voltage (mV)', 'Calcium conc. (nM)', 'Release rate (ves./s)'],
    trace_df=pd.DataFrame({
        'Time': OFF_cell.rec_data['heatmap_vm']['Time'],
        'Stim': OFF_cell.rec_data['heatmap_vm']['Stim']
    }),
    set_sbny=5, sbnx=4, cb_width=0.15, figsize=(6.7,3.3), abc='ABCD',
    trace_kw=dict(color=(0.8, 0.1, 0.1)),
    tl_dict=dict(h_pad=-1, w_pad=-2, pad=0.2)
)

#### Snapshot

In [None]:
plt.rcParams['figure.dpi'] = 600
plt.rcParams["savefig.dpi"] = 600

fig, *_ = CP.plot_data(**plot_kwargs, data_list_idx=22)
plt.savefig('../create_figures/_figures_apx/ON_CBC_chirp.pdf')

#### Video

This takes several minutes.

In [None]:
plt.rcParams['figure.dpi'] = 100
plt.rcParams["savefig.dpi"] = 300

data_utils.make_dir('_animations')
CP.animate(dt=np.diff(t_list)[0], filename='_animations/ON.mp4', **plot_kwargs)