# Groups (aka sectionlists)

In many cases, you might want to group several compartments (or branches, or cells) and assign a unique parameter or mechanism to this group. For example, you might want to define a couple of branches as basal and then assign a Hodgkin-Huxley mechanism only to those branches. Or you might define a couple of cells as fast spiking and assign them a high value for the sodium conductance. We describe how you can do this in this tutorial.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# I have experienced stability issues with float32.
from jax.config import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".8"

In [3]:
import time
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, value_and_grad

import neurax as nx
from neurax.channels import HHChannel
from neurax.synapses import GlutamateSynapse

### Setup

In [4]:
# Number of segments per branch.
nseg_per_branch = 8

# Stimulus.
i_delay = 3.0  # ms
i_amp = 0.05  # nA
i_dur = 2.0  # ms

# Duration and step size.
dt = 0.025  # ms
t_max = 50.0  # ms

time_vec = jnp.arange(0.0, t_max+dt, dt)

### Define a network

In [5]:
comp = nx.Compartment()
branch = nx.Branch([comp for _ in range(nseg_per_branch)])
cell = nx.Cell([branch for _ in range(5)], parents=jnp.asarray([-1, 0, 0, 1, 1]))

In [6]:
_ = np.random.seed(0)
conn_builder = nx.ConnectivityBuilder([cell.total_nbranches for _ in range(5)])
connectivities = [nx.Connectivity(GlutamateSynapse(), conn_builder.fc(np.arange(0, 2), np.arange(2, 5)))]

In [7]:
network = nx.Network([cell for _ in range(5)], connectivities)
network.insert(HHChannel())

### Group: apical dendrites
Assume that, in each of the five neurons in this network, the second and forth branch are apical dendrites. We can define this as:

In [8]:
for cell_ind in range(5):
    network.cell(cell_ind).branch(1).add_to_group("apical")
    network.cell(cell_ind).branch(3).add_to_group("apical")

After this, we can access `network.apical` as we previously accesses anything else:

In [9]:
network.apical.set_params("radius", 0.3)

In [10]:
network.apical.show()

Unnamed: 0,comp_index,branch_index,cell_index,controlled_by_param,length,radius,axial_resistivity,voltages
8,0,0,0,0,10.0,0.3,5000.0,-70.0
9,1,0,0,0,10.0,0.3,5000.0,-70.0
10,2,0,0,0,10.0,0.3,5000.0,-70.0
11,3,0,0,0,10.0,0.3,5000.0,-70.0
12,4,0,0,0,10.0,0.3,5000.0,-70.0
...,...,...,...,...,...,...,...,...
187,3,0,0,0,10.0,0.3,5000.0,-70.0
188,4,0,0,0,10.0,0.3,5000.0,-70.0
189,5,0,0,0,10.0,0.3,5000.0,-70.0
190,6,0,0,0,10.0,0.3,5000.0,-70.0


### Group: fast spiking
Similarly, you could define a group of fast-spiking cells. Assume that the first, second, and forth cell are fast-spiking:

In [11]:
network.cell(0).add_to_group("fast_spiking")
network.cell(1).add_to_group("fast_spiking")
network.cell(3).add_to_group("fast_spiking")

In [12]:
network.fast_spiking.set_params("gNa", 0.4)

In [13]:
network.fast_spiking.show("HHChannel")

Unnamed: 0,comp_index,branch_index,cell_index,gNa,gK,gLeak,m,h,n
0,0,0,0,0.4,0.036,0.0003,0.2,0.2,0.2
1,1,0,0,0.4,0.036,0.0003,0.2,0.2,0.2
2,2,0,0,0.4,0.036,0.0003,0.2,0.2,0.2
3,3,0,0,0.4,0.036,0.0003,0.2,0.2,0.2
4,4,0,0,0.4,0.036,0.0003,0.2,0.2,0.2
...,...,...,...,...,...,...,...,...,...
155,155,19,3,0.4,0.036,0.0003,0.2,0.2,0.2
156,156,19,3,0.4,0.036,0.0003,0.2,0.2,0.2
157,157,19,3,0.4,0.036,0.0003,0.2,0.2,0.2
158,158,19,3,0.4,0.036,0.0003,0.2,0.2,0.2


### How groups are interpreted by `.make_trainable()`
If you make a parameter of a `group` trainable, then it will be treated as a single shared parameter for a given property:

In [14]:
network.fast_spiking.make_trainable("gNa")

As such, `get_parameters()` returns only a single trainable parameter, which will be the sodium conductance for every compartment of every fast-spiking neuron:

In [15]:
network.get_parameters()

[{'gNa': DeviceArray([[0.4]], dtype=float64)}]

If, instead, you would want a separate parameter for every fast-spiking cell, you should not use the group, but instead do the following (remember that fast-spiking neurons had indices [0,1,3]):

In [16]:
network.cell([0,1,3]).make_trainable("axial_resistivity")

In [17]:
network.get_parameters()

[{'gNa': DeviceArray([[0.4]], dtype=float64)},
 {'axial_resistivity': DeviceArray([[5000.],
               [5000.],
               [5000.]], dtype=float64)}]

This generated three parameters for the axial resistivitiy, each corresponding to one cell.