# IN | summary of DAP and rebound burst

Make summary plots for DAP and rebound burst Izhikevich neuron

In [None]:
import numpy as np
from itertools import product as itproduct

from matplotlib import pyplot as plt
import seaborn as sns

In [None]:
from sys import path as sys_path
from os.path import abspath as os_path_abspath
sys_path.append(os_path_abspath('..'))
import addpaths

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import data_utils
import metric_utils
import plot_utils as pltu

# Select modes

In [None]:
# Select modes to generate data for.
neuron_modes = [
    'Rebound burst',
    'DAP',
]

In [None]:
import izhikevich_parameters
parameters = izhikevich_parameters.parameters()

# Generator

In [None]:
import izhikevich
from data_generator_IN import data_generator_IN

gens = {}

for neuron_mode in neuron_modes:
    neuron_parameters, stimulus_parameters, t_parameters = parameters.select_mode(neuron_mode)
    neuron = izhikevich.neuron(neuron_parameters, stimulus_parameters)
    
    print(neuron_mode, t_parameters['dt'])
    
    gen = data_generator_IN(
        y0=neuron.y0, t0=0.0, tmax=t_parameters['tmax'], gen_acc_sols=True,
        model=neuron, n_samples=40, n_parallel=20,
        base_folder='_data/summary'
    )
    gen.subfoldername = neuron_mode
    gen.load_acc_sols_from_file()    
    
    gens[neuron_mode] = gen

## Generate data

In [None]:
# pert_method, adaptive, methods, step_params
solver_params = [
    ('conrad', 0, ['FE', 'RKBS', 'RKDP'], [0.5, 0.2, 0.02, 0.002]),
    ('conrad', 2, ['FE', 'RKBS', 'RKDP'], [0.5, 0.2, 0.02, 0.002]),
    ('conrad', 1, ['FE', 'RKBS', 'RKDP'], [1e-2, 1e-3, 1e-4]),
]

In [None]:
for neuron_mode, gen in gens.items():
    
    print('----------------------------------------------------------')
    print(neuron_mode, ':', gen.subfoldername)
    print('----------------------------------------------------------') 
    
    for pert_method, adaptive, methods, step_params in solver_params:
        for step_param, method in itproduct(step_params, methods):
            gen.gen_and_save_data(
                method=method, adaptive=adaptive, step_param=step_param,
                pert_method=pert_method, overwrite=False
            )

# Plots

## Number of spikes plots

In [None]:
def load_n_spikes(neuron_mode, method, adaptive, step_param, pert_method):
    """Load the number of spikes for given solver settings"""
    n_spikes = [len(e_list[0]) for e_list in gens[neuron_mode].load_data_and_check(
        method, adaptive, step_param, pert_method).events]
        
    return n_spikes


def load_all_n_spikes(neuron_mode, method, adaptive, step_params, pert_method):
    """Load the number of spikes for given solver settings"""
    n_spikes = [load_n_spikes(neuron_mode, method, adaptive, step_param, pert_method) for step_param in step_params]

    return n_spikes


def load_acc_n_spikes(neuron_mode, method, adaptive, step_param, pert_method):
    """Load the number of spikes for given solver settings"""
    n_spikes = len(gens[neuron_mode].load_data_and_check(
        method, adaptive, step_param, pert_method).acc_events[0])
        
    return n_spikes

In [None]:
def plot_n_spikes(ax, neuron_mode, pert_method, adaptive, methods, step_params, maketitle=True):
    """Plot number of spikes for different solver settings"""
    
    for idx_method, method in enumerate(methods):
        n_spikes = load_all_n_spikes(neuron_mode, method, adaptive, step_params, pert_method)
        
        pltu.plot_percentiles(
            ax, data=n_spikes,
            positions=pltu.get_x_positions(n_positions=len(step_params), idx=idx_method, n_idxs=len(methods), offset=0.22),
            color=pltu.method2color(method),
            marker=pltu.method2marker(method),
            outl_kw=dict(clip_on=False),
            mean_kw=dict(clip_on=False),
            line_kw=dict(clip_on=False),
            connect=False,
        )
        
    ax.axhline(
        load_acc_n_spikes(neuron_mode, method, adaptive, step_params[0], pert_method),
        c='grey', ls='--', zorder=-20
    )
        
        
    if maketitle: ax.set_title(pltu.mode2label(adaptive))
    ax.set_xticks(np.arange(len(step_params)))
    ax.set_xticklabels([pltu.step_param2tick(step_param, adaptive) for step_param in step_params])
    ax.set_ylabel('No. spikes')
    ax.set_xlabel(pltu.mode2xlabel(adaptive, time_unit='ms'))
    
    ax.set_ylim(0, None)

In [None]:
"""Plot example"""
sbnx = len(solver_params)
sbny = len(plot_modes)

fig, axs = pltu.subplots(sbnx, sbny)

for (neuron_mode, tmin, tmax), axs_row in zip(plot_modes, axs):
    for ax, (pert_method, adaptive, methods, step_params) in zip(axs_row, solver_params):
        plot_n_spikes(
            ax=ax, neuron_mode=neuron_mode, pert_method=pert_method,
            adaptive=adaptive, methods=methods, step_params=step_params
        )
plt.tight_layout()

## Traces plot

In [None]:
# method to show traces for
plot_method = 'FE'
plot_pert_method = 'conrad'

# List of (neuron_mode, tmin, tmax)
plot_modes = [('Rebound burst', 40, 120), ('DAP', 0, 1000)]

# list of (adaptive, step_param)
trace_params = [(0, 0.2), (0, 0.02)]
ref_params = trace_params[0]

In [None]:
def plot_trace(ax, method, adaptive, step_param, pert_method, maketitle=True, n_samples=3):
    """Load data for given parameters and plot trace on given axis"""
    output_data = gens[neuron_mode].load_data_and_check(method, adaptive, step_param, pert_method)

    idxs = np.random.randint(0,output_data.vs.shape[0],n_samples)
        
    for ii, i in enumerate(idxs):

        ts_i = output_data.ts[i] if isinstance(output_data.ts, list) else output_data.ts
        vs_i = output_data.vs[i]

        tminidx = np.argmin(np.abs(ts_i-tmin))
        tmaxidx = np.argmin(np.abs(ts_i-tmax))

        ax.plot(ts_i[tminidx:tmaxidx], vs_i[tminidx:tmaxidx], lw=0.7, c=pltu.neuron2color(ii))
        
    decorate_trace_plot(ax, title=pltu.step_param2label(step_param=step_param, adaptive=adaptive) if maketitle else None)
    
    
def plot_reference(ax, maketitle=True):
    """Load data for reference solution and plot trace on given axis"""
    output_data = gens[neuron_mode].load_data_and_check(
        plot_method, ref_params[0], ref_params[1], plot_pert_method
    )
    ts_i = output_data.acc_ts
    vs_i = output_data.acc_vs
    tminidx, tmaxidx = np.argmin(np.abs(ts_i-tmin)), np.argmin(np.abs(ts_i-tmax))
    ax.plot(ts_i[tminidx:tmaxidx], output_data.acc_vs[tminidx:tmaxidx], c='grey', lw=0.7)
    decorate_trace_plot(ax, title='Reference' if maketitle else None)

    
def decorate_trace_plot(ax, title):
    if title is not None: ax.set_title(title)
    ax.set_xlabel('Time (ms)')
    ax.set_ylabel('v(t)')
    ax.set_ylim((-75, 40))
    ax.set_yticks([-60, -30, 0, 30])

In [None]:
"""Plot example"""
sbnx = len(trace_params)+1
sbny = len(plot_modes)

fig, axs = pltu.subplots(sbnx, sbny)

for (neuron_mode, tmin, tmax), axs_row in zip(plot_modes, axs):
    for ax, (adaptive, step_param) in zip(axs_row[:-1], trace_params):
        plot_trace(ax, plot_method, adaptive, step_param, plot_pert_method, maketitle=(axs_row[0] == axs[0,0]))
    plot_reference(axs_row[-1], maketitle=(axs_row[0] == axs[0,0]))

plt.tight_layout()

## Figure

In [None]:
np.random.seed(777)

### Prepare axes ###
sbnx = len(trace_params)+1+len(solver_params)
sbny = len(plot_modes)

fig, axs = pltu.auto_subplots(
    sbnx*sbny, max_nx_sb=sbnx, xsize='fullwidth', ysizerow=1.35,
    gridspec_kw=dict(width_ratios=[1]*(len(trace_params)+1) + [1.3]*len(solver_params))
)

### Plot data ###
for (neuron_mode, tmin, tmax), axs_row in zip(plot_modes, axs):

    # Plot traces.
    for ax, (adaptive, step_param) in zip(axs_row[:len(trace_params)], trace_params):
        plot_trace(ax, plot_method, adaptive, step_param, plot_pert_method, maketitle=(axs_row[0] == axs[0,0]))

    # Plot reference traces.
    plot_reference(axs_row[len(trace_params)], maketitle=(axs_row[0] == axs[0,0]))

    # Plot summary.
    for ax, (pert_method, adaptive, methods, step_params) in zip(axs_row[-len(solver_params):], solver_params):
        plot_n_spikes(
            ax=ax, neuron_mode=neuron_mode, pert_method=pert_method,
            adaptive=adaptive, methods=methods, step_params=step_params,
            maketitle=axs_row[0]==axs[0,0]
        )

    
### Decorate ### 
pltu.set_labs(axs, panel_nums='auto', panel_num_space=0, panel_num_va='top')

# Labels and ticks
for axs_row in axs: pltu.make_share_ylims(axs_row[-len(solver_params):])
for ax in axs[:,1:len(trace_params)+1].flat:
    ax.set_yticklabels([])    
    ax.set_ylabel(None)    
for ax in axs[:,-len(solver_params)+1:].flat:
    ax.set_yticklabels([])
    ax.set_ylabel(None)
for ax in axs[:-1,:].flat:
    ax.set_xlabel(None)

# Adjust axis spacing
pltu.move_xaxis_outward(axs)
pltu.tight_layout(w_pad=-1, h_pad=0.3, rect=(0,0.01,0.98,1))
for i in range(len(solver_params)): pltu.move_box(axs[:,-(i+1)], dx=+0.014*(i+1))

# Legend.
pltu.make_method_legend(
    axs[0, -3], methods=solver_params[0][2], pert_method=pert_method,
    legend_kw=dict(loc='lower right', handlelength=0.5, bbox_to_anchor=(1,-0.1))
)

sns.despine()
fig.align_labels(axs)

pltu.savefig(f"IN_mode_details")
pltu.show_saved_figure(fig)