# Chapter 4 - I/O Equivalence for Spiking Neuron Models

In this notebook, we provide the source code for all the figures appearing in the 
Chapter 4 of BMEBW4020, Circuits in Brain. 

We illustrate the dynamics of multiplicatively and additively coupled Hodgkin-Huxley neurons and demonstrate the computation of the Phase Response Curves using Winfree's Method and the infinitesimal Phase Response Curves using Malkin's Method. We also describe the geometric interpretation of these phase response and also study the corresponding phase response manifolds.

The purpose of this demo is to help readers understand the material covered 
in the chapter better and get more intuition by running the code. 
Readers are encouraged to modify the code to create more interesting phenomena. 

This notebook is written based MATLAB scripts and notebooks by members of the Bionet Lab at Colubmia University: Wenze Li, Tingkai Liu, Mehmet Kerem Turkcan, Chung-Heng Yeh.

*Authors*: Shashwat Shukla <shashwat.shukla@columbia.edu>, Tingkai Liu <tl2747@columbia.edu>

*Copyright 2023*, Aurel A. Lazar, Tingkai Liu, Shashwat Shukla.

## Setup

In [11]:
%load_ext autoreload
%autoreload 2

# import libraries
import sys
from os.path import exists
import ipynbname
import re
import time
from collections import namedtuple
from tqdm.auto import tqdm
from tqdm.notebook import tqdm
import IPython.display
import ipywidgets as widgets
from itertools import cycle
from collections import OrderedDict

import numpy as np
np.random.seed(0)  # fix random seed
from scipy.optimize import fsolve

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

import bokeh as bk
import bokeh.layouts
import bokeh.models
import bokeh.plotting
from bokeh.io import output_notebook
from bokeh.palettes import all_palettes as palette
output_notebook()  # render Bokeh animations inline

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
pio.renderers.default='notebook' # render Plotly animations inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
sys.path.insert(1, '../scripts/') # add path to tag handling script
from link_figures import save_tagged_fig
nb_name = ipynbname.name()
fig_dir = "../book_src/web_figures/{}/figures/".format(nb_name) # relative path to figures directory
fig_count = 0 # number of figures rendered so far

ModuleNotFoundError: No module named 'link_figures'

In [7]:
from compneuro.neurons.hodgkin_huxley import HodgkinHuxley
from compneuro.neurons.hodgkin_huxley_3state import HodgkinHuxley3State
from compneuro.neurons.hodgkin_huxley_rinzel import HodgkinHuxleyRinzel
from compneuro.neurons.leaky_integrate_fire import LeakyIntegrateFire
from compneuro.utils.phase_response import PIF, iPRC, PRC, solve_multiplicative
from compneuro.utils.signal import spike_detect, spike_detect_local
from compneuro.utils.neuron import limit_cycle, isochron
from plotting.plot_basic import plot_spikes
from plotting.plot_bokeh import plot_quiver

ModuleNotFoundError: No module named 'compneuro.utils.phase_response'

In [1]:
# stimulus composed of weighted sinc functions
def sinc_stimulus(t, omega, K=15):
    out = np.zeros_like(t)
    omega_pi = omega / np.pi
    for k in range(1, K + 1):
        out += np.random.rand() * omega_pi * np.sinc(omega_pi * t - k)
    return out / np.max(out)

In [6]:
class HHN_f:
    """
    1. Define functions for alpha and beta of n,m,and h.
    2. Define functions for gating function and time constant of n,m,and h.
    """

    @classmethod
    def a_n(cls, v):
        return (10 - v + 1e-8) / (1e-7 + 100 * (np.exp((10 - v) / 10) - 1))

    @classmethod
    def b_n(cls, v):
        return 0.125 * np.exp(-v / 80)

    @classmethod
    def a_m(cls, v):
        return (25 - v + 1e-8) / (1e-8 + 10 * (np.exp((25 - v) / 10) - 1))

    @classmethod
    def b_m(cls, v):
        return 4 * np.exp(-v / 18)

    @classmethod
    def a_h(cls, v):
        return 0.07 * np.exp(-v / 20)

    @classmethod
    def b_h(cls, v):
        return 1 / (np.exp((30 - v) / 10) + 1)

    @classmethod
    def n_inf(cls, v):
        return cls.a_n(v) / (cls.a_n(v) + cls.b_n(v))

    @classmethod
    def n_tau(cls, v):
        return 1 / (cls.a_n(v) + cls.b_n(v))

    @classmethod
    def m_inf(cls, v):
        return cls.a_m(v) / (cls.a_m(v) + cls.b_m(v))

    @classmethod
    def m_tau(cls, v):
        return 1 / (cls.a_m(v) + cls.b_m(v))

    @classmethod
    def h_inf(cls, v):
        return cls.a_h(v) / (cls.a_h(v) + cls.b_h(v))

    @classmethod
    def h_tau(cls, v):
        return 1 / (cls.a_h(v) + cls.b_h(v))

# I/O Equivalence for Hodgkin-Huxley Neurons with Multiplicative Coupling

In the following, we show that the phase portraits of a neuron without stimulus, and with stimulus with multiplicative coupling are the same, though the voltage traces in the two cases are different.

In [4]:
dt = 1e-6
t = np.arange(-0.01, 0.04, dt)
omega = 2 * np.pi * 50
I_ext = 2 + sinc_stimulus(t, omega, K=15)
I_inj = 30

In [5]:
hhr = HodgkinHuxleyRinzel()
res_noinput = hhr.solve(t, I_ext=I_inj * np.ones_like(t), verbose=True)
res_multiply = solve_multiplicative(hhr, t, stimulus=I_ext, I_ext=I_inj, verbose=True)

NameError: name 'HodgkinHuxleyRinzel' is not defined

In [9]:
hhr = HodgkinHuxleyRinzel()
# We analyze the phase response of Rinzel Model for dV/dt = 0 and dR/dt = 0.
# When dV/dt = 0, R is the root of the equation I_ext - I_Na - I_K - I_leak = 0.

# Determine the slope S of the Rinzel approximation
S = (1 - HHN_f.h_inf(0)) / HHN_f.n_inf(0)

# Compute R
V1 = np.arange(-75, 45, 0.3)  # set voltage range
R_dV_zero = np.zeros(len(V1))

for i in range(len(V1)):
    # define the equation I_ext - I_Na - I_K - I_leak = 0 for each voltage.
    def func(x):
        return (
            I_inj
            - 120 * HHN_f.m_inf(V1[i] + 65) ** 3 * (1 - x) * (V1[i] - 50)
            - 36 * (x / S) ** 4 * (V1[i] + 77)
            - 0.3 * (V1[i] + 54.37)
        )

    # Solve the equation using fsolve.
    R_dV_zero[i] = fsolve(func, 0.5)

# When dR/dt = 0, R = R_infinity(V)
V2 = np.arange(-95, -15, 0.3)  # set voltage range
R_dR_zero = S * (HHN_f.n_inf(V2 + 65) + S * (1 - HHN_f.h_inf(V2 + 65))) / (1 + S ** 2)

# Here we get the vector field at many different state points
# that we will plot later for intuition.

# choose same limits as the limits of plot
Vgrid, Rgrid = np.meshgrid(np.arange(-100.1, 50 + 5, 5), np.arange(0, 1 + 0.05, 0.05))
dV, dR = hhr.ode(states=[Vgrid[None, ...], Rgrid[None, ...]], t=0.0, I_ext=0.0)

# Normalize vector fields according to their range
dV /= np.max(np.abs(dV))
dR /= np.max(np.abs(dR))

xmin = -100
xmax = 60
ymin = 0
ymax = 1
sc1 = 0.04
sc2 = (xmax - xmin) / (ymax - ymin)
xs = Vgrid
ys = Rgrid
xe = xs + sc1 * sc2 * dV
ye = ys + sc1 * dR

In [10]:
# Plot the result

## static plots

d1 = bk.models.ColumnDataSource(
    data=dict(vr=res_noinput["V"][0], r=res_noinput["R"][0])
)
d2 = bk.models.ColumnDataSource(data=dict(v1=V1, r_dv=R_dV_zero))
d3 = bk.models.ColumnDataSource(data=dict(v2=V2, r_dr=R_dR_zero))
d4 = bk.models.ColumnDataSource(
    data=dict(vr=res_multiply["V"][0], r=res_multiply["R"][0])
)
quiver = bk.models.ColumnDataSource(
    data=dict(x_start=xs[np.newaxis,...], y_start=ys[np.newaxis,...], x_end=xe, y_end=ye)
)
trace1 = bk.models.ColumnDataSource(data=dict(t=1e3 * t, vr=res_noinput["V"][0]))
trace2 = bk.models.ColumnDataSource(data=dict(t=1e3 * t, vr=res_multiply["V"][0]))

plot1 = bk.plotting.figure(
    x_range=(0, 20), y_range=(-100, 60), plot_width=400, plot_height=400
)
plot1.title.text = "Voltage trace without stimulus"
plot1.title.text_font_size = '12pt'
plot1.title.align = "center"
plot1.xaxis.axis_label = "Time [ms]"
plot1.xaxis.axis_label_text_font_style = "normal"
plot1.yaxis.axis_label = "Membrane Potential [mV]"
plot1.yaxis.axis_label_text_font_style = "normal"
plot1.line(x="t", y="vr", source=trace1, color="red", line_width=2)

plot2 = bk.plotting.figure(
    x_range=(xmin, xmax), y_range=(ymin, ymax), plot_width=400, plot_height=400
)
plot2.title.text = "Phase plane without stimulus"
plot2.title.text_font_size = '12pt'
plot2.title.align = "center"
plot2.xaxis.axis_label = "V[mV]"
plot2.xaxis.axis_label_text_font_style = "normal"
plot2.yaxis.axis_label = "R"
plot2.yaxis.axis_label_text_font_style = "normal"
plot2.add_layout(
    bk.models.Arrow(
        end=bk.models.NormalHead(size=4),
        source=quiver,
        x_start="x_start",
        y_start="y_start",
        x_end="x_end",
        y_end="y_end",
    )
)

plot2.line(
    x="vr", y="r", source=d1, color="red", legend_label="Limit Cycle", line_width=2
)
plot2.line(
    x="v1", y="r_dv", source=d2, color="black", legend_label="dV/dt=0", line_width=2
)
plot2.line(
    x="v2", y="r_dr", source=d3, color="blue", legend_label="dR/dt=0", line_width=2
)
plot2.legend.location = "center"
plot2.legend.background_fill_alpha = 0.9

plot3 = bk.plotting.figure(
    x_range=(0, 20), y_range=(-100, 60), plot_width=400, plot_height=400
)
plot3.title.text = "Voltage trace, multiplicatively coupled stimulus"
plot3.title.text_font_size = '12pt'
plot3.title.align = "center"
plot3.xaxis.axis_label = "Time [ms]"
plot3.xaxis.axis_label_text_font_style = "normal"
plot3.yaxis.axis_label = "Membrane Potential [mV]"
plot3.yaxis.axis_label_text_font_style = "normal"
plot3.line(x="t", y="vr", source=trace2, color="red", line_width=2)

plot4 = bk.plotting.figure(
    x_range=(xmin, xmax), y_range=(ymin, ymax), plot_width=400, plot_height=400
)
plot4.title.text = "Phase plane, multiplicatively coupled stimulus"
plot4.title.text_font_size = '12pt'
plot4.title.align = "center"
plot4.xaxis.axis_label = "V[mV]"
plot4.xaxis.axis_label_text_font_style = "normal"
plot4.yaxis.axis_label = "R"
plot4.yaxis.axis_label_text_font_style = "normal"
plot4.add_layout(
    bk.models.Arrow(
        end=bk.models.NormalHead(size=4),
        source=quiver,
        x_start="x_start",
        y_start="y_start",
        x_end="x_end",
        y_end="y_end",
    )
)
plot4.line(
    x="vr", y="r", source=d4, color="red", legend_label="Limit Cycle", line_width=2
)
plot4.line(
    x="v1", y="r_dv", source=d2, color="black", legend_label="dV/dt=0", line_width=2
)
plot4.line(
    x="v2", y="r_dr", source=d3, color="blue", legend_label="dR/dt=0", line_width=2
)
plot4.legend.location = "center"
plot4.legend.background_fill_alpha = 0.9

plot3.x_range = plot1.x_range  # sync x-axis pan and zoom
plot3.y_range = plot1.y_range  # sync y-axis pan and zoom
plot4.x_range = plot2.x_range  # sync x-axis pan and zoom
plot4.y_range = plot2.y_range  # sync y-axis pan and zoom

## animation

ds = int(100)  # downsampling factor
x = 1e3 * t[(t > 0.0) * (t < 0.02)][::ds]
v1 = res_noinput["V"][0][(t > 0.0) * (t < 0.02)][::ds]
v2 = res_multiply["V"][0][(t > 0.0) * (t < 0.02)][::ds]
r1 = res_noinput["R"][0][(t > 0.0) * (t < 0.02)][::ds]
r2 = res_multiply["R"][0][(t > 0.0) * (t < 0.02)][::ds]
xi = 0  # initial value for time slider
x_t = [x[xi]]
v1_t = [v1[xi]]
v2_t = [v2[xi]]
r1_t = [r1[xi]]
r2_t = [r2[xi]]
d_t = 1e3 * (t[1] - t[0]) * ds

signals = bk.models.ColumnDataSource(
    dict(x=x, v1=v1, v2=v2, r1=r1, r2=r2)
)  # signals database
nodes = bk.models.ColumnDataSource(
    dict(x_t=x_t, v1_t=v1_t, v2_t=v2_t, r1_t=r1_t, r2_t=r2_t)
)  # dynamically rendered time points
index_t = bk.models.ColumnDataSource(dict(index=x))

plot1.circle(x="x_t", y="v1_t", source=nodes, size=15, color="grey", level="overlay")
plot2.circle(x="v1_t", y="r1_t", source=nodes, size=15, color="grey", level="overlay")
plot3.circle(x="x_t", y="v2_t", source=nodes, size=15, color="grey", level="overlay")
plot4.circle(x="v2_t", y="r2_t", source=nodes, size=15, color="grey", level="overlay")

# time slider
Timecallback = bk.models.CustomJS(
    args=dict(signals=signals, nodes=nodes, dt=d_t),
    code="""
    const arrayPos = (array) => Math.abs(array - cb_obj.value) < 0.1*dt;
    const new_value = signals.data['x'].findIndex(arrayPos);
    
    nodes.data['x_t']  = [signals.data['x'][new_value]];
    nodes.data['v1_t']  = [signals.data['v1'][new_value]];
    nodes.data['v2_t']  = [signals.data['v2'][new_value]];
    nodes.data['r1_t']  = [signals.data['r1'][new_value]];
    nodes.data['r2_t']  = [signals.data['r2'][new_value]];
    nodes.change.emit();
""",
)

t_slider = bk.models.Slider(
    start=x[0], end=x[-1], value=x[xi], step=d_t, title="Time [ms]"
)
t_slider.js_on_change("value", Timecallback)

# play/pause
toggle_js = bk.models.CustomJS(
    args=dict(slider=t_slider, indexCDS=index_t, dt=d_t),
    code="""
    var check_and_iterate = function(index){
        var slider_val = slider.value;
        var toggle_val = cb_obj.active;
        if(toggle_val == false) {
            cb_obj.label = '► Play';
            clearInterval(looop);
            }
        else if(slider_val >= index[index.length - 1]) {
            cb_obj.label = '► Play';
            slider.value = index[0];
            cb_obj.active = false;
            clearInterval(looop);
            }
        else if(slider_val !== index[index.length - 1]){
            slider.value = slider.value + dt;
            }
        else {
        clearInterval(looop);
            }
    }
    if(cb_obj.active == false){
        cb_obj.label = '► Play';
        clearInterval(looop);
    }
    else {
        cb_obj.label = '❚❚ Pause';
        var looop = setInterval(check_and_iterate, 100, indexCDS.data['index']);
    };
""",
)

toggle = bk.models.Toggle(label="► Play", active=False)
toggle.js_on_change("active", toggle_js)

# render interactive plot
layout = bk.layouts.grid(
    [
        bk.models.Div(
            text=f"<b>Rinzel model simulation</b>", style={"font-size": "150%"}
        ),
        [plot1, plot3],
        [plot2, plot4],
        [t_slider],
        [toggle],
    ]
)


# render figure
bk.io.show(layout)
# specify figure caption and tag
tag = "rinzel-mult"
caption = "Voltage trace and phase plane plots for the Rinzel model under constant injected current 𝐼 and with multiplicative coupling. <b>(Top)</b> Spike train generated by a reduced HHN with constant injected current I and multiplicative coupling: (left) without stimulus, (right) with bandlimited stimulus. <b>(Bottom)</b> The orbit of a reduced HHN with a constant injected current I and multiplicative coupling: (left) without stimulus, (right) with bandlimited stimulus."
# save figure with name specified by tag
save_tagged_fig(tag, layout, fig_dir, "bokeh")
# render caption
fig_count += 1 # update figure count
IPython.display.HTML(
    data="<center><b>Figure {}.</b> {}</center>".format(fig_count, caption)
)

NameError: name 'res_multiply' is not defined

## I/O Equivalent IAF Neuron with Variable Threshold for Hodgkin-Huxley Neuron with Multiplicative Coupling

In [11]:
dt = 1e-5
t = np.arange(0.0, 0.2, dt)
omega = 2 * np.pi * 50
I_inj = 8
I_ext = 1.5 + sinc_stimulus(t, omega, K=15)

hhn = HodgkinHuxley()
# Find the variable threshold for the IAF neuron
res = hhn.solve(t, I_ext=I_inj * np.ones_like(t), verbose=True)
spikes = spike_detect(res["V"][0], height=0)
spikes = np.where(spikes)[0]
delta = (spikes[-1] - spikes[-2]) * dt
# Simulate the Hodgkin-Huxley neuron with multiplicative coupling
res = solve_multiplicative(hhn, t, stimulus=I_ext, I_ext=I_inj, verbose=True)
V_hhn = res["V"][0]
# Encode the stimulus with an ideal IAF neuron
v = 0
V_iaf = np.zeros_like(t)
for i in tqdm(range(len(t))):
    v += dt * I_ext[i]
    if v >= delta:
        v = v - delta
    V_iaf[i] = v

  0%|          | 0/20000 [00:00<?, ?it/s]

NameError: name 'spike_detect' is not defined

In [12]:
mask_hhn = spike_detect(V_hhn)
mask_iaf = spike_detect(V_iaf)

trace_hhn = bk.models.ColumnDataSource(dict(x=1e3 * t, v=V_hhn))
trace_iaf = bk.models.ColumnDataSource(dict(x=1e3 * t, v=V_iaf))
spike_hhn = bk.models.ColumnDataSource(dict(xs=1e3 * t[mask_hhn], vs=V_hhn[mask_hhn]))
spike_iaf = bk.models.ColumnDataSource(dict(xs=1e3 * t[mask_iaf], vs=V_iaf[mask_iaf]))

plot1 = bk.plotting.figure(
    x_range=(-5, 205), y_range=(-78, 55), plot_width=800, plot_height=300
)
plot1.title.text = "Voltage trace and spikes of the Hodgkin-Huxley Neuron"
plot1.title.text_font_size = '12pt'
plot1.title.align = "center"
plot1.xaxis.axis_label = "Time [ms]"
plot1.xaxis.axis_label_text_font_style = "normal"
plot1.yaxis.axis_label = "Membrane Potential [mV]"
plot1.yaxis.axis_label_text_font_style = "normal"
plot1.line(
    x="x",
    y="v",
    source=trace_hhn,
    line_width=2,
    color="red",
    legend_label="Membrane Potential",
)
plot1.circle(
    x="xs", y="vs", source=spike_hhn, size=15, color="green", legend_label="Spike"
)
plot1.legend.location = "bottom_right"
plot1.legend.background_fill_alpha = 0.8

plot2 = bk.plotting.figure(
    x_range=(-5, 205), y_range=(-0.001, 0.02), plot_width=800, plot_height=300
)
plot2.title.text = "Voltage trace and spikes of the equivalent IAF Neuron"
plot2.title.text_font_size = '12pt'
plot2.title.align = "center"
plot2.title.align = "center"
plot2.xaxis.axis_label = "Time [ms]"
plot2.xaxis.axis_label_text_font_style = "normal"
plot2.yaxis.axis_label = "Membrane Potential [mV]"
plot2.yaxis.axis_label_text_font_style = "normal"
plot2.line(
    x="x",
    y="v",
    source=trace_iaf,
    line_width=2,
    color="red",
    legend_label="Membrane Potential",
)
plot2.circle(
    x="xs", y="vs", source=spike_iaf, size=15, color="green", legend_label="Spike"
)
plot2.legend.location = "bottom_right"
plot2.legend.background_fill_alpha = 0.8
plot2.x_range = plot1.x_range  # sync x-axis pan and zoom
layout = bk.layouts.grid([plot1, plot2])

# render figure
bk.io.show(layout)
# specify figure caption and tag
tag = "equiv-hhn-iaf"
caption = "I/O equivalence based on spike times deﬁned as maxima."
# save figure with name specified by tag
save_tagged_fig(tag, layout, fig_dir, "bokeh")
# render caption
fig_count += 1 # update figure count
IPython.display.HTML(
    data="<center><b>Figure {}.</b> {}</center>".format(fig_count, caption)
)

NameError: name 'spike_detect' is not defined

## I/O Equivalence for HHNs with Multiplicative Coupling and Feedback

In [13]:
dt = 1e-5
t = np.arange(0.0, 0.2, dt)
omega = 2 * np.pi * 20
I_injected = 8
bias = 1.5
# u = sinc_stimulus(t, omega, K=15)
u = np.sin(t*omega)

In [14]:
t_lc,lc = limit_cycle(HodgkinHuxley(), I_ext=I_injected, dt=dt, verbose=False)
hhn = HodgkinHuxley(V=lc[0,0],n=lc[1,0],m=lc[2,0],h=lc[3,0])
hhn_states = np.zeros((4, len(t)))
hhn_input = np.zeros_like(t)
hhn_z = np.zeros_like(t)
z = 1.
for tt, (_t, _u) in tqdm(enumerate(zip(t, u)), total=len(t), desc="Solving HHN w/ Feedback"):
    d_x = hhn.ode(t=_t, states=hhn.state_arr, I_ext=I_injected)
    _I = _u * z + bias
    hhn.state_arr += dt * hhn.Time_Scale * np.vstack(d_x) * _I
    hhn_states[:, tt] = hhn.state_arr[:,0]
    hhn_input[tt] = _u * z
    hhn_z[tt] = z
    if tt >= 2:
        if spike_detect_local(hhn_states[0,tt-2],hhn_states[0,tt-1],hhn_states[0,tt], 0):
            z *= -1

tk_idx_hhn, = np.where(spike_detect(hhn_states[0]))
tk_hhn = t[tk_idx_hhn]

NameError: name 'limit_cycle' is not defined

In [15]:
# Stimulus encoding with ASDM
asdm_params = dict(
    delta = (t_lc.max()-t_lc.min())/2, 
    C = 1.,  # kappa (capacitance) [mF]
    bias = bias
)

def integrate_step_asdm(y, x, x_prev, C, bias, sign, dt):
    return y + dt * (sign * bias + 0.5 * (x + x_prev)) / C

tk_idx_asdm = []  # spike time indices
V_asdm = np.zeros_like(t)
V_asdm[0] = -asdm_params['delta']
sign = 1
delta = asdm_params['delta']
for tt in range(1, len(t)):
    V_asdm[tt] = integrate_step_asdm(V_asdm[tt - 1], u[tt], u[tt - 1], asdm_params['C'], asdm_params['bias'], sign, dt)
    if sign * V_asdm[tt] >= delta:
        V_asdm[tt] = sign * delta
        sign = - sign
        tk_idx_asdm.append(tt)
tk_asdm = t[tk_idx_asdm]  # spike times
tk_idx_asdm = np.array(tk_idx_asdm)

NameError: name 't_lc' is not defined

In [16]:
V_hhn = hhn_states[0]
trace_hhn = bk.models.ColumnDataSource(dict(x=1e3 * t, v=V_hhn))
trace_asdm = bk.models.ColumnDataSource(dict(x=1e3 * t, v=V_asdm))
spike_hhn = bk.models.ColumnDataSource(dict(xs=1e3 * t[tk_idx_hhn], vs=V_hhn[tk_idx_hhn]))
spike_asdm = bk.models.ColumnDataSource(dict(xs=1e3 * t[tk_idx_asdm], vs=V_asdm[tk_idx_asdm]))

plot1 = bk.plotting.figure(
    x_range=(-5, 205), y_range=(-78, 55), plot_width=800, plot_height=300
)
plot1.title.text = "Voltage trace and spikes of the Hodgkin-Huxley Neuron w/ Feedback"
plot1.title.text_font_size = '12pt'
plot1.title.align = "center"
plot1.xaxis.axis_label = "Time [ms]"
plot1.xaxis.axis_label_text_font_style = "normal"
plot1.yaxis.axis_label = "Membrane Potential [mV]"
plot1.yaxis.axis_label_text_font_style = "normal"
plot1.line(
    x="x",
    y="v",
    source=trace_hhn,
    line_width=2,
    color="red",
    legend_label="Membrane Potential",
)
plot1.circle(
    x="xs", y="vs", source=spike_hhn, size=15, color="green", legend_label="Spike"
)
plot1.legend.location = "bottom_right"
plot1.legend.background_fill_alpha = 0.8

plot2 = bk.plotting.figure(
    x_range=plot1.x_range, y_range=(-asdm_params['delta']*1.1, asdm_params['delta']*1.1), 
    plot_width=800, plot_height=300
)
plot2.title.text = "Voltage trace and spikes of the equivalent ASDM Neuron"
plot2.title.text_font_size = '12pt'
plot2.title.align = "center"
plot2.title.align = "center"
plot2.xaxis.axis_label = "Time [ms]"
plot2.xaxis.axis_label_text_font_style = "normal"
plot2.yaxis.axis_label = "ASDM Integrator Output"
plot2.yaxis.axis_label_text_font_style = "normal"
plot2.line(
    x="x",
    y="v",
    source=trace_asdm,
    line_width=2,
    color="red",
    legend_label="ASDM Integrator Output",
)
plot2.circle(
    x="xs", y="vs", source=spike_asdm, size=15, color="green", legend_label="Spike"
)
plot2.legend.location = "bottom_right"
plot2.legend.background_fill_alpha = 0.8
plot2.x_range = plot1.x_range  # sync x-axis pan and zoom
layout = bk.layouts.grid([plot1, plot2])

# render figure
bk.io.show(layout)
# specify figure caption and tag
tag = "equiv-hhn-asdm"
caption = "I/O equivalence based on spike times deﬁned as maxima."
# save figure with name specified by tag
save_tagged_fig(tag, layout, fig_dir, "bokeh")
# render caption
fig_count += 1 # update figure count
IPython.display.HTML(
    data="<center><b>Figure {}.</b> {}</center>".format(fig_count, caption)
)

NameError: name 'hhn_states' is not defined

# I/O Equivalence for Hodgkin-Huxley Neurons with Additive Coupling

<!-- Consider now the neuron is additively coupled as follows:

$$ \dot{y} = f(y,I) + [u(t) \ 0 \ 0 \ 0 ]^{T} $$

where $I$ is a constant injected current (a.k.a bias $b$ in the case of Leaky Integrate-and-Fire model) that parametrizes the dynamic of the neuron.


## Phase Shift Process $\tau(t,I)$
Assuming that the original system $\dot{x} = f(x,I)$ has a $T$-periodic solution $x^{0}(t)$, the solution to the additively couple system can be written as:

$$ y(t) = x^{0}(t+\tau(t,I)) + z(t) $$

where

* $x^{0}(t+\tau(t,I))$: phase-shifted along limit cycle as in multiplicative coupling
* $z(t)$: orbital deviation -->
<!-- ## infinitesimal Phase Response Curve $\psi(t,I)$ - decoupling $u(t)$ and _phase shift_
To decouple the input from the phase shift caused by it, we introduce the concept of _infinitesimal Phase Response Curve_(iPRC), for that we invoke Malkin's Theorem.

### Malkin's Theorem
For _inifitesimally small perturbations_ $u(t)$, _Malkin's Theorem_ states that the solution to system

$$
\dot{y} = f(y,I) + u(t)
$$

is given by $y(t) = x^{0}(t+\tau(t,I))$, where

$$
\frac{d\tau}{dt} = \psi(t+\tau(t,I),I)u(t), \ \ \tau(0,b) =0
$$

and the PRC $\psi(t,I)$ is given as solution to the initial value problem:

$$
	\dot{\psi} = - [\mathbf{D}f(x^{0},I)]^{T} \psi, \ \ \ \psi(0,I)\cdot f(x^{0}(0),I) = 1
$$

where $\mathbf{D}f(x^{0},I) =\begin{bmatrix} \frac{\partial f}{\partial x_{1}}(x^{0},I) & \frac{\partial f}{\partial x_{2}}(x^{0},I) & \ldots & \frac{\partial f}{\partial x_{n}}(x^{0},I)&\end{bmatrix}^{T}$ is the Jacobian of system $f(\cdot)$ evaluated at point $x^{0}(t) \in \mathbb{R}^{n}$. -->

## Response to current pulse perturbation
_If $u(t)$ small enough_ the solution $y(t)$ can be approximated as 

$$ y(t) \approx x^{0}(t+\tau(t,I)) $$

where $\tau(t,b)$ is an **input-dependent** time shift function called the _Phase Shift Process_.

Equivalently, we can rewrite the addtive coupling equation as follows:

$$ \dot{y} \approx \big(1+\frac{d\tau(t,I)}{dt}\big) f(y,I) $$
 
The above equivalence shows that the approximation gives us a multiplicatively coupled system driven by $\frac{d\tau(t,I)}{dt}$

In [18]:
dt = 1e-6
t = np.arange(0.0, 0.02, dt)
I_bias = 10
hhr = HodgkinHuxleyRinzel()
# simulate without perturbation
I_ext = I_bias * np.ones_like(t)
res = hhr.solve(t, I_ext=I_ext, verbose="No Perturbation", solver="Euler")
v1 = res["V"][0]
r1 = res["R"][0]
# simulate with perturbation
I_ext = I_bias * np.ones_like(t)
I_ext[np.logical_and(1e3*t>=2.3, 1e3*t<=2.3+dt)] += 10000  # add a large impulse to input current
res = hhr.solve(t, I_ext=I_ext, verbose="Large Perturbation", solver="Euler")
v2 = res["V"][0]
r2 = res["R"][0]
# downsample signals
ds = int(5)  # downsampling factor
d_t = 1e3 * dt * ds
x = 1e3 * t[::ds]
v1 = v1[::ds]
r1 = r1[::ds]
v2 = v2[::ds]
r2 = r2[::ds]

No Perturbation:   0%|          | 0/20000 [00:00<?, ?it/s]

Large Perturbation:   0%|          | 0/20000 [00:00<?, ?it/s]

In [19]:
# store signals in bokeh datastructures
xi = 0  # initial value for time slider
signals = bk.models.ColumnDataSource(dict(x=x, v1=v1, r1=r1, v2=v2, r2=r2))
nodes = bk.models.ColumnDataSource(
    dict(x_t=[x[xi]], v1_t=[v1[xi]], r1_t=[r1[xi]], v2_t=[v2[xi]], r2_t=[r2[xi]])
)
link = bk.models.ColumnDataSource(dict(lx=[v1[xi], v2[xi]], ly=[r1[xi], r2[xi]]))

# static plots
plot1 = bk.plotting.figure(
    x_range=(-2, 22), y_range=(-100, 60), plot_width=400, plot_height=400
)
plot1.title.text = "Voltage trace"
plot1.title.text_font_size = '12pt'
plot1.title.align = "center"
plot1.xaxis.axis_label = "Time [ms]"
plot1.xaxis.axis_label_text_font_style = "normal"
plot1.yaxis.axis_label = "Voltage [mV]"
plot1.yaxis.axis_label_text_font_style = "normal"
plot1.line(
    x="x",
    y="v1",
    source=signals,
    color="red",
    line_width=2,
    legend_label="Unperturbed",
    level="overlay",
)
plot1.line(
    x="x", y="v2", source=signals, color="green", line_width=2, legend_label="Perturbed"
)
plot1.legend.location = "bottom_right"
plot1.legend.background_fill_alpha = 0.8

plot2 = bk.plotting.figure(
    x_range=(-80, 60), y_range=(-0.1, 1.0), plot_width=400, plot_height=400
)
plot2.title.text = "Phase plane"
plot2.title.text_font_size = '12pt'
plot2.title.align = "center"
plot2.xaxis.axis_label = "V"
plot2.xaxis.axis_label_text_font_style = "normal"
plot2.yaxis.axis_label = "R"
plot2.yaxis.axis_label_text_font_style = "normal"
plot2.line(
    x="v1",
    y="r1",
    source=signals,
    color="red",
    line_width=2,
    legend_label="Unperturbed",
    level="overlay",
)
plot2.line(
    x="v2",
    y="r2",
    source=signals,
    color="green",
    line_width=2,
    legend_label="Perturbed",
)
plot2.legend.location = "bottom_right"
plot2.legend.background_fill_alpha = 0.8

# animation
index_t = bk.models.ColumnDataSource(dict(index=x))
plot1.circle(x="x_t", y="v1_t", source=nodes, size=10, color="darkred", level="overlay")
plot1.circle(
    x="x_t", y="v2_t", source=nodes, size=10, color="darkgreen", level="overlay"
)
plot2.circle(
    x="v1_t", y="r1_t", source=nodes, size=10, color="darkred", level="overlay"
)
plot2.circle(
    x="v2_t", y="r2_t", source=nodes, size=10, color="darkgreen", level="overlay"
)
plot2.line(x="lx", y="ly", source=link, color="blue", level="overlay")

# time slider
Timecallback = bk.models.CustomJS(
    args=dict(signals=signals, nodes=nodes, link=link, dt=d_t),
    code="""
    const arrayPos = (array) => Math.abs(array - cb_obj.value) < 0.1*dt;
    const new_value = signals.data['x'].findIndex(arrayPos);
    
    nodes.data['x_t']  = [signals.data['x'][new_value]];
    nodes.data['v1_t']  = [signals.data['v1'][new_value]];
    nodes.data['v2_t']  = [signals.data['v2'][new_value]];
    nodes.data['r1_t']  = [signals.data['r1'][new_value]];
    nodes.data['r2_t']  = [signals.data['r2'][new_value]];
    link.data['lx']   = [signals.data['v1'][new_value], signals.data['v2'][new_value]];
    link.data['ly']   = [signals.data['r1'][new_value], signals.data['r2'][new_value]];
    nodes.change.emit();
    link.change.emit();
""",
)

t_slider = bk.models.Slider(
    start=x[0], end=x[-1], value=x[xi], step=d_t, title="Time [ms]"
)
t_slider.js_on_change("value", Timecallback)

# play/pause
toggle_js = bk.models.CustomJS(
    args=dict(slider=t_slider, indexCDS=index_t, dt=d_t),
    code="""
    var check_and_iterate = function(index){
        var slider_val = slider.value;
        var toggle_val = cb_obj.active;
        if(toggle_val == false) {
            cb_obj.label = '► Play';
            clearInterval(looop);
            }
        else if(slider_val >= index[index.length - 1]) {
            cb_obj.label = '► Play';
            slider.value = index[0];
            cb_obj.active = false;
            clearInterval(looop);
            }
        else if(slider_val !== index[index.length - 1]){
            slider.value = slider.value + dt;
            }
        else {
        clearInterval(looop);
            }
    }
    if(cb_obj.active == false){
        cb_obj.label = '► Play';
        clearInterval(looop);
    }
    else {
        cb_obj.label = '❚❚ Pause';
        var looop = setInterval(check_and_iterate, 1, indexCDS.data['index']);
    };
""",
)

toggle = bk.models.Toggle(label="► Play", active=False)
toggle.js_on_change("active", toggle_js)

# render interactive plot
layout = bk.layouts.grid(
    [
        bk.models.Div(
            text=f"<b>Rinzel perturbed vs unperturbed</b>", style={"font-size": "150%"}
        ),
        [plot1, plot2],
        [t_slider],
        [toggle],
    ],
    sizing_mode="scale_width",
)

# render figure
bk.io.show(layout)
# specify figure caption and tag
tag = "rinzel-perturb"
caption = "The response of the Rinzel neuron model to a current pulse perturbation."
# save figure with name specified by tag
save_tagged_fig(tag, layout, fig_dir, "bokeh")
# render caption
fig_count += 1 # update figure count
IPython.display.HTML(
    data="<center><b>Figure {}.</b> {}</center>".format(fig_count, caption)
)

### Illustrating Winfree's Method for Deriving Phase Response Curve for Rinzel Neuron

In [17]:
dt = 5e-6
t = np.arange(0.0, 0.025, dt)
I_bias = 15
hhr = HodgkinHuxleyRinzel()
res_arr = []
# simulate without perturbation
I_ext = I_bias * np.ones_like(t)
res = hhr.solve(t, I_ext=I_ext, verbose=False)
res_arr.append(res)
v1 = res["V"][0]
r1 = res["R"][0]
ref_idx = spike_detect(v1)
ref_idx = np.where(ref_idx)[0]
# simulate with perturbations
p_amp = 2  # perturbation amplitude
p_range = np.linspace(ref_idx[2], ref_idx[3], 50)  # pertubation indices
signals_db = {}
for i, pid in enumerate(p_range):
    hhr.states["V"] = v1[round(pid)] + p_amp
    hhr.states["R"] = r1[round(pid)]
    res = hhr.solve(
        t[round(pid) :], 
        I_ext=I_ext[round(pid) :],
        reset_initial_state=False,
        verbose=False
    )
    res_arr.append(res)
    signals_db[f"xp{i}"] = [1e3 * t[round(pid) :]]
    signals_db[f"vp{i}"] = [res["V"][0]]
    signals_db[f"rp{i}"] = [res["R"][0]]

NameError: name 'spike_detect' is not defined

In [21]:
# store signals in bokeh datastructures
signals_db = bk.models.ColumnDataSource(signals_db)
trace1 = bk.models.ColumnDataSource(dict(x=1e3 * t, v=v1, r=r1))
trace2 = bk.models.ColumnDataSource(
    dict(
        x=signals_db.data["xp0"][0],
        v=signals_db.data["vp0"][0],
        r=signals_db.data["rp0"][0],
    )
)
index_p = bk.models.ColumnDataSource(dict(index=1e3 * dt * p_range))
vbar1 = bk.models.ColumnDataSource(
    dict(x=[1e3 * dt * p_range[0], 1e3 * dt * p_range[0]], y=[-80, 60])
)
vbar2 = bk.models.ColumnDataSource(
    dict(x=[1e3 * dt * p_range[0], 1e3 * dt * p_range[0]], y=[-0.1, 1.1])
)
node = bk.models.ColumnDataSource(
    dict(x=[trace2.data["x"][0]], v=[trace2.data["v"][0]], r=[trace2.data["r"][0]])
)
dp = 1e3 * dt * (p_range[1] - p_range[0])

In [8]:
# static plots
TOOLS = "pan, box_zoom, reset"

plot1 = bk.plotting.figure(
    tools=TOOLS, x_range=(-2, 26), y_range=(-80, 60), plot_width=400, plot_height=250
)
plot1.xaxis.axis_label = "Time [ms]"
plot1.xaxis.axis_label_text_font_style = "normal"
plot1.yaxis.axis_label = "Voltage [mV]"
plot1.yaxis.axis_label_text_font_style = "normal"
plot1.line(
    x="x", y="v", source=trace1, color="red", line_width=2, legend_label="Unperturbed"
)
plot1.legend.location = "bottom_right"
plot1.legend.background_fill_alpha = 0.8
plot1.rect(
    x=1e3 * (t[ref_idx[3]] + t[ref_idx[2]]) / 2,
    y=-10,
    width=1e3 * (t[ref_idx[3]] - t[ref_idx[2]]),
    height=140,
    fill_color="blue",
    fill_alpha=0.1,
)

plot2 = bk.plotting.figure(
    tools=TOOLS, x_range=(-2, 26), y_range=(-0.1, 1.1), plot_width=400, plot_height=250
)
plot2.xaxis.axis_label = "Time [ms]"
plot2.xaxis.axis_label_text_font_style = "normal"
plot2.yaxis.axis_label = "R"
plot2.yaxis.axis_label_text_font_style = "normal"
plot2.line(
    x="x", y="r", source=trace1, color="red", line_width=2, legend_label="Unperturbed"
)
plot2.legend.location = "bottom_right"
plot2.legend.background_fill_alpha = 0.8
plot2.x_range = plot1.x_range  # sync x-axis pan and zoom
plot2.rect(
    x=1e3 * (t[ref_idx[3]] + t[ref_idx[2]]) / 2,
    y=-10,
    width=1e3 * (t[ref_idx[3]] - t[ref_idx[2]]),
    height=140,
    fill_color="blue",
    fill_alpha=0.1,
)

plot3 = bk.plotting.figure(
    tools=TOOLS, x_range=(-90, 60), y_range=(-0.1, 1.0), plot_width=400, plot_height=500
)
plot3.xaxis.axis_label = "V"
plot3.xaxis.axis_label_text_font_style = "normal"
plot3.yaxis.axis_label = "R"
plot3.yaxis.axis_label_text_font_style = "normal"
plot3.line(
    x="v", y="r", source=trace1, color="red", line_width=2, legend_label="Unperturbed"
)
plot3.legend.location = "bottom_right"
plot3.legend.background_fill_alpha = 0.8
plot3.x_range = plot1.y_range  # sync pan and zoom
plot3.y_range = plot2.y_range  # sync pan and zoom

# animations
plot1.line(
    x="x", y="v", source=trace2, color="green", line_width=2, legend_label="Perturbed"
)
plot2.line(
    x="x", y="r", source=trace2, color="green", line_width=2, legend_label="Perturbed"
)
plot3.line(
    x="v", y="r", source=trace2, color="green", line_width=2, legend_label="Perturbed"
)

plot1.line(x="x", y="y", source=vbar1, color="cyan")
plot2.line(x="x", y="y", source=vbar2, color="cyan")

plot1.circle(x="x", y="v", source=node, size=5, color="lime", level="overlay")
plot2.circle(x="x", y="r", source=node, size=5, color="lime", level="overlay")
plot3.circle(x="v", y="r", source=node, size=5, color="lime", level="overlay")

plot1.circle(x=1e3 * t[0], y=v1[0], size=5, color="orange", level="overlay")
plot2.circle(x=1e3 * t[0], y=r1[0], size=5, color="orange", level="overlay")
plot3.circle(x=v1[0], y=r1[0], size=5, color="orange", level="overlay")

# time slider
Timecallback = bk.models.CustomJS(
    args=dict(
        signals_db=signals_db,
        trace2=trace2,
        vbar1=vbar1,
        vbar2=vbar2,
        node=node,
        dt=dp,
        index_p=index_p,
    ),
    code="""
    const arrayPos = (array) => Math.abs(array - cb_obj.value) < 0.1*dt;
    const new_value = index_p.data['index'].findIndex(arrayPos);
    
    trace2.data['x']  = signals_db.data['xp' + new_value.toString()][0];
    trace2.data['v']  = signals_db.data['vp' + new_value.toString()][0];
    trace2.data['r']  = signals_db.data['rp' + new_value.toString()][0];
    vbar1.data['x'] = [cb_obj.value, cb_obj.value];
    vbar2.data['x'] = [cb_obj.value, cb_obj.value];
    node.data['x']  = [trace2.data['x'][0]];
    node.data['v']  = [trace2.data['v'][0]];
    node.data['r']  = [trace2.data['r'][0]];
    trace2.change.emit();
    vbar1.change.emit();
    vbar2.change.emit();
    node.change.emit();
""",
)

t_slider = bk.models.Slider(
    start=1e3 * dt * p_range[0],
    end=1e3 * dt * p_range[-1],
    value=1e3 * dt * p_range[0],
    step=dp,
    title="Perturbation time [ms]",
)
t_slider.js_on_change("value", Timecallback)

# play/pause
toggle_js = bk.models.CustomJS(
    args=dict(slider=t_slider, indexCDS=index_p, dt=dp),
    code="""
    var check_and_iterate = function(index){
        var slider_val = slider.value;
        var toggle_val = cb_obj.active;
        if(toggle_val == false) {
            cb_obj.label = '► Play';
            clearInterval(looop);
            }
        else if(slider_val >= index[index.length - 1]) {
            cb_obj.label = '► Play';
            slider.value = index[0];
            cb_obj.active = false;
            clearInterval(looop);
            }
        else if(slider_val !== index[index.length - 1]){
            slider.value = slider.value + dt;
            }
        else {
        clearInterval(looop);
            }
    }
    if(cb_obj.active == false){
        cb_obj.label = '► Play';
        clearInterval(looop);
    }
    else {
        cb_obj.label = '❚❚ Pause';
        var looop = setInterval(check_and_iterate, 300, indexCDS.data['index']);
    };
""",
)

toggle = bk.models.Toggle(label="► Play", active=False)
toggle.js_on_change("active", toggle_js)

# render interactive plot
layout = bokeh.layouts.column([plot1, plot2], sizing_mode="scale_width")
layout = bokeh.layouts.row([layout, plot3], sizing_mode="scale_width")
layout = bokeh.layouts.column([layout, t_slider, toggle], sizing_mode="scale_width")


# render figure
bk.io.show(layout)
# specify figure caption and tag
tag = "rinzel-prc"
caption = "Winfree’s method of deriving the PRC of the Rinzel neuron model. For every pulse with amplitude $\epsilon$ applied at phase $θ$, the spike shift $PRC(θ, I, \epsilon) = T-T^{'} = t_{k+1}^{'} - t_{k+1}$"
# save figure with name specified by tag
save_tagged_fig(tag, layout, fig_dir, "bokeh")
# render caption
fig_count += 1 # update figure count
IPython.display.HTML(
    data="<center><b>Figure {}.</b> {}</center>".format(fig_count, caption)
)

AttributeError: unexpected attribute 'plot_width' to figure, similar attributes are outer_width, width or min_width

### Computing the PRC using Winfree's Method
Winfree's method performs a impulse perturbation on the model neuron. The impulse perturbation ( adding $\epsilon \delta(t-\theta)$ to voltage $V$ at time $t$) probes the neuron dynamic in a similar way of characterizing linear filter by its impulse response

**Note**: the PRC's are normalized so that the largest absolute value of each of the phase response curves is `1`. This normalization is not required for our definitions and is done for visualization purposes.

In [9]:
# compute phase response using Winfree's method
hhn_limit_cycle_t_w, hhn_limit_cycle_w, hhn_prc_w = PRC(
    HodgkinHuxley(), I_ext=15, dt=1e-6,
    verbose=True, 
    winfree_steps=100, 
    winfree_perturb_amp=[.1, 1e-3, 1e-3, 1e-3]
)
hhn_period_w = hhn_limit_cycle_t_w[-1]
hhr_limit_cycle_t_w, hhr_limit_cycle_w, hhr_prc_w = PRC(
    HodgkinHuxleyRinzel(), I_ext=15, dt=1e-6,
    verbose=True, 
    winfree_steps=100, 
    winfree_perturb_amp=[.1, 1e-3]
)
hhr_period_w = hhr_limit_cycle_t_w[-1]

# normalize phase responses
hhn_prc_w_norm = hhn_prc_w / np.max(np.abs(hhn_prc_w), axis=1)[:,None]
hhr_prc_w_norm = hhr_prc_w / np.max(np.abs(hhr_prc_w), axis=1)[:,None]

NameError: name 'PRC' is not defined

In [10]:
ds = int(10)  # downsampling factor
figs = []
colors = ["red", "green", "skyblue", "blue"]
for t_lc, prc, Model, model_name in zip(
    [hhn_limit_cycle_t_w, hhr_limit_cycle_t_w],
    [hhn_prc_w_norm, hhr_prc_w_norm],
    [HodgkinHuxley, HodgkinHuxleyRinzel],
    ['Hodgkin-Huxley', 'Rinzel']
):

    data = bk.models.ColumnDataSource(
        {
            **{'x': np.linspace(0, 1, len(t_lc))[::ds]}, # 1e3 * t_lc[::ds]
            **{
                key: prc[n,::ds]
                for n,key 
                in enumerate(['V'])
            }
        }
    )

    fig = bk.plotting.figure(
        x_range=None if len(figs) == 0 else figs[0].x_range,
        y_range=None if len(figs) == 0 else figs[0].y_range,
        plot_width=400,
        plot_height=400,
    )
    fig.title.text = f"PRC {model_name} Neuron"
    fig.title.text_font_size = '12pt'
    fig.title.align = "center"
    fig.xaxis.axis_label = f"Normalized Phase - Period {t_lc.max()*1e3:.1f} ms"
    fig.xaxis.axis_label_text_font_style = "normal"
    fig.yaxis.axis_label = "Normalized Phase Shift"
    fig.yaxis.axis_label_text_font_style = "normal"

    for n, key in enumerate(['V']):
        fig.line(
            x="x",
            y=key,
            source=data,
            line_width=2,
            color=colors[n],
            legend_label=key,
        )
    fig.legend.location = "bottom_left"
    fig.legend.background_fill_alpha = 0.8
    figs.append(fig)
layout = bk.layouts.grid([figs])

# render figure
bk.io.show(layout)
# specify figure caption and tag
tag = "hhn-rinzel-prc"
caption = "PRC for Hodgkin-Huxley and Rinzel Neuron Models."
# save figure with name specified by tag
save_tagged_fig(tag, layout, fig_dir, "bokeh")
# render caption
fig_count += 1 # update figure count
IPython.display.HTML(
    data="<center><b>Figure {}.</b> {}</center>".format(fig_count, caption)
)

NameError: name 'hhn_limit_cycle_t_w' is not defined

### Rinzel PRC Manifold

*Note*: This block may take a while to run (expect 5-10 min runtime).

In [25]:
neu = HodgkinHuxleyRinzel()
bias_range = np.linspace(9, 15, 10)

# compute PRCs
_, _, neu_prc = PRC(
    neu, 
    I_ext=bias_range[0],
    dt=1e-6,
    verbose=False,
    winfree_steps=30,
    winfree_perturb_amp=[.1, 1e-3]
)
neu_prc = neu_prc[0]  # retain the voltage iPRC

prc_manifold = np.full((len(bias_range), len(neu_prc)), np.nan)
prc_manifold[0] = neu_prc
for i in tqdm(range(1, len(bias_range))):
    _, _, neu_prc = PRC(
        neu, I_ext=bias_range[i], dt=1e-6,
        verbose=False,
        winfree_steps=30,
        winfree_perturb_amp=[.1, 1e-3]
    )
    neu_prc = neu_prc[0]  # retain the voltage iPRC
    prc_manifold[i, : len(neu_prc)] = neu_prc

  0%|          | 0/9 [00:00<?, ?it/s]

In [26]:
[X, Y] = np.meshgrid(1e-2 * np.arange(prc_manifold.shape[1]), bias_range)
fig = go.Figure(data=[go.Surface(x=X, y=Y, z=prc_manifold)])
fig.update_layout(
    title="PRCs of the Rinzel neuron",
    title_x=0.5,
    autosize=False,
    scene_camera_eye=dict(x=1.6, y=-1.6, z=1.8),
    scene=dict(
        xaxis_title="Phase [ms]", 
        yaxis_title="Injected Current [mA]", 
        zaxis_title="PRCs"
    ),
    width=800,
    height=600,
    margin=dict(r=30, b=30, l=30, t=50),
)

# render figure
fig.show()
# specify figure caption and tag
tag = "rinzel-prcs"
caption = "Phase Response Curves of the Rinzel Neuron Model for different values of injected current."
# save figure with name specified by tag
save_tagged_fig(tag, fig, fig_dir, "plotly")
# render caption
fig_count += 1 # update figure count
IPython.display.HTML(
    data="<center><b>Figure {}.</b> {}</center>".format(fig_count, caption)
)

### The Relationship between the PSP and the iPRC of the LIF Neuron
Recall, the PRC of LIF neuron is given by

$$
PRC(\theta,b,v) = RC\ln(\frac{Rb}{Rb-ve^{\theta/RC}})
$$

where $\theta$ is the phase at which the perturbation happens, $b$ is the bias of the LIF neuron, $v$ is the perturbation amplitude.

The iPRC of LIF neuron is given by:

$$
\psi(s, b) = \frac{C}{b}\exp\left(\frac{s}{RC}\right), ~ 0\le s \le T
$$


In [27]:
def LIF_T(lif, b):
    if lif.params["R"] * b > lif.params["V_T"]:
        return (
            lif.params["R"]
            * lif.params["C"]
            * np.log(lif.params["R"] * b / (lif.params["R"] * b - lif.params["V_T"]))
        )
    else:
        return np.nan


def LIF_PRC(lif, theta, b, v):
    R = lif.params["R"]
    C = lif.params["C"]
    delta = lif.params["V_T"]
    if R * b > delta:
        T = R * C * np.log(R * b / (R * b - delta))
        return (
            R
            * C
            * np.log(R * b / (R * b - v * np.exp(np.remainder(theta, T) / (R * C))))
        )
    else:
        return np.nan


def LIF_iPRC(lif, s, b):
    R = lif.params["R"]
    C = lif.params["C"]
    delta = lif.params["V_T"]
    if R * b > delta:
        T = R * C * np.log(R * b / (R * b - delta))
        return C / b * np.exp(np.remainder(s, T) / (R * C))
    else:
        return np.nan

In [28]:
lif = LeakyIntegrateFire(C=1e-2, R=5, V_T=4)
b = 1  # bias
T = LIF_T(lif, b)
dt = 1e-3
t = np.arange(0, 3 * T, dt)
theta = np.linspace(0, T - dt, 50)
perturbs = np.array([
    d 
    for d 
    in np.around(np.linspace(-.5, .5, 8), 2)
    if not np.isclose(d, 0)
])
cs = plt.cm.get_cmap("coolwarm", len(perturbs))
colors = [mpl.colors.to_hex(cs(n)) for n in range(len(perturbs))]

In [29]:
# compute signals
res = lif.solve(t=t, I_ext=b * np.ones_like(t), verbose=False)
V_0 = res["V"][0]

signals = np.zeros((len(theta), len(perturbs), len(t)))
for i, th in enumerate(theta):
    for j, delta_I in enumerate(perturbs):
        I_ext = b * np.ones_like(t)
        I_ext[np.argmin(np.abs(t - th))] += delta_I * lif.params["C"] / dt
        res = lif.solve(t=t, I_ext=I_ext, verbose=False)
        signals[i,j] = res["V"][0]

In [30]:
iPRCs = LIF_iPRC(lif, theta, b)
nPRCs = np.zeros((len(perturbs), len(theta)))
for i, delta_I in enumerate(perturbs):
    nPRCs[i, :] = LIF_PRC(lif, theta, b, delta_I) / delta_I

In [31]:
# save data to Bokeh ColumnDataSource
signals_db = bk.models.ColumnDataSource(
    {f"V{i}": list(sig) for i, sig in enumerate(signals)}
)
nPRCs_db = bk.models.ColumnDataSource(
    dict(xs=[theta] * len(perturbs), ys=list(nPRCs), colors=colors)
)
signals_db_ref = bk.models.ColumnDataSource(
    dict(tp=[t] * len(perturbs), Vp=signals_db.data["V0"], colors=colors)
)

# construct bokeh plots
TOOLS = "pan, box_zoom, wheel_zoom, reset"

# work in milliseconds to get around limited decimal precision of Bokeh slider display:
dth = 1e3 * (theta[1] - theta[0])
index_t = bk.models.ColumnDataSource(dict(index=1e3 * theta))
vbar1 = bk.models.ColumnDataSource(dict(x=[1e3 * theta[0], 1e3 * theta[0]], y=[0, 4]))
vbar2 = bk.models.ColumnDataSource(
    dict(x=[1e3 * theta[0], 1e3 * theta[0]], y=[0, 0.05])
)

plot1 = bk.plotting.figure(tools=TOOLS, plot_width=720, plot_height=260)
plot1.title.text = "Voltage trace of LIF neurons"
plot1.title.text_font_size = '12pt'
plot1.title.align = "center"
plot1.xaxis.axis_label = "Time [s]"
plot1.xaxis.axis_label_text_font_style = "normal"
plot1.xaxis.ticker = [T, 2 * T, 3 * T]
plot1.xaxis.major_label_overrides = {T: "T", 2 * T: "2T", 3 * T: "3T"}
plot1.yaxis.axis_label = "Voltage"
plot1.yaxis.axis_label_text_font_style = "normal"

plot1.multi_line(
    xs="tp",
    ys="Vp",
    line_color="colors",
    line_width=2,
    source=signals_db_ref,
    legend_label="Perturbed",
)
plot1.line(
    x=t,
    y=V_0,
    line_color="black",
    line_width=2,
    line_dash="dashed",
    legend_label="Unperturbed",
)
plot1.line(x="x", y="y", source=vbar1, line_color="green", line_width=2)
plot1.legend.location = "top_left"
plot1.legend.background_fill_alpha = 0.8
color_mapper = bokeh.models.CategoricalColorMapper(
    factors=perturbs.astype(str), palette=colors
)
color_bar = bokeh.models.ColorBar(color_mapper=color_mapper)
plot1.add_layout(color_bar, "right")

plot2 = bk.plotting.figure(tools=TOOLS, plot_width=720, plot_height=260)
plot2.title.text = "Normalized PRC vs iPRC for LIF neuron"
plot2.title.text_font_size = '12pt'
plot2.title.align = "center"
plot2.xaxis.axis_label = "Phase [s]"
plot2.xaxis.axis_label_text_font_style = "normal"
plot2.xaxis.ticker = [0, T]
plot2.xaxis.major_label_overrides = {0: "0", T: "T"}
plot2.yaxis.axis_label = "Normalized PRC and ψ"
plot2.yaxis.axis_label_text_font_style = "normal"

plot2.multi_line(
    xs="xs",
    ys="ys",
    line_color="colors",
    line_width=2,
    source=nPRCs_db,
    legend_label="Normalized PRC",
)
plot2.line(
    x=theta,
    y=iPRCs,
    line_color="black",
    line_width=2,
    line_dash="dashed",
    legend_label="iPRC ψ",
)
plot2.line(x="x", y="y", source=vbar2, line_color="green", line_width=2)
plot2.legend.location = "top_left"
plot2.legend.background_fill_alpha = 0.8

# time slider
Timecallback = bk.models.CustomJS(
    args=dict(
        signals_db=signals_db,
        signals=signals_db_ref,
        vbar1=vbar1,
        vbar2=vbar2,
        indexCDS=index_t,
        dt=dth,
    ),
    code="""
    const arrayPos = (array) => Math.abs(array - cb_obj.value) < 0.1*dt;
    const new_value = indexCDS.data['index'].findIndex(arrayPos);
    signals.data['Vp']  = signals_db.data['V' + new_value.toString()];
    signals.change.emit();
    vbar1.data['x'] = [cb_obj.value/1000, cb_obj.value/1000];
    vbar1.change.emit();
    vbar2.data['x'] = [cb_obj.value/1000, cb_obj.value/1000];
    vbar2.change.emit();
""",
)

t_slider = bk.models.Slider(
    start=1e3 * theta[0],
    end=1e3 * theta[-1],
    value=1e3 * theta[0],
    step=dth,
    title="Theta [ms]",
)
t_slider.js_on_change("value", Timecallback)

# play/pause
toggle_js = bk.models.CustomJS(
    args=dict(slider=t_slider, indexCDS=index_t, dt=dth),
    code="""
    var check_and_iterate = function(index){
        var slider_val = slider.value;
        var toggle_val = cb_obj.active;
        if(toggle_val == false) {
            cb_obj.label = '► Play';
            clearInterval(looop);
            }
        else if(slider_val >= index[index.length - 1]) {
            cb_obj.label = '► Play';
            slider.value = index[0];
            cb_obj.active = false;
            clearInterval(looop);
            }
        else if(slider_val !== index[index.length - 1]){
            slider.value = slider.value + dt;
            }
        else {
        clearInterval(looop);
            }
    }
    if(cb_obj.active == false){
        cb_obj.label = '► Play';
        clearInterval(looop);
    }
    else {
        cb_obj.label = '❚❚ Pause';
        var looop = setInterval(check_and_iterate, 200, indexCDS.data['index']);
    };
""",
)

toggle = bk.models.Toggle(label="► Play", active=False)
toggle.js_on_change("active", toggle_js)

# render interactive plot
layout = bk.layouts.grid(
    [[plot1], [[plot2]], [t_slider], [None, toggle, None]], sizing_mode="scale_width"
)

# render figure
bk.io.show(layout)
# specify figure caption and tag
tag = "lif-prc-iprc"
caption = "(top) Unperturbed spike train (dashed black) and perturbed spike trains; (bottom) scaled PRCs and iPRC (dashed black) of the LIF neuron."
# save figure with name specified by tag
save_tagged_fig(tag, layout, fig_dir, "bokeh")
# render caption
fig_count += 1 # update figure count
IPython.display.HTML(
    data="<center><b>Figure {}.</b> {}</center>".format(fig_count, caption)
)

### iPRC Curves for the LIF Neuron

Below, we draw the iPRC curves of an LIF neuron for different values of injected current.

In [32]:
lif = LeakyIntegrateFire(C=7e-5, R=5, V_T=1e-2)
theta_range = np.arange(0.000005, 0.0005, 0.000005)
bias_range = np.arange(0.001, 0.01, 0.0002)
[T, B] = np.meshgrid(theta_range, bias_range)
ff = np.vectorize(LIF_iPRC)(lif, T, B)
kk = np.vectorize(LIF_T)(lif, B)
ff[T > kk] = np.nan

In [33]:
fig = go.Figure(data=[go.Surface(x=1e3 * B, y=1e3 *T, z=ff)])
fig.update_layout(
    title="iPhase Response Curves for the LIF neuron", 
    title_x=0.5, 
    autosize=False,
    scene_camera_eye=dict(x=-1.2, y=-1.2, z=1.5),
    scene=dict(yaxis_title="Phase [ms]", xaxis_title="Injected Current [mA]", zaxis_title="ψ"),
    width=800,
    height=600,
    margin=dict(r=30, b=30, l=30, t=50),
)

# render figure
fig.show()
# specify figure caption and tag
tag = "lif-iprcs"
caption = "The iPhase Response Curves of the LIF neuron for different values of injected current. Put together, these curves can be viewed as a two-dimensional iPhase Response Manifold."
# save figure with name specified by tag
save_tagged_fig(tag, fig, fig_dir, "plotly")
# render caption
fig_count += 1 # update figure count
IPython.display.HTML(
    data="<center><b>Figure {}.</b> {}</center>".format(fig_count, caption)
)

### iPhase Response Curves for the model neurons

Below, we draw the iPhase Response Curves for several model neurons.

In [34]:
iprc_manifolds = dict()
biases = [np.arange(7, 14, 0.5), np.arange(9, 15, 0.5)]
for Model, bias_range in zip(
    [HodgkinHuxley, HodgkinHuxleyRinzel],
    biases
):
    neu = Model()

    _, _, neu_prc = iPRC(neu, I_ext=bias_range[0], dt=1e-5)  # compute iPRCs
    neu_prc = neu_prc[0]  # retain the voltage iPRC

    iprcs = np.full((len(bias_range), len(neu_prc)), np.nan)
    iprcs[0] = neu_prc
    for i, bias in tqdm(enumerate(bias_range[1:]), total=len(bias_range)-1):
        _, _, neu_prc = iPRC(neu, I_ext=bias, dt=1e-5)  # compute iPRCs
        neu_prc = neu_prc[0]  # retain the voltage iPRC
        iprcs[i+1, : len(neu_prc)] = neu_prc
    
    iprc_manifolds[Model.__name__] = iprcs

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

In [35]:
fig = make_subplots(
    rows=1, cols=len(iprc_manifolds), 
    specs=[[{"type":"scene"} for _ in range(len(iprc_manifolds))]],
    subplot_titles=[f"iPhase Response Curves<br>{name}" for name in iprc_manifolds],
    column_widths=[400]*len(iprc_manifolds),
    row_heights=[500],
    horizontal_spacing = 0.01
)
for n, (model_name, iprcs) in enumerate(iprc_manifolds.items()):
    [X, Y] = np.meshgrid(1e-2 * np.arange(iprcs.shape[1]), biases[n])
    fig.add_trace(go.Surface(x=X, y=Y, z=iprcs, cmin=-.5, cmax=1.), row=1, col=n+1)
fig.update_scenes(
    camera_eye=dict(x=1.6, y=-1.6, z=1.8),
    xaxis_title="Phase [ms]", 
    yaxis_title="Injected Current [mA]", 
    zaxis_title="ψ",
)
fig.update_layout(
    margin=dict(r=30, b=30, l=30, t=50)
)

# render figure
fig.show()
# specify figure caption and tag
tag = "hhn-hhr-iprcs"
caption = "iPhase Response Curves of model neurons."
# save figure with name specified by tag
save_tagged_fig(tag, fig, fig_dir, "plotly")
# render caption
fig_count += 1 # update figure count
IPython.display.HTML(
    data="<center><b>Figure {}.</b> {}</center>".format(fig_count, caption)
)

### Multiple iPRCs can be traversed by the neuron 
When the input amplitude $max_{t}|u(t)|$ is _small_ as compared to current $I$, the neuron stays close to the same limit cycle (left). When the stimulus $u(t)$ is strong, however, the neuron can traverse across several limit cycles (right). Since reduced PIF assumes a single iPRC $\psi(t,I)$ (which corresponds to a single limit cycle), it is a good approximation only when the stimulus $u(t)$ is small.

In [36]:
dt = 1e-6
ds = int(20)  # downsampling factor
neu = HodgkinHuxleyRinzel()
N = 30  # number of limit cycles
bias_arr = np.linspace(15, 200, N)
cycle_arr = []
for I_inj in tqdm(bias_arr):
    t_lc, lc = limit_cycle(neu, dt=dt, I_ext=I_inj, verbose=False, N_spikes=3)
    cycle_arr.append([lc[0, ::ds], lc[1, ::ds]])

  0%|          | 0/30 [00:00<?, ?it/s]

In [37]:
t = np.arange(0, 0.02, dt)
omega = 2 * np.pi * 50
u = sinc_stimulus(t, omega, K=15)
u1 = 50 + u
u2 = 50 + 30 * u
res1 = neu.solve(t, I_ext=u1, verbose="Small Perturbation")
res2 = neu.solve(t, I_ext=u2, verbose="Large Perturbation")

# downsample signals
u1 = u1[::ds]
u2 = u2[::ds]
v1 = res1["V"][0][::ds]
r1 = res1["R"][0][::ds]
v2 = res2["V"][0][::ds]
r2 = res2["R"][0][::ds]
t = t[::ds]

Small Perturbation:   0%|          | 0/20000 [00:00<?, ?it/s]

Large Perturbation:   0%|          | 0/20000 [00:00<?, ?it/s]

In [38]:
# In between each pair of spikes, we find the average input amplitude,
# which are important for conditional PRCs (which are not covered in the lecture.
u1d = np.zeros_like(u1)
s1 = spike_detect(v1)
s1 = np.where(s1)[0]
for i in range(len(s1) - 1):
    u1d[s1[i] : s1[i + 1]] = np.mean(u1[s1[i] : s1[i + 1]])
u1d[: s1[0]] = np.mean(u1[: s1[0]])
u1d[s1[-1] :] = np.mean(u1[s1[-1] :])

u2d = np.zeros_like(u2)
s2 = spike_detect(v2)
s2 = np.where(s2)[0]
for i in range(len(s2) - 1):
    u2d[s2[i] : s2[i + 1]] = np.mean(u2[s2[i] : s2[i + 1]])
u2d[: s2[0]] = np.mean(u2[: s2[0]])
u2d[s2[-1] :] = np.mean(u2[s2[-1] :])

In [39]:
palette = cycle([
    mpl.colors.to_hex(plt.cm.coolwarm(n/N))
    for n in range(N)
])
fig = make_subplots(
    rows=2,
    cols=2,
    subplot_titles=("max(u)=1, I=50", "max(u)=30, I=50"),
    specs=[[{}, {}], [{"type": "surface"}, {"type": "surface"}]],
    column_widths=[0.5, 0.5],
    row_heights=[0.2, 1.2],
    vertical_spacing=0.0,
)


for i, I_bias in enumerate(bias_arr):  # trace 0 to (2*N-1)
    x = cycle_arr[i][0]
    y = cycle_arr[i][1]
    z = I_bias * np.ones_like(x)
    c = next(palette)
    fig.add_scatter3d(
        x=x,
        y=y,
        z=z,
        mode="lines",
        line=dict(color=c, width=3),
        showlegend=False,
        row=2,
        col=1,
    )
    fig.add_scatter3d(
        x=x,
        y=y,
        z=z,
        mode="lines",
        line=dict(color=c, width=3),
        showlegend=False,
        row=2,
        col=2,
    )

fig.add_trace(
    go.Scatter(
        x=1e3 * t, y=u1, mode="lines", line=dict(color="blue"), showlegend=False
    ),
    row=1,
    col=1,
)  # trace 2*N
fig.add_trace(
    go.Scatter(
        x=1e3 * t, y=u2, mode="lines", line=dict(color="blue"), showlegend=False
    ),
    row=1,
    col=2,
)  # trace (2*N+1)
fig.add_trace(
    go.Scatter(
        x=[1e3 * t[0]],
        y=[u1[0]],
        mode="markers",
        showlegend=False,
        marker=dict(color="green", size=9),
    ),
    row=1,
    col=1,
)  # trace (2*N+2)
fig.add_trace(
    go.Scatter(
        x=[1e3 * t[0]],
        y=[u2[0]],
        mode="markers",
        showlegend=False,
        marker=dict(color="green", size=9),
    ),
    row=1,
    col=2,
)  # trace (2*N+3)
fig.add_scatter3d(
    x=[v1[0]],
    y=[r1[0]],
    z=[u1[0]],
    mode="markers",
    marker=dict(color="green", size=3),
    showlegend=False,
    row=2,
    col=1,
)  # trace (2*N+4)
fig.add_scatter3d(
    x=[v2[0]],
    y=[r2[0]],
    z=[u2[0]],
    mode="markers",
    marker=dict(color="green", size=3),
    showlegend=False,
    row=2,
    col=2,
)  # trace (2*N+5)
fig.add_scatter3d(
    x=v1,
    y=r1,
    z=u1,
    mode="lines",
    line=dict(color="black", width=3),
    showlegend=False,
    row=2,
    col=1,
)  # trace (2*N+6)
fig.add_scatter3d(
    x=v2,
    y=r2,
    z=u2,
    mode="lines",
    line=dict(color="black", width=3),
    showlegend=False,
    row=2,
    col=2,
)  # trace (2*N+7)
fig.add_trace(
    go.Scatter(
        x=1e3 * t,
        y=u1d,
        mode="lines",
        line=dict(color="red", width=0.9),
        showlegend=False,
    ),
    row=1,
    col=1,
)  # trace (2*N+8)
fig.add_trace(
    go.Scatter(
        x=1e3 * t,
        y=u2d,
        mode="lines",
        line=dict(color="red", width=0.9),
        showlegend=False,
    ),
    row=1,
    col=2,
)  # trace (2*N+9)


frames = [
    go.Frame(
        data=[
            go.Scatter(
                x=[1e3 * t[k]],
                y=[u1[k]],
                mode="markers",
                name="Timepoint",
                showlegend=False,
                marker=dict(color="green", size=9),
            ),
            go.Scatter(
                x=[1e3 * t[k]],
                y=[u2[k]],
                mode="markers",
                name="Timepoint",
                showlegend=False,
                marker=dict(color="green", size=9),
            ),
            go.Scatter3d(
                x=[v1[k]],
                y=[r1[k]],
                z=[u1[k]],
                mode="markers",
                marker=dict(color="green", size=3),
                showlegend=False,
            ),
            go.Scatter3d(
                x=[v2[k]],
                y=[r2[k]],
                z=[u2[k]],
                mode="markers",
                marker=dict(color="green", size=3),
                showlegend=False,
            ),
        ],
        traces=[
            2 * N + 2,
            2 * N + 3,
            2 * N + 4,
            2 * N + 5,
        ],  # specify that these frames are updating trace 0
        name=str(k),  # you need to name the frame for the animation to behave properly
    )
    for k in np.arange(0, len(t), 5)
]  # downsample frames by a factor of 5

fig.update(frames=frames)


def frame_args(duration):
    return {
        "frame": {"duration": duration, "redraw": True},
        "mode": "immediate",
        "fromcurrent": True,
        "transition": {"duration": duration, "easing": "linear"},
    }


sliders = [
    {
        "pad": {"b": 10, "t": 60},
        "len": 0.9,
        "x": 0.1,
        "y": 0,
        "steps": [
            {
                "args": [[f.name], frame_args(0)],
                "label": str(k),
                "method": "animate",
            }
            for k, f in enumerate(fig.frames)
        ],
    }
]

# Layout
fig.update_layout(
    width=800,
    height=800,
    scene=dict(aspectratio=dict(x=1, y=1, z=1)),
    updatemenus=[
        {
            "buttons": [
                {
                    "args": [None, frame_args(1)],
                    "label": "&#9654;",  # play symbol
                    "method": "animate",
                },
                {
                    "args": [[None], frame_args(0)],
                    "label": "&#9724;",  # pause symbol
                    "method": "animate",
                },
            ],
            "direction": "left",
            "pad": {"r": 10, "t": 70},
            "type": "buttons",
            "x": 0.1,
            "y": 0,
        }
    ],
    sliders=sliders,
    showlegend=False,
    margin=dict(r=30, b=30, l=30, t=50)
)

fig.layout.xaxis1.title = "Time [ms]"
fig.layout.yaxis1.title = "u(t) + I [mA]"
fig.layout.xaxis2.title = "Time [ms]"
fig.layout.yaxis2.title = "u(t) + I [mA]"
fig.layout.scene1.camera.eye = dict(x=-1.6, y=-1.6, z=1.8)
fig.layout.scene2.camera.eye = dict(x=-1.6, y=-1.6, z=1.8)
fig.layout.scene1.xaxis.title = "V [mV]"
fig.layout.scene2.xaxis.title = "V [mV]"
fig.layout.scene1.yaxis.title = "R"
fig.layout.scene2.yaxis.title = "R"
fig.layout.scene1.zaxis.title = "u(t) + I[mA]"
fig.layout.scene2.zaxis.title = "u(t) + I[mA]"

# render figure
fig.show()
# specify figure caption and tag
tag = "hhn-lcycles"
caption = "A characterization of the limit cycle orbits of the HHN with injected current I and input stimulus $u(t)$."
# save figure with name specified by tag
save_tagged_fig(tag, fig, fig_dir, "plotly")
# render caption
fig_count += 1 # update figure count
IPython.display.HTML(
    data="<center><b>Figure {}.</b> {}</center>".format(fig_count, caption)
)