In [1]:
# %load_ext autoreload
# %autoreload 2

In [2]:
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".8"

In [3]:
from __future__ import annotations
import jax
import jax.numpy as jnp
import pandas as pd

import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from jax import vmap

from typing import Optional, List, Dict, Any, Union, Set, Tuple
from matplotlib.axes import Axes
from dataclasses import dataclass, field

In [4]:
def pandas_to_nx(
    node_attrs: pd.DataFrame, edge_attrs: pd.DataFrame, global_attrs: pd.Series
) -> nx.DiGraph:
    """Convert node_attrs, edge_attrs and global_attrs from pandas datatypes to a NetworkX DiGraph.

    Args:
        node_attrs: DataFrame containing node attributes
        edge_attrs: DataFrame containing edge attributes
        global_attrs: Series containing global graph attributes

    Returns:
        A directed graph with nodes, edges and global attributes from the input data.
    """
    has_edge_attrs = None if edge_attrs.empty else True
    G = nx.from_pandas_edgelist(
        edge_attrs.reset_index(),
        source="level_0",
        target="level_1",
        edge_attr=has_edge_attrs,
        create_using=nx.DiGraph(),
    )

    nx.set_node_attributes(G, node_attrs.to_dict(orient="index"))
    G.graph.update(global_attrs.to_dict())
    return G


def nx_to_pandas(G: nx.DiGraph) -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series]:
    """Convert a NetworkX DiGraph to pandas datatypes.

    Args:
        G: Input directed graph

    Returns:
        Tuple containing:
        - DataFrame of node attributes
        - DataFrame of edge attributes
        - Series of global graph attributes
    """
    edge_df = nx.to_pandas_edgelist(G).set_index(["source", "target"])
    edge_df.index.names = [None, None]
    node_df = pd.DataFrame.from_dict(dict(G.nodes(data=True)), orient="index")

    return node_df, edge_df, pd.Series(G.graph)


def swc_to_nx(fname: str, num_lines: Optional[int] = None) -> nx.DiGraph:
    """Read a SWC morphology file into a NetworkX DiGraph.

    Args:
        fname: Path to the SWC file
        num_lines: Number of lines to read from the file

    Returns:
        A directed graph representing the morphology where:
        - Nodes have attributes: id, x, y, z, r (radius)
        - Edges represent parent-child relationships
    """
    i_id_xyzr_p = np.loadtxt(fname)[:num_lines]

    graph = nx.DiGraph()
    for i, id, x, y, z, r, p in i_id_xyzr_p.tolist():  # tolist: np.float64 -> float
        graph.add_node(int(i), **{"id": int(id), "x": x, "y": y, "z": z, "r": r})
        if p != -1:
            graph.add_edge(int(p), int(i))
    return graph


def nx_to_jax(G: nx.DiGraph) -> jax.tree_util.PyTree:
    """Convert a NetworkX DiGraph to a Jax tree.

    Args:
        G: Input directed graph

    Returns:
        A Jax tree representing the morphology.
    """

    inds, jax_node_attrs = jax.tree_util.tree_map(lambda *args: jnp.array(args), *G.nodes(data=True))
    jax_node_attrs["index"] = jnp.array(inds)

    *inds, jax_edge_attrs = jax.tree_util.tree_map(lambda *args: jnp.array(args), *G.edges(data=True))
    jax_edge_attrs["index_pre"] = jnp.array(inds[0])
    jax_edge_attrs["index_post"] = jnp.array(inds[1])

    jax_global_attrs = {k: jnp.array(v) for k, v in G.graph.items()}

    return jax_node_attrs, jax_edge_attrs, jax_global_attrs

def jax_to_nx(jax_node_attrs: jax.tree_util.PyTree, jax_edge_attrs: jax.tree_util.PyTree, jax_global_attrs: jax.tree_util.PyTree) -> nx.DiGraph:
    """Convert a Jax tree to a NetworkX DiGraph.

    Args:
        jax_node_attrs: Jax tree of node attributes
        jax_edge_attrs: Jax tree of edge attributes
        jax_global_attrs: Jax tree of global graph attributes

    Returns:
        A NetworkX DiGraph representing the morphology.
    """
    node_df, edge_df, global_attrs = jax_to_pandas(jax_node_attrs, jax_edge_attrs, jax_global_attrs)
    return pandas_to_nx(node_df, edge_df, global_attrs)

def jax_to_pandas(jax_node_attrs: jax.tree_util.PyTree, jax_edge_attrs: jax.tree_util.PyTree, jax_global_attrs: jax.tree_util.PyTree) -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series]:
    """Convert a Jax tree to pandas datatypes.

    Args:
        jax_node_attrs: Jax tree of node attributes
        jax_edge_attrs: Jax tree of edge attributes
        jax_global_attrs: Jax tree of global graph attributes

    Returns:
        Tuple containing:
        - DataFrame of node attributes
        - DataFrame of edge attributes
        - Series of global graph attributes
    """

    node_index = np.array(jax_node_attrs[("index",)])
    node_attrs_df = pd.DataFrame({k:v.tolist() for k, v in jax_node_attrs.items()}, index=node_index).drop(columns=[("index",)])
    
    edge_index = pd.MultiIndex.from_arrays(np.vstack([jax_edge_attrs["index_pre"], jax_edge_attrs["index_post"]]))
    edge_attrs_df = pd.DataFrame({k:v.tolist() for k, v in jax_edge_attrs.items()}, index = edge_index).drop(columns=["index_pre", "index_post"])

    global_attrs_df = pd.Series(jax_global_attrs)

    return node_attrs_df, edge_attrs_df, global_attrs_df

def pandas_to_jax(node_df: pd.DataFrame, edge_df: pd.DataFrame, global_attrs: pd.Series) -> jax.tree_util.PyTree:
    """Convert pandas datatypes to a Jax tree.

    Args:
        node_df: DataFrame of node attributes
        edge_df: DataFrame of edge attributes
        global_attrs: Series of global graph attributes
    """
    node_attrs = node_df.to_dict(orient="index")
    edge_attrs = edge_df.to_dict(orient="index")

    inds, jax_node_attrs = jax.tree_util.tree_map(lambda *args: jnp.array(args), *node_attrs.items())
    jax_node_attrs[("index", )] = jnp.array(inds)

    *inds, jax_edge_attrs = jax.tree_util.tree_map(lambda *args: jnp.array(args), *edge_attrs.items())
    jax_edge_attrs["index_pre"] = jnp.array(inds[0])
    jax_edge_attrs["index_post"] = jnp.array(inds[1])

    jax_global_attrs = {k: jnp.array(v) for k, v in global_attrs.items()}

    return jax_node_attrs, jax_edge_attrs, jax_global_attrs

In [479]:
def tree_filter(tree, condition, do_x = None, do_y = None):
    do_x = lambda x: None if do_x is None else do_x
    do_y = lambda x: None if do_y is None else do_y

    update_if = lambda path, val: do_x(val) if condition(path, val) else do_y(val)
    return jax.tree_util.tree_map_with_path(update_if, tree)

def tree_apply_at(tree, func_mapper: Union[dict[str, callable], callable]):
    if isinstance(func_mapper, Callable):
        return jax.tree.map(lambda x: func_mapper(x), tree)
    
    def update_if_key_matches(path, value):
        if (key := path[0].key) in func_mapper:
            return func_mapper[key](value)
        return value
    return jax.tree_util.tree_map_with_path(update_if_key_matches, tree)

def tree_set_at(tree, keys_values: Union[dict[str, Any], Any], inds = None):
    inds = slice(None) if inds is None else inds
    setter = lambda v: lambda x: x.at[inds].set(v)
    if not isinstance(keys_values, dict):
        return tree_apply_at(tree, setter(keys_values))
    return tree_apply_at(tree, {k: setter(v) for k, v in keys_values.items()})

def tree_get_at(tree, keys = None, inds = None):
    inds = slice(None) if inds is None else inds
    getter = lambda x: x.at[inds].get()
    if keys is None:
        return tree_apply_at(tree, getter)
    return tree_apply_at(tree, {k: getter for k in keys})

has_top_level_key = lambda d, key: any(k[0] == key for k in d)

In [501]:
from dataclasses import dataclass, field, replace

@dataclass
class TestModule:

    node_attrs: dict[Any, Any]
    edge_attrs: dict[Any, Any]
    global_attrs: dict[Any, Any]

    base: TestModule = None

    def __post_init__(self):
        if self.base is None:
            self.base = self
        
        if "externals" not in self.global_attrs:
            self.global_attrs["externals"] = {}
        if "channels" not in self.global_attrs:
            self.global_attrs["channels"] = {}
        if "synapses" not in self.global_attrs:
            self.global_attrs["synapses"] = {}

    @property
    def _nodes_in_view(self):
        return self.node_attrs[("index",)]
    
    @property
    def _edges_in_view(self):
        return self.edge_attrs["index_pre"]

    @property
    def _num_nodes(self):
        return len(self._nodes_in_view)
    
    @property
    def _num_edges(self):
        return len(self._edges_in_view)

    def __repr__(self):
        node_keys = list(self.node_attrs.keys())
        edge_keys = list(self.edge_attrs.keys())
        global_keys = list(self.global_attrs.keys())
        return f"TestModule(node_attrs={self._num_nodes}*{node_keys}, edge_attrs={self._num_edges}*{edge_keys}, global_attrs={global_keys})"

    def _select_nodes(self, keys = None, inds = None):
        return tree_get_at(self.node_attrs, keys=keys, inds=inds)

    def _select_edges(self, keys = None, inds = None):
        return tree_get_at(self.edge_attrs, keys=keys, inds=inds)
    
    def _set_node_attrs(self, keys_values, inds = None):
        inds = self._nodes_in_view if inds is None else inds
        node_attrs = tree_set_at(self.base.node_attrs, keys_values=keys_values, inds=inds)
        updated_base = replace(self.base, node_attrs=node_attrs)
        updated_node_attrs = updated_base._select_nodes(inds=inds)
        updated_view = replace(self, base=updated_base, node_attrs=updated_node_attrs)
        return updated_view
    
    def _set_edge_attrs(self, keys_values, inds = None):
        return tree_set_at(self.edge_attrs, keys_values=keys_values, inds=inds)

    def select(self, nodes=None, edges=None):
        node_inds = self._nodes_in_view if nodes is None else nodes
        edge_inds = self._edges_in_view if edges is None else edges

        node_attrs = self._select_nodes(inds=node_inds)
        edge_attrs = self._select_edges(inds=edge_inds)

        return replace(self, node_attrs=node_attrs, edge_attrs=edge_attrs, base=self.base)
    
    def _init_node_attrs(self, d: dict[str, Any], inds = None, init_value = jnp.nan):
        inds = self._nodes_in_view if inds is None else inds
        data_type = lambda x: x.dtype if isinstance(x, jnp.ndarray) else np.dtype(type(x))
        init_node_attrs = jax.tree.map(lambda x: jnp.stack([init_value*x]*self._num_nodes, axis=0, dtype=data_type(x)), d)
        base_node_attrs = self.base.node_attrs.copy()
        base_node_attrs.update(init_node_attrs)
        updated_base = replace(self.base, node_attrs=base_node_attrs)
        init_view = replace(self, base=updated_base)
        return init_view._set_node_attrs(keys_values=d, inds=inds)
    
    @property
    def pandas(self):
        return jax_to_pandas(self.node_attrs, self.edge_attrs, self.global_attrs)
    
    @property
    def nodes(self):
        return self.pandas[0]
    
    @property
    def edges(self):
        return self.pandas[1]
    
    @property
    def globals(self):
        return self.pandas[2]
    
    @property
    def recordings(self):
        df = pd.DataFrame.from_dict(self.globals["recordings"], orient="index").T.set_index("index")
        df.index.name = None
        return df
    
    @property
    def externals(self):
        df = pd.DataFrame.from_dict(self.globals["externals"], orient="index").T.set_index("index")
        df.index.name = None
        return df
    
    def set(self, key, value):
        is_morph_key = ("morphology", key) in self.node_attrs
        is_channel_key = ("channels", key) in self.node_attrs

        if is_morph_key or is_channel_key:
            key = ("morphology", key) if is_morph_key else ("channels", key)
            updated_view = self._set_node_attrs(keys_values={key: value}, inds=self._nodes_in_view)
            return updated_view
        
        # elif key in self.edge_attrs:
        #     updated_base = self.base._set_edges([key], [value], self._edges_in_view)
            
        #     updated_edge_attrs = updated_base._select_edges(inds=self._edges_in_view)
        #     updated_view = replace(self, base=updated_base, edge_attrs=updated_edge_attrs)
        #     return updated_view
        
        else:
            raise ValueError(f"Key {key} not found or not mutable via `.set()`")
    
    def insert(self, channel):
        base_global_attrs = self.base.global_attrs.copy()
        channel_param_states = {**channel.states, **channel.params, channel.name: True}
        channel_setter = {("channels", k): v for k,v in channel_param_states.items()}
        if not channel.name in base_global_attrs["channels"]:
            base_global_attrs["channels"][channel.name] = channel
            updated_base = replace(self.base, global_attrs=base_global_attrs)
            updated_view = replace(self, base=updated_base)
            updated_view = updated_view._init_node_attrs(channel_setter)
            return updated_view
        else:
            return self._set_node_attrs(keys_values=channel_setter)

    
    def record(self, key):
        if ("recordings", key) not in self.base.node_attrs:
            return self._init_node_attrs({("recordings", key): True}, init_value=False)
        else:
            return self._set_node_attrs({("recordings", key): True})

    def add_to_group(self, group):
        if ("groups", group) not in self.base.node_attrs:
            return self._init_node_attrs({("groups", group): True}, init_value=False)
        else:
            return self._set_node_attrs({("groups", group): True})

    def stimulate(self, key, values):
        
        if ("externals", key) not in self.base.node_attrs:
            updated_view = self._init_node_attrs({("externals", key): True}, init_value=False)
        else:
            updated_view = self._set_node_attrs({("externals", key): True})
        
        for idx in self._nodes_in_view:
            updated_view.base.global_attrs["externals"][(key, int(idx))] = values
            updated_view.global_attrs["externals"][(key, int(idx))] = values

        return updated_view


# def connect(pre_view, post_view, synapse):
#     pre_view_base_global_attrs = pre_view.base.global_attrs.copy()
#     pre_view_base_global_attrs["synapses"][synapse.name] = synapse
    
#     pre_view_base_edge_attrs = pre_view.base.edge_attrs.copy()
#     pre_nodes = pre_view._nodes_in_view
#     post_nodes = post_view._nodes_in_view

#     for k, v in {**synapse.states, **synapse.params}.items():
#         pre_view_base_edge_attrs[k] = jnp.nan*jnp.stack([v]*pre_view._num_edges, axis=0)
#         pre_view_base_edge_attrs[k] = pre_view_base_edge_attrs[k].at[pre_view._edges_in_view].set(v)

#     updated_base = replace(pre_view.base, edge_attrs=pre_view_base_edge_attrs, global_attrs=pre_view_base_global_attrs)
#     updated_edge_attrs = updated_base._select_edges(pre_view._edges_in_view)

#     return 

In [502]:
class TestChannel:
    def __init__(self, name=None):
        self.name = name if name is not None else "TestChannel"
        self.states = {"m": 0.0, "h": 0.0}
        # self.params = {"E": jnp.array([0.0]), "g": jnp.array([[0.0, 0.0], [0.0, 0.0]])}
        self.params = {"E": 0.0, "g": 0.0}

class TestSynapse:
    def __init__(self, name=None):
        self.name = name if name is not None else "TestSynapse"
        self.states = {"g": 0.0}
        # self.params = {"E": jnp.array([0.0]), "g": jnp.array([0.0])}
        self.params = {"E": 0.0, "g": 0.0}

In [503]:
G = swc_to_nx("../jaxley/tests/swc_files/morph_ca1_n120.swc")
jax_node_attrs, jax_edge_attrs, jax_global_attrs = nx_to_jax(G)
jax_node_attrs["index"] = jax_node_attrs["index"] -1 # TMP FIX FOR INDEXING, otherwise index drift for select

jax_node_attrs = {("morphology", k) if k != "index" else (k,): v for k, v in jax_node_attrs.items()}

cell = TestModule(jax_node_attrs, jax_edge_attrs, jax_global_attrs)
view = cell.select(nodes=jnp.array([1, 2, 4, 6, 7]))

view = view.set("r", 20)
view = view.insert(TestChannel())
view = view.record("i")
view = view.add_to_group("test")
view = view.stimulate("i", jnp.array([0.0, 0.0]))

In [504]:
view.nodes

Unnamed: 0_level_0,channels,channels,channels,channels,channels,externals,groups,morphology,morphology,morphology,morphology,morphology,recordings
Unnamed: 0_level_1,E,TestChannel,g,h,m,i,test,id,r,x,y,z,i
1,0.0,True,0.0,0.0,0.0,True,True,1,20.0,1.85,-4.03,0.0,True
2,0.0,True,0.0,0.0,0.0,True,True,1,20.0,1.98,-6.0,0.0,True
4,0.0,True,0.0,0.0,0.0,True,True,1,20.0,2.17,-8.49,0.0,True
6,0.0,True,0.0,0.0,0.0,True,True,1,20.0,2.74,-9.85,0.0,True
7,0.0,True,0.0,0.0,0.0,True,True,1,20.0,3.0,-10.35,0.0,True


In [467]:
@jax.jit
def test():
    new_cell = cell.select(nodes=jnp.array([1, 2])).set("r", 20)

test()

# connect(cell.select(nodes=jnp.array([1, 2])), cell.select(nodes=jnp.array([3, 4])), TestSynapse())

# cell.select(nodes=jnp.array([1, 2, 3, 4, 5])).select(nodes=jnp.array([1, 3])).insert(TestChannel()) # this is slow!
# cell.select(nodes=jnp.array([2, 3, 4, 5])).record("v")
# cell.select(nodes=jnp.array([2, 3, 4, 5])).stimulate("i", jnp.array([0.0, 0.0, 0.0, 0.0]))
# cell.select(nodes=jnp.array([2, 3, 4, 5])).add_to_group("test")
