In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".8"

In [3]:
import jaxley as jx
import numpy as np
import jax.numpy as jnp
from jaxley.connect import connect
from jaxley.channels import HH, Leak, K, CaT
from jaxley.synapses import TestSynapse, IonotropicSynapse
import matplotlib.pyplot as plt
import pandas as pd
import jax
from jax import vmap, jit

from jaxley.modules.base import View
from jaxley.utils.cell_utils import params_to_pstate, loc_of_index
from jaxley.connect import fully_connect
from jaxley.utils.cell_utils import local_index_of_loc
from copy import copy
from jax import vmap
from jaxley.connect import connectivity_matrix_connect
from typing import List, Tuple, Optional, Any, Union
from copy import deepcopy

### New channel API 
- change param / state naming convention for channels and synapses
- channels and synapses should support *args and **kwargs
- add current_name to synapses
- ensure that args and kwargs in channels functions are same order and standardized (enforce this with tests)
- make channel / synapse name user facing ?

- ensure that channel/synapse currents are prefixed on the module side
- add prepare_mechansim function (assert name does not exist already, assert conform to build rules)
- are we testing inner vs outer loop of setting states / params in init_states / _step_channel_currents etc?

- warn if states / params does not contain all global params
update_states vs. init_state
try:
    channel.update_states()
    channel.compute_current()
    channel.init_states()
except KeyError:
    warn("Some global param / state seems to be misssing")

- change `self.synapses = {}`, `self.channels = {}` to dicts and remove synapse_names, etc.
- make synapse_current_names and membrane_current_names a property
- change recordings to a dict (either `{"v": [0,1,2], ...}` or `{0: ["v", "h"], ...}`)
- merge externals and external_inds into one dict, `{1: {"i": np.zeros(1,1000), "v": np.zeros(1,1000)}}` or `{"i": {"input": np.zeros(1,1000), "indices": [0]}, ...}`

### other things
- jx.integrate(net1, t_max=0) runs one step
- `render` takes all attrs and displays them in a dataframe, keep everything in pytrees

In [625]:
class Module:
    def __init__(self, submodules = [], id = 0):
        self.groups = {}
        self._submodules = None
        self.id = id
        self._states = {}
        self._params = {}

        self.submodules = submodules

    def __repr__(self):
        repr_str = f"{self.name}[{self.id}]"
        if self.submodules is not None:
            for sm in self.submodules:
                repr_str += f"[{sm}]"
        return repr_str
    
    def __str__(self):
        return self.__repr__()
    
    def __getattr__(self, key):
        if key.startswith("__"):
            return self.__getattribute__(key)
        
        if self.submodules is not None:
            sub_name = self.submodules[0].name
            if key == sub_name:
                return self.at
            
            if key == sub_name + "s":
                return self.submodules
    
    def __iter__(self):
        yield from self.submodules

    @property
    def name(self):
        return self.__class__.__name__.lower()

    @property
    def submodules(self):
        return self._submodules
    
    @submodules.setter
    def submodules(self, submodules):
        if submodules is not None:
            self._submodules = [deepcopy(u) for u in submodules]
            for i, b in enumerate(self._submodules):
                b.id += i

    @property
    def states(self):
        states = []
        if self.submodules is None:
            channel_states = {}
            for channel in self.channels.values():
                channel_states.update(channel.channel_states)
            return {**self._states, **channel_states}
        else:
            for sm in self.submodules:
                states += [sm.states]
        return states
    
    @states.setter
    def states(self, dct):
        if self.submodules is None:
            for k,v in dct.items():
                if k in self._states:
                    self._states[k] = v
                else:
                    for c in self.channels.values():
                        if k in c.channel_states:
                            c.channel_states[k] = v
        else:
            for sm in self.submodules:
                sm.states = dct
        
    @property
    def params(self):
        params = []
        if self.submodules is None:
            channel_params = {}
            for channel in self.channels.values():
                channel_params.update(channel.channel_params)
            return {**self._params, **channel_params}
        else:
            for sm in self.submodules:
                params += [sm.params]
        return params
    
    @params.setter
    def params(self, dct):
        if self.submodules is None:
            for k,v in dct.items():
                if k in self._params:
                    self._params[k] = v
                else:
                    for c in self.channels.values():
                        if k in c.channel_params:
                            c.channel_params[k] = v
        else:
            for sm in self.submodules:
                sm.params = dct

    def at(self, index):
        if self.submodules is not None:
            return self.submodules[index]

    def insert(self, mech):
        if self.submodules is None:
            self.channels[mech.name] = deepcopy(mech)
        else:
            for sm in self.submodules:
                sm.insert(mech)

    def set(self, key, value):
        if self.submodules is None:
            if key in self.states:
                self.states = {key: value}
            elif key in self.params:
                self.params = {key: value}                   
        else:
            for sm in self.submodules:
                sm.set(key, value)

    def flatten(self):
        submodules = []
        if self.submodules is None:
            submodules += [self]
        else:
            for sm in self.submodules:
                submodules += sm.flatten()
        return submodules
    
    @property
    def flat(self):
        flat_module = Module()
        flat_module._submodules = self.flatten()
        return flat_module

    def select(self, index):
        flat_module = Module()
        comps = self.flatten()
        flat_module._submodules = [comp for i, comp in enumerate(comps) if i in index]
        return flat_module

    @property
    def tree(self):
        tree = []
        if self.submodules is None:
            return self.tree
        else:
            tree = [m.tree for m in self.submodules]
        return tree
    
    @property
    def nodes(self):
        return pd.DataFrame(self.flat.tree)
    
    @property
    def xyzr(self):
        if self.submodules is None:
            return self._xyzr
        else:
            return [m.xyzr for m in self.submodules]

class Compartment(Module):
    def __init__(self):
        super().__init__(None)
        self.recordings = {}
        self.externals = {}
        self.channels = {}
        self._params.update({"radius": 1, "length": 1, "capacitance": 1})
        self._states.update({"v": -70})
        self._xyzr = np.array([[0,0,0,self.params["radius"]]])

    @property
    def tree(self):
        leaves = {}
        leaves.update(self.states)
        leaves.update(self.params)
        leaves.update({f"{self.name}_index": self.id})
        return leaves

    @property
    def name(self):
        return "comp"


class Branch(Module):
    def __init__(self, compartments):
        super().__init__(compartments)
        

class Cell(Module):
    def __init__(self, branches = None, parents = None):
        super().__init__(branches)
        self.branch_edges = []
    

class Network(Module):
    def __init__(self, cells):
        super().__init__(cells)
        self.synapses = []

    @property
    def states(self):
        states = super().states
        synapse_states = {}
        for synapse in self.synapses.values():
            synapse_states.update(synapse.synapse_states)
        return {**states, **synapse_states}
    
    @property
    def params(self):
        params = super().params
        synapse_params = {}
        for synapse in self.synapses.values():
            synapse_params.update(synapse.synapse_params)
        return {**params, **synapse_params}
    

    def connect(self, comp1, comp2, synapse):
        self.synapses.append(synapse)
        synapse.connect(comp1, comp2)


In [626]:
comp = Compartment()
branch = Branch([comp]*4)
cell = Cell([branch]*4)
net = Network([cell]*4)

In [627]:
net.flat.states

[{'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70},
 {'v': -70}]

In [613]:
cell.branch(0).insert(HH())
cell.branch(3).insert(HH())
cell.select([0, 1]).set("HH_m", 0.123)

In [614]:
cell.nodes

Unnamed: 0,v,HH_m,HH_h,HH_n,radius,length,capacitance,HH_gNa,HH_gK,HH_gLeak,HH_eNa,HH_eK,HH_eLeak,comp_index
0,-70,0.123,0.2,0.2,1,1,1,0.12,0.036,0.0003,50.0,-77.0,-54.3,0
1,-70,0.123,0.2,0.2,1,1,1,0.12,0.036,0.0003,50.0,-77.0,-54.3,1
2,-70,0.2,0.2,0.2,1,1,1,0.12,0.036,0.0003,50.0,-77.0,-54.3,2
3,-70,0.2,0.2,0.2,1,1,1,0.12,0.036,0.0003,50.0,-77.0,-54.3,3
4,-70,,,,1,1,1,,,,,,,0
5,-70,,,,1,1,1,,,,,,,1
6,-70,,,,1,1,1,,,,,,,2
7,-70,,,,1,1,1,,,,,,,3
8,-70,,,,1,1,1,,,,,,,0
9,-70,,,,1,1,1,,,,,,,1


In [919]:
def get_params_all_trainable(net):
    net.cell("all").branch("all").loc("all").make_trainable("HH_gNa")
    params = net.get_parameters()
    params[0]["HH_gNa"] = params[0]["HH_gNa"].at[:].set(0.0)
    net.to_jax()
    pstate = params_to_pstate(params, net.indices_set_by_trainables)
    print(pstate)1
    return net.get_all_parameters(pstate, voltage_solver="jaxley.thomas")

def get_params_set(net):
    net.set("HH_gNa", 0.0)
    params = net.get_parameters()
    net.to_jax()
    pstate = params_to_pstate(params, net.indices_set_by_trainables)
    return net.get_all_parameters(pstate, voltage_solver="jaxley.thomas")

def SimpleComp():
    comp = jx.Compartment()
    return comp

def SimpleBranch(ncomp):
    comp = jx.Compartment()
    branch = jx.Branch([comp]*ncomp)
    return branch

def SimpleCell(nbranch, ncomp):
    branch = SimpleBranch(ncomp)
    cell = jx.Cell([branch]*nbranch, parents=[-1, 0, 0, 1, 1, 2, 2, 3, 3][:nbranch])
    return cell

def SimpleNet(ncell, nbranch, ncomp):
    cell = SimpleCell(nbranch, ncomp)
    net = jx.Network([cell]*ncell)
    return net

In [4]:
    # def _filter_by_mech(
    #     self, param_states: Dict, mech: Union[Channel, Synapse]
    # ) -> Dict:
    #     """Filter params/states to include only those relevant to the active mech.

    #     Args:
    #         param_states: The param_states dictionary to filter.
    #         mech: The active mechanism.

    #     Returns:
    #         The filtered dictionary.
    #     """
    #     filter_keys = []
    #     if "v" in param_states:
    #         filter_keys += ["v"] + list(mech.states)
    #     if "radius" in param_states:
    #         module_params = ["radius", "length", "axial_resistivity", "capacitance"]
    #         filter_keys += module_params + list(mech.params)
        
    #     is_global = lambda x: not x.startswith(f"{mech.name}_")
    #     filtered_param_states = {}
    #     for key in filter_keys:
    #         filtered_param_states[key] = param_states[key]
    #         if key in param_states and is_global(key):
    #             param_state_inds = self._inds_of_state_param(key)
    #             filtered_inds = index_of_a_in_b(mech.indices, param_state_inds)
    #             filtered_param_states[key] = filtered_param_states[key][filtered_inds]
        
    #     is_channel = isinstance(mech, Channel)
    #     i_mech = mech.current_name if is_channel else f"i_{mech.name}"
    #     if i_mech in param_states:
    #         filtered_param_states[i_mech] = param_states[i_mech][mech.indices]
    #     return filtered_param_states