# Obtaining the gradient and training (optimizing) the parameters

In this tutorial, we will describe how you can use JAX's automatic differentiation to obtain gradients through `jaxley` simulations and how you can use optimize the parameters with the Adam optimizer.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# I have experienced stability issues with float32.
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".8"

In [3]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, value_and_grad

import jaxley as jx
from jaxley.channels import HHChannel
from jaxley.synapses import GlutamateSynapse

### Setup

In [4]:
# Number of segments per branch.
nseg_per_branch = 8

# Stimulus.
i_delay = 3.0  # ms
i_amp = 0.05  # nA
i_dur = 2.0  # ms

# Duration and step size.
dt = 0.025  # ms
t_max = 50.0  # ms

In [5]:
time_vec = jnp.arange(0.0, t_max+dt, dt)

### Let's define a network

In [6]:
comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(nseg_per_branch)])
cell = jx.Cell([branch for _ in range(5)], parents=jnp.asarray([-1, 0, 0, 1, 1]))

In [7]:
_ = np.random.seed(0)
conn_builder = jx.ConnectivityBuilder([cell.total_nbranches for _ in range(5)])
connectivities = [jx.Connectivity(GlutamateSynapse(), conn_builder.fc(np.arange(0, 2), np.arange(2, 5)))]

In [11]:
network = jx.Network([cell for _ in range(5)], connectivities)

In [12]:
network.insert(HHChannel())

In [13]:
for cell_ind in range(5):
    network.cell(cell_ind).branch(1).comp(0.0).record()
    
current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)
for stim_ind in range(2):
    network.cell(stim_ind).branch(1).comp(0.0).stimulate(current)

### Defining trainable parameters

This follows the same API as `.set_params()` seen in the previous tutorial. If you want to use a single parameter for all `radius`es in the entire network, do:

In [14]:
network.make_trainable("radius")

Number of newly added trainable parameters: 1. Total number of trainable parameters: 1


We can also define parameters for individual compartments. To do this, use the `"all"` key. The following defines a separate parameter the sodium conductance for every compartment in the entire network:

In [15]:
network.cell("all").branch("all").comp("all").make_trainable("gNa")

Number of newly added trainable parameters: 200. Total number of trainable parameters: 201


### Making synaptic parameters trainable

Synaptic parameters can be made trainable in the exact same way. To use a single parameter for all syanptic conductances in the entire network, do

In [16]:
network.make_trainable("gS")

Number of newly added trainable parameters: 1. Total number of trainable parameters: 202


and to use a different syanptic conductance for all syanpses, do

In [17]:
network.GlutamateSynapse("all").make_trainable("gS")

Number of newly added trainable parameters: 6. Total number of trainable parameters: 208


### Running the simulation again

Once all parameters are defined, you have to use `.get_parameters()` to obtain all trainable parameters. This is also the time to check how many trainable parameters your network has:

In [18]:
params = network.get_parameters()
# print(params)

You can now run the simulation with the trainable parameters by passing them to the `jx.integrate` function.

In [19]:
s = jx.integrate(network, delta_t=dt, params=params)

### Defining a loss function

Let us define a loss function to be optimized:

In [20]:
def loss(params):
    s = jx.integrate(network, delta_t=dt, params=params)
    return jnp.sum(s[0, -1])

And we can use `JAX`'s inbuilt functions to take the gradient through the entire ODE:

In [21]:
jitted_grad = jit(value_and_grad(loss))

In [22]:
value, gradient = jitted_grad(params)

### Training

We will use the ADAM optimizer from the [optax library](https://optax.readthedocs.io/en/latest/) to optimize the free parameters (you have to install the package with `pip install optax` first):

In [23]:
import optax

Before training, however, we will enforce for all parameters to be within a prespecified range (such that, e.g., conductances can not become negative)

In [24]:
transform = jx.ParamTransform(
    lowers={"gNa": 0.05, "gK": 0.01, "gLeak": 0.0001, "radius": 0.1, "length": 1.0, "axial_resistivity": 500.0, "gS": 0.01}, 
    uppers={"gNa": 1.1, "gK": 0.3, "gLeak": 0.001, "radius": 5.0, "length": 20.0, "axial_resistivity": 5500.0, "gS": 5.0}, 
)

Let's modify the loss function acocrdingly:

In [25]:
def loss(params):
    params = transform.forward(params)
    s = jx.integrate(network, delta_t=dt, params=params)
    return jnp.sum(s[0, -1])

jitted_grad = jit(value_and_grad(loss))
value, gradient = jitted_grad(params)

Then we define the optimizer:

In [26]:
opt_params = transform.inverse(params)
optimizer = optax.adam(learning_rate=1e-1)
opt_state = optimizer.init(opt_params)

In [27]:
epoch_losses = []

for epoch in range(5):
    loss_val, gradient = jitted_grad(opt_params)
    updates, opt_state = optimizer.update(gradient, opt_state)
    opt_params = optax.apply_updates(opt_params, updates)

    print(f"epoch {epoch}, loss {loss_val}")
    epoch_losses.append(loss_val)
    
final_params = transform.forward(opt_params)

epoch 0, loss -64.97740512171988
epoch 1, loss -65.03050878143252
epoch 2, loss -65.07820463321355
epoch 3, loss -65.12091871011332
epoch 4, loss -65.15909222980945


Indeed, the loss goes down.