In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".2"

import time
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from copy import deepcopy

import jax
import jax.numpy as jnp
from jax import jit, lax, grad, value_and_grad, vmap
import optax
import distrax

from neuron import h
_ = h.load_file("stdlib.hoc")

from neurax.integrate import solve
from neurax.cell import Cell
from neurax.stimulus import Stimulus, step_current
from neurax.recording import Recording
from neurax.connection import Connection

--No graphics will be displayed.


In [3]:
device_str = "gpu"
jax.config.update('jax_platform_name', device_str)

cpus = jax.devices("cpu")
gpus = jax.devices("gpu")

device = cpus[0] if device_str == "cpu" else gpus[0]

# Setup

### Define model

In [4]:
num_cells = 16
nseg_per_branch = 4
num_branches = 15
ncomp = num_branches * nseg_per_branch
parents = jnp.asarray([-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6])

assert len(parents) == num_branches
g_na = jnp.asarray([[0.12] * ncomp] * num_cells)
g_kd = jnp.asarray([[0.036] * ncomp] * num_cells)
g_leak = jnp.asarray([[0.0003] * ncomp] * num_cells)
params = jnp.stack([g_na, g_kd, g_leak])

init_v = jnp.asarray([[-62.0] *num_branches*nseg_per_branch] * num_cells)
init_m = jnp.asarray([[0.074901] *num_branches*nseg_per_branch] * num_cells)
init_h = jnp.asarray([[0.4889] *num_branches*nseg_per_branch] * num_cells)
init_n = jnp.asarray([[0.3644787] *num_branches*nseg_per_branch] * num_cells)
u = jnp.stack([init_v, init_m, init_h, init_n])

length = 100.0  # um (length of a single branch)
radius = 10.0  # um
r_a = 10_000  # ohm cm 
i_delay = 5.0  # ms
i_amp = 1.0  # nA
i_dur = 2.0  # ms
dt = 0.025  # ms
t_max = 15.0  # ms

### Set up model

In [5]:
time_vec = jnp.arange(0, t_max+dt, dt)

In [6]:
cells = [Cell(num_branches, parents, nseg_per_branch, length, radius, r_a)] * num_cells
recs = [
    Recording(15, 0, 0.0),
]
conns = []
for pre in jnp.arange(5):
    for post in jnp.arange(5, 15):
        conn = Connection(pre, 0, 0.0, post, 7+pre, 1.0)
        conns.append(conn)
for pre in jnp.arange(5, 10):
    for post in jnp.arange(15, 16):
        conn = Connection(pre, 0, 0.0, post, 2+pre, 1.0)
        conns.append(conn)
for pre in jnp.arange(10, 15):
    for post in jnp.arange(15, 16):
        conn = Connection(pre, 0, 0.0, post, -3+pre, 0.5)
        conns.append(conn)
        
syn_params = jnp.asarray([1.0] * len(conns))

### Parameter bounds

In [7]:
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

def expit(x):
    return -jnp.log(1 / x - 1)

lower_gna = jnp.asarray([[0.06] * ncomp] * num_cells)
upper_gna = jnp.asarray([[0.18] * ncomp] * num_cells)

lower_gkd = jnp.asarray([[0.02] * ncomp] * num_cells)
upper_gkd = jnp.asarray([[0.05] * ncomp] * num_cells)

lower_gleak = jnp.asarray([[0.0001] * ncomp] * num_cells)
upper_gleak = jnp.asarray([[0.0005] * ncomp] * num_cells)

lower = jnp.stack([lower_gna, lower_gkd, lower_gleak])
upper = jnp.stack([upper_gna, upper_gkd, upper_gleak])

tf = distrax.Lambda(
    forward=lambda x: sigmoid(x) * (upper-lower) + lower, 
    inverse=lambda x: expit((x - lower) / (upper - lower))
)

lower_syn = jnp.asarray([0.0] * len(conns))
upper_syn = jnp.asarray([2.0] * len(conns))

tf_syn = distrax.Lambda(
    forward=lambda x: sigmoid(x) * (upper_syn-lower_syn) + lower_syn, 
    inverse=lambda x: expit((x - lower_syn) / (upper_syn - lower_syn))
)

### ODE

In [8]:
checkpoint_inds = [200, 400, 600]
def ode(diff_params, diff_syn_params, stim_traces):
    # Prepare params.
    diff_params = tf.forward(diff_params)
    diff_syn_params = tf_syn.forward(diff_syn_params)
    
    # Prepare stimuli.
    stims = [
        Stimulus(
            cell_ind=i, 
            branch_ind=0, 
            loc=0.0, 
            current=stim_traces[i]
        ) for i in range(5)
    ]
    
    # Solve ODE.
    s = solve(
        cells,
        u,
        diff_params,
        diff_syn_params,
        stims,
        recs,
        conns,
        t_max=t_max,
        dt=dt,
        solver="stone",
        checkpoint_inds=checkpoint_inds
    )
    
    return s


def loss_fn(diff_params, diff_syn_params, stim_traces, target):
    s = ode(diff_params, diff_syn_params, stim_traces)
    voltage_at_last_time = s[0, -1]
    target_at_last_time = target[0, -1]
    return (voltage_at_last_time - target_at_last_time) ** 2

### Generate data

In [9]:
g_na = jnp.asarray([[0.12] * ncomp] * num_cells)
g_kd = jnp.asarray([[0.036] * ncomp] * num_cells)
g_leak = jnp.asarray([[0.0003] * ncomp] * num_cells)

theta_gt = tf.inverse(jnp.stack([g_na, g_kd, g_leak]))
synapse_gt = tf_syn.inverse(jnp.asarray([1.0] * len(conns)))

In [10]:
n_inputs = 5

n_batches = 5
batchsize = 10
n_train = n_batches * batchsize

stim_intensities = 2 * jnp.asarray(np.random.rand(n_train, n_inputs))

stims = jnp.asarray(
    [
        [
            stim_intensities[j, i] * step_current(i_delay, i_dur, i_amp, time_vec) for i in range(n_inputs)
        ] for j in range(n_train)
    ]
)
stims_in_batches = jnp.reshape(stims, (n_batches, batchsize, n_inputs, -1))

In [11]:
vmaped_ode = vmap(jit(ode), in_axes=(None, None, 0))

In [12]:
targets = []
for batch in stims_in_batches:
    start_time = time.time()
    results = vmaped_ode(
        theta_gt,
        synapse_gt,
        batch, 
    )
    print("Time:  ", time.time() - start_time)
    targets.append(results)
targets_in_batches = jnp.asarray(targets)

Time:   6.056895732879639
Time:   0.10374283790588379
Time:   0.10261225700378418
Time:   0.10201191902160645
Time:   0.1020362377166748


### Compile grad

In [13]:
vmaped_grad = vmap(jit(value_and_grad(loss_fn, argnums=(0, 1))), in_axes=(None, None, 0, 0))

In [14]:
start_time = time.time()
loss, grad_val = vmaped_grad(
    params,
    syn_params,
    stims_in_batches[0],
    targets_in_batches[0],
)
print("Time:  ", time.time() - start_time)
print("loss", loss)

Time:   69.25602912902832
loss [656.444   620.7225  714.29193 637.1032  624.85315 659.56555 639.52386
 662.9869  626.3393  617.9082 ]


In [15]:
loss

Array([656.444  , 620.7225 , 714.29193, 637.1032 , 624.85315, 659.56555,
       639.52386, 662.9869 , 626.3393 , 617.9082 ], dtype=float32)

# Optimize

In [30]:
params_membrane = jnp.asarray(np.random.rand(3, num_cells, ncomp) * (upper - lower) + lower)
params_syn = jnp.asarray(np.random.rand(len(conns)) * (upper_syn - lower_syn) + lower_syn)

In [31]:
theta_membrane = tf.inverse(theta_gt)
theta_syn = tf_syn.inverse(synapse_gt)

theta_membrane = tf.inverse(jnp.stack([g_na, g_kd, g_leak]))
theta_syn = tf_syn.inverse(jnp.asarray([1.0] * len(conns)))

In [32]:
optimizer = optax.adam(learning_rate=1e-2)
opt_state = optimizer.init((theta_membrane, theta_syn))

In [33]:
num_iter = 100
for iter_ in range(num_iter):
    for batch_stims, batch_targets in zip(stims_in_batches, targets_in_batches):
        loss, grad = vmaped_grad(
            theta_membrane,
            theta_syn,
            batch_stims,
            batch_targets,
        )
        loss = jnp.mean(loss, axis=0)
        grad_membrane = jnp.mean(grad[0], axis=0)
        grad_synapse = jnp.mean(grad[1], axis=0)
        
        if jnp.any(jnp.isnan(grad_membrane)) or jnp.any(jnp.isnan(grad_synapse)):
            print(f"iter {iter_}, loss {loss}, ============= skipping")
            theta_membrane = theta_membrane.at[0, 0, 0].add(1e-3)
        else:
            print(f"iter {iter_}, loss {loss}")
            updates, opt_state = optimizer.update((grad_membrane, grad_synapse), opt_state)
            theta_membrane, theta_syn = optax.apply_updates((theta_membrane, theta_syn), updates)
    
final_membrane = tf.forward(theta_membrane)
final_syn = tf_syn.forward(theta_syn)

iter 0, loss 1.5951912013534297e-09
iter 0, loss 0.3034658133983612
iter 0, loss 0.7308279871940613
iter 0, loss 0.14372754096984863
iter 1, loss 0.05911582335829735
iter 1, loss 0.19042475521564484
iter 1, loss 0.3617531359195709
iter 1, loss 0.2165220081806183
iter 1, loss 0.09423833340406418
iter 2, loss 0.025481006130576134
iter 2, loss 0.05709538981318474
iter 2, loss 0.11733829975128174
iter 2, loss 0.17675797641277313
iter 2, loss 0.1471520960330963
iter 3, loss 0.022862737998366356
iter 3, loss 0.01046544685959816
iter 3, loss 0.07165174186229706
iter 3, loss 0.06967740505933762
iter 3, loss 0.06961514055728912
iter 4, loss 0.019784068688750267
iter 4, loss 0.0055564227513968945
iter 4, loss 0.04278094694018364
iter 4, loss 0.09801667183637619
iter 5, loss 0.059145908802747726
iter 5, loss 0.02201390638947487
iter 5, loss 0.012631750665605068
iter 5, loss 0.021405000239610672
iter 5, loss 0.07326125353574753
iter 6, loss 0.0796840712428093
iter 6, loss 0.0573418103158474
iter 6

In [128]:
jitted_sum_ode(opt_params)

Array(-61.988945, dtype=float32)