# STG 3

Simulate the parameterizations of the STG model shown in the original paper.

Simulate with a highly accurate solver, and with probabilistic solvers using different tolerances.

In [None]:
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from itertools import product as itproduct

In [None]:
from sys import path as sys_path
from os.path import abspath as os_path_abspath
sys_path.append(os_path_abspath('..'))
import addpaths

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import metric_utils
import frame_utils
import plot_utils as pltu

# Generator

In [None]:
import stg_model
stg_model.compile_cython()

In [None]:
from data_generator_STG import data_generator_STG

def get_panel_generator(panel, tmax, n_samples=40):
    """Get generator for panel"""
    network = stg_model.stg_model(n_neurons=3, g_params=panel, y0from1n=False)
    
    gen = data_generator_STG(
        y0=network.y0, t0=0.0, tmax=tmax*1e3, acc_same_ts=False,
        vidx=None, yidxs=np.arange(46), max_step=10,
        model=network, n_samples=n_samples,
        gen_det_sols=True, gen_acc_sols=True, 
        base_folder='_data/uncertainty',
        return_vars=['ys', 'events']
    )
    
    gen.update_subfoldername(panel=panel)
    
    return gen

<img src="_background/n3_traces.png">

In [None]:
panel2tmax = {
    'a': 2.5,
    'b': 2.5,
    'c': 2.5,
    'd': 6,
    'e': 2.5,
    'f': 2.5, 
    'g': 2.5,
    'h': 2.5,
    'i': 6,
    'j': 2.5,
}

gens = {}

panels = 'abcde'
for panel in panels:
    print(panel)
    gens[panel] = get_panel_generator(panel=panel, tmax=panel2tmax[panel])
    gens[panel].load_acc_sols_from_file()

## Test

In [None]:
sol = gens['e'].gen_sol(method='RKDP', adaptive=1, step_param=1e-3, pert_method='conrad')

In [None]:
sol.plot(y_idxs=[0,13,26])

## Data

In [None]:
# pert_method, adaptive, methods, step_params
solver_params = [
    ('conrad', 0, ['EE'], [0.5, 0.25, 0.2, 0.1, 0.025, 0.01]),
    ('conrad', 0, ['EEMP'], [0.5, 0.25, 0.2, 0.1, 0.025, 0.01]),
    ('conrad', 1, ['RKBS', 'RKDP'], [1e-3, 1e-5, 1e-7]),
]

In [None]:
np.seterr(over='raise', invalid='raise') # Don't raise warnings, but errors

for panel, gen in gens.items():
    
    print('----------------------------------------------------------')
    print(panel, ':', gen.subfoldername)
    print('----------------------------------------------------------')   
    
    for pert_method, adaptive, methods, step_params in solver_params:
        for step_param, method in itproduct(step_params, methods):
            gen.gen_and_save_data(
                method=method, adaptive=adaptive, step_param=step_param,
                pert_method=pert_method, plot=False, allowgenerror=True,
                overwrite=False
            )
            
np.seterr(over='warn', invalid='warn') # Reset

# Load data

In [None]:
from data_generator_STG import data_loader_STG
import pandas as pd

df = pd.DataFrame()

for panel, gen in gens.items():
    panel_df = data_loader_STG(gen).load_data2dataframe(solver_params, MAEs=True, allowgenerror=True)
    metric_utils.add_det_nODEcalls(panel_df, T=gen.tmax-gen.t0)
    panel_df['panel'] = panel
    df = df.append(panel_df, ignore_index=True)
df.columns

# Plot

## Traces plots

In [None]:
vidxs = [0,13,26]

## Kernel density estimates

In [None]:
def plot_voltage(ax, data_row, event_idx, xlim=None):
    """Plot voltage traces and kernel density estimates"""
    
    if xlim is not None: xlim = (xlim[0]*1e3, xlim[1]*1e3) # s to ms
    
    vidx = vidxs[event_idx]

    pltu.plot_sample_trace(
        ax=ax, ts=data_row.acc_ts, ys=data_row.acc_ys[:,vidx],
        tr_kw=dict(lw=0.4), label='ref.', xlim=xlim
    )
    pltu.plot_mean_and_uncertainty(
        ax=ax, ts=data_row.ts,
        ys=data_row.ys[:,:,vidx] if isinstance(data_row.ys, np.ndarray) else [ys_i[:,vidx] for ys_i in data_row.ys],
        intpol_dt=1, xlim=xlim, tr_mean_kw=dict(lw=0.4),
    )
    
    ax.set_ylabel('v(t) (mV)')
    ax.set_xlabel('Time (s)')

In [None]:
def plot_voltage_samples(ax, data_row, event_idx, smp_idxs=[0,1,2], xlim=None):
    """Plot voltage traces and kernel density estimates"""
    
    if xlim is not None: xlim = (xlim[0]*1e3, xlim[1]*1e3) # s to ms
        
    vidx = vidxs[event_idx]
    
    for i, smp_idx in enumerate(smp_idxs):
        pltu.plot_sample_trace(
            ax=ax, ts=data_row.ts[smp_idx] if data_row.adaptive else data_row.ts,
            ys=data_row.ys[smp_idx][:,vidx],
            tr_kw=dict(c=pltu.neuron2color(i), alpha=0.9, lw=0.7, ls='-'), xlim=xlim
        ) 
    
    pltu.plot_sample_trace(
        ax=ax, ts=data_row.acc_ts, ys=data_row.acc_ys[:,vidx],
        tr_kw=dict(lw=0.7, zorder=1000), label='ref.', xlim=xlim
    )
    
    ax.set_ylabel('v(t) (mV)')
    ax.set_xlabel('Time (s)')

In [None]:
def plot_kde(ax, data_row, event_idx, xlim=None):
    """Plot voltage traces and kernel density estimates"""
    
    if xlim is not None: xlim = (xlim[0]*1e3, xlim[1]*1e3) # s to ms
    
    vidx = vidxs[event_idx]

    pltu.plot_sample_trace(ax=ax, ts=data_row.kde_ts, ys=data_row.acc_kde[:,event_idx], xlim=xlim)
    pltu.plot_mean_and_uncertainty(ax=ax, ts=data_row.kde_ts, ys=data_row.kdes[:,:,event_idx], xlim=xlim)

    ax.set_ylabel('Rate (Hz)')
    ax.set_xlabel('Time (s)')

### Test

In [None]:
data_rows = frame_utils.get_data_rows(df=df, method='RKBS', panel='e')
data_rows

In [None]:
fig, axs = pltu.subplots(data_rows.shape[0], 3, xsize='fullwidth')
for ax_col, (i, data_row) in zip(axs.T, data_rows.iterrows()):
    plot_voltage(ax=ax_col[0], data_row=data_row, event_idx=1)
    plot_voltage_samples(ax=ax_col[1], data_row=data_row, event_idx=1, xlim=None, smp_idxs=[0,3,4])
    plot_kde(ax=ax_col[2], data_row=data_row, event_idx=1)
plt.tight_layout()

## Summary plot

In [None]:
panel2xlim = {
    'a': [(0,2.5), (1, 2)],
    'b': [(0,2.5), (1, 2)],
    'c': [(0,2.5), (1.2, 2.2)],
    'd': [(0,6), (2.8, 3.8)],
    'e': [(0,2.5), (1.6, 2.6)],
}

In [None]:
pert_method = 'conrad'
event_idx = 1

df_groups = df.groupby(['panel'], sort=False)

### Prepare axes ###
sbnx = len(df_groups)
sbny = 5

fig, axs = pltu.subplots(
    sbnx, sbny, ysizerow=0.95, yoffsize=0.3, squeeze=False, xsize='fullwidth',
    gridspec_kw=dict(height_ratios=[0.7,0.7,0.5,0.5,1])
)

plot_methods = set()

### Plot traces ###
for idx, (panel, group) in enumerate(df_groups):
    ax_col = axs[:-1,idx]
    
    data_row1 = frame_utils.get_data_rows(group, n_ex=1, pert_method=pert_method, adaptive=0, method='EE', step_param=0.1)
    data_row2 = frame_utils.get_data_rows(group, n_ex=1, pert_method=pert_method, adaptive=0, method='EE', step_param=0.01)
    
    plot_voltage(ax=ax_col[0], data_row=data_row1, event_idx=event_idx, xlim=panel2xlim[panel][0])
    plot_voltage_samples(
        ax=ax_col[1], data_row=data_row1, event_idx=event_idx, xlim=panel2xlim[panel][1],
        smp_idxs=[3,0,4] if panel!='c' else [3,4,0]) # Select samples.
    plot_kde(ax=ax_col[-2], data_row=data_row1, event_idx=event_idx, xlim=panel2xlim[panel][0])
    plot_kde(ax=ax_col[-1], data_row=data_row2, event_idx=event_idx, xlim=panel2xlim[panel][0])

### Plot summaries ###
for idx, (panel, group) in enumerate(df_groups):
    ax = axs[-1,idx]
    for (method, adaptive, pert_param), subgroup in group.groupby(['method', 'adaptive', 'pert_param'], sort=False):
        plot_methods.add(method)
               
        pltu.plot_xy_percentiles(
            ax,
            datax=[data_row['det_nODEcalls_per_time'] for i, data_row in subgroup.iterrows() if data_row.n_samples > 0],
            datay=[data_row['MAE_SR'][event_idx] for i, data_row in subgroup.iterrows() if data_row.n_samples > 0],
            marker=pltu.method2marker(method), color=pltu.method2color(method),
            mean_kw=dict(alpha=0.6, ls=pltu.mode2ls(adaptive), mfc=pltu.mode2mfc(adaptive), ms=6),
            line_kw=dict(color='k', lw=0.5, alpha=1.0), 
        )
        
    ax.grid(True, axis='both', alpha=.3, c='k', lw=plt.rcParams['ytick.major.width'])
    ax.set_xscale('log')
    ax.set_yscale('log')

### Decorate ###

# labels
pltu.set_labs(axs[-1,:], xlabs=r'N$_\mathrm{det.}$(ODE) (1/ms)')
pltu.set_labs(axs[-1,0], ylabs=r'$\mathrm{MAE}_\mathrm{SR}$ (Hz)')
pltu.set_labs(axs, panel_nums='auto', panel_num_space=2, panel_num_va='top')
for ax in axs[:,1:].flat: ax.set_ylabel(None)
for ax in axs[:-2].flat: ax.set_xlabel(None)
    
# ylims
pltu.make_share_ylims(axs[-1,:])
pltu.make_share_xlims(axs[-1,:])
pltu.make_share_ylims(axs[0,:])
pltu.make_share_ylims(axs[2:4,:2])
pltu.make_share_ylims(axs[2:4,2:])

# ticks
for ax in axs[2:4,:2].flat: ax.set_yticks([0,20]) 
for ax in axs[2:4,2:].flat: ax.set_yticks([0,60]) 
for ax in axs[0:2,:].flat: ax.set_yticks([-60, -30, 0, 30])   
pltu.scale_ticks(axs[:-1,:], scale=1e-3, x=True, y=False)

# alignment
fig.align_labels()
sns.despine()

pltu.move_xaxis_outward(axs, scale=2)
pltu.tight_layout(h_pad=-0.3, w_pad=0.2)
for i, ax_row in enumerate(axs[1:-1,:]): pltu.move_box(axs=ax_row, dy=0.013*(i+1))

# legend
plot_methods = pltu.sort_methods(list(plot_methods))
pltu.make_method_and_mode_legend(
    ax=axs[-1,0], methods=plot_methods, example_method=plot_methods[-1], pert_method=pert_method,
    legend1_kw=dict(labelspacing=0.0, borderpad=0.0, loc='lower left', bbox_to_anchor=(0.00,-0.06), handlelength=0.4),
    legend2_kw=dict(labelspacing=0.0, borderpad=0.0, loc='lower left', bbox_to_anchor=(0.4,-0.06)),
)
axs[0,0].legend(
    loc='upper right', borderpad=0.15, frameon=True, framealpha=0.7,
    borderaxespad=0.0, bbox_to_anchor=(1,1.2)
)

# save and show
pltu.savefig("STG3_summary_runtime")
plt.show()

In [None]:
pltu.show_saved_figure(fig)

# Text

In [None]:
data_rows = frame_utils.get_data_rows(df=df, method='EEMP')
data_rows = data_rows[data_rows.n_samples > 0]

In [None]:
for panel, group in data_rows.groupby(['panel']):
    print(panel)
    for _, data_row in group.iterrows():
        print(np.mean(data_row.det_nODEcalls_per_time))