# Simulation and Inference for Neuroscience
## Final Project: Parameter inference for the pyloric network

The notebook below contains the final project.
The final project is meant to be a bit more open-endend and less structured than the exercises. You are free to answer the questions below as you see fit and you can re-use as much of your previous code as you want. You will work be graded in groups of 2, but you are welcome to collaborate with others. 

If you need additional dependencies to solve any of the questions, feel free to install and/or import them. Make sure you document your process and thinking, such that it is clear how you arrived at your final answer.

The project will be graded based on a brief presentation and a follow-up interview about your code and results (10-15min). Take note of the following:
- Make sure you understand what you are doing and you can explain / defend your analysis.
- Ensure you're notebook is presentable and easy to follow. Intermediate results should be backed up by plots (and / or printouts). Plots should be readable
- You should be able to present what you did with the help of your plots in about 2-5 minutes.
- Make sure the notebook can be run from start to finish without errors (if you cache intermediate results, it is fine to load them).
- We encourage the use of coding assistants.


**Before you start, please add your name below!**

**Names:** _Firstname Lastname, Firstname Lastname_

---

In the following project you will perform parameter inference on the pyloric network of the stomatogastric ganglion (STG) of crustaceans.

The pyloric network of the stomatogastric ganglion (STG) is a well-studied component of the crustacean nervous system that provides valuable insights into neural circuit function. Computational models of the pyloric network have been instrumental in understanding how neural circuits maintain functional stability despite biological variability and perturbations. It consists of about 14 neurons that generate a rhythmic motor pattern controlling the pylorus of the crustacean stomach that are connected by a set of Cholinergic and Glutamatergic synapses.

The key neuron types include:
- Anterior Burster (AB) neuron
- Pyloric Dilator (PD) neurons
- Lateral Pyloric (LP) neuron
- Pyloric (PY) neurons

The simplified model that we will be using in the following approximates all neurons of the same type with a single compartmental model and also treats the electrically coupled AB and PD neurons as a single neuron. This leads to the following circuit:

![pyloric network](../assets/pyloric_schematic.png)

For more context you can read the following papers:
- [Similar network activity from disparate circuit parameters](../assets/nn1352.pdf)
- [Alternative to hand-tuning conductance-based models](../assets/alternative-to-hand-tuning-conductance-based-models-tqk0oa7i15.pdf)
- [Training deep neural density estimators to identify mechanistic models of neural dynamics](../assets/elife-56261-v3-4.pdf)

In this project, we will focus on inferring parameters of the synapses.


In [1]:
# configure jax to use 64bit precision and cpu
from jax import config

config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".8"

In [2]:
import sys
sys.path.append("..")

from pyloric import PyloricNetwork
import jaxley as jx
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import jit, vmap
import numpy as np
import jax

import torch
from torch import Tensor
from numpy import ndarray
from typing import Union, Optional, Tuple


In [3]:
# utils
def plot_pyloric(ts: Union[ndarray, Tensor], v: Union[ndarray, Tensor], axs: Optional[plt.Axes] = None, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
    """Plot the voltage of the pyloric network for each neuron.

    Args:
        ts: The time points to plot.
        v: The voltage of the pyloric network.
        axs: The axes to plot on. Allows to plot multiple traces in one figure.

    Returns:
        fig: The figure.
        axs: The axes.
    """
    if axs is None:
        fig, axs = plt.subplots(3, 1, figsize=(10, 5), sharex=True, layout='constrained')
    for ax_i, v_i in zip(axs, v):
        ax_i.plot(ts, v_i, **kwargs)
        ax_i.set_ylabel('V (mV)')
    axs[0].set_title(f'AB/PD')
    axs[1].set_title(f'LP')
    axs[2].set_title(f'PY')
    axs[2].set_xlabel('t (ms)')
    return fig, axs

In [None]:
net = PyloricNetwork() # instantiate the pyloric circuit model
net.record() # record the voltage of all neurons
net.init_states() # initialize the states

# set up the simulation parameters (no stimulus needed)
dt = 0.025
t_max = 4_000
ts = jnp.arange(0, t_max, dt)

@jit # compile the simulator
def simulate(params: jnp.ndarray) -> jnp.ndarray:
    """Simulate the pyloric network.
    
    Args:
        params: The synaptic conductances of shape (7,).
        params[[0, 2, 4, 5, 6]]: Glutamatergic synapses.
        params[[1, 3]]: Cholinergic synapses.
        for details see `PyloricNetwork`.

    Returns:
        v: The voltages of the pyloric network. Shape (3, num_steps).
    """
    # set the synapse parameters
    pstate = None
    for i in [0,2,4,5,6]:
        pstate = net.select(edges=i).data_set("GlutamatergicSynapse_gS", params[i], pstate)
    for i in [1,3]:
        pstate = net.select(edges=i).data_set("CholinergicSynapse_gS", params[i], pstate)
    
    # simulate the network
    v = jx.integrate(net, param_state=pstate, t_max=t_max-dt)
    return v

### Inspect the model
Familiarize yourself with the model. Read the code in `pyloric/channels.py`, `pyloric/synapses.py` and `pyloric/model.py` to gain a rough understanding of how it is implemented. You should pay particular attention to `PyloricNetwork`.

In addition you can use `.nodes` and `.edges` to inspect the neurons and synapses respectively.

Also take a closer look at the imported data and try to understand what you are looking at.

In [None]:
# import and plot the data
t_obs, *v_obs = jnp.array(np.loadtxt("../data/pyloric_observation.csv", unpack=True))
v_obs = jnp.array(v_obs)
fig, axs = plot_pyloric(t_obs, v_obs)
plt.show()

### Identify suitable synpaptic conductances for the pyloric network
In the following task we will try to identify parameters for the pyloric network that reproduce the observed activity from our experimental data. You are free in your choice of method and how you approach this task.

**Be careful**, the observation is very long (4s), this makes the simuluator quite expensive to run for many simulations. While developing your code, think about how you can get a way with running it less often or more cheaply. I.e. fit only 1 second (or even less) of the observation and see if it generalizes to the whole observation. Also once you are sure you're pipeline is working, make sure to cache intermediate results to avoid re-running the simulator.


Make sure the results you obtain are sound. It is sufficient to do this qualitatively, but bonus points if you can do it quantitatively.

In [None]:
# Here is an example of how to run the simulator
param_guess = 0.001*jnp.ones(7)
v = simulate(param_guess)
fig, axs = plot_pyloric(ts, v)
plt.show()

In [9]:
# ...and how to parallelize the simulation
parallel_simulate = vmap(simulate)
key = jax.random.PRNGKey(0)

param_batch = param_guess.reshape(1,-1).repeat(5,axis=0) + 0.0001*jax.random.normal(key, (5,7))
v_samples = parallel_simulate(param_batch)

# ... and how to adapt the simulator to the sbi framework
def simulate_for_sbi(theta: Tensor) -> Tensor:
    """Simulate the pyloric network for the given parameters.
    
    Args:
        theta: The parameters of the synapses. Shape (n_samples,7).
        
    Returns:
        v: The voltages of the pyloric network. Shape (n_samples, 3, num_steps).
    """
    theta = theta.to(torch.float64).numpy()
    v = torch.tensor(parallel_simulate(theta)).to(torch.float32)
    noise = torch.randn_like(v) * 0.1 # add a bit of observation noise
    v += noise
    return v

In [None]:
# implement your pipeline here
bounds = (1e-5, 10) # uS
