# Izhikevich neuron network | Uncertainty analysis

In [None]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns

from itertools import product as itproduct

In [None]:
from sys import path as sys_path
from os.path import abspath as os_path_abspath
sys_path.append(os_path_abspath('..'))
import addpaths

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import metric_utils
import frame_utils
import plot_utils as pltu

# Model

In [None]:
import izhikevich_network

t0, tmax = 0, 70
network = izhikevich_network.network()
network.plot()

In [None]:
from scipy.interpolate import interp1d

network_smooth = izhikevich_network.network()

I_ts = np.append(np.arange(t0, tmax), tmax)
get_I_at_t = interp1d(x=I_ts, y=network_smooth.Is[:,:I_ts.size], kind='cubic')
network_smooth.get_I_at_t=get_I_at_t

In [None]:
ts = np.arange(0,7,0.2)
plt.plot(ts, [network.get_I_at_t(t)[0] for t in ts], '.-', label='network')
plt.plot(ts, [network_smooth.get_I_at_t(t)[0] for t in ts], '.-', label='network smooth')
plt.legend();

# Data

In [None]:
from data_generator_INN import data_generator_INN

n_samples = 20

# pert_method, adaptive, methods, step_params
solver_params = [    
    ('conrad', 0, ['FE', ], [0.5, 0.05, 0.005]),
    ('conrad', 0, ['RKBS'], [1.0, 0.5, 0.05, 0.005]),
    ('conrad', 0, ['RKDP'], [1.0, 0.5, 0.05, 0.005]),
    
    ('conrad', 2, ['FE', ], [0.5, 0.05, 0.005]),
    ('conrad', 2, ['RKBS'], [1.0, 0.5, 0.05, 0.005]),
    ('conrad', 2, ['RKDP'], [1.0, 0.5, 0.05, 0.005]),
    
    ('conrad', 1, ['RKBS'], [1e-2, 1e-5, 1e-8]),
    ('conrad', 1, ['RKDP'], [1e-2, 1e-5, 1e-8]),
]

## Step stimulus

In [None]:
gen_steps = data_generator_INN(
    t0=t0, y0=network.y0, tmax=tmax, model=network,
    n_samples=n_samples, n_parallel=2, gen_det_sols=False, gen_acc_sols=True,
    acc_same_ts=False, return_vars=['events'],
    base_folder='_data/INN_uncertainty'
)
gen_steps.load_acc_sols_from_file()

In [None]:
for pert_method, adaptive, methods, step_params in solver_params:
    for step_param, method in itproduct(step_params, methods):
        
        gen_steps.gen_and_save_data(
            method=method, adaptive=adaptive, step_param=step_param, pert_method=pert_method,
            allowgenerror=False, plot=True, overwrite=False,
        )

## Smooth stimulus

In [None]:
gen_smooth = data_generator_INN(
    t0=t0, y0=network_smooth.y0, tmax=tmax, model=network_smooth,
    n_samples=n_samples, n_parallel=2, gen_det_sols=False, gen_acc_sols=True,
    acc_same_ts=False, return_vars=['events'],
    base_folder='_data/INN_uncertainty_smooth_currents'
)
gen_smooth.load_acc_sols_from_file()

In [None]:
for pert_method, adaptive, methods, step_params in solver_params:
    for step_param, method in itproduct(step_params, methods):
        
        gen_smooth.gen_and_save_data(
            method=method, adaptive=adaptive, step_param=step_param, pert_method=pert_method,
            allowgenerror=False, plot=False, overwrite=False,
        )

# Load data

In [None]:
from data_generator_INN import data_loader_INN

df_steps = data_loader_INN(gen_steps).load_data2dataframe(solver_params, MAEs=True)
df_steps['stim'] = 'steps'

df_smooth = data_loader_INN(gen_smooth).load_data2dataframe(solver_params, MAEs=True)
df_smooth['stim'] = 'smooth'

df = df_steps.append(df_smooth, ignore_index=True)
df = df.sort_values(by=['stim', 'method', 'step_param', 'adaptive'])
df.columns

In [None]:
metric_utils.add_det_nODEcalls(df, T=tmax-t0)
df.columns

### Plot traces

In [None]:
def plot_kde_traces_and_summary(ax, data_row, title=None, xlim=None):
    """Plot kde traces and reference"""
        
    pltu.plot_sample_trace(ax=ax, ts=data_row.kde_ts, ys=data_row.acc_kde, label='ref.', tr_kw=dict(clip_on=False, zorder=10000))
        
    pltu.plot_mean_and_uncertainty(
        ax=ax, ts=data_row.kde_ts, ys=data_row.kdes, smpidxs=np.arange(data_row.n_samples),
        tr_mean_kw=dict(clip_on=False), tr_bnds_kw=dict(clip_on=False), tr_smp_kw=dict(clip_on=False),
        qs=(10,90),
    )    
    
    if title is not None:
        if title == 'step_param':
            ax.set_title(
                pltu.step_param2label(adaptive=data_row.adaptive, step_param=data_row.step_param)
            )
        else:
            ax.set_title(
                pltu.method2label(method=data_row.method, pert_method=data_row.pert_method,
                adaptive=data_row.adaptive, step_param=data_row.step_param)
            )

In [None]:
for stim in ['steps', 'smooth']:
    print(stim)
    fig, axs = pltu.auto_subplots(
        df[df.stim==stim].shape[0], xsize='fullwidth',
        max_nx_sb=3, max_ny_sb=20, ysizerow=1.3,
    )
    for ii, (i, data_row) in enumerate(df[df.stim==stim].sort_values(['adaptive', 'method', 'step_param']).iterrows()):
        plot_kde_traces_and_summary(axs.flat[ii], data_row, title=True)
    axs[0,0].legend()
    plt.tight_layout()
    plt.show()

### Plot run time vs MAE

In [None]:
def plot_run_time_vs_MAE_SR(ax, df, x='run_times', y='MAE_SR'):
    """Plot runtime vs MAE"""
    ax.set(xscale='log', yscale='log')

    for (method, adaptive, pert_method), df_group in df.groupby(['method', 'adaptive', 'pert_method']):

        df_group = df_group.sort_values(['step_param'])

        pltu.plot_xy_percentiles(
            ax,
            datax=[data_row[x] for _, data_row in df_group.iterrows()],
            datay=[data_row[y] for _, data_row in df_group.iterrows()],
            color=pltu.method2color(method), marker=pltu.method2marker(method),
            mean_kw=dict(alpha=0.6, ls=pltu.mode2ls(adaptive), mfc=pltu.mode2mfc(adaptive), ms=6, clip_on=False),
            line_kw=dict(color=pltu.method2color(method), lw=0.5, alpha=1.0, clip_on=False),
            connect=True, connect_alpha=0.4,
        )
    
    if x == 'run_times':
        ax.set_xlabel('Run time (s), prob.')
    elif x == 'det_nODEcalls':
        ax.set_xlabel('N$_\mathrm{det.}$(ODE)')
    elif x == 'det_nODEcalls_per_time':
        ax.set_xlabel('N$_\mathrm{det.}$(ODE) (1/ms)')
    else:
        ax.set_xlabel(x)
    if y == 'MAE_SR': ax.set_ylabel(r'$\mathrm{MAE}_\mathrm{SR}$ (Hz)')

In [None]:
# Sort by adaptive
fig, axs = pltu.subplots(3,2,xsize='fullwidth', ysizerow=2, sharex=True, sharey=True)
for ax_row, stim in zip(axs, ['steps', 'smooth']):
    for ax, (name, df_group) in zip(ax_row, df[df.stim==stim].groupby(['adaptive'])):
        ax.set_title(f"adaptive = {name}")
        plot_run_time_vs_MAE_SR(ax=ax, df=df_group, x='det_nODEcalls_per_time')    
plt.tight_layout()

In [None]:
# Sort by method
fig, axs = pltu.subplots(3,2,xsize='fullwidth', ysizerow=2, sharex=True, sharey=True)
for ax_row, stim in zip(axs, ['steps', 'smooth']):
    for ax, (name, df_group) in zip(ax_row, df[df.stim==stim].groupby(['method'])):
        ax.set_title(name)
        plot_run_time_vs_MAE_SR(ax=ax, df=df_group, x='det_nODEcalls_per_time')    
plt.tight_layout()

### Plot step size vs MAE

In [None]:
def plot_step_param_vs_MAE_SR(ax, df):
    """Plot runtime vs MAE"""
    
    assert np.all(np.unique(df.adaptive) == np.array([0,2]))
    
    step_params = np.flip(np.unique(df.step_param))
    
    ax.set(xscale='linear', yscale='log')
    
    groups = df.groupby(['method', 'adaptive'])
    
    for i, ((method, adaptive), group) in enumerate(groups):

        MAEs = [frame_utils.get_data_rows(group, n_ex=1, step_param=step_param)['MAE_SR']
                for step_param in step_params]

        pltu.plot_percentiles(
            ax, data=MAEs, connect=False,
            positions=pltu.get_x_positions(n_positions=len(MAEs), idx=i, n_idxs=len(groups), offset=0.12),
            color=pltu.method2color(method), marker=pltu.method2marker(method),
            mean_kw=dict(alpha=0.8, ls=pltu.mode2ls(adaptive), mfc=pltu.mode2mfc(adaptive), ms=6),

        )
    
    ax.set_xlabel('Step-size (ms)')
    ax.set_ylabel(r'$\mathrm{MAE}_\mathrm{SR}$ (Hz)')
    
    ax.set_xticks(range(len(step_params)))
    ax.set_xticklabels(step_params)
    
    pltu.grid(ax, axis='y')

In [None]:
fig, axs = pltu.subplots(
    np.unique(df.stim).size, 1,
    xsize='text', ysizerow=1.6, sharex=False, sharey=False, squeeze=False
)

for ax, stim in zip(axs.flat, ['steps', 'smooth']):
    plot_step_param_vs_MAE_SR(ax=ax, df=df[(df.stim==stim) & (df.adaptive != 1) & ((df.step_param == 0.5) | (df.step_param == 0.05) | (df.step_param == 0.005))])
    
for ax in axs[:,1:].flat: ax.set_ylabel(None)
for ax in axs[:-1,:].flat: ax.set_xlabel(None)
plt.tight_layout()
sns.despine()
pltu.move_xaxis_outward(axs, scale=3)

# Figure - Pseudo-fixed step size

In [None]:
### Select parameters to plot ###
pert_method = 'conrad'

plot_solver_params = [
    [dict(stim='steps', pert_method=pert_method, adaptive=0, method='FE', step_param=0.5),
     dict(stim='steps', pert_method=pert_method, adaptive=0, method='FE', step_param=0.05),
     dict(stim='steps', pert_method=pert_method, adaptive=0, method='FE', step_param=0.005)],
    [dict(stim='steps', pert_method=pert_method, adaptive=2, method='FE', step_param=0.5),
     dict(stim='steps', pert_method=pert_method, adaptive=2, method='FE', step_param=0.05),
     dict(stim='steps', pert_method=pert_method, adaptive=2, method='FE', step_param=0.005)],
]

### Create figure ###
fig = plt.figure(figsize=(pltu.TEXT_WIDTH, 4))

n_rows = 3
n_cols = 6

axs = []
for i in range(2):
    axs.append([plt.subplot2grid((n_rows, n_cols), (i, j), colspan=2) for j in np.arange(0,6,2)])
axs = np.asarray(axs)
    
summary_axs = [plt.subplot2grid((n_rows, n_cols), (2, j), colspan=3) for j in np.arange(0,6,3)]


### Plot data ###

# Traces
for i, params_list in enumerate(plot_solver_params):
    for j, params in enumerate(params_list):
        plot_kde_traces_and_summary(axs[i][j], frame_utils.get_data_rows(df, n_ex=1, **params), title='step_param')
    
# Summaries
for ax, stim, ttl in zip(summary_axs, ['steps', 'smooth'], ['Step stimulus', 'Smooth stimulus']):
    ax.set_title(ttl)
    plot_step_param_vs_MAE_SR(ax=ax, df=df[
        (df.stim==stim) & (df.adaptive != 1) & ((df.step_param == 0.5) | (df.step_param == 0.05) | (df.step_param == 0.005))
    ])
    
#### Decorate ###
for ax in axs[1:,:].flat:
    ax.set_ylabel(None)
for ax in summary_axs[1:]:
    ax.set_ylabel(None)
for ax in axs[:-1,:].flat:
    ax.set_xticklabels([])
    ax.set_xlabel(None)
    
for ax in axs[-1,:]:
    ax.set_xlabel('Time (ms)')
for ax in axs[:,0]:
    ax.set_ylabel('Rate (Hz)')

fig.align_labels()
    
for ax, adaptive in zip(axs[:,-1], [0, 2]):
    ax.yaxis.set_label_position("right")
    ax.set_ylabel(pltu.mode2label(adaptive), labelpad=10, va='top')

summary_axs[-1].set_yticks(10.**np.arange(-3,1.1))
     
pltu.make_share_ylims(axs, ylim=(0,75))
for ax in axs[:,1:].flat: ax.set_yticklabels([])
pltu.tight_layout(h_pad=-0.2, w_pad=0.2, rect=[0,0.12,1,0.99])
sns.despine()
pltu.move_xaxis_outward(axs, scale=3)
pltu.move_xaxis_outward(summary_axs, scale=3)
pltu.set_labs(list(axs.flatten())+summary_axs, panel_nums='auto', panel_num_space=2)

pltu.move_box(axs[-1,:], dy=+0.03)
pltu.move_box(summary_axs, dy=-0.1)
pltu.change_box(summary_axs, dy=+0.07)
pltu.move_box(summary_axs[-1], dx=+0.02)
    
# Legends
axs[0,-1].legend(loc='upper left', bbox_to_anchor=(0.02, 1.05))

pltu.make_method_and_mode_legend(
    ax=summary_axs[0], methods=df.method.unique(), modes=[0,2], pert_method=pert_method,
    legend1_kw=dict(loc='lower left', bbox_to_anchor=(0.0, -0.05)),
    legend2_kw=dict(loc='lower left', bbox_to_anchor=(0.3, -0.05)), mode_lines=False,
)

pltu.savefig("INN_fixed_vs_pdeuso_fixed")
plt.show()

pltu.show_saved_figure(fig)

# Figure - Run time vs MAE

In [None]:
### Figure and axes ###
fig, axs = pltu.subplots(2, 1, ysizerow=2.3, sharex='row', sharey=True, squeeze=False)

plot_df = df[(df.method != 'RKCK')]

### Plot data ###
for i, (ax, stim) in enumerate(zip(axs.flat, ['steps', 'smooth'])):
    ax.set_title('Step stimulus' if stim == 'steps' else 'Smooth stimulus')
    plot_run_time_vs_MAE_SR(ax=ax, df=plot_df[plot_df.stim==stim], x='det_nODEcalls_per_time') 

### Decorate ###
for ax in axs[:,1:].flat:
    ax.set_ylabel(None)
    
for ax in axs.flat:
    ax.set_yscale('symlog', linthreshy=1e-1)
    ax.set_ylim(0, None)
    ax.set_xlim(ax.get_xlim())
    ax.fill_between(ax.get_xlim(), [1e-1, 1e-1], color='lightgray', zorder=-1000)

pltu.make_method_and_mode_legend(
    ax=axs[0,0], methods=pltu.sort_methods(np.unique(plot_df.method)), pert_method=pert_method,
    legend2_kw=dict(loc='lower left', bbox_to_anchor=(0.29,0)), modes=np.unique(plot_df.adaptive),
    mode_lines=True,
)
    
fig.align_labels()
pltu.move_xaxis_outward(axs, scale=5)
pltu.set_labs(axs, panel_nums='auto', panel_num_space=2)

pltu.tight_layout(h_pad=0.2, rect=[0,0,1,1], w_pad=1.5)
sns.despine()

pltu.grid(axs[0,:], axis='both')
pltu.savefig("INN_runtime")
plt.show()

pltu.show_saved_figure(fig)

In [None]:
ax = plt.subplot(111)
legend_handles = pltu._get_mode_legend_handles(method='FE', mean_kw=dict(), modes=[0,1])
ax.legend(handles=legend_handles)
#pltu.make_method_and_mode_legend(ax=ax, methods=['FE', 'HN'], mode_lines=True)

# Text

In [None]:
frame_utils.get_data_rows(df=df, method='FE', step_param=0.5, stim='smooth', adaptive=0, pert_method='conrad').det_nODEcalls_per_time

# Appendix

In [None]:
raise

### Why is this one solution so bad?

In [None]:
plot_kde_traces_and_summary(
    plt.subplot(121),
    data_row=frame_utils.get_data_rows(df=df, method='RKDP', adaptive=0, step_param=0.05, stim='smooth', n_ex=1),
    title=True
)

plot_kde_traces_and_summary(
    plt.subplot(122),
    data_row=frame_utils.get_data_rows(df=df, method='RKBS', adaptive=0, step_param=0.5, stim='smooth', n_ex=1),
    title=True
)

In [None]:
np.concatenate(frame_utils.get_data_rows(df=df, method='RKBS', adaptive=0, step_param=0.5, stim='smooth', n_ex=1).acc_events).size

### Why is pseudo-fixed more expensive?

In [None]:
data_row_f = frame_utils.get_data_rows(df, n_ex=1, method='RKBS', adaptive=0, step_param=0.5, stim='smooth')
data_row_pf = frame_utils.get_data_rows(df, n_ex=1, method='RKBS', adaptive=2, step_param=0.5, stim='smooth')

In [None]:
plt.boxplot(np.sort(data_row_f.run_times / data_row_f.nODEcalls), positions=[0])
plt.boxplot(np.sort(data_row_pf.run_times / data_row_pf.nODEcalls), positions=[1])
plt.show()

In [None]:
1e3*np.mean(data_row_f.run_times / data_row_f.nODEcalls)

In [None]:
1e3*np.mean(data_row_pf.run_times / data_row_pf.nODEcalls)