From 0af38b0d789ae655054bdf9e8e46f1777e781e81 Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Tue, 23 Jul 2013 19:40:15 +0200 Subject: [PATCH] Fix bugs and add tests for StateMonitor, no longer derive from Groups (this doesn't play nicely with subexpressions) --- .../numpy_rt/templates/statemonitor.py_ | 2 +- .../weave_rt/templates/statemonitor.cpp | 8 +- brian2/monitors/statemonitor.py | 90 ++++++++----------- brian2/tests/test_monitor.py | 45 ++++++++++ 4 files changed, 90 insertions(+), 55 deletions(-) diff --git a/brian2/codegen/runtime/numpy_rt/templates/statemonitor.py_ b/brian2/codegen/runtime/numpy_rt/templates/statemonitor.py_ index ab287297a..640f432f1 100644 --- a/brian2/codegen/runtime/numpy_rt/templates/statemonitor.py_ +++ b/brian2/codegen/runtime/numpy_rt/templates/statemonitor.py_ @@ -12,7 +12,7 @@ _recorded_{{_varname}}.resize((_new_len, _num_indices)) _t[-1] = _clock_t _vectorisation_idx = _indices -_record_idx = _indices +_neuron_idx = _indices[:] {% for line in code_lines %} {{line}} {% endfor %} diff --git a/brian2/codegen/runtime/weave_rt/templates/statemonitor.cpp b/brian2/codegen/runtime/weave_rt/templates/statemonitor.cpp index 94e0d27db..0dbeb8d04 100644 --- a/brian2/codegen/runtime/weave_rt/templates/statemonitor.cpp +++ b/brian2/codegen/runtime/weave_rt/templates/statemonitor.cpp @@ -44,12 +44,14 @@ const npy_intp* _record_strides = _record_data->strides; for (int _idx=0; _idx < _num_indices; _idx++) { - const int _record_idx = _indices[_idx]; - const int _vectorisation_idx = _record_idx; + const int _neuron_idx = _indices[_idx]; + const int _vectorisation_idx = _neuron_idx; {% for line in code_lines %} {{line}} {% endfor %} - double *recorded_entry = ((double*)(_record_data->data + (_new_len - 1)*_record_strides[0] + _idx*_record_strides[1])); + + // FIXME: This will not work for variables with other data types + double *recorded_entry = (double*)(_record_data->data + (_new_len - 1)*_record_strides[0] + _idx*_record_strides[1]); *recorded_entry = _to_record_{{_varname}}; } } diff --git a/brian2/monitors/statemonitor.py b/brian2/monitors/statemonitor.py index 78efbd76b..be8da2d62 100644 --- a/brian2/monitors/statemonitor.py +++ b/brian2/monitors/statemonitor.py @@ -8,33 +8,16 @@ from brian2.core.base import BrianObject from brian2.core.scheduler import Scheduler from brian2.core.preferences import brian_prefs -from brian2.groups.group import Group -from brian2.units.fundamentalunits import Unit +from brian2.units.fundamentalunits import Unit, Quantity from brian2.units.allunits import second from brian2.memory.dynamicarray import DynamicArray, DynamicArray1D from brian2.groups.group import create_runner_codeobj +from units.fundamentalunits import have_same_dimensions __all__ = ['StateMonitor'] -class MonitorVariable(Value): - def __init__(self, name, unit, dtype, monitor): - Value.__init__(self, name, unit, dtype) - self.monitor = weakref.proxy(monitor) - - def get_value(self): - return self.monitor._values[self.name] - - -class MonitorTime(Value): - def __init__(self, monitor): - Value.__init__(self, 't', second, np.float64) - self.monitor = weakref.proxy(monitor) - - def get_value(self): - return self.monitor._t.data.copy() - -class StateMonitor(BrianObject, Group): +class StateMonitor(BrianObject): ''' Record values of state variables during a run @@ -100,16 +83,19 @@ def __init__(self, source, variables, record=None, when=None, elif isinstance(variables, str): variables = [variables] self.variables = variables - + + if len(self.variables) == 0: + raise ValueError('No variables to record') + # record should always be an array of ints self.record_all = False if record is None or record is False: - record = np.array([], dtype=int) + record = np.array([], dtype=np.int32) elif record is True: self.record_all = True - record = np.arange(len(source)) + record = np.arange(len(source), dtype=np.int32) else: - record = np.array(record, dtype=int) + record = np.array(record, dtype=np.int32) #: The array of recorded indices self.indices = record @@ -117,43 +103,27 @@ def __init__(self, source, variables, record=None, when=None, # create data structures self.reinit() - # initialise Group access + # Setup specifiers self.specifiers = {} for variable in variables: spec = source.specifiers[variable] - if isinstance(spec, ArrayVariable): - self.specifiers['_source_' + variable] = ArrayVariable(variable, - spec.unit, - spec.dtype, - spec.array, - '_record_idx', - group=spec.group, - constant=spec.constant, - scalar=spec.scalar, - is_bool=spec.is_bool) - elif isinstance(spec, Subexpression): - self.specifiers['_source_' + variable] = weakref.proxy(spec) - else: - raise TypeError('Variable %s cannot be recorded.' % variable) - self.specifiers[variable] = MonitorVariable(variable, - spec.unit, - spec.dtype, - self) + self.specifiers[variable] = weakref.proxy(spec) + self.specifiers['_recorded_'+variable] = ReadOnlyValue('_recorded_'+variable, Unit(1), self._values[variable].dtype, self._values[variable]) + self.specifiers['_t'] = ReadOnlyValue('_t', Unit(1), self._t.dtype, self._t) - self.specifiers['_clock_t'] = AttributeValue('t', second, np.float64, self.clock, 't_') - - self.specifiers['t'] = MonitorTime(self) - - Group.__init__(self) + self.specifiers['_indices'] = ArrayVariable('_indices', Unit(1), + np.int32, self.indices, + index='', group=None, + constant=True) def reinit(self): - self._values = dict((v, DynamicArray((0, len(self.variables)), + self._values = dict((v, DynamicArray((0, len(self.indices)), use_numpy_resize=True, dtype=self.source.specifiers[v].dtype)) for v in self.variables) @@ -163,7 +133,7 @@ def reinit(self): def pre_run(self, namespace): # Some dummy code so that code generation takes care of the indexing # and subexpressions - code = ['_to_record_%s = _source_%s' % (v, v) + code = ['_to_record_%s = %s' % (v, v) for v in self.variables] code += ['_recorded_%s = _recorded_%s' % (v, v) for v in self.variables] @@ -174,13 +144,31 @@ def pre_run(self, namespace): additional_specifiers=self.specifiers, additional_namespace=namespace, template_name='statemonitor', - indices={'_record_idx': Index('_record_idx', self.record_all)}, + indices={'_neuron_idx': Index('_neuron_idx', self.record_all)}, template_kwds={'_variable_names': self.variables}, codeobj_class=self.codeobj_class) def update(self): self.codeobj() + def __getattr__(self, item): + # TODO: Decide about the interface + if item == 't': + return Quantity(self._t.data.copy(), dim=second.dim) + elif item == 't_': + return self._t.data.copy() + elif item in self.variables: + unit = self.specifiers[item].unit + if have_same_dimensions(unit, 1): + return self._values[item].data.copy() + else: + return Quantity(self._values[item].data.copy(), + dim=unit.dim) + elif item.endswith('_') and item[:-1] in self.variables: + return self._values[item[:-1]].data.copy() + else: + getattr(super(StateMonitor, self), item) + def __repr__(self): description = '<{classname}, recording {variables} from {source}>' return description.format(classname=self.__class__.__name__, diff --git a/brian2/tests/test_monitor.py b/brian2/tests/test_monitor.py index a32c923ec..073c0a4af 100644 --- a/brian2/tests/test_monitor.py +++ b/brian2/tests/test_monitor.py @@ -35,6 +35,51 @@ def test_spike_monitor(): brian_prefs.codegen.target = language_before +def test_state_monitor(): + # Record all neurons + language_before = brian_prefs.codegen.target + for language in languages: + brian_prefs.codegen.target = language + defaultclock.t = 0*second + # Check that all kinds of variables can be recorded + G = NeuronGroup(2, '''dv/dt = -v / (10*ms) : 1 + f = clip(v, 0.1, 0.9) : 1 + rate: Hz''', threshold='v>1', reset='v=0') + G.rate = [100, 1000] * Hz + G.v = 1 + + # Use a single StateMonitor + v_mon = StateMonitor(G, 'v', record=True) + v_mon1 = StateMonitor(G, 'v', record=[1]) + + # Use a StateMonitor for specified variables + multi_mon = StateMonitor(G, ['v', 'f', 'rate'], record=True) + multi_mon1 = StateMonitor(G, ['v', 'f', 'rate'], record=[1]) + + net = Network(G, v_mon, v_mon1, + multi_mon, multi_mon1) + net.run(10*ms) + + # Check v recording + assert_allclose(v_mon.v, + np.exp(np.tile(-v_mon.t - defaultclock.dt, (2, 1)).T / (10*ms))) + assert_allclose(v_mon.v_, + np.exp(np.tile(-v_mon.t_ - defaultclock.dt_, (2, 1)).T / float(10*ms))) + assert_equal(v_mon.v, multi_mon.v) + assert_equal(v_mon.v_, multi_mon.v_) + assert_equal(v_mon.v[:, 1:2], v_mon1.v) + assert_equal(multi_mon.v[:, 1:2], multi_mon1.v) + + # Other variables + assert_equal(multi_mon.rate_, np.tile(np.atleast_2d(G.rate_), + (multi_mon.rate.shape[0], 1))) + assert_equal(multi_mon.rate[:, 1:2], multi_mon1.rate) + assert_allclose(np.clip(multi_mon.v, 0.1, 0.9), multi_mon.f) + assert_allclose(np.clip(multi_mon1.v, 0.1, 0.9), multi_mon1.f) + + brian_prefs.codegen.target = language_before + if __name__ == '__main__': test_spike_monitor() + test_state_monitor()