From 4e2b3d335c0faf70cff37ccdd39a41bf88df6a08 Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 30 May 2022 18:01:17 +0800 Subject: [PATCH 1/9] remove `brainpy.running.Minitor`. Using `brainpy.tools.DotDict` instead. --- brainpy/analysis/lowdim/lowdim_analyzer.py | 2 +- brainpy/dyn/runners.py | 104 +++++----- brainpy/integrators/runner.py | 16 +- brainpy/running/__init__.py | 1 - brainpy/running/monitor.py | 221 --------------------- brainpy/running/runner.py | 180 +++++++++-------- brainpy/tools/others/dicts.py | 8 +- brainpy/train/runners/base_runner.py | 1 + brainpy/train/runners/offline_trainer.py | 7 +- brainpy/train/runners/online_trainer.py | 15 +- examples/simulation/Brette_2007_COBA.py | 3 +- examples/simulation/Brette_2007_COBAHH.py | 3 +- 12 files changed, 182 insertions(+), 379 deletions(-) delete mode 100644 brainpy/running/monitor.py diff --git a/brainpy/analysis/lowdim/lowdim_analyzer.py b/brainpy/analysis/lowdim/lowdim_analyzer.py index 0af9e672e..8fe7ec713 100644 --- a/brainpy/analysis/lowdim/lowdim_analyzer.py +++ b/brainpy/analysis/lowdim/lowdim_analyzer.py @@ -206,7 +206,7 @@ def __init__(self, # 'x_by_y_in_fy' : # 'y_by_x_in_fx' : # 'x_by_y_in_fx' : - self.analyzed_results = tools.DictPlus() + self.analyzed_results = tools.DotDict() def show_figure(self): global pyplot diff --git a/brainpy/dyn/runners.py b/brainpy/dyn/runners.py index 44e81d9d6..7f786aaa0 100644 --- a/brainpy/dyn/runners.py +++ b/brainpy/dyn/runners.py @@ -50,7 +50,7 @@ def check_and_format_inputs(host, inputs): if not isinstance(inputs, (tuple, list)): raise RunningError('"inputs" must be a tuple/list.') if len(inputs) > 0 and not isinstance(inputs[0], (list, tuple)): - if isinstance(inputs[0], str): + if isinstance(inputs[0], (str, bm.Variable)): inputs = [inputs] else: raise RunningError('Unknown input structure, only support inputs ' @@ -76,32 +76,35 @@ def check_and_format_inputs(host, inputs): # checking 1: absolute access # Check whether the input target node is accessible, # and check whether the target node has the attribute - nodes = host.nodes(method='absolute') - nodes[host.name] = host + nodes = host.nodes(method='absolute', level=-1, include_self=True) for one_input in inputs: key = one_input[0] - if not isinstance(key, str): + if isinstance(key, bm.Variable): + real_target = key + elif isinstance(key, str): + splits = key.split('.') + target = '.'.join(splits[:-1]) + key = splits[-1] + if target == '': + real_target = host + else: + if target not in nodes: + inputs_not_found_target.append(one_input) + continue + real_target = nodes[target] + if not hasattr(real_target, key): + raise RunningError(f'Input target key "{key}" is not defined in {real_target}.') + real_target = getattr(real_target, key) + else: raise RunningError(f'For each input, input[0] must be a string to ' f'specify variable of the target, but we got {key}.') - splits = key.split('.') - target = '.'.join(splits[:-1]) - key = splits[-1] - if target == '': - real_target = host - else: - if target not in nodes: - inputs_not_found_target.append(one_input) - continue - real_target = nodes[target] - if not hasattr(real_target, key): - raise RunningError(f'Input target key "{key}" is not defined in {real_target}.') - inputs_which_found_target.append((real_target, key) + tuple(one_input[1:])) + inputs_which_found_target.append((real_target, ) + tuple(one_input[1:])) # checking 2: relative access # Check whether the input target node is accessible # and check whether the target node has the attribute if len(inputs_not_found_target): - nodes = host.nodes(method='relative') + nodes = host.nodes(method='relative', level=-1, include_self=True) for one_input in inputs_not_found_target: splits = one_input[0].split('.') target, key = '.'.join(splits[:-1]), splits[-1] @@ -110,38 +113,39 @@ def check_and_format_inputs(host, inputs): real_target = nodes[target] if not hasattr(real_target, key): raise RunningError(f'Input target key "{key}" is not defined in {real_target}.') - inputs_which_found_target.append((real_target, key) + tuple(one_input[1:])) + real_target = getattr(real_target, key) + inputs_which_found_target.append((real_target, ) + tuple(one_input[1:])) # 3. format inputs # --------- formatted_inputs = [] for one_input in inputs_which_found_target: # input value - data_value = one_input[2] + data_value = one_input[1] # input type - if len(one_input) >= 4: - if one_input[3] == 'iter': + if len(one_input) >= 3: + if one_input[2] == 'iter': if not isinstance(data_value, Iterable): - raise ValueError(f'Input "{data_value}" for "{one_input[0]}.{one_input[1]}" ' + raise ValueError(f'Input "{data_value}" for "{one_input[0]}" \n' f'is set to be "iter" type, however we got the value with ' f'the type of {type(data_value)}') - elif one_input[3] == 'func': + elif one_input[2] == 'func': if not callable(data_value): - raise ValueError(f'Input "{data_value}" for "{one_input[0]}.{one_input[1]}" ' + raise ValueError(f'Input "{data_value}" for "{one_input[0]}" \n' f'is set to be "func" type, however we got the value with ' f'the type of {type(data_value)}') - elif one_input[3] != 'fix': + elif one_input[2] != 'fix': raise RunningError(f'Only support {SUPPORTED_INPUT_TYPE} input type, but ' - f'we got "{one_input[3]}" in {one_input}') + f'we got "{one_input[2]}"') - data_type = one_input[3] + data_type = one_input[2] else: data_type = 'fix' # operation - if len(one_input) == 5: - data_op = one_input[4] + if len(one_input) == 4: + data_op = one_input[3] else: data_op = '+' if data_op not in SUPPORTED_INPUT_OPS: @@ -149,7 +153,7 @@ def check_and_format_inputs(host, inputs): f'{data_op} in {one_input}') # final - format_inp = one_input[:2] + (data_value, data_type, data_op) + format_inp = (one_input[0], data_value, data_type, data_op) formatted_inputs.append(format_inp) return formatted_inputs @@ -178,6 +182,8 @@ class DSRunner(Runner): for example ``[(target1, value1), (target2, value2)]``. """ + target: DynamicalSystem + def __init__( self, target: DynamicalSystem, @@ -228,11 +234,10 @@ def build_inputs(self, inputs): func_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} array_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} - for target, key, value, type_, op in inputs: + for variable, value, type_, op in inputs: # variable - variable = getattr(target, key) if not isinstance(variable, bm.Variable): - raise RunningError(f'"{key}" in {target} is not a dynamically changed Variable, ' + raise RunningError(f'{variable}\n is not a dynamically changed Variable, ' f'its value will not change, we think there is no need to ' f'give its input.') @@ -332,10 +337,10 @@ def predict( xs = (times, xs,) # reset the model states if reset_state: - self.target.reset_batch_state(num_batch) + self.target.reset_state(num_batch) # init monitor - for key in self.mon.item_contents.keys(): - self.mon.item_contents[key] = [] # reshape the monitor items + for key in self.mon.var_names: + self.mon[key] = [] # reshape the monitor items # init progress bar if self.progress_bar and progress_bar: self._pbar = tqdm.auto.tqdm(total=num_step) @@ -349,9 +354,11 @@ def predict( self._pbar.close() # post-running for monitors for key, val in hists.items(): - self.mon.item_contents[key] = val + self.mon[key] = val if self.numpy_mon_after_run: - self.mon.numpy() + self.mon['ts'] = np.asarray(self.mon['ts']) + for key in hists.keys(): + self.mon[key] = np.asarray(self.mon[key]) return outputs def _predict( @@ -401,8 +408,8 @@ def run(self, duration, start_t=None, shared_args: Dict = None, eval_time=False) # times times = jax.device_put(jnp.arange(start_t, end_t, self.dt)) # build monitor - for key in self.mon.item_contents.keys(): - self.mon.item_contents[key] = [] # reshape the monitor items + for key in self.mon.var_names: + self.mon[key] = [] # reshape the monitor items # running if self.progress_bar: self._pbar = tqdm.auto.tqdm(total=times.size) @@ -416,17 +423,14 @@ def run(self, duration, start_t=None, shared_args: Dict = None, eval_time=False) if self.progress_bar: self._pbar.close() # post-running - if self.jit: - self.mon.ts = times + self.dt - for key in self.mon.item_names: - self.mon.item_contents[key] = bm.asarray(hists[key]) - else: - self.mon.ts = times + self.dt - for key in self.mon.item_names: - self.mon.item_contents[key] = bm.asarray(self.mon.item_contents[key]) + self.mon.ts = times + self.dt + for key in hists.keys(): + self.mon[key] = bm.asarray(hists[key]) self._start_t = end_t if self.numpy_mon_after_run: - self.mon.numpy() + self.mon['ts'] = np.asarray(self.mon['ts']) + for key in hists.keys(): + self.mon[key] = np.asarray(self.mon[key]) if eval_time: return running_time, outputs else: diff --git a/brainpy/integrators/runner.py b/brainpy/integrators/runner.py index 7ffb14a3e..4e3954093 100644 --- a/brainpy/integrators/runner.py +++ b/brainpy/integrators/runner.py @@ -167,7 +167,7 @@ def __init__( self._dyn_args.update(dyn_args) # monitors - for k in self.mon.item_names: + for k in self.mon.var_names: if k not in self.target.variables and k not in self.fun_monitors: raise MonitorError(f'Variable "{k}" to monitor is not defined ' f'in the integrator {self.target}.') @@ -203,12 +203,12 @@ def __init__( _loop_func = bm.make_loop( self._step, dyn_vars=self.dyn_vars, - out_vars={k: self.variables[k] for k in self.mon.item_names}, + out_vars={k: self.variables[k] for k in self.monitors.keys()}, has_return=True ) else: def _loop_func(times): - out_vars = {k: [] for k in self.mon.item_names} + out_vars = {k: [] for k in self.monitors.keys()} returns = {k: [] for k in self.fun_monitors.keys()} for i in range(len(times)): _t = times[i] @@ -219,16 +219,16 @@ def _loop_func(times): # step call self._step(_t) # variable monitors - for k in self.mon.item_names: + for k in self.monitors.keys(): out_vars[k].append(bm.as_device_array(self.variables[k])) - out_vars = {k: bm.asarray(out_vars[k]) for k in self.mon.item_names} + out_vars = {k: bm.asarray(out_vars[k]) for k in self.monitors.keys()} return out_vars, returns self.step_func = _loop_func def _post(self, times, returns: dict): # monitor self.mon.ts = times + self.dt for key in returns.keys(): - self.mon.item_contents[key] = bm.asarray(returns[key]) + self.mon[key] = bm.asarray(returns[key]) def _step(self, t): # arguments @@ -296,6 +296,8 @@ def run(self, duration, start_t=None, eval_time=False): self._post(times, hists) self._start_t = end_t if self.numpy_mon_after_run: - self.mon.numpy() + self.mon.ts = np.asarray(self.mon.ts) + for key in returns.keys(): + self.mon[key] = np.asarray(self.mon[key]) if eval_time: return running_time diff --git a/brainpy/running/__init__.py b/brainpy/running/__init__.py index 0c08e726f..b6d7d1e23 100644 --- a/brainpy/running/__init__.py +++ b/brainpy/running/__init__.py @@ -6,5 +6,4 @@ """ from .parallel import * -from .monitor import * from .runner import * diff --git a/brainpy/running/monitor.py b/brainpy/running/monitor.py deleted file mode 100644 index 8284c38a1..000000000 --- a/brainpy/running/monitor.py +++ /dev/null @@ -1,221 +0,0 @@ -# -*- coding: utf-8 -*- - -import numpy as np - -from brainpy import math as bm -from brainpy.errors import MonitorError - -__all__ = [ - 'Monitor' -] - - -class Monitor(object): - """The basic Monitor class to store the past variable trajectories. - - Currently, :py:class:`brainpy.simulation.Monitor` support to specify: - - - variable key by `strings`. - - variable index by `None`, `int`, `list`, `tuple`, `1D array/tensor` - (==> all will be transformed into a 1D array/tensor) - - variable monitor interval by `None`, `int`, `float` - - Users can instance a monitor object by multiple ways: - - 1. list of strings. - - >>> Monitor(variables=['a', 'b', 'c']) - - 1.1. list of strings and list of intervals - - >>> Monitor(variables=['a', 'b', 'c'], - >>> intervals=[None, 1, 2] # ms - >>> ) - - 2. list of strings and string + indices - - >>> Monitor(variables=['a', ('b', bm.array([1,2,3])), 'c']) - - 2.1. list of string (+ indices) and list of intervals - - >>> Monitor(variables=['a', ('b', bm.array([1,2,3])), 'c'], - >>> intervals=[None, 2, 3]) - - 3. a dictionary with the format of {key: indices} - - >>> Monitor(variables={'a': None, 'b': bm.array([1,2,3])}) - - 3.1. a dictionary of variable and indexes, and a dictionary of time intervals - - >>> Monitor(variables={'a': None, 'b': bm.array([1,2,3])}, - >>> intervals={'b': 2.}) - - .. note:: - :py:class:`brainpy.simulation.Monitor` records any target variable with an - two-dimensional array/tensor with the shape of `(num_time_step, variable_size)`. - This means for any variable, no matter what's the shape of the data - (int, float, vector, matrix, 3D array/tensor), will be reshaped into a - one-dimensional vector. - - """ - - _KEYWORDS = ['_KEYWORDS', 'target', 'vars', 'intervals', 'ts', 'num_item', - 'item_names', 'item_indices', 'item_intervals', 'item_contents', - 'has_build'] - - def __init__(self, variables, intervals=None): - if isinstance(variables, (list, tuple)): - if intervals is not None: - if not isinstance(intervals, (list, tuple)): - raise MonitorError(f'"vars" and "intervals" must be the same type. ' - f'While we got type(vars)={type(variables)}, ' - f'type(intervals)={type(intervals)}.') - if len(variables) != len(intervals): - raise MonitorError(f'The length of "vars" and "every" are not equal.') - - elif isinstance(variables, dict): - if intervals is not None: - if not isinstance(intervals, dict): - raise MonitorError(f'"vars" and "every" must be the same type. ' - f'While we got type(vars)={type(variables)}, ' - f'type(intervals)={type(intervals)}.') - for key in intervals.keys(): - if key not in variables: - raise MonitorError(f'"{key}" is not in "vars": {list(variables.keys())}') - - else: - raise MonitorError(f'We only supports a format of list/tuple/dict of ' - f'"vars", while we got {type(variables)}.') - - self.has_build = False - self.ts = None - self.vars = variables - self.intervals = intervals - self.item_names = [] - self.item_indices = [] - self.item_intervals = [] - self.item_contents = dict() - self.num_item = len(variables) - self.build() - - def __repr__(self): - return (f'{self.__class__.__name__}(items={tuple(self.item_names)}, ' - f'indices={self.item_indices})') - - def build(self): - if not self.has_build: - item_names = [] - item_indices = [] - item_contents = dict() - - if isinstance(self.vars, (list, tuple)): - if self.intervals is None: - item_intervals = [None] * len(self.vars) - else: - item_intervals = list(self.intervals) - - for mon_var, interval in zip(self.vars, item_intervals): - # users monitor a variable by a string - if isinstance(mon_var, str): - mon_key = mon_var - mon_idx = None - # users monitor a variable by a tuple: `('b', bm.array([1,2,3]))` - elif isinstance(mon_var, (tuple, list)): - mon_key = mon_var[0] - mon_idx = mon_var[1] - else: - raise MonitorError(f'Unknown monitor item: {str(mon_var)}') - - # self.check(mon_key) - item_names.append(mon_key) - item_indices.append(mon_idx) - item_contents[mon_key] = [] - if interval is not None: - item_contents[f'{mon_key}.t'] = [] - - elif isinstance(self.vars, dict): - item_intervals = [] - # users monitor a variable by a dict: `{'a': None, 'b': bm.array([1,2,3])}` - for mon_key, mon_idx in self.vars.items(): - item_names.append(mon_key) - item_indices.append(mon_idx) - item_contents[mon_key] = [] - if self.intervals is None: - item_intervals.append(None) - else: - if mon_key in self.intervals: - item_intervals.append(self.intervals[mon_key]) - if self.intervals[mon_key] is not None: - item_contents[f'{mon_key}.t'] = [] - else: - raise MonitorError(f'Unknown monitors type: {type(self.vars)}') - - self.item_names = item_names - self.item_indices = item_indices - self.item_intervals = item_intervals - self.item_contents = item_contents - self.num_item = len(item_contents) - self.has_build = True - - def get(self, key): - if key not in self.item_contents: - raise ValueError(f'{key} is not defined in {self}') - return self.item_contents[key] - - def __getitem__(self, item: str): - """Get item in the monitor values. - - Parameters - ---------- - item : str - - Returns - ------- - value : ndarray - The monitored values. - """ - item_contents = super(Monitor, self).__getattribute__('item_contents') - if item not in item_contents: - raise ValueError(f'Do not have "{item}". Available items are:\n' - f'{list(item_contents.keys())}') - return item_contents[item] - - def __setitem__(self, key, value): - """Get item value in the monitor. - - Parameters - ---------- - key : str - The item key. - value : ndarray - The item value. - """ - item_contents = super(Monitor, self).__getattribute__('item_contents') - if key not in item_contents: - raise ValueError(f'Do not have "{key}". Available items are:\n' - f'{list(item_contents.keys())}') - self.item_contents[key] = value - - def __getattr__(self, item): - if item in self._KEYWORDS: - return super(Monitor, self).__getattribute__(item) - else: - item_contents = super(Monitor, self).__getattribute__('item_contents') - if item in item_contents: - return item_contents[item] - else: - super(Monitor, self).__getattribute__(item) - - def __setattr__(self, key, value): - if key in self._KEYWORDS: - object.__setattr__(self, key, value) - elif key in self.item_contents: - self.item_contents[key] = value - else: - object.__setattr__(self, key, value) - - def numpy(self): - for key, val in self.item_contents.items(): - self.item_contents[key] = np.asarray(val) - if self.ts is not None: - self.ts = np.asarray(self.ts) diff --git a/brainpy/running/runner.py b/brainpy/running/runner.py index 7a98ea963..31237a884 100644 --- a/brainpy/running/runner.py +++ b/brainpy/running/runner.py @@ -3,12 +3,14 @@ import types from typing import Callable, Dict, Sequence, Union +import numpy as np + +from brainpy import math as bm from brainpy.base import Base from brainpy.base.collector import TensorCollector from brainpy.errors import MonitorError, RunningError from brainpy.tools.checking import check_dict_data -from brainpy import math as bm -from .monitor import Monitor +from brainpy.tools.others import DotDict __all__ = [ 'Runner', @@ -22,14 +24,32 @@ class Runner(object): ---------- target: Any The target model. - monitors: None, list of str, tuple of str, Monitor + monitors: None, sequence of str, dict, Monitor Variables to monitor. + + - A list of string. Like `monitors=['a', 'b', 'c']` + - A list of string with index specification. Like `monitors=[('a', 1), ('b', [1,3,5]), 'c']` + - A dict with the explicit monitor target, like: `monitors={'a': model.spike, 'b': model.V}` + - A dict with the index specification, like: `monitors={'a': (model.spike, 0), 'b': (model.V, [1,2])}` + + fun_monitors: dict + Monitoring variables by callable functions. Should be a dict. + The `key` should be a string for later retrieval by `runner.mon[key]`. + The `value` should be a callable function which receives two arguments: `t` and `dt`. jit: bool, dict + The JIT settings. progress_bar: bool + Use progress bar or not? dyn_vars: Optional, dict + The dynamically changed variables. Instance of :py:class:`~.Variable`. numpy_mon_after_run : bool + Transform the JAX arrays into numpy ndarray or not, when finishing the network running? """ + mon: DotDict + jit: Dict[str, bool] + target: Base + def __init__( self, target: Base, @@ -54,18 +74,75 @@ def __init__( else: raise ValueError(f'Unknown "jit" setting: {jit}') - # monitors if monitors is None: - self.mon = Monitor(variables=[]) - elif isinstance(monitors, (list, tuple, dict)): - self.mon = Monitor(variables=monitors) - elif isinstance(monitors, Monitor): - self.mon = monitors - self.mon.target = self + monitors = dict() + elif isinstance(monitors, (list, tuple)): + # format string monitors + _monitors = [] + for mon in monitors: + if isinstance(mon, str): + _monitors.append((mon, None)) + elif isinstance(mon, (tuple, list)): + if not isinstance(mon[0], str) and len(mon) == 2: + raise MonitorError(f'We expect the monitor format with (name, index). But we got {mon}') + if isinstance(mon[1], (int, np.integer)): + idx = bm.array([mon[1]]) + else: + idx = mon[1] + _monitors.append((mon[0], idx)) + else: + raise MonitorError(f'We do not support monitor with {type(mon)}: {mon}') + + # get monitor targets + monitors = {} + name2node = {node.name: node for node in list(target.nodes(level=-1).unique().values())} + for mon in _monitors: + key, index = mon[0], mon[1] + splits = key.split('.') + if len(splits) == 1: + if not hasattr(target, splits[0]): + raise RunningError(f'{target} does not has variable {key}.') + monitors[key] = (getattr(target, splits[-1]), index) + else: + if not hasattr(target, splits[0]): + if splits[0] not in name2node: + raise MonitorError(f'Cannot find target {key} in monitor of {target}, please check.') + else: + master = name2node[splits[0]] + assert len(splits) == 2 + monitors[key] = (getattr(master, splits[-1]), index) + else: + master = target + for s in splits[:-1]: + try: + master = getattr(master, s) + except KeyError: + raise MonitorError(f'Cannot find {key} in {master}, please check.') + monitors[key] = (getattr(master, splits[-1]), index) + elif isinstance(monitors, dict): + _monitors = dict() + for key, val in monitors.items(): + if not isinstance(key, str): + raise MonitorError('Expect the key of the dict "monitors" must be a string. But got ' + f'{type(key)}: {key}') + if isinstance(val, bm.Variable): + val = (val, None) + if isinstance(val, (tuple, list)): + if len(val) != 2: + raise MonitorError('Expect the format of (variable, index) in the monitor setting. ' + f'But we got {val}') + if not isinstance(val[0], bm.Variable): + raise MonitorError('Expect the format of (variable, index) in the monitor setting. ' + f'But we got {val}') + _monitors[key] = val + else: + raise MonitorError('Expect the format of (variable, index) in the monitor setting. ' + f'But we got {val}') + monitors = _monitors else: - raise MonitorError(f'"monitors" only supports list/tuple/dict/ ' - f'instance of Monitor, not {type(monitors)}.') - self.mon.build() # build the monitor + raise MonitorError(f'We only supports a format of list/tuple/dict of ' + f'"vars", while we got {type(monitors)}.') + self.monitors = monitors # extra monitors if fun_monitors is None: @@ -73,6 +150,10 @@ def __init__( check_dict_data(fun_monitors, key_type=str, val_type=types.FunctionType) self.fun_monitors = fun_monitors + # monitor for user access + self.mon = DotDict() + self.mon['var_names'] = tuple(self.monitors.keys()) + tuple(self.fun_monitors.keys()) + # progress bar assert isinstance(progress_bar, bool), 'Must be a boolean variable.' self.progress_bar = progress_bar @@ -91,81 +172,14 @@ def __init__( self.numpy_mon_after_run = numpy_mon_after_run def format_monitors(self): - monitors = check_and_format_monitors(host=self.target, mon=self.mon) return_with_idx = dict() return_without_idx = dict() - for key, target, variable, idx, interval in monitors: - if interval is not None: - raise ValueError(f'Running with "{self.__class__.__name__}" does ' - f'not support "interval" in the monitor.') - data = target - for k in variable.split('.'): - data = getattr(data, k) - if not isinstance(data, bm.Variable): - raise RunningError(f'"{key}" in {target} is not a dynamically changed Variable, ' - f'its value will not change, we think there is no need to ' - f'monitor its trajectory.') + for key, (variable, idx) in self.monitors.items(): if idx is None: - return_without_idx[key] = data + return_without_idx[key] = variable else: - return_with_idx[key] = (data, bm.asarray(idx)) - + return_with_idx[key] = (variable, bm.asarray(idx)) return return_without_idx, return_with_idx def build_monitors(self, return_without_idx, return_with_idx) -> Callable: raise NotImplementedError - - -def check_and_format_monitors(host, mon): - """Return a formatted monitor items: - - >>> [(node, key, target, variable, idx, interval), - >>> ...... ] - - """ - assert isinstance(host, Base) - assert isinstance(mon, Monitor) - - formatted_mon_items = [] - - # master node: - # Check whether the input target node is accessible, - # and check whether the target node has the attribute - name2node = {node.name: node for node in list(host.nodes().unique().values())} - for key, idx, interval in zip(mon.item_names, mon.item_indices, mon.item_intervals): - # target and variable - splits = key.split('.') - if len(splits) == 1: - if not hasattr(host, splits[0]): - raise RunningError(f'{host} does not has variable {key}.') - target = host - variable = splits[-1] - else: - if not hasattr(host, splits[0]): - if splits[0] not in name2node: - raise RunningError(f'Cannot find target {key} in monitor of {host}, please check.') - else: - target = name2node[splits[0]] - assert len(splits) == 2 - variable = splits[-1] - else: - target = host - for s in splits[:-1]: - try: - target = getattr(target, s) - except KeyError: - raise RunningError(f'Cannot find {key} in {host}, please check.') - variable = splits[-1] - - # idx - if isinstance(idx, int): idx = bm.array([idx]) - - # interval - if interval is not None: - if not isinstance(interval, float): - raise RunningError(f'"interval" must be a float (denotes time), but we got {interval}') - - # append - formatted_mon_items.append((key, target, variable, idx, interval,)) - - return formatted_mon_items diff --git a/brainpy/tools/others/dicts.py b/brainpy/tools/others/dicts.py index f8c2a2fd3..9f877e3b4 100644 --- a/brainpy/tools/others/dicts.py +++ b/brainpy/tools/others/dicts.py @@ -3,16 +3,16 @@ import copy __all__ = [ - 'DictPlus', + 'DotDict', ] -class DictPlus(dict): +class DotDict(dict): """Python dictionaries with advanced dot notation access. For example: - >>> d = DictPlus({'a': 10, 'b': 20}) + >>> d = DotDict({'a': 10, 'b': 20}) >>> d.a 10 >>> d['a'] @@ -49,7 +49,7 @@ def __setattr__(self, name, value): self[name] = value def __setitem__(self, name, value): - super(DictPlus, self).__setitem__(name, value) + super(DotDict, self).__setitem__(name, value) try: p = object.__getattribute__(self, '__parent') key = object.__getattribute__(self, '__key') diff --git a/brainpy/train/runners/base_runner.py b/brainpy/train/runners/base_runner.py index a8808132f..180862b11 100644 --- a/brainpy/train/runners/base_runner.py +++ b/brainpy/train/runners/base_runner.py @@ -17,6 +17,7 @@ class DSTrainer(DSRunner): """Structural Trainer for Models with Recurrent Dynamics.""" + target: TrainingSystem train_nodes: Sequence[TrainingSystem] # need to be initialized by subclass def __init__( diff --git a/brainpy/train/runners/offline_trainer.py b/brainpy/train/runners/offline_trainer.py index 148b983da..f1bb1b401 100644 --- a/brainpy/train/runners/offline_trainer.py +++ b/brainpy/train/runners/offline_trainer.py @@ -2,6 +2,7 @@ from typing import Dict, Sequence, Union, Callable +import numpy as np import tqdm.auto from jax.experimental.host_callback import id_tap @@ -158,9 +159,11 @@ def fit( # final things for node in self.train_nodes: - self.mon.item_contents.pop(f'{node.name}-fit_record') + self.mon.pop(f'{node.name}-fit_record') if self.true_numpy_mon_after_run: - self.mon.numpy() + for key in self.mon.keys(): + if key != 'var_names': + self.mon[key] = np.asarray(self.mon[key]) def f_train(self, shared_kwargs: Dict = None) -> Callable: """Get training function.""" diff --git a/brainpy/train/runners/online_trainer.py b/brainpy/train/runners/online_trainer.py index 82e54816a..75cb71b45 100644 --- a/brainpy/train/runners/online_trainer.py +++ b/brainpy/train/runners/online_trainer.py @@ -2,6 +2,7 @@ from typing import Dict, Sequence, Union, Callable +import numpy as np import tqdm.auto from jax.experimental.host_callback import id_tap from jax.tree_util import tree_map @@ -118,11 +119,11 @@ def fit( # reset the model states if reset_state: - self.target.reset_batch_state(num_batch) + self.target.reset_state(num_batch) # init monitor - for key in self.mon.item_contents.keys(): - self.mon.item_contents[key] = [] # reshape the monitor items + for key in self.mon.var_names: + self.mon[key] = [] # reshape the monitor items # init progress bar if self.progress_bar: @@ -137,10 +138,12 @@ def fit( self._pbar.close() # post-running for monitors - for key in self.mon.item_names: - self.mon.item_contents[key] = hists[key] + for key in hists.keys(): + self.mon[key] = hists[key] if self.numpy_mon_after_run: - self.mon.numpy() + self.mon.ts = np.asarray(self.mon.ts) + for key in hists.keys(): + self.mon[key] = np.asarray(self.mon[key]) def _fit( self, diff --git a/examples/simulation/Brette_2007_COBA.py b/examples/simulation/Brette_2007_COBA.py index dc0805fc3..2065fdfcb 100644 --- a/examples/simulation/Brette_2007_COBA.py +++ b/examples/simulation/Brette_2007_COBA.py @@ -38,8 +38,7 @@ def __init__(self, scale=1.0, method='exp_auto'): runner = bp.dyn.DSRunner(net, monitors=['E.spike'], inputs=[('E.input', 20.), ('I.input', 20.)]) -t = runner.run(100.) -print(t) +runner.run(100.) # visualization bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True) diff --git a/examples/simulation/Brette_2007_COBAHH.py b/examples/simulation/Brette_2007_COBAHH.py index bd26f161e..ce4768dae 100644 --- a/examples/simulation/Brette_2007_COBAHH.py +++ b/examples/simulation/Brette_2007_COBAHH.py @@ -35,6 +35,5 @@ def __init__(self, scale=1.): net = EINet(scale=1) runner = bp.dyn.DSRunner(net, monitors=['E.spike']) -t = runner.run(100.) -print(t) +runner.run(100.) bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True) From 1afaf4962077c48066d3cb70cc49bddd3bfbebfa Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 2 Jun 2022 13:48:01 +0800 Subject: [PATCH 2/9] update runners and replace `Monitor` to `DotDict` --- brainpy/compat/__init__.py | 4 - brainpy/compat/monitor.py | 21 --- brainpy/compat/nn/runners/back_propagation.py | 14 +- brainpy/compat/nn/runners/offline_trainer.py | 27 +-- brainpy/compat/nn/runners/online_trainer.py | 20 +- brainpy/compat/nn/runners/rnn_runner.py | 18 +- brainpy/compat/tests/test_integrator_rnn.py | 82 ++++++++ .../compat/tests/test_ngrc_double_scroll.py | 151 +++++++++++++++ brainpy/compat/tests/test_ngrc_lorenz.py | 151 +++++++++++++++ .../tests/test_ngrc_lorenz_inference.py | 176 ++++++++++++++++++ brainpy/integrators/runner.py | 43 +++-- brainpy/running/runner.py | 111 ++++++----- brainpy/train/runners/back_propagation.py | 12 +- brainpy/train/runners/offline_trainer.py | 6 +- brainpy/train/runners/online_trainer.py | 6 +- 15 files changed, 715 insertions(+), 127 deletions(-) delete mode 100644 brainpy/compat/monitor.py create mode 100644 brainpy/compat/tests/test_integrator_rnn.py create mode 100644 brainpy/compat/tests/test_ngrc_double_scroll.py create mode 100644 brainpy/compat/tests/test_ngrc_lorenz.py create mode 100644 brainpy/compat/tests/test_ngrc_lorenz_inference.py diff --git a/brainpy/compat/__init__.py b/brainpy/compat/__init__.py index 5fa1bcbe7..c77b01528 100644 --- a/brainpy/compat/__init__.py +++ b/brainpy/compat/__init__.py @@ -13,9 +13,6 @@ 'set_default_odeint', 'set_default_sdeint', 'get_default_odeint', 'get_default_sdeint', - # monitor - 'Monitor', - # runners 'IntegratorRunner', 'DSRunner', 'StructRunner', 'ReportRunner' ] @@ -23,5 +20,4 @@ from . import brainobjects, layers, nn from .brainobjects import * from .integrators import * -from .monitor import * from .runners import * diff --git a/brainpy/compat/monitor.py b/brainpy/compat/monitor.py deleted file mode 100644 index c21cf0da0..000000000 --- a/brainpy/compat/monitor.py +++ /dev/null @@ -1,21 +0,0 @@ -# -*- coding: utf-8 -*- -import warnings - -from brainpy.running import monitor - -__all__ = [ - 'Monitor' -] - - -class Monitor(monitor.Monitor): - """Monitor class. - - .. deprecated:: 2.1.0 - Please use "brainpy.running.Monitor" instead. - """ - def __init__(self, *args, **kwargs): - warnings.warn('Please use "brainpy.running.Monitor" instead. ' - '"brainpy.Monitor" is deprecated since version 2.1.0.', - DeprecationWarning) - super(Monitor, self).__init__(*args, **kwargs) diff --git a/brainpy/compat/nn/runners/back_propagation.py b/brainpy/compat/nn/runners/back_propagation.py index 0b3b4141a..85757e63e 100644 --- a/brainpy/compat/nn/runners/back_propagation.py +++ b/brainpy/compat/nn/runners/back_propagation.py @@ -564,18 +564,20 @@ def predict( if reset: self.target.initialize(num_batch) # init monitor - for key in self.mon.item_contents.keys(): - self.mon.item_contents[key] = [] # reshape the monitor items + for key in self.mon.var_names: + self.mon[key] = [] # reshape the monitor items # prediction outputs, hists = self._predict(xs=xs, forced_states=forced_states, forced_feedbacks=forced_feedbacks, shared_kwargs=shared_kwargs) # post-running for monitors - for key in self.mon.item_names: - self.mon.item_contents[key] = hists[key] + for key in hists.keys(): + self.mon[key] = hists[key] if self.numpy_mon_after_run: - self.mon.numpy() + self.mon.ts = np.asarray(self.mon.ts) + for key in hists.keys(): + self.mon[key] = np.asarray(self.mon[key]) return outputs def _check_forced_states(self, forced_states, num_batch): @@ -722,7 +724,7 @@ def _make_predict_func(self, shared_kwargs: Dict): f'but got {type(shared_kwargs)}') def run_func(xs, forced_states, forced_feedbacks): - monitors = self.mon.item_contents.keys() + monitors = self.mon.var_names return self.target(xs, forced_states=forced_states, forced_feedbacks=forced_feedbacks, diff --git a/brainpy/compat/nn/runners/offline_trainer.py b/brainpy/compat/nn/runners/offline_trainer.py index a4ab65784..f333deffb 100644 --- a/brainpy/compat/nn/runners/offline_trainer.py +++ b/brainpy/compat/nn/runners/offline_trainer.py @@ -5,6 +5,7 @@ import tqdm.auto from jax.experimental.host_callback import id_tap +import numpy as np from brainpy.base import Base import brainpy.math as bm from brainpy.errors import NoImplementationError @@ -182,12 +183,12 @@ def fit( for node in self.target.entry_nodes: if node in self.train_nodes: inputs = node.data_pass_func({node.name: xs[node.name]}) - self.mon.item_contents[f'{node.name}.inputs'] = inputs + self.mon[f'{node.name}.inputs'] = inputs self._added_items.add(f'{node.name}.inputs') elif isinstance(self.target, Node): if self.target in self.train_nodes: inputs = self.target.data_pass_func({self.target.name: xs[self.target.name]}) - self.mon.item_contents[f'{self.target.name}.inputs'] = inputs + self.mon[f'{self.target.name}.inputs'] = inputs self._added_items.add(f'{self.target.name}.inputs') # format target data @@ -201,8 +202,8 @@ def fit( # training monitor_data = dict() for node in self.train_nodes: - monitor_data[f'{node.name}.inputs'] = self.mon.item_contents.get(f'{node.name}.inputs', None) - monitor_data[f'{node.name}.feedbacks'] = self.mon.item_contents.get(f'{node.name}.feedbacks', None) + monitor_data[f'{node.name}.inputs'] = self.mon.get(f'{node.name}.inputs', None) + monitor_data[f'{node.name}.feedbacks'] = self.mon.get(f'{node.name}.feedbacks', None) self.f_train(shared_kwargs)(monitor_data, ys) # close the progress bar @@ -211,9 +212,11 @@ def fit( # final things for key in self._added_items: - self.mon.item_contents.pop(key) + self.mon.pop(key) if self.true_numpy_mon_after_run: - self.mon.numpy() + for key in self.mon.keys(): + if key != 'var_names': + self.mon[key] = np.asarray(self.mon[key]) def f_train(self, shared_kwargs: Dict = None) -> Callable: """Get training function.""" @@ -248,14 +251,14 @@ def _add_monitor_items(self): if isinstance(self.target, Network): for node in self.train_nodes: if node not in self.target.entry_nodes: - if f'{node.name}.inputs' not in self.mon.item_names: - self.mon.item_names.append(f'{node.name}.inputs') - self.mon.item_contents[f'{node.name}.inputs'] = [] + if f'{node.name}.inputs' not in self.mon.var_names: + self.mon.var_names += (f'{node.name}.inputs', ) + self.mon[f'{node.name}.inputs'] = [] added_items.add(f'{node.name}.inputs') if node in self.target.fb_senders: - if f'{node.name}.feedbacks' not in self.mon.item_names: - self.mon.item_names.append(f'{node.name}.feedbacks') - self.mon.item_contents[f'{node.name}.feedbacks'] = [] + if f'{node.name}.feedbacks' not in self.mon.var_names: + self.mon.var_names += (f'{node.name}.feedbacks',) + self.mon[f'{node.name}.feedbacks'] = [] added_items.add(f'{node.name}.feedbacks') else: # brainpy.nn.Node instance does not need to monitor its inputs diff --git a/brainpy/compat/nn/runners/online_trainer.py b/brainpy/compat/nn/runners/online_trainer.py index eb4a005aa..279052fbb 100644 --- a/brainpy/compat/nn/runners/online_trainer.py +++ b/brainpy/compat/nn/runners/online_trainer.py @@ -6,6 +6,7 @@ from jax.experimental.host_callback import id_tap from jax.tree_util import tree_map +import numpy as np from brainpy.base import Base import brainpy.math as bm from brainpy.errors import NoImplementationError @@ -135,8 +136,8 @@ def fit( self.target.initialize(num_batch) # init monitor - for key in self.mon.item_contents.keys(): - self.mon.item_contents[key] = [] # reshape the monitor items + for key in self.mon.var_names: + self.mon[key] = [] # reshape the monitor items # init progress bar if self.progress_bar: @@ -157,10 +158,11 @@ def fit( self._pbar.close() # post-running for monitors - for key in self.mon.item_names: - self.mon.item_contents[key] = hists[key] + for key in hists.keys(): + self.mon[key] = hists[key] if self.numpy_mon_after_run: - self.mon.numpy() + for key in hists.keys(): + self.mon[key] = np.asarray(self.mon[key]) def _fit( self, @@ -215,7 +217,7 @@ def _make_fit_func(self, shared_kwargs: Dict): def _step_func(all_inputs): xs, ys, forced_states, forced_feedbacks = all_inputs - monitors = tuple(self.mon.item_contents.keys()) + monitors = tuple(self.mon.var_names) _, outs = self.target(xs, forced_states=forced_states, @@ -243,7 +245,7 @@ def _step_func(all_inputs): else: def run_func(all_inputs): xs, ys, forced_states, forced_feedbacks = all_inputs - monitors = {key: [] for key in self.mon.item_contents.keys()} + monitors = {key: [] for key in self.mon.var_names} num_step = check_data_batch_size(xs) for i in range(num_step): one_xs = {key: tensor[i] for key, tensor in xs.items()} @@ -261,9 +263,9 @@ def run_func(all_inputs): def _add_monitor_items(self): added_items = set() for node in self.train_nodes: - if f'{node.name}.inputs' not in self.mon.item_names: + if f'{node.name}.inputs' not in self.mon.var_names: added_items.add(f'{node.name}.inputs') - if f'{node.name}.feedbacks' not in self.mon.item_names: + if f'{node.name}.feedbacks' not in self.mon.var_names: added_items.add(f'{node.name}.feedbacks') return tuple(added_items) diff --git a/brainpy/compat/nn/runners/rnn_runner.py b/brainpy/compat/nn/runners/rnn_runner.py index 9c1951ee4..846f272cf 100644 --- a/brainpy/compat/nn/runners/rnn_runner.py +++ b/brainpy/compat/nn/runners/rnn_runner.py @@ -3,6 +3,7 @@ from typing import Dict, Union import jax.numpy as jnp +import numpy as np import tqdm.auto from jax.experimental.host_callback import id_tap from jax.tree_util import tree_map @@ -41,6 +42,8 @@ class RNNRunner(Runner): Change the monitored iterm into NumPy arrays. """ + target: Node + def __init__(self, target: Node, jit=True, **kwargs): super(RNNRunner, self).__init__(target=target, **kwargs) assert isinstance(self.target, Node), '"target" must be an instance of brainpy.nn.Node.' @@ -127,8 +130,8 @@ def predict( if reset: self.target.initialize(num_batch) # init monitor - for key in self.mon.item_contents.keys(): - self.mon.item_contents[key] = [] # reshape the monitor items + for key in self.mon.var_names: + self.mon[key] = [] # reshape the monitor items # init progress bar if self.progress_bar and progress_bar: if num_step is None: @@ -144,10 +147,11 @@ def predict( if self.progress_bar and progress_bar: self._pbar.close() # post-running for monitors - for key in self.mon.item_names: - self.mon.item_contents[key] = hists[key] + for key in hists.keys(): + self.mon[key] = hists[key] if self.numpy_mon_after_run: - self.mon.numpy() + for key in hists.keys(): + self.mon[key] = np.asarray(self.mon[key]) return outputs def _predict( @@ -200,7 +204,7 @@ def _make_predict_func(self, shared_kwargs: Dict): def _step_func(a_input): xs, forced_states, forced_feedbacks = a_input - monitors = self.mon.item_contents.keys() + monitors = self.mon.var_names outs = self.target(xs, forced_states=forced_states, forced_feedbacks=forced_feedbacks, @@ -225,7 +229,7 @@ def run_func(all_inputs): else: outputs = [] output_type = 'node' - monitors = {key: [] for key in self.mon.item_contents.keys()} + monitors = {key: [] for key in self.mon.var_names} num_step = check_data_batch_size(xs) for i in range(num_step): one_xs = {key: tensor[i] for key, tensor in xs.items()} diff --git a/brainpy/compat/tests/test_integrator_rnn.py b/brainpy/compat/tests/test_integrator_rnn.py new file mode 100644 index 000000000..83dcca39b --- /dev/null +++ b/brainpy/compat/tests/test_integrator_rnn.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- + +from functools import partial + +import matplotlib.pyplot as plt + +import brainpy as bp +import brainpy.math as bm + + +block = False +dt = 0.04 +num_step = int(1.0 / dt) +num_batch = 128 + + +@partial(bm.jit, + dyn_vars=bp.TensorCollector({'a': bm.random.DEFAULT}), + static_argnames=['batch_size']) +def build_inputs_and_targets(mean=0.025, scale=0.01, batch_size=10): + # Create the white noise input + sample = bm.random.normal(size=(batch_size, 1, 1)) + bias = mean * 2.0 * (sample - 0.5) + samples = bm.random.normal(size=(batch_size, num_step, 1)) + noise_t = scale / dt ** 0.5 * samples + inputs = bias + noise_t + targets = bm.cumsum(inputs, axis=1) + return inputs, targets + + +def train_data(): + for _ in range(10): + yield build_inputs_and_targets(batch_size=num_batch) + + +def test_rnn_training(): + model = ( + bp.nn.Input(1) + >> + bp.nn.VanillaRNN(100, state_trainable=True) + >> + bp.nn.Dense(1) + ) + model.initialize(num_batch=num_batch) + + + # define loss function + def loss(predictions, targets, l2_reg=2e-4): + mse = bp.losses.mean_squared_error(predictions, targets) + l2 = l2_reg * bp.losses.l2_norm(model.train_vars().unique().dict()) ** 2 + return mse + l2 + + + # define optimizer + lr = bp.optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975) + opt = bp.optim.Adam(lr=lr, eps=1e-1) + + # create a trainer + trainer = bp.nn.BPTT(model, + loss=loss, + optimizer=opt, + max_grad_norm=5.0) + trainer.fit(train_data, + num_batch=num_batch, + num_train=5, + num_report=10) + + plt.plot(trainer.train_losses.numpy()) + plt.show(block=block) + + model.initialize(1) + x, y = build_inputs_and_targets(batch_size=1) + predicts = trainer.predict(x) + + plt.figure(figsize=(8, 2)) + plt.plot(bm.as_numpy(y[0]).flatten(), label='Ground Truth') + plt.plot(bm.as_numpy(predicts[0]).flatten(), label='Prediction') + plt.legend() + plt.show(block=block) + plt.close() + + diff --git a/brainpy/compat/tests/test_ngrc_double_scroll.py b/brainpy/compat/tests/test_ngrc_double_scroll.py new file mode 100644 index 000000000..4f4495561 --- /dev/null +++ b/brainpy/compat/tests/test_ngrc_double_scroll.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- + +"""Implementation of the paper: + +- Gauthier, D.J., Bollt, E., Griffith, A. et al. Next generation reservoir + computing. Nat Commun 12, 5564 (2021). https://doi.org/10.1038/s41467-021-25801-2 + +The main task is forecasting the double-scroll system. +""" + + +import matplotlib.pyplot as plt +import numpy as np + +import brainpy as bp +import brainpy.math as bm + + +block = False + + +def get_subset(data, start, end): + res = {'x': data['x'][start: end], + 'y': data['y'][start: end], + 'z': data['z'][start: end]} + res = bm.hstack([res['x'], res['y'], res['z']]) + return res.reshape((1,) + res.shape) + + +def plot_weights(Wout, coefs, bias=None): + Wout = np.asarray(Wout) + if bias is not None: + bias = np.asarray(bias) + Wout = np.concatenate([bias.reshape((1, 3)), Wout], axis=0) + coefs.insert(0, 'bias') + x_Wout, y_Wout, z_Wout = Wout[:, 0], Wout[:, 1], Wout[:, 2] + + fig = plt.figure(figsize=(10, 10)) + ax = fig.add_subplot(131) + ax.grid(axis="y") + ax.set_xlabel("$[W_{out}]_x$") + ax.set_ylabel("Features") + ax.set_yticks(np.arange(len(coefs))) + ax.set_yticklabels(coefs) + ax.barh(np.arange(x_Wout.size), x_Wout) + + ax1 = fig.add_subplot(132) + ax1.grid(axis="y") + ax1.set_yticks(np.arange(len(coefs))) + ax1.set_xlabel("$[W_{out}]_y$") + ax1.barh(np.arange(y_Wout.size), y_Wout) + + ax2 = fig.add_subplot(133) + ax2.grid(axis="y") + ax2.set_yticks(np.arange(len(coefs))) + ax2.set_xlabel("$[W_{out}]_z$") + ax2.barh(np.arange(z_Wout.size), z_Wout) + + plt.show(block=block) + + +def plot_double_scroll(ground_truth, predictions): + fig = plt.figure(figsize=(15, 10)) + ax = fig.add_subplot(121, projection='3d') + ax.set_title("Generated attractor") + ax.set_xlabel("$x$") + ax.set_ylabel("$y$") + ax.set_zlabel("$z$") + ax.grid(False) + ax.plot(predictions[:, 0], predictions[:, 1], predictions[:, 2]) + + ax2 = fig.add_subplot(122, projection='3d') + ax2.set_title("Real attractor") + ax2.grid(False) + ax2.plot(ground_truth[:, 0], ground_truth[:, 1], ground_truth[:, 2]) + plt.show(block=block) + + +dt = 0.02 +t_warmup = 10. # ms +t_train = 100. # ms +t_test = 800. # ms +num_warmup = int(t_warmup / dt) # warm up NVAR +num_train = int(t_train / dt) +num_test = int(t_test / dt) + + +def test_ngrc_double_scroll(): + bm.enable_x64() + + # Datasets # + # -------- # + data_series = bp.datasets.double_scroll_series(t_warmup + t_train + t_test, dt=dt) + + X_warmup = get_subset(data_series, 0, num_warmup - 1) + Y_warmup = get_subset(data_series, 1, num_warmup) + X_train = get_subset(data_series, num_warmup - 1, num_warmup + num_train - 1) + # Target: Lorenz[t] - Lorenz[t - 1] + dX_train = get_subset(data_series, num_warmup, num_warmup + num_train) - X_train + X_test = get_subset(data_series, + num_warmup + num_train - 1, + num_warmup + num_train + num_test - 1) + Y_test = get_subset(data_series, + num_warmup + num_train, + num_warmup + num_train + num_test) + + # Model # + # ----- # + + i = bp.nn.Input(3) + r = bp.nn.NVAR(delay=2, order=3) + di = bp.nn.LinearReadout(3, trainable=True, name='readout') + o = bp.nn.Summation() + # + # Cannot express the model as + # + # [i >> r >> di, i] >> o + # (i >> r >> di, i) >> o + # because it will concatenate the outputs of "i" and "di", + # then feed into the node "o". This is not the connection + # we want. + model = {i >> r >> di, i} >> o + # model = (i >> r >> di >> o) & (i >> o) + model.plot_node_graph() + model.initialize(num_batch=1) + + # Training # + # -------- # + + # warm-up + trainer = bp.nn.RidgeTrainer(model, beta=1e-5, jit=True) + + # training + outputs = trainer.predict(X_warmup) + print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup)) + trainer.fit([X_train, {'readout': dX_train}]) + plot_weights(di.Wff, r.get_feature_names_for_plot(), di.bias) + + # prediction + model = bm.jit(model) + outputs = [model(X_test[:, 0])] + for i in range(1, X_test.shape[1]): + outputs.append(model(outputs[i - 1])) + outputs = bm.asarray(outputs).squeeze() + print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test)) + plot_double_scroll(Y_test.numpy().squeeze(), outputs.numpy()) + plt.close() + + bm.disable_x64() + bp.base.clear_name_cache(True) + diff --git a/brainpy/compat/tests/test_ngrc_lorenz.py b/brainpy/compat/tests/test_ngrc_lorenz.py new file mode 100644 index 000000000..9766b1b6b --- /dev/null +++ b/brainpy/compat/tests/test_ngrc_lorenz.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- + +"""Implementation of the paper: + +- Gauthier, D.J., Bollt, E., Griffith, A. et al. Next generation reservoir + computing. Nat Commun 12, 5564 (2021). https://doi.org/10.1038/s41467-021-25801-2 + +The main task is forecasting the Lorenz63 strange attractor. +""" + +import matplotlib.pyplot as plt +import numpy as np + +import brainpy as bp +import brainpy.math as bm + +block = False + + +def get_subset(data, start, end): + res = {'x': data['x'][start: end], + 'y': data['y'][start: end], + 'z': data['z'][start: end]} + res = bm.hstack([res['x'], res['y'], res['z']]) + return res.reshape((1,) + res.shape) + + +def plot_weights(Wout, coefs, bias=None): + Wout = np.asarray(Wout) + if bias is not None: + bias = np.asarray(bias) + Wout = np.concatenate([bias.reshape((1, 3)), Wout], axis=0) + coefs.insert(0, 'bias') + x_Wout, y_Wout, z_Wout = Wout[:, 0], Wout[:, 1], Wout[:, 2] + + fig = plt.figure(figsize=(10, 10)) + ax = fig.add_subplot(131) + ax.grid(axis="y") + ax.set_xlabel("$[W_{out}]_x$") + ax.set_ylabel("Features") + ax.set_yticks(np.arange(len(coefs))) + ax.set_yticklabels(coefs) + ax.barh(np.arange(x_Wout.size), x_Wout) + + ax1 = fig.add_subplot(132) + ax1.grid(axis="y") + ax1.set_yticks(np.arange(len(coefs))) + ax1.set_xlabel("$[W_{out}]_y$") + ax1.barh(np.arange(y_Wout.size), y_Wout) + + ax2 = fig.add_subplot(133) + ax2.grid(axis="y") + ax2.set_yticks(np.arange(len(coefs))) + ax2.set_xlabel("$[W_{out}]_z$") + ax2.barh(np.arange(z_Wout.size), z_Wout) + + plt.show(block=block) + + +def plot_lorenz(ground_truth, predictions): + fig = plt.figure(figsize=(15, 10)) + ax = fig.add_subplot(121, projection='3d') + ax.set_title("Generated attractor") + ax.set_xlabel("$x$") + ax.set_ylabel("$y$") + ax.set_zlabel("$z$") + ax.grid(False) + ax.plot(predictions[:, 0], predictions[:, 1], predictions[:, 2]) + + ax2 = fig.add_subplot(122, projection='3d') + ax2.set_title("Real attractor") + ax2.grid(False) + ax2.plot(ground_truth[:, 0], ground_truth[:, 1], ground_truth[:, 2]) + plt.show(block=block) + + +dt = 0.01 +t_warmup = 5. # ms +t_train = 10. # ms +t_test = 120. # ms +num_warmup = int(t_warmup / dt) # warm up NVAR +num_train = int(t_train / dt) +num_test = int(t_test / dt) + + +def test_ngrc_lorenz(): + bm.enable_x64() + + # Datasets # + # -------- # + lorenz_series = bp.datasets.lorenz_series(t_warmup + t_train + t_test, + dt=dt, + inits={'x': 17.67715816276679, + 'y': 12.931379185960404, + 'z': 43.91404334248268}) + + X_warmup = get_subset(lorenz_series, 0, num_warmup - 1) + Y_warmup = get_subset(lorenz_series, 1, num_warmup) + X_train = get_subset(lorenz_series, num_warmup - 1, num_warmup + num_train - 1) + # Target: Lorenz[t] - Lorenz[t - 1] + dX_train = get_subset(lorenz_series, num_warmup, num_warmup + num_train) - X_train + X_test = get_subset(lorenz_series, + num_warmup + num_train - 1, + num_warmup + num_train + num_test - 1) + Y_test = get_subset(lorenz_series, + num_warmup + num_train, + num_warmup + num_train + num_test) + + # Model # + # ----- # + + i = bp.nn.Input(3) + r = bp.nn.NVAR(delay=2, order=2, constant=True) + di = bp.nn.LinearReadout(3, bias_initializer=None, trainable=True, name='readout') + o = bp.nn.Summation() + # + # Cannot express the model as + # + # [i >> r >> di, i] >> o + # because it will concatenate the outputs of "i" and "di", + # then feed into the node "o". This is not the connection + # we want. + model = (i >> r >> di >> o) & (i >> o) + # model.plot_node_graph() + model.initialize(num_batch=1) + + print(r.get_feature_names()) + + # Training # + # -------- # + + # warm-up + trainer = bp.nn.RidgeTrainer(model, beta=2.5e-6) + + # training + outputs = trainer.predict(X_warmup) + print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup)) + trainer.fit([X_train, {'readout': dX_train}]) + plot_weights(di.Wff, r.get_feature_names_for_plot(), di.bias) + + # prediction + model = bm.jit(model) + outputs = [model(X_test[:, 0])] + for i in range(1, X_test.shape[1]): + outputs.append(model(outputs[i - 1])) + outputs = bm.asarray(outputs) + print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test)) + plot_lorenz(Y_test.numpy().squeeze(), outputs.numpy().squeeze()) + plt.close() + bm.disable_x64() + bp.base.clear_name_cache(True) diff --git a/brainpy/compat/tests/test_ngrc_lorenz_inference.py b/brainpy/compat/tests/test_ngrc_lorenz_inference.py new file mode 100644 index 000000000..d27f2415f --- /dev/null +++ b/brainpy/compat/tests/test_ngrc_lorenz_inference.py @@ -0,0 +1,176 @@ +# -*- coding: utf-8 -*- + +"""Implementation of the paper: + +- Gauthier, D.J., Bollt, E., Griffith, A. et al. Next generation reservoir + computing. Nat Commun 12, 5564 (2021). https://doi.org/10.1038/s41467-021-25801-2 + +The main task is forecasting the Lorenz63 strange attractor. +""" + +import matplotlib.pyplot as plt +import numpy as np + +import brainpy as bp +import brainpy.math as bm + +block = False + + +def get_subset(data, start, end): + res = {'x': data['x'][start: end], + 'y': data['y'][start: end], + 'z': data['z'][start: end]} + X = bm.hstack([res['x'], res['y']]) + X = X.reshape((1,) + X.shape) + Y = res['z'] + Y = Y.reshape((1,) + Y.shape) + return X, Y + + +def plot_lorenz(x, y, true_z, predict_z, linewidth=.8): + fig1 = plt.figure() + fig1.set_figheight(8) + fig1.set_figwidth(12) + + t_all = t_warmup + t_train + t_test + ts = np.arange(0, t_all, dt) + + h = 240 + w = 2 + + # top left of grid is 0,0 + axs1 = plt.subplot2grid(shape=(h, w), loc=(0, 0), colspan=2, rowspan=30) + axs2 = plt.subplot2grid(shape=(h, w), loc=(36, 0), colspan=2, rowspan=30) + axs3 = plt.subplot2grid(shape=(h, w), loc=(72, 0), colspan=2, rowspan=30) + axs4 = plt.subplot2grid(shape=(h, w), loc=(132, 0), colspan=2, rowspan=30) + axs5 = plt.subplot2grid(shape=(h, w), loc=(168, 0), colspan=2, rowspan=30) + axs6 = plt.subplot2grid(shape=(h, w), loc=(204, 0), colspan=2, rowspan=30) + + # training phase x + axs1.set_title('training phase') + axs1.plot(ts[num_warmup:num_warmup + num_train], + x[num_warmup:num_warmup + num_train], + color='b', linewidth=linewidth) + axs1.set_ylabel('x') + axs1.axes.xaxis.set_ticklabels([]) + axs1.axes.set_xbound(t_warmup - .08, t_warmup + t_train + .05) + axs1.axes.set_ybound(-21., 21.) + axs1.text(-.14, .9, 'a)', ha='left', va='bottom', transform=axs1.transAxes) + + # training phase y + axs2.plot(ts[num_warmup:num_warmup + num_train], + y[num_warmup:num_warmup + num_train], + color='b', linewidth=linewidth) + axs2.set_ylabel('y') + axs2.axes.xaxis.set_ticklabels([]) + axs2.axes.set_xbound(t_warmup - .08, t_warmup + t_train + .05) + axs2.axes.set_ybound(-26., 26.) + axs2.text(-.14, .9, 'b)', ha='left', va='bottom', transform=axs2.transAxes) + + # training phase z + axs3.plot(ts[num_warmup:num_warmup + num_train], + true_z[num_warmup:num_warmup + num_train], + color='b', linewidth=linewidth) + axs3.plot(ts[num_warmup:num_warmup + num_train], + predict_z[num_warmup:num_warmup + num_train], + color='r', linewidth=linewidth) + axs3.set_ylabel('z') + axs3.set_xlabel('time') + axs3.axes.set_xbound(t_warmup - .08, t_warmup + t_train + .05) + axs3.axes.set_ybound(3., 48.) + axs3.text(-.14, .9, 'c)', ha='left', va='bottom', transform=axs3.transAxes) + + # testing phase x + axs4.set_title('testing phase') + axs4.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test], + x[num_warmup + num_train:num_warmup + num_train + num_test], + color='b', linewidth=linewidth) + axs4.set_ylabel('x') + axs4.axes.xaxis.set_ticklabels([]) + axs4.axes.set_ybound(-21., 21.) + axs4.axes.set_xbound(t_warmup + t_train - .5, t_all + .5) + axs4.text(-.14, .9, 'd)', ha='left', va='bottom', transform=axs4.transAxes) + + # testing phase y + axs5.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test], + y[num_warmup + num_train:num_warmup + num_train + num_test], + color='b', linewidth=linewidth) + axs5.set_ylabel('y') + axs5.axes.xaxis.set_ticklabels([]) + axs5.axes.set_ybound(-26., 26.) + axs5.axes.set_xbound(t_warmup + t_train - .5, t_all + .5) + axs5.text(-.14, .9, 'e)', ha='left', va='bottom', transform=axs5.transAxes) + + # testing phose z + axs6.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test], + true_z[num_warmup + num_train:num_warmup + num_train + num_test], + color='b', linewidth=linewidth) + axs6.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test], + predict_z[num_warmup + num_train:num_warmup + num_train + num_test], + color='r', linewidth=linewidth) + axs6.set_ylabel('z') + axs6.set_xlabel('time') + axs6.axes.set_ybound(3., 48.) + axs6.axes.set_xbound(t_warmup + t_train - .5, t_all + .5) + axs6.text(-.14, .9, 'f)', ha='left', va='bottom', transform=axs6.transAxes) + + plt.show(block=block) + + +dt = 0.02 +t_warmup = 10. # ms +t_train = 20. # ms +t_test = 50. # ms +num_warmup = int(t_warmup / dt) # warm up NVAR +num_train = int(t_train / dt) +num_test = int(t_test / dt) + + +def test_ngrc_lorenz_inference(): + bm.enable_x64() + # Datasets # + # -------- # + lorenz_series = bp.datasets.lorenz_series(t_warmup + t_train + t_test, + dt=dt, + inits={'x': 17.67715816276679, + 'y': 12.931379185960404, + 'z': 43.91404334248268}) + + X_warmup, Y_warmup = get_subset(lorenz_series, 0, num_warmup) + X_train, Y_train = get_subset(lorenz_series, num_warmup, num_warmup + num_train) + X_test, Y_test = get_subset(lorenz_series, 0, num_warmup + num_train + num_test) + + # Model # + # ----- # + + i = bp.nn.Input(2) + r = bp.nn.NVAR(delay=4, order=2, stride=5) + o = bp.nn.LinearReadout(1, trainable=True) + model = i >> r >> o + model.plot_node_graph() + model.initialize(num_batch=1) + + # Training # + # -------- # + + trainer = bp.nn.RidgeTrainer(model, beta=0.05) + + # warm-up + outputs = trainer.predict(X_warmup) + print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup)) + + # training + trainer.fit([X_train, Y_train]) + + # prediction + outputs = trainer.predict(X_test, reset=True) + print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test)) + + plot_lorenz(x=lorenz_series['x'].flatten().numpy(), + y=lorenz_series['y'].flatten().numpy(), + true_z=lorenz_series['z'].flatten().numpy(), + predict_z=outputs.to_numpy().flatten()) + plt.close() + bm.disable_x64() + bp.base.clear_name_cache(True) diff --git a/brainpy/integrators/runner.py b/brainpy/integrators/runner.py index 4e3954093..07f97943e 100644 --- a/brainpy/integrators/runner.py +++ b/brainpy/integrators/runner.py @@ -105,8 +105,11 @@ def __init__( Parameters ---------- target: Integrator + The target to run. monitors: sequence of str + The variables to monitor. fun_monitors: dict + The monitors with callable functions. inits: sequence, dict The initial value of variables. With this parameter, you can easily control the number of variables to simulate. @@ -130,6 +133,31 @@ def __init__( progress_bar: bool numpy_mon_after_run: bool """ + + # initialize variables + if not isinstance(target, Integrator): + raise TypeError(f'Target must be instance of {Integrator.__name__}, ' + f'but we got {type(target)}') + if inits is not None: + if isinstance(inits, (list, tuple, bm.JaxArray, jnp.ndarray)): + assert len(target.variables) == len(inits) + inits = {k: inits[i] for i, k in enumerate(target.variables)} + assert isinstance(inits, dict), f'"inits" must be a dict, but we got {type(inits)}' + sizes = np.unique([np.size(v) for v in list(inits.values())]) + max_size = np.max(sizes) + else: + max_size = 1 + inits = dict() + self.variables = TensorCollector({v: bm.Variable(bm.zeros(max_size)) + for v in target.variables}) + for k in inits.keys(): + self.variables[k][:] = inits[k] + + # format string monitors + monitors = self._format_seq_monitors(monitors) + monitors = {k: (self.variables[k], i) for k, i in monitors} + + # initialize super class super(IntegratorRunner, self).__init__(target=target, monitors=monitors, fun_monitors=fun_monitors, @@ -179,20 +207,7 @@ def __init__( self.dyn_vars.update(self.target.vars().unique()) # Variables - if inits is not None: - if isinstance(inits, (list, tuple, bm.JaxArray, jnp.ndarray)): - assert len(self.target.variables) == len(inits) - inits = {k: inits[i] for i, k in enumerate(self.target.variables)} - assert isinstance(inits, dict), f'"inits" must be a dict, but we got {type(inits)}' - sizes = np.unique([np.size(v) for v in list(inits.values())]) - max_size = np.max(sizes) - else: - max_size = 1 - inits = dict() - self.variables = TensorCollector({v: bm.Variable(bm.zeros(max_size)) - for v in self.target.variables}) - for k in inits.keys(): - self.variables[k][:] = inits[k] + self.dyn_vars.update(self.variables) if len(self._dyn_args) > 0: self.idx = bm.Variable(bm.zeros(1, dtype=jnp.int_)) diff --git a/brainpy/running/runner.py b/brainpy/running/runner.py index 31237a884..f2e3d75f6 100644 --- a/brainpy/running/runner.py +++ b/brainpy/running/runner.py @@ -78,47 +78,9 @@ def __init__( monitors = dict() elif isinstance(monitors, (list, tuple)): # format string monitors - _monitors = [] - for mon in monitors: - if isinstance(mon, str): - _monitors.append((mon, None)) - elif isinstance(mon, (tuple, list)): - if not isinstance(mon[0], str) and len(mon) == 2: - raise MonitorError(f'We expect the monitor format with (name, index). But we got {mon}') - if isinstance(mon[1], (int, np.integer)): - idx = bm.array([mon[1]]) - else: - idx = mon[1] - _monitors.append((mon[0], idx)) - else: - raise MonitorError(f'We do not support monitor with {type(mon)}: {mon}') - + monitors = self._format_seq_monitors(monitors) # get monitor targets - monitors = {} - name2node = {node.name: node for node in list(target.nodes(level=-1).unique().values())} - for mon in _monitors: - key, index = mon[0], mon[1] - splits = key.split('.') - if len(splits) == 1: - if not hasattr(target, splits[0]): - raise RunningError(f'{target} does not has variable {key}.') - monitors[key] = (getattr(target, splits[-1]), index) - else: - if not hasattr(target, splits[0]): - if splits[0] not in name2node: - raise MonitorError(f'Cannot find target {key} in monitor of {target}, please check.') - else: - master = name2node[splits[0]] - assert len(splits) == 2 - monitors[key] = (getattr(master, splits[-1]), index) - else: - master = target - for s in splits[:-1]: - try: - master = getattr(master, s) - except KeyError: - raise MonitorError(f'Cannot find {key} in {master}, please check.') - monitors[key] = (getattr(master, splits[-1]), index) + monitors = self._find_monitor_targets(monitors) elif isinstance(monitors, dict): _monitors = dict() for key, val in monitors.items(): @@ -128,13 +90,20 @@ def __init__( if isinstance(val, bm.Variable): val = (val, None) if isinstance(val, (tuple, list)): - if len(val) != 2: + if not isinstance(val[0], bm.Variable): raise MonitorError('Expect the format of (variable, index) in the monitor setting. ' f'But we got {val}') - if not isinstance(val[0], bm.Variable): + if len(val) == 1: + _monitors[key] = (val[0], None) + elif len(val) == 2: + if isinstance(val[1], (int, np.integer)): + idx = bm.array([val[1]]) + else: + idx = None if val[1] is None else bm.asarray(val[1]) + _monitors[key] = (val[0], idx) + else: raise MonitorError('Expect the format of (variable, index) in the monitor setting. ' f'But we got {val}') - _monitors[key] = val else: raise MonitorError('Expect the format of (variable, index) in the monitor setting. ' f'But we got {val}') @@ -181,5 +150,61 @@ def format_monitors(self): return_with_idx[key] = (variable, bm.asarray(idx)) return return_without_idx, return_with_idx + def _format_seq_monitors(self, monitors): + if not isinstance(monitors, (tuple, list)): + raise TypeError(f'Must be a sequence, but we got {type(monitors)}') + _monitors = [] + for mon in monitors: + if isinstance(mon, str): + _monitors.append((mon, None)) + elif isinstance(mon, (tuple, list)): + if isinstance(mon[0], str): + if len(mon) == 1: + _monitors.append((mon[0], None)) + elif len(mon) == 2: + if isinstance(mon[1], (int, np.integer)): + idx = bm.array([mon[1]]) + else: + idx = None if mon[1] is None else bm.asarray(mon[1]) + _monitors.append((mon[0], idx)) + else: + raise MonitorError(f'We expect the monitor format with (name, index). But we got {mon}') + else: + raise MonitorError(f'We expect the monitor format with (name, index). But we got {mon}') + else: + raise MonitorError(f'We do not support monitor with {type(mon)}: {mon}') + return _monitors + + def _find_monitor_targets(self, _monitors): + if not isinstance(_monitors, (tuple, list)): + raise TypeError(f'Must be a sequence, but we got {type(_monitors)}') + # get monitor targets + monitors = {} + name2node = {node.name: node for node in list(self.target.nodes(level=-1).unique().values())} + for mon in _monitors: + key, index = mon[0], mon[1] + splits = key.split('.') + if len(splits) == 1: + if not hasattr(self.target, splits[0]): + raise RunningError(f'{self.target} does not has variable {key}.') + monitors[key] = (getattr(self.target, splits[-1]), index) + else: + if not hasattr(self.target, splits[0]): + if splits[0] not in name2node: + raise MonitorError(f'Cannot find target {key} in monitor of {self.target}, please check.') + else: + master = name2node[splits[0]] + assert len(splits) == 2 + monitors[key] = (getattr(master, splits[-1]), index) + else: + master = self.target + for s in splits[:-1]: + try: + master = getattr(master, s) + except KeyError: + raise MonitorError(f'Cannot find {key} in {master}, please check.') + monitors[key] = (getattr(master, splits[-1]), index) + return monitors + def build_monitors(self, return_without_idx, return_with_idx) -> Callable: raise NotImplementedError diff --git a/brainpy/train/runners/back_propagation.py b/brainpy/train/runners/back_propagation.py index 0a1194946..85c1263ce 100644 --- a/brainpy/train/runners/back_propagation.py +++ b/brainpy/train/runners/back_propagation.py @@ -401,15 +401,17 @@ def predict( if reset_state: self.target.reset_state(num_batch) # init monitor - for key in self.mon.item_contents.keys(): - self.mon.item_contents[key] = [] # reshape the monitor items + for key in self.mon.var_names: + self.mon[key] = [] # reshape the monitor items # prediction outputs, hists = self._predict(xs=xs, shared_args=shared_args) # post-running for monitors - for key in self.mon.item_names: - self.mon.item_contents[key] = hists[key] + for key in hists.keys(): + self.mon[key] = bm.asarray(hists[key]) if self.numpy_mon_after_run: - self.mon.numpy() + self.mon.ts = np.asarray(self.mon.ts) + for key in hists.keys(): + self.mon[key] = np.asarray(self.mon[key]) return outputs def _predict( diff --git a/brainpy/train/runners/offline_trainer.py b/brainpy/train/runners/offline_trainer.py index f1bb1b401..d46e67941 100644 --- a/brainpy/train/runners/offline_trainer.py +++ b/brainpy/train/runners/offline_trainer.py @@ -194,9 +194,9 @@ def func(_t, _dt): res = {k: v.value for k, v in return_without_idx.items()} res.update({k: v[idx] for k, (v, idx) in return_with_idx.items()}) res.update({k: f(_t, _dt) for k, f in self.fun_monitors.items()}) - res.update({f'{node.name}-fit_record': node.fit_record for node in self.train_nodes}) - # for node in self.train_nodes: - # node.fit_record.clear() + res.update({f'{node.name}-fit_record': {k: node.fit_record.pop(k) + for k in node.fit_record.keys()} + for node in self.train_nodes}) return res return func diff --git a/brainpy/train/runners/online_trainer.py b/brainpy/train/runners/online_trainer.py index 75cb71b45..9f044c195 100644 --- a/brainpy/train/runners/online_trainer.py +++ b/brainpy/train/runners/online_trainer.py @@ -258,9 +258,9 @@ def func(t, dt): res = {k: v.value for k, v in return_without_idx.items()} res.update({k: v[idx] for k, (v, idx) in return_with_idx.items()}) res.update({k: f(t, dt) for k, f in self.fun_monitors.items()}) - res.update({f'{node.name}-fit_record': node.fit_record for node in self.train_nodes}) - # for node in self.train_nodes: - # node.fit_record.clear() + res.update({f'{node.name}-fit_record': {k: node.fit_record.pop(k) + for k in node.fit_record.keys()} + for node in self.train_nodes}) return res return func From 37354455a6a120f574e0e27adbdbb0166264d722 Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 2 Jun 2022 13:48:31 +0800 Subject: [PATCH 3/9] update codes --- brainpy/base/base.py | 35 +++++++++++++++++++++++++++-------- brainpy/dyn/base.py | 12 ++++++------ brainpy/tools/checking.py | 6 +++--- brainpy/train/base.py | 4 ++++ 4 files changed, 40 insertions(+), 17 deletions(-) diff --git a/brainpy/base/base.py b/brainpy/base/base.py index e1b5f776c..25f41a102 100644 --- a/brainpy/base/base.py +++ b/brainpy/base/base.py @@ -52,9 +52,27 @@ def name(self, name: str = None): self._name = self.unique_name(name=name) naming.check_name_uniqueness(name=self._name, obj=self) - def register_implicit_vars(self, variables): - assert isinstance(variables, dict), f'Must be a dict, but we got {type(variables)}' - self.implicit_vars.update(variables) + def register_implicit_vars(self, *variables, **named_variables): + from brainpy.math import Variable + for variable in variables: + if isinstance(variable, Variable): + self.implicit_vars[f'var{id(variable)}'] = variable + elif isinstance(variable, (tuple, list)): + for v in variable: + if not isinstance(v, Variable): + raise ValueError(f'Must be instance of {Variable.__name__}, but we got {type(v)}') + self.implicit_vars[f'var{id(variable)}'] = v + elif isinstance(variable, dict): + for k, v in variable.items(): + if not isinstance(v, Variable): + raise ValueError(f'Must be instance of {Variable.__name__}, but we got {type(v)}') + self.implicit_vars[k] = v + else: + raise ValueError(f'Unknown type: {type(variable)}') + for key, variable in named_variables.items(): + if not isinstance(variable, Variable): + raise ValueError(f'Must be instance of {Variable.__name__}, but we got {type(variable)}') + self.implicit_vars[key] = variable def register_implicit_nodes(self, *nodes, **named_nodes): for node in nodes: @@ -70,11 +88,12 @@ def register_implicit_nodes(self, *nodes, **named_nodes): if not isinstance(n, Base): raise ValueError(f'Must be instance of {Base.__name__}, but we got {type(n)}') self.implicit_nodes[k] = n - for node in named_nodes.values(): - for k, n in node.items(): - if not isinstance(n, Base): - raise ValueError(f'Must be instance of {Base.__name__}, but we got {type(n)}') - self.implicit_nodes[k] = n + else: + raise ValueError(f'Unknown type: {type(node)}') + for key, node in named_nodes.items(): + if not isinstance(node, Base): + raise ValueError(f'Must be instance of {Base.__name__}, but we got {type(node)}') + self.implicit_nodes[key] = node def vars(self, method='absolute', level=-1, include_self=True): """Collect all variables in this node and the children nodes. diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py index 11fb1215f..1527121dc 100644 --- a/brainpy/dyn/base.py +++ b/brainpy/dyn/base.py @@ -102,7 +102,7 @@ def register_delay( # delay steps if delay_step is None: delay_type = 'none' - elif isinstance(delay_step, int): + elif isinstance(delay_step, (int, np.integer, jnp.integer)): delay_type = 'homo' elif isinstance(delay_step, (bm.ndarray, jnp.ndarray, np.ndarray)): if delay_step.size == 1 and delay_step.ndim == 0: @@ -168,7 +168,7 @@ def get_delay_data( return self.global_delay_targets[name] if name in self.global_delay_vars: - if isinstance(delay_step, int): + if isinstance(delay_step, (int, np.integer)): return self.global_delay_vars[name](delay_step, *indices) else: if len(indices) == 0: @@ -176,7 +176,7 @@ def get_delay_data( return self.global_delay_vars[name](delay_step, *indices) elif name in self.local_delay_vars: - if isinstance(delay_step, int): + if isinstance(delay_step, (int, np.integer)): return self.local_delay_vars[name](delay_step) else: if len(indices) == 0: @@ -784,13 +784,13 @@ def __init__( def var_shape(self): return self.size if self.keep_size else self.num - def update(self, t, dt): + def update(self, t, dt, V): raise NotImplementedError('Must be implemented by the subclass.') - def current(self): + def current(self, V): raise NotImplementedError('Must be implemented by the subclass.') - def reset(self): + def reset(self, V): raise NotImplementedError('Must be implemented by the subclass.') diff --git a/brainpy/tools/checking.py b/brainpy/tools/checking.py index e0ff76019..af75ba3a5 100644 --- a/brainpy/tools/checking.py +++ b/brainpy/tools/checking.py @@ -257,10 +257,10 @@ def check_float(value: float, name=None, min_bound=None, max_bound=None, else: raise ValueError(f'{name} must be a float, but got None') if allow_int: - if not isinstance(value, (float, int)): + if not isinstance(value, (float, int, np.integer, np.floating)): raise ValueError(f'{name} must be a float, but got {type(value)}') else: - if not isinstance(value, float): + if not isinstance(value, (float, np.floating)): raise ValueError(f'{name} must be a float, but got {type(value)}') if min_bound is not None: if value < min_bound: @@ -292,7 +292,7 @@ def check_integer(value: int, name=None, min_bound=None, max_bound=None, allow_n return else: raise ValueError(f'{name} must be an int, but got None') - if not isinstance(value, int): + if not isinstance(value, (int, np.integer)): if hasattr(value, '__array__'): if not (np.issubdtype(value.dtype, np.integer) and value.ndim == 0 and value.size == 1): raise ValueError(f'{name} must be an int, but got {value}') diff --git a/brainpy/train/base.py b/brainpy/train/base.py index 1eb1aab37..70eddff79 100644 --- a/brainpy/train/base.py +++ b/brainpy/train/base.py @@ -53,6 +53,10 @@ def __init__(self, name: str = None, trainable: bool = False): def trainable(self): return self._trainable + @trainable.setter + def trainable(self, value): + self._trainable = value + def __repr__(self): return f"{type(self).__name__}(name={self.name}, trainable={self.trainable})" From fb63584ef3803e0088922122c2e458f3714c0400 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 4 Jun 2022 14:21:23 +0800 Subject: [PATCH 4/9] [module] update `brainpy.losses` module --- brainpy/losses/__init__.py | 420 +------------------------------ brainpy/losses/comparison.py | 416 ++++++++++++++++++++++++++++++ brainpy/losses/regularization.py | 58 +++++ brainpy/losses/utils.py | 32 +++ docs/auto_generater.py | 12 +- 5 files changed, 517 insertions(+), 421 deletions(-) create mode 100644 brainpy/losses/comparison.py create mode 100644 brainpy/losses/regularization.py create mode 100644 brainpy/losses/utils.py diff --git a/brainpy/losses/__init__.py b/brainpy/losses/__init__.py index 70b3fdc07..04a304ae0 100644 --- a/brainpy/losses/__init__.py +++ b/brainpy/losses/__init__.py @@ -7,423 +7,7 @@ # - https://github.com/deepmind/optax/blob/master/optax/_src/loss.py # - https://github.com/google/jaxopt/blob/main/jaxopt/_src/loss.py -import jax.numpy as jn -import jax.scipy -from jax.tree_util import tree_flatten +from .comparison import * +from .regularization import * -import brainpy.math as bm -from brainpy import errors -__all__ = [ - 'cross_entropy_loss', - 'l1_loos', - 'l2_loss', - 'l2_norm', - 'huber_loss', - 'mean_absolute_error', - 'mean_squared_error', - 'mean_squared_log_error', -] - -_reduction_error = 'Only support reduction of "mean", "sum" and "none", but we got "%s".' - - -def _return(outputs, reduction): - if reduction == 'mean': - return outputs.mean() - elif reduction == 'sum': - return outputs.sum() - elif reduction == 'none': - return outputs - else: - raise errors.UnsupportedError(_reduction_error % reduction) - - -def cross_entropy_loss(logits, targets, weight=None, reduction='mean'): - r"""This criterion combines ``LogSoftmax`` and `NLLLoss`` in one single class. - - It is useful when training a classification problem with `C` classes. - If provided, the optional argument :attr:`weight` should be a 1D `Tensor` - assigning weight to each of the classes. This is particularly useful when - you have an unbalanced training set. - - The ``input`` is expected to contain raw, unnormalized scores for each class. - ``input`` has to be an array of size either :math:`(minibatch, C)` or - :math:`(d_1, d_2, ..., d_K, minibatch, C)` with :math:`K \geq 1` for the - `K`-dimensional case (described later). - - This criterion expects a class index in the range :math:`[0, C-1]` as the - `target` for each value of a 1D tensor of size `minibatch`. - - The loss can be described as: - - .. math:: - \text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right) - = -x[class] + \log\left(\sum_j \exp(x[j])\right) - - or in the case of the :attr:`weight` argument being specified: - - .. math:: - \text{loss}(x, class) = weight[class] \left(-x[class] + \log\left(\sum_j \exp(x[j])\right)\right) - - Can also be used for higher dimension inputs, such as 2D images, by providing - an input of size :math:`(d_1, d_2, ..., d_K, minibatch, C)` with :math:`K \geq 1`, - where :math:`K` is the number of dimensions, and a target of appropriate shape. - - Parameters - ---------- - logits : jmath.JaxArray - :math:`(N, C)` where `C = number of classes`, or - :math:`(d_1, d_2, ..., d_K, N, C)` with :math:`K \geq 1` - in the case of `K`-dimensional loss. - targets : jmath.JaxArray - :math:`(N, C)` or :math:`(N)` where each value is - :math:`0 \leq \text{targets}[i] \leq C-1`, or - :math:`(d_1, d_2, ..., d_K, N, C)` or :math:`(d_1, d_2, ..., d_K, N)` - with :math:`K \geq 1` in the case of K-dimensional loss. - weight : mjax.JaxArray, optional - A manual rescaling weight given to each class. If given, has to be an array of size `C`. - reduction : str, optional - Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. - - ``'none'``: no reduction will be applied, - - ``'mean'``: the weighted mean of the output is taken, - - ``'sum'``: the output will be summed. - - Returns - ------- - output : scalar, mjax.JaxArray - If :attr:`reduction` is ``'none'``, then the same size as the target: - :math:`(N)`, or :math:`(d_1, d_2, ..., d_K, N)` with :math:`K \geq 1` - in the case of K-dimensional loss. - """ - targets = bm.as_device_array(targets) - logits = bm.as_device_array(logits) - - # loss - if bm.ndim(targets) + 1 == bm.ndim(logits): - # targets_old = targets.reshape((-1,)) - # length = targets_old.shape[0] - # rows = jn.arange(length) - # targets = ops.zeros((length, logits.shape[-1])) - # targets[rows, targets_old] = 1. - # targets = targets.reshape(logits.shape).value - targets = bm.activations.one_hot(targets, logits.shape[-1]) - loss = jax.scipy.special.logsumexp(logits, axis=-1) - (logits * targets).sum(axis=-1) - - # weighted loss - if weight: - loss *= weight[targets] - raise NotImplementedError - - return _return(outputs=loss, reduction=reduction) - - -def cross_entropy_sparse(logits, labels): - r"""Computes the softmax cross-entropy loss. - - Args: - logits: (batch, ..., #class) tensor of logits. - labels: (batch, ...) integer tensor of label indexes in {0, ...,#nclass-1} or just a single integer. - - Returns: - (batch, ...) tensor of the cross-entropy for each entry. - """ - - if isinstance(labels, int): - labeled_logits = logits[..., labels] - else: - logits = bm.as_device_array(logits) - labels = bm.as_device_array(labels) - labeled_logits = jn.take_along_axis(logits, labels, -1).squeeze(-1) - loss = jax.scipy.special.logsumexp(logits, axis=-1) - labeled_logits - return loss - - -def cross_entropy_sigmoid(logits, labels): - """Computes the sigmoid cross-entropy loss. - - Args: - logits: (batch, ..., #class) tensor of logits. - labels: (batch, ..., #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1) - - Returns: - (batch, ...) tensor of the cross-entropies for each entry. - """ - return jax.numpy.maximum(logits, 0) - logits * labels + \ - jax.numpy.log(1 + jax.numpy.exp(-jax.numpy.abs(logits))) - - -def l1_loos(logits, targets, reduction='sum'): - r"""Creates a criterion that measures the mean absolute error (MAE) between each element in - the logits :math:`x` and targets :math:`y`. It is useful in regression problems. - - The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: - - .. math:: - \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad - l_n = \left| x_n - y_n \right|, - - where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` - (default ``'mean'``), then: - - .. math:: - \ell(x, y) = - \begin{cases} - \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ - \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} - \end{cases} - - :math:`x` and :math:`y` are tensors of arbitrary shapes with a total - of :math:`n` elements each. - - The sum operation still operates over all the elements, and divides by :math:`n`. - - The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. - - Supports real-valued and complex-valued inputs. - - Parameters - ---------- - logits : jmath.JaxArray - :math:`(N, *)` where :math:`*` means, any number of additional dimensions. - targets : jmath.JaxArray - :math:`(N, *)`, same shape as the input. - reduction : str - Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. - Default: ``'mean'``. - - ``'none'``: no reduction will be applied, - - ``'mean'``: the sum of the output will be divided by the number of elements in the output, - - ``'sum'``: the output will be summed. Note: :attr:`size_average` - - Returns - ------- - output : scalar. - If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same shape as the input. - """ - diff = (logits - targets).reshape((logits.shape[0], -1)) - norm = jn.linalg.norm(bm.as_device_array(diff), ord=1, axis=1, keepdims=False) - return _return(outputs=norm, reduction=reduction) - - -def l2_loss(predicts, targets): - r"""Computes the L2 loss. - - The 0.5 term is standard in "Pattern Recognition and Machine Learning" - by Bishop [1]_, but not "The Elements of Statistical Learning" by Tibshirani. - - Parameters - ---------- - - predicts: JaxArray - A vector of arbitrary shape. - targets: JaxArray - A vector of shape compatible with predictions. - - Returns - ------- - loss : float - A scalar value containing the l2 loss. - - References - ---------- - .. [1] Bishop, Christopher M. 2006. Pattern Recognition and Machine Learning. - """ - return bm.as_device_array(0.5 * (predicts - targets) ** 2) - - -def l2_norm(x): - """Computes the L2 loss. - - Args: - x: n-dimensional tensor of floats. - - Returns: - scalar tensor containing the l2 loss of x. - """ - leaves, _ = tree_flatten(x) - return jn.sqrt(sum(jn.vdot(x, x) for x in leaves)) - - -def mean_absolute_error(x, y, axis=None): - r"""Computes the mean absolute error between x and y. - - Args: - x: a tensor of shape (d0, .. dN-1). - y: a tensor of shape (d0, .. dN-1). - keep_axis: a sequence of the dimensions to keep, use `None` to return a scalar value. - - Returns: - tensor of shape (d_i, ..., for i in keep_axis) containing the mean absolute error. - """ - r = bm.abs(x - y) - return jn.mean(bm.as_device_array(r), axis=axis) - - -def mean_squared_error(predicts, targets, axis=None): - r"""Computes the mean squared error between x and y. - - Args: - predicts: a tensor of shape (d0, .. dN-1). - targets: a tensor of shape (d0, .. dN-1). - keep_axis: a sequence of the dimensions to keep, use `None` to return a scalar value. - - Returns: - tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error. - """ - r = (predicts - targets) ** 2 - return jn.mean(bm.as_device_array(r), axis=axis) - - -def mean_squared_log_error(y_true, y_pred, axis=None): - r"""Computes the mean squared logarithmic error between y_true and y_pred. - - Args: - y_true: a tensor of shape (d0, .. dN-1). - y_pred: a tensor of shape (d0, .. dN-1). - keep_axis: a sequence of the dimensions to keep, use `None` to return a scalar value. - - Returns: - tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error. - """ - r = (bm.log1p(y_true) - bm.log1p(y_pred)) ** 2 - return jn.mean(bm.as_device_array(r), axis=axis) - - -def huber_loss(predicts, targets, delta: float = 1.0): - r"""Huber loss. - - Huber loss is similar to L2 loss close to zero, L1 loss away from zero. - If gradient descent is applied to the `huber loss`, it is equivalent to - clipping gradients of an `l2_loss` to `[-delta, delta]` in the backward pass. - - Parameters - ---------- - predicts: JaxArray - predictions - targets: JaxArray - ground truth - delta: float - radius of quadratic behavior - - Returns - ------- - loss : float - The loss value. - - References - ---------- - .. [1] https://en.wikipedia.org/wiki/Huber_loss - """ - diff = bm.as_device_array(bm.abs(targets - predicts)) - # 0.5 * err^2 if |err| <= d - # 0.5 * d^2 + d * (|err| - d) if |err| > d - return jn.where(diff > delta, delta * (diff - .5 * delta), 0.5 * diff ** 2) - - -def binary_logistic_loss(logits: float, labels: int, ) -> float: - """Binary logistic loss. - - Args: - labels: ground-truth integer label (0 or 1). - logits: score produced by the model (float). - Returns: - loss value - """ - # Softplus is the Fenchel conjugate of the Fermi-Dirac negentropy on [0, 1]. - # softplus = proba * logit - xlogx(proba) - xlogx(1 - proba), - # where xlogx(proba) = proba * log(proba). - return bm.activations.softplus(logits) - labels * logits - - -def multiclass_logistic_loss(label: int, logits: jn.ndarray) -> float: - """Multiclass logistic loss. - - Args: - label: ground-truth integer label, between 0 and n_classes - 1. - logits: scores produced by the model, shape = (n_classes, ). - Returns: - loss value - """ - n_classes = logits.shape[0] - one_hot = jax.nn.one_hot(label, n_classes) - # Logsumexp is the Fenchel conjugate of the Shannon negentropy on the simplex. - # logsumexp = jnp.dot(proba, logits) - jnp.dot(proba, jnp.log(proba)) - return jax.scipy.special.logsumexp(logits) - jn.dot(logits, one_hot) - - -def smooth_labels(labels, alpha: float) -> jn.ndarray: - r"""Apply label smoothing. - Label smoothing is often used in combination with a cross-entropy loss. - Smoothed labels favour small logit gaps, and it has been shown that this can - provide better model calibration by preventing overconfident predictions. - References: - [Müller et al, 2019](https://arxiv.org/pdf/1906.02629.pdf) - Args: - labels: one hot labels to be smoothed. - alpha: the smoothing factor, the greedy category with be assigned - probability `(1-alpha) + alpha / num_categories` - Returns: - a smoothed version of the one hot input labels. - """ - num_categories = labels.shape[-1] - return (1.0 - alpha) * labels + alpha / num_categories - - -def sigmoid_binary_cross_entropy(logits, labels): - """Computes sigmoid cross entropy given logits and multiple class labels. - Measures the probability error in discrete classification tasks in which - each class is an independent binary prediction and different classes are - not mutually exclusive. This may be used for multilabel image classification - for instance a model may predict that an image contains both a cat and a dog. - References: - [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) - Args: - logits: unnormalized log probabilities. - labels: the probability for that class. - Returns: - a sigmoid cross entropy loss. - """ - log_p = jax.nn.log_sigmoid(logits) - # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter more numerically stable - log_not_p = jax.nn.log_sigmoid(-logits) - return -labels * log_p - (1. - labels) * log_not_p - - -def softmax_cross_entropy(logits, labels): - """Computes the softmax cross entropy between sets of logits and labels. - Measures the probability error in discrete classification tasks in which - the classes are mutually exclusive (each entry is in exactly one class). - For example, each CIFAR-10 image is labeled with one and only one label: - an image can be a dog or a truck, but not both. - References: - [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) - Args: - logits: unnormalized log probabilities. - labels: a valid probability distribution (non-negative, sum to 1), e.g a - one hot encoding of which class is the correct one for each input. - Returns: - the cross entropy loss. - """ - logits = bm.as_device_array(logits) - labels = bm.as_device_array(labels) - return -jn.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) - - -def log_cosh(predicts, targets=None, ): - r"""Calculates the log-cosh loss for a set of predictions. - - log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)` - for large x. It is a twice differentiable alternative to the Huber loss. - References: - [Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym) - Args: - predicts: a vector of arbitrary shape. - targets: a vector of shape compatible with predictions; if not provided - then it is assumed to be zero. - Returns: - the log-cosh loss. - """ - errors = (predicts - targets) if (targets is not None) else predicts - errors = bm.as_device_array(errors) - # log(cosh(x)) = log((exp(x) + exp(-x))/2) = log(exp(x) + exp(-x)) - log(2) - return jn.logaddexp(errors, -errors) - jn.log(2.0).astype(errors.dtype) diff --git a/brainpy/losses/comparison.py b/brainpy/losses/comparison.py new file mode 100644 index 000000000..730849fe6 --- /dev/null +++ b/brainpy/losses/comparison.py @@ -0,0 +1,416 @@ +# -*- coding: utf-8 -*- + +""" +This module implements several loss functions. +""" + +# - https://github.com/deepmind/optax/blob/master/optax/_src/loss.py +# - https://github.com/google/jaxopt/blob/main/jaxopt/_src/loss.py + +import jax.numpy as jnp +import jax.scipy +from jax.tree_util import tree_flatten, tree_map + +import brainpy.math as bm +from brainpy import errors +from .utils import _return, _multi_return, _is_leaf + +__all__ = [ + 'cross_entropy_loss', + 'cross_entropy_sparse', + 'cross_entropy_sigmoid', + 'l1_loos', + 'l2_loss', + 'huber_loss', + 'mean_absolute_error', + 'mean_squared_error', + 'mean_squared_log_error', + 'binary_logistic_loss', + 'multiclass_logistic_loss', + 'smooth_labels', + 'sigmoid_binary_cross_entropy', + 'softmax_cross_entropy', + 'log_cosh_loss', +] + + +def cross_entropy_loss(predicts, targets, weight=None, reduction='mean'): + r"""This criterion combines ``LogSoftmax`` and `NLLLoss`` in one single class. + + It is useful when training a classification problem with `C` classes. + If provided, the optional argument :attr:`weight` should be a 1D `Tensor` + assigning weight to each of the classes. This is particularly useful when + you have an unbalanced training set. + + The ``input`` is expected to contain raw, unnormalized scores for each class. + ``input`` has to be an array of size either :math:`(minibatch, C)` or + :math:`(d_1, d_2, ..., d_K, minibatch, C)` with :math:`K \geq 1` for the + `K`-dimensional case (described later). + + This criterion expects a class index in the range :math:`[0, C-1]` as the + `target` for each value of a 1D tensor of size `minibatch`. + + The loss can be described as: + + .. math:: + \text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right) + = -x[class] + \log\left(\sum_j \exp(x[j])\right) + + or in the case of the :attr:`weight` argument being specified: + + .. math:: + \text{loss}(x, class) = weight[class] \left(-x[class] + \log\left(\sum_j \exp(x[j])\right)\right) + + Can also be used for higher dimension inputs, such as 2D images, by providing + an input of size :math:`(d_1, d_2, ..., d_K, minibatch, C)` with :math:`K \geq 1`, + where :math:`K` is the number of dimensions, and a target of appropriate shape. + + Parameters + ---------- + predicts : jmath.JaxArray + :math:`(N, C)` where `C = number of classes`, or + :math:`(d_1, d_2, ..., d_K, N, C)` with :math:`K \geq 1` + in the case of `K`-dimensional loss. + targets : jmath.JaxArray + :math:`(N, C)` or :math:`(N)` where each value is + :math:`0 \leq \text{targets}[i] \leq C-1`, or + :math:`(d_1, d_2, ..., d_K, N, C)` or :math:`(d_1, d_2, ..., d_K, N)` + with :math:`K \geq 1` in the case of K-dimensional loss. + weight : JaxArray, optional + A manual rescaling weight given to each class. If given, has to be an array of size `C`. + reduction : str, optional + Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. + - ``'none'``: no reduction will be applied, + - ``'mean'``: the weighted mean of the output is taken, + - ``'sum'``: the output will be summed. + + Returns + ------- + output : scalar, mjax.JaxArray + If :attr:`reduction` is ``'none'``, then the same size as the target: + :math:`(N)`, or :math:`(d_1, d_2, ..., d_K, N)` with :math:`K \geq 1` + in the case of K-dimensional loss. + """ + targets = bm.as_device_array(targets) + predicts = bm.as_device_array(predicts) + + # loss + if bm.ndim(targets) + 1 == bm.ndim(predicts): + # targets_old = targets.reshape((-1,)) + # length = targets_old.shape[0] + # rows = jn.arange(length) + # targets = ops.zeros((length, logits.shape[-1])) + # targets[rows, targets_old] = 1. + # targets = targets.reshape(logits.shape).value + targets = bm.activations.one_hot(targets, predicts.shape[-1]) + loss = jax.scipy.special.logsumexp(predicts, axis=-1) - (predicts * targets).sum(axis=-1) + + # weighted loss + if weight: + loss *= weight[targets] + raise NotImplementedError + + return _return(outputs=loss, reduction=reduction) + + +def cross_entropy_sparse(predicts, targets): + r"""Computes the softmax cross-entropy loss. + + Args: + predicts: (batch, ..., #class) tensor of logits. + targets: (batch, ...) integer tensor of label indexes in {0, ...,#nclass-1} or just a single integer. + + Returns: + (batch, ...) tensor of the cross-entropy for each entry. + """ + predicts = bm.as_device_array(predicts) + targets = bm.as_device_array(targets) + if isinstance(targets, int): + labeled_logits = predicts[..., targets] + else: + labeled_logits = jnp.take_along_axis(predicts, targets, -1).squeeze(-1) + loss = jax.scipy.special.logsumexp(predicts, axis=-1) - labeled_logits + return loss + + +def cross_entropy_sigmoid(predicts, targets): + """Computes the sigmoid cross-entropy loss. + + Args: + predicts: (batch, ..., #class) tensor of logits. + targets: (batch, ..., #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1) + + Returns: + (batch, ...) tensor of the cross-entropies for each entry. + """ + predicts = bm.as_device_array(predicts) + targets = bm.as_device_array(targets) + return jnp.maximum(predicts, 0) - predicts * targets + jnp.log(1 + jnp.exp(-jnp.abs(predicts))) + + +def l1_loos(logits, targets, reduction='sum'): + r"""Creates a criterion that measures the mean absolute error (MAE) between each element in + the logits :math:`x` and targets :math:`y`. It is useful in regression problems. + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = \left| x_n - y_n \right|, + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + :math:`x` and :math:`y` are tensors of arbitrary shapes with a total + of :math:`n` elements each. + + The sum operation still operates over all the elements, and divides by :math:`n`. + + The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. + + Supports real-valued and complex-valued inputs. + + Parameters + ---------- + logits : jmath.JaxArray + :math:`(N, *)` where :math:`*` means, any number of additional dimensions. + targets : jmath.JaxArray + :math:`(N, *)`, same shape as the input. + reduction : str + Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. + Default: ``'mean'``. + - ``'none'``: no reduction will be applied, + - ``'mean'``: the sum of the output will be divided by the number of elements in the output, + - ``'sum'``: the output will be summed. Note: :attr:`size_average` + + Returns + ------- + output : scalar. + If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same shape as the input. + """ + diff = (logits - targets).reshape((logits.shape[0], -1)) + norm = jnp.linalg.norm(bm.as_device_array(diff), ord=1, axis=1, keepdims=False) + return _return(outputs=norm, reduction=reduction) + + +def l2_loss(predicts, targets): + r"""Computes the L2 loss. + + The 0.5 term is standard in "Pattern Recognition and Machine Learning" + by Bishop [1]_, but not "The Elements of Statistical Learning" by Tibshirani. + + Parameters + ---------- + + predicts: JaxArray + A vector of arbitrary shape. + targets: JaxArray + A vector of shape compatible with predictions. + + Returns + ------- + loss : float + A scalar value containing the l2 loss. + + References + ---------- + .. [1] Bishop, Christopher M. 2006. Pattern Recognition and Machine Learning. + """ + return bm.as_device_array(0.5 * (predicts - targets) ** 2) + + +def mean_absolute_error(x, y, axis=None): + r"""Computes the mean absolute error between x and y. + + Args: + x: a tensor of shape (d0, .. dN-1). + y: a tensor of shape (d0, .. dN-1). + axis: a sequence of the dimensions to keep, use `None` to return a scalar value. + + Returns: + tensor of shape (d_i, ..., for i in keep_axis) containing the mean absolute error. + """ + r = tree_map(lambda a, b: bm.mean(bm.abs(a - b), axis=axis), x, y, is_leaf=_is_leaf) + return _multi_return(r) + + +def mean_squared_error(predicts, targets, axis=None): + r"""Computes the mean squared error between x and y. + + Args: + predicts: a tensor of shape (d0, .. dN-1). + targets: a tensor of shape (d0, .. dN-1). + keep_axis: a sequence of the dimensions to keep, use `None` to return a scalar value. + + Returns: + tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error. + """ + r = tree_map(lambda a, b: bm.mean((a - b) ** 2, axis=axis), predicts, targets, is_leaf=_is_leaf) + return _multi_return(r) + + +def mean_squared_log_error(predicts, targets, axis=None): + r"""Computes the mean squared logarithmic error between y_true and y_pred. + + Args: + targets: a tensor of shape (d0, .. dN-1). + predicts: a tensor of shape (d0, .. dN-1). + keep_axis: a sequence of the dimensions to keep, use `None` to return a scalar value. + + Returns: + tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error. + """ + r = tree_map(lambda a, b: bm.mean((bm.log1p(a) - bm.log1p(b)) ** 2, axis=axis), + predicts, targets, is_leaf=_is_leaf) + return _multi_return(r) + + +def huber_loss(predicts, targets, delta: float = 1.0): + r"""Huber loss. + + Huber loss is similar to L2 loss close to zero, L1 loss away from zero. + If gradient descent is applied to the `huber loss`, it is equivalent to + clipping gradients of an `l2_loss` to `[-delta, delta]` in the backward pass. + + Parameters + ---------- + predicts: JaxArray + predictions + targets: JaxArray + ground truth + delta: float + radius of quadratic behavior + + Returns + ------- + loss : float + The loss value. + + References + ---------- + .. [1] https://en.wikipedia.org/wiki/Huber_loss + """ + def _loss(y_predict, y_target): + # 0.5 * err^2 if |err| <= d + # 0.5 * d^2 + d * (|err| - d) if |err| > d + diff = bm.abs(y_predict - y_target) + return bm.where(diff > delta, + delta * (diff - .5 * delta), + 0.5 * diff ** 2) + + return tree_map(_loss, targets, predicts, is_leaf=_is_leaf) + + +def binary_logistic_loss(predicts: float, targets: int, ) -> float: + """Binary logistic loss. + + Args: + targets: ground-truth integer label (0 or 1). + predicts: score produced by the model (float). + Returns: + loss value + """ + # Softplus is the Fenchel conjugate of the Fermi-Dirac negentropy on [0, 1]. + # softplus = proba * logit - xlogx(proba) - xlogx(1 - proba), + # where xlogx(proba) = proba * log(proba). + return bm.as_device_array(bm.activations.softplus(predicts) - targets * predicts) + + +def multiclass_logistic_loss(label: int, logits: jnp.ndarray) -> float: + """Multiclass logistic loss. + + Args: + label: ground-truth integer label, between 0 and n_classes - 1. + logits: scores produced by the model, shape = (n_classes, ). + Returns: + loss value + """ + logits = bm.as_device_array(logits) + n_classes = logits.shape[0] + one_hot = bm.one_hot(label, n_classes) + # Logsumexp is the Fenchel conjugate of the Shannon negentropy on the simplex. + # logsumexp = jnp.dot(proba, logits) - jnp.dot(proba, jnp.log(proba)) + return jax.scipy.special.logsumexp(logits) - bm.dot(logits, one_hot) + + +def smooth_labels(labels, alpha: float) -> jnp.ndarray: + r"""Apply label smoothing. + Label smoothing is often used in combination with a cross-entropy loss. + Smoothed labels favour small logit gaps, and it has been shown that this can + provide better model calibration by preventing overconfident predictions. + References: + [Müller et al, 2019](https://arxiv.org/pdf/1906.02629.pdf) + Args: + labels: one hot labels to be smoothed. + alpha: the smoothing factor, the greedy category with be assigned + probability `(1-alpha) + alpha / num_categories` + Returns: + a smoothed version of the one hot input labels. + """ + num_categories = labels.shape[-1] + return (1.0 - alpha) * labels + alpha / num_categories + + +def sigmoid_binary_cross_entropy(logits, labels): + """Computes sigmoid cross entropy given logits and multiple class labels. + Measures the probability error in discrete classification tasks in which + each class is an independent binary prediction and different classes are + not mutually exclusive. This may be used for multilabel image classification + for instance a model may predict that an image contains both a cat and a dog. + References: + [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) + Args: + logits: unnormalized log probabilities. + labels: the probability for that class. + Returns: + a sigmoid cross entropy loss. + """ + log_p = bm.log_sigmoid(logits) + # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter more numerically stable + log_not_p = bm.log_sigmoid(-logits) + return -labels * log_p - (1. - labels) * log_not_p + + +def softmax_cross_entropy(logits, labels): + """Computes the softmax cross entropy between sets of logits and labels. + Measures the probability error in discrete classification tasks in which + the classes are mutually exclusive (each entry is in exactly one class). + For example, each CIFAR-10 image is labeled with one and only one label: + an image can be a dog or a truck, but not both. + References: + [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) + Args: + logits: unnormalized log probabilities. + labels: a valid probability distribution (non-negative, sum to 1), e.g a + one hot encoding of which class is the correct one for each input. + Returns: + the cross entropy loss. + """ + logits = bm.as_device_array(logits) + labels = bm.as_device_array(labels) + return -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) + + +def log_cosh_loss(predicts, targets): + r"""Calculates the log-cosh loss for a set of predictions. + + log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)` + for large x. It is a twice differentiable alternative to the Huber loss. + References: + [Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym) + Args: + predicts: a vector of arbitrary shape. + targets: a vector of shape compatible with predictions; if not provided + then it is assumed to be zero. + Returns: + the log-cosh loss. + """ + errors = bm.as_device_array(predicts - targets) + return jnp.logaddexp(errors, -errors) - jnp.log(2.0).astype(errors.dtype) diff --git a/brainpy/losses/regularization.py b/brainpy/losses/regularization.py new file mode 100644 index 000000000..e4612b48e --- /dev/null +++ b/brainpy/losses/regularization.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- + +from jax.tree_util import tree_flatten, tree_map + +import brainpy.math as bm +from .utils import _is_leaf, _multi_return + +__all__ = [ + 'l2_norm', + 'mean_absolute', + 'mean_square', + 'log_cosh', +] + + +def l2_norm(x, axis=None): + """Computes the L2 loss. + + Args: + x: n-dimensional tensor of floats. + + Returns: + scalar tensor containing the l2 loss of x. + """ + leaves, _ = tree_flatten(x, is_leaf=_is_leaf) + return bm.sqrt(bm.sum([bm.vdot(x, x) for x in leaves], axis=axis)) + + +def mean_absolute(outputs, axis=None): + r"""Computes the mean absolute error between x and y. + + Returns: + tensor of shape (d_i, ..., for i in keep_axis) containing the mean absolute error. + """ + r = tree_map(lambda a: bm.mean(bm.abs(a), axis=axis), outputs, is_leaf=_is_leaf) + return _multi_return(r) + + +def mean_square(predicts, axis=None): + r = tree_map(lambda a: bm.mean(a ** 2, axis=axis), predicts, is_leaf=_is_leaf) + return _multi_return(r) + + +def log_cosh(errors): + r"""Calculates the log-cosh loss for a set of predictions. + + log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)` + for large x. It is a twice differentiable alternative to the Huber loss. + References: + [Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym) + Args: + errors: a vector of arbitrary shape. + Returns: + the log-cosh loss. + """ + r = tree_map(lambda a: bm.logaddexp(a, -a) - bm.log(2.0).astype(a.dtype), + errors, is_leaf=_is_leaf) + return _multi_return(r) diff --git a/brainpy/losses/utils.py b/brainpy/losses/utils.py new file mode 100644 index 000000000..fec7c026c --- /dev/null +++ b/brainpy/losses/utils.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- + + +from jax.tree_util import tree_flatten + +import brainpy.math as bm +from brainpy.errors import UnsupportedError + +_reduction_error = 'Only support reduction of "mean", "sum" and "none", but we got "%s".' + + +def _is_leaf(x): + return isinstance(x, bm.JaxArray) + + +def _return(outputs, reduction): + if reduction == 'mean': + return outputs.mean() + elif reduction == 'sum': + return outputs.sum() + elif reduction == 'none': + return outputs + else: + raise UnsupportedError(_reduction_error % reduction) + + +def _multi_return(r): + leaves = tree_flatten(r)[0] + r = leaves[0] + for leaf in leaves[1:]: + r += leaf + return r diff --git a/docs/auto_generater.py b/docs/auto_generater.py index 64a773683..8fe673b62 100644 --- a/docs/auto_generater.py +++ b/docs/auto_generater.py @@ -388,9 +388,15 @@ def generate_losses_docs(path='apis/auto/'): if not os.path.exists(path): os.makedirs(path) - write_module(module_name='brainpy.losses', - filename=os.path.join(path, 'losses.rst'), - header='``brainpy.losses`` module') + module_and_name = [ + ('Comparison', 'comparison'), + ('Regularization', 'regularization'), + ] + write_submodules(module_name='brainpy.losses', + filename=os.path.join(path, 'losses.rst'), + header='``brainpy.losses`` module', + submodule_names=[k[0] for k in module_and_name], + section_names=[k[1] for k in module_and_name]) def generate_math_docs(path='apis/auto/math/'): From aac793e7769f8bddfe6fa3c5bcebefdf05900aa0 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 4 Jun 2022 14:22:30 +0800 Subject: [PATCH 5/9] [module] enhance `brainpy.analysis` for directly analyzing DynamicalSystem instance --- brainpy/analysis/__init__.py | 1 + brainpy/analysis/base.py | 12 + brainpy/analysis/constants.py | 9 + brainpy/analysis/highdim/slow_points.py | 584 ++++++++++++++---- .../highdim/tests/test_slow_points.py | 164 +++++ brainpy/analysis/lowdim/lowdim_analyzer.py | 27 +- brainpy/analysis/utils/measurement.py | 74 ++- brainpy/analysis/utils/others.py | 59 +- brainpy/base/collector.py | 36 +- brainpy/dyn/runners.py | 135 ++-- brainpy/math/controls.py | 10 +- examples/analysis/2d_decision_making_model.py | 21 +- examples/analysis/4d_HH_model.py | 3 +- examples/analysis/highdim_CANN.py | 6 +- examples/analysis/highdim_RNN_Analysis.py | 2 +- examples/analysis/highdim_gj_coupled_fhn.py | 49 +- 16 files changed, 929 insertions(+), 263 deletions(-) create mode 100644 brainpy/analysis/base.py create mode 100644 brainpy/analysis/highdim/tests/test_slow_points.py diff --git a/brainpy/analysis/__init__.py b/brainpy/analysis/__init__.py index aeeebe272..c63d15b4c 100644 --- a/brainpy/analysis/__init__.py +++ b/brainpy/analysis/__init__.py @@ -19,6 +19,7 @@ from .lowdim.lowdim_phase_plane import * from .lowdim.lowdim_bifurcation import * +from .constants import * from . import constants as C from . import stability from . import utils diff --git a/brainpy/analysis/base.py b/brainpy/analysis/base.py new file mode 100644 index 000000000..3001289c2 --- /dev/null +++ b/brainpy/analysis/base.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + + +__all__ = [ + 'BrainPyAnalyzer' +] + + +class BrainPyAnalyzer(object): + """Base class for analyzers in BrainPy""" + pass + diff --git a/brainpy/analysis/constants.py b/brainpy/analysis/constants.py index ae85b6527..e9691cca5 100644 --- a/brainpy/analysis/constants.py +++ b/brainpy/analysis/constants.py @@ -1,6 +1,15 @@ # -*- coding: utf-8 -*- +__all__ = [ + 'CONTINUOUS', + 'DISCRETE', +] + + +CONTINUOUS = 'continuous' +DISCRETE = 'discrete' + F_vmap_fx = 'F_vmap_fx' F_vmap_fy = 'F_vmap_fy' F_vmap_brentq_fx = 'F_vmap_brentq_fx' diff --git a/brainpy/analysis/highdim/slow_points.py b/brainpy/analysis/highdim/slow_points.py index 598768f24..50bf8098b 100644 --- a/brainpy/analysis/highdim/slow_points.py +++ b/brainpy/analysis/highdim/slow_points.py @@ -2,24 +2,36 @@ import time import warnings -from functools import partial +from typing import Callable, Union, Dict, Optional, Sequence -from jax import vmap -import jax.numpy +import jax.numpy as jnp import numpy as np +from jax import vmap from jax.scipy.optimize import minimize +from jax.tree_util import tree_flatten, tree_unflatten, tree_map import brainpy.math as bm -from brainpy import optimizers as optim -from brainpy.analysis import utils -from brainpy.errors import AnalyzerError +from brainpy import optimizers as optim, losses +from brainpy.analysis import utils, base, constants +from brainpy.base import TensorCollector +from brainpy.dyn.base import DynamicalSystem +from brainpy.dyn.runners import build_inputs, check_and_format_inputs +from brainpy.errors import AnalyzerError, UnsupportedError +from brainpy.types import Tensor __all__ = [ 'SlowPointFinder', ] +F_OPT_SOLVER = 'function_for_opt_solver' +F_GRADIENT_DESCENT = 'function_for_gradient_descent' + +SUPPORTED_OPT_SOLVERS = { + 'BFGS': lambda f, x0: minimize(f, x0, method='BFGS') +} -class SlowPointFinder(object): + +class SlowPointFinder(base.BrainPyAnalyzer): """Find fixed/slow points by numerical optimization. This class can help you: @@ -29,85 +41,253 @@ class SlowPointFinder(object): - exclude any non-unique fixed points according to a tolerance - exclude any far-away "outlier" fixed points - This model implementation is inspired by https://github.com/google-research/computation-thru-dynamics. - Parameters ---------- - f_cell : callable, function - The function to compute the recurrent units. + f_cell : callable, function, DynamicalSystem + The target of computing the recurrent units. + f_type : str The system's type: continuous system or discrete system. - 'continuous': continuous derivative function, denotes this is a continuous system, or - 'discrete': discrete update function, denotes this is a discrete system. + + verbose : bool + Whether output the optimization progress. + + f_loss: callable + The loss function. + - If ``f_type`` is `"discrete"`, the loss function must receive three arguments, i.e., + ``loss(outputs, targets, axis)``. + - If ``f_type`` is `"continuous"`, the loss function must receive two arguments, i.e., + ``loss(outputs, axis)``. + + .. versionadded:: 2.2.0 + + t: float + The time to evaluate the fixed points. + + .. versionadded:: 2.2.0 + + inputs: sequence + Same as ``inputs`` in :py:class:`~.DSRunner`. + + .. versionadded:: 2.2.0 + + excluded_vars: sequence + The excluded variables (can be a sequence of `Variable` instances), + when ``f_cell`` is an instance of :py:class:`~.DynamicalSystem`. + These variables will not be included for optimization of fixed points. + + .. versionadded:: 2.2.0 + + included_vars: dict + The target variables (can be a dict of `Variable` instances), + when ``f_cell`` is an instance of :py:class:`~.DynamicalSystem`. + These variables will be included for optimization of fixed points. + The candidate points later provided should have same keys as in ``included_vars``. + + .. versionadded:: 2.2.0 + f_loss_batch : callable, function The function to compute the loss. - verbose : bool - Whether print the optimization progress. + + .. deprecated:: 2.2.0 + Has been removed. Please use ``f_loss`` to set different loss function. + """ - def __init__(self, f_cell, f_type='continuous', f_loss_batch=None, verbose=True): + def __init__( + self, + f_cell: Union[Callable, DynamicalSystem], + f_type: str = None, + f_loss: Callable = None, + inputs: Sequence = None, + t: float = 0., + verbose: bool = True, + f_loss_batch: Callable = None, + included_vars: Dict[str, bm.Variable] = None, + excluded_vars: Sequence[bm.Variable] = (), + ): + super(SlowPointFinder, self).__init__() + + # update function + if included_vars is None: + self.included_vars = TensorCollector() + else: + if not isinstance(included_vars, dict): + raise TypeError(f'"included_vars" must be a dict but we got {type(included_vars)}') + self.included_vars = TensorCollector(included_vars) + if not isinstance(excluded_vars, (tuple, list)): + raise TypeError(f'"excluded_vars" must be a sequence but we got {type(excluded_vars)}') + for v in excluded_vars: + if not isinstance(v, bm.Variable): + raise TypeError(f'"excluded_vars" must be a sequence of Variable, ' + f'but we got {type(v)}') + self.excluded_vars = {f'_exclude_v{i}': v for i, v in enumerate(excluded_vars)} + self.target = f_cell + + if len(self.included_vars) > 0 and len(self.excluded_vars) > 0: + raise ValueError + + if isinstance(f_cell, DynamicalSystem): + # included variables + all_vars = f_cell.vars(method='relative', level=-1, include_self=True).unique() + # exclude variables + if len(self.included_vars) > 0: + _all_ids = [id(v) for v in self.included_vars.values()] + for k, v in all_vars.items(): + if id(v) not in _all_ids: + self.excluded_vars[k] = v + else: + self.included_vars = all_vars + if len(excluded_vars): + excluded_vars = [id(v) for v in excluded_vars] + for key, val in tuple(self.included_vars.items()): + if id(val) in excluded_vars: + self.included_vars.pop(key) + # input function + if inputs is not None: + inputs = check_and_format_inputs(host=self.target, inputs=inputs) + _input_step, _i = build_inputs(inputs) + if _i is not None: + raise UnsupportedError(f'Do not support iterable inputs when using fixed point finder.') + else: + _input_step = None + # update function + self.f_cell = self._generate_ds_cell_function(self.target, + self.included_vars, + self.excluded_vars, + t, + _input_step) + # check function type + if f_type is not None: + if f_type != constants.DISCRETE: + raise ValueError(f'"f_type" must be "{constants.DISCRETE}" when "f_cell" ' + f'is instance of {DynamicalSystem.__name__}') + f_type = constants.DISCRETE + + elif callable(f_cell): + self.f_cell = f_cell + if inputs is not None: + raise UnsupportedError('Do not support "inputs" when "f_cell" is not instance of ' + f'{DynamicalSystem.__name__}') + + else: + raise ValueError(f'Unknown type of "f_type": {type(f_cell)}') + if f_type not in [constants.DISCRETE, constants.CONTINUOUS]: + raise AnalyzerError(f'Only support "{constants.CONTINUOUS}" (continuous derivative function) or ' + f'"{constants.DISCRETE}" (discrete update function), not {f_type}.') self.verbose = verbose - if f_type not in ['discrete', 'continuous']: - raise AnalyzerError(f'Only support "continuous" (continuous derivative function) or ' - f'"discrete" (discrete update function), not {f_type}.') + self.f_type = f_type + + # loss functon + if f_loss_batch is not None: + raise UnsupportedError('"f_loss_batch" is no longer supported, please ' + 'use "f_loss" instead.') + if f_loss is None: + f_loss = losses.mean_squared_error if f_type == constants.DISCRETE else losses.mean_square + self.f_loss = f_loss # functions - self.f_cell = f_cell - if f_loss_batch is None: - if f_type == 'discrete': - self.f_loss = bm.jit(lambda h: bm.mean((h - f_cell(h)) ** 2)) - self.f_loss_batch = bm.jit(lambda h: bm.mean((h - vmap(f_cell)(h)) ** 2, axis=1)) - if f_type == 'continuous': - self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h) ** 2)) - self.f_loss_batch = bm.jit(lambda h: bm.mean((vmap(f_cell)(h)) ** 2, axis=1)) + self._opt_functions = dict() + # functions + if self.f_type == constants.DISCRETE: + # evaluate losses of a batch of inputs + self.f_eval_loss = bm.jit(lambda h: self.f_loss(h, vmap(self.f_cell)(h), axis=1)) else: - self.f_loss_batch = f_loss_batch - self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h) ** 2)) - self.f_jacob_batch = bm.jit(vmap(bm.jacobian(f_cell))) + # evaluate losses of a batch of inputs + self.f_eval_loss = bm.jit(lambda h: self.f_loss(vmap(self.f_cell)(h), axis=1)) + # evaluate Jacobian matrix of a batch of inputs + + # if f_type == constants.DISCRETE: + # # overall loss function for fixed points optimization + # self.f_loss = bm.jit(lambda h: f_loss(h, f_cell(h))) + # # evaluate losses of a batch of inputs + # self.f_loss_batch = bm.jit(lambda h: f_loss(h, vmap(f_cell)(h), axis=1)) + # elif f_type == constants.CONTINUOUS: + # # overall loss function for fixed points optimization + # self.f_loss = bm.jit(lambda h: f_loss(f_cell(h))) + # # evaluate losses of a batch of inputs + # self.f_loss_batch = bm.jit(lambda h: f_loss(vmap(f_cell)(h), axis=1)) # essential variables self._losses = None self._fixed_points = None self._selected_ids = None - self.opt_losses = None + self._opt_losses = None @property - def fixed_points(self): + def opt_losses(self) -> np.ndarray: + """The optimization losses.""" + return np.asarray(self._opt_losses) + + @opt_losses.setter + def opt_losses(self, val): + raise UnsupportedError('Do not support set "opt_losses" by users.') + + @property + def fixed_points(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: """The final fixed points found.""" - return self._fixed_points + return tree_map(lambda a: np.asarray(a), self._fixed_points) + + @fixed_points.setter + def fixed_points(self, val): + raise UnsupportedError('Do not support set "fixed_points" by users.') @property - def losses(self): + def num_fps(self) -> int: + if isinstance(self._fixed_points, dict): + return tuple(self._fixed_points.values())[0].shape[0] + else: + return self._fixed_points.shape[0] + + @property + def losses(self) -> np.ndarray: """Losses of fixed points.""" - return self._losses + return np.asarray(self._losses) + + @losses.setter + def losses(self, val): + raise UnsupportedError('Do not support set "losses" by users.') @property - def selected_ids(self): + def selected_ids(self) -> np.ndarray: """The selected ids of candidate points.""" - return self._selected_ids - - def find_fps_with_gd_method(self, - candidates, - tolerance=1e-5, - num_batch=100, - num_opt=10000, - optimizer=None, - opt_setting=None): + return np.asarray(self._selected_ids) + + @selected_ids.setter + def selected_ids(self, val): + raise UnsupportedError('Do not support set "selected_ids" by users.') + + + def find_fps_with_gd_method( + self, + candidates: Union[Tensor, Dict[str, Tensor]], + tolerance: Union[float, Dict[str, float]] = 1e-5, + num_batch: int = 100, + num_opt: int = 10000, + optimizer: optim.Optimizer = None, + opt_setting: Optional[Dict] = None + ): """Optimize fixed points with gradient descent methods. Parameters ---------- - candidates : jax.ndarray, JaxArray + candidates : Tensor, dict The array with the shape of (batch size, state dim) of hidden states of RNN to start training for fixed points. + tolerance: float The loss threshold during optimization + num_opt : int The maximum number of optimization. + num_batch : int Print training information during optimization every so often. + opt_setting: optional, dict The optimization settings. @@ -126,13 +306,13 @@ def find_fps_with_gd_method(self, optimizer = optim.Adam(lr=optim.ExponentialDecay(0.2, 1, 0.9999), beta1=0.9, beta2=0.999, eps=1e-8) else: - assert isinstance(optimizer, optim.Optimizer), (f'Must be an instance of ' - f'{optim.Optimizer.__name__}, ' - f'while we got {type(optimizer)}') + if not isinstance(optimizer, optim.Optimizer): + raise ValueError(f'Must be an instance of {optim.Optimizer.__name__}, ' + f'while we got {type(optimizer)}') else: - warnings.warn('Please use "optimizer" to set optimization method. ' + warnings.warn('\nPlease use "optimizer" to set optimization method. ' '"opt_setting" is deprecated since version 2.1.2. ', - DeprecationWarning) + UserWarning) assert isinstance(opt_setting, dict) assert 'method' in opt_setting @@ -147,79 +327,120 @@ def find_fps_with_gd_method(self, opt_setting = opt_setting optimizer = opt_method(lr=opt_lr, **opt_setting) - if self.verbose: - print(f"Optimizing with {optimizer} to find fixed points:") - # set up optimization - fixed_points = bm.Variable(bm.asarray(candidates)) - grad_f = bm.grad(lambda: self.f_loss_batch(fixed_points.value).mean(), - grad_vars={'a': fixed_points}, return_value=True) - optimizer.register_vars({'a': fixed_points}) - dyn_vars = optimizer.vars() + {'_a': fixed_points} + num_candidate = self._check_candidates(candidates) + if not (isinstance(candidates, (bm.ndarray, jnp.ndarray, np.ndarray)) or isinstance(candidates, dict)): + raise ValueError('Candidates must be instance of JaxArray or dict of JaxArray.') + leaves, tree = tree_flatten(candidates, is_leaf=lambda x: isinstance(x, bm.JaxArray)) + fixed_points = tree_unflatten(tree, [bm.TrainVar(leaf) for leaf in leaves]) + + def f_loss(): + return self.f_eval_loss(tree_map(lambda a: a.value, + fixed_points, + is_leaf=lambda x: isinstance(x, bm.JaxArray))).mean() + + grad_f = bm.grad(f_loss, grad_vars=fixed_points, return_value=True) + optimizer.register_vars(fixed_points if isinstance(fixed_points, dict) else {'a': fixed_points}) + dyn_vars = optimizer.vars() + (fixed_points if isinstance(fixed_points, dict) else {'a': fixed_points}) + dyn_vars = dyn_vars.unique() def train(idx): gradients, loss = grad_f() - optimizer.update(gradients) + optimizer.update(gradients if isinstance(gradients, dict) else {'a': gradients}) return loss - @partial(bm.jit, dyn_vars=dyn_vars, static_argnames=('start_i', 'num_batch')) - def batch_train(start_i, num_batch): + def batch_train(start_i, n_batch): f = bm.make_loop(train, dyn_vars=dyn_vars, has_return=True) - return f(bm.arange(start_i, start_i + num_batch)) + return f(bm.arange(start_i, start_i + n_batch)) # Run the optimization + if self.verbose: + print(f"Optimizing with {optimizer} to find fixed points:") opt_losses = [] do_stop = False num_opt_loops = int(num_opt / num_batch) for oidx in range(num_opt_loops): - if do_stop: break + if do_stop: + break batch_idx_start = oidx * num_batch start_time = time.time() - (_, losses) = batch_train(start_i=batch_idx_start, num_batch=num_batch) + (_, train_losses) = batch_train(start_i=batch_idx_start, n_batch=num_batch) batch_time = time.time() - start_time - opt_losses.append(losses) + opt_losses.append(train_losses) if self.verbose: print(f" " f"Batches {batch_idx_start + 1}-{batch_idx_start + num_batch} " - f"in {batch_time:0.2f} sec, Training loss {losses[-1]:0.10f}") + f"in {batch_time:0.2f} sec, Training loss {train_losses[-1]:0.10f}") - if losses[-1] < tolerance: + if train_losses[-1] < tolerance: do_stop = True if self.verbose: print(f' ' - f'Stop optimization as mean training loss {losses[-1]:0.10f} ' + f'Stop optimization as mean training loss {train_losses[-1]:0.10f} ' f'is below tolerance {tolerance:0.10f}.') - self.opt_losses = bm.concatenate(opt_losses) - self._losses = np.asarray(self.f_loss_batch(fixed_points)) - self._fixed_points = np.asarray(fixed_points) - self._selected_ids = np.arange(fixed_points.shape[0]) - def find_fps_with_opt_solver(self, candidates, opt_method=None): + self._opt_losses = bm.concatenate(opt_losses) + self._losses = self.f_eval_loss(tree_map(lambda a: a.value, + fixed_points, + is_leaf=lambda x: isinstance(x, bm.JaxArray))) + self._fixed_points = tree_map(lambda a: a.value, fixed_points, + is_leaf=lambda x: isinstance(x, bm.JaxArray)) + self._selected_ids = jnp.arange(num_candidate) + + def find_fps_with_opt_solver( + self, + candidates: Union[Tensor, Dict[str, Tensor]], + opt_solver: str = 'BFGS' + ): """Optimize fixed points with nonlinear optimization solvers. Parameters ---------- - candidates - opt_method: function, callable + candidates: Tensor, dict + The candidate (initial) fixed points. + opt_solver: str + The solver of the optimization. """ - assert bm.ndim(candidates) == 2 and isinstance(candidates, (bm.JaxArray, jax.numpy.ndarray)) - if opt_method is None: - opt_method = lambda f, x0: minimize(f, x0, method='BFGS') + # optimization function + num_candidate = self._check_candidates(candidates) + for var in self.included_vars.values(): + if bm.ndim(var) != 1: + raise ValueError('Cannot use opt solver.') + if self._opt_functions.get(F_OPT_SOLVER, None) is None: + self._opt_functions[F_OPT_SOLVER] = self._get_f_for_opt_solver( + candidates, SUPPORTED_OPT_SOLVERS[opt_solver]) + f_opt = self._opt_functions[F_OPT_SOLVER] + if self.verbose: - print(f"Optimizing to find fixed points:") - f_opt = bm.jit(vmap(lambda x0: opt_method(self.f_loss, x0))) - res = f_opt(bm.as_device_array(candidates)) - valid_ids = jax.numpy.where(res.success)[0] - self._fixed_points = np.asarray(res.x[valid_ids]) - self._losses = np.asarray(res.fun[valid_ids]) - self._selected_ids = np.asarray(valid_ids) + print(f"Optimizing with {opt_solver} to find fixed points:") + + # optimizing + res = f_opt(tree_map(lambda a: a.value, + candidates, + is_leaf=lambda a: isinstance(a, bm.JaxArray))) + + # results + valid_ids = jnp.where(res.success)[0] + fixed_points = res.x[valid_ids] + if isinstance(candidates, dict): + indices = [0] + for v in candidates.values(): + indices.append(v.shape[1]) + indices = np.cumsum(indices) + keys = tuple(candidates.keys()) + self._fixed_points = {key: fixed_points[:, indices[i]: indices[i + 1]] + for i, key in enumerate(keys)} + else: + self._fixed_points = fixed_points + self._losses = res.fun[valid_ids] + self._selected_ids = jnp.asarray(valid_ids) if self.verbose: print(f' ' - f'Found {len(valid_ids)} fixed points from {len(candidates)} initial points.') + f'Found {len(valid_ids)} fixed points from {num_candidate} initial points.') - def filter_loss(self, tolerance=1e-5): + def filter_loss(self, tolerance: float = 1e-5): """Filter fixed points whose speed larger than a given tolerance. Parameters @@ -230,18 +451,21 @@ def filter_loss(self, tolerance=1e-5): if self.verbose: print(f"Excluding fixed points with squared speed above " f"tolerance {tolerance}:") - num_fps = self.fixed_points.shape[0] + if isinstance(self._fixed_points, dict): + num_fps = tuple(self._fixed_points.values())[0].shape[0] + else: + num_fps = self._fixed_points.shape[0] ids = self._losses < tolerance - keep_ids = bm.where(ids)[0] - self._fixed_points = self._fixed_points[ids] + keep_ids = bm.as_device_array(bm.where(ids)[0]) + self._fixed_points = tree_map(lambda a: a[keep_ids], self._fixed_points) self._losses = self._losses[keep_ids] self._selected_ids = self._selected_ids[keep_ids] if self.verbose: print(f" " - f"Kept {self._fixed_points.shape[0]}/{num_fps} " + f"Kept {len(keep_ids)}/{num_fps} " f"fixed points with tolerance under {tolerance}.") - def keep_unique(self, tolerance=2.5e-2): + def keep_unique(self, tolerance: float = 2.5e-2): """Filter unique fixed points by choosing a representative within tolerance. Parameters @@ -251,16 +475,19 @@ def keep_unique(self, tolerance=2.5e-2): """ if self.verbose: print("Excluding non-unique fixed points:") - num_fps = self.fixed_points.shape[0] + if isinstance(self._fixed_points, dict): + num_fps = tuple(self._fixed_points.values())[0].shape[0] + else: + num_fps = self._fixed_points.shape[0] fps, keep_ids = utils.keep_unique(self.fixed_points, tolerance=tolerance) - self._fixed_points = fps + self._fixed_points = tree_map(lambda a: jnp.asarray(a), fps) self._losses = self._losses[keep_ids] self._selected_ids = self._selected_ids[keep_ids] if self.verbose: - print(f" Kept {self._fixed_points.shape[0]}/{num_fps} unique fixed points " + print(f" Kept {keep_ids.shape[0]}/{num_fps} unique fixed points " f"with uniqueness tolerance {tolerance}.") - def exclude_outliers(self, tolerance=1e0): + def exclude_outliers(self, tolerance: float = 1e0): """Exclude points whose closest neighbor is further than threshold. Parameters @@ -272,11 +499,15 @@ def exclude_outliers(self, tolerance=1e0): print("Excluding outliers:") if np.isinf(tolerance): return - if self._fixed_points.shape[0] <= 1: + if isinstance(self._fixed_points, dict): + num_fps = tuple(self._fixed_points.values())[0].shape[0] + else: + num_fps = self._fixed_points.shape[0] + if num_fps <= 1: return # Compute pairwise distances between all fixed points. - distances = utils.euclidean_distance(self._fixed_points) + distances = np.asarray(utils.euclidean_distance_jax(self.fixed_points, num_fps)) # Find second smallest element in each column of the pairwise distance matrix. # This corresponds to the closest neighbor for each fixed point. @@ -284,8 +515,7 @@ def exclude_outliers(self, tolerance=1e0): # Return data with outliers removed and indices of kept datapoints. keep_ids = np.where(closest_neighbor < tolerance)[0] - num_fps = self._fixed_points.shape[0] - self._fixed_points = self._fixed_points[keep_ids] + self._fixed_points = tree_map(lambda a: a[keep_ids], self._fixed_points) self._selected_ids = self._selected_ids[keep_ids] self._losses = self._losses[keep_ids] @@ -294,32 +524,39 @@ def exclude_outliers(self, tolerance=1e0): f"Kept {keep_ids.shape[0]}/{num_fps} fixed points " f"with within outlier tolerance {tolerance}.") - def compute_jacobians(self, points): + def compute_jacobians(self, points, stack_vars=True): """Compute the jacobian matrices at the points. Parameters ---------- points: np.ndarray, bm.JaxArray, jax.ndarray The fixed points with the shape of (num_point, num_dim). - - Returns - ------- - jacobians : bm.JaxArray - npoints number of jacobians, np array with shape npoints x dim x dim + stack_vars: bool """ - # if len(self.fixed_points) == 0: return - if bm.ndim(points) == 1: - points = bm.asarray([points, ]) - assert bm.ndim(points) == 2 - return self.f_jacob_batch(bm.asarray(points)) + ndim = np.unique([l.ndim for l in tree_flatten(points, is_leaf=lambda a: isinstance(a, bm.JaxArray))[0]]) + if len(ndim) != 1: + raise ValueError(f'Get multiple dimension of the evaluated points. {ndim}') + if ndim[0] == 1: + points = tree_map(lambda a: bm.asarray([a]), points) + elif ndim[0] == 2: + pass + else: + raise ValueError('Only support points of 1D: (num_feature,) or 2D: (num_point, num_feature)') - def decompose_eigenvalues(self, matrices, sort_by='magnitude', do_compute_lefts=True): + if isinstance(points, dict) and stack_vars: + points = bm.hstack(points.values()).value + return self._get_f_jocabian(stack_vars)(points) + + @staticmethod + def decompose_eigenvalues(matrices, sort_by='magnitude', do_compute_lefts=True): """Compute the eigenvalues of the matrices. Parameters ---------- matrices: np.ndarray, bm.JaxArray, jax.ndarray A 3D array with the shape of (num_matrices, dim, dim). + sort_by: str + The method of sorting. do_compute_lefts: bool Compute the left eigenvectors? Requires a pseudo-inverse call. @@ -335,6 +572,7 @@ def decompose_eigenvalues(self, matrices, sort_by='magnitude', do_compute_lefts= sort_fun = np.real else: raise ValueError("Not implemented yet.") + matrices = np.asarray(matrices) decompositions = [] for mat in matrices: @@ -348,3 +586,127 @@ def decompose_eigenvalues(self, matrices, sort_by='magnitude', do_compute_lefts= 'R': eig_vectors[:, indices], 'L': L}) return decompositions + + def _get_f_for_opt_solver(self, candidates, opt_method): + # update function + if isinstance(candidates, (bm.ndarray, jnp.ndarray, np.ndarray)): + f_cell = self.f_cell + + elif isinstance(candidates, dict): + indices = [0] + for v in self.included_vars.values(): + indices.append(v.shape[0]) + indices = np.cumsum(indices) + keys = tuple(self.included_vars.keys()) + + def f_cell(x): + x = {keys[i]: x[indices[i]: indices[i + 1]] for i in range(len(keys))} + r = self.f_cell(x) + return r + + else: + raise ValueError(f'Only supports tensor or a dict of tensors. But we got {type(candidates)}') + + # loss function + if self.f_type == constants.DISCRETE: + # overall loss function for fixed points optimization + if isinstance(candidates, dict): + def f_loss(h): + return bm.as_device_array( + self.f_loss({key: h[indices[i]: indices[i + 1]] for i, key in enumerate(self.included_vars.keys())}, + {k: v for k, v in f_cell(h).items() if k in self.included_vars}) + ) + else: + def f_loss(h): + return bm.as_device_array(self.f_loss(h, f_cell(h))) + else: + # overall loss function for fixed points optimization + def f_loss(h): + return self.f_loss(f_cell(h)) + + excluded_data = {k: v.value for k, v in self.excluded_vars.items()} + + @bm.jit + @vmap + def f_opt(x0): + for k, v in self.included_vars.items(): + v.value = x0[k] + for k, v in self.excluded_vars.items(): + v.value = excluded_data[k] + if isinstance(x0, dict): + x0 = bm.concatenate(tuple(x0.values())).value + return opt_method(f_loss, x0) + + return f_opt + + def _generate_ds_cell_function(self, + ds_instance, + included_vars: Dict, + excluded_vars: Dict, + t=0., + f_input=None): + + excluded_data = {k: v.value for k, v in excluded_vars.items()} + + def f_cell(h: Dict): + for k, v in included_vars.items(): + v.value = bm.asarray(h[k], dtype=v.dtype) + for k, v in excluded_vars.items(): + v.value = excluded_data[k] + if f_input is not None: + f_input(t, bm.get_dt()) + ds_instance.update(t, bm.get_dt()) + return {k: v.value for k, v in included_vars.items()} + + return f_cell + + def _get_f_jocabian(self, stack=True): + name = f'f_eval_jacobian_stack={stack}' + if name not in self._opt_functions: + self._opt_functions[name] = self._generate_ds_jocabian(stack) + return self._opt_functions[name] + + def _generate_ds_jocabian(self, stack=True): + if stack and isinstance(self.target, DynamicalSystem): + indices = [0] + for var in self.included_vars.values(): + indices.append(var.shape[0]) + indices = np.cumsum(indices) + + def jacob(x0): + x0 = {k: x0[indices[i]:indices[i + 1]] for i, k in enumerate(self.included_vars.keys())} + r = self.f_cell(x0) + return bm.concatenate(list(r.values())) + else: + jacob = self.f_cell + + return bm.jit(vmap(bm.jacobian(jacob))) + + def _check_candidates(self, candidates): + if isinstance(self.target, DynamicalSystem): + if not isinstance(candidates, dict): + raise ValueError(f'When "f_cell" is instance of {DynamicalSystem.__name__}, ' + f'we should provide "candidates" as a dict, in which the key is ' + f'the variable name with relative path, and the value ' + f'is the candidate fixed point values. ') + for key in candidates: + if key not in self.included_vars: + raise KeyError(f'"{key}" is not defined in required variables ' + f'for fixed point optimization of {self.target}. ' + f'Please do not provide its initial values.') + + for key in self.included_vars.keys(): + if key not in candidates: + raise KeyError(f'"{key}" is defined in required variables ' + f'for fixed point optimization of {self.target}. ' + f'Please provide its initial values.') + + if isinstance(candidates, dict): + num_candidate = np.unique([leaf.shape[0] for leaf in candidates.values()]) + if len(num_candidate) != 1: + raise ValueError('The numbers of candidates for each variable should be the same. ' + f'But we got {num_candidate}') + num_candidate = num_candidate[0] + else: + num_candidate = candidates.shape[0] + return num_candidate \ No newline at end of file diff --git a/brainpy/analysis/highdim/tests/test_slow_points.py b/brainpy/analysis/highdim/tests/test_slow_points.py new file mode 100644 index 000000000..6a9961607 --- /dev/null +++ b/brainpy/analysis/highdim/tests/test_slow_points.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- + +import brainpy as bp +import unittest +import brainpy.math as bm + + +class HH(bp.dyn.NeuGroup): + def __init__(self, size, ENa=50., gNa=120., EK=-77., gK=36., EL=-54.387, gL=0.03, + V_th=20., C=1.0, name=None): + super(HH, self).__init__(size=size, name=name) + + # parameters + self.ENa = ENa + self.EK = EK + self.EL = EL + self.C = C + self.gNa = gNa + self.gK = gK + self.gL = gL + self.V_th = V_th + + # variables + self.V = bm.Variable(bm.ones(self.num) * -65.) + self.m = bm.Variable(0.5 * bm.ones(self.num)) + self.h = bm.Variable(0.6 * bm.ones(self.num)) + self.n = bm.Variable(0.32 * bm.ones(self.num)) + self.spike = bm.Variable(bm.zeros(size, dtype=bool)) + self.input = bm.Variable(bm.zeros(size)) + + # integral functions + self.int_h = bp.ode.ExponentialEuler(self.dh) + self.int_n = bp.ode.ExponentialEuler(self.dn) + self.int_m = bp.ode.ExponentialEuler(self.dm) + self.int_V = bp.ode.ExponentialEuler(self.dV) + + def dh(self, h, t, V): + alpha = 0.07 * bm.exp(-(V + 65) / 20.) + beta = 1 / (1 + bm.exp(-(V + 35) / 10)) + dhdt = alpha * (1 - h) - beta * h + return dhdt + + def dn(self, n, t, V): + alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10)) + beta = 0.125 * bm.exp(-(V + 65) / 80) + dndt = alpha * (1 - n) - beta * n + return dndt + + def dm(self, m, t, V): + alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) + beta = 4.0 * bm.exp(-(V + 65) / 18) + dmdt = alpha * (1 - m) - beta * m + return dmdt + + def dV(self, V, t, m, h, n, Iext): + I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa) + I_K = (self.gK * n ** 4.0) * (V - self.EK) + I_leak = self.gL * (V - self.EL) + dVdt = (- I_Na - I_K - I_leak + Iext) / self.C + return dVdt + + def update(self, t, dt): + m = self.int_m(self.m, t, self.V, dt=dt) + h = self.int_h(self.h, t, self.V, dt=dt) + n = self.int_n(self.n, t, self.V, dt=dt) + V = self.int_V(self.V, t, self.m, self.h, self.n, self.input, dt=dt) + self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) + self.V.value = V + self.h.value = h + self.n.value = n + self.m.value = m + self.input[:] = 0. + + +class TestFixedPointsFinding(unittest.TestCase): + def test_opt_solver_for_func1(self): + gamma = 0.641 # Saturation factor for gating variable + tau = 0.06 # Synaptic time constant [sec] + a = 270. + b = 108. + d = 0.154 + + JE = 0.3725 # self-coupling strength [nA] + JI = -0.1137 # cross-coupling strength [nA] + JAext = 0.00117 # Stimulus input strength [nA] + + mu = 20. # Stimulus firing rate [spikes/sec] + coh = 0.5 # Stimulus coherence [%] + Ib1 = 0.3297 + Ib2 = 0.3297 + + def ds1(s1, t, s2, coh=0.5, mu=20.): + I1 = JE * s1 + JI * s2 + Ib1 + JAext * mu * (1. + coh) + r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b))) + return - s1 / tau + (1. - s1) * gamma * r1 + + def ds2(s2, t, s1, coh=0.5, mu=20.): + I2 = JE * s2 + JI * s1 + Ib2 + JAext * mu * (1. - coh) + r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b))) + return - s2 / tau + (1. - s2) * gamma * r2 + + def step(s): + return bm.asarray([ds1(s[0], 0., s[1]), ds2(s[1], 0., s[0])]) + + finder = bp.analysis.SlowPointFinder(f_cell=step, f_type=bp.analysis.CONTINUOUS) + finder.find_fps_with_opt_solver(bm.random.random((100, 2))) + + def test_opt_solver_for_ds1(self): + hh = HH(1) + finder = bp.analysis.SlowPointFinder(f_cell=hh, excluded_vars=[hh.input, hh.spike]) + + with self.assertRaises(ValueError): + finder.find_fps_with_opt_solver(bm.random.random((100, 4))) + + finder.find_fps_with_opt_solver({'V': bm.random.random((100, 1)), + 'm': bm.random.random((100, 1)), + 'h': bm.random.random((100, 1)), + 'n': bm.random.random((100, 1))}) + + def test_gd_method_for_func1(self): + gamma = 0.641 # Saturation factor for gating variable + tau = 0.06 # Synaptic time constant [sec] + a = 270. + b = 108. + d = 0.154 + + JE = 0.3725 # self-coupling strength [nA] + JI = -0.1137 # cross-coupling strength [nA] + JAext = 0.00117 # Stimulus input strength [nA] + + mu = 20. # Stimulus firing rate [spikes/sec] + coh = 0.5 # Stimulus coherence [%] + Ib1 = 0.3297 + Ib2 = 0.3297 + + def ds1(s1, t, s2, coh=0.5, mu=20.): + I1 = JE * s1 + JI * s2 + Ib1 + JAext * mu * (1. + coh) + r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b))) + return - s1 / tau + (1. - s1) * gamma * r1 + + def ds2(s2, t, s1, coh=0.5, mu=20.): + I2 = JE * s2 + JI * s1 + Ib2 + JAext * mu * (1. - coh) + r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b))) + return - s2 / tau + (1. - s2) * gamma * r2 + + def step(s): + return bm.asarray([ds1(s[0], 0., s[1]), ds2(s[1], 0., s[0])]) + + finder = bp.analysis.SlowPointFinder(f_cell=step, f_type=bp.analysis.CONTINUOUS) + finder.find_fps_with_gd_method(bm.random.random((100, 2)), num_opt=100) + + def test_gd_method_for_func2(self): + hh = HH(1) + finder = bp.analysis.SlowPointFinder(f_cell=hh, excluded_vars=[hh.input, hh.spike]) + + with self.assertRaises(ValueError): + finder.find_fps_with_opt_solver(bm.random.random((100, 4))) + + finder.find_fps_with_gd_method({'V': bm.random.random((100, 1)), + 'm': bm.random.random((100, 1)), + 'h': bm.random.random((100, 1)), + 'n': bm.random.random((100, 1))}, + num_opt=100) + diff --git a/brainpy/analysis/lowdim/lowdim_analyzer.py b/brainpy/analysis/lowdim/lowdim_analyzer.py index 8fe7ec713..07958db36 100644 --- a/brainpy/analysis/lowdim/lowdim_analyzer.py +++ b/brainpy/analysis/lowdim/lowdim_analyzer.py @@ -3,13 +3,14 @@ from functools import partial import numpy as np -from jax import vmap from jax import numpy as jnp +from jax import vmap from jax.scipy.optimize import minimize import brainpy.math as bm from brainpy import errors, tools from brainpy.analysis import constants as C, utils +from brainpy.analysis.base import BrainPyAnalyzer from brainpy.base.collector import Collector pyplot = None @@ -21,7 +22,7 @@ ] -class LowDimAnalyzer(object): +class LowDimAnalyzer(BrainPyAnalyzer): r"""Automatic Analyzer for Low-dimensional Dynamical Systems. A dynamical model is characterized by a series of dynamical @@ -68,16 +69,18 @@ class LowDimAnalyzer(object): The optional setting. Maybe needed in the individual analyzer. """ - def __init__(self, - model, - target_vars, - fixed_vars=None, - target_pars=None, - pars_update=None, - resolutions=None, - jit_device=None, - lim_scale=1.05, - options=None, ): + def __init__( + self, + model, + target_vars, + fixed_vars=None, + target_pars=None, + pars_update=None, + resolutions=None, + jit_device=None, + lim_scale=1.05, + options=None, + ): # model # ----- self.model = utils.model_transform(model) diff --git a/brainpy/analysis/utils/measurement.py b/brainpy/analysis/utils/measurement.py index 24d7d9dd0..d6e90f019 100644 --- a/brainpy/analysis/utils/measurement.py +++ b/brainpy/analysis/utils/measurement.py @@ -1,13 +1,19 @@ # -*- coding: utf-8 -*- +from typing import Union +import brainpy.math as bm +import jax import jax.numpy as jnp import numpy as np +from jax.tree_util import tree_flatten from brainpy.tools.others import numba_jit +from functools import partial __all__ = [ 'find_indexes_of_limit_cycle_max', 'euclidean_distance', + 'euclidean_distance_jax', ] @@ -31,8 +37,8 @@ def find_indexes_of_limit_cycle_max(arr, tol=0.001): return _f1(arr, grad, tol) -# @tools.numba_jit -def euclidean_distance(points: np.ndarray): +@numba_jit +def euclidean_distance(points: np.ndarray, num_point=None): """Get the distance matrix. Equivalent to: @@ -50,13 +56,63 @@ def euclidean_distance(points: np.ndarray): dist_matrix: jnp.ndarray The distance matrix. """ - num_point = points.shape[0] - indices = np.triu_indices(num_point) - dist_mat = np.zeros((num_point, num_point)) - for idx in range(len(indices[0])): - i = indices[0][idx] - j = indices[1][idx] - dist_mat[i, j] = np.linalg.norm(points[i] - points[j]) + + if isinstance(points, dict): + if num_point is None: + raise ValueError('Please provide num_point') + indices = np.triu_indices(num_point) + dist_mat = np.zeros((num_point, num_point)) + for idx in range(len(indices[0])): + i = indices[0][idx] + j = indices[1][idx] + dist_mat[i, j] = np.sqrt(np.sum([np.sum((value[i] - value[j]) ** 2) for value in points.values()])) + else: + num_point = points.shape[0] + indices = np.triu_indices(num_point) + dist_mat = np.zeros((num_point, num_point)) + for idx in range(len(indices[0])): + i = indices[0][idx] + j = indices[1][idx] + dist_mat[i, j] = np.linalg.norm(points[i] - points[j]) dist_mat = np.maximum(dist_mat, dist_mat.T) return dist_mat + +@jax.jit +@partial(jax.vmap, in_axes=[0, 0, None]) +def _ed(i, j, leaves): + squares = bm.asarray([((leaf[i] - leaf[j]) ** 2).sum() for leaf in leaves]) + return bm.sqrt(bm.sum(squares)) + + +def euclidean_distance_jax(points: Union[jnp.ndarray, bm.ndarray], num_point=None): + """Get the distance matrix. + + Equivalent to: + + >>> from scipy.spatial.distance import squareform, pdist + >>> f = lambda points: squareform(pdist(points, metric="euclidean")) + + Parameters + ---------- + points: jnp.ndarray, bm.JaxArray + The points. + num_point: int + + Returns + ------- + dist_matrix: JaxArray + The distance matrix. + """ + if isinstance(points, dict): + if num_point is None: + raise ValueError('Please provide num_point') + else: + num_point = points.shape[0] + indices = jnp.triu_indices(num_point) + dist_mat = bm.zeros((num_point, num_point)) + leaves, _ = tree_flatten(points) + dist_mat[indices] = _ed(*indices, leaves) + dist_mat = bm.maximum(dist_mat, dist_mat.T) + return dist_mat + diff --git a/brainpy/analysis/utils/others.py b/brainpy/analysis/utils/others.py index e6e6bac30..eb4fe9028 100644 --- a/brainpy/analysis/utils/others.py +++ b/brainpy/analysis/utils/others.py @@ -1,12 +1,14 @@ # -*- coding: utf-8 -*- +from typing import Union, Dict import jax.numpy as jnp from jax import vmap import numpy as np +from jax.tree_util import tree_flatten, tree_map import brainpy.math as bm from .function import f_without_jaxarray_return -from .measurement import euclidean_distance +from .measurement import euclidean_distance, euclidean_distance_jax __all__ = [ 'Segment', @@ -85,12 +87,59 @@ def get_sign2(f, *xyz, args=()): return jnp.sign(f(*(XYZ + args))).reshape(shape) -def keep_unique(candidates, tolerance=2.5e-2): +def keep_unique(candidates: Union[np.ndarray, Dict[str, np.ndarray]], + tolerance: float=2.5e-2): """Filter unique fixed points by choosing a representative within tolerance. Parameters ---------- - candidates: np.ndarray + candidates: np.ndarray, dict + The fixed points with the shape of (num_point, num_dim). + tolerance: float + tolerance. + + Returns + ------- + fps_and_ids : tuple + A 2-tuple of (kept fixed points, ids of kept fixed points). + """ + if isinstance(candidates, dict): + element = tuple(candidates.values())[0] + num_fps = element.shape[0] + dtype = element.dtype + else: + num_fps = candidates.shape[0] + dtype = candidates.dtype + keep_ids = np.arange(num_fps) + if tolerance <= 0.0: + return candidates, keep_ids + if num_fps <= 1: + return candidates, keep_ids + candidates = tree_map(lambda a: np.asarray(a), candidates, is_leaf=lambda a: isinstance(a, bm.JaxArray)) + + # If point A and point B are within identical_tol of each other, and the + # A is first in the list, we keep A. + distances = np.asarray(euclidean_distance_jax(candidates, num_fps)) + example_idxs = np.arange(num_fps) + all_drop_idxs = [] + for fidx in range(num_fps - 1): + distances_f = distances[fidx, fidx + 1:] + drop_idxs = example_idxs[fidx + 1:][distances_f <= tolerance] + all_drop_idxs += list(drop_idxs) + keep_ids = np.setdiff1d(example_idxs, np.unique(all_drop_idxs)) + if keep_ids.shape[0] > 0: + unique_fps = tree_map(lambda a: a[keep_ids], candidates) + else: + unique_fps = np.array([], dtype=dtype) + return unique_fps, keep_ids + + +def keep_unique_jax(candidates, tolerance=2.5e-2): + """Filter unique fixed points by choosing a representative within tolerance. + + Parameters + ---------- + candidates: Tesnor The fixed points with the shape of (num_point, num_dim). Returns @@ -107,14 +156,14 @@ def keep_unique(candidates, tolerance=2.5e-2): # If point A and point B are within identical_tol of each other, and the # A is first in the list, we keep A. nfps = candidates.shape[0] - distances = euclidean_distance(candidates) + distances = euclidean_distance_jax(candidates) example_idxs = np.arange(nfps) all_drop_idxs = [] for fidx in range(nfps - 1): distances_f = distances[fidx, fidx + 1:] drop_idxs = example_idxs[fidx + 1:][distances_f <= tolerance] all_drop_idxs += list(drop_idxs) - keep_ids = np.setdiff1d(example_idxs, np.unique(all_drop_idxs)) + keep_ids = np.setdiff1d(example_idxs, np.unique(np.asarray(all_drop_idxs))) if keep_ids.shape[0] > 0: unique_fps = candidates[keep_ids, :] else: diff --git a/brainpy/base/collector.py b/brainpy/base/collector.py index eb2ccbcbd..ea425ad43 100644 --- a/brainpy/base/collector.py +++ b/brainpy/base/collector.py @@ -1,9 +1,7 @@ # -*- coding: utf-8 -*- -import jax -import jax.numpy as jnp -from contextlib import contextmanager +from typing import Dict, Sequence, Union math = None @@ -55,12 +53,12 @@ def __add__(self, other): gather.update(other) return gather - def __sub__(self, other): + def __sub__(self, other: Union[Dict, Sequence]): """Remove other item in the collector. Parameters ---------- - other: dict + other: dict, sequence The items to remove. Returns @@ -70,14 +68,26 @@ def __sub__(self, other): """ if not isinstance(other, dict): raise ValueError(f'Only support dict, but we got {type(other)}.') - gather = type(self)() - for key, val in self.items(): - if key in other: - if id(val) != id(other[key]): - raise ValueError(f'Cannot remove {key}, because we got two different values: ' - f'{val} != {other[key]}') - else: - gather[key] = val + gather = type(self)(self) + if isinstance(other, dict): + for key, val in other.items(): + if key in gather: + if id(val) != id(gather[key]): + raise ValueError(f'Cannot remove {key}, because we got two different values: ' + f'{val} != {gather[key]}') + gather.pop(key) + else: + raise ValueError(f'Cannot remove {key}, because we do not find it ' + f'in {self.keys()}.') + elif isinstance(other, (list, tuple)): + for key in other: + if key in gather: + gather.pop(key) + else: + raise ValueError(f'Cannot remove {key}, because we do not find it ' + f'in {self.keys()}.') + else: + raise KeyError(f'Unknown type of "other". Only support dict/tuple/list, but we got {type(other)}') return gather def subset(self, var_type): diff --git a/brainpy/dyn/runners.py b/brainpy/dyn/runners.py index 7f786aaa0..e6d7a21ed 100644 --- a/brainpy/dyn/runners.py +++ b/brainpy/dyn/runners.py @@ -76,12 +76,14 @@ def check_and_format_inputs(host, inputs): # checking 1: absolute access # Check whether the input target node is accessible, # and check whether the target node has the attribute - nodes = host.nodes(method='absolute', level=-1, include_self=True) + nodes = None for one_input in inputs: key = one_input[0] if isinstance(key, bm.Variable): real_target = key elif isinstance(key, str): + if nodes is None: + nodes = host.nodes(method='absolute', level=-1, include_self=True) splits = key.split('.') target = '.'.join(splits[:-1]) key = splits[-1] @@ -159,6 +161,69 @@ def check_and_format_inputs(host, inputs): return formatted_inputs +def build_inputs(inputs): + fix_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} + next_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} + func_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} + array_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} + + _has_iter_array = False + for variable, value, type_, op in inputs: + # variable + if not isinstance(variable, bm.Variable): + raise RunningError(f'{variable}\n is not a dynamically changed Variable, ' + f'its value will not change, we think there is no need to ' + f'give its input.') + + # input data + if type_ == 'iter': + if isinstance(value, (bm.ndarray, np.ndarray, jnp.ndarray)): + array_inputs[op].append([variable, bm.asarray(value)]) + _has_iter_array = True + else: + next_inputs[op].append([variable, iter(value)]) + elif type_ == 'func': + func_inputs[op].append([variable, value]) + else: + fix_inputs[op].append([variable, value]) + + index = None + if _has_iter_array: + index = bm.Variable(bm.zeros(1, dtype=int)) + + def _f_ops(ops, var, data): + if ops == '=': + var[:] = data + elif ops == '+': + var += data + elif ops == '-': + var -= data + elif ops == '*': + var *= data + elif ops == '/': + var /= data + else: + raise ValueError + + def func(_t, _dt): + for ops, values in fix_inputs.items(): + for var, data in values: + _f_ops(ops, var, data) + for ops, values in array_inputs.items(): + for var, data in values: + _f_ops(ops, var, data[index[0]]) + for ops, values in func_inputs.items(): + for var, data in values: + _f_ops(ops, var, data(_t, _dt)) + for ops, values in next_inputs.items(): + for var, data in values: + _f_ops(ops, var, next(data)) + if _has_iter_array: + index[0] += 1 + + return func, index + + class DSRunner(Runner): """The runner for dynamical systems. @@ -205,13 +270,9 @@ def __init__( # Build the monitor function self._monitor_step = self.build_monitors(*self.format_monitors()) - # whether it has iterable input data - self._has_iter_array = False # default do not have iterable input array - self._i = bm.Variable(bm.asarray([0])) - # Build input function inputs = check_and_format_inputs(host=target, inputs=inputs) - self._input_step = self.build_inputs(inputs) + self._input_step, self._i = build_inputs(inputs) # start simulation time self._start_t = None @@ -219,71 +280,11 @@ def __init__( # JAX does not support iterator in fori_loop, scan, etc. # https://github.com/google/jax/issues/3567 # We use Variable i to index the current input data. - if self._has_iter_array: # must behind of "self.build_input()" + if self._i is not None: # must behind of "self.build_input()" self.dyn_vars.update({'_i': self._i}) - else: - self._i = None # run function self._predict_func = dict() - # self._run_func = self.build_run_function() - - def build_inputs(self, inputs): - fix_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} - next_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} - func_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} - array_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} - - for variable, value, type_, op in inputs: - # variable - if not isinstance(variable, bm.Variable): - raise RunningError(f'{variable}\n is not a dynamically changed Variable, ' - f'its value will not change, we think there is no need to ' - f'give its input.') - - # input data - if type_ == 'iter': - if isinstance(value, (bm.ndarray, np.ndarray, jnp.ndarray)): - array_inputs[op].append([variable, bm.asarray(value)]) - self._has_iter_array = True - else: - next_inputs[op].append([variable, iter(value)]) - elif type_ == 'func': - func_inputs[op].append([variable, value]) - else: - fix_inputs[op].append([variable, value]) - - def _f_ops(ops, var, data): - if ops == '=': - var[:] = data - elif ops == '+': - var += data - elif ops == '-': - var -= data - elif ops == '*': - var *= data - elif ops == '/': - var /= data - else: - raise ValueError - - def func(_t, _dt): - for ops, values in fix_inputs.items(): - for var, data in values: - _f_ops(ops, var, data) - for ops, values in array_inputs.items(): - for var, data in values: - _f_ops(ops, var, data[self._i[0]]) - for ops, values in func_inputs.items(): - for var, data in values: - _f_ops(ops, var, data(_t, _dt)) - for ops, values in next_inputs.items(): - for var, data in values: - _f_ops(ops, var, next(data)) - if self._has_iter_array: - self._i += 1 - - return func def build_monitors(self, return_without_idx, return_with_idx, flatten=False): if flatten: diff --git a/brainpy/math/controls.py b/brainpy/math/controls.py index 6046878c1..0f6c2fe02 100644 --- a/brainpy/math/controls.py +++ b/brainpy/math/controls.py @@ -712,14 +712,15 @@ def _body_fun(op): def _cond_fun(op): dyn_vals, static_vals = op for v, d in zip(dyn_vars, dyn_vals): v._value = d - return as_device_array(cond_fun(static_vals)) + r = cond_fun(static_vals) + return r if isinstance(r, JaxArray) else r dyn_init = [v.value for v in dyn_vars] try: turn_on_global_jit() - dyn_values, _ = lax.while_loop(cond_fun=_cond_fun, - body_fun=_body_fun, - init_val=(dyn_init, operands)) + dyn_values, out = lax.while_loop(cond_fun=_cond_fun, + body_fun=_body_fun, + init_val=(dyn_init, operands)) turn_off_global_jit() except UnexpectedTracerError as e: turn_off_global_jit() @@ -730,3 +731,4 @@ def _cond_fun(op): for v, d in zip(dyn_vars, dyn_init): v._value = d raise e for v, d in zip(dyn_vars, dyn_values): v._value = d + return out diff --git a/examples/analysis/2d_decision_making_model.py b/examples/analysis/2d_decision_making_model.py index d7090037b..dd8651b23 100644 --- a/examples/analysis/2d_decision_making_model.py +++ b/examples/analysis/2d_decision_making_model.py @@ -69,16 +69,17 @@ def fixed_point_finder(): def step(s): ds1 = int_s1.f(s[0], 0., s[1]) ds2 = int_s2.f(s[1], 0., s[0]) - return bm.asarray([ds1.value, ds2.value]) - - finder = bp.analysis.SlowPointFinder(f_cell=step) - finder.find_fps_with_gd_method( - candidates=bm.random.random((1000, 2)), - tolerance=1e-5, num_batch=200, - optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.01, 1, 0.9999)), - ) - # finder.find_fps_with_opt_solver(bm.random.random((1000, 2))) - finder.filter_loss(1e-5) + return bm.asarray([ds1, ds2]) + + finder = bp.analysis.SlowPointFinder(f_cell=step, f_type=bp.analysis.CONTINUOUS) + # finder.find_fps_with_gd_method( + # candidates=bm.random.random((1000, 2)), + # tolerance=1e-8, + # num_batch=200, + # optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.01, 1, 0.9999)), + # ) + finder.find_fps_with_opt_solver(bm.random.random((1000, 2))) + finder.filter_loss(1e-14) finder.keep_unique() print('fixed_points: ', finder.fixed_points) diff --git a/examples/analysis/4d_HH_model.py b/examples/analysis/4d_HH_model.py index 463dd5b06..be0280e34 100644 --- a/examples/analysis/4d_HH_model.py +++ b/examples/analysis/4d_HH_model.py @@ -89,7 +89,8 @@ def update(self, t, dt): bp.visualize.line_plot(run.mon.ts, run.mon.V, legend='V', show=True) # analysis -finder = bp.analysis.SlowPointFinder(lambda h: model.step(h, I)) +finder = bp.analysis.SlowPointFinder(lambda h: model.step(h, I), + bp.analysis.CONTINUOUS) V = bm.random.normal(0., 5., (1000, model.num)) - 50. mhn = bm.random.random((1000, model.num * 3)) finder.find_fps_with_opt_solver(candidates=bm.hstack([V, mhn])) diff --git a/examples/analysis/highdim_CANN.py b/examples/analysis/highdim_CANN.py index 86602fb23..0899b3cd9 100644 --- a/examples/analysis/highdim_CANN.py +++ b/examples/analysis/highdim_CANN.py @@ -86,7 +86,8 @@ def find_fixed_points(): # candidates = bm.random.uniform(0, 20., (1000, cann.num)) - finder = bp.analysis.SlowPointFinder(f_cell=cann.cell) + finder = bp.analysis.SlowPointFinder(f_cell=cann.cell, + f_type=bp.analysis.CONTINUOUS) # finder.find_fps_with_gd_method( # candidates=candidates, # tolerance=1e-6, @@ -139,7 +140,8 @@ def verify_fixed_point_stability(num=3): fixed_points = np.load(fps_output_fn) cann = CANN1D(num=512, k=k, a=a, A=A) - finder = bp.analysis.SlowPointFinder(f_cell=cann.cell) + finder = bp.analysis.SlowPointFinder(f_cell=cann.cell, + f_type=bp.analysis.CONTINUOUS) J = finder.compute_jacobians(fixed_points[:num]) for i in range(num): diff --git a/examples/analysis/highdim_RNN_Analysis.py b/examples/analysis/highdim_RNN_Analysis.py index 52b19d43b..80fa55d65 100644 --- a/examples/analysis/highdim_RNN_Analysis.py +++ b/examples/analysis/highdim_RNN_Analysis.py @@ -260,7 +260,7 @@ def train(xs, ys): fp_candidates.shape # %% -finder = bp.analysis.SlowPointFinder(f_cell=f_cell, f_type='discrete') +finder = bp.analysis.SlowPointFinder(f_cell=f_cell, f_type=bp.analysis.DISCRETE) finder.find_fps_with_gd_method( candidates=fp_candidates, tolerance=1e-5, num_batch=200, diff --git a/examples/analysis/highdim_gj_coupled_fhn.py b/examples/analysis/highdim_gj_coupled_fhn.py index 9c0d31334..d8ba0c725 100644 --- a/examples/analysis/highdim_gj_coupled_fhn.py +++ b/examples/analysis/highdim_gj_coupled_fhn.py @@ -6,6 +6,7 @@ import brainpy as bp import brainpy.math as bm + bp.math.enable_x64() @@ -57,21 +58,18 @@ def d4_system(): plot_ids=list(range(model.num)), show=True) # analysis - def step(vw): - v, w = bm.split(vw, 2) - dv = model.dV(v, 0., w, Iext) - dw = model.dw(w, 0., v) - return bm.concatenate([dv, dw]) - - finder = bp.analysis.SlowPointFinder(f_cell=step) + finder = bp.analysis.SlowPointFinder(f_cell=model, + included_vars={'V': model.V, 'w': model.w}, + inputs=['Iext', Iext]) # finder.find_fps_with_gd_method( - # candidates=bm.random.normal(0., 2., (1000, model.num * 2)), - # tolerance=1e-5, + # candidates={'V': bm.random.normal(0., 2., (1000, model.num)), + # 'w': bm.random.normal(0., 2., (1000, model.num))}, + # tolerance=1e-7, # num_batch=200, - # opt_setting=dict(method=bm.optimizers.Adam, lr=bm.optimizers.ExponentialDecay(0.05, 1, 0.9999)), + # optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.05, 1, 0.9999)) # ) - - finder.find_fps_with_opt_solver(candidates=bm.random.normal(0., 2., (1000, model.num * 2))) + finder.find_fps_with_opt_solver(candidates={'V': bm.random.normal(0., 2., (1000, model.num)), + 'w': bm.random.normal(0., 2., (1000, model.num))}) finder.filter_loss(1e-7) finder.keep_unique() @@ -79,11 +77,11 @@ def step(vw): print('losses:', finder.losses) if len(finder.fixed_points): jac = finder.compute_jacobians(finder.fixed_points) - for i in range(len(finder.fixed_points)): + for i in range(len(finder.selected_ids)): eigval, eigvec = np.linalg.eig(np.asarray(jac[i])) plt.figure() plt.scatter(np.real(eigval), np.imag(eigval)) - plt.plot([0, 0], [-1, 1], '--') + plt.plot([1, 1], [-1, 1], '--') plt.xlabel('Real') plt.ylabel('Imaginary') plt.title(f'FP {i}') @@ -103,17 +101,13 @@ def d8_system(): plot_ids=list(range(model.num)), show=True) - # analysis - def step(vw): - v, w = bm.split(vw, 2) - dv = model.dV(v, 0., w, Iext) - dw = model.dw(w, 0., v) - return bm.concatenate([dv, dw]) - - finder = bp.analysis.SlowPointFinder(f_cell=step) + finder = bp.analysis.SlowPointFinder(f_cell=model, + included_vars={'V': model.V, 'w': model.w}, + inputs=['Iext', Iext]) finder.find_fps_with_gd_method( - candidates=bm.random.normal(0., 2., (1000, model.num * 2)), - tolerance=1e-5, + candidates={'V': bm.random.normal(0., 2., (1000, model.num)), + 'w': bm.random.normal(0., 2., (1000, model.num))}, + tolerance=1e-6, num_batch=200, optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.05, 1, 0.9999)), ) @@ -124,11 +118,11 @@ def step(vw): print('losses:', finder.losses) if len(finder.fixed_points): jac = finder.compute_jacobians(finder.fixed_points) - for i in range(len(finder.fixed_points)): + for i in range(finder.num_fps): eigval, eigvec = np.linalg.eig(np.asarray(jac[i])) plt.figure() plt.scatter(np.real(eigval), np.imag(eigval)) - plt.plot([0, 0], [-1, 1], '--') + plt.plot([1, 1], [-1, 1], '--') plt.xlabel('Real') plt.ylabel('Imaginary') plt.title(f'FP {i}') @@ -137,5 +131,4 @@ def step(vw): if __name__ == '__main__': d4_system() - # d8_system() - # analysis() + d8_system() From 58496ac8030418b650372e1a06e0d42e5dc67cfb Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 4 Jun 2022 15:19:36 +0800 Subject: [PATCH 6/9] [example] simulation model examples with hierarchy structures --- ...2017_unified_thalamus_oscillation_model.py | 343 ++++++++++++++++++ examples/simulation/hh_model.py | 27 ++ 2 files changed, 370 insertions(+) create mode 100644 examples/simulation/Li_2017_unified_thalamus_oscillation_model.py create mode 100644 examples/simulation/hh_model.py diff --git a/examples/simulation/Li_2017_unified_thalamus_oscillation_model.py b/examples/simulation/Li_2017_unified_thalamus_oscillation_model.py new file mode 100644 index 000000000..0d93457a7 --- /dev/null +++ b/examples/simulation/Li_2017_unified_thalamus_oscillation_model.py @@ -0,0 +1,343 @@ +# -*- coding: utf-8 -*- + +""" +Implementation of the model: + +- Li, Guoshi, Craig S. Henriquez, and Flavio Fröhlich. "Unified + thalamic model generates multiple distinct oscillations with + state-dependent entrainment by stimulation." PLoS computational + biology 13.10 (2017): e1005797. +""" + + +from typing import Dict +import matplotlib.pyplot as plt +import numpy as np + +import brainpy as bp +import brainpy.math as bm +from brainpy.dyn import channels, synapses, synouts, synplast + + +class HTC(bp.dyn.CondNeuGroup): + def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-65.), ): + gL = 0.01 if size == 1 else bp.init.Uniform(0.0075, 0.0125) + IL = channels.IL(size, g_max=gL, E=-70) + IKL = channels.IKL(size, g_max=gKL) + INa = channels.INa_Ba2002(size, V_sh=-30) + IDR = channels.IKDR_Ba2002(size, V_sh=-30., phi=0.25) + Ih = channels.Ih_HM1992(size, g_max=0.01, E=-43) + + ICaL = channels.ICaL_IS2008(size, g_max=0.5) + IAHP = channels.IAHP_De1994(size, g_max=0.3, E=-90.) + ICaN = channels.ICaN_IS2008(size, g_max=0.5) + ICaT = channels.ICaT_HM1992(size, g_max=2.1) + ICaHT = channels.ICaHT_HM1992(size, g_max=3.0) + Ca = channels.CalciumDetailed(size, C_rest=5e-5, tau=10., d=0.5, ICaL=ICaL, + IAHP=IAHP, ICaN=ICaN, ICaT=ICaT, ICaHT=ICaHT) + + super(HTC, self).__init__(size, A=2.9e-4, V_initializer=V_initializer, V_th=20., + IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ih=Ih, Ca=Ca) + + +class RTC(bp.dyn.CondNeuGroup): + def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-65.), ): + gL = 0.01 if size == 1 else bp.init.Uniform(0.0075, 0.0125) + IL = channels.IL(size, g_max=gL, E=-70) + IKL = channels.IKL(size, g_max=gKL) + INa = channels.INa_Ba2002(size, V_sh=-40) + IDR = channels.IKDR_Ba2002(size, V_sh=-40, phi=0.25) + Ih = channels.Ih_HM1992(size, g_max=0.01, E=-43) + + ICaL = channels.ICaL_IS2008(size, g_max=0.3) + IAHP = channels.IAHP_De1994(size, g_max=0.1, E=-90.) + ICaN = channels.ICaN_IS2008(size, g_max=0.6) + ICaT = channels.ICaT_HM1992(size, g_max=2.1) + ICaHT = channels.ICaHT_HM1992(size, g_max=0.6) + Ca = channels.CalciumDetailed(size, C_rest=5e-5, tau=10., d=0.5, ICaL=ICaL, + IAHP=IAHP, ICaN=ICaN, ICaT=ICaT, ICaHT=ICaHT) + + super(RTC, self).__init__(size, A=2.9e-4, V_initializer=V_initializer, V_th=20., + IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ih=Ih, Ca=Ca) + + +class IN(bp.dyn.CondNeuGroup): + def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-70.), ): + gL = 0.01 if size == 1 else bp.init.Uniform(0.0075, 0.0125) + IL = channels.IL(size, g_max=gL, E=-60) + IKL = channels.IKL(size, g_max=gKL) + INa = channels.INa_Ba2002(size, V_sh=-30) + IDR = channels.IKDR_Ba2002(size, V_sh=-30, phi=0.25) + Ih = channels.Ih_HM1992(size, g_max=0.05, E=-43) + + IAHP = channels.IAHP_De1994(size, g_max=0.2, E=-90.) + ICaN = channels.ICaN_IS2008(size, g_max=0.1) + ICaHT = channels.ICaHT_HM1992(size, g_max=2.5) + Ca = channels.CalciumDetailed(size, C_rest=5e-5, tau=10., d=0.5, + IAHP=IAHP, ICaN=ICaN, ICaHT=ICaHT) + + super(IN, self).__init__(size, A=1.7e-4, V_initializer=V_initializer, V_th=20., + IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ih=Ih, Ca=Ca) + + +class TRN(bp.dyn.CondNeuGroup): + def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-70.), ): + gL = 0.01 if size == 1 else bp.init.Uniform(0.0075, 0.0125) + IL = channels.IL(size, g_max=gL, E=-60) + IKL = channels.IKL(size, g_max=gKL) + INa = channels.INa_Ba2002(size, V_sh=-40) + IDR = channels.IKDR_Ba2002(size, V_sh=-40) + + IAHP = channels.IAHP_De1994(size, g_max=0.2, E=-90.) + ICaN = channels.ICaN_IS2008(size, g_max=0.2) + ICaT = channels.ICaT_HP1992(size, g_max=1.3) + Ca = channels.CalciumDetailed(size, C_rest=5e-5, tau=100., d=0.5, + IAHP=IAHP, ICaN=ICaN, ICaT=ICaT) + + super(TRN, self).__init__(size, A=1.43e-4, + V_initializer=V_initializer, V_th=20., + IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ca=Ca) + + +class MgBlock(bp.dyn.SynapseOutput): + def __init__(self, E=0.): + super(MgBlock, self).__init__() + self.E = E + + def filter(self, g): + V = self.master.post.V.value + return g * (self.E - V) / (1 + bm.exp(-(V + 25) / 12.5)) + + +class Thalamus(bp.dyn.Network): + def __init__( + self, g_input: Dict[str, float], g_KL: Dict[str, float], + HTC_V_init=bp.init.OneInit(-65.), RTC_V_init=bp.init.OneInit(-65.), + IN_V_init=bp.init.OneInit(-70.), RE_V_init=bp.init.OneInit(-70.), + ): + super(Thalamus, self).__init__() + + # populations + self.HTC = HTC(size=(7, 7), gKL=g_KL['TC'], V_initializer=HTC_V_init) + self.RTC = RTC(size=(12, 12), gKL=g_KL['TC'], V_initializer=RTC_V_init) + self.RE = TRN(size=(10, 10), gKL=g_KL['RE'], V_initializer=IN_V_init) + self.IN = IN(size=(8, 8), gKL=g_KL['IN'], V_initializer=RE_V_init) + + # noises + self.poisson_HTC = bp.dyn.PoissonGroup(self.HTC.size, freqs=100) + self.poisson_RTC = bp.dyn.PoissonGroup(self.RTC.size, freqs=100) + self.poisson_IN = bp.dyn.PoissonGroup(self.IN.size, freqs=100) + self.poisson_RE = bp.dyn.PoissonGroup(self.RE.size, freqs=100) + self.noise2HTC = synapses.Exponential(self.poisson_HTC, self.HTC, bp.conn.One2One(), + output=synouts.COBA(E=0.), tau=5., + g_max=g_input['TC']) + self.noise2RTC = synapses.Exponential(self.poisson_RTC, self.RTC, bp.conn.One2One(), + output=synouts.COBA(E=0.), tau=5., + g_max=g_input['TC']) + self.noise2IN = synapses.Exponential(self.poisson_IN, self.IN, bp.conn.One2One(), + output=synouts.COBA(E=0.), tau=5., + g_max=g_input['IN']) + self.noise2RE = synapses.Exponential(self.poisson_RE, self.RE, bp.conn.One2One(), + output=synouts.COBA(E=0.), tau=5., + g_max=g_input['RE']) + + # HTC cells were connected with gap junctions + self.gj_HTC = synapses.GapJunction(self.HTC, self.HTC, + bp.conn.ProbDist(dist=2., prob=0.3, ), + conn_type='sparse', + g_max=1e-2) + + # HTC provides feedforward excitation to INs + self.HTC2IN_ampa = synapses.AMPA(self.HTC, self.IN, bp.conn.FixedProb(0.3), + delay_step=int(2 / bm.get_dt()), + plasticity=synplast.STD(tau=700, U=0.07), + alpha=0.94, + beta=0.18, + g_max=6e-3) + self.HTC2IN_nmda = synapses.AMPA(self.HTC, self.IN, bp.conn.FixedProb(0.3), + delay_step=int(2 / bm.get_dt()), + plasticity=synplast.STD(tau=700, U=0.07), + output=MgBlock(), + alpha=1., + beta=0.0067, + g_max=3e-3) + + # INs delivered feedforward inhibition to RTC cells + self.IN2RTC = synapses.GABAa(self.IN, self.RTC, bp.conn.FixedProb(0.3), + delay_step=int(2 / bm.get_dt()), + plasticity=synplast.STD(tau=700, U=0.07), + output=synouts.COBA(E=-80), + alpha=10.5, + beta=0.166, + g_max=3e-3) + + # 20% RTC cells electrically connected with HTC cells + self.gj_RTC2HTC = synapses.GapJunction(self.RTC, self.HTC, + bp.conn.ProbDist(dist=2., prob=0.3, pre_ratio=0.2), + conn_type='sparse', + g_max=1 / 300) + + # Both HTC and RTC cells sent glutamatergic synapses to RE neurons, while + # receiving GABAergic feedback inhibition from the RE population + self.HTC2RE_ampa = synapses.AMPA(self.HTC, self.RE, bp.conn.FixedProb(0.2), + delay_step=int(2 / bm.get_dt()), + plasticity=synplast.STD(tau=700, U=0.07), + alpha=0.94, + beta=0.18, + g_max=4e-3) + self.RTC2RE_ampa = synapses.AMPA(self.RTC, self.RE, bp.conn.FixedProb(0.2), + delay_step=int(2 / bm.get_dt()), + plasticity=synplast.STD(tau=700, U=0.07), + alpha=0.94, + beta=0.18, + g_max=4e-3) + self.HTC2RE_nmda = synapses.AMPA(self.HTC, self.RE, bp.conn.FixedProb(0.2), + delay_step=int(2 / bm.get_dt()), + plasticity=synplast.STD(tau=700, U=0.07), + output=MgBlock(), + alpha=1., + beta=0.0067, + g_max=2e-3) + self.RTC2RE_nmda = synapses.AMPA(self.RTC, self.RE, bp.conn.FixedProb(0.2), + delay_step=int(2 / bm.get_dt()), + plasticity=synplast.STD(tau=700, U=0.07), + output=MgBlock(), + alpha=1., + beta=0.0067, + g_max=2e-3) + self.RE2HTC = synapses.GABAa(self.RE, self.HTC, bp.conn.FixedProb(0.2), + delay_step=int(2 / bm.get_dt()), + plasticity=synplast.STD(tau=700, U=0.07), + output=synouts.COBA(E=-80), + alpha=10.5, + beta=0.166, + g_max=3e-3) + self.RE2RTC = synapses.GABAa(self.RE, self.RTC, bp.conn.FixedProb(0.2), + delay_step=int(2 / bm.get_dt()), + plasticity=synplast.STD(tau=700, U=0.07), + output=synouts.COBA(E=-80), + alpha=10.5, + beta=0.166, + g_max=3e-3) + + # RE neurons were connected with both gap junctions and GABAergic synapses + self.gj_RE = synapses.GapJunction(self.RE, self.RE, + bp.conn.ProbDist(dist=2., prob=0.3, pre_ratio=0.2), + conn_type='sparse', + g_max=1 / 300) + self.RE2RE = synapses.GABAa(self.RE, self.RE, bp.conn.FixedProb(0.2), + delay_step=int(2 / bm.get_dt()), + plasticity=synplast.STD(tau=700, U=0.07), + output=synouts.COBA(E=-70), + alpha=10.5, beta=0.166, + g_max=1e-3) + + # 10% RE neurons project GABAergic synapses to local interneurons + # probability (0.05) was used for the RE->IN synapses according to experimental data + self.RE2IN = synapses.GABAa(self.RE, self.IN, bp.conn.FixedProb(0.05, pre_ratio=0.1), + delay_step=int(2 / bm.get_dt()), + plasticity=synplast.STD(tau=700, U=0.07), + output=synouts.COBA(E=-80), + alpha=10.5, beta=0.166, + g_max=1e-3, ) + + +states = { + 'delta': dict(g_input={'IN': 1e-4, 'RE': 1e-4, 'TC': 1e-4}, + g_KL={'TC': 0.035, 'RE': 0.03, 'IN': 0.01}), + 'spindle': dict(g_input={'IN': 3e-4, 'RE': 3e-4, 'TC': 3e-4}, + g_KL={'TC': 0.01, 'RE': 0.02, 'IN': 0.015}), + 'alpha': dict(g_input={'IN': 1.5e-3, 'RE': 1.5e-3, 'TC': 1.5e-3}, + g_KL={'TC': 0., 'RE': 0.01, 'IN': 0.02}), + 'gamma': dict(g_input={'IN': 1.5e-3, 'RE': 1.5e-3, 'TC': 1.7e-2}, + g_KL={'TC': 0., 'RE': 0.01, 'IN': 0.02}), +} + + +def rhythm_const_input(amp, freq, length, duration, t_start=0., t_end=None, dt=None): + if t_end is None: t_end = duration + if length > duration: + raise ValueError(f'Expected length <= duration, while we got {length} > {duration}') + sec_length = 1e3 / freq + values, durations = [0.], [t_start] + for t in np.arange(t_start, t_end, sec_length): + values.append(amp) + if t + length <= t_end: + durations.append(length) + values.append(0.) + if t + sec_length <= t_end: + durations.append(sec_length - length) + else: + durations.append(t_end - t - length) + else: + durations.append(t_end - t) + values.append(0.) + durations.append(duration - t_end) + return bp.inputs.section_input(values=values, durations=durations, dt=dt, ) + + +def try_trn_neuron(): + trn = TRN(1) + I, length = bp.inputs.section_input(values=[0, -0.05, 0], + durations=[100, 100, 500], + return_length=True, + dt=0.01) + runner = bp.dyn.DSRunner(trn, + monitors=['V'], + inputs=['input', I, 'iter'], + dt=0.01) + runner.run(length) + + bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True) + + +def try_network(): + duration = 3e3 + net = Thalamus( + IN_V_init=bp.init.OneInit(-70.), + RE_V_init=bp.init.OneInit(-70.), + HTC_V_init=bp.init.OneInit(-80.), + RTC_V_init=bp.init.OneInit(-80.), + **states['delta'], + ) + net.reset() + + # currents = rhythm_const_input(2e-4, freq=4., length=10., duration=duration, + # t_end=2e3, t_start=1e3) + # plt.plot(currents) + # plt.show() + + runner = bp.dyn.DSRunner( + net, + monitors=['HTC.spike', 'RTC.spike', 'RE.spike', 'IN.spike', + 'HTC.V', 'RTC.V', 'RE.V', 'IN.V', ], + # inputs=[('HTC.input', currents, 'iter'), + # ('RTC.input', currents, 'iter'), + # ('IN.input', currents, 'iter')], + ) + runner.run(duration) + + fig, gs = bp.visualize.get_figure(4, 2, 2, 5) + fig.add_subplot(gs[0, 0]) + bp.visualize.line_plot(runner.mon.ts, runner.mon.get('HTC.V'), ylabel='HTC', xlim=(0, duration)) + fig.add_subplot(gs[1, 0]) + bp.visualize.line_plot(runner.mon.ts, runner.mon.get('RTC.V'), ylabel='RTC', xlim=(0, duration)) + fig.add_subplot(gs[2, 0]) + bp.visualize.line_plot(runner.mon.ts, runner.mon.get('IN.V'), ylabel='IN', xlim=(0, duration)) + fig.add_subplot(gs[3, 0]) + bp.visualize.line_plot(runner.mon.ts, runner.mon.get('RE.V'), ylabel='RE', xlim=(0, duration)) + + fig.add_subplot(gs[0, 1]) + bp.visualize.raster_plot(runner.mon.ts, runner.mon.get('HTC.spike'), xlim=(0, duration)) + fig.add_subplot(gs[1, 1]) + bp.visualize.raster_plot(runner.mon.ts, runner.mon.get('RTC.spike'), xlim=(0, duration)) + fig.add_subplot(gs[2, 1]) + bp.visualize.raster_plot(runner.mon.ts, runner.mon.get('IN.spike'), xlim=(0, duration)) + fig.add_subplot(gs[3, 1]) + bp.visualize.raster_plot(runner.mon.ts, runner.mon.get('RE.spike'), xlim=(0, duration)) + + plt.show() + + +if __name__ == '__main__': + try_network() diff --git a/examples/simulation/hh_model.py b/examples/simulation/hh_model.py new file mode 100644 index 000000000..e43b75079 --- /dev/null +++ b/examples/simulation/hh_model.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- + +import brainpy as bp +from brainpy import dyn +from brainpy.dyn import channels + + +class HH(dyn.CondNeuGroup): + def __init__(self, size): + super(HH, self).__init__(size) + + self.INa = channels.INa_HH(size, ) + self.IK = channels.IK_HH(size, ) + self.IL = channels.IL(size, E=-54.387, g_max=0.03) + + +hh = HH(1) +I, length = bp.inputs.section_input(values=[0, 5, 0], + durations=[100, 500, 100], + return_length=True) +runner = bp.dyn.DSRunner(hh, + monitors=['V', 'INa.p', 'INa.q', 'IK.p'], + inputs=['input', I, 'iter']) +runner.run(length) + +bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True) + From 240179de1322e550d7106c0ea740b750fa2ee6b9 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 5 Jun 2022 21:27:29 +0800 Subject: [PATCH 7/9] update examples --- brainpy/dyn/networks/cann.py | 6 +-- brainpy/losses/comparison.py | 3 +- examples/analysis/4d_HH_model.py | 17 ++++---- examples/analysis/highdim_CANN.py | 9 ++--- examples/analysis/highdim_gj_coupled_fhn.py | 2 +- .../Bazhenov_1998_thalamus_aug_response.py | 39 +++++++++++++++++++ .../Sanda_2021_hippo-tha-cortex-model.py | 26 +++++++++++++ 7 files changed, 81 insertions(+), 21 deletions(-) create mode 100644 examples/simulation/Bazhenov_1998_thalamus_aug_response.py create mode 100644 examples/simulation/Sanda_2021_hippo-tha-cortex-model.py diff --git a/brainpy/dyn/networks/cann.py b/brainpy/dyn/networks/cann.py index 35380e8a8..2b0cb2480 100644 --- a/brainpy/dyn/networks/cann.py +++ b/brainpy/dyn/networks/cann.py @@ -6,8 +6,6 @@ __all__ = [ 'WuCANN1D', 'WuCANN2D', - 'CANN_SFA_1D', - 'CANN_SFA_2D', ] @@ -19,10 +17,10 @@ class WuCANN2D(NeuGroup): pass -class CANN_SFA_1D(NeuGroup): +class ACANN_1D(NeuGroup): pass -class CANN_SFA_2D(NeuGroup): +class ACANN_2D(NeuGroup): pass diff --git a/brainpy/losses/comparison.py b/brainpy/losses/comparison.py index 730849fe6..12d6e4254 100644 --- a/brainpy/losses/comparison.py +++ b/brainpy/losses/comparison.py @@ -9,10 +9,9 @@ import jax.numpy as jnp import jax.scipy -from jax.tree_util import tree_flatten, tree_map +from jax.tree_util import tree_map import brainpy.math as bm -from brainpy import errors from .utils import _return, _multi_return, _is_leaf __all__ = [ diff --git a/examples/analysis/4d_HH_model.py b/examples/analysis/4d_HH_model.py index be0280e34..dcceee557 100644 --- a/examples/analysis/4d_HH_model.py +++ b/examples/analysis/4d_HH_model.py @@ -82,15 +82,14 @@ def update(self, t, dt): self.input[:] = 0. -model = HH(1) I = 5. -run = bp.dyn.DSRunner(model, inputs=('input', I), monitors=['V']) -run(100) -bp.visualize.line_plot(run.mon.ts, run.mon.V, legend='V', show=True) +model = HH(1) +runner = bp.dyn.DSRunner(model, inputs=('input', I), monitors=['V']) +runner.run(100) +bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', show=True) # analysis -finder = bp.analysis.SlowPointFinder(lambda h: model.step(h, I), - bp.analysis.CONTINUOUS) +finder = bp.analysis.SlowPointFinder(model, inputs=(model.input, I), excluded_vars=[model.input]) V = bm.random.normal(0., 5., (1000, model.num)) - 50. mhn = bm.random.random((1000, model.num * 3)) finder.find_fps_with_opt_solver(candidates=bm.hstack([V, mhn])) @@ -116,6 +115,6 @@ def update(self, t, dt): model.m[:] = fp[1] model.h[:] = fp[2] model.n[:] = fp[3] - run = bp.dyn.DSRunner(model, inputs=('input', I), monitors=['V']) - run(100) - bp.visualize.line_plot(run.mon.ts, run.mon.V, legend='V', title=f'FP {i}', show=True) + runner = bp.dyn.DSRunner(model, inputs=('input', I), monitors=['V']) + runner.run(100) + bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', title=f'FP {i}', show=True) diff --git a/examples/analysis/highdim_CANN.py b/examples/analysis/highdim_CANN.py index 0899b3cd9..229e26694 100644 --- a/examples/analysis/highdim_CANN.py +++ b/examples/analysis/highdim_CANN.py @@ -86,15 +86,14 @@ def find_fixed_points(): # candidates = bm.random.uniform(0, 20., (1000, cann.num)) - finder = bp.analysis.SlowPointFinder(f_cell=cann.cell, - f_type=bp.analysis.CONTINUOUS) + finder = bp.analysis.SlowPointFinder(f_cell=cann, included_vars={'u': cann.u}) # finder.find_fps_with_gd_method( # candidates=candidates, # tolerance=1e-6, # optimizer = bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.1, , 0.999)), # num_batch=200 # ) - finder.find_fps_with_opt_solver(candidates) + finder.find_fps_with_opt_solver({'u': candidates}) finder.filter_loss(1e-5) finder.keep_unique() # finder.exclude_outliers() @@ -127,10 +126,10 @@ def verify_fixed_points_through_simulation(num=3): for i in range(num): cann.u[:] = fixed_points[i] - runner = bp.StructRunner(cann, + runner = bp.dyn.DSRunner(cann, monitors=['u'], dyn_vars=cann.vars()) - runner(100.) + runner.run(100.) plt.plot(runner.mon.ts, runner.mon.u.max(axis=1)) plt.ylim(0, runner.mon.u.max() + 1) plt.show() diff --git a/examples/analysis/highdim_gj_coupled_fhn.py b/examples/analysis/highdim_gj_coupled_fhn.py index d8ba0c725..3b39228dc 100644 --- a/examples/analysis/highdim_gj_coupled_fhn.py +++ b/examples/analysis/highdim_gj_coupled_fhn.py @@ -103,7 +103,7 @@ def d8_system(): finder = bp.analysis.SlowPointFinder(f_cell=model, included_vars={'V': model.V, 'w': model.w}, - inputs=['Iext', Iext]) + inputs=[model.Iext, Iext]) finder.find_fps_with_gd_method( candidates={'V': bm.random.normal(0., 2., (1000, model.num)), 'w': bm.random.normal(0., 2., (1000, model.num))}, diff --git a/examples/simulation/Bazhenov_1998_thalamus_aug_response.py b/examples/simulation/Bazhenov_1998_thalamus_aug_response.py new file mode 100644 index 000000000..195c1c001 --- /dev/null +++ b/examples/simulation/Bazhenov_1998_thalamus_aug_response.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- + +""" +Implementation of the model: + +- Bazhenov, Maxim, et al. "Cellular and network models for + intrathalamic augmenting responses during 10-Hz stimulation." + Journal of Neurophysiology 79.5 (1998): 2730-2748. +""" + +import brainpy as bp +from brainpy.dyn import neurons, synapses, channels + + +class RE(bp.dyn.CondNeuGroup): + def __init__(self, size): + super(RE, self).__init__(size, A=1.43e-4) + + self.IL = channels.IL(size, ) + self.IKL = channels.IKL(size, ) + self.INa = channels.INa_TM1991(size, V_sh=-50.) + self.IK = channels.IK_TM1991(size, V_sh=-50.) + self.IT = channels.ICaT_HP1992(size, V_sh=0., phi_q=3., phi_p=3.) + + +class TC(bp.dyn.CondNeuGroup): + def __init__(self, size): + super(TC, self).__init__(size, A=2.9e-4) + + self.IL = channels.IL(size, ) + self.IKL = channels.IKL(size, ) + self.INa = channels.INa_TM1991(size, V_sh=-50.) + self.IK = channels.IK_TM1991(size, V_sh=-50.) + self.IT = channels.ICaT_HM1992(size, V_sh=0., ) + self.IA = channels.IKA1_HM1992(size, V_sh=0., phi_q=3.7255, phi_p=3.7) + + self.Ih = channels.Ih_De1996(size, ) + self.Ca = channels.CalciumFirstOrder(size, ) + diff --git a/examples/simulation/Sanda_2021_hippo-tha-cortex-model.py b/examples/simulation/Sanda_2021_hippo-tha-cortex-model.py new file mode 100644 index 000000000..0ef35cf8c --- /dev/null +++ b/examples/simulation/Sanda_2021_hippo-tha-cortex-model.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- + +import brainpy as bp +from brainpy.dyn import neurons, synapses + + +class HippoThaCortexModel(bp.dyn.Network): + def __init__(self, ): + super(HippoThaCortexModel, self).__init__() + + self.CA1Exc = neurons.AdExIF(800, R=1/7., tau=200/7., V_rest=-58, delta_T=2., + V_T=-50, tau_w=120, a=2, V_th=0., V_reset=-46, b=100) + self.CA3Exc = neurons.AdExIF(1200, R=1/7., tau=200/7., V_rest=-58, delta_T=2., + V_T=-50, tau_w=120, a=2, V_th=0., V_reset=-46, b=40) + self.CA1Inh = neurons.AdExIF(160, R=1/10., tau=200/10., V_rest=-70, delta_T=2., + V_T=-50, tau_w=30, a=2, V_th=0., V_reset=-58, b=10) + self.CA3Inh = neurons.AdExIF(240, R=1/10., tau=200/10., V_rest=-70, delta_T=2., + V_T=-50, tau_w=30, a=2, V_th=0., V_reset=-58, b=10) + for pop in [self.CA1Exc, self.CA3Exc, self.CA1Inh, self.CA3Inh]: + ou = neurons.OUProcess(self.CA1Exc.size, ) + conn = synapses.WeightedSum(ou, pop, bp.conn.One2One()) + self.register_implicit_nodes(ou, conn) + + + + From d51eb2053044d386f316065205ef3ce57bee72b8 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 5 Jun 2022 21:38:34 +0800 Subject: [PATCH 8/9] remove duplicate tests --- .../compat/nn/nodes/ANN/tests/test_conv.py | 82 --------- .../nn/nodes/ANN/tests/test_normalization.py | 158 ------------------ .../compat/nn/nodes/ANN/tests/test_pooling.py | 56 ------- 3 files changed, 296 deletions(-) delete mode 100644 brainpy/compat/nn/nodes/ANN/tests/test_conv.py delete mode 100644 brainpy/compat/nn/nodes/ANN/tests/test_normalization.py delete mode 100644 brainpy/compat/nn/nodes/ANN/tests/test_pooling.py diff --git a/brainpy/compat/nn/nodes/ANN/tests/test_conv.py b/brainpy/compat/nn/nodes/ANN/tests/test_conv.py deleted file mode 100644 index b2d57a782..000000000 --- a/brainpy/compat/nn/nodes/ANN/tests/test_conv.py +++ /dev/null @@ -1,82 +0,0 @@ -# -*- coding: utf-8 -*- -import random - -import pytest -from unittest import TestCase -import brainpy as bp -import jax.numpy as jnp -import numpy as np - -class TestConv(TestCase): - def test_Conv2D_img(self): - i = bp.nn.Input((200, 198, 4)) - b = bp.nn.Conv2D(32, (3, 3), strides=(1, 1), padding='VALID', groups=2) - model = i >> b - model.initialize(num_batch=2) - - img = jnp.zeros((2, 200, 198, 4)) - for k in range(4): - x = 30 + 60 * k - y = 20 + 60 * k - img = img.at[0, x:x + 10, y:y + 10, k].set(1.0) - img = img.at[1, x:x + 20, y:y + 20, k].set(3.0) - - out = model(img) - print("out shape: ", out.shape) - # print("First output channel:") - # plt.figure(figsize=(10, 10)) - # plt.imshow(np.array(out)[0, :, :, 0]) - # plt.show() - - def test_conv2D_fb(self): - i = bp.nn.Input((5, 5, 3)) - b = bp.nn.Conv2D(32, (3, 3)) - c = bp.nn.Conv2D(64, (3, 3)) - model = (i >> b >> c) & (b << c) - model.initialize(num_batch=2) - - input = bp.math.ones((2, 5, 5, 3)) - - out = model(input) - print("out shape: ", out.shape) - - def test_conv1D(self): - i = bp.nn.Input((5, 3)) - b = bp.nn.Conv1D(32, (3,)) - model = i >> b - model.initialize(num_batch=2) - - input = bp.math.ones((2, 5, 3)) - - out = model(input) - print("out shape: ", out.shape) - # print("First output channel:") - # plt.figure(figsize=(10, 10)) - # plt.imshow(np.array(out)[0, :, :]) - # plt.show() - - def test_conv2D(self): - i = bp.nn.Input((5, 5, 3)) - b = bp.nn.Conv2D(32, (3, 3)) - model = i >> b - model.initialize(num_batch=2) - - input = bp.math.ones((2, 5, 5, 3)) - - out = model(input) - print("out shape: ", out.shape) - # print("First output channel:") - # plt.figure(figsize=(10, 10)) - # plt.imshow(np.array(out)[0, :, :, 31]) - # plt.show() - - def test_conv3D(self): - i = bp.nn.Input((5, 5, 5, 3)) - b = bp.nn.Conv3D(32, (3, 3, 3)) - model = i >> b - model.initialize(num_batch=2) - - input = bp.math.ones((2, 5, 5, 5, 3)) - - out = model(input) - print("out shape: ", out.shape) diff --git a/brainpy/compat/nn/nodes/ANN/tests/test_normalization.py b/brainpy/compat/nn/nodes/ANN/tests/test_normalization.py deleted file mode 100644 index c57defb65..000000000 --- a/brainpy/compat/nn/nodes/ANN/tests/test_normalization.py +++ /dev/null @@ -1,158 +0,0 @@ -# -*- coding: utf-8 -*- - - -from unittest import TestCase - -import brainpy as bp - - -class TestBatchNorm1d(TestCase): - def test_batchnorm1d1(self): - i = bp.nn.Input((3, 4)) - b = bp.nn.BatchNorm1d() - model = i >> b - model.initialize(num_batch=2) - # model.plot_node_graph(fig_size=(5, 5), node_size=500) - - inputs = bp.math.ones((2, 3, 4)) - inputs[0, 0, :] = 2. - inputs[0, 1, 0] = 5. - print(inputs) - - print(model(inputs)) - - def test_batchnorm1d2(self): - i = bp.nn.Input(4) - b = bp.nn.BatchNorm1d() - o = bp.nn.DenseMD(4) - model = i >> b >> o - model.initialize(num_batch=2) - - inputs = bp.math.ones((2, 4)) - inputs[0, :] = 2. - print(inputs) - - print(model(inputs)) - - -class TestBatchNorm2d(TestCase): - def test_batchnorm2d(self): - i = bp.nn.Input((32, 32, 3)) - b = bp.nn.BatchNorm2d() - model = i >> b - model.initialize(num_batch=10) - - inputs = bp.math.ones((10, 32, 32, 3)) - inputs[0, 1, :, :] = 2. - print(inputs.shape) - - print(model(inputs).shape) - - -class TestBatchNorm3d(TestCase): - def test_batchnorm3d(self): - i = bp.nn.Input((32, 32, 16, 3)) - b = bp.nn.BatchNorm3d() - model = i >> b - model.initialize(num_batch=10) - - inputs = bp.math.ones((10, 32, 32, 16, 3)) - print(inputs.shape) - - print(model(inputs).shape) - - -class TestBatchNorm(TestCase): - def test_batchnorm1(self): - i = bp.nn.Input((3, 4)) - b = bp.nn.BatchNorm(axis=(0, 2), use_bias=False) # channel axis: 1 - model = i >> b - model.initialize(num_batch=2) - - inputs = bp.math.ones((2, 3, 4)) - inputs[0, 0, :] = 2. - inputs[0, 1, 0] = 5. - print(inputs) - - print(model(inputs)) - - def test_batchnorm2(self): - i = bp.nn.Input((3, 4)) - b = bp.nn.BatchNorm(axis=(0, 2)) # channel axis: 1 - f = bp.nn.Reshape((-1, 12)) - o = bp.nn.DenseMD(2) - model = i >> b >> f >> o - model.initialize(num_batch=2) - - inputs = bp.math.ones((2, 3, 4)) - inputs[0, 0, :] = 2. - inputs[0, 1, 0] = 5. - # print(inputs) - print(model(inputs)) - - # training - bp.math.random.seed() - X = bp.math.random.random((1000, 10, 3, 4)) - Y = bp.math.random.randint(0, 2, (1000, 10, 2)) - trainer = bp.nn.BPTT(model, - loss=bp.losses.cross_entropy_loss, - optimizer=bp.optim.Adam(lr=1e-3)) - trainer.fit([X, Y]) - - -class TestLayerNorm(TestCase): - def test_layernorm1(self): - i = bp.nn.Input((3, 4)) - l = bp.nn.LayerNorm() - model = i >> l - model.initialize(num_batch=2) - - inputs = bp.math.ones((2, 3, 4)) - inputs[0, 0, :] = 2. - inputs[0, 1, 0] = 5. - print(inputs) - - print(model(inputs)) - - def test_layernorm2(self): - i = bp.nn.Input((3, 4)) - l = bp.nn.LayerNorm(axis=2) - model = i >> l - model.initialize(num_batch=2) - - inputs = bp.math.ones((2, 3, 4)) - inputs[0, 0, :] = 2. - inputs[0, 1, 0] = 5. - print(inputs) - - print(model(inputs)) - - -class TestInstanceNorm(TestCase): - def test_instancenorm(self): - i = bp.nn.Input((3, 4)) - l = bp.nn.InstanceNorm() - model = i >> l - model.initialize(num_batch=2) - - inputs = bp.math.ones((2, 3, 4)) - inputs[0, 0, :] = 2. - inputs[0, 1, 0] = 5. - print(inputs) - - print(model(inputs)) - - -class TestGroupNorm(TestCase): - def test_groupnorm1(self): - i = bp.nn.Input((3, 4)) - l = bp.nn.GroupNorm(num_groups=2) - model = i >> l - model.initialize(num_batch=2) - - inputs = bp.math.ones((2, 3, 4)) - inputs[0, 0, :] = 2. - inputs[0, 1, 0] = 5. - print(inputs) - - print(model(inputs)) diff --git a/brainpy/compat/nn/nodes/ANN/tests/test_pooling.py b/brainpy/compat/nn/nodes/ANN/tests/test_pooling.py deleted file mode 100644 index 8ca5720a9..000000000 --- a/brainpy/compat/nn/nodes/ANN/tests/test_pooling.py +++ /dev/null @@ -1,56 +0,0 @@ -# -*- coding: utf-8 -*- -import random - -import pytest -from unittest import TestCase -import brainpy as bp -import jax.numpy as jnp -import jax -import numpy as np - - -class TestPool(TestCase): - def test_maxpool(self): - i = bp.nn.Input((3, 3, 1)) - p = bp.nn.MaxPool((2, 2)) - model = i >> p - model.initialize(num_batch=1) - - x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) - - y = model(x) - print("out shape: ", y.shape) - expected_y = jnp.array([ - [4., 5.], - [7., 8.], - ]).reshape((1, 2, 2, 1)) - np.testing.assert_allclose(y, expected_y) - - def test_minpool(self): - i = bp.nn.Input((3, 3, 1)) - p = bp.nn.MinPool((2, 2)) - model = i >> p - model.initialize(num_batch=1) - - x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) - - y = model(x) - print("out shape: ", y.shape) - expected_y = jnp.array([ - [0., 1.], - [3., 4.], - ]).reshape((1, 2, 2, 1)) - np.testing.assert_allclose(y, expected_y) - - def test_avgpool(self): - i = bp.nn.Input((3, 3, 1)) - p = bp.nn.AvgPool((2, 2)) - model = i >> p - model.initialize(num_batch=1) - - x = jnp.full((1, 3, 3, 1), 2.) - y = model(x) - print("out shape: ", y.shape) - np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.)) - - From 70a3e9ac72eaf2c62bfea633fab3f843887970ac Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 5 Jun 2022 21:53:09 +0800 Subject: [PATCH 9/9] fix `l2_norm` bug --- brainpy/losses/regularization.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/brainpy/losses/regularization.py b/brainpy/losses/regularization.py index e4612b48e..51811063c 100644 --- a/brainpy/losses/regularization.py +++ b/brainpy/losses/regularization.py @@ -2,6 +2,7 @@ from jax.tree_util import tree_flatten, tree_map +import jax.numpy as jnp import brainpy.math as bm from .utils import _is_leaf, _multi_return @@ -22,8 +23,8 @@ def l2_norm(x, axis=None): Returns: scalar tensor containing the l2 loss of x. """ - leaves, _ = tree_flatten(x, is_leaf=_is_leaf) - return bm.sqrt(bm.sum([bm.vdot(x, x) for x in leaves], axis=axis)) + leaves, _ = tree_flatten(x) + return jnp.sqrt(jnp.sum(jnp.asarray([jnp.vdot(x, x) for x in leaves]), axis=axis)) def mean_absolute(outputs, axis=None):