Skip to content

Commit

Permalink
Merge pull request #403 from brian-team/improve_variable_handling
Browse files Browse the repository at this point in the history
Improve `Variable` objects
  • Loading branch information
thesamovar committed Feb 12, 2015
2 parents 1fdb52c + c11576b commit 88b0f67
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 14 deletions.
47 changes: 33 additions & 14 deletions brian2/core/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ def get_dtype_str(val):
return 'unknown[%s, %s]' % (str(val), val.__class__.__name__)


def variables_by_owner(variables, owner):
owner_name = getattr(owner, 'name', None)
return dict([(varname, var) for varname, var in variables.iteritems()
if getattr(var.owner, 'name', None) is owner_name])


class Variable(object):
'''
An object providing information about model variables (including implicit
Expand All @@ -100,6 +106,11 @@ class Variable(object):
known as ``v_post`` in a `Synapse` connecting to the group).
unit : `Unit`
The unit of the variable.
owner : `Nameable`, optional
The object that "owns" this variable, e.g. the `NeuronGroup` or
`Synapses` object that declares the variable in its model equations.
Defaults to ``None`` (the value used for `Variable` objects without an
owner, e.g. external `Constant`\ s).
dtype : `dtype`, optional
The dtype used for storing the variable. Defaults to the preference
`core.default_scalar.dtype`.
Expand All @@ -115,7 +126,7 @@ class Variable(object):
for the variable ``N``, the number of neurons in a group). Defaults
to ``False``.
'''
def __init__(self, name, unit, dtype=None, scalar=False,
def __init__(self, name, unit, owner=None, dtype=None, scalar=False,
constant=False, read_only=False, dynamic=False):

#: The variable's unit.
Expand All @@ -124,6 +135,9 @@ def __init__(self, name, unit, dtype=None, scalar=False,
#: The variable's name.
self.name = name

#: The `Group` to which this variable belongs.
self.owner = weakproxy_with_fallback(owner) if owner is not None else None

#: The dtype used for storing the variable.
self.dtype = dtype
if dtype is None:
Expand Down Expand Up @@ -276,8 +290,12 @@ class Constant(Variable):
by value) should never have units attached.
value: reference to the variable value
The value of the constant.
owner : `Nameable`, optional
The object that "owns" this variable, for constants that belong to a
specific group, e.g. the ``N`` constant for a `NeuronGroup`. External
constants will have ``None`` (the default value).
'''
def __init__(self, name, unit, value):
def __init__(self, name, unit, value, owner=None):
# Determine the type of the value
is_bool = (value is True or
value is False or
Expand Down Expand Up @@ -307,7 +325,7 @@ def __init__(self, name, unit, value):
#: The constant's value
self.value = value

super(Constant, self).__init__(unit=unit, name=name,
super(Constant, self).__init__(unit=unit, name=name, owner=owner,
dtype=dtype, scalar=True, constant=True,
read_only=True)

Expand Down Expand Up @@ -371,16 +389,21 @@ class AttributeVariable(Variable):
dtype : `dtype`, optional
The dtype used for storing the variable. If none is given, defaults
to `core.default_float_dtype`.
owner : `Nameable`, optional
The object that "owns" this variable, e.g. the `NeuronGroup` to which
a ``dt`` value belongs (even if it is the attribute of a `Clock`
object). Defaults to ``None``.
constant : bool, optional
Whether the attribute's value is constant during a run. Defaults to
``False``.
scalar : bool, optional
Whether the variable is a scalar value (``True``) or vector-valued, e.g.
defined for every neuron (``False``). Defaults to ``True``.
'''
def __init__(self, name, unit, obj, attribute, dtype, constant=False,
def __init__(self, name, unit, obj, attribute, dtype=None, owner=None,
constant=False,
scalar=True):
super(AttributeVariable, self).__init__(unit=unit,
super(AttributeVariable, self).__init__(unit=unit, owner=owner,
name=name, dtype=dtype,
constant=constant,
scalar=scalar,
Expand Down Expand Up @@ -444,13 +467,11 @@ class ArrayVariable(Variable):
'''
def __init__(self, name, unit, owner, size, device, dtype=None,
constant=False, scalar=False, read_only=False, dynamic=False):
super(ArrayVariable, self).__init__(unit=unit, name=name,
super(ArrayVariable, self).__init__(unit=unit, name=name, owner=owner,
dtype=dtype, scalar=scalar,
constant=constant,
read_only=read_only,
dynamic=dynamic)
#: The `Group` to which this variable belongs.
self.owner = weakproxy_with_fallback(owner)

#: The `Device` responsible for memory access.
self.device = device
Expand Down Expand Up @@ -598,12 +619,10 @@ class Subexpression(Variable):
'''
def __init__(self, name, unit, owner, expr, device, dtype=None,
scalar=False):
super(Subexpression, self).__init__(unit=unit,
super(Subexpression, self).__init__(unit=unit, owner=owner,
name=name, dtype=dtype,
scalar=scalar,
constant=False, read_only=True)
#: The `Group` to which this variable belongs
self.owner = weakproxy_with_fallback(owner)

#: The `Device` responsible for memory access
self.device = device
Expand Down Expand Up @@ -1532,8 +1551,8 @@ def add_attribute_variable(self, name, unit, obj, attribute, dtype=None,
'of object "%r"') % (attribute, obj))
dtype = get_dtype(value)

var = AttributeVariable(name=name, unit=unit, obj=obj,
attribute=attribute, dtype=dtype,
var = AttributeVariable(name=name, unit=unit, owner=self.owner,
obj=obj, attribute=attribute, dtype=dtype,
constant=constant, scalar=scalar)
self._add_variable(name, var)

Expand All @@ -1551,7 +1570,7 @@ def add_constant(self, name, unit, value):
value: reference to the variable value
The value of the constant.
'''
var = Constant(name=name, unit=unit, value=value)
var = Constant(name=name, unit=unit, owner=self.owner, value=value)
self._add_variable(name, var)

def add_subexpression(self, name, unit, expr, dtype=None, scalar=False,
Expand Down
21 changes: 21 additions & 0 deletions brian2/tests/test_synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np

from brian2 import *
from brian2.core.variables import variables_by_owner
from brian2.utils.logger import catch_logs

def _compare(synapses, expected):
Expand Down Expand Up @@ -782,6 +783,25 @@ def test_repr():
assert len(func(S.equations))


def test_variables_by_owner():
# Test the `variables_by_owner` convenience function
G = NeuronGroup(10, 'v : 1')
G2 = NeuronGroup(10, '''v : 1
w : 1''')
S = Synapses(G, G2, 'x : 1')

# Check that the variables returned as owned by the pre/post groups are the
# variables stored in the respective groups. We only compare the `Variable`
# objects, as the names may be different (e.g. ``v_post`` vs. ``v``)
assert set(G.variables.values()) == set(variables_by_owner(S.variables, G).values())
assert set(G2.variables.values()) == set(variables_by_owner(S.variables, G2).values())
assert len(set(variables_by_owner(S.variables, S)) & set(G.variables.values())) == 0
assert len(set(variables_by_owner(S.variables, S)) & set(G2.variables.values())) == 0
# Just test a few examples for synaptic variables
assert all(varname in variables_by_owner(S.variables, S)
for varname in ['x', 'N', 'N_incoming', 'N_outgoing'])


if __name__ == '__main__':
test_creation()
test_incoming_outgoing()
Expand Down Expand Up @@ -809,3 +829,4 @@ def test_repr():
test_external_variables()
test_event_driven()
test_repr()
test_variables_by_owner()

0 comments on commit 88b0f67

Please sign in to comment.