From db33c244a05e8419c5736f64f276501cc1adb264 Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Wed, 7 Aug 2013 15:39:37 +0200 Subject: [PATCH 1/3] Implement new indexing semantic for StateMonitor, closes #88 --- brian2/monitors/statemonitor.py | 96 +++++++++++++++++++++++++++++---- brian2/tests/test_functions.py | 4 +- brian2/tests/test_monitor.py | 14 ++--- brian2/tests/test_refractory.py | 10 ++-- 4 files changed, 100 insertions(+), 24 deletions(-) diff --git a/brian2/monitors/statemonitor.py b/brian2/monitors/statemonitor.py index 8dbc6f20f..03e12aff5 100644 --- a/brian2/monitors/statemonitor.py +++ b/brian2/monitors/statemonitor.py @@ -1,4 +1,5 @@ import weakref +import collections import numpy as np @@ -15,18 +16,85 @@ __all__ = ['StateMonitor'] +class StateMonitorView(object): + def __init__(self, monitor, item): + self.monitor = monitor + self.item = item + self.indices = self._calc_indices(item) + self._group_attribute_access_active = True + + def __getattr__(self, item): + # We do this because __setattr__ and __getattr__ are not active until + # _group_attribute_access_active attribute is set, and if it is set, + # then __getattr__ will not be called. Therefore, if getattr is called + # with this name, it is because it hasn't been set yet and so this + # method should raise an AttributeError to agree that it hasn't been + # called yet. + if item == '_group_attribute_access_active': + raise AttributeError + if not hasattr(self, '_group_attribute_access_active'): + raise AttributeError + + if item == 't': + return Quantity(self.monitor._t.data.copy(), dim=second.dim) + elif item == 't_': + return self.monitor._t.data.copy() + elif item in self.monitor.variables: + unit = self.monitor.specifiers[item].unit + return Quantity(self.monitor._values[item].data.T[self.indices].copy(), + dim=unit.dim) + elif item.endswith('_') and item[:-1] in self.monitor.variables: + return self.monitor._values[item[:-1]].data.T[self.indices].copy() + else: + raise AttributeError('Unknown attribute %s' % item) + + def _calc_indices(self, item): + ''' + Convert the neuron indices to indices into the stored values. For example, if neurons [0, 5, 10] have been + recorded, [5, 10] is converted to [1, 2]. + ''' + if isinstance(item, int): + indices = np.nonzero(self.monitor.indices == item)[0] + if len(indices) == 0: + raise IndexError('Neuron number %d has not been recorded' % item) + return indices[0] + + if self.monitor.record_all: + return index_array + indices = [] + for position, index in enumerate(self.monitor.indices): + if index in index_array: + indices.append(position) + else: + raise IndexError('Neuron number %d has not been recorded' % index) + return np.array(indices) + + def __repr__(self): + description = '<{classname}, giving access to elements {elements} recorded by {monitor}>' + return description.format(classname=self.__class__.__name__, + elements=repr(self.item), + monitor=self.monitor.name) + + class StateMonitor(BrianObject): ''' Record values of state variables during a run To extract recorded values after a run, use `t` attribute for the array of times at which values were recorded, and variable name attribute - for the values. The values will have shape ``(len(t), len(indices))``, - where `indices` are the array indices which were recorded. + for the values. The values will have shape ``(len(indices), len(t))``, + where `indices` are the array indices which were recorded. When indexing the + `StateMonitor` directly, the returned object can be used to get the + recorded values for the specified indices, i.e. the indexing semantic + refers to the indices in `source`, not to the relative indices of the + recorded values. For example, when recording only neurons with even numbers, + `mon[[0, 2]].v` will return the values for neurons 0 and 2, whereas + `mon.v[[0, 2]]` will return the values for the first and third *recorded* + neurons, i.e. for neurons 0 and 4. Parameters ---------- - source : `NeuronGroup`, `Group` + source : `Group` Which object to record values from. variables : str, sequence of str, True Which variables to record, or ``True`` to record all variables @@ -58,7 +126,7 @@ class StateMonitor(BrianObject): G.V = rand(len(G)) M = StateMonitor(G, True, record=range(5)) run(100*ms) - plot(M.t, M.V) + plot(M.t, M.V.T) show() ''' @@ -153,6 +221,17 @@ def pre_run(self, namespace): def update(self): self.codeobj() + def __getitem__(self, item): + if isinstance(item, (int, np.ndarray)): + return StateMonitorView(self, item) + elif isinstance(item, collections.Seqquence): + index_array = np.array(item) + if not np.issubdtype(index_array.dtype, np.int): + raise TypeError('Index has to be an integer or a sequence of integers') + return StateMonitorView(self, item) + else: + raise TypeError('Cannot use object of type %s as an index' % type(item)) + def __getattr__(self, item): # We do this because __setattr__ and __getattr__ are not active until # _group_attribute_access_active attribute is set, and if it is set, @@ -172,13 +251,10 @@ def __getattr__(self, item): return self._t.data.copy() elif item in self.variables: unit = self.specifiers[item].unit - if have_same_dimensions(unit, 1): - return self._values[item].data.copy() - else: - return Quantity(self._values[item].data.copy(), - dim=unit.dim) + return Quantity(self._values[item].data.T.copy(), + dim=unit.dim) elif item.endswith('_') and item[:-1] in self.variables: - return self._values[item[:-1]].data.copy() + return self._values[item[:-1]].data.T.copy() else: raise AttributeError('Unknown attribute %s' % item) diff --git a/brian2/tests/test_functions.py b/brian2/tests/test_functions.py index 54ad07907..12a9e64f3 100644 --- a/brian2/tests/test_functions.py +++ b/brian2/tests/test_functions.py @@ -48,7 +48,7 @@ def test_math_functions(): net = Network(G, mon) net.run(clock.dt) - assert_equal(numpy_result, mon.func_[0, :], + assert_equal(numpy_result, mon.func_.flatten(), 'Function %s did not return the correct values' % func.__name__) # Functions/operators @@ -70,7 +70,7 @@ def test_math_functions(): net = Network(G, mon) net.run(clock.dt) - assert_equal(numpy_result, mon.func_[0, :], + assert_equal(numpy_result, mon.func_.flatten(), 'Function %s did not return the correct values' % func.__name__) diff --git a/brian2/tests/test_monitor.py b/brian2/tests/test_monitor.py index 9f07d3253..19ee19749 100644 --- a/brian2/tests/test_monitor.py +++ b/brian2/tests/test_monitor.py @@ -81,19 +81,19 @@ def test_state_monitor(): np.arange(len(nothing_mon.t)) * defaultclock.dt) # Check v recording - assert_allclose(v_mon.v, + assert_allclose(v_mon.v.T, np.exp(np.tile(-v_mon.t - defaultclock.dt, (2, 1)).T / (10*ms))) - assert_allclose(v_mon.v_, + assert_allclose(v_mon.v_.T, np.exp(np.tile(-v_mon.t_ - defaultclock.dt_, (2, 1)).T / float(10*ms))) assert_equal(v_mon.v, multi_mon.v) assert_equal(v_mon.v_, multi_mon.v_) - assert_equal(v_mon.v[:, 1:2], v_mon1.v) - assert_equal(multi_mon.v[:, 1:2], multi_mon1.v) + assert_equal(v_mon.v[1:2], v_mon1.v) + assert_equal(multi_mon.v[1:2], multi_mon1.v) # Other variables - assert_equal(multi_mon.rate_, np.tile(np.atleast_2d(G.rate_), - (multi_mon.rate.shape[0], 1))) - assert_equal(multi_mon.rate[:, 1:2], multi_mon1.rate) + assert_equal(multi_mon.rate_.T, np.tile(np.atleast_2d(G.rate_), + (multi_mon.rate.shape[1], 1))) + assert_equal(multi_mon.rate[1:2], multi_mon1.rate) assert_allclose(np.clip(multi_mon.v, 0.1, 0.9), multi_mon.f) assert_allclose(np.clip(multi_mon1.v, 0.1, 0.9), multi_mon1.f) diff --git a/brian2/tests/test_refractory.py b/brian2/tests/test_refractory.py index 1c459dc0e..0056ffcc5 100644 --- a/brian2/tests/test_refractory.py +++ b/brian2/tests/test_refractory.py @@ -42,15 +42,15 @@ def test_refractoriness_variables(): net = Network(G, mon) net.run(20*ms) # No difference before the spike - assert_equal(mon.v[mon.t < 10*ms], mon.w[mon.t < 10*ms]) + assert_equal(mon[0].v[mon.t < 10*ms], mon[0].w[mon.t < 10*ms]) # v is not updated during refractoriness - in_refractoriness = mon.v[(mon.t >= 10*ms) & (mon.t <15*ms)] + in_refractoriness = mon[0].v[(mon.t >= 10*ms) & (mon.t <15*ms)] assert_equal(in_refractoriness, np.zeros_like(in_refractoriness)) # w should evolve as before - assert_equal(mon.w[mon.t < 5*ms], mon.w[(mon.t >= 10*ms) & (mon.t <15*ms)]) - assert np.all(mon.w[(mon.t >= 10*ms) & (mon.t <15*ms)] > 0) + assert_equal(mon[0].w[mon.t < 5*ms], mon[0].w[(mon.t >= 10*ms) & (mon.t <15*ms)]) + assert np.all(mon[0].w[(mon.t >= 10*ms) & (mon.t <15*ms)] > 0) # After refractoriness, v should increase again - assert np.all(mon.v[(mon.t >= 15*ms) & (mon.t <20*ms)] > 0) + assert np.all(mon[0].v[(mon.t >= 15*ms) & (mon.t <20*ms)] > 0) def test_refractoriness_threshold(): From 715a948e150b977d7bb080da01c4cc568127fecc Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Fri, 9 Aug 2013 11:11:37 +0200 Subject: [PATCH 2/3] Add tests for the new indexing semantics --- brian2/monitors/statemonitor.py | 16 +++++++++------- brian2/tests/test_monitor.py | 26 +++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/brian2/monitors/statemonitor.py b/brian2/monitors/statemonitor.py index 03e12aff5..5bf9e247a 100644 --- a/brian2/monitors/statemonitor.py +++ b/brian2/monitors/statemonitor.py @@ -60,11 +60,11 @@ def _calc_indices(self, item): return indices[0] if self.monitor.record_all: - return index_array + return item indices = [] - for position, index in enumerate(self.monitor.indices): - if index in index_array: - indices.append(position) + for index in item: + if index in self.monitor.indices: + indices.append(np.nonzero(self.monitor.indices == index)[0][0]) else: raise IndexError('Neuron number %d has not been recorded' % index) return np.array(indices) @@ -224,13 +224,15 @@ def update(self): def __getitem__(self, item): if isinstance(item, (int, np.ndarray)): return StateMonitorView(self, item) - elif isinstance(item, collections.Seqquence): + elif isinstance(item, collections.Sequence): index_array = np.array(item) if not np.issubdtype(index_array.dtype, np.int): - raise TypeError('Index has to be an integer or a sequence of integers') + raise TypeError('Index has to be an integer or a sequence ' + 'of integers') return StateMonitorView(self, item) else: - raise TypeError('Cannot use object of type %s as an index' % type(item)) + raise TypeError('Cannot use object of type %s as an index' + % type(item)) def __getattr__(self, item): # We do this because __setattr__ and __getattr__ are not active until diff --git a/brian2/tests/test_monitor.py b/brian2/tests/test_monitor.py index 19ee19749..a2c466142 100644 --- a/brian2/tests/test_monitor.py +++ b/brian2/tests/test_monitor.py @@ -1,5 +1,5 @@ import numpy as np -from numpy.testing.utils import assert_allclose, assert_equal +from numpy.testing.utils import assert_allclose, assert_equal, assert_raises from brian2 import * @@ -97,6 +97,30 @@ def test_state_monitor(): assert_allclose(np.clip(multi_mon.v, 0.1, 0.9), multi_mon.f) assert_allclose(np.clip(multi_mon1.v, 0.1, 0.9), multi_mon1.f) + # Check indexing semantics + G = NeuronGroup(10, 'v:volt') + G.v = np.arange(10) * volt + mon = StateMonitor(G, 'v', record=[5, 6, 7]) + + net = Network(G, mon) + net.run(2 * defaultclock.dt) + + assert_equal(mon.v, np.array([[5, 5], + [6, 6], + [7, 7]]) * volt) + assert_equal(mon.v_, np.array([[5, 5], + [6, 6], + [7, 7]])) + assert_equal(mon[5].v, mon.v[0]) + assert_equal(mon[7].v, mon.v[2]) + assert_equal(mon[[5, 7]].v, mon.v[[0, 2]]) + assert_equal(mon[np.array([5, 7])].v, mon.v[[0, 2]]) + + assert_raises(IndexError, lambda: mon[8]) + assert_raises(TypeError, lambda: mon['string']) + assert_raises(TypeError, lambda: mon[5.0]) + assert_raises(TypeError, lambda: mon[[5.0, 6.0]]) + brian_prefs.codegen.target = language_before From c89e0537c06a063ee37c787177d20cc4387a9175 Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Fri, 9 Aug 2013 14:49:53 +0200 Subject: [PATCH 3/3] Replace numpy scalar values with numpy data types with their Python equivalents (e.g. replace np.int64(1) with int(1)) -- weave otherwise treats them as general Python objects, not as numbers --- brian2/core/namespace.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/brian2/core/namespace.py b/brian2/core/namespace.py index ceb11afa8..6d7bd467f 100644 --- a/brian2/core/namespace.py +++ b/brian2/core/namespace.py @@ -169,11 +169,24 @@ def resolve(self, identifier, additional_namespace=None, strip_units=False): # use the first match (according to resolution order) resolved = matches[0][1] + + # Remove units if strip_units and isinstance(resolved, Quantity): if resolved.ndim == 0: resolved = float(resolved) else: resolved = np.asarray(resolved) + + # Use standard Python types if possible + if not isinstance(resolved, np.ndarray) and hasattr(resolved, 'dtype'): + numpy_type = resolved.dtype + if np.can_cast(numpy_type, np.int_): + resolved = int(resolved) + elif np.can_cast(numpy_type, np.float_): + resolved = float(resolved) + elif np.can_cast(numpy_type, np.complex_): + resolved = complex(resolved) + return resolved def resolve_all(self, identifiers, additional_namespace=None,