In [1]:
%load_ext autoreload
%autoreload 2

from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".8"

In [393]:
import jaxley as jx
from jaxley.channels import HH

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import networkx as nx
import jax.numpy as jnp
from jaxley.utils.cell_utils import interpolate_xyz, loc_of_index

In [529]:
class WrappedModule:
    def __init__(self, module):
        self.nodes = module.nodes
        self.edges = module.edges
        self.branch_edges = module.branch_edges
        self.recordings = module.recordings
        self.xyzr = module.xyzr
        self.nseg = module.nseg
        self._in_view = self.nodes.index.to_numpy()
        self._scope = "global"
        self.groups = {}
        self.__class__.__name__ = module.__class__.__name__

        self._add_local_indices()
        self.view = self

    def _add_local_indices(self) -> pd.DataFrame:
        idx_cols = ["global_comp_index", "global_branch_index", "global_cell_index"]
        self.nodes.rename(columns={col.replace("global_", ""):col for col in idx_cols}, inplace=True)
        idcs = self.nodes[idx_cols]
        
        def reindex_a_by_b(df, a, b):
            df.loc[:, a] = df.groupby(b)[a].rank(method="dense").astype(int) - 1
            return df

        idcs = reindex_a_by_b(idcs, idx_cols[1], idx_cols[2])
        idcs = reindex_a_by_b(idcs, idx_cols[0], idx_cols[1:])
        idcs.columns = [col.replace("global", "local") for col in idx_cols]
        self.nodes = pd.concat([idcs, self.nodes], axis=1) 

    def _reformat_index(self, idx):
        idx = np.array([idx]) if isinstance(idx, (int, np.int64)) else idx
        idx = np.array(idx) if isinstance(idx, (list,range)) else idx
        idx = np.arange(self.nodes.index.max() + 1)[idx] if isinstance(idx, slice) else idx
        if isinstance(idx, str):
            assert idx == "all", "Only 'all' is allowed"
            idx = self.nodes.index.to_numpy() 
        assert isinstance(idx, np.ndarray), "Invalid type"
        assert idx.dtype == np.int64, "Invalid dtype"
        return idx.reshape(-1)

    def at(self, idx):
        idx = self._reformat_index(idx)
        new_indices = self._in_view[idx]
        return View(self, at=new_indices)

    def set(self, key, value):
        self.view.nodes.loc[self._in_view, key] = value

    def set_scope(self, scope):
        self._scope = scope
    
    def _at_level(self, level: str, idx):
        idx = self._reformat_index(idx)
        where = self.nodes[self._scope+f"_{level}_index"].isin(idx)
        inds = np.where(where)[0]
        return self.at(inds)

    def cell(self, idx):
        return self._at_level("cell", idx)
    
    def branch(self, idx):
        return self._at_level("branch", idx)
    
    def comp(self, idx):
        return self._at_level("comp", idx)
    
    def loc(self, at: float):
        comp_edges = np.linspace(0, 1, self.view.nseg+1)
        idx = np.digitize(at, comp_edges)
        view = self.comp(idx)
        return view
        
    def add_group(self, name, idx=None):
        idx = self._in_view if idx is None else idx
        idx = self._reformat_index(idx)
        self.view.groups[name] = idx

    def __getattr__(self, key):
        if key.startswith("__"):
            return super().__getattribute__(key)

        if key in self.groups:
            return self.at(self.groups[key])
        
    def show(self):
        return self.nodes.copy() # prevents this from being edited
    
    def __getitem__(self, idx):
        levels = ["network", "cell", "branch", "comp"]
        module = self.view.__class__.__name__.lower() # 
        module = "comp" if module == "compartment" else module
        
        children = levels[levels.index(module)+1:]
        idx = idx if isinstance(idx, tuple) else (idx,)
        view = self
        for i, child in enumerate(children):
            view = view._at_level(child, idx[i])
        return view
    
    def _iter_level(self, level):
        col = self._scope + f"_{level}_index"
        idxs = self.nodes[col].unique()
        for idx in idxs:
            yield self._at_level(level, idx)
    
    @property
    def cells(self):
        yield from self._iter_level("cell")
    
    @property
    def branches(self):
        yield from self._iter_level("branch")

    @property
    def comps(self):
        yield from self._iter_level("comp")    

    @property
    def shape(self):
        cols = ["global_cell_index", "global_branch_index", "global_comp_index"]
        raw_shape = self.nodes[cols].nunique().to_list()

        # ensure (net.shape -> dim=3, cell.shape -> dim=2, branch.shape -> dim=1, comp.shape -> dim=0)
        levels = ["network", "cell", "branch", "comp"]
        module = self.view.__class__.__name__.lower()
        module = "comp" if module == "compartment" else module
        shape = tuple(raw_shape[levels.index(module):])
        return shape
    
    # TODOs
    def _append_multiple_synapses(self, pre_rows, post_rows, synapse_type):
        pass
    
    def insert(self, channel):
        pass

    def record(self, state, verbose=False):
        pass

    def stimulate(self, state, verbose=False):
        pass

    def show(self):
        with View(net, self._at) as view:
            return view.nodes
    
    def set(self, key, value):
        self.nodes.loc[:,key] = value
    
    # def _at_level(self, level: str, idx):
    #     idx = self._reformat_index(idx)
    #     with View(net, self._at) as view:
    #         where = view.nodes[f"{level}_index"].isin(idx)
    #     inds = np.where(where)[0]
    #     return self.at(inds)

    # def cell(self, idx):
    #     return self._at_level("cell", idx)
    
    # def branch(self, idx):
    #     return self._at_level("branch", idx)
    
    # def comp(self, idx):
    #     return self._at_level("comp", idx)
        
    
test_net = WrappedModule(net)

In [407]:
class WrappedModule:
    def __init__(self, module):
        self.nodes = module.nodes
        self.edges = module.edges
        self.branch_edges = module.branch_edges
        self.recordings = module.recordings
        self._in_view = self.nodes.index.to_numpy()
        self._scope = "global"

        self._add_local_indices()
        self.view = View(self)

    def _add_local_indices(self) -> pd.DataFrame:
        idx_cols = ["global_comp_index", "global_branch_index", "global_cell_index"]
        self.nodes.rename(columns={col.replace("global_", ""):col for col in idx_cols}, inplace=True)
        idcs = self.nodes[idx_cols]
        
        def reindex_a_by_b(df, a, b):
            df.loc[:, a] = df.groupby(b)[a].rank(method="dense").astype(int) - 1
            return df

        idcs = reindex_a_by_b(idcs, idx_cols[1], idx_cols[2])
        idcs = reindex_a_by_b(idcs, idx_cols[0], idx_cols[1:])
        idcs.columns = [col.replace("global", "local") for col in idx_cols]
        self.nodes = pd.concat([idcs, self.nodes], axis=1) 

    def vis(self):
        pass

    def move(self, x,y,z, update_nodes=True):
        pass

    def move_to(self, x,y,z, update_nodes=True):
        pass

    def rotate(self, degrees, rotation_axis, update_nodes=True):
        pass


class View(WrappedModule):
    def __init__(self, pointer, at = None):
        # attrs with a static view
        self._scope = pointer._scope
        self.nseg = pointer.nseg
        self.view = pointer.view
        self.total_nbranches = pointer.total_nbranches
        
        # attrs affected by view
        self._in_view = pointer._in_view if at is None else at

        self.nodes = pointer.nodes.loc[self._in_view]
        self.branch_edges = pointer.branch_edges.loc[self._branch_edges_in_view]
        self.edges = pointer.edges.loc[self._edges_in_view]
        self.xyzr = self._view_of_xyzr(pointer)
        

        #TODO: 
        # self.recordings


        if len(self.nodes) == 0:
            raise ValueError("Nothing in view. Check your indices.")
        
    def _view_of_xyzr(self, pointer):
        prev_branch_inds = pointer._branches_in_view
        viewed_branch_inds = self._branches_in_view
        if prev_branch_inds is None:
            xyzr = pointer.xyzr
        else:
            branches2keep = np.isin(prev_branch_inds, viewed_branch_inds)
            branch_inds2keep = np.where(branches2keep)[0]
            xyzr = [pointer.xyzr[i] for i in branch_inds2keep]

        # Currently viewing with `.loc` will show the closest compartment
        # rather than the actual loc along the branch!
        viewed_nseg_for_branch = self.nodes.groupby("global_branch_index").size()
        incomplete_inds = np.where(viewed_nseg_for_branch != self.view.nseg)[0]
        incomplete_branch_inds = viewed_branch_inds[incomplete_inds]

        cond = self.nodes["global_branch_index"].isin(incomplete_branch_inds)
        interp_inds = self.nodes.loc[cond]
        local_inds_per_branch = interp_inds.groupby("global_branch_index")["local_comp_index"]
        locs = [loc_of_index(inds.to_numpy(), self.view.nseg) for _, inds in local_inds_per_branch]
        
        for i, loc in zip(incomplete_inds, locs):
            xyzr[i] = interpolate_xyz(loc, xyzr[i]).T
        return xyzr

    @property
    def _nodes_in_view(self):
        return self._in_view
    
    @property
    def _branch_edges_in_view(self):
        incl_branches = self.nodes["global_branch_index"].unique()
        pre = self.view.branch_edges["parent_branch_index"].isin(incl_branches)
        post = self.view.branch_edges["child_branch_index"].isin(incl_branches)
        viewed_branch_inds = self.view.branch_edges.index.to_numpy()[pre & post]
        return viewed_branch_inds
    
    @property
    def _edges_in_view(self):
        incl_comps = self.nodes["global_comp_index"].unique()
        pre = self.view.edges["global_pre_comp_index"].isin(incl_comps)
        post = self.view.edges["global_post_comp_index"].isin(incl_comps)
        viewed_edge_inds = self.view.edges.index.to_numpy()[pre & post]
        return viewed_edge_inds

    def __getattr__(self, name):
        # Delegate attribute access to the pointer if not found in View
        return getattr(self.pointer, name)
    
    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_value, exc_traceback):
        pass

In [531]:
test_net.cell(0).nseg

4

In [524]:
# add test that asserts that every attr in view also has a corresponding attr in module
# apart from a few allowed exceptions
# this should trigger if new attrs are added to module that should potentially
# be included in view if they need to be accessed in a specific way

def test_view_attrs(module):
    exceptions = ["_scope", "_at", "view"]

    for name, attr in module.__dict__.items():
        if name not in exceptions:
            assert hasattr(View(module, np.array([0,1])), name), f"View missing attribute: {name}"


In [530]:
comp = jx.Compartment()
branch = jx.Branch(comp, nseg=4)
cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1])
net = jx.Network([cell]*5)
test_net = WrappedModule(net)

test_net.set_scope("local")
test_net.cell([0,2]).branch([0,1,2]).comp([0,3]).set("v", 90)

with test_net.cell([0,2]).branch([0,1,2]).comp([1,2]) as view:
    view.set("v", 10)
    view.set("capacitance", 3)

test_net.nodes

Unnamed: 0,local_comp_index,local_branch_index,local_cell_index,global_comp_index,global_branch_index,global_cell_index,length,radius,axial_resistivity,capacitance,v
0,0,0,0,0,0,0,10.0,1.0,5000.0,1.0,90.0
1,1,0,0,1,0,0,10.0,1.0,5000.0,1.0,-70.0
2,2,0,0,2,0,0,10.0,1.0,5000.0,1.0,-70.0
3,3,0,0,3,0,0,10.0,1.0,5000.0,1.0,90.0
4,0,1,0,4,1,0,10.0,1.0,5000.0,1.0,90.0
5,1,1,0,5,1,0,10.0,1.0,5000.0,1.0,-70.0
6,2,1,0,6,1,0,10.0,1.0,5000.0,1.0,-70.0
7,3,1,0,7,1,0,10.0,1.0,5000.0,1.0,90.0
8,0,2,0,8,2,0,10.0,1.0,5000.0,1.0,90.0
9,1,2,0,9,2,0,10.0,1.0,5000.0,1.0,-70.0
