## Izhikevich model of spiking neuron

$\dot{v} = 0.04 v^{2} + 5v + 140 - u + I$  
$\dot{u} = a\left( bv - u\right)$

with auxilliary after-spike reset

$v \ge 30 mV \Rightarrow \begin{cases}
    v \leftarrow c \\
    u \leftarrow u + d
\end{cases}$

where $v$ represents a membrane potential and $u$ membrane recovery variable.

**Reference**:
Izhikevich, E. M. (2003). Simple model of spiking neurons. IEEE Transactions on neural networks, 14(6), 1569-1572.

<img src="izhi_illustration.png" alt="Izhikevich model explanation" style="width: 447px; height: 303px"/>

In [1]:
# imports and other prepare stuff
# %matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import ipywidgets as widgets
from scipy.signal import square
from scipy.integrate import solve_ivp
from IPython.display import display
plt.style.use("seaborn-muted")

# define basics
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']
INPUT_START = 1000
LABEL_SIZE = 16

In [2]:
def terminal(func):
    """
    Decorate event function if it's terminal.
    For scipy's intergator.
    """
    func.terminal = True
    return func


class IzhikevichModel:
    """
    Class represents and is able to integrate Izhikevich model of artificial neuron.
    
    Izhikevich, E. M. (2003). Simple model of spiking neurons. IEEE Transactions on
        neural networks, 14(6), 1569-1572.
    """
    
    # default integration parameters
    dt = 0.1
    v_0 = -70.0
    u_0 = -14.0
    v_reset_cond = 30.0
    
    def __init__(self, a=0.02, b=0.2, c=-65.0, d=8.0, T=250):
        """
        Set neuron parameters and initialise stuff.
        """
        self.a = a
        self.b = b
        self.c = c
        self.d = d
        # integration params
        self.T_total = T
        self.n_points = int(np.floor(self.T_total / self.dt))
        # input
        self.I = None
        # state variables
        self.u = self.u_0
        self.v = self.v_0
        self.t = 0.
        self.num_spikes = 0
        
    @terminal
    def reset_condition(self, t, y, I):
        """
        Reset condition - spike.
        It is terminal, i.e. terminate and reset intergation after reset.
        """
        return y[0] - self.v_reset_cond
        
    def set_input(self, I):
        """
        Set input to the neuron.
        """
        # if I is scalar, cast to array
        if isinstance(I, float):
            I = np.ones((self.n_points,)) * I
        # input must have correct shape: at least n_points
        assert len(I) >= self.n_points
        self.I = I
        
    def _rhs(self, t, y, I):
        """
        Right hand side of the equation:
        v' = 0.04*v^2 + 5*v + 140 - u + I
        u' = a*(b*v - u)
        """
        v, u = y
        t_idx = int(t / self.dt)
        v_deriv = 0.04 * np.power(v, 2) + 5.0 * v + 140.0 - u + I[t_idx]
        u_deriv = self.a * (self.b * v - u)
        return (v_deriv, u_deriv)
    
    def integrate(self):
        """
        Run full integration.
        """
        ts = []
        ys = []
        y0 = [self.v_0, self.u_0]
        t = 0
        # start integration
        t_eval = np.linspace(0, self.T_total, self.n_points, endpoint=False)
        while True:
            # solve until spike
            sol = solve_ivp(self._rhs, t_span=[t, self.T_total], t_eval=t_eval, y0=y0,
                            args=[self.I], events=self.reset_condition)
            ts.append(sol.t)
            ys.append(sol.y)
            so_far_length = sum(len(t_temp) for t_temp in ts)
            # if terminated using event, i.e. spike
            if sol.status == 1:
                self.num_spikes += 1
                # restart with new t0 - the last time from previous integration
                t = sol.t[-1]
                t_eval = np.linspace(t, self.T_total, self.n_points - so_far_length, endpoint=False)
                # restart with new initial conditions as per reset
                y0 = sol.y[:, -1].copy()
                y0[0] = self.c
                y0[1] += self.d
            # if not terminated using event, i.e. end of integration
            else:
                break
        # stitch results together
        t = np.concatenate(ts)
        y = np.concatenate(ys, axis=1)
        # return as time, v variable, u variable
        return t, y[0, :], y[1, :]

In [3]:
def get_ext_input(I_max, I_period, current_type, t_total, input_length):
    """
    Construct external current of given type.
    """
    if current_type == "constant":
        return I_max
    elif current_type == "sine":
        time = np.linspace(0, t_total, input_length)
        return I_max * np.sin(2 * np.pi * time * (1./I_period))
    elif current_type == "sq. pulse":
        time = np.linspace(0, t_total, input_length)
        return I_max * square(2 * np.pi * time * (1./I_period))
    elif current_type == "ramp":
        time = np.linspace(0, t_total, input_length)
        return ((I_max / I_period) * time) * (time < I_period) + I_max * (time > I_period)
    else:
        raise ValueError("Unknown current type")

def integrate_and_plot(a, b, c, d, I_max, I_period=200, current_type="constant", T=700):
    """
    Integrate Izhikevich model and plot the results.
    """
    # set up model and integrate
    model = IzhikevichModel(a=a, b=b, c=c, d=d, T=T)
    input = np.zeros((model.n_points + 1))
    input_length = input.shape[0] - INPUT_START
    input[INPUT_START:] = get_ext_input(I_max, I_period, current_type, model.T_total, input_length)
    model.set_input(input)
    t, v, u = model.integrate()
    
    # set up figure
    fig = plt.figure(constrained_layout=True, figsize=(15, 8))
    spec = gridspec.GridSpec(ncols=3, nrows=3, figure=fig)
    # set up axis for timeseries of input current
    ax2 = fig.add_subplot(spec[2, :2])
    ax2.set_ylim([-20, 20])
    ax2.set_ylabel('INPUT CURRENT [AU]', size=LABEL_SIZE)
    ax2.set_xlabel("TIME [ms]", size=LABEL_SIZE)
    ax2.axvline(100., 0, 1, linestyle="--", color="grey", linewidth=0.7)
    ax2.spines['right'].set_visible(False)
    ax2.spines['top'].set_visible(False)
    ax2.tick_params(axis='both', which='major', labelsize=LABEL_SIZE - 2)

    # set up axis for timeseries of u and v variables
    ax1 = fig.add_subplot(spec[:2, :2], sharex=ax2)
    ax1.set_ylim([-90, 20])
    ax1.set_ylabel('MEMBRANE POTENTIAL [mV]', size=LABEL_SIZE)
    ax1.spines['right'].set_visible(False)
    ax1.spines['top'].set_visible(False)
    ax1.spines['bottom'].set_visible(False)
    ax1.axvline(100., 0, 1, linestyle="--", color="grey", linewidth=0.7)
    ax1.tick_params(axis='both', which='major', labelsize=LABEL_SIZE - 2)
    ax12 = ax1.twinx()
    ax12.set_ylim([-20, 10])
    ax12.set_yticklabels([])
    ax12.set_yticks([])
    ax12.spines['right'].set_visible(False)
    ax12.spines['top'].set_visible(False)
    ax12.spines['bottom'].set_visible(False)
    ax12.tick_params(axis='both', which='major', labelsize=LABEL_SIZE - 2)

    # set up axis for scatter u vs v
    ax3 = fig.add_subplot(spec[:2, 2], sharey=ax1)
    ax3.spines['right'].set_visible(False)
    ax3.spines['top'].set_visible(False)
    ax3.set_xlabel("MEMBRANE RECOVERY", size=LABEL_SIZE)
    scatter_colors = colors[3]
    ax3.set_ylim([-90, 20])
    ax3.set_xlim([-20, 10])
    ax3.tick_params(axis='both', which='major', labelsize=LABEL_SIZE - 2)
    
    # plot
    ax1.plot(t, v, color=colors[0], linewidth=2.5)
    ax12.plot(t, u, color=colors[1])
    ax2.plot(t, model.I[1:], color=colors[2])
    ax3.scatter(u, v, s=7, c=scatter_colors)
    plt.suptitle(f"Number of spikes: {model.num_spikes}", size=LABEL_SIZE + 3)

In [4]:
def setup_silders():
    """
    Set up interactive part of the plot, i.e. sliders and grid layout.
    """
    # define sliders
    a_slider = widgets.FloatSlider(min=0.02, max=0.1, step=0.008, value=0.02, description="a")
    b_slider = widgets.FloatSlider(min=0.2, max=0.25, step=0.01, value=0.2, description="b")
    c_slider = widgets.IntSlider(min=-65, max=-50, step=5, value=-65, description="c")
    d_slider = widgets.FloatSlider(min=0.05, max=8., step=0.1, value=8., description="d")
    I_m_slider = widgets.FloatSlider(min=0, max=20, step=0.5, value=10., description="I max")
    T_slider = widgets.IntSlider(min=500, max=2000, step=5, value=750, description="time")
    I_types = widgets.ToggleButtons(
        options=['constant', 'sq. pulse', 'sine', 'ramp'],
        value="constant",
        description='Current type',
        disabled=False,
        layout=widgets.Layout(height="auto", width="auto")
    )
    I_period = widgets.FloatSlider(min=10, max=500, step=5, value=200, description="I period")

    # define grid
    grid = widgets.GridspecLayout(7, 2)
    grid[0, :] = widgets.Button(description="Model parameters",
                                layout=widgets.Layout(height="auto", width="auto"))
    grid[1, 0] = a_slider
    grid[1, 1] = b_slider
    grid[2, 0] = c_slider
    grid[2, 1] = d_slider
    grid[3, :] = widgets.Button(description="External current parameters",
                               layout=widgets.Layout(height="auto", width="auto"))
    grid[4, 0] = I_period
    grid[5, 0] = I_m_slider
    grid[5, 1] = T_slider
    grid[6, :] = I_types
    sliders = {
        "a": a_slider,
        "b": b_slider,
        "c": c_slider,
        "d" :d_slider,
        "I_max": I_m_slider,
        "I_period": I_period,
        "T": T_slider,
        "current_type": I_types
    }
    return grid, sliders

In [5]:
# run interactive plot
grid, sliders = setup_silders()
ui = widgets.interactive_output(integrate_and_plot, sliders)

display(grid, ui)

GridspecLayout(children=(Button(description='Model parameters', layout=Layout(grid_area='widget001', height='a…

Output()