In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# I have experienced stability issues with float32.
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".8"

In [4]:
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
from jax import jit, vmap

import jaxley as jx
from jaxley.channels import HH
# from jaxley.synapses import GlutamateSynapse
from jaxley_mech.channels.fm97 import Na, K, KA, KCa, Ca, Leak


In [5]:
dt = 0.025
t_max = 100.0
time_vec = np.arange(0, t_max+2*dt, dt)

i_delay = 10.0
i_dur = 80.0
i_amp = 5.0  # nA
current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)

cell_id = "20161028_1"
np.random.seed(0)

In [125]:
from jaxley.utils.cell_utils import index_of_loc, loc_of_index

In [112]:
comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(4)])

In [130]:
loc_of_index(0, 4)

0.875

In [133]:
assert np.all(branch.loc(index_of_loc(0, 0.4, 4)).show() == branch.loc(0.4).show())


In [None]:
comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(4)])
cell = jx.Cell([branch for _ in range(3)], parents=jnp.asarray([-1, 0, 0]))
network = jx.Network([cell for _ in range(5)])
network.compute_xyz()

In [None]:
cell[0].stimulate(current)
cell[0].stimulate(jnp.stack([current]*1))
cell[0].stimulate(jnp.stack([current]*4))
cell[0].stimulate(jnp.stack([current]*3))

In [None]:
cell.currents, cell.current_inds

In [None]:
comp.compute_xyz()
comp._update_nodes_with_xyz()
comp.nodes


branch.compute_xyz()
branch._update_nodes_with_xyz()
branch.nodes

cell.compute_xyz()
cell._update_nodes_with_xyz()
cell.nodes

cell.vis()

# network.compute_xyz()
# network._update_nodes_with_xyz()
# network.nodes

In [None]:
cell.nodes[["comp_index", "branch_index", "x", "y"]]

In [None]:
cell = jx.read_swc(f"../../jaxley_experiments/nex/rgc/morphologies/{cell_id}.swc", nseg=4, max_branch_len=300.0, min_radius=5.0)
cell.compute_xyz()
cell._update_nodes_with_xyz()
cell.nodes

In [None]:
# comp = jx.Compartment()
# comp.compute_xyz()
# comp.update_nodes_with_xyz()

# print(comp.nodes[["x", "y", "z"]])
# print(comp.xyzr)

# cell = jx.read_swc(f"../../jaxley_experiments/nex/rgc/morphologies/{cell_id}.swc", nseg=4, max_branch_len=300.0, min_radius=5.0)

# misc ideas and Qs:
- merge `syn_edges` and `branch_edges` -> edges and replace branch_edge by "type=branch" label, why does base have `edges` and comp has `syn_edges` and `branch_edges`?
- connect should only have to update list of edges!
- what is the reasoning of having comp be a float. I find this counterintuitive since it's discrete. If only important for plotting, then we should only have this in plotting too.

In [None]:
comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(4)])
cell = jx.Cell([branch for _ in range(5)], parents=jnp.asarray([-1, 0, 0, 1, 1]))
network = jx.Network([cell for _ in range(5)])

network.cell(0).connect(network.cell(1), GlutamateSynapse)
network.cell(1).connect(network.cell(2), GlutamateSynapse)


In [None]:
type(GlutamateSynapse())

In [None]:
network.cell(0)

In [None]:
network.edges

In [None]:
# "20170610_1" is a t-off-mini
# "20161028_1" is a t-off-alpha
cell_id = "20161028_1"

if cell_id == "20161028_1":
    soma_branch = 1
    dendrite_branch = 50
elif cell_id == "20170610_1":
    soma_branch = 0
    dendrite_branch = 70
else:
    raise ValueError


cell = jx.read_swc(f"../../jaxley_experiments/nex/rgc/morphologies/{cell_id}.swc", nseg=4, max_branch_len=300.0, min_radius=5.0)

cell.insert(HH())
cell.insert(KA())
cell.insert(Ca())
cell.insert(KCa())

cell.set("v", -65.0)
cell.init_states()

cell.delete_trainables()

cell.delete_stimuli()
cell.delete_recordings()

cell.branch(soma_branch).loc(0.4).stimulate(current)
cell.branch(soma_branch).loc(0.4).record()
cell.branch(dendrite_branch).loc(1.0).record()

In [None]:
dims=[0,1]
soma = cell.xyzr[1][0]
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
dmin, dmax = 0, 182
for i in np.unique(cell.show()["branch_index"]):
    xyzr = cell.xyzr[i]
    d = np.sqrt(np.sum((xyzr[:, dims] - soma[dims])**2, axis=1)).mean()
    c = np.array(plt.cm.viridis((d-dmin)/(dmax-dmin)))
    cell.branch(i).vis(col=c, ax=ax, dims=dims)
plt.show()

In [None]:
cell.branch(dendrite_branch).show()