Skip to content

Commit

Permalink
Merge pull request #75 from brian-team/multiple_template_arguments
Browse files Browse the repository at this point in the history
Added arbitrary keywords option to create_codeobj
  • Loading branch information
Marcel Stimberg committed Jul 10, 2013
2 parents a2fab76 + 551cc7a commit 7022bff
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
19 changes: 17 additions & 2 deletions brian2/codegen/languages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,33 @@ def translate_statement_sequence(self, statements, specifiers, namespace, indice
raise NotImplementedError

def create_codeobj(self, name, abstract_code, namespace, specifiers,
template, indices=None):
template, indices=None, template_kwds=None):
'''
The following arguments keywords are passed to the template:
* code_lines coming from translation applied to abstract_code, a list
of lines of code, given to the template as ``code_lines`` keyword.
* ``template_kwds`` dict
* ``kwds`` coming from `translate` function overwrite those in
``template_kwds`` (but you should ensure there are no name
clashes.
'''
if indices is None: # TODO: Do we ever create code without any index?
indices = {}
if template_kwds is None:
template_kwds = dict()
else:
template_kwds = template_kwds.copy()

namespace = self.prepare_namespace(namespace, specifiers)

logger.debug(name + " abstract code:\n" + abstract_code)
innercode, kwds = translate(abstract_code, specifiers, namespace,
brian_prefs['core.default_scalar_dtype'],
self, indices)
template_kwds.update(kwds)
logger.debug(name + " inner code:\n" + str(innercode))
code = template(innercode, **kwds)
code = template(innercode, **template_kwds)
logger.debug(name + " code:\n" + str(code))

specifiers.update(indices)
Expand Down
15 changes: 11 additions & 4 deletions brian2/groups/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def __setattr__(self, name, val):


def _create_codeobj(group, name, code, additional_namespace=None,
template=None, iterate_all=True, check_units=True):
template=None, template_kwds=None, iterate_all=True,
check_units=True):
''' A little helper function to reduce the amount of repetition when
calling the language's _create_codeobj (always pass self.specifiers and
self.namespace + additional namespace).
Expand Down Expand Up @@ -128,7 +129,8 @@ def _create_codeobj(group, name, code, additional_namespace=None,
template,
indices={'_neuron_idx':
Index('_neuron_idx',
iterate_all)})
iterate_all)},
template_kwds=template_kwds)


class GroupCodeRunner(BrianObject):
Expand Down Expand Up @@ -165,6 +167,8 @@ class GroupCodeRunner(BrianObject):
updaters (units are already checked for the equations and the generated
abstract code might have already replaced variables with their unit-less
values)
template_kwds : dict, optional
A dictionary of additional information that is passed to the template.
Notes
-----
Expand All @@ -178,13 +182,15 @@ class GroupCodeRunner(BrianObject):
`NeuronGroup.spikes` property in `post_update`.
'''
def __init__(self, group, template, code=None, iterate_all=True,
when=None, name='coderunner*', check_units=True):
when=None, name='coderunner*', check_units=True,
template_kwds=None):
BrianObject.__init__(self, when=when, name=name)
self.group = weakref.proxy(group)
self.template = template
self.abstract_code = code
self.iterate_all = iterate_all
self.check_units = check_units
self.template_kwds = template_kwds
# Try to generate the abstract code and the codeobject without any
# additional namespace. This might work in situations where the
# namespace is completely defined in the NeuronGroup. In this case,
Expand Down Expand Up @@ -214,7 +220,8 @@ def pre_run(self, namespace):
additional_namespace=namespace,
template=self.template,
iterate_all=self.iterate_all,
check_units=self.check_units)
check_units=self.check_units,
template_kwds=self.template_kwds)

def pre_update(self):
'''
Expand Down
3 changes: 2 additions & 1 deletion brian2/groups/neurongroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def update_abstract_code(self, additional_namespace):
namespace = dict(self.group.namespace)
if additional_namespace is not None:
namespace.update(additional_namespace[1])
unit = parse_expression_unit(ref, namespace, self.group.specifiers)
unit = parse_expression_unit(str(ref), namespace,
self.group.specifiers)
if have_same_dimensions(unit, second):
self.abstract_code = 'not_refractory = (t - lastspike) > %s\n' % ref
elif have_same_dimensions(unit, Unit(1)):
Expand Down

0 comments on commit 7022bff

Please sign in to comment.