In [None]:
# %load_ext autoreload
# %autoreload 2

In [1]:
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".8"

In [2]:
from __future__ import annotations
import jax
import jax.numpy as jnp
import equinox as eqx
import pandas as pd

import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from jax import vmap
# from graphjax.graph import MorphTree

from typing import Optional, List, Dict, Any, Union, Set, Tuple
from dataclasses import dataclass, field

In [1257]:
def swc_to_morph_tree(fname: str, num_lines: Optional[int] = None) -> MorphTree:
    i_id_xyzr_p = jnp.array(np.loadtxt(fname)[:num_lines])
    
    edges = i_id_xyzr_p[:, [0, -1]]
    edges = edges[edges[:, 1] != -1].astype(int)
    
    nodes = i_id_xyzr_p[:, 0].astype(int)
    
    get_node_attrs = lambda id, x, y, z, r: {"id": int(id), "x": x, "y": y, "z": z, "r": r}
    node_attrs = list(map(get_node_attrs, *i_id_xyzr_p[:, 1:-1].T))    
    return MorphTree(nodes, edges, node_attrs)

@dataclass
class MorphTree:
    nodes: jnp.ndarray
    edges: jnp.ndarray
    node_attrs: List[Dict[str, Any]] = None
    edge_attrs: List[Dict[str, Any]] = None
    global_attrs: Dict[str, Any] = field(default_factory=dict)

    def __post_init__(self) -> None:
        if self.node_attrs is None:
            self.node_attrs = [{}]*len(self.nodes)
        if self.edge_attrs is None:
            self.edge_attrs = [{}]*len(self.edges)

    def __repr__(self) -> str:
        n_nodes = len(self.nodes)
        n_edges = len(self.edges)

        node_attrs = list(self.node_attrs[0].keys()) if len(self.node_attrs) > 0 else []
        edge_attrs = list(self.edge_attrs[0].keys()) if len(self.edge_attrs) > 0 else []
        return f"MorphTree(nodes={n_nodes}*{node_attrs}, edges={n_edges}*{edge_attrs})"
    
    def node(self, i: int) -> Dict[str, Any]:
        node_idx = jnp.where(self.nodes == i)[0]
        if len(node_idx) > 0:
            return self.node_attrs[node_idx[0]]
        raise ValueError(f"Node ({i}) does not exist.")
    
    def edge(self, i: int, j: int) -> Dict[str, Any]:
        edge_idx = jnp.where(jnp.all(self.edges == jnp.array([i, j]), axis=1))[0]
        if len(edge_idx) > 0:
            return self.edge_attrs[edge_idx[0]]
        raise ValueError(f"Edge ({i}, {j}) does not exist.")
    
    def to_nx(self) -> nx.DiGraph:
        G = nx.DiGraph()
        node_map = map(int, self.nodes)
        edge_map = map(lambda x: tuple(map(int, x)), self.edges.T)
        G.add_nodes_from(zip(node_map, self.node_attrs))
        G.add_edges_from(zip(*edge_map, self.edge_attrs))
        return G
    
    @staticmethod
    def from_nx(G: nx.DiGraph) -> MorphTree:
        nodes = jnp.array(list(G.nodes))
        edges = jnp.array(list(G.edges))
        node_attrs = [G.nodes[n] for n in G.nodes]
        edge_attrs = [G.edges[e] for e in G.edges]
        return MorphTree(nodes, edges, node_attrs, edge_attrs)
    
    def to_pandas(self) -> Tuple[pd.DataFrame, pd.DataFrame]:
        node_df = pd.DataFrame(self.node_attrs, index=np.array(self.nodes))
        edge_df = pd.DataFrame(self.edge_attrs)
        edge_index = pd.MultiIndex.from_arrays(np.array(self.edges).T)
        edge_df = edge_df.set_index(edge_index)
        return node_df, edge_df
    
    @staticmethod
    def from_pandas(node_df: pd.DataFrame, edge_df: pd.DataFrame) -> MorphTree:
        node_attrs = node_df.to_dict(orient="records")
        edge_attrs = edge_df.to_dict(orient="records")
        edges = edge_df.index.to_numpy()
        nodes = node_df.index.to_numpy()
        return MorphTree(nodes, edges, node_attrs, edge_attrs)
    
    def plot(self, **kwargs: Any) -> None:
        G = self.to_nx()
        pos = {}
        for n, attr in G.nodes(data=True):
            if "x" in attr:  # assume y is also present
                pos[n] = (attr["x"], attr["y"])
        nx.draw(G, pos, with_labels=True, **kwargs)
        plt.show()

    def reindex_nodes(self, new_indices: jnp.ndarray) -> MorphTree:
        def remap(x: int) -> int:
            idx = jnp.argmax(self.nodes == x)
            exists = self.nodes[idx] == x
            return jnp.where(exists, new_indices[idx], x)

        new_nodes = vmap(remap)(self.nodes.ravel()).reshape(self.nodes.shape)
        new_edges = vmap(remap)(self.edges.ravel()).reshape(self.edges.shape)
        return MorphTree(new_nodes, new_edges, self.node_attrs, self.edge_attrs)
    
    def reorder_tree(self, new_order: jnp.ndarray) -> MorphTree:
        # TODO: check this does what I think it does, i.e. change the edge orientation
        # in order of appearance of the nodes in self.nodes.
        new_nodes = self.nodes[new_order]
        new_tree = self.reindex_nodes(new_nodes)

        new_edges = new_tree.edges
        idx_i = jnp.where(new_edges[:,0] == new_tree.nodes[:, None])[0]
        idx_j = jnp.where(new_edges[:,1] == new_tree.nodes[:, None])[0]
        is_descending = ~(idx_i < idx_j)[:, None]
        ascending_edges = jnp.where(is_descending, new_edges[:, ::-1], new_edges)
        return MorphTree(new_nodes, ascending_edges, new_tree.node_attrs, new_tree.edge_attrs)

def list_branches(tree: MorphTree, return_branchpoints: bool = False) -> Union[List[List[int]], Tuple[List[List[int]], Set[int], List[Tuple[int, int]]]]:
    G = tree.to_nx().to_undirected()
    branches = []
    branchpoints = set()
    visited = set()

    def is_branchpoint_or_tip(n: int) -> bool:
        is_leaf = G.degree(n) <= 1
        is_branching = G.degree(n) > 2
        if G.degree(n) == 2:
            i,j = G.neighbors(n)
             # trace dir matters here! For segment with node IDs: [1, 1, 2, 2]
             # -> [[1,1], [1,2,2]] 
             # <- [[2,2], [2,1,1]] 
            return not same_id(n, j)
        
        return is_leaf or is_branching

    def in_visited(n1: int, n2: int) -> bool:
        return (n1, n2) in visited or (n2, n1) in visited

    def same_id(n1: int, n2: int) -> bool:
        return G.nodes[n1]["id"] == G.nodes[n2]["id"] if "id" in G.nodes[n1] else True

    def is_soma(n: int) -> bool:
        return G.nodes[n]["id"] == 1
    
    def soma_nodes() -> bool:
        return [i for i, n in G.nodes.items() if n["id"] == 1]

    def walk_path(start: int, succ: int) -> List[int]:
        path = [start, succ]
        visited.add((start, succ))
 
        while G.degree(succ) == 2:
            next_node = next(n for n in G.neighbors(succ) if n != path[-2])

            if in_visited(succ, next_node) or is_branchpoint_or_tip(succ):
                break
            
            path.append(next_node)
            visited.add((succ, next_node))
            succ = next_node

        return path

    leaf = next(n for n in G.nodes if G.degree(n) == 1)
    single_soma = len(soma_nodes()) == 1
    for node in nx.dfs_tree(G, leaf):
        if single_soma and is_soma(node):
            branches.append([node])

        elif is_branchpoint_or_tip(node):
            branchpoints.add(node)
            for succ in G.neighbors(node):
                if not in_visited(node, succ):
                    branches.append(walk_path(node, succ))
    
    if return_branchpoints:
        branchpoint_edges = sum([list(G.edges(n)) for n in branchpoints], [])
        return branches, branchpoints, branchpoint_edges
    return branches

def compartmentalize(tree: MorphTree, num_comps: int = 1) -> MorphTree:
    branches = list_branches(tree)
    nodes_df = tree.to_pandas()[0].astype(float)

    # create new set of indices which arent already used as node indices to label comps
    existing_inds = set(nodes_df.index)
    num_new_inds = len(branches)*num_comps
    proposed_inds = set(range(num_new_inds + len(existing_inds)))
    proposed_comp_inds = list(proposed_inds - existing_inds) # avoid overlap w. node indices
    
    v_interp = vmap(jnp.interp, in_axes=(None, None, 1), out_axes=1)
    
    # identify tip nodes (degree == 1)
    nodes_in_edges, node_counts_in_edges = np.unique(tree.edges, return_counts=True)
    tip_node_inds = nodes_in_edges[node_counts_in_edges == 1]

    # collect comps and comp_edges
    branch_nodes, branch_edges = [], []
    for i, branch in enumerate(branches):
        node_attrs = nodes_df.loc[branch]
        xyz_i = node_attrs[["x", "y", "z"]]
        edge_lens = ((xyz_i.diff(axis=0).fillna(0)**2).sum(axis=1)**.5)
        node_attrs["l"] = edge_lens.cumsum() # path length
        
        # For single-point somatata, we set l = 2*r this ensures
        # A_cylinder = 2*pi*r*l = 4*pi*r^2 = A_sphere.
        if len(branch) == 1:
            node_attrs = node_attrs.loc[branch*2] # duplicate soma node
            radius = node_attrs["r"].iloc[0]
            node_attrs["l"] = np.array([0, 2*radius])

        branch_id = node_attrs["id"].iloc[-1] # TODO: handle multi ids within branch!
        branch_len = max(node_attrs["l"])
        comp_len = branch_len / num_comps
        comp_locs = list(np.linspace(comp_len/2, branch_len - comp_len/2, num_comps))
        
        # Create node indices and attributes
        # branch_inds, comp_type, comp_id, comp_len
        branch_tips = branch[0], branch[-1]
        branch_tip_attrs = [i, -1, -1, 0]
        comp_attrs = [i, 0, branch_id, comp_len] # comp_type: -1: branchpoint, 0: compartment

        comp_inds = proposed_comp_inds[i*num_comps:(i+1)*num_comps]
        comp_inds = np.array([branch_tips[0], *comp_inds, branch_tips[1]])
        comp_attrs = np.array([branch_tip_attrs] + [comp_attrs]*num_comps + [branch_tip_attrs])
        comp_attrs = np.hstack([comp_inds[:, None], comp_attrs])
        
        # Interpolate xyzr coordinates and combine with attributes
        x = jnp.array([0] + comp_locs + [branch_len]) # 0, branch_len = branchpoints
        xp = jnp.array(node_attrs["l"].values)
        fp = jnp.array(node_attrs[["x", "y", "z", "r"]].values)
        #TODO: interpolate r differently!
        comp_attrs = np.hstack([comp_attrs, np.array(v_interp(x, xp, fp))])
        
        # remove tip nodes
        comp_attrs = comp_attrs[1:] if branch_tips[0] in tip_node_inds else comp_attrs
        comp_attrs = comp_attrs[:-1] if branch_tips[1] in tip_node_inds else comp_attrs

        # Store edges and nodes
        branch_edges.append(list(zip(comp_attrs[:-1, 0], comp_attrs[1:, 0])))
        branch_nodes.append(comp_attrs)

    # TODO: add missing attrs, cell_index, axial_resistvity, membrane_capacitance, voltage
    branch_nodes = jnp.concatenate(branch_nodes)
    comp_attrs_keys = ["idx", "branch", "type", "id", "l", "x", "y", "z", "r"]
    comp_df = pd.DataFrame(branch_nodes, columns=comp_attrs_keys)
    int_cols = ["idx", "branch", "type", "id"]
    comp_df[int_cols] = comp_df[int_cols].astype(int)
    
    # drop duplicated branch nodes
    comp_df = comp_df.drop_duplicates(subset=["idx"])
    comp_df = comp_df.set_index("idx")

    comps = jnp.array(comp_df.index)
    comp_attrs = comp_df.to_dict(orient="records")
    comp_edges = jnp.array(sum(branch_edges, []))
    comp_edge_attrs = [{"comp_edge": True, "synapse": False} for _ in comp_edges]

    comp_tree = MorphTree(comps, comp_edges, comp_attrs, comp_edge_attrs)
    comp_tree = comp_tree.reindex_nodes(jnp.arange(len(comps)))
    return comp_tree

In [1334]:
class DummyChannel:
    def __init__(self, name = None):
        self.name = self.__class__.__name__ if name is None else name
        self.params = {f"gbar_{self.name}": 1.0, f"e_{self.name}": 0.0, f"nn_weights_{self.name}": jnp.ones((10, 10))}
        self.states = {f"m_{self.name}": 0.5, f"h_{self.name}": 0.5}

def dummy_insert(tree, inds, channel):
    # TODO: Should Module and MorphTree be separate or the same thing?
    for i in inds:
        tree.node_attrs[i].update(channel.params)
        tree.node_attrs[i].update(channel.states)
    tree.global_attrs["channels"].append(channel)

def dummy_set(tree, inds, key, value):
    for i in inds:
        tree.node_attrs[i][key] = value

def dummy_to_pytree(tree):
    nodes_df = tree.to_pandas()[0]
    jax_nodes = {}
    for key in nodes_df.columns:
        values = jnp.array(nodes_df.loc[~nodes_df[key].isna(), key].to_list())
        jax_nodes[key] = values
    return jax_nodes

class DummySynapse:
    def __init__(self, name = None):
        self.name = self.__class__.__name__ if name is None else name
        self.params = {f"gbar_{self.name}": 1.0, f"e_{self.name}": 0.0, f"nn_weights_{self.name}": jnp.ones((10, 10))}
        self.states = {f"m_{self.name}": 0.5, f"h_{self.name}": 0.5}

def dummy_connect(tree, pre, post, synapse):
    synapse_edges = jnp.vstack([pre, post]).T
    synapse_idxs = np.where((tree.edges[:, :, None] == synapse_edges).all(axis=1).any(axis=1))[0]
    # TODO: map / vectorize this
    for i in synapse_idxs:
        tree.edge_attrs[i]["synapse"] = True
        tree.edge_attrs[i].update(synapse.params)
        tree.edge_attrs[i].update(synapse.states)

# There can only be one edge per pair of nodes. (or use MultiDiGraph).
# This means all synapses need to live in the same edge (i,j)
# -> treat synapses more like channels, i.e. multiple channels per row in nodes -> multiple synapses per edge.
# downside cannot connect i and j with the same synapse twice, but can do if one synapse is named differently.
# think about how to handle if i,j is a comp_edge and also connects via synapses


# TODO: node and edge attrs as list or dict?
# - pro: one can change node / edge idx without changing touching attrs, since pos of node_idx -> pos node_attr
# - con: hard to index into node / edge attrs
    

In [1335]:
from jaxley.io.graph import build_compartment_graph, to_swc_graph, _trace_branches, _remove_branch_points

testcases = [ 
"morph_3_types_single_point_soma.swc",
"morph_3_types.swc",
"morph_interrupted_soma.swc",
"morph_soma_both_ends.swc",
"morph_somatic_branchpoint.swc",
"morph_non_somatic_branchpoint.swc", # no soma!
"morph_ca1_n120_single_point_soma.swc",
"morph_ca1_n120.swc",
"morph_l5pc_with_axon.swc",
"morph_allen_485574832.swc",
]

jx_graph = to_swc_graph("../jaxley/tests/swc_files/"+testcases[-3])

morph_comps = compartmentalize(MorphTree.from_nx(jx_graph.copy()), num_comps=1)
jx_comps = build_compartment_graph(jx_graph, ncomp=1)

In [1336]:
dummy_connect(morph_comps, jnp.array([0, 1]), jnp.array([2, 3]), DummySynapse("test1"))
dummy_connect(morph_comps, jnp.array([0, 1]), jnp.array([2, 3]), DummySynapse("test2"))

In [877]:
from jaxley.io.graph import build_compartment_graph, to_swc_graph, _trace_branches

testcases = [ 
"morph_3_types_single_point_soma.swc",
"morph_3_types.swc",
"morph_interrupted_soma.swc",
"morph_soma_both_ends.swc",
"morph_somatic_branchpoint.swc",
"morph_non_somatic_branchpoint.swc", # no soma!
"morph_ca1_n120_single_point_soma.swc",
"morph_ca1_n120.swc",
"morph_l5pc_with_axon.swc",
"morph_allen_485574832.swc",
]

for i, testcase in enumerate(testcases):
    jx_graph = to_swc_graph("../jaxley/tests/swc_files/"+testcase)

    morph_branches = list_branches(MorphTree.from_nx(jx_graph.copy()))
    morph_branch_nodes = [np.sort(b) for b in morph_branches]

    # do jx_trace after morph_traces, since jax_trace modifies the graph
    jx_branches = _trace_branches(jx_graph.copy())[1]
    jx_branch_nodes = [np.sort(np.unique(b[:, :-1])) for b in jx_branches]
    if i in [0,6,9]: # single soma
        jx_branch_nodes = [b-1 for b in jx_branch_nodes]

    morph_eq_jx = []

    for i, b in enumerate(jx_branch_nodes):
        for j, mb in enumerate(morph_branch_nodes):
            if len(b) == len(mb):
                if np.allclose(b, mb):
                    morph_eq_jx.append((i,j))
                    break
    if len(morph_eq_jx) > 0:
        diff_morph_branches = [b for i, b in enumerate(morph_branch_nodes) if i not in np.array(morph_eq_jx)[:,1]]
        diff_jx_branches = [b for j, b in enumerate(jx_branch_nodes) if j not in np.array(morph_eq_jx)[:,0]]
    else:
        print("No branches are equal")
        diff_morph_branches = morph_branch_nodes
        diff_jx_branches = jx_branch_nodes
    
    # single soma handled differently and will lead to 1 diff branch
    print(f"testcase {testcase}: {len(diff_morph_branches)}, {len(diff_jx_branches)}")

testcase morph_3_types_single_point_soma.swc: 1, 1
testcase morph_3_types.swc: 0, 0
testcase morph_interrupted_soma.swc: 0, 0
testcase morph_soma_both_ends.swc: 0, 0
testcase morph_somatic_branchpoint.swc: 0, 0
testcase morph_non_somatic_branchpoint.swc: 0, 0
testcase morph_ca1_n120_single_point_soma.swc: 1, 1
testcase morph_ca1_n120.swc: 0, 0
testcase morph_l5pc_with_axon.swc: 0, 0
testcase morph_allen_485574832.swc: 1, 1


In [867]:
# from jaxley.io.graph import build_compartment_graph, to_swc_graph, _trace_branches

# testcases = [ 
# "morph_3_types_single_point_soma.swc",
# "morph_3_types.swc",
# "morph_interrupted_soma.swc",
# "morph_soma_both_ends.swc",
# "morph_somatic_branchpoint.swc",
# "morph_non_somatic_branchpoint.swc", # no soma!
# "morph_ca1_n120_single_point_soma.swc",
# "morph_ca1_n120.swc",
# "morph_l5pc_with_axon.swc",
# "morph_allen_485574832.swc",
# ]

# jx_graph = to_swc_graph("../jaxley/tests/swc_files/"+testcases[2])

# morph_branches = list_branches(MorphTree.from_nx(jx_graph.copy()))
# morph_branch_nodes = [np.sort(b) for b in morph_branches]

# # do jx_trace after morph_traces, since jax_trace modifies the graph
# jx_branches = _trace_branches(jx_graph.copy())[1]
# jx_branch_nodes = [np.sort(np.unique(b[:, :-1])) for b in jx_branches]
# # jx_branch_nodes = [b-1 for b in jx_branch_nodes]


# morph_eq_jx = []

# for i, b in enumerate(jx_branch_nodes):
#     for j, mb in enumerate(morph_branch_nodes):
#         if len(b) == len(mb):
#             if np.allclose(b, mb):
#                 morph_eq_jx.append((i,j))
#                 break
# if len(morph_eq_jx) > 0:
#     diff_morph_branches = [b for i, b in enumerate(morph_branch_nodes) if i not in np.array(morph_eq_jx)[:,1]]
#     diff_jx_branches = [b for j, b in enumerate(jx_branch_nodes) if j not in np.array(morph_eq_jx)[:,0]]
# else:
#     print("No branches are equal")
#     diff_morph_branches = morph_branch_nodes
#     diff_jx_branches = jx_branch_nodes

# jx_subgraph = jx_graph.subgraph(np.unique(np.hstack(diff_morph_branches)))

# # Get node positions and colors
# pos = {node: (jx_subgraph.nodes[node]['x'], jx_subgraph.nodes[node]['y']) for node in jx_subgraph.nodes()}

# # Create figure with 1x3 subplots
# fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# # Plot 1: Node indices
# node_colors = [jx_subgraph.nodes[node]['id'] for node in jx_subgraph.nodes()]
# nx.draw(jx_subgraph, pos=pos, node_color=node_colors, cmap='viridis', with_labels=True, ax=axes[0])
# axes[0].set_title('Node Indices')

# # Plot 2: JX branches
# node_colors = np.zeros(len(jx_subgraph.nodes()))
# for i, branch in enumerate(diff_jx_branches):
#     node_colors[np.isin(list(jx_subgraph.nodes()), branch)] = i + 1
# nx.draw(jx_subgraph, pos=pos, node_color=node_colors, cmap='tab10', with_labels=True, ax=axes[1])
# axes[1].set_title('JX Branches')

# # Plot 3: Morph branches
# node_colors = np.zeros(len(jx_subgraph.nodes()))
# for i, branch in enumerate(diff_morph_branches):
#     node_colors[np.isin(list(jx_subgraph.nodes()), branch)] = i + 1
# nx.draw(jx_subgraph, pos=pos, node_color=node_colors, cmap='tab10', with_labels=True, ax=axes[2])
# axes[2].set_title('Morph Branches')

# plt.tight_layout()
# plt.show()

In [4]:
def infer_module_type_from_inds(idxs: pd.DataFrame) -> str:
    nuniques = idxs[["cell_index", "branch_index", "comp_index"]].nunique()
    nuniques.index = ["cell", "branch", "compartment"]
    nuniques = pd.concat([pd.Series({"network": 1}), nuniques])
    return_type = nuniques.loc[nuniques == 1].index[-1]
    return return_type


def build_module_scaffold(
    idxs: pd.DataFrame,
    return_type: Optional[str] = None,
    parent_branches: Optional[List[np.ndarray]] = None,
) -> Union[jx.Network, jx.Cell, jx.Branch, jx.Compartment]:
    """Builds a skeleton module from a DataFrame of indices.
    This is useful for instantiating a module that can be filled with data later.
    Args:
        idxs: DataFrame containing cell_index, branch_index, comp_index, i.e.
            Module.nodes or View.view.
        return_type: Type of module to return. If None, the type is inferred from the
            number of unique values in the indices. I.e. only 1 unique cell_index
                and 1 unique branch_index -> return_type = "jx.Branch".
    Returns:
        A skeleton module with the correct number of compartments, branches, cells, or
        networks."""
    return_types = ["compartment", "branch", "cell", "network"]
    build_cache = {k: [] for k in return_types}

    if return_type is None:  # infer return type from idxs
        return_type = infer_module_type_from_inds(idxs)

    comp = jx.Compartment()
    build_cache["compartment"] = [comp]

    if return_type in return_types[1:]:
        nsegs = idxs["branch_index"].value_counts().iloc[0]
        branch = jx.Branch([comp for _ in range(nsegs)])
        build_cache["branch"] = [branch]

    if return_type in return_types[2:]:
        for cell_id, cell_groups in idxs.groupby("cell_index"):
            num_branches = cell_groups["branch_index"].nunique()
            default_parents = np.arange(num_branches) - 1  # ignores morphology
            parents = (
                default_parents if parent_branches is None else parent_branches[cell_id]
            )
            cell = jx.Cell([branch] * num_branches, parents)
            build_cache["cell"].append(cell)

    if return_type in return_types[3:]:
        build_cache["network"] = [jx.Network(build_cache["cell"])]

    module = build_cache[return_type][0]
    build_cache.clear()
    return module

In [5]:
def convert_view_to_module(view: View, reset_index: bool = True) -> "Module":
    """Extract part of a module and return a copy of its View or a new module.

    This can be used to call `jx.integrate` on part of a Module.

    Args:
        reset_index: if True, the indices of the new module are reset to start from 0.
        as_module: if True, a new module is returned instead of a View.

    Returns:
        A part of the module or a copied view of it."""
    view = deepcopy(view)
    if reset_index:
        view.nodes.reset_index(drop=True, inplace=True)
        view.edges.reset_index(drop=True, inplace=True)
        # TODO: also re-enumerate cell,branch,comp indices in nodes and edges

    testnodes = view.nodes.copy()
    testnodes.rename({"global_"+k:k for k in ["cell_index", "branch_index", "comp_index"]}, axis=1, inplace=True)
    mod_type = infer_module_type_from_inds(testnodes)
    module = build_module_scaffold(testnodes, mod_type)
    module.__dict__.update(view.__dict__)
    return module

In [6]:
branch = jx.Branch(ncomp=2)
branch[0].insert(HH())
branch.record()

comp = branch[0]
comp = convert_view_to_module(comp)
jx.integrate(comp, t_max=1)

Added 2 recordings. See `.recordings` for details.


Array([[-70.        , -69.48756081, -69.09088174, -68.77443768,
        -68.51477492, -68.29603413, -68.10727942, -67.94084395,
        -67.79127429, -67.65464046, -67.52807603, -67.40946647,
        -67.29723486, -67.19019277, -67.0874353 , -66.98826644,
        -66.89214546, -66.7986481 , -66.70743824, -66.61824695,
        -66.53085703, -66.44509138, -66.36080422, -66.27787448,
        -66.19620072, -66.11569714, -66.03629063, -65.9579183 ,
        -65.8805257 , -65.80406526, -65.72849517, -65.6537784 ,
        -65.57988194, -65.50677619, -65.43443445, -65.3628325 ,
        -65.29194826, -65.22176147, -65.15225352, -65.08340718,
        -65.01520646]], dtype=float64)

- change param / state naming convenetion for channels and synapses
- ensure that channel/synapse currents are prefixed on the module side
- ensure that args and kwargs in channels functions are same order and standardized (enforce this with tests)
- make channel / synapse name user facing ?
- jx.integrate(net1, t_max=0) runs one step
- are we testing inner vs outer loop of setting states / params in init_states / _step_channel_currents etc?
- add test to see if indices are handled correctly for two different mechs changing the same state at non-overlapping indices
- fix: ![image.png](attachment:image.png)
- add prepare_mechansim function (assert name does not exist already, assert conform to build rules)
- warn if states / params does not contain all global params
update_states vs. init_state
try:
    channel.update_states()
    channel.compute_current()
    channel.init_states()
except KeyError:
    warn("Some global param / state seems to be misssing")


- in the new channel API, `_filter_params_states` would have to filter potentially all channels, even the ones not in the channel. Better Solution! only filter global states!
- add current_name to synapses
- store global parameters in jaxnodes only once! i.e. for any global param jaxnodes["global_param_name"].shape == (1,) and duplicate if it is used multiple times
- should compute_current get dt arg?

In [25]:
def get_params_all_trainable(net):
    net.cell("all").branch("all").loc("all").make_trainable("HH_gNa")
    params = net.get_parameters()
    params[0]["HH_gNa"] = params[0]["HH_gNa"].at[:].set(0.0)
    net.to_jax()
    pstate = params_to_pstate(params, net.indices_set_by_trainables)
    print(pstate)
    return net.get_all_parameters(pstate, voltage_solver="jaxley.thomas")

def get_params_set(net):
    net.set("HH_gNa", 0.0)
    params = net.get_parameters()
    net.to_jax()
    pstate = params_to_pstate(params, net.indices_set_by_trainables)
    return net.get_all_parameters(pstate, voltage_solver="jaxley.thomas")

def SimpleComp():
    comp = jx.Compartment()
    return comp

def SimpleBranch(ncomp):
    comp = jx.Compartment()
    branch = jx.Branch([comp]*ncomp)
    return branch

def SimpleCell(nbranch, ncomp):
    branch = SimpleBranch(ncomp)
    cell = jx.Cell([branch]*nbranch, parents=[-1, 0, 0, 1, 1, 2, 2, 3, 3][:nbranch])
    return cell

def SimpleNet(ncell, nbranch, ncomp):
    cell = SimpleCell(nbranch, ncomp)
    net = jx.Network([cell]*ncell)
    return net

Number of newly added trainable parameters: 3. Total number of trainable parameters: 3


Array([[ 1,  2, -1],
       [ 6,  7, -1],
       [ 3, -1, -1]], dtype=int64)

In [None]:
# global_params = ["v", "radius", "length", "axial_resistivity", "capacitance"]
# jaxnodes = {"global": {param: jnp.asarray(cell.nodes[param]) for param in global_params}}

# # Add channel-specific nodes and update globals
# for channel in cell.channels:
#     channel_dict = {}
#     for param, value in {**channel.states, **channel.params}.items():
#         if f"{channel._name}_" in param:
#             channel_dict[param] = jnp.asarray(cell.nodes[param][channel.indices])
#         else:
#             jaxnodes["global"][param] = jnp.asarray(cell.nodes[param])
#     jaxnodes[channel._name] = channel_dict

# # Update channel states
# for channel in cell.channels:
#     # Combine channel-specific and global nodes
#     channel_nodes = jaxnodes[channel._name].copy()
#     channel_nodes.update({k: v[channel.indices] for k, v in jaxnodes["global"].items()})
    
#     # Update states
#     channel_states_updated = channel.update_states(
#         channel_nodes, 0.025, channel_nodes["v"], channel_nodes
#     )
    
#     # Apply updates back to jaxnodes
#     for key, val in channel_states_updated.items():
#         mech_key = "global" if key in jaxnodes["global"] else channel._name
#         jaxnodes[mech_key][key] = jaxnodes[mech_key][key].at[channel.indices].set(val)

# def _iter_states_params(self, params=False, states=False) -> Tuple[str, jnp.ndarray]:
#     # TODO FROM #447: MAKE THIS WORK FOR VIEW?

#     # assert that either params or states is True
#     assert params or states, "Either params or states must be True."
#     global_states = ["v"]
#     morph_params = ["radius", "length", "axial_resistivity", "capacitance"]

#     for key in global_states + morph_params:
#         yield "global", key, self._inds_of_state_param[key]
            
#     mechs = self.channels + self.synapses
#     for mech in mechs:
#         data = self.nodes if isinstance(mech, Channel) else self.edges
#         params_states = mech.params if params else []
#         params_states += mech.states if states else []
#         for key in params_states:
#             if f"{mech._name}_" not in key:
#                 yield "global", key, jnp.asarray(data.index)
#             else:
#                 yield mech._name, key, mech.indices

# def _get_all_states_params(
#     self,
#     pstate: List[Dict],
#     voltage_solver=None,
#     delta_t=None,
#     all_params=None,
#     params=False,
#     states=False,
# ) -> Dict[str, jnp.ndarray]:
#     states_params = {}
#     pkeys = {k:[] for k in pstate}
#     for i, p in enumerate(pstate):
#         pkeys[p["key"]] += [i]

#     for mech_key, key, _ in self._iter_states_params(params, states):
#         jax_arrays = self.jaxnodes if key in self.nodes.columns else self.jaxedges
#         states_params[mech_key][key] = jax_arrays[mech_key][key]

#         if key in pkeys:
#             for i in pkeys[key]:
#                 # `inds` is of shape `(num_params, num_comps_per_param)`.
#                 # `set_param` is of shape `(num_params,)`
#                 # We need to unsqueeze `set_param` to make it `(num_params, 1)`
#                 #  for the `.set()` to work. This is done with `[:, None]`.
#                 inds, set_param = pstate[i]["indices"], pstate[i]["val"]
#                 states_params[mech_key][key] = states_params[key].at[inds].set(set_param[:, None])

#     if params:
#         # Compute conductance params and add them to the params dictionary.
#         states_params["axial_conductances"] = self._compute_axial_conductances(
#             params=states_params
#         )

#     if states:
#         all_params = states_params if all_params is None and params else all_params
#         for current in self.membrane_current_names:
#             states_params[current] = jnp.zeros_like(states_params['v'])
#         # Add to the states the initial current through every channel.
#         states, _ = self._channel_currents(
#             states_params, delta_t, self.channels, self.nodes, all_params
#         )

#         # Add to the states the initial current through every synapse.
#         states, _ = self._synapse_currents(
#             states_params, self.synapses, all_params, delta_t, self.edges
#         )
#     return states_params
    