# Hodgkin-Huxley | Number of samples

Compute the solutions for the HH neuron for:
- different stimuli
- different solvers

Bootstrap the MAE ratio as a function of the number of samples

In [None]:
import numpy as np
import pandas as pd

from matplotlib import pyplot as plt
import seaborn as sns

from itertools import product as itproduct

In [None]:
from sys import path as sys_path
from os.path import abspath as os_path_abspath
sys_path.append(os_path_abspath('..'))
import addpaths

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import math_utils
import metric_utils
import plot_utils as pltu

# Model

In [None]:
import hodgkin_huxley

t0, tmax = 0, 100
neuron = hodgkin_huxley.neuron()

In [None]:
import stim_utils

stim_onset, stim_offset = 10, tmax-10
stims = [
    stim_utils.Istim(Iamp=0.15, onset=stim_onset, offset=stim_offset, name='Step'),
    stim_utils.Istim_noisy(Iamp=0.2, onset=stim_onset, offset=stim_offset, name='Noisy', nknots=51, seed=1146),
]

In [None]:
for stim in stims: stim.plot(t0=t0, tmax=tmax)

# Generator

In [None]:
from data_generator_HH import data_generator_HH
from copy import deepcopy

gens = {}

for stim in stims:    
    neuron = deepcopy(neuron)
    neuron.get_Istim_at_t = stim.get_I_at_t
    
    gens[stim] = data_generator_HH(
        t0=t0, tmax=tmax, t_eval_adaptive=math_utils.t_arange(t0, tmax, 1),
        model=neuron, y0=neuron.compute_yinf(-65),
        n_samples=300, n_parallel=20,
        gen_det_sols=True, gen_acc_sols=True, acc_same_ts=True,
        base_folder='_data/n_samples_summary'
    )
    
    gens[stim].update_subfoldername(stim=stim.name)
    gens[stim].load_acc_sols_from_file()

In [None]:
# pert_method, adaptive, methods, step_params, pert_params
solver_params = [
    ('conrad', 0, ['EEMP'], [0.1], [1]),
    ('conrad', 1, ['RKBS'], [1e-3], [1]),
    
    ('abdulle', 0, ['EEMP'], [0.1], [0.1]),
    ('abdulle', 1, ['RKBS'], [1e-3], [0.1]),
]

## Data

In [None]:
for stim, gen in gens.items():
    
    print('----------------------------------------------------------')
    print(stim, ':', gen.subfoldername)
    print('----------------------------------------------------------')   
    
    for pert_method, adaptive, methods, step_params, pert_params in solver_params:
        for method, step_param, pert_param in itproduct(methods, step_params, pert_params):
            gen.gen_and_save_data(
                method=method, adaptive=adaptive, step_param=step_param,
                pert_method=pert_method, pert_param=pert_param,
                overwrite=False, plot=True,
            )

# Load data

In [None]:
from data_loader import data_loader

df = pd.DataFrame()
for stim, gen in gens.items():
    stim_df = data_loader(gen).load_data2dataframe(solver_params, drop_traces=True, MAE_SS_flat=False)
    stim_df['stimfun'] = stim
    stim_df['stim'] = stim.name
    df = df.append(stim_df, ignore_index=True)

In [None]:
n_samples_list = [2,3,5,10,20]
n_boot = 1000
average_fun = np.mean

df['avg_MAE_SS'] = None
for n in n_samples_list:
    df[f'avg_MAE_SS_n{n}'] = None
    df[f'avg_MAE_ratio_SS_n{n}'] = None

for i, data_row in df.iterrows():
    df.at[i, f'avg_MAE_SS'] = average_fun(data_row.MAE_SS[np.triu_indices(data_row.MAE_SS.shape[0], 1)])
    for n in n_samples_list:
        MAE_SS_n, _ = metric_utils.bootstrap_average_distance(
            dist_SS=data_row.MAE_SS, dist_SR=data_row.MAE_SR,
            n_samples=n, n_boot=n_boot, average_fun=average_fun
        )
            
        df.at[i, f'avg_MAE_SS_n{n}'] = np.asarray(MAE_SS_n)
        df.at[i, f'avg_MAE_ratio_SS_n{n}'] = df.at[i, f'avg_MAE_SS_n{n}'] / df.at[i, f'avg_MAE_SS']

# Figure

## Plot functions

How many samples are needed to get a good estimate of the SS distance and therefore a estimate for the SR distance.

In [None]:
def plot_R_distribution(ax, df, transform=None, lb=None, ub=None):
    """Plot R distributions as violins as a function of the stimulus and sample number"""
    assert df.method.nunique() == 1
    assert df.adaptive.nunique() == 1
    assert df.pert_method.nunique() == 1
    assert df.step_param.nunique() == 1
    
    if transform is None: transform = lambda x: x
    
    # Plot ratio distribution
    plot_df = pd.DataFrame()
    for _, data_row in df.iterrows():
        plot_df = plot_df.append(pd.DataFrame({
            'n': np.repeat(n_samples_list, n_boot),
            'stim': np.repeat(data_row.stim, n_boot*len(n_samples_list)),
            'R': transform(np.concatenate([data_row[f'avg_MAE_ratio_SS_n{n}'] for n in n_samples_list])),
        }))

    sns.violinplot(
        ax=ax, data=plot_df, x='n', y='R', hue='stim', scale="width",
        hue_order=[stim.name for stim in stims],
        palette=[stim2color[stim.name] for stim in stims],
        split=False, cut=0, linewidth=0.1, inner=None,
    )
    
    if lb is not None and ub is not None:
        ax.axhline(transform(lb), color='gray', ls=':', alpha=1.0, zorder=-1000)
        ax.axhline(transform(1), color='gray', ls='-', alpha=1.0, zorder=-1000)
        ax.axhline(transform(ub), color='gray', ls=':', alpha=1.0, zorder=-1000)

In [None]:
from matplotlib.ticker import FuncFormatter

def plot_R_summary(ax, df, lb, ub, marker, ls):
    for _, data_row in subgroup.iterrows():
        percentages = []
        for n in n_samples_list:
            percentages.append(
                np.mean((data_row[f'avg_MAE_ratio_SS_n{n}'] >= lb) & (data_row[f'avg_MAE_ratio_SS_n{n}'] <= ub))
            )
    
        ax.plot(percentages, color=stim2color[data_row.stim], marker=marker, ls=ls, label='_', clip_on=True, **line_kws)
        ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: '{:.0%}'.format(y)))
        print(data_row.stim, percentages)

## Plot

In [None]:
### Metric parameters ###
lb, ub = 0.5, 2.0

### Plot parameters ###
stim2color = {stim.name: pltu.neuron2color(i) for i, stim in enumerate(stims)}
markers = ['^', 'o', 'v']
lss = ['-', '--', ':']

line_kws = dict(lw=1, markeredgewidth=0.0, alpha=0.8)

### Create axes ###
fig, axs = pltu.subplots(3, 2, yoffsize=0.6, ysizerow=1, sharey=False, squeeze=False, xsize='text')

all_legend_handles, all_legend_labels = [], [] # Collect plotted methods

### Plot data ###
for ax_row, (method, group) in zip(axs, df.groupby(['method'], sort=False)):
        
    legend_handles, legend_labels = [], []

    for i, (ax, ((pert_method, adaptive), subgroup)) in enumerate(zip(ax_row, group.groupby(['pert_method', 'adaptive'], sort=False))):
        
        # Plot R distributions
        ax.set_title(pltu.method2label(method=method, adaptive=adaptive, pert_method=pert_method)) 
        plot_R_distribution(ax=ax, df=subgroup, transform=np.log10, lb=lb, ub=ub)

        # Plot R in bounds
        plot_R_summary(ax_row[-1], df=subgroup, lb=lb, ub=ub, marker=markers[i], ls=lss[i])
        
        # Collect legend handles
        legend_handles.append(pltu.get_legend_handle(
            marker=markers[i], ls=lss[i], color=stim2color[stims[0].name], altcolor=stim2color[stims[1].name], **line_kws
        ))
        legend_labels.append(pltu.method2label(
            method=method, adaptive=adaptive, pert_method=pert_method
        ))
          
    all_legend_handles.append(legend_handles)
    all_legend_labels.append(legend_labels)
                       
### Decorate ###
for ax in axs[:,0].flat: ax.set_ylabel(pltu.text2mathtext('log_10 (R_SS )'))
for ax in axs[:,-1]: ax.set_ylabel(pltu.text2mathtext('R_SS \in ' + f'[{lb}, {ub}]'))

for ax in axs[:,1:-1].flat:
    ax.set_ylabel(None)
    ax.legend().set_visible(False)

for ax in axs[1:,:-1].flat:
    ax.legend().set_visible(False)

# customize legends
ax = axs[0,0]
handles, labels = ax.get_legend_handles_labels()
labels = [pltu.stim2label(stim_name) for stim_name in labels]
ax.legend(handles=handles, labels=labels, loc='lower right')

# Labels and ticks
for ax in axs[-1,:]: ax.set_xlabel(r'No. samples')
for ax in axs[:-1,:].flat: ax.set_xlabel(None)
pltu.move_xaxis_outward(axs)

for ax in axs.flat:
    ax.set_xticks(range(len(n_samples_list)))
    ax.set_xticklabels(n_samples_list)

for ax_row in axs: pltu.make_share_ylims(axs[:,-1])
for ax in axs[:1,:].flat: ax.set_xticklabels([])
pltu.set_labs(axs, panel_nums='auto')
fig.align_ylabels()
fig.align_xlabels()
sns.despine()

pltu.tight_layout(w_pad=-2., h_pad=1.5, rect=[0, 0, 0.91, 0.995])
pltu.move_box(axs[:,-1], dx=0.08)

for ax, legend_handles, legend_labels in zip(axs[:,-1], all_legend_handles, all_legend_labels):
    ax.legend(
        handles=legend_handles, labels=legend_labels, loc='lower right',
        handlelength=2.3, bbox_to_anchor=(1,-0.1)
    )

pltu.savefig('HH_n_samples_summary')
pltu.show_saved_figure(fig)
plt.show()