Skip to content

Commit

Permalink
Merge pull request #90 from brian-team/monitor_return_values
Browse files Browse the repository at this point in the history
Implement indexing semantics for StateMonitor
  • Loading branch information
thesamovar committed Aug 9, 2013
2 parents 04c133d + c89e053 commit 15f380c
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 25 deletions.
98 changes: 88 additions & 10 deletions brian2/monitors/statemonitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import weakref
import collections

import numpy as np

Expand All @@ -15,18 +16,85 @@
__all__ = ['StateMonitor']


class StateMonitorView(object):
def __init__(self, monitor, item):
self.monitor = monitor
self.item = item
self.indices = self._calc_indices(item)
self._group_attribute_access_active = True

def __getattr__(self, item):
# We do this because __setattr__ and __getattr__ are not active until
# _group_attribute_access_active attribute is set, and if it is set,
# then __getattr__ will not be called. Therefore, if getattr is called
# with this name, it is because it hasn't been set yet and so this
# method should raise an AttributeError to agree that it hasn't been
# called yet.
if item == '_group_attribute_access_active':
raise AttributeError
if not hasattr(self, '_group_attribute_access_active'):
raise AttributeError

if item == 't':
return Quantity(self.monitor._t.data.copy(), dim=second.dim)
elif item == 't_':
return self.monitor._t.data.copy()
elif item in self.monitor.variables:
unit = self.monitor.specifiers[item].unit
return Quantity(self.monitor._values[item].data.T[self.indices].copy(),
dim=unit.dim)
elif item.endswith('_') and item[:-1] in self.monitor.variables:
return self.monitor._values[item[:-1]].data.T[self.indices].copy()
else:
raise AttributeError('Unknown attribute %s' % item)

def _calc_indices(self, item):
'''
Convert the neuron indices to indices into the stored values. For example, if neurons [0, 5, 10] have been
recorded, [5, 10] is converted to [1, 2].
'''
if isinstance(item, int):
indices = np.nonzero(self.monitor.indices == item)[0]
if len(indices) == 0:
raise IndexError('Neuron number %d has not been recorded' % item)
return indices[0]

if self.monitor.record_all:
return item
indices = []
for index in item:
if index in self.monitor.indices:
indices.append(np.nonzero(self.monitor.indices == index)[0][0])
else:
raise IndexError('Neuron number %d has not been recorded' % index)
return np.array(indices)

def __repr__(self):
description = '<{classname}, giving access to elements {elements} recorded by {monitor}>'
return description.format(classname=self.__class__.__name__,
elements=repr(self.item),
monitor=self.monitor.name)


class StateMonitor(BrianObject):
'''
Record values of state variables during a run
To extract recorded values after a run, use `t` attribute for the
array of times at which values were recorded, and variable name attribute
for the values. The values will have shape ``(len(t), len(indices))``,
where `indices` are the array indices which were recorded.
for the values. The values will have shape ``(len(indices), len(t))``,
where `indices` are the array indices which were recorded. When indexing the
`StateMonitor` directly, the returned object can be used to get the
recorded values for the specified indices, i.e. the indexing semantic
refers to the indices in `source`, not to the relative indices of the
recorded values. For example, when recording only neurons with even numbers,
`mon[[0, 2]].v` will return the values for neurons 0 and 2, whereas
`mon.v[[0, 2]]` will return the values for the first and third *recorded*
neurons, i.e. for neurons 0 and 4.
Parameters
----------
source : `NeuronGroup`, `Group`
source : `Group`
Which object to record values from.
variables : str, sequence of str, True
Which variables to record, or ``True`` to record all variables
Expand Down Expand Up @@ -58,7 +126,7 @@ class StateMonitor(BrianObject):
G.V = rand(len(G))
M = StateMonitor(G, True, record=range(5))
run(100*ms)
plot(M.t, M.V)
plot(M.t, M.V.T)
show()
'''
Expand Down Expand Up @@ -153,6 +221,19 @@ def pre_run(self, namespace):
def update(self):
self.codeobj()

def __getitem__(self, item):
if isinstance(item, (int, np.ndarray)):
return StateMonitorView(self, item)
elif isinstance(item, collections.Sequence):
index_array = np.array(item)
if not np.issubdtype(index_array.dtype, np.int):
raise TypeError('Index has to be an integer or a sequence '
'of integers')
return StateMonitorView(self, item)
else:
raise TypeError('Cannot use object of type %s as an index'
% type(item))

def __getattr__(self, item):
# We do this because __setattr__ and __getattr__ are not active until
# _group_attribute_access_active attribute is set, and if it is set,
Expand All @@ -172,13 +253,10 @@ def __getattr__(self, item):
return self._t.data.copy()
elif item in self.variables:
unit = self.specifiers[item].unit
if have_same_dimensions(unit, 1):
return self._values[item].data.copy()
else:
return Quantity(self._values[item].data.copy(),
dim=unit.dim)
return Quantity(self._values[item].data.T.copy(),
dim=unit.dim)
elif item.endswith('_') and item[:-1] in self.variables:
return self._values[item[:-1]].data.copy()
return self._values[item[:-1]].data.T.copy()
else:
raise AttributeError('Unknown attribute %s' % item)

Expand Down
4 changes: 2 additions & 2 deletions brian2/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_math_functions():
net = Network(G, mon)
net.run(clock.dt)

assert_equal(numpy_result, mon.func_[0, :],
assert_equal(numpy_result, mon.func_.flatten(),
'Function %s did not return the correct values' % func.__name__)

# Functions/operators
Expand All @@ -70,7 +70,7 @@ def test_math_functions():
net = Network(G, mon)
net.run(clock.dt)

assert_equal(numpy_result, mon.func_[0, :],
assert_equal(numpy_result, mon.func_.flatten(),
'Function %s did not return the correct values' % func.__name__)


Expand Down
40 changes: 32 additions & 8 deletions brian2/tests/test_monitor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from numpy.testing.utils import assert_allclose, assert_equal
from numpy.testing.utils import assert_allclose, assert_equal, assert_raises

from brian2 import *

Expand Down Expand Up @@ -81,22 +81,46 @@ def test_state_monitor():
np.arange(len(nothing_mon.t)) * defaultclock.dt)

# Check v recording
assert_allclose(v_mon.v,
assert_allclose(v_mon.v.T,
np.exp(np.tile(-v_mon.t - defaultclock.dt, (2, 1)).T / (10*ms)))
assert_allclose(v_mon.v_,
assert_allclose(v_mon.v_.T,
np.exp(np.tile(-v_mon.t_ - defaultclock.dt_, (2, 1)).T / float(10*ms)))
assert_equal(v_mon.v, multi_mon.v)
assert_equal(v_mon.v_, multi_mon.v_)
assert_equal(v_mon.v[:, 1:2], v_mon1.v)
assert_equal(multi_mon.v[:, 1:2], multi_mon1.v)
assert_equal(v_mon.v[1:2], v_mon1.v)
assert_equal(multi_mon.v[1:2], multi_mon1.v)

# Other variables
assert_equal(multi_mon.rate_, np.tile(np.atleast_2d(G.rate_),
(multi_mon.rate.shape[0], 1)))
assert_equal(multi_mon.rate[:, 1:2], multi_mon1.rate)
assert_equal(multi_mon.rate_.T, np.tile(np.atleast_2d(G.rate_),
(multi_mon.rate.shape[1], 1)))
assert_equal(multi_mon.rate[1:2], multi_mon1.rate)
assert_allclose(np.clip(multi_mon.v, 0.1, 0.9), multi_mon.f)
assert_allclose(np.clip(multi_mon1.v, 0.1, 0.9), multi_mon1.f)

# Check indexing semantics
G = NeuronGroup(10, 'v:volt')
G.v = np.arange(10) * volt
mon = StateMonitor(G, 'v', record=[5, 6, 7])

net = Network(G, mon)
net.run(2 * defaultclock.dt)

assert_equal(mon.v, np.array([[5, 5],
[6, 6],
[7, 7]]) * volt)
assert_equal(mon.v_, np.array([[5, 5],
[6, 6],
[7, 7]]))
assert_equal(mon[5].v, mon.v[0])
assert_equal(mon[7].v, mon.v[2])
assert_equal(mon[[5, 7]].v, mon.v[[0, 2]])
assert_equal(mon[np.array([5, 7])].v, mon.v[[0, 2]])

assert_raises(IndexError, lambda: mon[8])
assert_raises(TypeError, lambda: mon['string'])
assert_raises(TypeError, lambda: mon[5.0])
assert_raises(TypeError, lambda: mon[[5.0, 6.0]])

brian_prefs.codegen.target = language_before


Expand Down
10 changes: 5 additions & 5 deletions brian2/tests/test_refractory.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ def test_refractoriness_variables():
net = Network(G, mon)
net.run(20*ms)
# No difference before the spike
assert_equal(mon.v[mon.t < 10*ms], mon.w[mon.t < 10*ms])
assert_equal(mon[0].v[mon.t < 10*ms], mon[0].w[mon.t < 10*ms])
# v is not updated during refractoriness
in_refractoriness = mon.v[(mon.t >= 10*ms) & (mon.t <15*ms)]
in_refractoriness = mon[0].v[(mon.t >= 10*ms) & (mon.t <15*ms)]
assert_equal(in_refractoriness, np.zeros_like(in_refractoriness))
# w should evolve as before
assert_equal(mon.w[mon.t < 5*ms], mon.w[(mon.t >= 10*ms) & (mon.t <15*ms)])
assert np.all(mon.w[(mon.t >= 10*ms) & (mon.t <15*ms)] > 0)
assert_equal(mon[0].w[mon.t < 5*ms], mon[0].w[(mon.t >= 10*ms) & (mon.t <15*ms)])
assert np.all(mon[0].w[(mon.t >= 10*ms) & (mon.t <15*ms)] > 0)
# After refractoriness, v should increase again
assert np.all(mon.v[(mon.t >= 15*ms) & (mon.t <20*ms)] > 0)
assert np.all(mon[0].v[(mon.t >= 15*ms) & (mon.t <20*ms)] > 0)


def test_refractoriness_threshold():
Expand Down

0 comments on commit 15f380c

Please sign in to comment.