# Inference on Hodgkin-Huxley model: tutorial

In this tutorial, we use `sbi` to do inference on a [Hodgkin-Huxley model](https://en.wikipedia.org/wiki/Hodgkin%E2%80%93Huxley_model) from neuroscience (Hodgkin and Huxley, 1952). We will learn two parameters ($\bar g_{Na}$,$\bar g_K$) based on a current-clamp recording, that we generate synthetically (in practice, this would be an experimental observation).

Note, you find the original version of this notebook at [https://github.com/mackelab/sbi/blob/main/examples/00_HH_simulator.ipynb](https://github.com/mackelab/sbi/blob/main/examples/00_HH_simulator.ipynb) in the `sbi` repository.

First we are going to import basic packages.

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import torch

# visualization
import matplotlib as mpl
import matplotlib.pyplot as plt

# sbi
from sbi import utils as utils
from sbi import analysis as analysis
from sbi.inference.base import infer
from sbi.inference import SNPE, prepare_for_sbi, simulate_for_sbi

import pickle

import sys

sys.path.append("../simulators")


import jupyter_black

jupyter_black.load()

# remove top and right axis from plots
mpl.rcParams["axes.spines.right"] = False
mpl.rcParams["axes.spines.top"] = False

## Different required components

Before running inference, let us define the different required components:

1. observed data
1. prior over model parameters
2. simulator

## 1. Observed data
Let us assume we current-clamped a neuron and recorded the following voltage trace:



In [None]:
# Load the data

with open("HH_observation.pickle", "rb") as f:
    data_dict = pickle.load(f)

observation_trace = data_dict["observation_trace"]
I = data_dict["I"]
dt = data_dict["dt"]
t_on = data_dict["t_on"]
t_off = data_dict["t_off"]
A_soma = data_dict["A_soma"]

In [None]:
# plot the data

fig = plt.figure(figsize=(7, 5))
gs = mpl.gridspec.GridSpec(2, 1, height_ratios=[4, 1])
ax = plt.subplot(gs[0])
plt.plot(observation_trace["time"], observation_trace["data"])
plt.ylabel("voltage (mV)")
plt.title("observed data")
plt.setp(ax, xticks=[], yticks=[-80, -20, 40])

ax = plt.subplot(gs[1])
plt.plot(observation_trace["time"], I * A_soma * 1e3, "k", lw=2)
plt.xlabel("time (ms)")
plt.ylabel("input (nA)")

ax.set_xticks([0, max(observation_trace["time"]) / 2, max(observation_trace["time"])])
ax.set_yticks([0, 1.1 * np.max(I * A_soma * 1e3)])
ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter("%.2f"))

In fact, this voltage trace was not measured experimentally but synthetically generated by simulating a Hodgkin-Huxley model with particular parameters ($\bar g_{Na}$,$\bar g_K$). We will come back to this point later in the tutorial.

## 2. Simulator

We would like to infer the posterior over the two parameters ($\bar g_{Na}$,$\bar g_K$) of a Hodgkin-Huxley model, given the observed electrophysiological recording above. The model has channel kinetics as in [Pospischil et al. 2008](https://link.springer.com/article/10.1007/s00422-008-0263-8), and is defined by the following set of differential equations (parameters of interest highlighted in orange):

$$
\scriptsize
\begin{align}
C_m\frac{dV}{dt}&=g_1\left(E_1-V\right)+
                    \color{orange}{\bar{g}_{Na}}m^3h\left(E_{Na}-V\right)+
                    \color{orange}{\bar{g}_{K}}n^4\left(E_K-V\right)+
                    \bar{g}_Mp\left(E_K-V\right)+
                    I_{inj}+
                    \sigma\eta\left(t\right)\\
\frac{dq}{dt}&=\frac{q_\infty\left(V\right)-q}{\tau_q\left(V\right)},\;q\in\{m,h,n,p\}
\end{align}
$$

Above, $V$ represents the membrane potential, $C_m$ is the membrane capacitance, $g_{\text{l}}$ is the leak conductance, $E_{\text{l}}$ is the membrane reversal potential, $\bar{g}_c$ is the density of channels of type $c$ ($\text{Na}^+$, $\text{K}^+$, M), $E_c$ is the reversal potential of $c$, ($m$, $h$, $n$, $p$) are the respective channel gating kinetic variables, and $\sigma \eta(t)$ is the intrinsic neural noise. The right hand side of the voltage dynamics is composed of a leak current, a voltage-dependent $\text{Na}^+$ current, a delayed-rectifier $\text{K}^+$ current, a slow voltage-dependent $\text{K}^+$ current responsible for spike-frequency adaptation, and an injected current $I_{\text{inj}}$. Channel gating variables $q$ have dynamics fully characterized by the neuron membrane potential $V$, given the respective steady-state $q_{\infty}(V)$ and time constant $\tau_{q}(V)$ (details in Pospischil et al. 2008).

The input current $I_{\text{inj}}$ is defined as

In [None]:
from HH_model import syn_current

I, t_on, t_off, dt, t, A_soma = syn_current()

The Hodgkin-Huxley simulator is given by:

In [None]:
from HH_model import HHsimulator

Putting the input current and the simulator together:

In [None]:
def run_HH_model(params):
    params = np.asarray(params)

    # input current, time step
    I, t_on, t_off, dt, t, A_soma = syn_current()

    t = np.arange(0, len(I), 1) * dt

    # initial voltage
    V0 = -70

    states = HHsimulator(V0, params.reshape(1, -1), dt, t, I)

    return dict(data=states.reshape(-1), time=t, dt=dt, I=I.reshape(-1))

To get an idea of the output of the Hodgkin-Huxley model, let us generate some voltage traces for different parameters ($\bar g_{Na}$,$\bar g_K$), given the input current $I_{\text{inj}}$:

In [None]:
# three sets of (g_Na, g_K, tau)

params = np.array([[50.0, 1.0, 600], [4.0, 1.5, 200], [20.0, 15.0, 800]])

# test for different taus
# params = np.array([[50.0, 5.0, 50], [50, 5, 1000], [50.0, 5.0, 2000]])

num_samples = len(params[:, 0])
sim_samples = np.zeros((num_samples, len(I)))
for i in range(num_samples):
    sim_samples[i, :] = run_HH_model(params=params[i, :])["data"]

In [None]:
# colors for traces
col_min = 0
num_colors = num_samples + col_min
cm1 = mpl.cm.viridis  # mpl.cm.Blues
col1 = [cm1(1.0 * i / num_colors) for i in range(col_min, num_colors)]

fig = plt.figure(figsize=(7, 5))
gs = mpl.gridspec.GridSpec(2, 1, height_ratios=[4, 1])
ax = plt.subplot(gs[0])
for i in range(num_samples):
    plt.plot(t, sim_samples[i, :], color=col1[i], lw=2, label=i)
plt.ylabel("voltage (mV)")
ax.set_xticks([])
ax.set_yticks([-80, -20, 40])
plt.legend()

ax = plt.subplot(gs[1])
plt.plot(t, I * A_soma * 1e3, "k", lw=2)
plt.xlabel("time (ms)")
plt.ylabel("input (nA)")

ax.set_xticks([0, max(t) / 2, max(t)])
ax.set_yticks([0, 1.1 * np.max(I * A_soma * 1e3)])
ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter("%.2f"))
plt.show()

As can be seen, the voltage traces can be quite diverse for different parameter values.

Often, we are not interested in matching the exact trace, but only in matching certain features thereof. 

In this example of the Hodgkin-Huxley model, the summary features are:
- the number of spikes, 
- the mean resting potential, 
- the standard deviation of the resting potential, 
- and the first four voltage moments: mean, standard deviation, skewness and kurtosis. 

Using the function `calculate_summary_statistics()` imported below, we obtain these statistics from the output of the Hodgkin Huxley simulator. 

In [None]:
from HH_model import calculate_summary_statistics

Lastly, we define a function that performs all of the above steps at once. The function `simulation_wrapper` takes in conductance values, runs the Hodgkin Huxley model and then returns the summary statistics.

In [None]:
def simulation_wrapper(params):
    """
    Returns summary statistics from conductance values in `params`.

    Summarizes the output of the HH simulator and converts it to `torch.Tensor`.
    """
    obs = run_HH_model(params)
    summstats = torch.as_tensor(calculate_summary_statistics(obs))
    return summstats

In [None]:
# Let's get a quick intuition for the summary statistics
print(simulation_wrapper(params[1]))
"""
Returns:
  np.array, summary statistics
        (7: 
        spike count, 
        resting potential, 
        std of resting potential, 
        mean of voltage during stimulus, 
        moments of voltage during stimulus)
"""

In [None]:
# compare these with the actual voltage trace
sim = run_HH_model(params[1])["data"]
plt.plot(sim)

Now, let's also compute the summary statistics for the observed data:

In [None]:
observation_summary_statistics = torch.as_tensor(
    calculate_summary_statistics(observation_trace)
)

## 3. Prior over model parameters

Now that we have the simulator, we need to define a function with the prior over the model parameters ($\bar g_{Na}$,$\bar g_K$), which in this case is chosen to be a Uniform distribution:

In [None]:
prior_min = [0.5, 1e-4, 100]  # [mS/cm2, mS/cm2, ms]
prior_max = [8.0, 1.5, 1000]
prior = utils.torchutils.BoxUniform(
    low=torch.as_tensor(prior_min), high=torch.as_tensor(prior_max)
)

### Data generation

In [None]:
simulator, prior = prepare_for_sbi(simulation_wrapper, prior)

In [None]:
theta, x = simulate_for_sbi(
    simulator, proposal=prior, num_simulations=1_000, num_workers=6
)

In [None]:
# save the data
data_dict = dict(theta=theta, x=x)
with open("HH_dataset.pickle", "wb") as f:
    pickle.dump(data_dict, f)

We already simulated 1_000 traces for you with two different priors:

First dataset (`HH_dataset.pickle`)is generated from this prior:
```
prior_min = [0.5, 1e-4, 100]  # [mS/cm2, mS/cm2, ms]
prior_max = [8.0, 1.5, 1000]
prior = utils.torchutils.BoxUniform(
    low=torch.as_tensor(prior_min), high=torch.as_tensor(prior_max)
)
```

Second dataset (`HH_dataset2.pickle`) from this one:
```
prior_min = [0.5, 1e-4, 100]  # [mS/cm2, mS/cm2, ms]
prior_max = [80.0, 15, 1000]
prior = utils.torchutils.BoxUniform(
    low=torch.as_tensor(prior_min), high=torch.as_tensor(prior_max)
)
```

*Questions*: Which prior should you choose?

In [None]:
# load the data
num_simulations = 1_000

with open("HH_dataset.pickle", "rb") as f:
    data_dict = pickle.load(f)

theta = data_dict["theta"][:num_simulations]
x = data_dict["x"][:num_simulations]

## 2.1 Check the simulator (Prior predictive check)

For inference it is essential that the simulator is *well specified*. This means that we can generate our observed data, so there is some $\theta_i$ with $p(\theta_i)>0$ such that $p(s(x_o)|\theta_i)>0$. This is in general hard to evaluate, but we can at least look at the *prior predictives* and see if the `observation_summary_statistics` can be reproduced by the simmulator. 


In [None]:
labels_sumstats = [
    "spike count",
    "resting potential",
    "std of resting potential",
    "mean v",
    "moment 1",
    "moment 2",
    "moment 3",
    "moment 4",
]

fig, axes = analysis.pairplot(
    x,
    figsize=(8, 8),
    points=observation_summary_statistics,
    points_offdiag={"markersize": 6},
    points_colors="r",
    labels=labels_sumstats,
    points_labels=["observed summary statistics"],
    legend=True,
    title="Prior predictive check",
)
print("Summary stats of observed data: ", observation_summary_statistics)

*Question*: Is this what we expect?

## Inference
Now that we have all the required components, we can run inference with SNPE to identify parameters whose activity matches this trace.

In [None]:
inference = SNPE(prior, density_estimator="maf")

density_estimator = inference.append_simulations(theta, x).train()

In [None]:
plt.figure(1, figsize=(4, 3), dpi=200)
plt.plot(-np.array(inference.summary["training_log_probs"]), label="training")
plt.plot(
    -np.array(inference.summary["validation_log_probs"]), label="validation", alpha=1
)
plt.xlabel("epoch")
plt.ylabel("-log(p)")
plt.legend()

In [None]:
posterior = inference.build_posterior(density_estimator)

### Coming back to the observed data
As mentioned at the beginning of the tutorial, the observed data are generated by the Hodgkin-Huxley model with a set of known parameters ($\bar g_{Na}$,$\bar g_K$). To illustrate how to compute the summary statistics of the observed data, let us regenerate the observed data:

In [None]:
# true parameters and respective labels
true_params = np.array([50.0, 5.0, 800])
labels_params = [r"$g_{Na}$", r"$g_{K}$", r"$\tau$"]

In [None]:
observation_trace = run_HH_model(true_params)
observation_summary_statistics = calculate_summary_statistics(observation_trace)

As we already shown above, the observed voltage traces look as follows:

In [None]:
fig = plt.figure(figsize=(7, 5))
gs = mpl.gridspec.GridSpec(2, 1, height_ratios=[4, 1])
ax = plt.subplot(gs[0])
plt.plot(observation_trace["time"], observation_trace["data"])
plt.ylabel("voltage (mV)")
plt.title("observed data")
plt.setp(ax, xticks=[], yticks=[-80, -20, 40])

ax = plt.subplot(gs[1])
plt.plot(observation_trace["time"], I * A_soma * 1e3, "k", lw=2)
plt.xlabel("time (ms)")
plt.ylabel("input (nA)")

ax.set_xticks([0, max(observation_trace["time"]) / 2, max(observation_trace["time"])])
ax.set_yticks([0, 1.1 * np.max(I * A_soma * 1e3)])
ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter("%.2f"))

## Analysis of the posterior given the observed data

After running the inference algorithm, let us inspect the inferred posterior distribution over the parameters ($\bar g_{Na}$,$\bar g_K$), given the observed trace. To do so, we first draw samples (i.e. consistent parameter sets) from the posterior:

In [None]:
samples = posterior.sample((10000,), x=observation_summary_statistics)

In [None]:
fig, axes = analysis.pairplot(
    samples,
    limits=[[0.5, 80], [1e-4, 15.00], [100, 1000]],
    ticks=[[0.5, 80], [1e-4, 15.0], [100, 1000]],
    figsize=(5, 5),
    points=true_params,
    points_offdiag={"markersize": 6},
    points_colors="r",
    labels=labels_params,
    points_labels=["true theta"],
    legend=True,
)

As can be seen, the inferred posterior contains the ground-truth parameters (red) in a high-probability region. Now, let us sample parameters from the posterior distribution, simulate the Hodgkin-Huxley model for this parameter set and compare the simulations with the observed data, this is also called posterior predictives.

### Posterior predictives

In [None]:
# Draw a sample from the posterior and convert to numpy for plotting.
posterior_sample = posterior.sample((10,), x=observation_summary_statistics).numpy()

In [None]:
# simulate and plot samples
x_posterior = np.array(
    [
        run_HH_model(posterior_sample[i])["data"]
        for i in range(posterior_sample.shape[0])
    ]
)

In [None]:
fig = plt.figure(figsize=(7, 5))

# plot observation
t = observation_trace["time"]
y_obs = observation_trace["data"]


plt.plot(
    t, x_posterior[0], "-", lw=2, label="posterior sample", color="blue", alpha=0.5
)
for i in range(1, 10):
    plt.plot(t, x_posterior[i], "-", lw=2, color="blue", alpha=0.5)

plt.plot(t, y_obs, lw=2, label="observation", color="red")


plt.xlabel("time (ms)")
plt.ylabel("voltage (mV)")

ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], bbox_to_anchor=(1.3, 1), loc="upper right")

ax.set_xticks([0, 60, 120])
ax.set_yticks([-80, -20, 40]);

As can be seen, the sample from the inferred posterior leads to simulations that closely resemble the observed data, confirming that `SNPE` did a good job at capturing the observed data in this simple case.

## Simulation-based calibration

Let's also check if the posterior is well calibrated.

In [None]:
from sbi.analysis import run_sbc, sbc_rank_plot

In [None]:
# generate test data set from prior
num_sbc_simulations = 100
num_posterior_samples = 1000


theta_test, x_test = simulate_for_sbi(
    simulator, proposal=prior, num_simulations=num_sbc_simulations, num_workers=1
)

In [None]:
# run SBC

ranks, dap_samples = run_sbc(
    theta_test, x_test, posterior, num_posterior_samples=num_posterior_samples
)

In [None]:
fig, ax = sbc_rank_plot(ranks, num_posterior_samples)

*Question*: Can you explain this plot?

## References


A. L. Hodgkin and A. F. Huxley. A quantitative description of membrane current and its application to conduction and excitation in nerve. The Journal of Physiology, 117(4):500–544, 1952.

M. Pospischil, M. Toledo-Rodriguez, C. Monier, Z. Piwkowska, T. Bal, Y. Frégnac, H. Markram, and A. Destexhe. Minimal Hodgkin-Huxley type models for different classes of cortical and thalamic neurons. Biological Cybernetics, 99(4-5), 2008.

## Helper cells for the notebook

In [None]:
# save the data
data_dict = dict(
    true_params=true_params,
    observation_trace=observation_trace,
    I=I,
    t_on=t_on,
    t_off=t_off,
    dt=dt,
    t=t,
    A_soma=A_soma,
)
with open("HH_observation.pickle", "wb") as f:
    pickle.dump(data_dict, f)