<a href="https://colab.research.google.com/github/nalewkoz/neural-bifurcations/blob/main/single_neuron_2d.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title 
# @markdown *Run this cell to import packages, load helper functions, and set default parameters*

# Standard imports

import matplotlib.pyplot as plt
import numpy as np

# Figure ettings
import ipywidgets as widgets       # interactive display
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/nma.mplstyle")

# @title Plotting Functions

def default_pars(**kwargs):
  pars = {}

  # Keep:
  pars["tau"] = 1

  # To play with:
  pars["b0"]   = 0.9
  pars["b1"]   = 1.1
  pars["tauw"] = 2
  pars["I"]     = 0

  # simulation parameters
  pars['T'] = 100.       # Total duration of simulation [ms]
  pars['dt'] = .1       # Simulation time step [ms]

  # External parameters if any
  pars.update(kwargs)

  # Vector of discretized time points [ms]
  pars['range_t'] = np.arange(0, pars['T'], pars['dt'])
  
  return pars


def plot_nullclines(u, w_u_null, w_w_null):

  #plt.figure()
  plt.plot(u, w_u_null, 'b', label='u-nullcline')
  plt.plot(u, w_w_null, 'r', label='w-nullcline')
  plt.xlabel(r'$u$')
  plt.ylabel(r'$w$')
  plt.ylim([-2,2.5])
  plt.legend(loc='best')
  #plt.show()


def my_plot_vector(pars, my_n_skip=2, myscale=5):
  uw_grid = np.linspace(-2.5, 2.5, 20)
  u, w = np.meshgrid(uw_grid, uw_grid)
  dudt, dwdt = uwderivs(u, w, **pars)

  n_skip = my_n_skip

  plt.quiver(u[::n_skip, ::n_skip], w[::n_skip, ::n_skip],
             dudt[::n_skip, ::n_skip], dwdt[::n_skip, ::n_skip],
             angles='xy', scale_units='xy', scale=myscale, facecolor='c')

  plt.xlabel(r'$u$')
  plt.ylabel(r'$w$')

def simulate_uw(u_init, w_init, I, tau, tauw, dt, range_t, b0, b1, **other_pars):
 # Initialize activity arrays
  Lt = range_t.size
  u = np.append(u_init, np.zeros(Lt - 1))
  w = np.append(w_init, np.zeros(Lt - 1))
  I_ext = I * np.ones(Lt)
  
  for k in range(Lt - 1):

    # Calculate the derivative of u
    du = dt / tau* (u[k] - u[k]**3/3 - w[k] + I_ext[k])

    # Calculate the derivative of w
    dw = dt / tauw * ( b0 + b1*u[k] - w[k] )

    # Update using Euler's method
    u[k + 1] = u[k] + du
    w[k + 1] = w[k] + dw

  return u, w

def my_plot_trajectory(pars, mycolor, x_init, mylabel):
  pars = pars.copy()
  pars['u_init'], pars['w_init'] = x_init[0], x_init[1]
  u_tj, w_tj = simulate_uw(**pars)

  plt.plot(u_tj, w_tj, color=mycolor, label=mylabel)
  plt.plot(x_init[0], x_init[1], 'o', color=mycolor, ms=8)
  plt.xlabel(r'$u$')
  plt.ylabel(r'$w$')
  return u_tj, w_tj

def my_plot_intime(pars, u, w):
  ttab = pars['range_t']
  plt.plot(ttab, u, 'k-', label=r'$u$')
  plt.plot(ttab, w, 'k:', label=r'$w$')


def my_plot_trajectories(pars, dx, n, mylabel):
  """
  Expects:
  pars    : Parameter dictionary
  dx      : increment of initial values
  n       : n*n trjectories
  mylabel : label for legend

  Returns:
    figure of trajectory
  """
  pars = pars.copy()
  for ie in range(n):
    for ii in range(n):
      pars['u_init'], pars['w_init'] = dx * ie, dx * ii
      u_tj, w_tj = simulate_uw(**pars)
      if (ie == n-1) & (ii == n-1):
          plt.plot(u_tj, w_tj, 'gray', alpha=0.8, label=mylabel)
      else:
          plt.plot(u_tj, w_tj, 'gray', alpha=0.8)

  plt.xlabel(r'$u$')
  plt.ylabel(r'$w$')

def plot_all(pars, u_tab, u_init, w_init):        
    # Compute nullclines
    w_u_null = get_w_at_u_nullcline(u_tab, **pars)
    w_w_null = get_w_at_w_nullcline(u_tab, **pars)

    plt.figure(figsize=(16,7))
    plt.subplot(1,2,1)
    # Visualize nullclines
    plot_nullclines(u_tab, w_u_null, w_w_null)

    # Vector field
    my_plot_vector(pars)

    # Trajectory

    u_traj, w_traj = my_plot_trajectory(pars,'gray',[u_init, w_init],'blah')

    plt.subplot(1,2,2)
    my_plot_intime(pars, u_traj, w_traj)
    plt.xlabel('time')
    plt.ylabel('u/w')
    plt.legend()

# Section 1: 2D models of neurons

The Hodgkin-Huxley (HH) model is not easy to analyze, especially without simulations. Leaky integrate-and-fire models are simple, but seem "artificial": action potentials are generated by construction, through an arbitrary threshold. Fortunately, approximate 2D models exist, which will allow us to understand in more detail possible mechanisms of spike generation. The general form of these two dimensional models is
\begin{align}
\tau \dot{u} &= F(u,w) + R I
\\
\tau_w \dot{w} &= G(u,w),
\end{align}
where 
$u$, $I$, and $R$ denote the membrane voltage, input current, and resistance. We also have some general (unknown at the moment) functions $F$ and $G$. $w$ is known as a 'recovery variable'. It summarizes the dynamics of three gating variables describing dynamics of sodium and potassium channels in the HH model.

Importantly $\tau_w \gg \tau$, i.e., the dynamics of the recovery variable is slow compared to the dynamics of the voltage. This is part of the reason why leaky integrate-and-fire models work pretty well. 

# Section 2: FitzHugh-Nagumo model

The FitzHugh-Nagumo model is given by the equations:
\begin{align}
F(u,w) &= u - \frac{1}{3}u^3 - w
\\
G(u,w) &= b_0 + b_1 u - w.
\end{align}
Note that both equations are linear in $w$. The only nonlinearity is pretty simple: a cubic term in $u$.

## Section 2.1: Nullclines

Here, we will plot nullclines. 
* u-nullcline is a set of point defined by the equation $\dot{u} = 0$.
* w-nullcline is a set of point defined by the equation $\dot{w} = 0$.

Note that our equations for nullclines admit a very simple form that can be easily solved for $w$. Thus, we will calculate $w$ given $u$ on a given nullcline.

Plot nullclines for different values of $I$. You can also play with other parameters. Answer questions:
1. What kind of object are nullclines geometrically?
2. What happens at the intersections of two nullclines?

In [None]:
def get_w_at_u_nullcline(u, I, **other_pars):
    '''w as a function of u at the u-nullcline'''
    ######################################################################
    # TODO for students: return the value of w as a function of u at the 
    # u-nullcline
    raise NotImplementedError("Student exercise: compute the vector field")
    ######################################################################
    return ...

def get_w_at_w_nullcline(u, b0, b1, **other_pars):
    '''w as a function of u at the w-nullcline'''
    return b0 + b1*u

# Set parameters
pars = default_pars(b0=0.9, b1=1.1, I=0)
u_tab = np.linspace(-2.5, 2.5, 100)

# Compute nullclines
w_u_null = get_w_at_u_nullcline(u_tab, **pars)
w_w_null = get_w_at_w_nullcline(u_tab, **pars)

# Visualize
plot_nullclines(u_tab, w_u_null, w_w_null)

## Section 2.2: Vector field and trajectories

The values of $u(t)$ and $w(t)$ at each time point $t$ correspond to a single point in the phase plane, with coordinates $(u(t),w(t))$. Therefore, the time-dependent trajectory of the system can be described as a continuous curve in the phase plane, and the tangent vector to the trajectory, which is defined as the vector $\bigg{(}\displaystyle{\frac{du(t)}{dt},\frac{dw(t)}{dt}}\bigg{)}$, indicates the direction towards which the activity is evolving and how fast is the activity changing along each axis. In fact, for each point $(u,w)$ in the phase plane, we can compute the tangent vector $\bigg{(}\displaystyle{\frac{du}{dt},\frac{dw}{dt}}\bigg{)}$, which  indicates the behavior of the system when it traverses that point. 

The map of tangent vectors in the phase plane is called **vector field**. The behavior of any trajectory in the phase plane is determined by i) the initial conditions $(u(0),w(0))$, and ii) the vector field $(\dot{u}, \dot{w})$.

In general, the value of the vector field at a particular point in the phase plane is represented by an arrow. The orientation and the size of the arrow reflect the direction and the norm of the vector, respectively.

Now let's plot and analyze the vector field and a sample trajectory. 

1. Implement the function that calculates $\dot{u} =\frac{d u}{d t}$ 
and $\dot{w} =\frac{d w}{d t}$.
2. Change the initial conditions (focus on $u$) for $I=0$. What kind of stimulation protocol does this correspond to? Do we observe a threshold behavior here? Are action potential stereotypical?
3. Now change $I$. Can we define a threshold? What kind of bifurcation does the system exhibit?



In [None]:
def uwderivs(u, w, b0, b1, I, tau, tauw, **other_pars):
  """Time derivatives for u and w"""
  ######################################################################
  # TODO for students: compute dudt and dwdt
  raise NotImplementedError("Student exercise: compute the vector field")
  ######################################################################

  # Compute the derivative of u
  dudt = ...

  # Compute the derivative of w
  dwdt = ...

  return dudt, dwdt


# Set parameters
pars = default_pars(b0=0.9, b1=1.1, tauw=5, I=0.55)
u_init = -1.5
w_init = -0.6
u_tab = np.linspace(-2.5, 2.5, 100)
plot_all(pars, u_tab, u_init, w_init)