In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".2"

import time
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from copy import deepcopy

import jax
import jax.numpy as jnp
from jax import jit, lax, grad, value_and_grad
import optax
import distrax

from neuron import h
_ = h.load_file("stdlib.hoc")

from neurax.integrate import solve
from neurax.cell import Cell
from neurax.stimulus import Stimulus, step_current
from neurax.recording import Recording
from neurax.connection import Connection

--No graphics will be displayed.


In [3]:
device_str = "cpu"
jax.config.update('jax_platform_name', device_str)

cpus = jax.devices("cpu")
gpus = jax.devices("gpu")

device = cpus[0] if device_str == "cpu" else gpus[0]

# Solve

### Define model

In [79]:
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

def expit(x):
    return -jnp.log(1 / x - 1)

lower_gna = jnp.asarray([[0.06] * ncomp] * num_cells)
upper_gna = jnp.asarray([[0.18] * ncomp] * num_cells)

lower_gkd = jnp.asarray([[0.02] * ncomp] * num_cells)
upper_gkd = jnp.asarray([[0.05] * ncomp] * num_cells)

lower_gleak = jnp.asarray([[0.0001] * ncomp] * num_cells)
upper_gleak = jnp.asarray([[0.0005] * ncomp] * num_cells)

lower = jnp.stack([lower_gna, lower_gkd, lower_gleak])
upper = jnp.stack([upper_gna, upper_gkd, upper_gleak])

tf = distrax.Lambda(
    forward=lambda x: sigmoid(x) * (upper-lower) + lower, 
    inverse=lambda x: expit((x - lower) / (upper - lower))
)

In [80]:
num_cells = 8
nseg_per_branch = 4
num_branches = 15
ncomp = num_branches * nseg_per_branch
parents = jnp.asarray([-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6])

assert len(parents) == num_branches
g_na = jnp.asarray([[0.12] * ncomp] * num_cells)
g_kd = jnp.asarray([[0.036] * ncomp] * num_cells)
g_leak = jnp.asarray([[0.0003] * ncomp] * num_cells)
params = jnp.stack([g_na, g_kd, g_leak])

init_v = jnp.asarray([[-62.0] *num_branches*nseg_per_branch] * num_cells)
init_m = jnp.asarray([[0.074901] *num_branches*nseg_per_branch] * num_cells)
init_h = jnp.asarray([[0.4889] *num_branches*nseg_per_branch] * num_cells)
init_n = jnp.asarray([[0.3644787] *num_branches*nseg_per_branch] * num_cells)
u = jnp.stack([init_v, init_m, init_h, init_n])

length = 100.0  # um (length of a single branch)
radius = 10.0  # um
r_a = 10_000  # ohm cm 
i_delay = 5.0  # ms
i_amp = 1.0  # nA
i_dur = 2.0  # ms
dt = 0.025  # ms
t_max = 20.0  # ms

### Set up model

In [81]:
time_vec = jnp.arange(0, t_max+dt, dt)

In [82]:
cells = [Cell(num_branches, parents, nseg_per_branch, length, radius, r_a)] * num_cells
stims = [
    Stimulus(cell_ind=0, branch_ind=0, loc=0.0, current=step_current(i_delay, i_dur, i_amp, time_vec)),
    Stimulus(cell_ind=1, branch_ind=0, loc=0.0, current=step_current(i_delay, i_dur, i_amp, time_vec)),
]
recs = [
    Recording(0, 0, 0.0),
    Recording(0, 6, 1.0),
    Recording(1, 0, 0.0),
    Recording(1, 6, 1.0),
    Recording(3, 0, 0.0),
    Recording(3, 6, 1.0),
    Recording(4, 0, 0.0),
    Recording(4, 6, 1.0),
]
conns = [
    Connection(0, 0, 0.0, 3, 0, 0.0, synaptic_cond=1.0),
    Connection(1, 0, 0.0, 3, 0, 0.3, synaptic_cond=1.0),
]

# Gradient

In [83]:
checkpoint_inds = [200, 400, 600]
def sum_ode(diff_params):
    constrained_params = tf.forward(diff_params)
    s = solve(
        cells, 
        u, 
        constrained_params, 
        stims, 
        recs, 
        conns, 
        t_max=t_max, 
        dt=dt, 
        solver="stone", 
        checkpoint_inds=checkpoint_inds
    )
    return jnp.mean(s)

In [84]:
jitted_sum_ode = jit(sum_ode)

In [87]:
start_time = time.time()
result = jitted_sum_ode(tf.inverse(params))
print("Time:  ", time.time() - start_time)

Time:   0.036626577377319336


In [88]:
jitted_grad = jit(value_and_grad(sum_ode))

In [91]:
start_time = time.time()
result = jitted_grad(tf.inverse(params))
print("Time:  ", time.time() - start_time)
print("Solve", result[0])
print("Grad", result[1][:, 0, ::100])

Time:   0.16089892387390137
Solve -60.95626
Grad [[ 0.00038098]
 [-0.0042358 ]
 [ 0.00612676]]


# Optimize

In [153]:
g_na = jnp.asarray([[0.12] * ncomp] * num_cells)
g_kd = jnp.asarray([[0.036] * ncomp] * num_cells)
g_leak = jnp.asarray([[0.0003] * ncomp] * num_cells)
opt_params = tf.inverse(jnp.stack([g_na, g_kd, g_leak]))

In [154]:
optimizer = optax.adam(learning_rate=1e-2)
opt_state = optimizer.init(opt_params)

In [156]:
num_iter = 10
for iter_ in range(num_iter):
    loss, grad = jitted_grad(opt_params)
    if jnp.any(jnp.isnan(grad)):
        print(f"iter {iter_}, loss {loss}, ============= skipping")
        opt_params = opt_params.at[0, 0, 0].add(1e-3)
    else:
        print(f"iter {iter_}, loss {loss}")
        updates, opt_state = optimizer.update(grad, opt_state)
        opt_params = optax.apply_updates(opt_params, updates)
    
final_params = tf.forward(opt_params)

iter 0, loss -65.22089385986328
iter 1, loss -65.22219848632812
iter 2, loss -65.22350311279297
iter 3, loss -65.22479248046875
iter 4, loss -65.22611999511719
iter 5, loss -65.2274169921875
iter 6, loss -65.22872924804688
iter 7, loss -65.23002624511719
iter 8, loss -65.23130798339844
iter 9, loss -65.23260498046875


In [128]:
jitted_sum_ode(opt_params)

Array(-61.988945, dtype=float32)