# Hodgkin-Huxley | Run time

In [None]:
import numpy as np
import pandas as pd

from matplotlib import pyplot as plt
import seaborn as sns

from itertools import product as itproduct

In [None]:
from sys import path as sys_path
from os.path import abspath as os_path_abspath
sys_path.append(os_path_abspath('..'))
import addpaths

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import stim_utils
import math_utils
import frame_utils
import metric_utils
import plot_utils as pltu

# Model

In [None]:
import hodgkin_huxley

t0, tmax = 0, 100
neuron = hodgkin_huxley.neuron()

In [None]:
stim_onset, stim_offset = 10, tmax-10
stims = [
    stim_utils.Istim(Iamp=0.15, onset=stim_onset, offset=stim_offset, name='Step'),
    stim_utils.Istim_noisy(Iamp=0.2, onset=stim_onset, offset=stim_offset, name='Noisy', nknots=51, seed=1146),
]
for stim in stims: stim.plot(t0=t0, tmax=tmax)

# Generator

In [None]:
from data_generator_HH import data_generator_HH
from copy import deepcopy

gens = {}

for stim in stims:
    
    neuron = deepcopy(neuron)
    neuron.get_Istim_at_t = stim.get_I_at_t
    
    gens[stim] = data_generator_HH(
        t0=t0, tmax=tmax, t_eval_adaptive=math_utils.t_arange(t0, tmax, 1),
        return_vars=['events', 'ys'],
        model=neuron, y0=neuron.compute_yinf(-65), thresh=0.0,
        n_samples=20, n_parallel=20,
        gen_det_sols=True, gen_acc_sols=True, acc_same_ts=True,
        base_folder='_data/run_time'
    )
    gens[stim].update_subfoldername(stim=stim.name)
    gens[stim].load_acc_sols_from_file()    

## Test

In [None]:
method, adaptive, step_param, pert_method = 'RKCK', 1, 1e-4, 'conrad'

gen = gens[stims[0]]
data = gen.gen_data(method=method, adaptive=adaptive, step_param=step_param, pert_method=pert_method)
solver = gen.get_solver(method=method, adaptive=adaptive, step_param=step_param, pert_method=pert_method)
sol = solver.solve(tmax=tmax, n_samples=gen.n_samples, return_vars=['ys', 'events'])

In [None]:
# Solutions do not have to be the same, but should be similar
plt.plot(data.ts, data.vs.T, 'k.--');
plt.vlines(np.concatenate([elist[0] for elist in data.events]), ymin=gens[stims[0]].thresh-5, ymax=gens[stims[0]].thresh+5);
for i in range(max(3, gen.n_samples)):
    plt.plot(sol.get_ts(sampleidx=i), sol.get_ys(sampleidx=i, yidx=0), ':');
    plt.vlines(np.concatenate(sol.events[i]), ymin=gens[stims[0]].thresh-10, ymax=gens[stims[0]].thresh+10, color=f'C{i}');

plt.plot(data.acc_ts, data.acc_vs, 'r.--', zorder=100);
plt.vlines(data.acc_events, ymin=-20, ymax=20, color='c', ls=':');
plt.axhline(gens[stims[0]].thresh)
plt.show()

## Data

In [None]:
# pert_method, adaptive, methods, step_params
solver_params = [
    ('conrad', 0, ['EE'], [0.5, 0.25, 0.1, 0.05, 0.025, 0.01, 0.005]),
    ('conrad', 0, ['EEMP'], [0.5, 0.4, 0.25, 0.1, 0.025, 0.01]),
    ('conrad', 0, ['FE'], [0.05, 0.04, 0.025, 0.01, 0.005]),
    ('conrad', 0, ['RKBS'], [0.1, 0.08, 0.05, 0.025, 0.01]),
    ('conrad', 0, ['RKDP'], [0.1, 0.08, 0.05, 0.025]),
    ('conrad', 1, ['RKBS'], [1e-2, 1e-4, 1e-6, 1e-8]),
    ('conrad', 1, ['RKDP'], [1e-2, 1e-4, 1e-6, 1e-8, 1e-10]),
    ('conrad', 1, ['RKCK'], [1e-2, 1e-4, 1e-6, 1e-8, 1e-10]),
]

In [None]:
np.seterr(over='raise', invalid='raise') # Don't raise warnings, but errors

for stim, gen in gens.items():
    
    print('----------------------------------------------------------')
    print(stim, ':', gen.subfoldername)
    print('----------------------------------------------------------')
    
    for pert_method, adaptive, methods, step_params in solver_params:
        for step_param, method in itproduct(step_params, methods):
            gen.gen_and_save_data(
                method=method, adaptive=adaptive, step_param=step_param,
                pert_method=pert_method, overwrite=False, allowgenerror=True,
            )
            
np.seterr(over='warn', invalid='warn') # Reset

# Load data

In [None]:
from data_loader import data_loader

df = pd.DataFrame()
for stim, gen in gens.items():
    stim_df = data_loader(gen).load_data2dataframe(solver_params, drop_traces=False, MAE_SS_flat=True, allowgenerror=True)
    metric_utils.add_det_nODEcalls(stim_df, T=gen.tmax-gen.t0) # Add num. det ODE calls
    metric_utils.add_spike_time_errors(stim_df, min_distance=5, spike_idx=-1, use_abs=True) # Add STE, avoid jitter spikes in highly perturbed solutions
    stim_df['stimfun'] = stim
    stim_df['stim'] = stim.name
    df = df.append(stim_df, ignore_index=True)

# Plot

## Plot functions

In [None]:
def plot_voltage_trace_and_stim(ax, data_row, ylabel=None):
    """Plot voltage traces and stimulus"""
    
    pltu.plot_sample_trace(ax, data_row.acc_ts, data_row.acc_vs, label='ref.', intpol_dt=0.1)
    pltu.plot_mean_and_uncertainty(ax, data_row.ts, data_row.vs, intpol_dt=0.1, smpidxs=[0, 2])

    stim_ts = np.linspace(data_row.acc_ts[0], data_row.acc_ts[-1], 1001)
    pltu.plot_stim_on_trace_plot(
        ax, ts=stim_ts, stim=[data_row.stimfun.get_I_at_t(t) for t in stim_ts],
        stim_y0=-100, stim_y1=-87, fill=False,
    )

    ax.set_ylim((-85, 48))
    ax.set_yticks([-50, 0])
    pltu.move_xaxis_outward(ax, scale=7)
    if ylabel is not None: ax.set_ylabel(ylabel)
    ax.set_xlabel('Time (ms)', labelpad=-4)
    ax.set_xticks([data_row.t0, data_row.tmax])

## Figure

In [None]:
x = 'det_nODEcalls_per_time' 
ys = ['MAE_SR', 'spike_time_errors']
yscales = ['symlog', 'symlog']

assert len(ys) == len(yscales)

### Prepare axes ###
sbnx = len(stims)*df.pert_method.nunique()
sbny = 1+len(ys)

fig, axs = pltu.subplots(
    sbnx, sbny, sharex='row', sharey='row', xsize='text',
    ysizerow=1.0, yoffsize=0.3, squeeze=False,
    gridspec_kw=dict(height_ratios=[0.5]+[1]*len(ys))
)

plot_methods = set()
df_groups = df.groupby(['pert_method', 'stim'], sort=False)

### Plot stimuli ###
for ax, (_, group) in zip(axs[0,:], df_groups):
    data_row = frame_utils.get_data_rows(df=group, method='EE', adaptive=0, step_param=0.25, n_ex=1)
    assert data_row.n_samples > 0, 'Pick an example that terminated'
    plot_voltage_trace_and_stim(
        ax=ax, data_row=data_row, ylabel="v(t) (mV)" if ax == axs[0,0] else None,
    )

### Plot data ###
for i, y in enumerate(ys):
    for ax, (_, group) in zip(axs[1+i,:], df_groups):           
        for (method, adaptive, pert_param), subgroup in group.groupby(['method', 'adaptive', 'pert_param'], sort=False):

            pltu.plot_xy_percentiles(ax,
                datax=[data_row[x] for _, data_row in subgroup.iterrows() if data_row['n_samples'] > 0.0],
                datay=[data_row[y] for _, data_row in subgroup.iterrows() if data_row['n_samples'] > 0.0],
                marker=pltu.method2marker(method), color=pltu.method2color(method),
                mean_kw=dict(alpha=0.8, ls=pltu.mode2ls(adaptive), mfc=pltu.mode2mfc(adaptive), ms=6, clip_on=False),
                line_kw=dict(color='k', lw=0.5, alpha=1.0),
            )
            plot_methods.add(method)
        
        ax.set_xscale('log')
        if yscales[i]=='symlog':
            ax.set_yscale(yscales[i], linthreshy=1e-3)
        else:
            ax.set_yscale(yscales[i])
        
### Decorate ###
plot_methods = pltu.sort_methods(list(plot_methods))

# Labels.
if x == 'run_times':
    xlab = 'Run time (s)'
elif x == 'det_nODEcalls':
    xlab = 'N$_\mathrm{det.}$(ODE)'
elif x == 'det_nODEcalls_per_time':
    xlab = 'N$_\mathrm{det.}$(ODE) (1/ms)'
pltu.set_labs(axs[-1,:], xlabs=xlab)

for i, y in enumerate(ys):
    if y == 'MAE_SR':
        ylab = r'$\mathrm{MAE}_\mathrm{SR}$' + f' ({neuron.get_y_units()[0]})'
    elif y == 'MAE_SS':
        ylab = r'$\mathrm{MAE}_\mathrm{SS}$' + f' ({neuron.get_y_units()[0]})'
    elif y == 'spike_time_errors':
        ylab = f'STE ({neuron.get_t_unit()})'
    else:
        ylab = y
    axs[i+1,0].set_ylabel(ylab)

pltu.make_share_xlims(axs[1:,:])
for ax in axs[1:-1,:].flat: ax.set_xticklabels([])
for ax in axs[1:,:].flat:
    ax.set_ylim(0,None)
    pltu.grid(ax)
    
for ax in axs[1:,:].flat:
    ax.set_xlim(ax.get_xlim())
    ax.fill_between(ax.get_xlim(), [1e-3, 1e-3], color='lightgray', zorder=-1000)

fig.align_labels(axs)
sns.despine()
    
pltu.move_xaxis_outward(axs)
pltu.set_labs(axs, panel_nums='auto', panel_num_space=1, panel_num_va='center')
pltu.tight_layout(h_pad=-1, w_pad=1, rect=[0,0,1,1])
for ax in axs[1,:]: pltu.move_box(ax, dy=-0.015)

pltu.make_method_and_mode_legend(
    ax=axs[2,0], methods=plot_methods, example_method=plot_methods[-1], pert_method=df.pert_method.unique()[0],
    legend1_kw=dict(labelspacing=0.1, handlelength=0.2, borderpad=0.0, loc='lower left', bbox_to_anchor=(0.01,-0.08)),
    legend2_kw=dict(labelspacing=0.1, borderpad=0.0, loc='lower left', bbox_to_anchor=(0.22,-0.08)),
    mode_lines=True, mode_handlelength=2.0,
)

pltu.savefig("HH_summary_runtime")
plt.show()
pltu.show_saved_figure(fig)

## Text

In [None]:
print(frame_utils.get_data_rows(df=df, method='RKBS', adaptive=True, step_param=1e-2).det_nODEcalls_per_time.iloc[0].mean())
print(frame_utils.get_data_rows(df=df, method='RKBS', adaptive=True, step_param=1e-2).det_nODEcalls_per_time.iloc[1].mean())

In [None]:
print(frame_utils.get_data_rows(df=df, method='RKDP', adaptive=True, step_param=1e-2).det_nODEcalls_per_time.iloc[0].mean())
print(frame_utils.get_data_rows(df=df, method='RKDP', adaptive=True, step_param=1e-2).det_nODEcalls_per_time.iloc[1].mean())

In [None]:
print(frame_utils.get_data_rows(df=df, method='RKCK', adaptive=True, step_param=1e-2).det_nODEcalls_per_time.iloc[0].mean())
print(frame_utils.get_data_rows(df=df, method='RKCK', adaptive=True, step_param=1e-2).det_nODEcalls_per_time.iloc[1].mean())

In [None]:
print(frame_utils.get_data_rows(df=df, method='RKCK', adaptive=True, step_param=1e-6).det_nODEcalls_per_time.iloc[0].mean())
print(frame_utils.get_data_rows(df=df, method='RKCK', adaptive=True, step_param=1e-6).det_nODEcalls_per_time.iloc[1].mean())

In [None]:
print(frame_utils.get_data_rows(df=df, method='EE', adaptive=0, step_param=0.01).det_nODEcalls_per_time.iloc[0].mean())
print(frame_utils.get_data_rows(df=df, method='EE', adaptive=0, step_param=0.01).det_nODEcalls_per_time.iloc[1].mean())

In [None]:
print(frame_utils.get_data_rows(df=df, method='EE', adaptive=0, step_param=0.01).MAE_SS.iloc[0].mean())
print(frame_utils.get_data_rows(df=df, method='EE', adaptive=0, step_param=0.01).MAE_SS.iloc[1].mean())

In [None]:
print(frame_utils.get_data_rows(df=df, method='RKBS', adaptive=1, step_param=1e-6).det_nODEcalls_per_time.iloc[0].mean())
print(frame_utils.get_data_rows(df=df, method='RKBS', adaptive=1, step_param=1e-6).det_nODEcalls_per_time.iloc[1].mean())

In [None]:
print(frame_utils.get_data_rows(df=df, method='RKBS', adaptive=1, step_param=1e-6).MAE_SS.iloc[0].mean())
print(frame_utils.get_data_rows(df=df, method='RKBS', adaptive=1, step_param=1e-6).MAE_SS.iloc[1].mean())