# Groups (aka sectionlists)

In many cases, you might want to group several compartments (or branches, or cells) and assign a unique parameter or mechanism to this group. For example, you might want to define a couple of branches as basal and then assign a Hodgkin-Huxley mechanism only to those branches. Or you might define a couple of cells as fast spiking and assign them a high value for the sodium conductance. We describe how you can do this in this tutorial.

In [23]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
# I have experienced stability issues with float32.
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".8"

In [25]:
import time
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, value_and_grad

import jaxley as jx
from jaxley.channels import HH
from jaxley.synapses import GlutamateSynapse

### Setup

In [26]:
# Number of segments per branch.
nseg_per_branch = 8

# Stimulus.
i_delay = 3.0  # ms
i_amp = 0.05  # nA
i_dur = 2.0  # ms

# Duration and step size.
dt = 0.025  # ms
t_max = 50.0  # ms

time_vec = jnp.arange(0.0, t_max+dt, dt)

### Define a network

In [27]:
comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(nseg_per_branch)])
cell = jx.Cell([branch for _ in range(5)], parents=jnp.asarray([-1, 0, 0, 1, 1]))

In [28]:
network = jx.Network([cell for _ in range(5)])

pre = network.cell([0, 1])
post = network.cell([2, 3, 4])
pre.fully_connect(post, GlutamateSynapse())

network.insert(HH())

### Group: apical dendrites
Assume that, in each of the five neurons in this network, the second and forth branch are apical dendrites. We can define this as:

In [29]:
for cell_ind in range(5):
    network.cell(cell_ind).branch(1).add_to_group("apical")
    network.cell(cell_ind).branch(3).add_to_group("apical")

After this, we can access `network.apical` as we previously accesses anything else:

In [30]:
network.apical.set("radius", 0.3)

In [31]:
network.apical.show()

Unnamed: 0,comp_index,branch_index,cell_index,length,radius,axial_resistivity,voltages,HH,HH_gNa,HH_gK,...,HH_eNa,HH_eK,HH_eLeak,HH_m,HH_h,HH_n,global_comp_index,global_branch_index,global_cell_index,controlled_by_param
8,8,1,0,10.0,0.3,5000.0,-70.0,True,0.12,0.036,...,50.0,-77.0,-54.3,0.2,0.2,0.2,8,1,0,0
9,9,1,0,10.0,0.3,5000.0,-70.0,True,0.12,0.036,...,50.0,-77.0,-54.3,0.2,0.2,0.2,9,1,0,0
10,10,1,0,10.0,0.3,5000.0,-70.0,True,0.12,0.036,...,50.0,-77.0,-54.3,0.2,0.2,0.2,10,1,0,0
11,11,1,0,10.0,0.3,5000.0,-70.0,True,0.12,0.036,...,50.0,-77.0,-54.3,0.2,0.2,0.2,11,1,0,0
12,12,1,0,10.0,0.3,5000.0,-70.0,True,0.12,0.036,...,50.0,-77.0,-54.3,0.2,0.2,0.2,12,1,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
187,187,23,4,10.0,0.3,5000.0,-70.0,True,0.12,0.036,...,50.0,-77.0,-54.3,0.2,0.2,0.2,187,23,4,0
188,188,23,4,10.0,0.3,5000.0,-70.0,True,0.12,0.036,...,50.0,-77.0,-54.3,0.2,0.2,0.2,188,23,4,0
189,189,23,4,10.0,0.3,5000.0,-70.0,True,0.12,0.036,...,50.0,-77.0,-54.3,0.2,0.2,0.2,189,23,4,0
190,190,23,4,10.0,0.3,5000.0,-70.0,True,0.12,0.036,...,50.0,-77.0,-54.3,0.2,0.2,0.2,190,23,4,0


### Group: fast spiking
Similarly, you could define a group of fast-spiking cells. Assume that the first, second, and forth cell are fast-spiking:

In [32]:
network.cell(0).add_to_group("fast_spiking")
network.cell(1).add_to_group("fast_spiking")
network.cell(3).add_to_group("fast_spiking")

In [33]:
network.fast_spiking.set("HH_gNa", 0.4)

In [34]:
network.fast_spiking.show(indices=False)

Unnamed: 0,length,radius,axial_resistivity,voltages,HH,HH_gNa,HH_gK,HH_gLeak,HH_eNa,HH_eK,HH_eLeak,HH_m,HH_h,HH_n
0,10.0,1.0,5000.0,-70.0,True,0.4,0.036,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2
1,10.0,1.0,5000.0,-70.0,True,0.4,0.036,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2
2,10.0,1.0,5000.0,-70.0,True,0.4,0.036,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2
3,10.0,1.0,5000.0,-70.0,True,0.4,0.036,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2
4,10.0,1.0,5000.0,-70.0,True,0.4,0.036,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
155,10.0,1.0,5000.0,-70.0,True,0.4,0.036,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2
156,10.0,1.0,5000.0,-70.0,True,0.4,0.036,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2
157,10.0,1.0,5000.0,-70.0,True,0.4,0.036,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2
158,10.0,1.0,5000.0,-70.0,True,0.4,0.036,0.0003,50.0,-77.0,-54.3,0.2,0.2,0.2


### How groups are interpreted by `.make_trainable()`
If you make a parameter of a `group` trainable, then it will be treated as a single shared parameter for a given property:

In [35]:
network.fast_spiking.make_trainable("HH_gNa")

Number of newly added trainable parameters: 1. Total number of trainable parameters: 1


As such, `get_parameters()` returns only a single trainable parameter, which will be the sodium conductance for every compartment of every fast-spiking neuron:

In [36]:
network.get_parameters()

[{'HH_gNa': Array([[0.4]], dtype=float64)}]

If, instead, you would want a separate parameter for every fast-spiking cell, you should not use the group, but instead do the following (remember that fast-spiking neurons had indices [0,1,3]):

In [37]:
network.cell([0,1,3]).make_trainable("axial_resistivity")

Number of newly added trainable parameters: 3. Total number of trainable parameters: 4


In [38]:
network.get_parameters()

[{'HH_gNa': Array([[0.4]], dtype=float64)},
 {'axial_resistivity': Array([[5000.],
         [5000.],
         [5000.]], dtype=float64)}]

This generated three parameters for the axial resistivitiy, each corresponding to one cell.