In [1]:
# %load_ext autoreload
# %autoreload 2

In [2]:
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".8"

In [3]:
from __future__ import annotations
import jax
import jax.numpy as jnp
import pandas as pd

import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from jax import vmap

from typing import Optional, List, Dict, Any, Union, Set, Tuple
from matplotlib.axes import Axes
from dataclasses import dataclass, field

In [4]:
def pandas_to_nx(
    node_attrs: pd.DataFrame, edge_attrs: pd.DataFrame, global_attrs: pd.Series
) -> nx.DiGraph:
    """Convert node_attrs, edge_attrs and global_attrs from pandas datatypes to a NetworkX DiGraph.

    Args:
        node_attrs: DataFrame containing node attributes
        edge_attrs: DataFrame containing edge attributes
        global_attrs: Series containing global graph attributes

    Returns:
        A directed graph with nodes, edges and global attributes from the input data.
    """
    has_edge_attrs = None if edge_attrs.empty else True
    G = nx.from_pandas_edgelist(
        edge_attrs.reset_index(),
        source="level_0",
        target="level_1",
        edge_attr=has_edge_attrs,
        create_using=nx.DiGraph(),
    )

    nx.set_node_attributes(G, node_attrs.to_dict(orient="index"))
    G.graph.update(global_attrs.to_dict())
    return G


def nx_to_pandas(G: nx.DiGraph) -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series]:
    """Convert a NetworkX DiGraph to pandas datatypes.

    Args:
        G: Input directed graph

    Returns:
        Tuple containing:
        - DataFrame of node attributes
        - DataFrame of edge attributes
        - Series of global graph attributes
    """
    edge_df = nx.to_pandas_edgelist(G).set_index(["source", "target"])
    edge_df.index.names = [None, None]
    node_df = pd.DataFrame.from_dict(dict(G.nodes(data=True)), orient="index")

    return node_df, edge_df, pd.Series(G.graph)


def swc_to_nx(fname: str, num_lines: Optional[int] = None) -> nx.DiGraph:
    """Read a SWC morphology file into a NetworkX DiGraph.

    Args:
        fname: Path to the SWC file
        num_lines: Number of lines to read from the file

    Returns:
        A directed graph representing the morphology where:
        - Nodes have attributes: id, x, y, z, r (radius)
        - Edges represent parent-child relationships
    """
    i_id_xyzr_p = np.loadtxt(fname)[:num_lines]

    graph = nx.DiGraph()
    for i, id, x, y, z, r, p in i_id_xyzr_p.tolist():  # tolist: np.float64 -> float
        graph.add_node(int(i), **{"id": int(id), "x": x, "y": y, "z": z, "r": r})
        if p != -1:
            graph.add_edge(int(p), int(i))
    return graph


def nx_to_jax(G: nx.DiGraph) -> jax.tree_util.PyTree:
    """Convert a NetworkX DiGraph to a Jax tree.

    Args:
        G: Input directed graph

    Returns:
        A Jax tree representing the morphology.
    """

    inds, jax_node_attrs = jax.tree_util.tree_map(lambda *args: jnp.array(args), *G.nodes(data=True))
    jax_node_attrs["index"] = jnp.array(inds)

    *inds, jax_edge_attrs = jax.tree_util.tree_map(lambda *args: jnp.array(args), *G.edges(data=True))
    jax_edge_attrs["index_pre"] = jnp.array(inds[0])
    jax_edge_attrs["index_post"] = jnp.array(inds[1])

    jax_global_attrs = {k: jnp.array(v) for k, v in G.graph.items()}

    return jax_node_attrs, jax_edge_attrs, jax_global_attrs

def jax_to_nx(jax_node_attrs: jax.tree_util.PyTree, jax_edge_attrs: jax.tree_util.PyTree, jax_global_attrs: jax.tree_util.PyTree) -> nx.DiGraph:
    """Convert a Jax tree to a NetworkX DiGraph.

    Args:
        jax_node_attrs: Jax tree of node attributes
        jax_edge_attrs: Jax tree of edge attributes
        jax_global_attrs: Jax tree of global graph attributes

    Returns:
        A NetworkX DiGraph representing the morphology.
    """
    node_df, edge_df, global_attrs = jax_to_pandas(jax_node_attrs, jax_edge_attrs, jax_global_attrs)
    return pandas_to_nx(node_df, edge_df, global_attrs)

def jax_to_pandas(jax_node_attrs: jax.tree_util.PyTree, jax_edge_attrs: jax.tree_util.PyTree, jax_global_attrs: jax.tree_util.PyTree) -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series]:
    """Convert a Jax tree to pandas datatypes.

    Args:
        jax_node_attrs: Jax tree of node attributes
        jax_edge_attrs: Jax tree of edge attributes
        jax_global_attrs: Jax tree of global graph attributes

    Returns:
        Tuple containing:
        - DataFrame of node attributes
        - DataFrame of edge attributes
        - Series of global graph attributes
    """

    node_index = np.array(jax_node_attrs[("index",)])
    node_attrs_df = pd.DataFrame({k:v.tolist() for k, v in jax_node_attrs.items()}, index=node_index).drop(columns=[("index",)])
    
    edge_index = pd.MultiIndex.from_arrays(np.vstack([jax_edge_attrs["index_pre"], jax_edge_attrs["index_post"]]))
    edge_attrs_df = pd.DataFrame({k:v.tolist() for k, v in jax_edge_attrs.items()}, index = edge_index).drop(columns=["index_pre", "index_post"])

    global_attrs_df = pd.Series(jax_global_attrs)

    return node_attrs_df, edge_attrs_df, global_attrs_df

def pandas_to_jax(node_df: pd.DataFrame, edge_df: pd.DataFrame, global_attrs: pd.Series) -> jax.tree_util.PyTree:
    """Convert pandas datatypes to a Jax tree.

    Args:
        node_df: DataFrame of node attributes
        edge_df: DataFrame of edge attributes
        global_attrs: Series of global graph attributes
    """
    node_attrs = node_df.to_dict(orient="index")
    edge_attrs = edge_df.to_dict(orient="index")

    inds, jax_node_attrs = jax.tree_util.tree_map(lambda *args: jnp.array(args), *node_attrs.items())
    jax_node_attrs[("index", )] = jnp.array(inds)

    *inds, jax_edge_attrs = jax.tree_util.tree_map(lambda *args: jnp.array(args), *edge_attrs.items())
    jax_edge_attrs["index_pre"] = jnp.array(inds[0])
    jax_edge_attrs["index_post"] = jnp.array(inds[1])

    jax_global_attrs = {k: jnp.array(v) for k, v in global_attrs.items()}

    return jax_node_attrs, jax_edge_attrs, jax_global_attrs

In [94]:
from dataclasses import dataclass, field, replace
from copy import deepcopy

tree_at = lambda tree, inds: jax.tree.map(lambda x: x.at[inds].get(), tree)

@dataclass
class TestModule:

    node_attrs: dict[Any, Any]
    edge_attrs: dict[Any, Any]
    global_attrs: dict[Any, Any]

    base: TestModule = None

    def __post_init__(self):
        if self.base is None:
            self.base = self
        
        if "externals" not in self.global_attrs:
            self.global_attrs["externals"] = {}
        if "channels" not in self.global_attrs:
            self.global_attrs["channels"] = {}
        if "synapses" not in self.global_attrs:
            self.global_attrs["synapses"] = {}

    @property
    def _nodes_in_view(self):
        return self.node_attrs[("index",)]
    
    @property
    def _edges_in_view(self):
        return self.edge_attrs["index_pre"]

    @property
    def _num_nodes(self):
        return len(self._nodes_in_view)
    
    @property
    def _num_edges(self):
        return len(self._edges_in_view)

    def __repr__(self):
        node_keys = list(self.node_attrs.keys())
        edge_keys = list(self.edge_attrs.keys())
        global_keys = list(self.global_attrs.keys())
        return f"TestModule(node_attrs={self._num_nodes}*{node_keys}, edge_attrs={self._num_edges}*{edge_keys}, global_attrs={global_keys})"

    def _select_nodes(self, inds):
        return tree_at(self.node_attrs, inds)

    def _select_edges(self, inds):
        return tree_at(self.edge_attrs, inds)

    def select(self, nodes=None, edges=None):
        node_inds = self._nodes_in_view if nodes is None else nodes
        edge_inds = self._edges_in_view if edges is None else edges

        node_attrs = self._select_nodes(node_inds)
        edge_attrs = self._select_edges(edge_inds)

        return replace(self, node_attrs=node_attrs, edge_attrs=edge_attrs, base=self.base)
    
    @property
    def pandas(self):
        return jax_to_pandas(self.node_attrs, self.edge_attrs, self.global_attrs)
    
    @property
    def nodes(self):
        return self.pandas[0]
    
    @property
    def edges(self):
        return self.pandas[1]
    
    @property
    def globals(self):
        return self.pandas[2]
    
    @property
    def recordings(self):
        df = pd.DataFrame.from_dict(self.globals["recordings"], orient="index").T.set_index("index")
        df.index.name = None
        return df
    
    @property
    def externals(self):
        df = pd.DataFrame.from_dict(self.globals["externals"], orient="index").T.set_index("index")
        df.index.name = None
        return df
    
    def set(self, key, value):
        if ("morphology", key) in self.node_attrs or ("channels", key):
            node_attrs = self.base.node_attrs.copy()
            node_attrs[("morphology", key)] = node_attrs[("morphology", key)].at[self._nodes_in_view].set(value)
            updated_base = replace(self.base, node_attrs=node_attrs)
            return replace(self, base=updated_base)
        
        elif key in self.edge_attrs:
            edge_attrs = self.base.edge_attrs.copy()
            edge_attrs[key] = edge_attrs[key].at[self._edges_in_view].set(value)
            updated_base = replace(self.base, edge_attrs=edge_attrs)
            return replace(self, base=updated_base)
        
        elif key in self.global_attrs:
            global_attrs = self.base.global_attrs.copy()
            global_attrs[key] = global_attrs[key].at[:].set(value)
            updated_base = replace(self.base, global_attrs=global_attrs)
            return replace(self, base=updated_base)
        
        else:
            raise ValueError(f"Key {key} not found in any attribute")
    
    def insert(self, channel):
        global_attrs = self.base.global_attrs.copy()
        global_attrs["channels"][channel.name] = channel

        node_attrs = self.base.node_attrs.copy()
        for k, v in {**channel.states, **channel.params}.items():
            if ("channels", k) not in node_attrs:
                node_attrs[("channels", k)] = jnp.stack([jnp.nan*v]*self.base._num_nodes, axis=0)
            node_attrs[("channels", k)] = node_attrs[("channels", k)].at[self._nodes_in_view].set(v)
        
        updated_base = replace(self.base, node_attrs=node_attrs, global_attrs=global_attrs)
        return replace(self, base=updated_base)
    
    # def record(self, key):
    #     if ("recordings", key) not in self.base.node_attrs:
    #         self.base.node_attrs[("recordings", key)] = jnp.stack([False]*self.base._num_nodes, axis=0)
    #     self.base.node_attrs[("recordings", key)] = self.base.node_attrs[("recordings", key)].at[self._nodes_in_view].set(True)

    # def stimulate(self, key, values):
    #     if ("externals", key) not in self.base.node_attrs:
    #         self.base.node_attrs[("externals", key)] = jnp.stack([False]*self.base._num_nodes, axis=0)
    #     self.base.node_attrs[("externals", key)] = self.base.node_attrs[("externals", key)].at[self._nodes_in_view].set(True)

    #     for idx in self._nodes_in_view:
    #         self.base.global_attrs["externals"][int(idx)] = values

    # def add_to_group(self, group):
    #     if ("groups", group) not in self.base.node_attrs:
    #         self.base.node_attrs[("groups", group)] = jnp.stack([False]*self.base._num_nodes, axis=0)
    #     self.base.node_attrs[("groups", group)] = self.base.node_attrs[("groups", group)].at[self._nodes_in_view].set(True)
    #     return self
    
# def connect(pre_view, post_view, synapse):
#     pre_view.base.global_attrs["synapses"][synapse.name] = synapse
#     for k, v in {**synapse.states, **synapse.params}.items():
#         pre_view.base.edge_attrs[k] = jnp.nan*jnp.stack([v]*pre_view._num_edges, axis=0)
#         pre_view.base.edge_attrs[k] = pre_view.base.edge_attrs[k].at[pre_view._edges_in_view].set(v)
#     return pre_view.base

In [95]:
class TestChannel:
    def __init__(self, name=None):
        self.name = name if name is not None else "TestChannel"
        self.states = {"m": 0.0, "h": 0.0}
        self.params = {"E": jnp.array([0.0]), "g": jnp.array([[0.0, 0.0], [0.0, 0.0]])}

class TestSynapse:
    def __init__(self, name=None):
        self.name = name if name is not None else "TestSynapse"
        self.states = {"g": 0.0}
        self.params = {"E": jnp.array([0.0]), "g": jnp.array([0.0])}

In [96]:
G = swc_to_nx("../jaxley/tests/swc_files/morph_ca1_n120.swc")
jax_node_attrs, jax_edge_attrs, jax_global_attrs = nx_to_jax(G)

jax_node_attrs = {("morphology", k) if k != "index" else (k,): v for k, v in jax_node_attrs.items()}

cell = TestModule(jax_node_attrs, jax_edge_attrs, jax_global_attrs)
view = cell.select(nodes=jnp.array([1, 2]))
view = view.insert(TestChannel())

In [99]:
@jax.jit
def test():
    new_cell = cell.select(nodes=jnp.array([1, 2])).set("r", 20)

test()

# connect(cell.select(nodes=jnp.array([1, 2])), cell.select(nodes=jnp.array([3, 4])), TestSynapse())

# cell.select(nodes=jnp.array([1, 2, 3, 4, 5])).select(nodes=jnp.array([1, 3])).insert(TestChannel()) # this is slow!
# cell.select(nodes=jnp.array([2, 3, 4, 5])).record("v")
# cell.select(nodes=jnp.array([2, 3, 4, 5])).stimulate("i", jnp.array([0.0, 0.0, 0.0, 0.0]))
# cell.select(nodes=jnp.array([2, 3, 4, 5])).add_to_group("test")


In [867]:
@jax.tree_util.register_dataclass
@dataclass
class MorphTree:
    """MorphTree is a custom cataclass that holds the node and edge attributes of a morphology.
    
    MorphTree is used to store the node and edge attributes of a morphology / jaxley Module
    as a pytree, to allow for easy manipulation of the Module / morphology parameters 
    using jax transformations.

    MorphTree also allows for easy conversion to and from pandas DataFrames, and networkx 
    DiGraphs, as well as basic convenience functions for plotting and renaming or reordering
    of nodes and edges.
    """
    node_attrs: Dict[int, Dict[str, Any]]
    edge_attrs: Dict[Tuple[int, int], Dict[str, Any]]
    global_attrs: Dict[str, Any] = field(default_factory=dict)

    @property
    def nodes(self) -> jnp.ndarray:
        """Returns the node indices as a jax array."""
        return jnp.array(list(self.node_attrs.keys())).astype(int)

    @property
    def edges(self) -> jnp.ndarray:
        """Returns the edge indices as a jax array."""
        return jnp.array(list(self.edge_attrs.keys())).astype(int)

    def __repr__(self) -> str:
        n_nodes = len(self.node_attrs)
        n_edges = len(self.edge_attrs)

        node_keys = list(next(iter(self.node_attrs.values())).keys())
        if len(self.edge_attrs) > 0:
            edge_keys = list(next(iter(self.edge_attrs.values())).keys())
        else:
            edge_keys = []

        node_attrs = node_keys if len(self.node_attrs) > 0 else []
        edge_attrs = edge_keys if len(self.edge_attrs) > 0 else []
        return f"MorphTree(nodes={n_nodes}*{node_attrs}, edges={n_edges}*{edge_attrs}, global={list(self.global_attrs.keys())})"
    
    def __iter__(self):
        """Allows unpacking of MorphTree as: node_attrs, edge_attrs, global_attrs = *tree"""
        yield self.node_attrs
        yield self.edge_attrs
        yield self.global_attrs
    
    def node(self, i: int) -> Dict[str, Any]:
        """Returns the node attributes for the node with index i."""
        return self.node_attrs[i]
    
    def edge(self, i: int, j: int) -> Dict[str, Any]:
        """Returns the edge attributes for the edge between nodes i and j."""
        return self.edge_attrs[i, j]
    
    def to_nx(self) -> nx.DiGraph:
        """Returns the MorphTree as a networkx DiGraph."""
        G = nx.DiGraph()
        G.add_nodes_from(self.node_attrs.items())
        G.add_edges_from((i, j, d) for (i, j), d in self.edge_attrs.items())
        G.graph.update(self.global_attrs)
        return G
    
    @staticmethod
    def from_nx(G: nx.DiGraph) -> MorphTree:
        """Returns a MorphTree from a networkx DiGraph."""
        node_attrs = {n: G.nodes[n] for n in G.nodes}
        edge_attrs = {(i, j): G.edges[i, j] for i, j in G.edges}
        return MorphTree(node_attrs, edge_attrs, G.graph)
    
    def to_pandas(self, return_global_attrs: bool = True) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """Returns the MorphTree as a pandas DataFrame."""
        node_df = pd.DataFrame(self.node_attrs.values(), index=self.node_attrs.keys())
        edge_df = pd.DataFrame(self.edge_attrs.values(), index=self.edge_attrs.keys())
        edge_index = pd.MultiIndex.from_arrays(np.array(self.edges).T)
        edge_df = edge_df.set_index(edge_index)

        if return_global_attrs:
            return node_df, edge_df, pd.Series(self.global_attrs)
        return node_df, edge_df
    
    @staticmethod
    def from_pandas(node_df: pd.DataFrame, edge_df: pd.DataFrame, global_attrs: pd.Series = pd.Series()) -> MorphTree:
        """Returns a MorphTree from a pandas DataFrame."""
        node_attrs = node_df.to_dict(orient="index")
        edge_attrs = edge_df.to_dict(orient="index")
        return MorphTree(node_attrs, edge_attrs, global_attrs.to_dict())
    
    def plot(self, dims=(0,1), ax: Optional[Axes] = None, **kwargs: Any) -> Axes:
        """Uses networkx to plot the MorphTree.
        
        Args:
            dims: Dimensions to plot (0:x, 1:y, 2:z).
            ax: plt.Axes.
            **kwargs: kwargs for networkx.draw.

        Returns:
            The Axes object on which the MorphTree was plotted.
        """
        G = self.to_nx()
        pos = {}
        dims2axes = {0: "x", 1: "y", 2: "z"}
        for n, attr in G.nodes(data=True):
            if "x" in attr:  # assume y is also present
                pos[n] = (attr[dims2axes[dims[0]]], attr[dims2axes[dims[1]]])
        
        ax = ax if ax is not None else plt.gca()
        nx.draw(G, pos, with_labels=True, ax=ax, **kwargs)
        return ax

    def reindex_nodes(self, mapping: dict) -> MorphTree:
        """Reindexes the nodes of the MorphTree according to the mapping dictionary.
        
        Args:
            mapping: A dict mapping the old to new node indices.

        Returns:
            A new MorphTree with the nodes reindexed according to the mapping.
        """
        new_node_attrs = {mapping[i]: attrs for i, attrs in self.node_attrs.items()}
        
        new_edge_attrs = {}
        for (i, j), attrs in self.edge_attrs.items():
            new_edge_attrs[(mapping[i], mapping[j])] = attrs
            
        return MorphTree(new_node_attrs, new_edge_attrs, self.global_attrs)
    
    def reorder_tree(self, new_order: jnp.ndarray) -> MorphTree:
        """Reorders the nodes of the MorphTree according to the new order.
        
        Edges are flipped to ensure they are always in ascending order.
        
        Args:
            new_order: New node order. new_order[i] is the new index of node i.

        Returns:
            A new MorphTree with the nodes reordered according to the new order.
        """
        # TODO: check this does what I think it does, i.e. change the edge orientation
        # in order of appearance of the nodes in self.nodes.
        edges = np.array(self.edges)
        np_order = np.array(new_order)
        idx_i = np.where(edges[:,0] == np_order[:, None])[0]
        idx_j = np.where(edges[:,1] == np_order[:, None])[0]
        is_descending = ~(idx_i < idx_j)
        for (i,j) in edges[is_descending]:
            print(i,j)
            self.edge_attrs[j, i] = self.edge_attrs.pop((i, j))
        return self
    
    def subgraph(self, nodes: List[int]) -> MorphTree:
        """Returns a subset of nodes in the MorphTree.

        Edges are only included if both nodes are in the subgraph.
        
        Args:
            nodes: List of node indices to include in the subgraph.

        Returns:
            A new MorphTree containing only the specified nodes and their edges.
        """
        node_attrs_subset = {i: self.node_attrs[i] for i in nodes}
        edge_attrs_subset = {(i,j): attrs for (i,j), attrs in self.edge_attrs.items() if i in nodes and j in nodes}
        return MorphTree(node_attrs_subset, edge_attrs_subset, self.global_attrs)
