In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".8"

In [200]:
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
from jax import jit, vmap, value_and_grad

import jaxley as jx
from jaxley.channels import HH
from jaxley.synapses import IonotropicSynapse, TestSynapse
from jaxley_mech.channels.fm97 import Na, K, KA, KCa, Ca, Leak

import warnings
import networkx as nx
from jaxley.connection import connect
import pandas as pd

In [4]:
dt = 0.025
t_max = 100.0
time_vec = np.arange(0, t_max+2*dt, dt)

i_delay = 10.0
i_dur = 80.0
i_amp = 5.0  # nA
current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)
np.random.seed(0)

In [5]:
comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(4)])
cell = jx.Cell([branch for _ in range(5)], parents=jnp.asarray([-1, 0, 1, 2, 2]))
net = jx.Network([cell]*3)
connect(net[0,0,0], net[1,0,0], IonotropicSynapse())
connect(net[0,0,1], net[1,0,1], IonotropicSynapse())
connect(net[0,0,1], net[1,0,1], TestSynapse())
net.cell(2).add_to_group("cell2")
net.cell(2).branch(1).add_to_group("cell2brach1")


net.cell(0).insert(Na())
net.cell(0).insert(Leak())

net.cell(1).branch(1).insert(Na())
net.cell(0).insert(K())
net.compute_xyz()

net.cell(0).branch(0).loc(0.0).record()
net.cell(0).branch(0).loc(0.0).record("m")
current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)
net.cell(0).branch(2).loc(0.0).stimulate(current)
net.cell(0).branch(1).make_trainable("Na")
net.cell(1).make_trainable("K")

  self.pointer.edges = pd.concat(


Added 1 recordings. See `.recordings` for details.
Added 1 recordings. See `.recordings` for details.
Added 1 stimuli. See `.currents` for details.
Number of newly added trainable parameters: 1. Total number of trainable parameters: 1
Number of newly added trainable parameters: 1. Total number of trainable parameters: 2


In [305]:
net.compute_xyz()
net.cell(1).move(0,30,0)
net.cell(2).move(0,-30,0)

In [389]:
return_type = "Network".lower()
module_build_cache = {"compartment": [], "branch": [], "cell": [], "network": []}
for cell_id, cell_groups in net.nodes.groupby("cell_index"):
    for branch_id, branch_groups in cell_groups.groupby("branch_index"):
        num_comps = len(branch_groups["comp_index"])
        module_build_cache["compartment"] = [jx.Compartment() for _ in range(num_comps)]
        module_build_cache["branch"].append(jx.Branch(module_build_cache["compartment"]))
    
    parents = np.arange(len(module_build_cache["branch"]))-1 # COMPUTE PARENTS HERE
    module_build_cache["cell"].append(jx.Cell(module_build_cache["branch"], parents))
    module_build_cache["branch"] = []
module_build_cache["network"] = [jx.Network(module_build_cache["cell"])]

module = module_build_cache[return_type][0]

In [410]:
module_graph = net.to_graph()

In [412]:
module = jx.utils.cell_utils.from_graph(module_graph)
module.to_graph()

<networkx.classes.digraph.DiGraph at 0x7f201d996c90>

In [310]:
# plot the graph
pos = {i: (n["x"], n["y"]) for i, n in module_graph.nodes(data=True)} 
plt.figure(figsize=(8, 8))
nx.draw(module_graph, pos, with_labels=True, node_size=200, node_color="skyblue", font_size=8, font_weight="bold", font_color="black", font_family="sans-serif")
plt.show()

TypeError: 'DataFrame' object is not callable

### Ideas:
- add a `groups` property
- show group membership in nodes
- make use to to_graph in plotting
- make use of from graph in swc import
- modules can be saved as graph, no need for pickle necessarily