In [435]:
%load_ext autoreload
%autoreload 2

from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".8"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [436]:
import jaxley as jx
from jaxley.channels import HH

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import networkx as nx
import jax.numpy as jnp

In [437]:
class DataFrameIndexer:
    def __init__(self, df, at):
        self.df = df
        self.at = at

    def __getitem__(self, key):
        rows, cols = key if isinstance(key, tuple) else (self.at, key)
        return self.df.loc[rows, cols]

    def __setitem__(self, key, value):
        rows, cols = key if isinstance(key, tuple) else (self.at, key)
        self.df.loc[rows, cols] = value

class DataFrameView:
    def __init__(self, view, df_name):
        self._view = view
        self._df_name = df_name

    @property
    def at(self):
        return self._view._at

    def __getattr__(self, name):
        original_df = getattr(self._view, f"_{self._df_name}")
        if name == 'loc':
            return DataFrameIndexer(original_df, self._at)
        elif callable(getattr(original_df, name)):
            def method(*args, **kwargs):
                return getattr(original_df.loc[self._at], name)(*args, **kwargs)
            return method
        else:
            return getattr(original_df.loc[self._at], name)
        
    def __getitem__(self, key):
        original_df = getattr(self._view, f"_{self._df_name}")
        return getattr(original_df.loc[self._at], key)
    
class NodeView(DataFrameView):
    def __init__(self, _view):
        super().__init__(_view, "nodes")


class EdgeView(DataFrameView):
    def __init__(self, _view):
        super().__init__(_view, "edges")

    @property
    def at(self):
        edges = self._view._edges
        comp_inds = np.unique(self._view.comps)
        pre = edges["global_pre_comp_index"].isin(comp_inds)
        post = edges["global_post_comp_index"].isin(comp_inds)
        return edges.index[pre & post]
        
class BranchEdgeView(DataFrameView):
    def __init__(self, view):
        super().__init__(view, "branch_edges")

    @property
    def at(self):        
        branch_edges = self._view._branch_edges
        branch_inds = np.unique(self._view.branches)
        pre = branch_edges["parent_branch_index"].isin(branch_inds)
        post = branch_edges["child_branch_index"].isin(branch_inds)
        return branch_edges.index[pre & post]
        

class RecordingView(DataFrameView):
    def __init__(self, view):
        super().__init__(view, "recordings")

    @property
    def at(self):
        comp_inds = np.unique(self._view.comps)
        recordings = self._view._recordings
        inds = recordings["rec_index"]
        return recordings.index[inds.isin(comp_inds)]

In [375]:
class View:
    def __init__(self, module):
        self._nodes = module._nodes
        self._edges = module._edges
        self._branch_edges = module._branch_edges
        # self._recordings = module._recordings
        self.groups = {}
        self._at = np.array(module._nodes.index)
        self.scope = "global"

        self._add_local_indices()

    def _add_local_indices(self) -> pd.DataFrame:
        idx_cols = ["comp_index", "branch_index", "cell_index"]
        self._nodes.rename(columns={col:"global_"+col for col in idx_cols}, inplace=True)
        idcs = self._nodes[["global_"+col for col in idx_cols]]
        
        def reindex_a_by_b(df, a, b):
            df.loc[:, a] = df.groupby(b)[a].rank(method="dense").astype(int) - 1
            return df

        idcs = reindex_a_by_b(idcs, "global_branch_index", ["global_cell_index"])
        idcs = reindex_a_by_b(idcs, "global_comp_index", ["global_cell_index", "global_branch_index"])
        idcs.columns = ["local_"+col for col in idx_cols]
        self._nodes = pd.concat([self._nodes, idcs], axis=1)        

        #TODO: add global2local to edges
        # def global2local(cell_idx, branch_idx, comp_idx):
        #     global_idcs = [(idcs["global_"+col]==idx) for col, idx in zip(idx_cols, [comp_idx, branch_idx, cell_idx])]
        #     idx = self._nodes.index[np.all(global_idcs, axis=0)]
        #     return self._nodes.loc[idx, ["local_"+col for col in idx_cols]]

    def set_scope(self, scope):
        self.reset()
        self.scope = scope

    @property
    def comps(self):
        return self.nodes[self.scope+"_comp_index"]
    
    @property
    def branches(self):
        return self.nodes[self.scope+"_branch_index"]

    @property
    def cells(self):
        return self.nodes[self.scope+"_cell_index"]

    @property
    def nodes(self):
        return NodeView(self)
    
    @property
    def edges(self):
        return EdgeView(self)
    
    @property
    def branch_edges(self):
        return BranchEdgeView(self)
    
    @property
    def recordings(self):
        return RecordingView(self)
    
    def as_group(self, group_name):
        self.groups[group_name] = self._at
    
    def reset(self):
        self._at = np.array(self._nodes.index)

    def __repr__(self) -> str:
        cells = np.unique(self.cells)
        branches = np.unique(self.branches)
        comps = np.unique(self.comps)
        return f"View of cells: {cells}, branches: {branches}, comps: {comps}"
    
    def _to_str(self):
        df = self.nodes.loc[[self.scope+"_cell_index", self.scope+"_branch_index", self.scope+"_comp_index"]]
        df.columns = ["cell", "branch", "comps"]
        df = df.astype(str)
        df = df.groupby(["cell", "branch"]).agg(lambda x: ", ".join(x))
        return df.__str__()

class ModuleWrapper:
    def __init__(self, module):
        self.module = module
        self._nodes = module.nodes
        self._edges = module.edges
        self._branch_edges = module.branch_edges
        self._global_prefix = "" #TODO: add set_scope global or local
        self._xyzr = module.xyzr
        self._recordings = module.recordings
        self.view = View(self)
        self._nseg = module.nseg

    def __getattr__(self, name):
        if name in self.view.groups:
            self.reset_view()
            self._at(self.view.groups[name])
            return self
        return getattr(self.module, name)
    
    @property
    def nodes(self):
        return self.view.nodes
    
    @property
    def edges(self):
        return self.view.edges
    
    @property
    def branch_edges(self):
        return self.view.branch_edges
    
    # @property
    # def recordings(self):
    #     return self.view.recordings
    
    # @property
    # def xyzr(self):
    #     return [self._xyzr[i] for i in np.unique(self.view.branches)]
    
    # #TODO:
    # @property
    # def externals(self):
    #     pass

    # #TODO:
    # @property
    # def external_inds(self):
    #     pass

    # #TODO:
    # @property
    # def indices_set_by_trainables(self):
    #     pass

    # #TODO:
    # @property
    # def trainable_params(self):
    #     pass

    def _reformat_index(self, idx):
        idx = np.array([idx]) if isinstance(idx, (int, np.int64)) else idx
        idx = np.array(idx) if isinstance(idx, (list,range)) else idx
        idx = np.arange(self.nodes.index.max() + 1)[idx] if isinstance(idx, slice) else idx
        if isinstance(idx, str):
            assert idx == "all", "Only 'all' is allowed"
            idx = self.nodes.index.to_numpy() 
        assert isinstance(idx, np.ndarray), "Invalid type"
        assert idx.dtype == np.int64, "Invalid dtype"
        return idx.reshape(-1)

    def at(self, idx):
        idx = self._reformat_index(idx)
        self.view._at = self.view._at[idx]
        return self
    
    def set(self, key, value):
        self.nodes.loc[key] = value

    def _at_level(self, level: str, idx):
        idx = self._reformat_index(idx)
        where = self.nodes[self.view.scope+f"_{level}_index"].isin(idx)
        # inds = self.nodes.index[where]
        inds = np.where(where)[0]
        return self._at(inds)

    def cell(self, idx):
        return self._at_level("cell", idx)
    
    def branch(self, idx):
        return self._at_level("branch", idx)
    
    def comp(self, idx):
        return self._at_level("comp", idx)
    
    def loc(self, pos):
        if isinstance(pos, str):
            assert pos == "all", "Invalid position"
            return self
        pos = [pos] if isinstance(pos, float) else pos
        comp_edges = np.linspace(0,1,self._nseg+1)
        comp_idx = np.digitize(pos, comp_edges) - 1
        where = self.nodes["local_comp_index"].isin(comp_idx) # has to be local...
        inds = np.where(where)[0]
        return self.at(inds)
    
    def show(self):
        nodes = self.nodes.copy()
        self.reset_view()
        return nodes
    
    def add_group(self, group_name):
        self.view.as_group(group_name)
        return self
    
    def reset_view(self):
        self.view.reset()
        return self
    
comp = jx.Compartment()
branch = jx.Branch(comp, nseg=4)
cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1])
net = jx.Network([cell]*5)
net.compute_xyz()
# net.cell(0).branch(0).comp(0).record("v")

test_cell = ModuleWrapper(cell)
test_net = ModuleWrapper(net)

In [367]:
test_net.view.set_scope("local")
print(test_net.cell([0,2]).branch([0,1,2]).comp([0,3]).view._to_str())

            comps
cell branch      
0    0       0, 3
     1       0, 3
     2       0, 3
2    0       0, 3
     1       0, 3
     2       0, 3


In [None]:
def _reformat_index(self, idx):
    idx = np.array([idx]) if isinstance(idx, (int, np.int64)) else idx
    idx = np.array(idx) if isinstance(idx, (list,range)) else idx
    idx = np.arange(self.nodes.index.max() + 1)[idx] if isinstance(idx, slice) else idx
    if isinstance(idx, str):
        assert idx == "all", "Only 'all' is allowed"
        idx = self.nodes.index.to_numpy() 
    assert isinstance(idx, np.ndarray), "Invalid type"
    assert idx.dtype == np.int64, "Invalid dtype"
    return idx.reshape(-1)


In [691]:
comp = jx.Compartment()
branch = jx.Branch(comp, nseg=4)
cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1])
net = jx.Network([cell]*5)

# write a context manager called View that takes a set of indices and replaces module.nodes with nodes
# only for those indices. Then you can do with View(indices) as view: view.nodes.loc["compartment", "diam"] = 10
# and it will only change the diam of the compartments in the indices set.

class View:
    def __init__(self, module, indices):
        self.module = module
        self.indices = indices
        self.nodes = module.nodes.loc[indices]
    
    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_value, exc_traceback):
        self.module.nodes.loc[self.indices] = self.nodes

    def __getattr__(self, name):
        return getattr(self.module, name)
    

class WrappedModule:
    def __init__(self, module):
        self.module = module
        self.nodes = module.nodes
        self.edges = module.edges
        self.branch_edges = module.branch_edges
        self.recordings = module.recordings
        self._at = np.array(module.nodes.index)

    def _reformat_index(self, idx):
        idx = np.array([idx]) if isinstance(idx, (int, np.int64)) else idx
        idx = np.array(idx) if isinstance(idx, (list,range)) else idx
        idx = np.arange(self.nodes.index.max() + 1)[idx] if isinstance(idx, slice) else idx
        if isinstance(idx, str):
            assert idx == "all", "Only 'all' is allowed"
            idx = self.nodes.index.to_numpy() 
        assert isinstance(idx, np.ndarray), "Invalid type"
        assert idx.dtype == np.int64, "Invalid dtype"
        return idx.reshape(-1)

    def at(self, idx):
        idx = self._reformat_index(idx)
        self._at = self._at[idx]
        return self

    def show(self):
        with View(net, self._at) as view:
            return view.nodes
    
    def set(self, key, value):
        self.nodes.loc[:,key] = value
    
    # def _at_level(self, level: str, idx):
    #     idx = self._reformat_index(idx)
    #     with View(net, self._at) as view:
    #         where = view.nodes[f"{level}_index"].isin(idx)
    #     inds = np.where(where)[0]
    #     return self.at(inds)

    # def cell(self, idx):
    #     return self._at_level("cell", idx)
    
    # def branch(self, idx):
    #     return self._at_level("branch", idx)
    
    # def comp(self, idx):
    #     return self._at_level("comp", idx)
        
    
test_net = WrappedModule(net)

In [445]:
class WrappedModule:
    def __init__(self, module):
        self.nodes = module.nodes
        self.edges = module.edges
        self.branch_edges = module.branch_edges
        self.recordings = module.recordings
        self._in_view = self.nodes.index.to_numpy()
        self._scope = "global"

        self._add_local_indices()
        self.view = self

    def _add_local_indices(self) -> pd.DataFrame:
        idx_cols = ["global_comp_index", "global_branch_index", "global_cell_index"]
        self.nodes.rename(columns={col.replace("global_", ""):col for col in idx_cols}, inplace=True)
        idcs = self.nodes[idx_cols]
        
        def reindex_a_by_b(df, a, b):
            df.loc[:, a] = df.groupby(b)[a].rank(method="dense").astype(int) - 1
            return df

        idcs = reindex_a_by_b(idcs, idx_cols[1], idx_cols[2])
        idcs = reindex_a_by_b(idcs, idx_cols[0], idx_cols[1:])
        idcs.columns = [col.replace("global", "local") for col in idx_cols]
        self.nodes = pd.concat([idcs, self.nodes], axis=1) 

    def _reformat_index(self, idx):
        idx = np.array([idx]) if isinstance(idx, (int, np.int64)) else idx
        idx = np.array(idx) if isinstance(idx, (list,range)) else idx
        idx = np.arange(self.nodes.index.max() + 1)[idx] if isinstance(idx, slice) else idx
        if isinstance(idx, str):
            assert idx == "all", "Only 'all' is allowed"
            idx = self.nodes.index.to_numpy() 
        assert isinstance(idx, np.ndarray), "Invalid type"
        assert idx.dtype == np.int64, "Invalid dtype"
        return idx.reshape(-1)

    def at(self, idx):
        idx = self._reformat_index(idx)
        new_indices = self._in_view[idx]
        return View(self, at=new_indices)

    def set(self, key, value):
        self.view.nodes.loc[self._in_view, key] = value

    def set_scope(self, scope):
        self._scope = scope
    
    def _at_level(self, level: str, idx):
        idx = self._reformat_index(idx)
        where = self.nodes[self._scope+f"_{level}_index"].isin(idx)
        inds = np.where(where)[0]
        return self.at(inds)

    def cell(self, idx):
        return self._at_level("cell", idx)
    
    def branch(self, idx):
        return self._at_level("branch", idx)
    
    def comp(self, idx):
        return self._at_level("comp", idx)

class View(WrappedModule):
    def __init__(self, pointer, at = None):
        self.view = pointer.view
        self._in_view = pointer._in_view if at is None else at
        self._scope = pointer._scope

        self.nodes = pointer.nodes.loc[self._in_view]
        self.branch_edges = pointer.branch_edges.loc[self._branch_edges_in_view]
        self.edges = pointer.edges.loc[self._edges_in_view]

    @property
    def _nodes_in_view(self):
        return self._in_view
    
    @property
    def _branch_edges_in_view(self):
        incl_branches = self.nodes["global_branch_index"].unique()
        pre = self.view.branch_edges["parent_branch_index"].isin(incl_branches)
        post = self.view.branch_edges["child_branch_index"].isin(incl_branches)
        viewed_branch_inds = self.view.branch_edges.index.to_numpy()[pre & post]
        return viewed_branch_inds
    
    @property
    def _edges_in_view(self):
        incl_comps = self.nodes["global_comp_index"].unique()
        pre = self.view.edges["global_pre_comp_index"].isin(incl_comps)
        post = self.view.edges["global_post_comp_index"].isin(incl_comps)
        viewed_edge_inds = self.view.edges.index.to_numpy()[pre & post]
        return viewed_edge_inds

    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_value, exc_traceback):
        pass

In [446]:
# add test that asserts that every attr in view also has a corresponding attr in module
# apart from a few allowed exceptions
# this should trigger if new attrs are added to module that should potentially
# be included in view if they need to be accessed in a specific way

def test_view_attrs(module):
    exceptions = ["_scope", "_at", "view"]

    for name, attr in module.__dict__.items():
        if name not in exceptions:
            assert hasattr(View(module, np.array([0,1])), name), f"View missing attribute: {name}"


In [447]:
comp = jx.Compartment()
branch = jx.Branch(comp, nseg=4)
cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1])
net = jx.Network([cell]*5)
test_net = WrappedModule(net)

test_net.set_scope("local")
test_net.cell([0,2]).branch([0,1,2]).comp([0,3]).set("v", 90)

with test_net.cell([0,2]).branch([0,1,2]).comp([1,2]) as view:
    view.set("v", 10)
    view.set("capacitance", 3)

test_net.nodes

Unnamed: 0,local_comp_index,local_branch_index,local_cell_index,global_comp_index,global_branch_index,global_cell_index,length,radius,axial_resistivity,capacitance,v
0,0,0,0,0,0,0,10.0,1.0,5000.0,1.0,90.0
1,1,0,0,1,0,0,10.0,1.0,5000.0,3.0,10.0
2,2,0,0,2,0,0,10.0,1.0,5000.0,3.0,10.0
3,3,0,0,3,0,0,10.0,1.0,5000.0,1.0,90.0
4,0,1,0,4,1,0,10.0,1.0,5000.0,1.0,90.0
...,...,...,...,...,...,...,...,...,...,...,...
95,3,3,4,95,23,4,10.0,1.0,5000.0,1.0,-70.0
96,0,4,4,96,24,4,10.0,1.0,5000.0,1.0,-70.0
97,1,4,4,97,24,4,10.0,1.0,5000.0,1.0,-70.0
98,2,4,4,98,24,4,10.0,1.0,5000.0,1.0,-70.0
