# Chapter 4 Appendix - I/O Equivalence for Spiking Neuron Models

In this notebook, we provide the source code for all the figures appearing in the 
Chapter 4 Appendix of BMEBW4020, Circuits in Brain. 

We provide additional information regarding the geometric interpretation of phase response in terms of isochrons. We also provide experimental result showing the approximation error of Project-Integrate-Fire neuron to biological spike generators

The purpose of this demo is to help readers understand the material covered 
in the chapter better and get more intuition by running the code. 
Readers are encouraged to modify the code to create more interesting phenomena. 

This notebook is written based MATLAB scripts and notebooks by members of the Bionet Lab at Colubmia University: Wenze Li, Tingkai Liu, Mehmet Kerem Turkcan, Chung-Heng Yeh.

*Authors*: Shashwat Shukla <shashwat.shukla@columbia.edu>, Tingkai Liu <tl2747@columbia.edu>

*Copyright 2023*, Aurel A. Lazar, Tingkai Liu, Shashwat Shukla.

## Setup

In [2]:
%load_ext autoreload
%autoreload 2

# import libraries
import sys
from os.path import exists
import ipynbname
import re
import time
from collections import namedtuple
from tqdm.auto import tqdm
import IPython.display
import ipywidgets as widgets

import numpy as np
np.random.seed(0)  # fix random seed
from scipy import signal
from scipy.integrate import cumtrapz
from scipy.optimize import fsolve

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.colors import to_hex
from mpl_toolkits.mplot3d import Axes3D

import bokeh as bk
from bokeh.io import output_notebook
import bokeh.plotting
import bokeh.layouts
import bokeh.models
output_notebook()  # render Bokeh animations inline

import plotly
import plotly.graph_objects as go
#from plotting.plot_plotly import plot_3d, animate_traces
import plotly.io as pio
pio.renderers.default = 'notebook'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
sys.path.insert(1, '../scripts/') # add path to tag handling script
#from link_figures import save_tagged_fig
nb_name = ipynbname.name()
fig_dir = "../book_src/web_figures/{}/figures/".format(nb_name) # relative path to figures directory
fig_count = 0 # number of figures rendered so far

In [5]:
from compneuro.neurons.hodgkin_huxley import HodgkinHuxley
from compneuro.neurons.hodgkin_huxley_3state import HodgkinHuxley3State
from compneuro.neurons.hodgkin_huxley_rinzel import HodgkinHuxleyRinzel
from compneuro.neurons.leaky_integrate_fire import LeakyIntegrateFire
from compneuro.utils.phase_response import PIF, iPRC, solve_multiplicative
from compneuro.utils.signal import spike_detect
from compneuro.utils.neuron import limit_cycle, isochron
from plotting.plot_bokeh import plot_quiver
from plotting.plot_basic import plot_spikes

ModuleNotFoundError: No module named 'compneuro.utils.phase_response'

## Isochrons
Isochrons provide a intuitive interpretation of the perturbed solution. Informally, if the limit cycle solution $x^{0}$ is asymptotically stable, it will attract all nearby states and therefore quickly reduce orbital devation $z(t)$. However, after the perturbed state converges back onto the limit cycle, the _phase shift_ caused by the perturbation is not negligible. To obtain an intuition about this phase shift, we introduce the concept of isocrhon.

Isochrons are manifolds in the state space that represent solutions having equal phases. In the plot below, we present the isochron for Rinzel's reduced Hodgkin-Huxley model. For each of the colored curve in the $[R,V]$ phase plane, all points along the line will stablize onto the limit cycle and share the same asymptotic phase.

**Note:** isochrons typically span the entire state space, here only 10 lines are plotted with phase interval $0.2\pi$ between them for illustration purposes.

The isochron plot illustrates a few important concepts:
1. The lines representing isochrons are evaluated at equal intervals in phase. We observe a higher angular seperation towards higher $V$ and lower towards lower $V$, which is expected because higher $V$ values correspond to faster dynamic. 
2. Isochrons have a very close relationship with change phase and the consequent change in spike-time . Given current state on the limit cycle $\mathbf{x}^{0}$, a positive perturbation in, for example, voltage $V$, will push the state horizontally to the right _onto_ a different isochron. The amount of phase shift then correspond to the _density_ of isochrons crossed by the state due to this perturbation. 

### Geometric Interpretation of $\psi(t,I)$ 
To interpret $\psi(t,I)$ we look again at the equation:
\begin{equation}
\dot{\theta} = 1 + \psi(\theta(t),I)u(t)
\end{equation}
Note that we can re-write the left-hand-side as 
\begin{equation}
\dot{\theta} = \frac{d\theta}{dy} \frac{dy}{dt} = \nabla_{y} \theta \cdot [f(y,I) + u(t)]
\end{equation}
By definition, $\dot{\theta}=1$ for $x^{0}(t)$, which gives us
\begin{equation}
    \dot{\theta} = 1+ \nabla_{y} \theta \cdot u(t)
\end{equation}

Therefore,
\begin{equation}
    \nabla_{y} \theta(t) \equiv \psi(\theta(t),I)
\end{equation}
and $\nabla_{y} \theta(t)$ refers to the rate of change of phase $\theta$ in response to a change in state $y$ evaluated at time $t$. This has a nice geometric interpretation related to the concept of _isochron_, relating the local density of isochrons to the iPRC $\psi(t,I)$.

In [4]:
I_inj = 10
hhr = HodgkinHuxleyRinzel()
rinzel_lc_t, rinzel_lc, x0_iso, t_iso, iso = isochron(
    hhr, I_ext=I_inj, dist_to_lc_threshold=1e-1, N_skip_cycles=5, dt=1e-5, Npoints=100
)

  0%|          | 0/2681 [00:00<?, ?it/s]

Computing Isochron via Backward Integration:   0%|          | 0/3217 [00:00<?, ?it/s]

In [5]:
# compute gradients
VV, RR = np.meshgrid(np.linspace(-80, 50, 30), np.linspace(0, 1.0, 30))
rinzel_gradients = hhr.ode(0.0, np.vstack([VV.ravel(), RR.ravel()]), I_inj)
d_VV, d_RR = (
    rinzel_gradients[0].reshape(VV.shape),
    rinzel_gradients[1].reshape(VV.shape),
)

# plot line plots
line_figs = []
ylims = dict(V=[-90, 50], R=[0.4, 1.0])
for n, (lab) in enumerate(hhr.states.keys()):
    p = bk.plotting.figure(
        height=250,
        width=400,
        y_range=ylims[lab],
        x_axis_label="Time [sec]",
        y_axis_label=lab,
        title=f"{lab}(t)",
    )
    if(n > 0):
        p.x_range = line_figs[-1].x_range
    else:
        p.xaxis.visible = False
    p.line(rinzel_lc_t, rinzel_lc[n], line_width=2)
    p.title.text_font_size = '12pt'
    p.title.align = "center"
    p.xaxis.axis_label_text_font_style = "normal"
    p.yaxis.axis_label_text_font_style = "normal"
    line_figs.append(p)

# plot phase plane with quiver
p = bk.plotting.figure(
    height=350, width=400, x_axis_label="V [mV]", y_axis_label="R", title="Phase Plane"
)
plot_quiver(p, VV.ravel(), RR.ravel(), d_VV.ravel(), d_RR.ravel())
p.line(rinzel_lc[0], rinzel_lc[1], line_color="black", line_width=2)
p.title.text_font_size = '12pt'
p.title.align = "center"
p.xaxis.axis_label_text_font_style = "normal"
p.yaxis.axis_label_text_font_style = "normal"

p.scatter(*x0_iso, color="red", size=10)
cs = mpl.colormaps.get_cmap("coolwarm")(np.arange(len(iso)))
for tt in np.arange(iso.shape[0])[::30]:
    p.line(*iso[tt], line_width=2, color=mpl.colors.to_hex(cs[tt]))
p.x_range = line_figs[0].y_range
p.y_range = line_figs[1].y_range

p_phi = bk.plotting.figure(
    height=150, width=400, x_axis_label="t [sec]", y_axis_label="ϕ", title="Phase", tools="box_zoom,reset"
)
p_phi.line(rinzel_lc_t, np.linspace(0, np.pi, len(rinzel_lc_t)), line_color="black")
for tt in np.arange(iso.shape[0])[::30]:
    p_phi.scatter(
        rinzel_lc_t[tt],
        np.linspace(0, np.pi, len(rinzel_lc_t))[tt],
        color=mpl.colors.to_hex(cs[tt]),
        size=10,
    )
p_phi.title.align = "center"
p_phi.title.text_font_size = '12pt'
p_phi.xaxis.axis_label_text_font_style = "normal"
p_phi.yaxis.axis_label_text_font_style = "normal"

fig = bk.layouts.row([bk.layouts.column(line_figs), bk.layouts.column([p, p_phi])])

# render figure
bk.io.show(fig)
# specify figure caption and tag
tag = "rinzel-isochron"
caption = "Rinzel Isochron."
# save figure with name specified by tag
save_tagged_fig(tag, fig, fig_dir, "bokeh")
# render caption
fig_count += 1 # update figure count
IPython.display.HTML(
    data="<center><b>Figure {}.</b> {}</center>".format(fig_count, caption)
)

## PIF Approximation Error vs. Injected Current
Intuitively, the smaller the perturabtion relatively to the Injected current $I$, the better the PRC approximation is for the PIF neuron. 
In the figure below, we see that the differences in spike times computed using the neuron and its PIF approximation increase in amplitude with increasing amplitude of input current. Note also from the red error bars that the variance of error also increases with input current amplitude.

In [10]:
# stimulus composed of weighted sinc functions
def sinc_stimulus(t, omega, K=15):
    out = np.zeros_like(t)
    omega_pi = omega / np.pi
    for k in range(1, K + 1):
        out += np.random.rand() * omega_pi * np.sinc(omega_pi * t - k)
    return out / np.max(out)

In [7]:
dt = 1e-6
I_inj = 50
neu = HodgkinHuxleyRinzel()
t_lc, lc, prc = iPRC(neu, I_ext=I_inj, dt=dt, verbose=False)
period = t_lc[-1]
t = np.arange(-0.1, 0.2, dt)
omega = 2 * np.pi * 50
u = sinc_stimulus(t, omega, K=15)

In [8]:
a_range = np.linspace(1, 20, 10)
neu = HodgkinHuxleyRinzel(num=len(a_range))
neu_res = neu.solve(t, I_ext=I_inj + np.outer(a_range, u).T, verbose=True)
neu_spike_mask = spike_detect(neu_res["V"], height=20)

  0%|          | 0/300001 [00:00<?, ?it/s]

In [9]:
err_arr = []
for n, ss in enumerate(neu_spike_mask):
    (tk_idx,) = np.where(ss)
    u_n = a_range[n] * u
    # we create a 2D array of stimulus of shape (len(tk_idx)-1, len(t_lc)*2)
    # each row in this array correspond to the stimulus from a given spike to
    # 2 limit cycle later.
    u_chunks = np.zeros((len(tk_idx) - 1, len(t_lc) * 2))
    for chunk_i, idx in enumerate(tk_idx[4:-4]):
        _u_chunk = u_n[idx + 1 : min(idx + 1 + len(t_lc) * 2, len(t))]
        u_chunks[chunk_i, : len(_u_chunk)] = _u_chunk

    # for each input stimulus, we integrate the PIF voltage within this 2 limit cycles
    # and find the time index of when the voltage first crosses the threshold (period).
    # The resulting time is the PIF Inter-Spike-Interval.
    # The spike time error of PIF is therefore the difference of Inter-Spike-Interval
    # of the biophysical neuron model and the PIF.
    V_pif_chunk = dt * np.cumsum(1 + u_chunks * np.tile(prc[0], 2)[None, :], axis=1)
    spk_diff_pif = np.array([np.where(_V >= period)[0][0] for _V in V_pif_chunk])
    err_arr.append((np.diff(tk_idx) - spk_diff_pif) * dt)

mean_err_arr = np.array([np.mean(np.abs(err)) for err in err_arr])
std_err_arr = np.array([np.std(np.abs(err)) for err in err_arr])

x_arr = []
y_arr = []
for i in range(len(err_arr)):
    x_arr.extend([a_range[i]] * len(err_arr[i]))
    y_arr.extend(np.abs(err_arr[i]))
x_arr = np.array(x_arr)
y_arr = np.array(y_arr)

In [10]:
plot = bk.plotting.figure(width=600, height=400)
plot.circle(x=x_arr, y=1e3 * y_arr, color="green")
plot.title.text = "PIF spike-time errors as a function of input amplitude"
plot.title.text_font_size = '12pt'
plot.title.align = "center"
plot.xaxis.axis_label = "Input current amplitude [mA]"
plot.xaxis.axis_label_text_font_style = "normal"
plot.yaxis.axis_label = "Spike-time difference (absolute value) [ms]"
plot.yaxis.axis_label_text_font_style = "normal"
error_db = bk.models.ColumnDataSource(
    data=dict(
        base=a_range,
        lower=1e3 * (mean_err_arr - std_err_arr),
        upper=1e3 * (mean_err_arr + std_err_arr),
    )
)
plot.add_layout(
    bk.models.Whisker(
        source=error_db,
        base="base",
        upper="upper",
        lower="lower",
        line_width=3,
        line_color="red",
    )
)


# render figure
bk.io.show(plot)
# specify figure caption and tag
tag = "pif-error"
caption = "Mean and variance of PIF approximation error."
# save figure with name specified by tag
save_tagged_fig(tag, plot, fig_dir, "bokeh")
# render caption
fig_count += 1 # update figure count
IPython.display.HTML(
    data="<center><b>Figure {}.</b> {}</center>".format(fig_count, caption)
)

## I/O equivalence of HHN and IAF with Variable Threhsold

In [11]:
dt = 1e-5
t = np.arange(0., 0.04, dt)
omega = 2 * np.pi * 100
u = .5 * np.sin(t * omega)
b = 1.
I_injected = 50

In [12]:
t_lc, lc = limit_cycle(HodgkinHuxleyRinzel(), I_ext=I_injected, dt=dt)

  0%|          | 0/100000 [00:00<?, ?it/s]

In [13]:
t_lc.max() - t_lc.min()

0.00313

In [14]:
# find where the time of zero-crossing for voltage trace
t_lc_long = np.arange(t_lc.min(), t_lc.max(), 1e-8)
V_lc_long = np.interp(t_lc_long, t_lc, lc[0])
delta_idx, = np.where(np.diff(np.sign(V_lc_long)))
delta_ks = t_lc_long[delta_idx]
thres = np.array([np.diff(delta_ks).item(), (t_lc.max()-t_lc.min()-np.diff(delta_ks)).item()])

In [15]:
# initialize the neuron model at the first zero-crossing
init_idx =int(t_lc_long[delta_idx][0]//dt)
hhr = HodgkinHuxleyRinzel(V=lc[0, init_idx], R=lc[1, init_idx])
res_noinput = hhr.solve(t, I_ext=I_injected + u + b, verbose=True, solver='Euler')
res_multiply = solve_multiplicative(hhr, t, stimulus=b+u, I_ext=I_injected, verbose=True)
V_hhn = res_multiply['V'][0]

# zero-crossing of hhn with multiplciative coupling
t_long = np.arange(t.min(), t.max(), 1e-8)
V_long = np.interp(t_long, t, V_hhn)
hhn_tk_idx, = np.where(np.diff(np.sign(V_long)))
hhn_tk = t_long[hhn_tk_idx]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/3999 [00:00<?, ?it/s]

In [16]:
V_iaf = np.zeros_like(V_hhn)
tk_idx_iaf = []
for tt, _u in enumerate(u):
    if tt == 0:
        continue
    new_V = V_iaf[tt-1] + dt * (b + _u)
    # threshold value alternates between the two zero-crossing intervals
    if new_V > thres[len(tk_idx_iaf) % len(thres)]:
        new_V = 0
        tk_idx_iaf.append(tt)
    V_iaf[tt] = new_V

In [17]:
TOOLS = "pan, box_zoom, reset"

fig1 = bk.plotting.figure(
    width=700,
    height=250,
    #     x_axis_label="Time [ms]",
    y_axis_label="u(t)+b",
    title="Input Stimulus",
    tools=TOOLS,
)
fig1.line(1e3 * t, u + b, color="red", line_width=2)
fig1.title.text_font_size = '12pt'
fig1.title.align = "center"
fig1.xaxis.axis_label_text_font_style = "normal"
fig1.yaxis.axis_label_text_font_style = "normal"


fig2 = bk.plotting.figure(
    width=700,
    height=250,
    #     x_axis_label="Time [ms]",
    y_axis_label="V[mV]",
    title=f"Hodgkin-Huxley Neuron w/ Multiplicative Coupling, Injected Current={I_injected} μA",
    tools=TOOLS,
)
fig2.line(1e3 * t, V_hhn.T, color="red", line_width=2)
fig2.circle(1e3 * t_long[hhn_tk_idx], V_long[hhn_tk_idx], color="green", size=10)
fig2.title.text_font_size = '12pt'
fig2.title.align = "center"
fig2.xaxis.axis_label_text_font_style = "normal"
fig2.yaxis.axis_label_text_font_style = "normal"


fig3 = bk.plotting.figure(
    width=700,
    height=250,
    x_axis_label="Time [ms]",
    y_axis_label="V",
    title="I/O Equivalent IAF Neuron with Variable Threshold",
    tools=TOOLS,
)
fig3.line(1e3 * t, V_iaf, color="red", line_width=2)
fig3.circle(1e3 * t[tk_idx_iaf], V_iaf[tk_idx_iaf], color="green", size=10)
fig3.title.text_font_size = '12pt'
fig3.title.align = "center"
fig3.xaxis.axis_label_text_font_style = "normal"
fig3.yaxis.axis_label_text_font_style = "normal"

fig2.x_range = fig1.x_range  # sync x-axis pan and zoom
fig3.x_range = fig1.x_range  # sync x-axis pan and zoom

fig = bk.layouts.column([fig1, fig2, fig3])

# render figure
bk.io.show(fig)
# specify figure caption and tag
tag = "hhn-iaf-equiv"
caption = "I/O equivalence between HHN under multiplicative coupling and IAF neuron with variable threshold, with trigger times defined as zero crossings."
# save figure with name specified by tag
save_tagged_fig(tag, fig, fig_dir, "bokeh")
# render caption
fig_count += 1 # update figure count
IPython.display.HTML(
    data="<center><b>Figure {}.</b> {}</center>".format(fig_count, caption)
)

### Computing the iPRC using Malkin's Method
In this section, we demonstrate the iPRC of Hodgkin-Huxley neuron and Rinzel neuron with a constant current stimulus. We start with the iPRC obtained with Malkin's method, which obtains iPRC through locally linearizing the dynamic of neuron.

**Note**: the iPRCs are normalized so that the largest absolute value of each of the phase response curves is `1`. This normalization is not required for our definitions and is done for visualization purposes.

In [18]:
# compute phase response using Malkin's method
hhn_limit_cycle_t_m, hhn_limit_cycle_m, hhn_prc_m = iPRC(
    HodgkinHuxley(), I_ext=15, dt=1e-5, verbose=True,
)
hhn_period_m = hhn_limit_cycle_t_m[-1]
hhr_limit_cycle_t_m, hhr_limit_cycle_m, hhr_prc_m = iPRC(
    HodgkinHuxleyRinzel(), I_ext=15, dt=1e-5, verbose=True,
)
hhr_period_m = hhr_limit_cycle_t_m[-1]

hhn_prc_m_norm = hhn_prc_m / np.max(np.abs(hhn_prc_m), axis=1)[:,None]
hhr_prc_m_norm = hhr_prc_m / np.max(np.abs(hhr_prc_m), axis=1)[:,None]

  0%|          | 0/100000 [00:00<?, ?it/s]

  0%|          | 0/100000 [00:00<?, ?it/s]

In [19]:
ds = int(10)  # downsampling factor
figs = []
colors = ["red", "green", "skyblue", "blue"]
for t_lc, prc, Model, model_name in zip(
    [hhn_limit_cycle_t_m, hhr_limit_cycle_t_m],
    [hhn_prc_m_norm, hhr_prc_m_norm],
    [HodgkinHuxley, HodgkinHuxleyRinzel],
    ['Hodgkin-Huxley', 'Rinzel']
):
    data = bk.models.ColumnDataSource(
        {
            **{'x': np.linspace(0, 1, len(t_lc))[::ds]}, # 1e3 * t_lc[::ds]
            **{
                key: prc[n,::ds]
                for n,key 
                in enumerate(Model.Default_States.keys())
            }
        }
    )

    fig = bk.plotting.figure(
        width=400,
        height=400,
    )
    if(len(figs) > 0):
        fig.x_range = figs[0].x_range
        fig.y_range = figs[0].y_range
    fig.title.text = f"iPRC - {model_name} Neuron"
    fig.title.text_font_size = '12pt'
    fig.title.align = "center"
    fig.xaxis.axis_label = f"Normalized Phase - Period {t_lc.max()*1e3:.1f} ms"
    fig.xaxis.axis_label_text_font_style = "normal"
    fig.yaxis.axis_label = "Normalized iPRC ψ"
    fig.yaxis.axis_label_text_font_style = "normal"

    for n, key in enumerate(Model.Default_States.keys()):
        fig.line(
            x="x",
            y=key,
            source=data,
            line_width=2,
            color=colors[n],
            legend_label=key,
        )
    fig.legend.location = "bottom_left"
    fig.legend.background_fill_alpha = 0.8
    figs.append(fig)
layout = bk.layouts.grid([figs])


# render figure
bk.io.show(layout)
# specify figure caption and tag
tag = "iprc-hhn-rinzel"
caption = "iPhase Response Curves of Hodgkin-Huxley and Rinzel Neurons."
# save figure with name specified by tag
save_tagged_fig(tag, layout, fig_dir, "bokeh")
# render caption
fig_count += 1 # update figure count
IPython.display.HTML(
    data="<center><b>Figure {}.</b> {}</center>".format(fig_count, caption)
)