In [1093]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [1094]:
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".8"

In [1095]:
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
from jax import jit, vmap, value_and_grad

import jaxley as jx
from jaxley.channels import HH
from jaxley.synapses import IonotropicSynapse
from jaxley_mech.channels.fm97 import Na, K, KA, KCa, Ca, Leak


In [1096]:
dt = 0.025
t_max = 100.0
time_vec = np.arange(0, t_max+2*dt, dt)

i_delay = 10.0
i_dur = 80.0
i_amp = 5.0  # nA
current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)

cell_id = "20161028_1"
np.random.seed(0)

time_vec = jnp.arange(0.0, 2*dt, dt)

In [965]:
comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(4)])
cell = jx.Cell([branch for _ in range(5)], parents=jnp.asarray([-1, 0, 1, 2, 2]))
net = jx.Network([cell]*2)


net.cell(0).insert(Na())
net.cell(0).insert(Leak())

net.cell(1).branch(1).insert(Na())
net.cell(0).insert(K())
net.compute_xyz()

net.cell(0).branch(0).loc(0.0).record()
current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)
net.cell(0).branch(2).loc(0.0).stimulate(current)

Added 1 recordings. See `.recordings` for details.
Added 1 stimuli. See `.currents` for details.


In [1068]:
import warnings
import networkx as nx

def to_graph(df):
    G = nx.Graph()
    edges = df[['parent_branch_index', 'child_branch_index']].values
    G.add_edges_from(edges)
    return G

In [44]:
net.branch_edges.isin([0,1,4])

Unnamed: 0,parent_branch_index,child_branch_index
0,True,True
1,True,False
2,False,False
3,False,True
4,False,False
5,False,False
6,False,False
7,False,False


In [115]:
def to_module(view):
    warnings.warn("This function is experimental and may not work as expected.")

    modules = np.array([jx.Network, jx.Cell, jx.Branch, jx.Compartment])
    indices = ["cell_index", "branch_index", "comp_index"]
    num_unique_elements = view.view[indices].nunique().values
    return_type =  modules[np.hstack((num_unique_elements > 1, np.array(True)))][0]

    viewed_indices = {col: np.unique(net[:].view[col]) for col in indices}
    
    where_to_view_edges = view.pointer.branch_edges.isin(viewed_indices["branch_index"])
    viewed_branch_edges = view.pointer.branch_edges.loc[where_to_view_edges.any(axis=1)]
    graph = to_graph(viewed_branch_edges)
    
    assert nx.is_connected(graph), "The branches currently in view are not all connected."
    assert not any(view.view[indices].duplicated()), "View must not contain duplicates."
    
    module_instances = []
    num_unique_elements = np.hstack([1, num_unique_elements])
    for num, module in zip(reversed(num_unique_elements), reversed(modules)):
        args = () 
        if module in [jx.Branch, jx.Network]:
            args = [module_instances]
        elif module == jx.Cell:
            root_node_index = viewed_branch_edges["parent_branch_index"].min()
            levels = list(nx.bfs_layers(graph, root_node_index))
            levels = [[i for i in indices if i in viewed_indices["branch_index"]] for indices in levels]
            parents = sum([[i-1]*len(level) for i, level in enumerate(levels)], [])
            args = [module_instances, parents]
        
        module_instances = [module(*args) for _ in range(num)]
        if module == return_type:
            break

    # TODO: sync module attrs
    module = module_instances[0]
    indices_of_viewed_elements = view.view.index
    module.nodes = view.pointer.nodes.loc[indices_of_viewed_elements]
    module.nodes.reset_index(drop=True, inplace=True)
    
    # drop columns with all nan or all False
    module.nodes = module.nodes.dropna(axis=1)
    module.nodes = module.nodes.loc[:, (module.nodes != 0).any()]
    module.xyzr = np.array(view.pointer.xyzr)[viewed_indices["branch_index"]].tolist()
    edge_indices = set(view.view.index.tolist()).intersection(set(view.pointer.edges.index.tolist()))
    module.edges = view.pointer.edges.loc[list(edge_indices)]
    # module.xyzr = view.xyzr # <- needs to be implemented!    

    module.nbranches_per_cell = view.pointer.nbranches_per_cell[viewed_indices["cell_index"]]
    module.total_nbranches = sum(module.nbranches_per_cell)
    module.cumsum_nbranches = np.cumsum([0] + module.nbranches_per_cell.tolist())
    module.comb_parents = view.pointer.comb_parents[viewed_indices["branch_index"]]
    # module.comb_branches_in_each_level 
    # module.initialized_morph
    # module.initialized_syns
    # module.synapses
    # module.synapse_param_names
    # module.synapse_state_names
    # module.synapse_names
    # module.channels
    # module.membrane_current_names
    # module.indices_set_by_trainables
    # module.trainable_params
    # module.allow_make_trainable
    # module.num_trainable_params
    # module.recordings
    # module.currents
    # module.current_inds
    # module.cells
    # module.branch_edges
    # module.initialized_conds
    return module

In [117]:
to_module(net[0,0,0])



TypeError: only integer scalar arrays can be converted to a scalar index

In [12]:
# def get_connections(net):
#     def get_global_comp_indices(loc):
#         get_cols = lambda loc: [f"{loc}_locs", f"{loc}_branch_index", f"{loc}_cell_index"]
#         cols = get_cols(loc)
#         edges = net.edges[cols]
#         locs = edges[cols[1]].values.astype(float)
#         branch_indices, cell_indices = edges[cols[1:]].values.T.astype(int)
#         comp_ind_from_loc = lambda x: index_of_loc(np.zeros_like(x),x, net.nseg)
#         comp_indices = np.array(list(map(comp_ind_from_loc, locs)))
#         global_comp_indices = net._local_inds_to_global(cell_indices, branch_indices, comp_indices)
#         return global_comp_indices

#     global_pre_comp_indices = get_global_comp_indices("pre")
#     global_post_comp_indices = get_global_comp_indices("post")

#     syn_types = net.edges["type"].values.astype(str)
#     syn_ids = [net.cell._infer_synapse_type_ind(syn_type)[0] for syn_type in syn_types]
#     synapses = [net.synapses[syn_id] for syn_id in syn_ids]
#     return global_pre_comp_indices, global_post_comp_indices, synapses

# pre, post, syns = get_connections(net)
# # pre_view = net_copy[:].view.loc[pre_global_comp_indices]
# # post_view = net_copy[:].view.loc[post_global_comp_indices]

In [None]:
cell = jx.read_swc(f"../../jaxley_experiments/nex/rgc/morphologies/{cell_id}.swc", nseg=4, max_branch_len=300.0, min_radius=5.0)
cell.compute_xyz()
cell._update_nodes_with_xyz()
cell.nodes

In [None]:
# comp = jx.Compartment()
# comp.compute_xyz()
# comp.update_nodes_with_xyz()

# print(comp.nodes[["x", "y", "z"]])
# print(comp.xyzr)

# cell = jx.read_swc(f"../../jaxley_experiments/nex/rgc/morphologies/{cell_id}.swc", nseg=4, max_branch_len=300.0, min_radius=5.0)

# misc ideas and Qs:
- merge `syn_edges` and `branch_edges` -> edges and replace branch_edge by "type=branch" label, why does base have `edges` and comp has `syn_edges` and `branch_edges`?
- connect should only have to update list of edges!
- what is the reasoning of having comp be a float. I find this counterintuitive since it's discrete. If only important for plotting, then we should only have this in plotting too.

In [None]:
comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(4)])
cell = jx.Cell([branch for _ in range(5)], parents=jnp.asarray([-1, 0, 0, 1, 1]))
network = jx.Network([cell for _ in range(5)])

network.cell(0).connect(network.cell(1), GlutamateSynapse)
network.cell(1).connect(network.cell(2), GlutamateSynapse)


In [None]:
type(GlutamateSynapse())

In [None]:
network.cell(0)

In [None]:
network.edges

In [None]:
# "20170610_1" is a t-off-mini
# "20161028_1" is a t-off-alpha
cell_id = "20161028_1"

if cell_id == "20161028_1":
    soma_branch = 1
    dendrite_branch = 50
elif cell_id == "20170610_1":
    soma_branch = 0
    dendrite_branch = 70
else:
    raise ValueError


cell = jx.read_swc(f"../../jaxley_experiments/nex/rgc/morphologies/{cell_id}.swc", nseg=4, max_branch_len=300.0, min_radius=5.0)

cell.insert(HH())
cell.insert(KA())
cell.insert(Ca())
cell.insert(KCa())

cell.set("v", -65.0)
cell.init_states()

cell.delete_trainables()

cell.delete_stimuli()
cell.delete_recordings()

cell.branch(soma_branch).loc(0.4).stimulate(current)
cell.branch(soma_branch).loc(0.4).record()
cell.branch(dendrite_branch).loc(1.0).record()

In [None]:
dims=[0,1]
soma = cell.xyzr[1][0]
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
dmin, dmax = 0, 182
for i in np.unique(cell.show()["branch_index"]):
    xyzr = cell.xyzr[i]
    d = np.sqrt(np.sum((xyzr[:, dims] - soma[dims])**2, axis=1)).mean()
    c = np.array(plt.cm.viridis((d-dmin)/(dmax-dmin)))
    cell.branch(i).vis(col=c, ax=ax, dims=dims)
plt.show()

In [None]:
cell.branch(dendrite_branch).show()

In [188]:
# global_post_indices = post_cell_view.view.groupby("cell_index").sample(num_pre, replace=True).index.to_numpy()        
# global_post_indices = global_post_indices.reshape((-1, num_pre), order="F").ravel()

274 ms ± 5.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
cell = jx.read_swc(f"../../jaxley_experiments/nex/rgc/morphologies/{cell_id}.swc", nseg=4, max_branch_len=300.0, min_radius=5.0)
cell.compute_xyz()
cell._update_nodes_with_xyz()
cell.nodes

In [None]:
# comp = jx.Compartment()
# comp.compute_xyz()
# comp.update_nodes_with_xyz()

# print(comp.nodes[["x", "y", "z"]])
# print(comp.xyzr)

# cell = jx.read_swc(f"../../jaxley_experiments/nex/rgc/morphologies/{cell_id}.swc", nseg=4, max_branch_len=300.0, min_radius=5.0)

# misc ideas and Qs:
- merge `syn_edges` and `branch_edges` -> edges and replace branch_edge by "type=branch" label, why does base have `edges` and comp has `syn_edges` and `branch_edges`?
- connect should only have to update list of edges!
- what is the reasoning of having comp be a float. I find this counterintuitive since it's discrete. If only important for plotting, then we should only have this in plotting too.

In [None]:
comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(4)])
cell = jx.Cell([branch for _ in range(5)], parents=jnp.asarray([-1, 0, 0, 1, 1]))
network = jx.Network([cell for _ in range(5)])

network.cell(0).connect(network.cell(1), GlutamateSynapse)
network.cell(1).connect(network.cell(2), GlutamateSynapse)


In [None]:
type(GlutamateSynapse())

In [None]:
network.cell(0)

In [None]:
network.edges

In [None]:
# "20170610_1" is a t-off-mini
# "20161028_1" is a t-off-alpha
cell_id = "20161028_1"

if cell_id == "20161028_1":
    soma_branch = 1
    dendrite_branch = 50
elif cell_id == "20170610_1":
    soma_branch = 0
    dendrite_branch = 70
else:
    raise ValueError


cell = jx.read_swc(f"../../jaxley_experiments/nex/rgc/morphologies/{cell_id}.swc", nseg=4, max_branch_len=300.0, min_radius=5.0)

cell.insert(HH())
cell.insert(KA())
cell.insert(Ca())
cell.insert(KCa())

cell.set("v", -65.0)
cell.init_states()

cell.delete_trainables()

cell.delete_stimuli()
cell.delete_recordings()

cell.branch(soma_branch).loc(0.4).stimulate(current)
cell.branch(soma_branch).loc(0.4).record()
cell.branch(dendrite_branch).loc(1.0).record()

In [None]:
dims=[0,1]
soma = cell.xyzr[1][0]
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
dmin, dmax = 0, 182
for i in np.unique(cell.show()["branch_index"]):
    xyzr = cell.xyzr[i]
    d = np.sqrt(np.sum((xyzr[:, dims] - soma[dims])**2, axis=1)).mean()
    c = np.array(plt.cm.viridis((d-dmin)/(dmax-dmin)))
    cell.branch(i).vis(col=c, ax=ax, dims=dims)
plt.show()

In [None]:
cell.branch(dendrite_branch).show()