Skip to content

Commit

Permalink
Centralized targets & patterns of devices (in NEURON) (#153)
Browse files Browse the repository at this point in the history
* Mixed in `PatternlessDevice` for the `VoltageRecorder`

* Centralized initialisation of targets and patterns on master node

* Use new centralized targetting and pattern system in the nrn devices

* fixed targetting test
  • Loading branch information
Helveg committed Oct 30, 2020
1 parent d620a97 commit 37f7305
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 23 deletions.
1 change: 1 addition & 0 deletions bsb/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
SuffixTakenError=_e(),
ReceptorSpecificationError=_e(),
),
ParallelIntegrityError=_e("rank"),
),
ConnectivityError=_e(),
MorphologyError=_e(
Expand Down
50 changes: 45 additions & 5 deletions bsb/simulation/targetting.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,23 +85,28 @@ def _targets_cell_type(self):
"""
cell_types = [self.scaffold.get_cell_type(t) for t in self.cell_types]
if len(cell_types) != 1:
# Compile a list of the different cell type cells.
# Concatenate a list of the different cell type cells.
target_cells = np.array([])
for t in cell_types:
if t.entity:
ids = self.scaffold.get_entities_by_type(t.name)
else:
ids = self.scaffold.get_cells_by_type(t.name)[:, 0]
target_cells = np.hstack((target_cells, ids))
return target_cells
target_cells = np.concatenate((target_cells, ids))
ids = target_cells
else:
# Retrieve a single list
t = cell_types[0]
if t.entity:
ids = self.scaffold.get_entities_by_type(t.name)
else:
ids = self.scaffold.get_cells_by_type(t.name)[:, 0]
return ids
n = len(ids)
# Use the `cell_fraction` or `cell_count` attribute to determine what portion of
# the selected ids to exclude.
r_threshold = getattr(self, "cell_fraction", getattr(self, "cell_count", n) / n)
ids = ids[np.random.random_sample(n) <= r_threshold]
return ids

def _targets_representatives(self):
target_types = [
Expand All @@ -124,7 +129,42 @@ def get_targets(self):
"""
Return the targets of the device.
"""
return self._get_targets()
if hasattr(self, "_targets"):
return self._targets
raise ParallelIntegrityError(
f"MPI process %rank% failed a checkpoint."
+ " `initialise_targets` should always be called before `get_targets` on all MPI processes.",
self.adapter.pc_id,
)

def get_patterns(self):
"""
Return the patterns of the device.
"""
if hasattr(self, "_patterns"):
return self._patterns
raise ParallelIntegrityError(
f"MPI process %rank% failed a checkpoint."
+ " `initialise_patterns` should always be called before `get_patterns` on all MPI processes.",
self.adapter.pc_id,
)

def initialise_targets(self):
if self.adapter.pc_id == 0:
targets = self._get_targets()
else:
targets = None
# Broadcast to make sure all the nodes have the same targets for each device.
self._targets = self.scaffold.MPI.COMM_WORLD.bcast(targets, root=0)

def initialise_patterns(self):
if self.adapter.pc_id == 0:
# Have root 0 prepare the possibly random patterns.
patterns = self.create_patterns()
else:
patterns = None
# Broadcast to make sure all the nodes have the same patterns for each device.
self._patterns = self.scaffold.MPI.COMM_WORLD.bcast(patterns, root=0)

# Define new targetting methods above this line or they will not be registered.
neuron_targetting_types = [s[9:] for s in vars().keys() if s.startswith("_targets_")]
Expand Down
11 changes: 4 additions & 7 deletions bsb/simulators/neuron/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,14 +528,11 @@ def prepare_devices(self):
# CamelCase the snake_case to obtain the class name
device_class = "".join(x.title() for x in device.device.split("_"))
device.__class__ = device_module.__dict__[device_class]
# Re-initialise the device
# TODO: Switch to better config in v4
device.initialise(device.scaffold)
if self.pc_id == 0:
# Have root 0 prepare the possibly random patterns.
patterns = device.create_patterns()
else:
patterns = None
# Broadcast to make sure all the nodes have the same patterns for each device.
device.patterns = self.scaffold.MPI.COMM_WORLD.bcast(patterns, root=0)
device.initialise_targets()
device.initialise_patterns()

def create_devices(self):
for device in self.devices.values():
Expand Down
2 changes: 1 addition & 1 deletion bsb/simulators/neuron/devices/current_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ def create_patterns(self):
return self.parameters

def get_pattern(self, target, cell=None, section=None, synapse=None):
return self.patterns
return self.get_patterns()
2 changes: 1 addition & 1 deletion bsb/simulators/neuron/devices/spike_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def create_patterns(self):
return patterns

def get_pattern(self, target, cell=None, section=None, synapse=None):
return self.patterns[target]
return self.get_patterns()[target]


class GeneratorRecorder(PresetPathMixin, PresetMetaMixin, SimulationRecorder):
Expand Down
10 changes: 2 additions & 8 deletions bsb/simulators/neuron/devices/voltage_recorder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from ..adapter import NeuronDevice
from ..adapter import NeuronDevice, PatternlessDevice
import numpy as np


class VoltageRecorder(NeuronDevice):
class VoltageRecorder(PatternlessDevice, NeuronDevice):
casts = {"x": float}

def implement(self, target, location):
Expand All @@ -26,9 +26,3 @@ def implement(self, target, location):
)
else:
self.adapter.register_recorder(group, cell, section.record(), section=section)

def create_patterns(self):
pass

def get_pattern(self, target, cell=None, section=None, synapse=None):
pass
7 changes: 6 additions & 1 deletion tests/configs/test_double_neuron_network_relay.json
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,12 @@
"test_representatives": {
"device": "spike_generator",
"io": "input",
"targetting": "representatives"
"targetting": "representatives",
"parameters": {
"start": 0,
"number": 0,
"interval": 20
}
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions tests/test_targetting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,16 @@ def test_representatives(self):
"""
Test that 1 cell per non-relay cell model is chosen.
"""
from patch import p

config = JSONConfig(double_nn_config)
scaffold = Scaffold(config)
scaffold.compile_network()
adapter = scaffold.create_adapter("neuron")
adapter.h = p
adapter.load_balance()
device = adapter.devices["test_representatives"]
device.initialise_targets()
targets = adapter.devices["test_representatives"].get_targets()
self.assertEqual(
1,
Expand Down

0 comments on commit 37f7305

Please sign in to comment.