In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".8"

In [3]:
import jaxley as jx
import numpy as np
import jax.numpy as jnp
from jaxley.connect import connect
from jaxley.channels import HH, Leak, K, CaT
from jaxley.synapses import TestSynapse, IonotropicSynapse
import matplotlib.pyplot as plt
import pandas as pd
import jax
from jax import vmap, jit

from jaxley.modules.base import View
from jaxley.utils.cell_utils import params_to_pstate, loc_of_index
from jaxley.connect import fully_connect
from jaxley.utils.cell_utils import local_index_of_loc
from copy import copy
from jax import vmap
from jaxley.connect import connectivity_matrix_connect
from typing import List, Tuple, Optional, Any, Union
from copy import deepcopy

### New channel API 
- change param / state naming convention for channels and synapses
- channels and synapses should support *args and **kwargs
- add current_name to synapses
- ensure that args and kwargs in channels functions are same order and standardized (enforce this with tests)
- make channel / synapse name user facing ?

- ensure that channel/synapse currents are prefixed on the module side
- add prepare_mechansim function (assert name does not exist already, assert conform to build rules)
- are we testing inner vs outer loop of setting states / params in init_states / _step_channel_currents etc?

- warn if states / params does not contain all global params
update_states vs. init_state
try:
    channel.update_states()
    channel.compute_current()
    channel.init_states()
except KeyError:
    warn("Some global param / state seems to be misssing")

- change `self.synapses = {}`, `self.channels = {}` to dicts and remove synapse_names, etc.
- make synapse_current_names and membrane_current_names a property
- change recordings to a dict (either `{"v": [0,1,2], ...}` or `{0: ["v", "h"], ...}`)
- merge externals and external_inds into one dict, `{1: {"i": np.zeros(1,1000), "v": np.zeros(1,1000)}}` or `{"i": {"input": np.zeros(1,1000), "indices": [0]}, ...}`

### other things
- jx.integrate(net1, t_max=0) runs one step
- `render` takes all attrs and displays them in a dataframe, keep everything in pytrees

In [996]:
def recurse(f):
    def wrapper(self, *args, **kwargs):
        out = [recurse(f)(sm, *args, **kwargs) for sm in self]
        if self.submodules is None:
            return f(self, *args, **kwargs)
        if out[0] == None:
            return None
        return out
    return wrapper

class Module:
    def __init__(self, submodules = [], index = 0):
        self.groups = {}
        self._submodules = None
        self.index = index
        self._states = {}
        self._params = {}

        self.submodules = submodules
        if submodules is not None:
            for i, sm in enumerate(self.submodules):
                sm.index = i

    @property
    def name(self):
        return self.__class__.__name__.lower()

    @property
    def submodules(self):
        return self._submodules
    
    @submodules.setter
    def submodules(self, submodules):
        if submodules is not None:
            self._submodules = [deepcopy(u) for u in submodules]

    @property
    @recurse
    def states(self):
        channel_states = {}
        for channel in self.channels.values():
            channel_states.update(channel.channel_states)
        return {**self._states, **channel_states}
    
    @states.setter
    @recurse
    def states(self, dct):
        for k,v in dct.items():
            if k in self._states:
                self._states[k] = v
            else:
                for c in self.channels.values():
                    if k in c.channel_states:
                        c.channel_states[k] = v
        
    @property
    @recurse
    def params(self):
        channel_params = {}
        for channel in self.channels.values():
            channel_params.update(channel.channel_params)
        return {**self._params, **channel_params}
    
    @params.setter
    @recurse
    def params(self, dct):
        for k,v in dct.items():
            if k in self._params:
                self._params[k] = v
            else:
                for c in self.channels.values():
                    if k in c.channel_params:
                        c.channel_params[k] = v

    @property
    def flat(self):
        if self.submodules is not None:
            flat_module = copy(self)
            flat_module._submodules = self.flatten()
            return flat_module
        return self
    
    @property
    @recurse
    def param_states(self):
        return self.param_states
    
    @property
    def nodes(self):
        if self.submodules is None:
            return pd.DataFrame(self.flat.param_states, index = [0])
        
        nodes = []
        node_inds = []
        for inds, sm in self.enumerate():
            nodes.append(sm.param_states)
            node_inds.append(inds)
        return pd.DataFrame(nodes, index = pd.MultiIndex.from_tuples(node_inds))
    
    @property
    @recurse
    def xyzr(self):
        return self._xyzr
    
    @property
    def shape(self):
        if self.submodules is not None:
            return (len(self),) + self.submodules[0].shape()
        return ()
    
    def __repr__(self, indent = ""):
        repr_str = f"{indent}{self.name}[{self.index}]"
        if self.submodules is not None:
            repr_str += "(\n"
            repr_str += "".join([f"{sm.__repr__(indent + '    ')},\n" for sm in self]) 
            repr_str += f"{indent})"
        return repr_str
    
    def __str__(self):
        return self.__repr__()
    
    def __len__(self):
        return len(self.submodules)
    
    def __getattr__(self, key):
        if key.startswith("__"):
            return self.__getattribute__(key)
        
        if self.submodules is not None:
            sub_name = self.submodules[0].name
            if key == sub_name:
                return self.at
            
            if key == sub_name + "s":
                return self.submodules
            
    def __getitem__(self, index):
        return self.at(index)
    
    def __iter__(self):
        if self.submodules is not None:
            for sm in self.submodules:
                yield sm

    def enumerate(self):
        for i, sm in enumerate(self):
            for index, sm in sm.enumerate():
                yield (i, *index), sm
        if self.submodules is None:
            yield (), self

    def at(self, index):
        if self.submodules is not None:
            return self.submodules[index]

    @recurse
    def insert(self, mech):
        self.channels[mech.name] = deepcopy(mech)

    @recurse
    def set(self, key, value):
        if key in self.states:
            self.states = {key: value}
        elif key in self.params:
            self.params = {key: value}                   

    def flatten(self, comps_only = False):
        submodules = sum([sm.flatten(comps_only) for sm in self], [])
        if comps_only and self.submodules is not None:
            return submodules
        return [self] + submodules

    def select(self, index):
        flat_module = copy(self)
        comps = [comp for i, (inds, comp) in enumerate(self.enumerate()) if i in index]
        flat_module._submodules = comps
        return flat_module
    
    @recurse
    def record(self, key):
        self.recordings[key] = jnp.empty(1000)
    
    @recurse
    def clamp(self, key, values):
        self.externals[key] = values

    def stimulate(self, values):
        self.clamp("i", values)

    @recurse
    def move(self, x, y, z):
        self.xyzr[:, :3] += np.array([x, y, z])

class Compartment(Module):
    def __init__(self):
        super().__init__(None)
        self.recordings = {}
        self.externals = {}
        self.channels = {}
        self._params.update({"radius": 1, "length": 1, "capacitance": 1})
        self._states.update({"v": -70})
        self._xyzr = np.array([[0, 0, 0, self.params["radius"]]])

    @property
    def param_states(self):
        leaves = {}
        leaves.update(self.states)
        leaves.update(self.params)
        return leaves

    @property
    def name(self):
        return "comp"
    
    @property
    def stimuli(self):
        return self.externals["i"]


class Branch(Module):
    def __init__(self, compartments):
        super().__init__(compartments)

    def __repr__(self, indent = ""):
        repr_str = f"{indent}{self.name}[{self.index}]"
        repr_str += f"(comp[{",".join([f"{sm.index}" for sm in self])}])"
        return repr_str
        

class Cell(Module):
    def __init__(self, branches = None, parents = None):
        super().__init__(branches)
        self.parents = parents


class Network(Module):
    def __init__(self, cells):
        super().__init__(cells)
        self.synapses = []

    @property
    def synapse_states(self):
        synapse_states = []
        for syn in self.synapses:
            synapse_states.append(syn.synapse_states)
        return synapse_states
    
    @property
    def synapse_params(self):
        synapse_params = []
        for syn in self.synapses:
            synapse_params.append(syn.synapse_params)
        return synapse_params

    @property
    def states(self):
        return super().states + self.synapse_states
    
    @property
    def params(self):
        return super().params + self.synapse_params

    def connect(self, comp1, comp2, synapse):
        syn = deepcopy(synapse)
        syn.pre = comp1.index
        syn.post = comp2.index
        self.synapses.append(syn)

    @property
    def edges(self):
        param_states = []
        for params, states in zip(self.synapse_params, self.synapse_states):
            param_states.append({**params, **states})
        return pd.DataFrame(param_states)


In [997]:
comp = Compartment()
branch = Branch([comp]*4)
cell = Cell([branch]*4)
net = Network([cell]*4)

In [1001]:
net

network[0](
    cell[0](
        branch[0](comp[0,1,2,3]),
        branch[1](comp[0,1,2,3]),
        branch[2](comp[0,1,2,3]),
        branch[3](comp[0,1,2,3]),
    ),
    cell[1](
        branch[0](comp[0,1,2,3]),
        branch[1](comp[0,1,2,3]),
        branch[2](comp[0,1,2,3]),
        branch[3](comp[0,1,2,3]),
    ),
    cell[2](
        branch[0](comp[0,1,2,3]),
        branch[1](comp[0,1,2,3]),
        branch[2](comp[0,1,2,3]),
        branch[3](comp[0,1,2,3]),
    ),
    cell[3](
        branch[0](comp[0,1,2,3]),
        branch[1](comp[0,1,2,3]),
        branch[2](comp[0,1,2,3]),
        branch[3](comp[0,1,2,3]),
    ),
)

In [486]:
cell.branch(0).stimulate(jnp.ones(100))

In [447]:
# net.connect(net.at(0).at(0), net.at(0).at(1), IonotropicSynapse())
# net.connect(net.at(0).at(0), net.at(0).at(2), TestSynapse())

In [448]:
cell.branch(0).insert(HH())
cell.branch(3).insert(HH())
cell.select([0, 1]).set("HH_m", 0.123)

In [450]:
cell.nodes

Unnamed: 0,Unnamed: 1,v,HH_m,HH_h,HH_n,radius,length,capacitance,HH_gNa,HH_gK,HH_gLeak,HH_eNa,HH_eK,HH_eLeak
0,0,-70,0.123,0.2,0.2,1,1,1,0.12,0.036,0.0003,50.0,-77.0,-54.3
0,1,-70,0.123,0.2,0.2,1,1,1,0.12,0.036,0.0003,50.0,-77.0,-54.3
0,2,-70,0.2,0.2,0.2,1,1,1,0.12,0.036,0.0003,50.0,-77.0,-54.3
0,3,-70,0.2,0.2,0.2,1,1,1,0.12,0.036,0.0003,50.0,-77.0,-54.3
1,0,-70,,,,1,1,1,,,,,,
1,1,-70,,,,1,1,1,,,,,,
1,2,-70,,,,1,1,1,,,,,,
1,3,-70,,,,1,1,1,,,,,,
2,0,-70,,,,1,1,1,,,,,,
2,1,-70,,,,1,1,1,,,,,,
