Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 1 addition & 36 deletions brainpy/base/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __sub__(self, other):
if not isinstance(other, dict):
raise ValueError(f'Only support dict, but we got {type(other)}.')
gather = type(self)()
for key, val in self.values():
for key, val in self.items():
if key in other:
if id(val) != id(other[key]):
raise ValueError(f'Cannot remove {key}, because we got two different values: '
Expand Down Expand Up @@ -170,38 +170,3 @@ def dict(self):
def data(self):
"""Get all data in each value."""
return [x.value for x in self.values()]

# @contextmanager
# def replicate(self):
# """A context manager to use in a with statement that replicates
# the variables in this collection to multiple devices.
#
# Important: replicating also updates the random state in order
# to have a new one per device.
# """
# global math
# if math is None: from brainpy import math
#
# replicated, saved_states = {}, {}
# x = jnp.zeros((jax.local_device_count(), 1), dtype=math.float_)
# sharded_x = jax.pmap(lambda x: x, axis_name='device')(x)
# devices = [b.device() for b in sharded_x.device_buffers]
# num_device = len(devices)
# for k, d in self.items():
# if isinstance(d, math.random.RandomState):
# replicated[k] = jax.device_put_sharded([shard for shard in d.split(num_device)], devices)
# saved_states[k] = d.value
# else:
# replicated[k] = jax.device_put_replicated(d.value, devices)
# self.assign(replicated)
# yield
# visited = set()
# for k, d in self.items():
# # Careful not to reduce twice in case of
# # a variable and a reference to it.
# if id(d) not in visited:
# if isinstance(d, math.random.RandomState):
# d.value = saved_states[k]
# else:
# d.value = reduce_func(d)
# visited.add(id(d))
155 changes: 84 additions & 71 deletions brainpy/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import math as pm
import warnings
from typing import Union, Dict, Callable, Sequence
from typing import Union, Dict, Callable, Sequence, List, Optional

import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -55,12 +55,13 @@ class DynamicalSystem(Base):
"""Global delay variables. Useful when the same target
variable is used in multiple mappings."""
global_delay_vars: Dict[str, bm.LengthDelay] = Collector()
global_delay_targets: Dict[str, bm.Variable] = Collector()

def __init__(self, name=None):
super(DynamicalSystem, self).__init__(name=name)

# local delay variables
self.local_delay_vars: Dict[str, bm.LengthDelay] = Collector()
self.local_delay_vars: List[str] = []

def __repr__(self):
return f'{self.__class__.__name__}(name={self.name})'
Expand Down Expand Up @@ -99,25 +100,22 @@ def __call__(self, *args, **kwargs):
def register_delay(
self,
name: str,
delay_step: Union[int, Tensor, Callable, Initializer],
delay_target: Union[bm.JaxArray, jnp.ndarray],
delay_step: Optional[Union[int, Tensor, Callable, Initializer]],
delay_target: bm.Variable,
initial_delay_data: Union[Initializer, Callable, Tensor, float, int, bool] = None,
domain: str = 'global'
):
"""Register delay variable.

Parameters
----------
name: str
The delay variable name.
delay_step: int, JaxArray, ndarray, callable, Initializer
delay_step: Optional, int, JaxArray, ndarray, callable, Initializer
The number of the steps of the delay.
delay_target: JaxArray, ndarray, Variable
The target for delay.
delay_target: Variable
The target variable for delay.
initial_delay_data: float, int, JaxArray, ndarray, callable, Initializer
The initializer for the delay data.
domain: str
The domain of the delay data to store.

Returns
-------
Expand All @@ -130,8 +128,11 @@ def register_delay(
elif isinstance(delay_step, int):
delay_type = 'homo'
elif isinstance(delay_step, (bm.ndarray, jnp.ndarray, np.ndarray)):
delay_type = 'heter'
delay_step = bm.asarray(delay_step)
if delay_step.size == 1 and delay_step.ndim == 0:
delay_type = 'homo'
else:
delay_type = 'heter'
delay_step = bm.asarray(delay_step)
elif callable(delay_step):
delay_step = init_param(delay_step, delay_target.shape, allow_none=False)
delay_type = 'heter'
Expand All @@ -145,33 +146,29 @@ def register_delay(
'then provide us the number of delay steps.')
if delay_target.shape[0] != delay_step.shape[0]:
raise ValueError(f'Shape is mismatched: {delay_target.shape[0]} != {delay_step.shape[0]}')
max_delay_step = int(bm.max(delay_step))
if delay_type != 'none':
max_delay_step = int(bm.max(delay_step))

# delay domain
if domain not in ['global', 'local']:
raise ValueError('"domain" must be a string in ["global", "local"]. '
f'Bug we got {domain}.')
# delay target
if not isinstance(delay_target, bm.Variable):
raise ValueError(f'"delay_target" must be an instance of Variable, but we got {type(delay_target)}')

# delay variable
if domain == 'local':
self.local_delay_vars[name] = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data)
self.register_implicit_nodes(self.local_delay_vars)
else:
self.global_delay_targets[name] = delay_target
if delay_type != 'none':
if name not in self.global_delay_vars:
self.global_delay_vars[name] = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data)
# save into local delay vars when first seen "var",
# for later update current value!
self.local_delay_vars[name] = self.global_delay_vars[name]
self.local_delay_vars.append(name)
else:
if self.global_delay_vars[name].num_delay_step - 1 < max_delay_step:
self.global_delay_vars[name].reset(delay_target, max_delay_step, initial_delay_data)
self.register_implicit_nodes(self.global_delay_vars)
self.register_implicit_nodes(self.global_delay_vars)
return delay_step

def get_delay_data(
self,
name: str,
delay_step: Union[int, bm.JaxArray, jnp.DeviceArray],
delay_step: Optional[Union[int, bm.JaxArray, jnp.DeviceArray]],
*indices: Union[int, bm.JaxArray, jnp.DeviceArray],
):
"""Get delay data according to the provided delay steps.
Expand All @@ -180,7 +177,7 @@ def get_delay_data(
----------
name: str
The delay variable name.
delay_step: int, JaxArray, ndarray
delay_step: Optional, int, JaxArray, ndarray
The delay length.
indices: optional, int, JaxArray, ndarray
The indices of the delay.
Expand All @@ -190,54 +187,27 @@ def get_delay_data(
delay_data: JaxArray, ndarray
The delay data at the given time.
"""
if delay_step is None:
return self.global_delay_targets[name]

if name in self.global_delay_vars:
if isinstance(delay_step, int):
return self.global_delay_vars[name](delay_step, *indices)
else:
if len(indices) == 0:
indices = (jnp.arange(delay_step.size), )
return self.global_delay_vars[name](delay_step, *indices)

elif name in self.local_delay_vars:
if isinstance(delay_step, int):
return self.local_delay_vars[name](delay_step)
else:
if len(indices) == 0:
indices = (jnp.arange(delay_step.size), )
return self.local_delay_vars[name](delay_step, *indices)
else:
raise ValueError(f'{name} is not defined in delay variables.')

def update_delay(
self,
name: str,
delay_data: Union[float, bm.JaxArray, jnp.ndarray]
):
"""Update the delay according to the delay data.

Parameters
----------
name: str
The name of the delay.
delay_data: float, JaxArray, ndarray
The delay data to update at the current time.
"""
if name in self.local_delay_vars:
return self.local_delay_vars[name].update(delay_data)
else:
if name not in self.global_delay_vars:
raise ValueError(f'{name} is not defined in delay variables.')

def reset_delay(
self,
name: str,
delay_target: Union[bm.JaxArray, jnp.DeviceArray]
):
"""Reset the delay variable."""
if name in self.local_delay_vars:
return self.local_delay_vars[name].reset(delay_target)
else:
if name not in self.global_delay_vars:
raise ValueError(f'{name} is not defined in delay variables.')
raise ValueError(f'{name} is not defined in delay variables.')

def update(self, t, dt):
"""The function to specify the updating rule.
Expand Down Expand Up @@ -297,7 +267,7 @@ def __repr__(self):
return f'{cls_name}({split.join(children)})'

def update(self, t, dt):
"""Step function of a network.
"""Update function of a container.

In this update function, the update functions in children systems are
iteratively called.
Expand All @@ -321,16 +291,6 @@ def __getattr__(self, item):
else:
return super(Container, self).__getattribute__(item)

def reset(self):
nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique()
neuron_groups = nodes.subset(NeuGroup)
synapse_groups = nodes.subset(TwoEndConn)
for node in neuron_groups.values():
node.reset()
for node in synapse_groups.values():
node.reset()
for node in (nodes - neuron_groups - synapse_groups).values():
node.reset()

@classmethod
def has(cls, **children_cls):
Expand Down Expand Up @@ -370,6 +330,59 @@ class Network(Container):
def __init__(self, *ds_tuple, name=None, **ds_dict):
super(Network, self).__init__(*ds_tuple, name=name, **ds_dict)

def update(self, t, dt):
"""Step function of a network.

In this update function, the update functions in children systems are
iteratively called.
"""
nodes = self.nodes(level=1, include_self=False)
nodes = nodes.subset(DynamicalSystem)
nodes = nodes.unique()
neuron_groups = nodes.subset(NeuGroup)
synapse_groups = nodes.subset(TwoEndConn)
other_nodes = nodes - neuron_groups - synapse_groups

# reset synapse nodes
for node in synapse_groups.values():
node.update(t, dt)

# reset neuron nodes
for node in neuron_groups.values():
node.update(t, dt)

# reset other types of nodes
for node in other_nodes.values():
node.update(t, dt)

# reset delays
for node in nodes.values():
for name in node.local_delay_vars:
self.global_delay_vars[name].update(self.global_delay_targets[name].value)

def reset(self):
nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique()
neuron_groups = nodes.subset(NeuGroup)
synapse_groups = nodes.subset(TwoEndConn)

# reset neuron nodes
for node in neuron_groups.values():
node.reset()

# reset synapse nodes
for node in synapse_groups.values():
node.reset()

# reset other types of nodes
for node in (nodes - neuron_groups - synapse_groups).values():
node.reset()

# reset delays
for node in nodes:
for name in node.local_delay_vars:
self.global_delay_vars[name].reset(self.global_delay_targets[name])



class ConstantDelay(DynamicalSystem):
"""Class used to model constant delay variables.
Expand Down Expand Up @@ -436,7 +449,7 @@ def __init__(self, size, delay, dtype=None, dt=None, **kwargs):
f"be the same with the delay data size. But "
f"we got {delay.shape[0]} != {self.size[0]}")
delay = bm.around(delay / self.dt)
self.diag = bm.array(bm.arange(self.num), dtype=bm.int_)
self.diag = bm.array(bm.arange(self.num))
self.num_step = bm.array(delay, dtype=bm.uint32) + 1
self.in_idx = bm.Variable(self.num_step - 1)
self.out_idx = bm.Variable(bm.zeros(self.num, dtype=bm.uint32))
Expand Down
Loading