In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# I have experienced stability issues with float32.
from jax.config import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".8"

In [3]:
import time
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, value_and_grad

import neurax as nx
from neurax.channels import HHChannel
from neurax.synapses import GlutamateSynapse

### Setup

In [4]:
# Number of segments per branch.
nseg_per_branch = 8

# Stimulus.
i_delay = 3.0  # ms
i_amp = 0.05  # nA
i_dur = 2.0  # ms

# Duration and step size.
dt = 0.025  # ms
t_max = 50.0  # ms

In [5]:
time_vec = jnp.arange(0.0, t_max+dt, dt)

### Define stimuli and recordings

In [6]:
recs = [nx.Recording(cell_ind, 1, 0.0) for cell_ind in range(5)]
stims = [
    nx.Stimulus(stim_ind, 1, 0.0, current=nx.step_current(i_delay, i_dur, i_amp, time_vec)) for stim_ind in range(2)
]

### Let's define a network

In [19]:
comp = nx.Compartment()
branch = nx.Branch([comp for _ in range(nseg_per_branch)])
cell = nx.Cell([branch for _ in range(5)], parents=jnp.asarray([-1, 0, 0, 1, 1]))

In [28]:
_ = np.random.seed(0)
conn_builder = nx.ConnectivityBuilder([cell.total_nbranches for _ in range(5)])
connectivities = [nx.Connectivity(GlutamateSynapse(), conn_builder.fc(np.arange(0, 2), np.arange(2, 5)))]

In [29]:
network = nx.Network([cell for _ in range(5)], connectivities)

In [30]:
network.insert(HHChannel())

### Defining trainable parameters

This follows the same API as `.set_params()` seen in the previous tutorial. If you want to use a single parameter for all `radius`es in the entire network, do:

In [31]:
network.make_trainable("radius")

We can also define parameters for individual compartments. To do this, use the `"all"` key. The following defines a separate parameter the sodium conductance for every compartment in the entire network:

In [32]:
network.cell("all").branch("all").comp("all").make_trainable("gNa")

### Making synaptic parameters trainable

Synaptic parameters can be made trainable in the exact same way. To use a single parameter for all syanptic conductances in the entire network, do

In [33]:
network.make_trainable("gS")

and to use a different syanptic conductance for all syanpses, do

In [35]:
network.GlutamateSynapse("all").make_trainable("gS")

### Running the simulation again

Once all parameters are defined, you have to use `.get_parameters()` to obtain all trainable parameters. This is also the time to check how many trainable parameters your network has:

In [37]:
params = network.get_parameters()
# print(params)

You can now run the simulation with the trainable parameters by passing them to the `nx.integrate` function.

In [38]:
s = nx.integrate(network, stimuli=stims, recordings=recs, delta_t=dt, params=params)

### Defining a loss function

Let us define a loss function to be optimized:

In [39]:
def loss(params):
    s = nx.integrate(network, stimuli=stims, recordings=recs, delta_t=dt, params=params)
    return jnp.sum(s[0, -1])

And we can use `JAX`'s inbuilt functions to take the gradient through the entire ODE:

In [40]:
jitted_grad = jit(value_and_grad(loss))

In [41]:
value, gradient = jitted_grad(params)