# Izhikevich neuron network | Examples

Simulate INN examples

In [None]:
import numpy as np
import pandas as pd

from matplotlib import pyplot as plt
import seaborn as sns

from itertools import product as itproduct

In [None]:
from sys import path as sys_path
from os.path import abspath as os_path_abspath
sys_path.append(os_path_abspath('..'))
import addpaths

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import plot_utils as pltu

# Model

In [None]:
t0, tmax = 0, 70

In [None]:
import izhikevich_network
network = izhikevich_network.network(DEBUG=True)
network.plot()

# Generator

In [None]:
from data_generator_INN import data_generator_INN

gen = data_generator_INN(
    t0=t0, y0=network.y0, tmax=tmax, model=network, n_parallel=2,
    n_samples=40, gen_det_sols=False, gen_acc_sols=True,
    t_eval_adaptive=None, return_vars=['ys', 'events'],
    base_folder='_data/INN_uncertainty_example',
)
gen.update_subfoldername()
gen.load_acc_sols_from_file()

## Data

In [None]:
method='FE'
adaptive=0
step_param=0.5
pert_method='conrad'

In [None]:
solver = gen.get_solver(
method=method, adaptive=adaptive, step_param=step_param,
    pert_method=pert_method
)

In [None]:
gen.gen_and_save_data(
    method=method, adaptive=adaptive, step_param=step_param, pert_method=pert_method,
    plot=True, overwrite=False,
)

## Load data

In [None]:
from data_generator_INN import data_loader_INN

data_dict = data_loader_INN(gen).load_data2dict(
    method=method, adaptive=adaptive,
    step_param=step_param, pert_method=pert_method, MAEs=False
)

# Plot

## Plot functions

In [None]:
nidxs = [int(1000/8), int(1000*3/8), int(1000*5/8), int(1000*7/8)]

def plot_traces(ax, ts, vs, nidxs):
    """Plot v(t) for sample neurons."""
    ax.set_xlim(t0, tmax)
    ax.set_ylim(-80, 40)

    for i, nidx in enumerate(nidxs):
        ax.plot(ts, vs[:,nidx], color=pltu.neuron2color(i), clip_on=True, alpha=0.85) 


def plot_events(ax, events, nidxs):
    """Plot events as scatter and highlight neurons"""
    ax.set_xlim(t0, tmax)
    ax.set_ylim(0, network.N)
    
    for nidx, elist in enumerate(events):
        ax.scatter(elist, np.full(len(elist), nidx), marker='.', zorder=0, color='darkgray', s=0.3) 
    
    for i, nidx in enumerate(nidxs):
        ax.scatter(
            events[nidx], np.full(len(events[nidx]), nidx),
            marker='x', zorder=10, color=pltu.neuron2color(i), alpha=0.9
        ) 
    
    
def plot_kde(ax, kde_ts, kde):
    """Plot KDE."""
    ax.set_ylim(0,65)
    ax.set_yticks([0,25,50])
    ax.plot(kde_ts, kde, color='k', clip_on=False)
    
    
def plot_kde_traces(ax, kde_ts, kdes):
    """Plot every KDE as a line"""
    ax.set_ylim(0,65)
    ax.set_yticks([0,25,50])
    for kde in kdes:
        ax.plot(kde_ts, kde, lw=0.5, color='k', alpha=0.5, clip_on=False)
    
    
def plot_kde_summary(ax, kde_ts, kdes, acc_kde):
    """Plot KDE summary"""
    ax.set_ylim(0,65)
    ax.set_yticks([0,25,50])
    pltu.plot_sample_trace(ax, kde_ts, acc_kde, label='ref.')
    pltu.plot_mean_and_uncertainty(ax, kde_ts, kdes)

## Figure

In [None]:
smpidxs = [11, 12, 13]

### Prepare plot ###
sbnx = 4
sbny = 2
fig, axs = pltu.subplots(sbnx, sbny, gridspec_kw=dict(height_ratios=[0.4, 1]), ysizerow=1.)

taxs = []

### Plot data ###
for smpidx, ax_col in zip(smpidxs, axs.T[:-1]):
    plot_traces(ax_col[0], data_dict['ts'], data_dict['ys'][smpidx], nidxs=nidxs)
    plot_events(ax_col[1], data_dict['events'][smpidx], nidxs=nidxs)
    tax = ax_col[1].twinx()
    taxs.append(tax)
    plot_kde(tax, data_dict['kde_ts'], data_dict['kdes'][smpidx])

# Plot single KDE traces
tax = axs[0,-1].twinx()
taxs.append(tax)
plot_kde_traces(ax=tax, kde_ts=data_dict['kde_ts'], kdes=data_dict['kdes'])

# Plot summary KDE
tax = axs[1,-1].twinx()
taxs.append(tax)
plot_kde_summary(ax=tax, kde_ts=data_dict['kde_ts'], kdes=data_dict['kdes'], acc_kde=data_dict['acc_kde'])
    
### Decorate ###
axs[0,-1].spines['left'].set_visible(False)
for ax in axs[1,:]: sns.despine(ax=ax, top=1, right=1, left=0, bottom=0)
for ax in taxs: sns.despine(ax=ax, top=1, right=0, left=0, bottom=1)
    
for ax in taxs[-2:]: ax.set_ylabel('Rate (Hz)')
    
for ax in axs[0,:]:
    sns.despine(ax=ax, top=1, right=1, left=0, bottom=1)
    ax.set_xticks([])
    
for ax in taxs: ax.spines['left'].set_visible(False)
for ax in axs[:,1:].flat: ax.set_yticklabels([])
for ax in axs[:,-1]:
    ax.set_yticks([])
    ax.spines['left'].set_visible(False)
    
for ax in taxs[:-2]: ax.set_yticklabels([])
    
pltu.make_share_xlims(axs[-1,:])
    
pltu.move_xaxis_outward(axs)
pltu.set_labs(axs[-1,:], xlabs='Time (ms)')
pltu.set_labs(axs[1,0], ylabs='Neuron')
pltu.set_labs(axs[0,0], ylabs='v(t)')
pltu.set_labs(axs[0,:], panel_nums='auto', panel_num_va='center')
pltu.tight_layout(w_pad=0.5, h_pad=0.5)
 
fig.align_labels()

taxs[-1].legend(loc='upper left', borderpad=0., bbox_to_anchor=(-0.1,1.1))

pltu.savefig("INN_examples")
plt.show()
pltu.show_saved_figure(fig)

## Appendix

In [None]:
solver = gen.get_solver(method=method, adaptive=adaptive, step_param=0.1, pert_method=pert_method)

In [None]:
solver.solve(tmax=tmax)

In [None]:
%timeit solver.eval_odefun(t=tmax, y=network.y0)

In [None]:
print(solver.prestepfun)
%timeit solver.step(step_tmax=tmax+solver.h0)

In [None]:
solver.prestepfun = None
%timeit solver.step(step_tmax=tmax+solver.h0)

In [None]:
%timeit network.spike_kernel((tmax-solver.last_spike_times))

In [None]:
# No spikes --> no spike kernel evluation
solver.presolvefun(solver)
%timeit solver.eval_odefun(t=tmax, y=network.y0)

In [None]:
idx = solver.y_new.argmax()
%timeit solver.dense_eval_at_y(np.mean([solver.y[idx], solver.y_new[idx]]), yidx=idx)