## Izhikevich model of spiking neuron

$\dot{v} = 0.04 v^{2} + 5v + 140 - u + I$  
$\dot{u} = a\left( bv - u\right)$

with auxilliary after-spike reset

$v \ge 30 mV \Rightarrow \begin{cases}
    v \leftarrow c \\
    u \leftarrow u + d
\end{cases}$

where $v$ represents a membrane potential and $u$ membrane recovery variable.

**Reference**:
Izhikevich, E. M. (2003). Simple model of spiking neurons. IEEE Transactions on neural networks, 14(6), 1569-1572.

<img src="izhi_illustration.png" alt="Izhikevich model explanation" style="width: 447px; height: 303px"/>

In [1]:
# imports and other prepare stuff
# %matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import ipywidgets as widgets
import scipy.signal as ss
from IPython.display import display
plt.style.use("seaborn-muted")

prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']

In [18]:
class IzhikevichModel:
    """
    Class represents and is able to integrate Izhikevich model of artificial neuron.
    
    Izhikevich, E. M. (2003). Simple model of spiking neurons. IEEE Transactions on
        neural networks, 14(6), 1569-1572.
    """
    
    # default integration parameters
    dt = 0.1
    v_0 = -70.0
    u_0 = -14.0
    v_reset_cond = 30.0
    
    def __init__(self, a=0.02, b=0.2, c=-65.0, d=8.0, T=250):
        """
        Set neuron parameters.
        """
        self.a = a
        self.b = b
        self.c = c
        self.d = d
        # integration params
        self.T_total = T
        self.n_points = int(self.T_total / self.dt)
        # input
        self.I = None
        # state variables
        self.u = self.u_0
        self.v = self.v_0
        self.t = 0.
        
    def set_input(self, I):
        """
        Set input to the neuron.
        """
        # if I is scaler, cast to array
        if isinstance(I, float):
            I = np.ones((self.n_points,)) * I
        # input must have correct shape
        assert len(I) == self.n_points
        self.I = I
        
    def _rhs(self, v, u, I):
        """
        Right hand side of the equation:
        v' = 0.04*v^2 + 5*v + 140 - u + I
        u' = a*(b*v - u)
        """
        v_deriv = 0.04 * np.power(v, 2) + 5.0 * v + 140.0 - u + I
        u_deriv = self.a * (self.b * v - u)
        return (v_deriv, u_deriv)
    
    def step(self):
        """
        Advance state by dt.
        Implements a reset condition:
        if v>30 then v->c; u->u+d
        """
        assert self.I is not None, "Input must be set"
        t_idx = int(self.t / self.dt)
        v_deriv, u_deriv = self._rhs(self.v, self.u, self.I[t_idx])
        # Euler dt step
        self.v += self.dt * v_deriv
        self.u += self.dt * u_deriv
        self.t += self.dt
        # reset condition
        if self.v >= self.v_reset_cond:
            self.v = self.c
            self.u += self.d
        
    
    def integrate(self):
        """
        Run full integration.
        """
        times = np.zeros((self.n_points,))
        v = np.zeros((self.n_points,))
        u = np.zeros((self.n_points,))
        times[0] = self.t
        v[0] = self.v
        u[0] = self.u
        for i in range(1, self.n_points):
            self.step()
            times[i] = self.t
            v[i] = self.v
            u[i] = self.u
                
        return times, v, u

In [19]:
INPUT_START = 1000

def get_ext_input(I_max, I_period, current_type, t_total, input_length):
    if current_type == "constant":
        return I_max
    elif current_type == "sine":
        time = np.linspace(0, t_total, input_length)
        return I_max * np.sin(2 * np.pi * time * (1./I_period))
    elif current_type == "sq. pulse":
        time = np.linspace(0, t_total, input_length)
        return I_max * ss.square(2 * np.pi * time * (1./I_period))
    elif current_type == "ramp":
        time = np.linspace(0, t_total, input_length)
        return ((I_max / I_period) * time) * (time < I_period) + I_max * (time > I_period)
    else:
        raise ValueError("Unknown current type")

def integrate_and_plot(a, b, c, d, I_max, I_period=200, current_type="constant", T=700):
    model = IzhikevichModel(a=a, b=b, c=c, d=d, T=T)
    input = np.zeros((model.n_points))
    input_length = input.shape[0] - INPUT_START
    input[INPUT_START:] = get_ext_input(I_max, I_period, current_type, model.T_total, input_length)
    model.set_input(input)
    t, v, u = model.integrate()
    
    fig = plt.figure(constrained_layout=True, figsize=(15, 8))
    spec = gridspec.GridSpec(ncols=3, nrows=3, figure=fig)

    ax2 = fig.add_subplot(spec[2, :2])
    ax2.set_ylim([-20, 20])
    ax2.set_ylabel('INPUT CURRENT [AU]')
    ax2.set_xlabel("TIME [ms]")
    ax2.axvline(100., 0, 1, linestyle="--", color="grey", linewidth=0.7)
    ax2.spines['right'].set_visible(False)
    ax2.spines['top'].set_visible(False)

    ax1 = fig.add_subplot(spec[:2, :2], sharex=ax2)
    ax1.set_ylim([-90, 20])
    ax1.set_ylabel('MEMBRANE POTENTIAL [mV]')
    ax1.spines['right'].set_visible(False)
    ax1.spines['top'].set_visible(False)
    ax1.spines['bottom'].set_visible(False)
    ax1.axvline(100., 0, 1, linestyle="--", color="grey", linewidth=0.7)
    ax12 = ax1.twinx()
    ax12.set_ylim([-20, 10])
    ax12.set_yticklabels([])
    ax12.set_yticks([])
    ax12.spines['right'].set_visible(False)
    ax12.spines['top'].set_visible(False)
    ax12.spines['bottom'].set_visible(False)

    ax3 = fig.add_subplot(spec[:2, 2], sharey=ax1)
    ax3.spines['right'].set_visible(False)
    ax3.spines['top'].set_visible(False)
    ax3.set_xlabel("MEMBRANE RECOVERY")
    scatter_colors = colors[3]
    ax3.set_ylim([-90, 20])
    ax3.set_xlim([-20, 10])
    
    ax1.plot(t, v, color=colors[0])
    ax12.plot(t, u, color=colors[1])
    ax2.plot(t, model.I, color=colors[2])
    ax3.scatter(u, v, s=7, c=scatter_colors)

In [20]:
def setup_silders():
    # define sliders
    a_slider = widgets.FloatSlider(min=0.02, max=0.1, step=0.008, value=0.02, description="a")
    b_slider = widgets.FloatSlider(min=0.2, max=0.25, step=0.01, value=0.2, description="b")
    c_slider = widgets.IntSlider(min=-65, max=-50, step=5, value=-65, description="c")
    d_slider = widgets.FloatSlider(min=0.05, max=8., step=0.1, value=8., description="d")
    I_m_slider = widgets.FloatSlider(min=0, max=20, step=0.5, value=10., description="I max")
    T_slider = widgets.IntSlider(min=500, max=2000, step=5, value=750, description="time")
    I_types = widgets.ToggleButtons(
        options=['constant', 'sq. pulse', 'sine', 'ramp'],
        value="constant",
        description='Current type',
        disabled=False,
        layout=widgets.Layout(height="auto", width="auto")
    )
    I_period = widgets.FloatSlider(min=10, max=500, step=5, value=200, description="I period")

    # define grid
    grid = widgets.GridspecLayout(7, 2)
    grid[0, :] = widgets.Button(description="Model parameters",
                                layout=widgets.Layout(height="auto", width="auto"))
    grid[1, 0] = a_slider
    grid[1, 1] = b_slider
    grid[2, 0] = c_slider
    grid[2, 1] = d_slider
    grid[3, :] = widgets.Button(description="External current parameters",
                               layout=widgets.Layout(height="auto", width="auto"))
    grid[4, 0] = I_period
    grid[5, 0] = I_m_slider
    grid[5, 1] = T_slider
    grid[6, :] = I_types
    sliders = {
        "a": a_slider,
        "b": b_slider,
        "c": c_slider,
        "d" :d_slider,
        "I_max": I_m_slider,
        "I_period": I_period,
        "T": T_slider,
        "current_type": I_types
    }
    return grid, sliders

In [21]:
grid, sliders = setup_silders()
ui = widgets.interactive_output(integrate_and_plot, sliders)

display(grid, ui)

GridspecLayout(children=(Button(description='Model parameters', layout=Layout(grid_area='widget001', height='a…

Output()