In [1]:
%load_ext autoreload
%autoreload 2

from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".8"

In [2]:
import jaxley as jx
from jaxley.channels import HH

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import networkx as nx
import jax.numpy as jnp
from jaxley.utils.cell_utils import interpolate_xyz, loc_of_index
from copy import deepcopy
from jaxley.utils.cell_utils import v_interp

from jaxley.connect import connect
from jaxley.synapses import IonotropicSynapse, TestSynapse
from jaxley.utils.misc_utils import concat_and_ignore_empty


In [3]:
class DummyModule:
    def __init__(self, module):
        self.nodes = module.nodes
        self.edges = pd.DataFrame(columns=[f"{scope}_{lvl}_index" for lvl in ["pre_comp", "pre_branch", "pre_cell", "post_comp", "post_branch", "post_cell"] for scope in ["global", "local"]]+["pre_locs", "post_locs", "type", "type_ind"])

        self.branch_edges = module.branch_edges
        self.recordings = module.recordings
        self.synapses = module.synapses
        self.synapse_param_names = module.synapse_param_names
        self.synapse_state_names = module.synapse_state_names
        self.synapse_names = module.synapse_names
        self.channels = module.channels
        self.membrane_current_names = module.membrane_current_names
        self.trainable_params = module.trainable_params
        self.indices_set_by_trainables = module.indices_set_by_trainables
        self.comb_parents = module.comb_parents
        self.externals = module.externals
        self.external_inds = module.external_inds
        self.xyzr = module.xyzr
        self.nseg = module.nseg
        self._in_view = self.nodes.index.to_numpy()
        self._scope = "local" # defaults to local scope
        self.groups = {}
        self.__class__.__name__ = module.__class__.__name__

        self._add_local_indices()
        self.base = self

    def _add_local_indices(self) -> pd.DataFrame:
        idx_cols = ["global_comp_index", "global_branch_index", "global_cell_index"]
        self.nodes.rename(columns={col.replace("global_", ""):col for col in idx_cols}, inplace=True)
        idcs = self.nodes[idx_cols]
        
        def reindex_a_by_b(df, a, b):
            df.loc[:, a] = df.groupby(b)[a].rank(method="dense").astype(int) - 1
            return df

        idcs = reindex_a_by_b(idcs, idx_cols[1], idx_cols[2])
        idcs = reindex_a_by_b(idcs, idx_cols[0], idx_cols[1:])
        idcs.columns = [col.replace("global", "local") for col in idx_cols]
        self.nodes = pd.concat([idcs, self.nodes], axis=1) 

    def _reformat_index(self, idx):
        idx = np.array([], dtype=int) if idx is None else idx
        idx = np.array([idx]) if isinstance(idx, (int, np.int64)) else idx
        idx = np.array(idx) if isinstance(idx, (list,range)) else idx
        idx = np.arange(len(self._in_view) + 1)[idx] if isinstance(idx, slice) else idx
        if isinstance(idx, str):
            assert idx == "all", "Only 'all' is allowed"
            idx = np.arange(len(self._in_view) + 1)
        assert isinstance(idx, np.ndarray), "Invalid type"
        assert idx.dtype == np.int64, "Invalid dtype"
        return idx.reshape(-1)

    def at(self, idx, sorted=False):
        idx = self._reformat_index(idx)
        new_indices = self._in_view[idx]
        new_indices = np.sort(new_indices) if sorted else new_indices
        return View(self, at=new_indices)

    def set(self, key, value):
        if key in self.nodes.columns:
            not_nan = ~self.nodes[key].isna()
            self.base.nodes.loc[self._in_view[not_nan], key] = value
        elif key in self.edges.columns:
            not_nan = ~self.edges[key].isna()
            self.base.edges.loc[self._edges_in_view[not_nan], key] = value
        else:
            raise KeyError(f"Key '{key}' not found in nodes or edges")

    def set_scope(self, scope):
        self._scope = scope

    def scope(self, scope):
        view = self.view
        view.set_scope(scope)
        return view
    
    def _at_level(self, level: str, idx):
        idx = self._reformat_index(idx)
        where = self.nodes[self._scope+f"_{level}_index"].isin(idx)
        inds = np.where(where)[0]
        return self.at(inds)

    def cell(self, idx):
        return self._at_level("cell", idx)
    
    def branch(self, idx):
        return self._at_level("branch", idx)
    
    def comp(self, idx):
        return self._at_level("comp", idx)
    
    def loc(self, at: float):
        comp_edges = np.linspace(0, 1, self.base.nseg+1)
        idx = np.digitize(at, comp_edges)
        view = self.comp(idx)
        return view
        
    def add_group(self, name):
        self.base.groups[name] = self._in_view

    def __getattr__(self, key):
        if key.startswith("__"):
            return super().__getattribute__(key)
        
        if key in self.base.groups:
            return self.at(self.groups[key]) if key in self.groups else self.at(None)
        
        if key in [c._name for c in self.base.channels]:
            channel_names = [c._name for c in self.channels]
            inds = self.nodes.index[self.nodes[key]].to_numpy()
            return self.at(inds) if key in channel_names else self.at(None)

        if key in self.base.synapse_names:
            # if the same 2 nodes are connected by 2 different synapses,
            # module.SynapseA.edges will still contain both synapses
            # since view filters based on index ONLY. Returning only the row
            # that contains SynapseA is not possible currently. For setting
            # synapse parameters this has no effect however.
            has_syn = self.edges["type"] == key
            where = has_syn, ["global_pre_comp_index", "global_post_comp_index"]
            comp_inds_in_view = pd.unique(self.edges.loc[where].values.ravel("K"))
            inds = np.where(self.nodes["global_comp_index"].isin(comp_inds_in_view))[0]
            return self.at(inds) if key in self.synapse_names else self.at(None)
        
    def show(self):
        nodes = self.nodes.copy() # prevents this from being edited
        # drop columns with global indices if scope is local
        drop = "global" if self._scope == "local" else "local"
        nodes = nodes.drop(columns=[col for col in nodes.columns if drop in col])
        nodes.columns = [col.replace(f"{self._scope}_", "") for col in nodes.columns]
        return nodes
    
    def __getitem__(self, idx):
        levels = ["network", "cell", "branch", "comp"]
        module = self.base.__class__.__name__.lower() # 
        module = "comp" if module == "compartment" else module
        
        children = levels[levels.index(module)+1:]
        idx = idx if isinstance(idx, tuple) else (idx,)
        view = self
        for i, child in enumerate(children):
            view = view._at_level(child, idx[i])
        return view
    
    def _iter_level(self, level):
        col = self._scope + f"_{level}_index"
        idxs = self.nodes[col].unique()
        for idx in idxs:
            yield self._at_level(level, idx)
    
    @property
    def cells(self):
        yield from self._iter_level("cell")
    
    @property
    def branches(self):
        yield from self._iter_level("branch")

    @property
    def comps(self):
        yield from self._iter_level("comp")    

    @property
    def shape(self):
        cols = ["global_cell_index", "global_branch_index", "global_comp_index"]
        raw_shape = self.nodes[cols].nunique().to_list()

        # ensure (net.shape -> dim=3, cell.shape -> dim=2, branch.shape -> dim=1, comp.shape -> dim=0)
        levels = ["network", "cell", "branch", "comp"]
        module = self.base.__class__.__name__.lower()
        module = "comp" if module == "compartment" else module
        shape = tuple(raw_shape[levels.index(module):])
        return shape
    
    def copy(self, reset_index=False, as_module=False):
        view = deepcopy(self)
        # TODO: add reset_index, i.e. for parents, nodes, edges etc. such that they
        # start from 0/-1 and are contiguous
        if as_module:
            # TODO: initialize a new module with the same attributes
            pass
        return view
    
    @property
    def view(self):
        return View(self, self._in_view)

    def vis(self, dims=[0,1], level="branch", ax=None, type="line", **kwargs):
        if ax is None:
            _, ax = plt.subplots(1, 1, figsize=(3, 3))
        if level == "branch":
            for coords_of_branch in self.xyzr:
                x1, x2 = coords_of_branch[:, dims].T

                if "line" in type.lower():
                    _ = ax.plot(x1, x2, **kwargs)
                elif "scatter" in type.lower():
                    _ = ax.scatter(x1, x2, **kwargs)
                else:
                    raise NotImplementedError
        if level == "comp":
            x1, x2 = self.nodes[["x", "y", "z"]].values[:, dims].T
            ax.scatter(x1, x2, **kwargs)
        return ax

    def record(self, state, verbose=True):
        new_recs = pd.DataFrame(self._in_view, columns=["rec_index"])
        new_recs["state"] = state
        self.base.recordings = pd.concat([self.base.recordings, new_recs])
        has_duplicates = self.base.recordings.duplicated()
        self.base.recordings = self.base.recordings.loc[~has_duplicates]
        if verbose:
            print(f"Added {len(self._in_view)-sum(has_duplicates)} recordings. See `.recordings` for details.")

    def insert(self, channel):
        name = channel._name

        # Channel does not yet exist in the `jx.Module` at all.
        if name not in [c._name for c in self.base.channels]:
            self.base.channels.append(channel)
            self.base.nodes[name] = False  # Previous columns do not have the new channel.

        if channel.current_name not in self.base.membrane_current_names:
            self.base.membrane_current_names.append(channel.current_name)

        # Add a binary column that indicates if a channel is present.
        self.base.nodes.loc[self._in_view, name] = True

        # Loop over all new parameters, e.g. gNa, eNa.
        for key in channel.channel_params:
            self.base.nodes.loc[self._in_view, key] = channel.channel_params[key]

        # Loop over all new parameters, e.g. gNa, eNa.
        for key in channel.channel_states:
            self.base.nodes.loc[self._in_view, key] = channel.channel_states[key]
    
    def stimulate(self, current, verbose=False):
        self._external_input("i", current, verbose)

    def _external_input(self, key, values, verbose=False):
        values = values if values.ndim == 2 else jnp.expand_dims(values, axis=0)
        batch_size = values.shape[0]
        num_inserted = len(self._in_view)
        is_multiple = num_inserted == batch_size
        values = values if is_multiple else jnp.repeat(values, len(self._in_view), axis=0)
        assert batch_size in [1, num_inserted], "Number of comps and stimuli do not match."

        if key in self.base.externals.keys():
            self.base.externals[key] = jnp.concatenate([self.base.externals[key], values])
            self.base.external_inds[key] = jnp.concatenate(
                [self.base.external_inds[key], self._in_view]
            )
        else:
            self.base.externals[key] = values
            self.base.external_inds[key] = self._in_view

        if verbose:
            print(f"Added {num_inserted} external_states. See `.externals` for details.")

    def clamp(self, state_name, state_array, verbose=False):
        self._external_input(state_name, state_array, verbose=verbose)

    def data_stimulate(self, current, data_stimuli, verbose=False):
        current = current if current.ndim == 2 else jnp.expand_dims(current, axis=0)
        batch_size = current.shape[0]
        num_inserted = len(self._in_view)
        is_multiple = num_inserted == batch_size
        current = current if is_multiple else jnp.repeat(current, num_inserted, axis=0)
        assert batch_size in [1, num_inserted], "Number of comps and stimuli do not match."

        if data_stimuli is not None:
            currents = data_stimuli[0]
            inds = data_stimuli[1]
        else:
            currents = None
            inds = pd.DataFrame().from_dict({})

        # Same as in `.stimulate()`.
        if currents is not None:
            currents = jnp.concatenate([currents, current])
        else:
            currents = current
        inds = pd.concat([inds, self._in_view])

        if verbose:
            print(f"Added {num_inserted} stimuli.")

        return (currents, inds)

    def data_set(self, key, val, param_state=None):
        # Note: `data_set` does not support arrays for `val`.
        if key in self.nodes.columns:
            not_nan = ~self.nodes[key].isna()
            added_param_state = [
                {
                    "indices": np.atleast_2d(self._in_view[not_nan]),
                    "key": key,
                    "val": jnp.atleast_1d(jnp.asarray(val)),
                }
            ]
            if param_state is not None:
                param_state += added_param_state
            else:
                param_state = added_param_state
        else:
            raise KeyError("Key not recognized.")
        return param_state

    def move(self, x,y,z, update_nodes=True):
        indizes = self.nodes["global_branch_index"].unique()
        for i in indizes:
            self.base.xyzr[i][:, :3] += np.array([x, x, y])
        if update_nodes:
            self._update_nodes_with_xyz()

    def move_to(self, x,y,z, update_nodes=True):
        # Test if any coordinate values are NaN which would greatly affect moving
        if np.any(np.concatenate(self.xyzr, axis=0)[:, :3] == np.nan):
            raise ValueError(
                "NaN coordinate values detected. Shift amounts cannot be computed. Please run compute_xyzr() or assign initial coordinate values."
            )
        
        indizes = self.nodes["global_branch_index"].unique()
        move_by = np.array([x, y, z]).T - self.xyzr[0][0,:3] # move with respect to root idx
        
        for idx in indizes:
            self.base.xyzr[idx][:, :3] += move_by
        if update_nodes:
            self._update_nodes_with_xyz()

    def rotate(self, degrees, rotation_axis="xy", update_nodes=True):
        degrees = degrees / 180 * np.pi
        if rotation_axis == "xy":
            dims = [0, 1]
        elif rotation_axis == "xz":
            dims = [0, 2]
        elif rotation_axis == "yz":
            dims = [1, 2]
        else:
            raise ValueError

        rotation_matrix = np.asarray(
            [[np.cos(degrees), np.sin(degrees)], [-np.sin(degrees), np.cos(degrees)]]
        )
        indizes = self.nodes["global_branch_index"].unique()
        for i in indizes:
            rot = np.dot(rotation_matrix, self.base.xyzr[i][:, dims].T).T
            self.base.xyzr[i][:, dims] = rot
        if update_nodes:
            self._update_nodes_with_xyz()

    def _update_nodes_with_xyz(self):
        num_branches = len(self.base.xyzr)
        comp_ends = (
            np.linspace(0, 1, self.nseg + 1).reshape(1, -1).repeat(num_branches, 0)
        )
        comp_ends = comp_ends + 2 * np.arange(num_branches).reshape(
            -1, 1
        )  # inter-branch padding
        comp_ends = comp_ends.reshape(-1)
        branch_lens = []
        for i, xyzr in enumerate(self.base.xyzr):
            branch_len = np.sqrt(
                np.sum(np.diff(xyzr[:, :3], axis=0) ** 2, axis=1)
            ).cumsum()
            branch_len = np.hstack([np.array([0]), branch_len])
            branch_len = branch_len / branch_len.max() + 2 * i  # add padding like above
            branch_len[np.isnan(branch_len)] = 0
            branch_lens.append(branch_len)
        branch_lens = np.hstack(branch_lens)
        xyz = np.vstack(self.base.xyzr)[:, :3]
        xyz = v_interp(comp_ends, branch_lens, xyz).reshape(
            3, num_branches, self.nseg + 1
        )
        centers = ((xyz[:, :, 1:] + xyz[:, :, :-1]) / 2).reshape(3, -1).T
        self.base.nodes.loc[self._in_view, ["x", "y", "z"]] = centers[self._in_view]
        return centers, xyz
    
    def _infer_synapse_type_ind(self, synapse_name):
        syn_names = self.base.synapse_names
        is_new_type = False if synapse_name in syn_names else True
        type_ind = len(syn_names) if is_new_type else syn_names.index(synapse_name)
        return type_ind, is_new_type
    
    def _update_synapse_state_names(self, synapse_type):
        # (Potentially) update variables that track meta information about synapses.
        self.base.synapse_names.append(synapse_type._name)
        self.base.synapse_param_names += list(synapse_type.synapse_params.keys())
        self.base.synapse_state_names += list(synapse_type.synapse_states.keys())
        self.base.synapses.append(synapse_type)

    def _append_multiple_synapses(self, pre, post, synapse_type):
        # Add synapse types to the module and infer their unique identifier.
        synapse_name = synapse_type._name
        type_ind, is_new = self._infer_synapse_type_ind(synapse_name)
        if is_new:  # synapse is not known
            self._update_synapse_state_names(synapse_type)

        index = len(self.base.edges)
        post_loc = loc_of_index(post._comps_in_view, self.nseg)
        pre_loc = loc_of_index(pre._comps_in_view, self.nseg)

        # Define new synapses. Each row is one synapse.
        cols = ["comp_index", "branch_index", "cell_index"]
        pre_nodes = pre.nodes[[f"{scope}_{col}" for col in cols for scope in ["local", "global"]]]
        pre_nodes.columns = [f"{scope}_pre_{col}" for col in cols for scope in ["local", "global"]]
        post_nodes = post.nodes[[f"{scope}_{col}" for col in cols for scope in ["local", "global"]]]
        post_nodes.columns = [f"{scope}_post_{col}" for col in cols for scope in ["local", "global"]]
        new_rows = pd.concat([pre_nodes.reset_index(drop=True), post_nodes.reset_index(drop=True)], axis=1)
        new_rows["type"] = synapse_name
        new_rows["type_ind"] = type_ind
        new_rows["pre_loc"] = pre_loc
        new_rows["post_loc"] = post_loc
        self.base.edges = concat_and_ignore_empty(
            [self.base.edges, new_rows],
            ignore_index=True, axis=0
        )

        indices = [idx for idx in range(index, index + len(pre_loc))]
        self._add_params_to_edges(synapse_type, indices)

    def _add_params_to_edges(self, synapse_type, indices):
        # Add parameters and states to the `.edges` table.
        for key, param_val in synapse_type.synapse_params.items():
            self.base.edges.loc[indices, key] = param_val

        # Update synaptic state array.
        for key, state_val in synapse_type.synapse_states.items():
            self.base.edges.loc[indices, key] = state_val

    def make_trainable(self, key, init_val=None, verbose=False):
        branches_in_view = self.view._branches_in_view
        assert (
            self.allow_make_trainable
        ), "network.cell('all').make_trainable() is not supported. Use a for-loop over cells."
        if key in self.nodes.columns:
            not_nan = ~self.nodes[key].isna()
            params = jnp.asarray(self.nodes.loc[self._in_view[not_nan], key])
            params = params.reshape(len(branches_in_view), -1)
            if init_val is not None:
                params.at[:,:].set(jnp.array(init_val))
            # TODO: discuss shapes
            self.base.trainable_params.append({key: params})

            inds = self._in_view[not_nan]
            inds = inds.reshape(len(branches_in_view), -1)
            self.base.indices_set_by_trainables.append(inds)
        elif key in self.edges.columns:
            not_nan = ~self.edges[key].isna()
            params = jnp.asarray(self.edges.loc[self._edges_in_view[not_nan], key])
            # TODO: discuss shapes
            if init_val is not None:
                params.at[:,:].set(jnp.array(init_val))
            self.base.trainable_params.append({key: params})

            inds = self._in_view[not_nan]
            self.base.indices_set_by_trainables.append(inds)
        else:
            raise KeyError(f"Key '{key}' not found in nodes or edges")

    
class View(DummyModule):
    def __init__(self, pointer, at = None):
        # attrs with a static view
        self._scope = pointer._scope
        self.base = pointer.base
        self.initialized_morph = pointer.initialized_morph
        self.initialized_syns = pointer.initialized_syns
        self.allow_make_trainable = pointer.allow_make_trainable
        
        # attrs affected by view
        self.nseg = pointer.nseg
        self._in_view = pointer._in_view if at is None else at

        self.nodes = pointer.nodes.loc[self._in_view]
        self.branch_edges = pointer.branch_edges.loc[self._branch_edges_in_view]
        self.edges = pointer.edges.loc[self._edges_in_view]
        self.xyzr = self._xyzr_in_view(pointer)
        self.nseg = 1 if len(self.nodes) == 1 else pointer.nseg
        self.total_nbranches = len(self._branches_in_view)
        self.nbranches_per_cell = self._nbranches_per_cell_in_view()
        self.cumsum_nbranches = np.cumsum(self.nbranches_per_cell)

        self.synapse_names = np.unique(self.edges["type"]).tolist()
        self.synapses, self.synapse_param_names, self.synapse_state_names = self._synapses_in_view(pointer)

        self.recordings = pointer.recordings.loc[pointer.recordings["rec_index"].isin(self._comps_in_view)]
        
        self.channels = self._channels_in_view(pointer)
        self.membrane_current_names = [c._name for c in self.channels]

        self.indices_set_by_trainables, self.trainable_params = self._trainables_in_view()

        self.comb_parents = self.base.comb_parents[self._branches_in_view]
        self.externals, self.external_inds = self._externals_in_view()
        self.groups = {k:np.intersect1d(v, self._in_view) for k,v in pointer.groups.items()} 

        #TODO:
        # self.debug_states

        if len(self.nodes) == 0:
            raise ValueError("Nothing in view. Check your indices.")
        
    def _externals_in_view(self):
        externals_in_view = {}
        external_inds_in_view = []
        for (name, inds), data in zip(self.base.external_inds.items(), self.base.externals.values()):
            in_view = np.isin(inds, self._in_view)
            inds_in_view = inds[in_view]
            if len(inds_in_view) > 0:
                externals_in_view[name] = data[in_view]
                external_inds_in_view.append(inds_in_view)
        return externals_in_view, external_inds_in_view

    def _trainables_in_view(self):
        trainable_inds = self.base.indices_set_by_trainables
        trainable_inds = np.unique(np.hstack([inds.reshape(-1) for inds in trainable_inds])) if len(trainable_inds) > 0 else []
        trainable_inds_in_view = np.intersect1d(trainable_inds, self._in_view)
        
        índices_set_by_trainables_in_view = []
        trainable_params_in_view = []
        for inds, params in zip(self.base.indices_set_by_trainables, self.base.trainable_params):
            in_view = np.isin(inds, trainable_inds_in_view)
            
            completely_in_view = in_view.all(axis=1)
            índices_set_by_trainables_in_view.append(inds[completely_in_view])
            trainable_params_in_view.append({k:v[completely_in_view] for k,v in params.items()})
            
            partially_in_view = in_view.any(axis=1) & ~completely_in_view
            índices_set_by_trainables_in_view.append(inds[partially_in_view][in_view[partially_in_view]])
            trainable_params_in_view.append({k:v[partially_in_view] for k,v in params.items()})

        índices_set_by_trainables_in_view = [inds for inds in índices_set_by_trainables_in_view if len(inds) > 0]
        trainable_params_in_view = [p for p in trainable_params_in_view if len(next(iter(p.values()))) > 0]
        return índices_set_by_trainables_in_view, trainable_params_in_view

    def _channels_in_view(self, pointer):
        names = [name._name for name in pointer.channels]
        channel_in_view = self.nodes[names].any(axis=0)
        channel_in_view = channel_in_view[channel_in_view].index
        return [c for c in pointer.channels if c._name in channel_in_view]
        
    def _synapses_in_view(self, pointer):
        viewed_synapses = []
        viewed_params = []
        viewed_states = []
        if not pointer.synapses is None:
            for syn in pointer.synapses:
                if syn is not None: # needed for recurive viewing
                    in_view = syn._name in self.synapse_names
                    viewed_synapses += [syn] if in_view else [None] # padded with None to keep indices consistent
                    viewed_params += list(syn.synapse_params.keys()) if in_view else []
                    viewed_states += list(syn.synapse_states.keys()) if in_view else []

        return viewed_synapses, viewed_params, viewed_states
        
    def _nbranches_per_cell_in_view(self):
        cell_nodes = self.nodes.groupby("global_cell_index")
        return cell_nodes["global_branch_index"].nunique().to_numpy()       

    def _xyzr_in_view(self, pointer):
        prev_branch_inds = pointer._branches_in_view
        viewed_branch_inds = self._branches_in_view
        if prev_branch_inds is None:
            xyzr = pointer.xyzr.copy() # copy to prevent editing original
        else:
            branches2keep = np.isin(prev_branch_inds, viewed_branch_inds)
            branch_inds2keep = np.where(branches2keep)[0]
            xyzr = [pointer.xyzr[i] for i in branch_inds2keep].copy()

        # Currently viewing with `.loc` will show the closest compartment
        # rather than the actual loc along the branch!
        viewed_nseg_for_branch = self.nodes.groupby("global_branch_index").size()
        incomplete_inds = np.where(viewed_nseg_for_branch != self.base.nseg)[0]
        incomplete_branch_inds = viewed_branch_inds[incomplete_inds]

        cond = self.nodes["global_branch_index"].isin(incomplete_branch_inds)
        interp_inds = self.nodes.loc[cond]
        local_inds_per_branch = interp_inds.groupby("global_branch_index")["local_comp_index"]
        locs = [loc_of_index(inds.to_numpy(), self.base.nseg) for _, inds in local_inds_per_branch]
        
        for i, loc in zip(incomplete_inds, locs):
            xyzr[i] = interpolate_xyz(loc, xyzr[i]).T
        return xyzr

    @property
    def _nodes_in_view(self):
        return self._in_view
    
    @property
    def _branch_edges_in_view(self):
        incl_branches = self.nodes["global_branch_index"].unique()
        pre = self.base.branch_edges["parent_branch_index"].isin(incl_branches)
        post = self.base.branch_edges["child_branch_index"].isin(incl_branches)
        viewed_branch_inds = self.base.branch_edges.index.to_numpy()[pre & post]
        return viewed_branch_inds
    
    @property
    def _edges_in_view(self):
        incl_comps = self.nodes["global_comp_index"].unique()
        pre = self.base.edges["global_pre_comp_index"].isin(incl_comps).to_numpy()
        post = self.base.edges["global_post_comp_index"].isin(incl_comps).to_numpy()
        viewed_edge_inds = self.base.edges.index.to_numpy()[(pre & post).flatten()]
        return viewed_edge_inds

    def __getattr__(self, name):
        # Delegate attribute access to the pointer if not found in View
        return getattr(self.pointer, name)
    
    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_value, exc_traceback):
        pass


# add test that asserts that every attr in view also has a corresponding attr in module
# apart from a few allowed exceptions
# this should trigger if new attrs are added to module that should potentially
# be included in view if they need to be accessed in a specific way

def test_view_attrs(module):
    exceptions = ["_scope", "_at", "view"]

    for name, attr in module.__dict__.items():
        if name not in exceptions:
            # check if attr is in view
            assert hasattr(View(module, np.array([0,1])), name), f"View missing attribute: {name}"
            # check if types match
            assert type(getattr(module, name)) == type(getattr(View(module, np.array([0,1])), name), f"Type mismatch: {name}")

#TODO replace global and local cell indexes with just global_cell_index


def connect(
    pre: "CompartmentView",
    post: "CompartmentView",
    synapse_type: "Synapse",
):
    """Connect two compartments with a chemical synapse.

    The pre- and postsynaptic compartments must be different compartments of the
    same network.

    Args:
        pre: View of the presynaptic compartment.
        post: View of the postsynaptic compartment.
        synapse_type: The synapse to append
    """
    # assert is_same_network(
    #     pre, post
    # ), "Pre and post compartments must be part of the same network."
    # assert np.all(
    #     pre_comp_not_equal_post_comp(pre, post)
    # ), "Pre and post compartments must be different."

    pre._append_multiple_synapses(pre, post, synapse_type)

In [4]:
# setup
comp = jx.Compartment()
branch = jx.Branch(comp, nseg=4)
cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])
net = jx.Network([cell]*5)
net.cell(0).insert(HH())
net.cell(1).branch(0).comp(0).record("v")
# net.cell(0).branch(0).comp("all").make_trainable("HH_gK")
# net.cell(0).branch(1).comp(0).make_trainable("radius")
# net.cell(0).branch(2).comp(0).make_trainable("radius")
# net.cell(2).branch("all").make_trainable("length")
# net.cell([0,1]).make_trainable("length")
net.cell(0).branch(0).comp(0).stimulate(np.zeros((1, 10)))
net.cell(0).branch([0,1]).comp(0).clamp("v", np.ones((1, 10))*-65)
net.cell(0).branch(1).comp("all").stimulate(np.ones((1, 10)))
cell.compute_xyz()
net.compute_xyz()
net = DummyModule(net)
net._update_nodes_with_xyz()
net.allow_make_trainable = True

connect(net.cell(0).branch(0).comp(0), net.cell(1).branch(1).comp(0), IonotropicSynapse())
connect(net.cell(2).branch(0).comp(0), net.cell(1).branch(1).comp(0), IonotropicSynapse())
connect(net.cell(1).branch(0).comp(0), net.cell(1).branch(1).comp(0), IonotropicSynapse())
connect(net.cell(0).branch(2).comp(0), net.cell(1).branch(1).comp(0), IonotropicSynapse())
connect(net.cell(1).branch(0).comp(0), net.cell(1).branch(1).comp(0), TestSynapse())
connect(net.cell(0).branch(2).comp(0), net.cell(1).branch(1).comp(0), TestSynapse())
# net.cell(0).branch(0).comp("all").make_trainable("HH_gK")
# net.cell(0).branch(1).comp(0).make_trainable("radius")
# net.cell(2).branch("all").make_trainable("length")

Added 1 recordings. See `.recordings` for details.
Added 1 external_states. See `.externals` for details.
Added 2 external_states. See `.externals` for details.
Added 4 external_states. See `.externals` for details.


In [515]:
# Before: View would take a module and wrap its methods if needed. This meant:
# 1. All methods meant to be accessed both in view and module had to be hidden
#    and accessed via wrappers in either view or module. This was a lot of boilerplate.
# 2. View was fundamentally a different object from module. This meant views only
#    had access to a subset of module's methods and attributes. Hence, i.e. net.cell(0)
#    did not support looking at all its attrs and could not be simulated on its own.
# 3. Indexing global vs local and managing how things were viewed was a bit clunky.
# ----------------------------
# Now: View returns a Module instance of itself with a different indexes in view.
# This means all methods in Module also work on View. The job of View now is to
# manage how attributes are returned based on the indexes in view. This means:
# calling View(module, inds) will behave like a module that only has a subset of
# nodes, edges etc., as defined by inds.

# NEW/OLD FEATURES
# more flexible indexing / selection with at 
# (tracked via dataframe index, which does not change with scope)
rnd_inds = np.random.randint(0, len(net.nodes),10)
net.at(rnd_inds)

# arbitrary selection
net.branch(0).show()
net.comp(0).show()

# scope
net.set_scope("global")
net.cell([0,2]).branch([0]).comp([1,2]).show() # -> [1,2] 

net.set_scope("local")
net.cell([0,2]).branch([0]).comp([1,2]).show() # -> [1,2,41,42]

net.scope("local").comp(0).show()
# vs.
net.scope("global").comp(0).show()

# context management
with net.cell(0).branch(0).comp(0) as comp0:
    comp0.set("v", -70)
    comp0.set("HH_gK", 0.1)
net.cell(0).branch(0).comp([0,1]).show()[["v", "HH_gK"]]

# iterables
for cell in net.cells:
    for branch in cell.branches:
        for comp in branch.comps:
            comp.set("v", -71)

for comp in net.cell(0).branch(0).comps:
    comp.set("v", -72)
net.show()[["v"]]

# indexing
net[0,0,0].show()

# groups
net.cell(1).branch(0).add_group("group")
net.group.show()

# Channel and Synapse views
net.HH.show()
net.cell(0).HH.nodes
net.HH.cell(0).nodes

net.IonotropicSynapse.nodes
net.cell(1).IonotropicSynapse.nodes
net.IonotropicSynapse.cell(1).nodes

# shape
net.shape
net.cell(0).shape

# copying
cell0 = net.cell(0).copy()
cell0.show()

  norm_pathlens = pathlens / pathlens[-1]  # path lengths normalized to [0,1]


Unnamed: 0,local_comp_index,local_branch_index,local_cell_index,global_comp_index,global_branch_index,global_cell_index,length,radius,axial_resistivity,capacitance,...,HH_gLeak,HH_eNa,HH_eK,HH_eLeak,HH_m,HH_h,HH_n,x,y,z
0,0,0,0,0,0,0,10.0,1.0,5000.0,1.0,...,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2,5.0,0.0,0.0
1,1,0,0,1,0,0,10.0,1.0,5000.0,1.0,...,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2,15.0,0.0,0.0
2,2,0,0,2,0,0,10.0,1.0,5000.0,1.0,...,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2,25.0,0.0,0.0
3,3,0,0,3,0,0,10.0,1.0,5000.0,1.0,...,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2,35.0,0.0,0.0
4,0,1,0,4,1,0,10.0,1.0,5000.0,1.0,...,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2,42.940858,-4.04368,0.0
5,1,1,0,5,1,0,10.0,1.0,5000.0,1.0,...,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2,48.822575,-12.131041,0.0
6,2,1,0,6,1,0,10.0,1.0,5000.0,1.0,...,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2,54.704292,-20.218402,0.0
7,3,1,0,7,1,0,10.0,1.0,5000.0,1.0,...,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2,60.586009,-28.305763,0.0
8,0,2,0,8,2,0,10.0,1.0,5000.0,1.0,...,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2,42.940858,4.04368,0.0
9,1,2,0,9,2,0,10.0,1.0,5000.0,1.0,...,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2,48.822575,12.131041,0.0
