In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from jax import config
config.update("jax_enable_x64", False)
device = "gpu"
config.update("jax_platform_name", device)

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".8"

import time
import pickle

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np

import jax.numpy as jnp
from jax import jit, value_and_grad, vmap
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

import jaxley as jx
from jaxley.channels import HH
from jaxley.utils.colors import network_cols

2024-05-27 09:17:34.190681: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-27 09:17:34.190772: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-27 09:17:34.192158: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
from neuron import h

_ = h.load_file("stdlib.hoc")
_ = h.load_file("import3d.hoc")

In [4]:
nseg_per_branch = 4

i_delay = 3.0  # ms
i_amp = 0.5  # nA
i_dur = 2.0  # ms
dt = 0.025  # ms
t_max = 20.0  # ms

### Panel A

In [5]:
cell = jx.Compartment()

2024-05-27 09:17:38.946559: W external/xla/xla/service/gpu/nvptx_compiler.cc:744] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


### Panel B

In [6]:
cell.insert(HH())

current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)
cell.stimulate(current)

cell.record()

cell.set("axial_resistivity", 1_000.0)
cell.set("v", -62.0)
cell.set("HH_m", 0.074901)
cell.set("HH_h", 0.4889)
cell.set("HH_n", 0.3644787)

cell.make_trainable("HH_gNa")
cell.make_trainable("HH_gLeak")
cell.make_trainable("radius")

Added 1 stimuli. See `.currents` for details.
Added 1 recordings. See `.recordings` for details.
Number of newly added trainable parameters: 1. Total number of trainable parameters: 1
Number of newly added trainable parameters: 1. Total number of trainable parameters: 2
Number of newly added trainable parameters: 1. Total number of trainable parameters: 3


In [7]:
parameters = cell.get_parameters()

In [8]:
def simulate_jaxley():
    return jx.integrate(cell, delta_t=dt)

jitted_sim = jit(simulate_jaxley)
voltages_jaxley = jitted_sim()

NEURON

In [9]:
for sec in h.allsec():
    h.delete_section(sec=sec)

soma = h.Section(name='soma')
soma.nseg = 1

stim = h.IClamp(soma(0.1))
stim.delay = i_delay
stim.dur = i_dur
stim.amp = i_amp

counter = 0
voltage_recs = {}

v = h.Vector()
v.record(soma(0.05)._ref_v)

soma.insert("hh")
soma.Ra = 1_000.0

soma.gnabar_hh = 0.120  # S/cm2
soma.gkbar_hh = 0.036  # S/cm2
soma.gl_hh = 0.0003  # S/cm2
soma.ena = 50  # mV
soma.ek = -77.0  # mV
soma.el_hh = -54.3  # mV

h.dt = dt
tstop = t_max
v_init = -62.0

def initialize():
    h.finitialize(v_init)
    h.fcurrent()

def integrate():
    while h.t < tstop:
        h.fadvance()

initialize()
integrate()
voltages_neuron = np.asarray([voltage_recs[key] for key in voltage_recs])

In [10]:
# time_vec = jnp.arange(0, t_max+2*dt, dt)

# with open("../results/01_accuracy/time_vec.pkl", "wb") as handle:
#     pickle.dump(time_vec, handle)

# with open("../results/01_accuracy/voltages_neuron.pkl", "wb") as handle:
#     pickle.dump(voltages_neuron, handle)

# with open("../results/01_accuracy/voltages_jaxley.pkl", "wb") as handle:
#     pickle.dump(voltages_jaxley, handle)

# Simulation

In [11]:
batch_sizes = [1, 10, 100, 1_000, 10_000, 100_000, 1_000_000] # , 1_000_000]

In [12]:
computer = "vm"

### Data

In [13]:
def simulate_jaxley_vmappable(current):
    current = jx.step_current(0.1, 5.0, 1e-5 * current, dt, t_max)
    data_stimuli = cell.data_stimulate(current, None)
    return jx.integrate(cell, delta_t=dt, t_max=t_max, data_stimuli=data_stimuli)[:, -1]

jitted_vmapped_sim = jit(vmap(simulate_jaxley_vmappable, in_axes=[0,]))

# Compile.
time_jaxley_gpu = {}

for batch_size in batch_sizes:
    _ = jitted_vmapped_sim(jnp.arange(batch_size))

    start_time = time.time()
    voltages_batch_100 = jitted_vmapped_sim(jnp.arange(batch_size))
    time_jaxley_gpu[batch_size] = time.time() - start_time

In [14]:
with open(f"../results/03_timing/currents_jaxley_{device}_{computer}_nseg{nseg_per_branch}_pointneuron.pkl", "wb") as handle:
    pickle.dump(time_jaxley_gpu, handle)

In [15]:
time_jaxley_gpu

{1: 0.0119476318359375,
 10: 0.012282371520996094,
 100: 0.011378049850463867,
 1000: 0.013270139694213867,
 10000: 0.014665603637695312,
 100000: 0.017347097396850586,
 1000000: 0.06923913955688477}

In [16]:
# Evaluate NEURON.
start_time = time.time()
initialize()
integrate()
time_neuron_once = (time.time() - start_time)

time_neuron = {}
for batch_size in batch_sizes:
    time_neuron[batch_size] = time_neuron_once * batch_size

In [17]:
with open(f"../results/03_timing/currents_neuron_{computer}_nseg{nseg_per_branch}_pointneuron.pkl", "wb") as handle:
    pickle.dump(time_neuron, handle)

### Parameters

In [18]:
def simulate_jaxley_vmappable(parameter):
    pstate = cell.data_set("radius", 1e-5 * parameter, None)
    return jx.integrate(cell, delta_t=dt, t_max=t_max, param_state=pstate)[:, -1]

jitted_vmapped_sim = jit(vmap(simulate_jaxley_vmappable, in_axes=[0,]))

In [19]:
batch_sizes

[1, 10, 100, 1000, 10000, 100000, 1000000]

In [20]:
# Compile.
time_jaxley_gpu = {}
for batch_size in batch_sizes:
    _ = jitted_vmapped_sim(jnp.arange(batch_size))
    start_time = time.time()
    voltages_batch_100 = jitted_vmapped_sim(jnp.arange(batch_size))
    time_jaxley_gpu[batch_size] = time.time() - start_time

In [21]:
time_jaxley_gpu

{1: 0.004081010818481445,
 10: 0.002918243408203125,
 100: 0.006276369094848633,
 1000: 0.007804155349731445,
 10000: 0.008698701858520508,
 100000: 0.011684417724609375,
 1000000: 0.04788374900817871}

In [22]:
with open(f"../results/03_timing/parameters_jaxley_{device}_{computer}_nseg{nseg_per_branch}_pointneuron.pkl", "wb") as handle:
    pickle.dump(time_jaxley_gpu, handle)

In [23]:
# Evaluate NEURON.
start_time = time.time()
initialize()
integrate()
time_neuron_once = time.time() - start_time

time_neuron = {}
for batch_size in batch_sizes:
    time_neuron[batch_size] = time_neuron_once * batch_size

In [24]:
with open(f"../results/03_timing/parameters_neuron_{computer}_nseg{nseg_per_branch}_pointneuron.pkl", "wb") as handle:
    pickle.dump(time_neuron, handle)

# Gradient

In [25]:
batch_sizes = [1, 10, 100, 1_000, 10_000, 100_000, 1_000_000]

### Data

In [26]:
def simulate_jaxley_vmappable(param, current):
    current = jx.step_current(0.1, 5.0, 1e-5 * current, dt, t_max)
    data_stimuli = cell.data_stimulate(current, None)
    pstate = cell.data_set("radius", 1e-5 * param, None)
    return jx.integrate(cell, delta_t=dt, t_max=t_max, data_stimuli=data_stimuli, param_state=pstate, checkpoint_lengths=[10, 9, 9])[0, -1]

vmapped_sim = vmap(simulate_jaxley_vmappable, in_axes=(None, 0))

def loss(param, currents):
    return jnp.mean(vmapped_sim(param, currents))

grad_fn = jit(value_and_grad(loss, argnums=(0,)))

In [27]:
# Compile.
time_jaxley_gpu = {}
for batch_size in batch_sizes:
    _ = grad_fn(jnp.ones(1,), jnp.arange(batch_size).astype(float))
    start_time = time.time()
    _ = grad_fn(jnp.ones(1,), jnp.arange(batch_size).astype(float))
    time_jaxley_gpu[batch_size] = time.time() - start_time

In [28]:
time_jaxley_gpu

{1: 0.001024007797241211,
 10: 0.0043277740478515625,
 100: 0.020061731338500977,
 1000: 0.23946332931518555,
 10000: 1.314502239227295,
 100000: 13.305858850479126,
 1000000: 110.9902594089508}

In [31]:
with open(f"../results/03_timing/gradient_currents_jaxley_{device}_{computer}_nseg{nseg_per_branch}_pointneuron.pkl", "wb") as handle:
    pickle.dump(time_jaxley_gpu, handle)

### Parameters

In [32]:
def simulate_jaxley_vmappable(parameter):
    pstate = cell.data_set("radius", 1e-5 * parameter, None)
    return jx.integrate(cell, delta_t=dt, t_max=t_max, param_state=pstate, checkpoint_lengths=[10, 9, 9])[0, -1]

vmapped_sim = vmap(simulate_jaxley_vmappable, in_axes=(0,))

def loss(parameters):
    return jnp.mean(vmapped_sim(parameters))

grad_fn = jit(value_and_grad(loss))

In [33]:
# Compile.
time_jaxley_gpu = {}
for batch_size in batch_sizes:
    _ = grad_fn(jnp.arange(batch_size).astype(float))
    start_time = time.time()
    _ = grad_fn(jnp.arange(batch_size).astype(float))
    time_jaxley_gpu[batch_size] = time.time() - start_time

In [34]:
time_jaxley_gpu

{1: 0.0007579326629638672,
 10: 0.003962993621826172,
 100: 0.01786661148071289,
 1000: 0.19870233535766602,
 10000: 1.0762073993682861,
 100000: 10.718116760253906,
 1000000: 91.69111728668213}

In [35]:
with open(f"../results/03_timing/gradient_parameters_jaxley_{device}_{computer}_nseg{nseg_per_branch}_pointneuron.pkl", "wb") as handle:
    pickle.dump(time_jaxley_gpu, handle)