# Hodgkin-Huxley | Calibration

Compute the solutions for the HH neuron for:
- different stimuli
- different solvers
- different perturbation parameters

In [None]:
import numpy as np
import pandas as pd

from matplotlib import pyplot as plt
import seaborn as sns

from itertools import product as itproduct

In [None]:
from sys import path as sys_path
from os.path import abspath as os_path_abspath
sys_path.append(os_path_abspath('..'))
import addpaths

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import hodgkin_huxley
import ode_solver

import stim_utils
import math_utils
import frame_utils
import metric_utils
import plot_utils as pltu

# Model

In [None]:
t0, tmax = 0, 100
neuron = hodgkin_huxley.neuron()

In [None]:
stim_onset, stim_offset = 10, tmax-10
stims = [
    stim_utils.Istim(Iamp=0.15, onset=0, offset=tmax*2, name='Constant'),
    stim_utils.Istim(Iamp=0.15, onset=stim_onset, offset=stim_offset, name='Step'),
    stim_utils.Istim_noisy(Iamp=0.2, onset=stim_onset, offset=stim_offset, name='Noisy', nknots=51, seed=1146),
    stim_utils.Istim_noisy_no_offset(Iamp=0.01, onset=stim_onset, offset=stim_offset, nknots=51, seed=42, name='Subthreshold'),
]
for stim in stims: stim.plot(t0=t0, tmax=tmax)

# Generator

In [None]:
from data_generator_HH import data_generator_HH
from copy import deepcopy

gens = {}

max_step = 1.0
base_folder='_data/heatmap_pert_param'
if max_step != 1.0: base_folder += f'_max_step_{max_step!r}'
print(base_folder)
    
for stim in stims:
    
    neuron = deepcopy(neuron)
    neuron.get_Istim_at_t = stim.get_I_at_t
    
    gen = data_generator_HH(
        y0=neuron.compute_yinf(-65), t0=t0, tmax=tmax,
        t_eval_adaptive=math_utils.t_arange(t0, tmax, 1), max_step=max_step,
        model=neuron, n_samples=100, n_parallel=20, dt_min_eval_fixed=0.1,
        gen_det_sols=True, gen_acc_sols=True, acc_same_ts=True,
        base_folder=base_folder,
    )
    gen.update_subfoldername(stim=stim.name)
    gen.load_acc_sols_from_file()    
    
    gens[stim] = gen

## Test

In [None]:
gen = gens[stims[-1]]

method = 'RKBS'
step_param = 1e-4

data1 = gen.gen_data(method=method, adaptive=1, pert_method='abdulle_ln', step_param=step_param, pert_param=1e-9)
data2 = gen.gen_data(method=method, adaptive=1, pert_method='abdulle_ln', step_param=step_param, pert_param=100)
data3 = gen.gen_data(method=method, adaptive=1, pert_method=None, step_param=step_param, pert_param=0.0)

ax = plt.subplot(111)
plt.plot(data2.ts, data2.vs.T, c='b', alpha=0.2);
plt.plot(data1.ts, data1.vs.T, c='r', alpha=0.2);
plt.plot(data3.ts, data3.vs.T, 'kx');
plt.plot(gen.acc_sols[0].ts, gen.acc_sols[0].get_ys(yidx=0), 'g:')
plt.show()

# Data

In [None]:
def get_pert_params(explb, expub, expstep):
    """Get pert params for given exponents"""
    pert_params = 10**np.arange(explb,expub,float(expstep))
    pert_params = np.append(pert_params, 10**expub)
    return pert_params


# pert_method, adaptive, methods, step_params, pert_params
solver_params = [
    # Conrad
    ('conrad', 0, ['EE', 'EEMP'], [0.25, 0.025], get_pert_params(-2, 1, 0.5)),   
    ('conrad', 0, ['FE', 'HN'], [0.05, 0.025], get_pert_params(-2, 1, 0.5)),
    ('conrad', 1, ['RKBS', 'RKDP'], [1e-2, 1e-4], get_pert_params(-4, 1, 1)),
    
    # Abdulle, uniform
    ('abdulle', 0, ['EE', 'EEMP'], [0.25], get_pert_params(-2, 0, 0.5)),
    ('abdulle', 0, ['EE'], [0.025], get_pert_params(-2, 1, 0.5)),
    ('abdulle', 0, ['EEMP'], [0.025], get_pert_params(-2, 2, 0.5)),
    
    ('abdulle', 0, ['FE', 'HN'], [0.05], get_pert_params(-2, 0, 0.5)),
    ('abdulle', 0, ['FE'], [0.025], get_pert_params(-2, 1, 0.5)),
    ('abdulle', 0, ['HN'], [0.025], get_pert_params(-2, 2, 0.5)),
    
    ('abdulle', 1, ['RKBS', 'RKDP'], [1e-2, 1e-4], get_pert_params(-4, 0, 1)),
    
     # Abdulle, lognormal
    ('abdulle_ln', 0, ['EE', 'EEMP'], [0.25, 0.025], get_pert_params(-2, 2, 0.5)),
    ('abdulle_ln', 0, ['FE', 'HN'], [0.05, 0.025], get_pert_params(-2, 2, 0.5)),    
    ('abdulle_ln', 1, ['RKBS', 'RKDP'], [1e-2, 1e-4], get_pert_params(-4, 2, 1)),
 ]

In [None]:
np.seterr(over='raise', invalid='raise') # Don't raise warnings, but errors

for stim, gen in gens.items():

    print('----------------------------------------------------------')
    print(stim, ':', gen.subfoldername)
    print('----------------------------------------------------------') 
    
    for pert_method, adaptive, methods, step_params, pert_params in solver_params:
        for step_param, method, pert_param in itproduct(step_params, methods, pert_params):
            gen.gen_and_save_data(
                method=method, adaptive=adaptive, step_param=step_param,
                pert_method=pert_method, pert_param=pert_param,
                overwrite=False, allowgenerror=True,
            )
            
np.seterr(over='warn', invalid='warn') # Reset

# Load data

In [None]:
from data_loader import data_loader

df = pd.DataFrame()

for stim, gen in gens.items():
    stim_df = data_loader(gen).load_data2dataframe(
        solver_params=solver_params, allowgenerror=True, MAEs=True, MAE_SS_flat=True, drop_traces=True
    )
    
    stim_df['stimfun'] = stim
    stim_df['stim'] = stim.name
    
    df = df.append(stim_df, ignore_index=True)

In [None]:
metric_utils.add_MAE_metrics_to_df(df, metric='mean')

In [None]:
df.head()

# Plot - Summary

## Plot functions

In [None]:
def stim2kwargs(stim):
    if stim == 'Constant':
        return dict(color=pltu.neuron2color(0), marker='+', ls='-', label='Const.', lw=0.8, alpha=0.7)
    elif stim == 'Step':
        return dict(color=pltu.neuron2color(1), marker='.', ls='--', label='Step', lw=0.8, alpha=0.7)
    elif stim == 'Noisy':
        return dict(color=pltu.neuron2color(2), marker='x', ls=':', label='Noisy', lw=0.8, alpha=0.7)
    elif stim == 'Subthreshold':
        return dict(color=pltu.neuron2color(3), marker='*', ls='dashdot', label='Subth.', lw=0.8, alpha=0.7)
    else:
        return dict(color='k', marker='.', ls='-', label=stim, lw=0.8, alpha=0.7)

In [None]:
def plot_R(ax, data_rows, mode='N', plot_kw=dict(marker='.', ls = '-', c='k'), lim_xticks=False):
    """Plot MAE ratio for data rows."""
    
    if mode == 'N':
        ydata = data_rows.MAE_ratio_RN
        ylabel = "R_N"
    elif mode == 'D':
        ydata = data_rows.MAE_ratio_RD
        ylabel = "R_D"
    elif mode == 'ND':
        ydata = data_rows.MAE_ratio_RNRD
        ylabel = "R_N R_D \,"
    
    ax.axhline(1, color='gray', ls=':', zorder=-100, clip_on=False)  
    ax.set_yticks([0,1])
    ax.set_ylim([0,1.1])
    
    ax.plot(np.log10(data_rows.pert_param), np.minimum(ydata, 1.1),
            clip_on=False, **plot_kw)
    
    ax.plot(np.log10(data_rows.pert_param)[ydata>1.1], np.minimum(ydata[ydata>1.1], 1.1)[ydata>1.1],
            clip_on=False, color='k', marker='o', fillstyle='none', alpha=0.3)
    
    ax.axvline(0, color='darkred', ls='-', zorder=-100)
    ax.set_ylabel(pltu.text2mathtext(ylabel), labelpad=-5)
    ax.set_xlabel(pltu.text2mathtext("log_10 " + f"({pltu.pert_param_symbol})"))
    
    if lim_xticks:
        max_lpp = np.log10(data_rows.pert_param).max()
        min_lpp = np.log10(data_rows.pert_param).min()

        ax.set_xlim(min_lpp-0.4, np.maximum(max_lpp, 0)+0.4)
        ax.set_xticks([min_lpp, max_lpp])
        ax.set_xticklabels([f"{min_lpp}".replace(".0", ""), f"{max_lpp}".replace(".0", "")])

## Figure

In [None]:
fig, axs = pltu.subplots(
    df.method.nunique(), 6, xsize='text',
    yoffsize=0.4, ysizerow=0.9, squeeze=False, sharey='row',
    gridspec_kw=dict(width_ratios=[1,1.2]*3)
)

pltu.tight_layout(h_pad=1.4, w_pad=0.2, rect=[0.12,0.11,1.02,0.94])

### Plot data ###
for i, ((method, adaptive), supergroup) in enumerate(df.groupby(by=["method", "adaptive"], sort=False)):
    pltu.row_title(axs[i,0], pltu.method2label(method=method, adaptive=adaptive), pad=35)
    
    for j, ((pert_method, step_param), group) in enumerate(supergroup.groupby(by=["pert_method", "step_param"], sort=False)):

        ax = axs[i,j]

        ax.set_title(pltu.step_param2label(step_param=step_param, adaptive=adaptive).replace('ms', ''), fontsize=plt.rcParams['font.size'], pad=4)

        for stim, subgroup in group.groupby(by=["stim"], sort=False):
            plot_R(ax=ax, data_rows=subgroup, mode='ND', plot_kw=stim2kwargs(stim), lim_xticks=True)
            ax.set_xlabel(pltu.text2mathtext("log_10 " + f"({pltu.pert_param_symbol})"))

            
        if i+1 == df.method.nunique() and j%2==0:
            if pert_method == 'conrad':
                ttl = 'State pert.' + '\n' + ' '
            elif pert_method == 'abdulle':
                ttl = 'Step-size pert.' + '\n' + 'uniform'
            elif pert_method == 'abdulle_ln':
                ttl = 'Step-size pert.' + '\n' + 'log-normal'
            else:
                ttl = 'error'

            ax.text(1.15, -1.45, ttl, transform=ax.transAxes, va='center', ha='center', fontsize=plt.rcParams['axes.titlesize'])
        
            
### Decorate ###
for ax in axs[:-1,:].flat: ax.set_xlabel(None)
for ax in axs[:,1:].flat: ax.set_ylabel(None)

box0 = np.array(axs.flat[0].get_position().bounds)
for ax in axs.flat: 
    box = np.array(ax.get_position().bounds)
    ax.set_position([box[0], box[1], box0[2], box[3]])

for i, ax in enumerate(axs[0,:]):
    panel_num = 'ABCDEFGHI'[i]
    ax.text(
        0.5, 1.5, panel_num, va='bottom', ha='center', fontweight='bold',
        transform=ax.transAxes, fontsize=plt.rcParams['axes.titlesize']
    )        
    
sns.despine()
pltu.move_xaxis_outward(axs, scale=2)
axs[0,0].legend(loc='upper right', frameon=False, bbox_to_anchor=(-0.35,0.11), handlelength=2.4)

pltu.savefig("calibration_lineplots")
plt.show()
pltu.show_saved_figure(fig)

## Text

In [None]:
frame_utils.get_data_rows(df=df, method='RKBS', adaptive=1, step_param=1e-2, stim='Constant', pert_method='conrad').MAE_ratio_RNRD.max()

In [None]:
frame_utils.get_data_rows(df=df, method='RKDP', adaptive=1, step_param=1e-2, stim='Constant', pert_method='conrad').MAE_ratio_RNRD.max()

# Plot - Histograms and Sigmoid

## Plot functions

In [None]:
def plot_MAE_hist(ax, data_row=None, log=False):
    """Plot MAE histograms of SS and SR, also plot DR"""
    
    if data_row.n_samples == 0: return
    
    # Extract data
    assert (MAE_SS is None) and (MAE_SR is None) and (MAE_DR is None)
    MAE_SS = data_row['MAE_SS']
    MAE_SR = data_row['MAE_SR']
    MAE_DR = data_row['MAE_DR']
    
    # Flatten?
    if MAE_SS.ndim == 2:
        MAE_SS = MAE_SS[np.triu_indices(MAE_SS.shape[0], 1)]
    
    # Transform data?
    if log:
        MAE_SS = np.log10(MAE_SS)
        MAE_SR = np.log10(MAE_SR)
        MAE_DR = np.log10(MAE_DR)
                
    # Plot
    for i, MAE in enumerate([MAE_SS, MAE_SR]):
        
        label = ['SS', 'SR'][i]
        color = ['red', 'C0'][i]
        
        sns.histplot(
            MAE, ax=ax, label=label, color=color, bins=15,
            stat="probability", element="step", fill=True, alpha=0.2
        )
        ax.plot(np.mean(MAE), 0, color=color, ls='-', marker='d')

    ax.axvline(MAE_DR, color='grey', label='DR', ls='--', zorder=1000)

## Figure

In [None]:
### Parameters ###
pert_methods = ['conrad', 'abdulle', 'abdulle_ln']
R_modes = ['N', 'D', 'ND']

### Create axes ###
n_R_cols = len(R_modes)

fig, axs = pltu.subplots(
    3+n_R_cols, len(pert_methods), squeeze=False,
    sharex='col', sharey=False, xsize='text',
    ysizerow=0.9, yoffsize=0.4,
)

pltu.tight_layout(w_pad=-1.3, h_pad=2, rect=[0.1,0.07,0.89,0.93])

### Plot data ###
for ax_row, pert_method in zip(axs, pert_methods):
    
    # Row title
    if pert_method == 'conrad':
        ttl = 'State pert.' + '\n' + ' '
    elif pert_method == 'abdulle':
        ttl = 'Step-size pert.' + '\n' + 'uniform'
    elif pert_method == 'abdulle_ln':
        ttl = 'Step-size pert.' + '\n' + 'log-normal'
        
    pltu.row_title(ax_row[0], ttl, pad=16, rotation=90, va='center', ha='center', fontsize=plt.rcParams['font.size'])
    
    # Get data
    data_rows = frame_utils.get_data_rows(
        df=df, method='EEMP', adaptive=0, step_param=0.025, pert_method=pert_method, stim='Noisy'
    )
    
    # Plot ratios
    for ax, R_mode in zip(ax[-n_R_cols:], R_modes):
        plot_R(ax=ax, data_rows=data_rows, mode=R_mode)
    
    # Plot histograms
    data_rows = data_rows[(data_rows.pert_param == 0.1) | (data_rows.pert_param == 1) | (data_rows.pert_param == 10)]
    for ax, (i, data_row) in zip(ax_row[:-2], data_rows.iterrows()):
        
        ax.set_xlabel(pltu.text2mathtext("log_10 (MAE)"))
        
        if ax_row[0] == axs[0,0]:
            title = ax.set_title(
                pltu.text2mathtext(f"{pltu.pert_param_symbol}=" + f"{data_row.pert_param:.1f}".replace('.0', ''))
            )
            if data_row.pert_param == 1.0: title.set_color('darkred')
        
        if data_row.n_samples == 0:
            ax.text(0.5, 0.5, 'No data', transform=ax.transAxes, ha='center', va='center')
        else:
            plot_MAE.plot_MAE_hist(ax=ax, data_row=data_row, log=True)
    
### Decorate ###
pltu.set_labs(axs, panel_nums='auto', panel_num_va='center')

for ax in axs[:-1,:].flat:
    ax.set_xlabel(None)
for ax in axs[:,0].flat:
    ax.set_ylabel('Norm. count')
    ax.set_yticklabels([0,1])
for ax in axs[:,1:-n_R_cols].flat: ax.set_ylabel(None)
for ax in axs[:,:-n_R_cols].flat:
    ax.set_yticks([0, ax.get_ylim()[1]*0.95])
    ax.set_xticks([-2,0])
for ax in axs[:,1:-n_R_cols].flat: ax.set_yticklabels([])

pltu.make_share_xlims(axs[:,:-n_R_cols])
pltu.make_share_xlims(axs[:,-n_R_cols:])
    
fig.align_labels()
sns.despine()
                
pltu.move_xaxis_outward(axs, scale=2)

pltu.move_box(axs=axs[:,-3], dx=+0.05)
pltu.move_box(axs=axs[:,-2], dx=+0.08)
pltu.move_box(axs=axs[:,-1], dx=+0.11)

axs[0,2].legend(loc='upper left', frameon=False, handlelength=0.7)

pltu.savefig('calibration_histograms')
pltu.show_saved_figure(fig)

# Appendix

In [None]:
raise

## Plot all histograms

In [None]:
### Plot data ###
for (method, adaptive, step_param, pert_method), group in\
        df.groupby(['method', 'adaptive', 'step_param', 'pert_method']):
    
    print(method, adaptive, step_param, pert_method)
    
    fig, axs = pltu.subplots(
        group.pert_param.nunique(), group.stim.nunique(),
        squeeze=True, xsize='fullwidth', sharey=False, sharex=True, ysizerow=0.8
    )
    
    for i, (stim, subgroup) in enumerate(group.groupby(['stim'])):
        
        pltu.row_title(axs[i,0], stim)
        
        for j, (_, data_row) in enumerate(subgroup.sort_values(['pert_param']).iterrows()):
            
            title = axs[i,j].set_title(f"{data_row.pert_param:.2f}", pad=0.) 
            if data_row.n_samples == 0:
                axs[i,j].text(0.5, 0.5, 'No data', transform=axs[i,j].transAxes, ha='center', va='center')
            else:
                plot_MAE.plot_MAE_hist(ax=axs[i,j], data_row=data_row, log=True)
                
            axs[i,j].set_yticks([])
            axs[i,j].set_ylabel(None)
                
    plt.tight_layout()
    plt.show()
    break