# Adaptive Exponential integrate-and-fire Model

http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model

$\tag{1}
C\frac{dV}{dt}=-g_L(V-E_L)+g_L\Delta_T\exp(\frac{V-V_T}{\Delta_T})-w+I$
$\tag{2}
\tau_w\frac{dw}{dt}=a(V-E_L)-w$
$\tag{3}
at  \quad t=t^f  \quad reset \begin{cases}V \rightarrow V_r \\
w \rightarrow w + b \end{cases}$

![Behaviour](http://www.scholarpedia.org/w/images/0/02/AdExDiagram.png)

In [None]:
# Parameter sets from https://brian2.readthedocs.io/en/stable/examples/frompapers.Brette_Gerstner_2005.html
# Pick an electrophysiological behaviour
tauw, a, b, Vr = 144*ms, 4*nS, 0.0805*nA, -70.6*mV # Regular spiking (as in the paper)
#tauw,a,b,Vr=20*ms,4*nS,0.5*nA,VT+5*mV # Bursting
#tauw,a,b,Vr=144*ms,2*C/(144*ms),0*nA,-70.6*mV # Fast spiking

In [None]:
%matplotlib inline
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import numpy as np
from matplotlib import pyplot as plt


# General simulation parameters
dt = 0.02  # Numerical time step
taux = 10  # (1,100) OU time constant
sigma = 0.15  # (0.05,0.5) Noise scaling
v_spike = 20  # For drawing nice spikes
N = 1  # Number of neurons
I0 = 1  # Maximum current 
Tdur = 150  # Simulation duration
Tstim = 100  # Stimulus duration
stimuli = ['step', 'delta', 'sine', 'ramp']  # Types of stimulus

n_steps = int(Tdur/dt)
sim_times = np.linspace(0, Tdur, num=n_steps+1)


def get_noise(N, n_steps, dt, taux, sigma):
    """Ornstein-Uhlenbeck process for additive noise"""
    noise = np.random.randn(N, n_steps)
    ou = np.zeros_like(noise)  # ((N, n_steps))
    for t in range(1, n_steps):
        ou[:, t] = ou[:, t-1] - dt/taux*ou[:, t-1] \
                              + sigma * np.sqrt(2*dt/taux) * noise[:, t]  # TODO: Check 2 * dt/taux
    return ou


def dvdt(t, sol, EL, gL, DeltaT, VT, C, a, tauw):
    v = sol[:, 0]
    w = sol[:, 1]
    v_diff = v-EL
    return np.array([(gL*(DeltaT*np.exp((v-VT)/DeltaT) - v_diff) - w)/C,
                     (a*v_diff-w)/tauw]).T


def simulate(stimulus='step', gL=30, EL=-70.6, Vr=-70.6, a=4, b=0.0805, 
             C=281, DeltaT=2, tauw=144, VT=-50.4, Trefract=0, noise=True):
    s0 = int((Tdur-Tstim)/2/dt)
    s_end = s0 + int(Tstim/dt)
    Input = np.zeros_like(sim_times)

    if stimulus == 'step':
        Input[s0:s_end] = I0
    elif stimulus == 'delta':
        s_end = s0 + int(1/dt)  # 1 ms
        Input[s0:s_end] = I0
    elif stimulus == 'sine':
        f = 0.5
        stim_times = np.linspace(0, (s_end-s0)*dt, num=s_end-s0) + 3*np.pi
        Input[s0:s_end] = (1+np.sin(f*stim_times))*I0/2
    elif stimulus == 'ramp':
        Input[s0:s_end] = I0 * np.linspace(0, 1, num=s_end-s0)

    v0 = EL  # Initial membrane potential
    var0 = np.array([v0, a*(v0-EL)])  # initial condition
    n_var = 2
    n_hold_steps = int(Trefract/dt)
    theta = VT + 5*DeltaT
    spikes = np.zeros((N, len(sim_times)), dtype=bool)
    v_new = np.zeros((N, len(sim_times), n_var))

    # Initialise solution variables
    v_new[:, 0, :] = np.tile(var0, (N, 1))
    v_old = v_new[:, 0, :]
    hold_steps = np.zeros(N, dtype=int)
    v2atspike = np.zeros(N)
    
    I_scale = 1e3/C
    current = np.tile(I_scale*Input, (N, 1))
    if noise:
        noise = get_noise(N, n_steps+1, dt, taux, sigma)
    else:
        noise = np.zeros((N, n_steps+1))
    scaled_noise = I_scale * noise
    
    for i in range(1, len(sim_times)):

        rhs = dvdt(sim_times[i], v_old, EL, gL, DeltaT, VT, C, a, tauw) 
        rhs[:, 0] += current[:, i] + scaled_noise[:, i]  # v_old == v_new[:, t-1, :]
        v_new[:, i, :] = v_old + dt * rhs

        spiked = np.where(v_new[:, i, 0]>theta)[0]
        hold_steps[spiked] = n_hold_steps  # neurons that spiked will be held for n_hold_steps steps
        v_old = v_new[:, i, :]

        if spiked.size > 0:
            v_old[spiked, 0] = Vr
            v_old[spiked, 1] = v_new[spiked, i, 1] + b
            v2atspike[spiked] = v_new[spiked, i, 1] + b
            spikes[spiked, i] = True  # TODO: spikes could be removed and t could be stored directly

        refractory = hold_steps > 0
        v_old[refractory, 0] = Vr
        v_old[refractory, 1] = v2atspike[refractory]
        hold_steps -= 1  # countdown for the held neurons

    v_new[spikes, 0] = v_spike
#     spiketimes = [sim_times[spikes[n, :]] for n in range(N)]
    
    fig, axes = plt.subplots(nrows=3, ncols=1, sharex=True, figsize=(14,8))
    axes[0].plot(sim_times, Input)
    axes[0].set_ylabel('$I$ [nA]')
    axes[1].plot(sim_times, v_new[0, :, 0])
    axes[1].set_ylim(-100, 30)
    axes[1].axhline(y=theta, linestyle=':', color='grey') #, xmin=0, xmax=1)
    axes[1].set_ylabel('$V$ [mV]')
    axes[2].plot(sim_times, v_new[0, :, 1])
    axes[2].set_ylabel('$w$ [nA]')
    axes[2].set_xlabel('Time [ms]')


# simulate(stimulus='step')
interact(simulate, stimulus=stimuli, gL=(0.,40), EL=(-100.,-50), Vr=(-100.,-50), a=(0.,10), b=(0.,10), 
         C=(200.,350), DeltaT=(0.,5), tauw=(100.,200), VT=(-65.,-30), Trefract=(0.,5), noise=True)