In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".8"

In [129]:
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
from jax import jit, vmap, value_and_grad

import jaxley as jx
from jaxley.channels import HH
from jaxley.synapses import IonotropicSynapse, TestSynapse
from jaxley_mech.channels.fm97 import Na, K, KA, KCa, Ca, Leak

import warnings
import networkx as nx
from jaxley.connection import connect

In [4]:
dt = 0.025
t_max = 100.0
time_vec = np.arange(0, t_max+2*dt, dt)

i_delay = 10.0
i_dur = 80.0
i_amp = 5.0  # nA
current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)
np.random.seed(0)

In [158]:
comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(4)])
cell = jx.Cell([branch for _ in range(5)], parents=jnp.asarray([-1, 0, 1, 2, 2]))
net = jx.Network([cell]*3)
connect(net[0,0,0], net[1,0,0], IonotropicSynapse())
connect(net[0,0,1], net[1,0,1], IonotropicSynapse())
connect(net[0,0,1], net[1,0,1], TestSynapse())
net.cell(2).add_to_group("cell2")


net.cell(0).insert(Na())
net.cell(0).insert(Leak())

net.cell(1).branch(1).insert(Na())
net.cell(0).insert(K())
net.compute_xyz()

net.cell(0).branch(0).loc(0.0).record()
current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)
net.cell(0).branch(2).loc(0.0).stimulate(current)
net.cell(0).branch(1).make_trainable("Na")
net.cell(1).make_trainable("K")

Added 1 recordings. See `.recordings` for details.
Added 1 stimuli. See `.currents` for details.
Number of newly added trainable parameters: 1. Total number of trainable parameters: 1
Number of newly added trainable parameters: 1. Total number of trainable parameters: 2


  self.pointer.edges = pd.concat(


In [184]:
net.nodes.columns

Index(['comp_index', 'branch_index', 'cell_index', 'length', 'radius',
       'axial_resistivity', 'capacitance', 'v', 'Na', 'Na_gNa', 'Na_eNa',
       'Na_m', 'Na_h', 'Leak', 'Leak_gLeak', 'Leak_eLeak', 'K', 'K_gK', 'eK',
       'K_n'],
      dtype='object')

In [192]:
net.cell([0,1]).edges

In [8]:
def to_graph(df):
    G = nx.Graph()
    edges = df[['parent_branch_index', 'child_branch_index']].values
    G.add_edges_from(edges)
    return G

def to_module(view):
    warnings.warn(
    """
    This function is experimental and may not work as expected.
    Recordings, currents and Synapses are lost in this operation currently.
    """
    )

    pointer, view = view.pointer, view.view
    modules = np.array([jx.Network, jx.Cell, jx.Branch, jx.Compartment])
    indices = ["cell_index", "branch_index", "comp_index"]
    num_unique_elements = view[indices].nunique().values # number of unique module elements in view
    # highest module with more than one unique element, i.e. if num_branches > 1, return Cell etc.
    return_type =  modules[np.concatenate((num_unique_elements > 1, [True]))][0]

    viewed_indices = {col: np.unique(view[col]) for col in indices}
    
    where_to_view_edges = pointer.branch_edges.isin(viewed_indices["branch_index"])
    viewed_branch_edges = pointer.branch_edges.loc[where_to_view_edges.any(axis=1)]
    graph = to_graph(viewed_branch_edges)
    
    assert nx.is_connected(graph), "The branches currently in view are not all connected."
    assert not any(view[indices].duplicated()), "View must not contain duplicates."
    
    module_instances = []
    num_unique_elements = np.hstack([1, num_unique_elements])
    for num, module in zip(reversed(num_unique_elements), reversed(modules)):
        args = () 
        if module in [jx.Branch, jx.Network]:
            args = [module_instances]
        elif module == jx.Cell:
            root_node_index = viewed_branch_edges["parent_branch_index"].min()
            levels = list(nx.bfs_layers(graph, root_node_index))
            levels = [[i for i in indices if i in viewed_indices["branch_index"]] for indices in levels]
            parents = sum([[i-1]*len(level) for i, level in enumerate(levels)], [])
            args = [module_instances, parents]
        
        module_instances = [module(*args) for _ in range(num)]
        if module == return_type:
            break

    module = module_instances[0]
    indices_of_viewed_elements = view.index
    module.nodes = pointer.nodes.loc[indices_of_viewed_elements]
    module.nodes.reset_index(drop=True, inplace=True)

    # drop columns with all nan or all False
    module.nodes = module.nodes.dropna(axis=1)
    module.nodes = module.nodes.loc[:, (module.nodes != 0).any()]

    comb_branches_in_each_level = [idcs[np.isin(idcs, viewed_indices["branch_index"])] for idcs in pointer.comb_branches_in_each_level]
    comb_branches_in_each_level = [level for level in comb_branches_in_each_level if len(level) > 0]
    
    #TODO add channels
    # module.channels = pointer.channels
    # module.membrane_current_names

    #TODO add synapses
    # module.synapses = pointer.synapses

    #TODO
    # module.indices_set_by_trainables
    # module.trainable_params
    # module.num_trainable_params = sum([len(p) for p in module.get_parameters()])

    def add_to_module(module, key, value): 
        module.__setattr__(key, value)

    # TODO: These could be added to properties of View as well
    # then they could be accessed as view.attr and would only return params that are in view
    attrs = {
        "allow_make_trainable": pointer.allow_make_trainable, 
        "initialized_morph": pointer.initialized_morph, 
        "initialized_syns": pointer.initialized_syns, 
        "initialized_conds": pointer.initialized_conds, 
        "xyzr": np.array(pointer.xyzr)[viewed_indices["branch_index"]].tolist(),
        "nbranches_per_cell": [pointer.nbranches_per_cell[i] for i in viewed_indices["cell_index"]],
        "total_nbranches": sum(module.nbranches_per_cell),
        "cumsum_nbranches": np.cumsum([0] + module.nbranches_per_cell),
        "comb_parents": pointer.comb_parents[viewed_indices["branch_index"]],
        "comb_branches_in_each_level": comb_branches_in_each_level,
        "branch_edges": viewed_branch_edges,
    }
    for key, value in attrs.items():
        add_to_module(module, key, value)
    
    return module

## How to make view behave like module?
- In general, most methods should be part of module if possible and passed through to view. However, while the method in module acts on the entire thing, in view it only acts on whatever is in view. This could be done by having a `self.viewed_indexes` attr in `Module`, that decides what is `set`, where things are `insert`ed etc. This would get rid of a lot of complexity and duplicate methods. For example would this make many hidden methods in `Module` obsolete. `view` can just be a property that returns `View.nodes`


1. CompartmentView.distance could be moved out of View and simplified to distance(comp1, comp2): return d(comp1.xyzr[:3], comp2.xyzr[:3]) or something similar
2. CellView.rotate -> view.rotate, and this rotates whatever xyzr is in view. 
3. `CellView.read_swc` will become from_graph

### Ideas:
- add a `groups` property

In [7]:
comp = to_module(net[0,0,0])

    This function is experimental and may not work as expected.
    Recordings, currents and Synapses are lost in this operation currently.
    


In [None]:
#TODO: add synapses?
#TODO: add currents, recordings?
#TODO: add synapses?
# pre_comp_indices = view.pointer.edges["global_pre_comp_index"]
# post_comp_indices = view.pointer.edges["global_post_comp_index"]
# viewed_comp_indices = view.view["global_comp_index"]
# # pre and post comp indices are both in viewed_comp_indices
# viewed_edge_indices = pre_comp_indices.isin(viewed_comp_indices) & post_comp_indices.isin(viewed_comp_indices)
# viewed_edges = view.pointer.edges.loc[viewed_edge_indices]
# viewed_syn_types = np.unique(view.pointer.edges["type"])
# for syn_type in viewed_syn_types:
#     viewed_edges_of_same_type = viewed_edges[viewed_edges["type"] == syn_type]
#     pre_comp_indices = viewed_edges_of_same_type["global_pre_comp_index"]
#     post_comp_indices = viewed_edges_of_same_type["global_post_comp_index"]
#     synapse = [s for s in view.pointer.synapses if s._name == syn_type]
#     pre_view = view.view.loc[pre_comp_indices]
#     post_view = view.view.loc[post_comp_indices]
#     connect(pre_view, post_view, synapse)
# module.edges
# module.synapse_param_names 
# module.synapse_state_names
# module.recordings
# module.currents
# module.current_inds