Skip to content

Commit

Permalink
Fix bugs and add tests for StateMonitor, no longer derive from Groups…
Browse files Browse the repository at this point in the history
… (this doesn't play nicely with subexpressions)
  • Loading branch information
Marcel Stimberg committed Jul 23, 2013
1 parent ac84555 commit 0af38b0
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 55 deletions.
2 changes: 1 addition & 1 deletion brian2/codegen/runtime/numpy_rt/templates/statemonitor.py_
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ _recorded_{{_varname}}.resize((_new_len, _num_indices))
_t[-1] = _clock_t

_vectorisation_idx = _indices
_record_idx = _indices
_neuron_idx = _indices[:]
{% for line in code_lines %}
{{line}}
{% endfor %}
Expand Down
8 changes: 5 additions & 3 deletions brian2/codegen/runtime/weave_rt/templates/statemonitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@
const npy_intp* _record_strides = _record_data->strides;
for (int _idx=0; _idx < _num_indices; _idx++)
{
const int _record_idx = _indices[_idx];
const int _vectorisation_idx = _record_idx;
const int _neuron_idx = _indices[_idx];
const int _vectorisation_idx = _neuron_idx;
{% for line in code_lines %}
{{line}}
{% endfor %}
double *recorded_entry = ((double*)(_record_data->data + (_new_len - 1)*_record_strides[0] + _idx*_record_strides[1]));

// FIXME: This will not work for variables with other data types
double *recorded_entry = (double*)(_record_data->data + (_new_len - 1)*_record_strides[0] + _idx*_record_strides[1]);
*recorded_entry = _to_record_{{_varname}};
}
}
Expand Down
90 changes: 39 additions & 51 deletions brian2/monitors/statemonitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,16 @@
from brian2.core.base import BrianObject
from brian2.core.scheduler import Scheduler
from brian2.core.preferences import brian_prefs
from brian2.groups.group import Group
from brian2.units.fundamentalunits import Unit
from brian2.units.fundamentalunits import Unit, Quantity
from brian2.units.allunits import second
from brian2.memory.dynamicarray import DynamicArray, DynamicArray1D
from brian2.groups.group import create_runner_codeobj
from units.fundamentalunits import have_same_dimensions

__all__ = ['StateMonitor']


class MonitorVariable(Value):
def __init__(self, name, unit, dtype, monitor):
Value.__init__(self, name, unit, dtype)
self.monitor = weakref.proxy(monitor)

def get_value(self):
return self.monitor._values[self.name]


class MonitorTime(Value):
def __init__(self, monitor):
Value.__init__(self, 't', second, np.float64)
self.monitor = weakref.proxy(monitor)

def get_value(self):
return self.monitor._t.data.copy()

class StateMonitor(BrianObject, Group):
class StateMonitor(BrianObject):
'''
Record values of state variables during a run
Expand Down Expand Up @@ -100,60 +83,47 @@ def __init__(self, source, variables, record=None, when=None,
elif isinstance(variables, str):
variables = [variables]
self.variables = variables


if len(self.variables) == 0:
raise ValueError('No variables to record')

# record should always be an array of ints
self.record_all = False
if record is None or record is False:
record = np.array([], dtype=int)
record = np.array([], dtype=np.int32)
elif record is True:
self.record_all = True
record = np.arange(len(source))
record = np.arange(len(source), dtype=np.int32)
else:
record = np.array(record, dtype=int)
record = np.array(record, dtype=np.int32)

#: The array of recorded indices
self.indices = record

# create data structures
self.reinit()

# initialise Group access
# Setup specifiers
self.specifiers = {}
for variable in variables:
spec = source.specifiers[variable]
if isinstance(spec, ArrayVariable):
self.specifiers['_source_' + variable] = ArrayVariable(variable,
spec.unit,
spec.dtype,
spec.array,
'_record_idx',
group=spec.group,
constant=spec.constant,
scalar=spec.scalar,
is_bool=spec.is_bool)
elif isinstance(spec, Subexpression):
self.specifiers['_source_' + variable] = weakref.proxy(spec)
else:
raise TypeError('Variable %s cannot be recorded.' % variable)
self.specifiers[variable] = MonitorVariable(variable,
spec.unit,
spec.dtype,
self)
self.specifiers[variable] = weakref.proxy(spec)

self.specifiers['_recorded_'+variable] = ReadOnlyValue('_recorded_'+variable, Unit(1),
self._values[variable].dtype,
self._values[variable])

self.specifiers['_t'] = ReadOnlyValue('_t', Unit(1), self._t.dtype,
self._t)

self.specifiers['_clock_t'] = AttributeValue('t', second, np.float64,
self.clock, 't_')

self.specifiers['t'] = MonitorTime(self)

Group.__init__(self)
self.specifiers['_indices'] = ArrayVariable('_indices', Unit(1),
np.int32, self.indices,
index='', group=None,
constant=True)

def reinit(self):
self._values = dict((v, DynamicArray((0, len(self.variables)),
self._values = dict((v, DynamicArray((0, len(self.indices)),
use_numpy_resize=True,
dtype=self.source.specifiers[v].dtype))
for v in self.variables)
Expand All @@ -163,7 +133,7 @@ def reinit(self):
def pre_run(self, namespace):
# Some dummy code so that code generation takes care of the indexing
# and subexpressions
code = ['_to_record_%s = _source_%s' % (v, v)
code = ['_to_record_%s = %s' % (v, v)
for v in self.variables]
code += ['_recorded_%s = _recorded_%s' % (v, v)
for v in self.variables]
Expand All @@ -174,13 +144,31 @@ def pre_run(self, namespace):
additional_specifiers=self.specifiers,
additional_namespace=namespace,
template_name='statemonitor',
indices={'_record_idx': Index('_record_idx', self.record_all)},
indices={'_neuron_idx': Index('_neuron_idx', self.record_all)},
template_kwds={'_variable_names': self.variables},
codeobj_class=self.codeobj_class)

def update(self):
self.codeobj()

def __getattr__(self, item):
# TODO: Decide about the interface
if item == 't':
return Quantity(self._t.data.copy(), dim=second.dim)
elif item == 't_':
return self._t.data.copy()
elif item in self.variables:
unit = self.specifiers[item].unit
if have_same_dimensions(unit, 1):
return self._values[item].data.copy()
else:
return Quantity(self._values[item].data.copy(),
dim=unit.dim)
elif item.endswith('_') and item[:-1] in self.variables:
return self._values[item[:-1]].data.copy()
else:
getattr(super(StateMonitor, self), item)

def __repr__(self):
description = '<{classname}, recording {variables} from {source}>'
return description.format(classname=self.__class__.__name__,
Expand Down
45 changes: 45 additions & 0 deletions brian2/tests/test_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,51 @@ def test_spike_monitor():
brian_prefs.codegen.target = language_before


def test_state_monitor():
# Record all neurons
language_before = brian_prefs.codegen.target
for language in languages:
brian_prefs.codegen.target = language
defaultclock.t = 0*second
# Check that all kinds of variables can be recorded
G = NeuronGroup(2, '''dv/dt = -v / (10*ms) : 1
f = clip(v, 0.1, 0.9) : 1
rate: Hz''', threshold='v>1', reset='v=0')
G.rate = [100, 1000] * Hz
G.v = 1

# Use a single StateMonitor
v_mon = StateMonitor(G, 'v', record=True)
v_mon1 = StateMonitor(G, 'v', record=[1])

# Use a StateMonitor for specified variables
multi_mon = StateMonitor(G, ['v', 'f', 'rate'], record=True)
multi_mon1 = StateMonitor(G, ['v', 'f', 'rate'], record=[1])

net = Network(G, v_mon, v_mon1,
multi_mon, multi_mon1)
net.run(10*ms)

# Check v recording
assert_allclose(v_mon.v,
np.exp(np.tile(-v_mon.t - defaultclock.dt, (2, 1)).T / (10*ms)))
assert_allclose(v_mon.v_,
np.exp(np.tile(-v_mon.t_ - defaultclock.dt_, (2, 1)).T / float(10*ms)))
assert_equal(v_mon.v, multi_mon.v)
assert_equal(v_mon.v_, multi_mon.v_)
assert_equal(v_mon.v[:, 1:2], v_mon1.v)
assert_equal(multi_mon.v[:, 1:2], multi_mon1.v)

# Other variables
assert_equal(multi_mon.rate_, np.tile(np.atleast_2d(G.rate_),
(multi_mon.rate.shape[0], 1)))
assert_equal(multi_mon.rate[:, 1:2], multi_mon1.rate)
assert_allclose(np.clip(multi_mon.v, 0.1, 0.9), multi_mon.f)
assert_allclose(np.clip(multi_mon1.v, 0.1, 0.9), multi_mon1.f)

brian_prefs.codegen.target = language_before

if __name__ == '__main__':
test_spike_monitor()
test_state_monitor()

0 comments on commit 0af38b0

Please sign in to comment.