Skip to content

Commit

Permalink
Merge pull request #91 from brian-team/specifiers_rewrite
Browse files Browse the repository at this point in the history
Specifiers rewrite
  • Loading branch information
thesamovar committed Aug 18, 2013
2 parents 15f380c + 5034ab4 commit a41f90e
Show file tree
Hide file tree
Showing 60 changed files with 1,703 additions and 1,634 deletions.
74 changes: 40 additions & 34 deletions brian2/codegen/codeobject.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import functools

from brian2.core.specifiers import (ArrayVariable, Value,
AttributeValue, Subexpression)
from brian2.core.variables import (ArrayVariable, Variable,
AttributeVariable, Subexpression,
StochasticVariable)
from .functions.base import Function
from brian2.core.preferences import brian_prefs, BrianPreference
from brian2.core.preferences import brian_prefs
from brian2.utils.logger import get_logger
from .translation import translate
from .runtime.targets import runtime_targets
Expand All @@ -30,20 +31,21 @@ def get_default_codeobject_class():
return codeobj_class


def prepare_namespace(namespace, specifiers):
def prepare_namespace(namespace, variables):
namespace = dict(namespace)
# Add variables referring to the arrays
arrays = []
for value in specifiers.itervalues():
for value in variables.itervalues():
if isinstance(value, ArrayVariable):
arrays.append((value.arrayname, value.get_value()))
namespace.update(arrays)

return namespace


def create_codeobject(name, abstract_code, namespace, specifiers, template_name,
codeobj_class=None, indices=None, template_kwds=None):
def create_codeobject(name, abstract_code, namespace, variables, template_name,
indices, variable_indices, codeobj_class=None,
template_kwds=None):
'''
The following arguments keywords are passed to the template:
Expand All @@ -54,8 +56,6 @@ def create_codeobject(name, abstract_code, namespace, specifiers, template_name,
``template_kwds`` (but you should ensure there are no name
clashes.
'''
if indices is None: # TODO: Do we ever create code without any index?
indices = {}

if template_kwds is None:
template_kwds = dict()
Expand All @@ -68,19 +68,22 @@ def create_codeobject(name, abstract_code, namespace, specifiers, template_name,
template = get_codeobject_template(template_name,
codeobj_class=codeobj_class)

namespace = prepare_namespace(namespace, specifiers)
namespace = prepare_namespace(namespace, variables)

logger.debug(name + " abstract code:\n" + abstract_code)
innercode, kwds = translate(abstract_code, specifiers, namespace,
brian_prefs['core.default_scalar_dtype'],
codeobj_class.language, indices)
iterate_all = template.iterate_all
innercode, kwds = translate(abstract_code, variables, namespace,
dtype=brian_prefs['core.default_scalar_dtype'],
language=codeobj_class.language,
variable_indices=variable_indices,
iterate_all=iterate_all)
template_kwds.update(kwds)
logger.debug(name + " inner code:\n" + str(innercode))
code = template(innercode, **template_kwds)
logger.debug(name + " code:\n" + str(code))

specifiers.update(indices)
codeobj = codeobj_class(code, namespace, specifiers)
variables.update(indices)
codeobj = codeobj_class(code, namespace, variables)
codeobj.compile()
return codeobj

Expand Down Expand Up @@ -111,40 +114,43 @@ class CodeObject(object):
#: The `Language` used by this `CodeObject`
language = None

def __init__(self, code, namespace, specifiers):
def __init__(self, code, namespace, variables):
self.code = code
self.compile_methods = self.get_compile_methods(specifiers)
self.compile_methods = self.get_compile_methods(variables)
self.namespace = namespace
self.specifiers = specifiers
self.variables = variables

# Specifiers can refer to values that are either constant (e.g. dt)
# Variables can refer to values that are either constant (e.g. dt)
# or change every timestep (e.g. t). We add the values of the
# constant specifiers here and add the names of non-constant specifiers
# constant variables here and add the names of non-constant variables
# to a list

# A list containing tuples of name and a function giving the value
self.nonconstant_values = []

for name, spec in self.specifiers.iteritems():
if isinstance(spec, Value):
if isinstance(spec, AttributeValue):
self.nonconstant_values.append((name, spec.get_value))
if not spec.scalar:
for name, var in self.variables.iteritems():
if isinstance(var, Variable) and not isinstance(var, Subexpression):
if not var.constant:
self.nonconstant_values.append((name, var.get_value))
if not var.scalar:
self.nonconstant_values.append(('_num' + name,
spec.get_len))
elif not isinstance(spec, Subexpression):
value = spec.get_value()
var.get_len))
else:
try:
value = var.get_value()
except TypeError: # A dummy Variable without value
continue
self.namespace[name] = value
# if it is a type that has a length, add a variable called
# '_num'+name with its length
if not spec.scalar:
self.namespace['_num' + name] = spec.get_len()
if not var.scalar:
self.namespace['_num' + name] = var.get_len()

def get_compile_methods(self, specifiers):
def get_compile_methods(self, variables):
meths = []
for var, spec in specifiers.items():
if isinstance(spec, Function):
meths.append(functools.partial(spec.on_compile,
for var, var in variables.items():
if isinstance(var, Function):
meths.append(functools.partial(var.on_compile,
language=self.language,
var=var))
return meths
Expand Down
13 changes: 8 additions & 5 deletions brian2/codegen/languages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Base class for languages, gives the methods which should be overridden to
implement a new language.
'''
from brian2.core.specifiers import (ArrayVariable, Value, AttributeValue,
from brian2.core.variables import (ArrayVariable, AttributeVariable,
Subexpression)
from brian2.utils.stringtools import get_identifiers

Expand Down Expand Up @@ -34,7 +34,8 @@ def translate_statement(self, statement):
'''
raise NotImplementedError

def translate_statement_sequence(self, statements, specifiers, namespace, indices):
def translate_statement_sequence(self, statements, variables, namespace,
variable_indices, iterate_all):
'''
Translate a sequence of `Statement` into the target language, taking
care to declare variables, etc. if necessary.
Expand All @@ -45,7 +46,7 @@ def translate_statement_sequence(self, statements, specifiers, namespace, indice
'''
raise NotImplementedError

def array_read_write(self, statements, specifiers):
def array_read_write(self, statements, variables):
'''
Helper function, gives the set of ArrayVariables that are read from and
written to in the series of statements. Returns the pair read, write
Expand All @@ -60,6 +61,8 @@ def array_read_write(self, statements, specifiers):
ids.add(stmt.var)
read = read.union(ids)
write.add(stmt.var)
read = set(var for var, spec in specifiers.items() if isinstance(spec, ArrayVariable) and var in read)
write = set(var for var, spec in specifiers.items() if isinstance(spec, ArrayVariable) and var in write)
read = set(varname for varname, var in variables.items()
if isinstance(var, ArrayVariable) and varname in read)
write = set(varname for varname, var in variables.items()
if isinstance(var, ArrayVariable) and varname in write)
return read, write
69 changes: 34 additions & 35 deletions brian2/codegen/languages/cpp_lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from brian2.utils.logger import get_logger
from brian2.parsing.rendering import CPPNodeRenderer
from brian2.core.preferences import brian_prefs, BrianPreference
from brian2.core.specifiers import ArrayVariable
from brian2.core.variables import ArrayVariable

from .base import Language

Expand Down Expand Up @@ -125,35 +125,38 @@ def translate_statement(self, statement):
decl = ''
return decl + var + ' ' + op + ' ' + self.translate_expression(expr) + ';'

def translate_statement_sequence(self, statements, specifiers, namespace, indices):
read, write = self.array_read_write(statements, specifiers)
def translate_statement_sequence(self, statements, variables, namespace,
variable_indices, iterate_all):

# Note that C++ code does not care about the iterate_all argument -- it
# always has to loop over the elements

read, write = self.array_read_write(statements, variables)
lines = []
# read arrays
for var in read:
index_var = specifiers[var].index
index_spec = indices[index_var]
spec = specifiers[var]
if var not in write:
for varname in read:
index_var = variable_indices[varname]
var = variables[varname]
if varname not in write:
line = 'const '
else:
line = ''
line = line + c_data_type(spec.dtype) + ' ' + var + ' = '
line = line + '_ptr' + spec.arrayname + '[' + index_var + '];'
line = line + c_data_type(var.dtype) + ' ' + varname + ' = '
line = line + '_ptr' + var.arrayname + '[' + index_var + '];'
lines.append(line)
# simply declare variables that will be written but not read
for var in write:
if var not in read:
spec = specifiers[var]
line = c_data_type(spec.dtype) + ' ' + var + ';'
for varname in write:
if varname not in read:
var = variables[varname]
line = c_data_type(var.dtype) + ' ' + varname + ';'
lines.append(line)
# the actual code
lines.extend([self.translate_statement(stmt) for stmt in statements])
# write arrays
for var in write:
index_var = specifiers[var].index
index_spec = indices[index_var]
spec = specifiers[var]
line = '_ptr' + spec.arrayname + '[' + index_var + '] = ' + var + ';'
for varname in write:
index_var = variable_indices[varname]
var = variables[varname]
line = '_ptr' + var.arrayname + '[' + index_var + '] = ' + varname + ';'
lines.append(line)
code = '\n'.join(lines)
# set up the restricted pointers, these are used so that the compiler
Expand All @@ -163,14 +166,11 @@ def translate_statement_sequence(self, statements, specifiers, namespace, indice
# same array. E.g. in gapjunction code, v_pre and v_post refer to the
# same array if a group is connected to itself
arraynames = set()
for var, spec in specifiers.iteritems():
if isinstance(spec, ArrayVariable):
arrayname = spec.arrayname
for varname, var in variables.iteritems():
if isinstance(var, ArrayVariable):
arrayname = var.arrayname
if not arrayname in arraynames:
if spec.dtype != spec.array.dtype:
print spec.array
raise AssertionError('Conflicting dtype information for %s: %s - %s' % (var, spec.dtype, spec.array.dtype))
line = c_data_type(spec.dtype) + ' * ' + self.restrict + '_ptr' + arrayname + ' = ' + arrayname + ';'
line = c_data_type(var.dtype) + ' * ' + self.restrict + '_ptr' + arrayname + ' = ' + arrayname + ';'
lines.append(line)
arraynames.add(arrayname)
pointers = '\n'.join(lines)
Expand All @@ -179,23 +179,22 @@ def translate_statement_sequence(self, statements, specifiers, namespace, indice
user_functions = []
support_code = ''
hash_defines = ''
for var, spec in itertools.chain(namespace.items(),
specifiers.items()):
if isinstance(spec, Function):
user_functions.append(var)
speccode = spec.code(self, var)
for varname, variable in namespace.items():
if isinstance(variable, Function):
user_functions.append(varname)
speccode = variable.code(self, varname)
support_code += '\n' + deindent(speccode['support_code'])
hash_defines += deindent(speccode['hashdefine_code'])
# add the Python function with a leading '_python', if it
# exists. This allows the function to make use of the Python
# function via weave if necessary (e.g. in the case of randn)
if not spec.pyfunc is None:
pyfunc_name = '_python_' + var
if pyfunc_name in namespace:
if not variable.pyfunc is None:
pyfunc_name = '_python_' + varname
if pyfunc_name in namespace:
logger.warn(('Namespace already contains function %s, '
'not replacing it') % pyfunc_name)
else:
namespace[pyfunc_name] = spec.pyfunc
namespace[pyfunc_name] = variable.pyfunc

# delete the user-defined functions from the namespace
for func in user_functions:
Expand Down
22 changes: 11 additions & 11 deletions brian2/codegen/languages/numpy_lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,26 @@ def translate_statement(self, statement):
op = '='
return var + ' ' + op + ' ' + self.translate_expression(expr)

def translate_statement_sequence(self, statements, specifiers, namespace, indices):
read, write = self.array_read_write(statements, specifiers)
def translate_statement_sequence(self, statements, variables, namespace,
variable_indices, iterate_all):
read, write = self.array_read_write(statements, variables)
lines = []
# read arrays
for var in read:
spec = specifiers[var]
index_spec = indices[spec.index]
spec = variables[var]
index = variable_indices[var]
line = var + ' = ' + spec.arrayname
if not index_spec.iterate_all:
line = line + '[' + spec.index + ']'
if not index in iterate_all:
line = line + '[' + index + ']'
lines.append(line)
# the actual code
lines.extend([self.translate_statement(stmt) for stmt in statements])
# write arrays
for var in write:
index_var = specifiers[var].index
index_spec = indices[index_var]
index_var = variable_indices[var]
# check if all operations were inplace and we're operating on the
# whole vector, if so we don't need to write the array back
if not index_spec.iterate_all:
if not index_var in iterate_all:
all_inplace = False
else:
all_inplace = True
Expand All @@ -54,8 +54,8 @@ def translate_statement_sequence(self, statements, specifiers, namespace, indice
all_inplace = False
break
if not all_inplace:
line = specifiers[var].arrayname
if index_spec.iterate_all:
line = variables[var].arrayname
if index_var in iterate_all:
line = line + '[:]'
else:
line = line + '[' + index_var + ']'
Expand Down
5 changes: 3 additions & 2 deletions brian2/codegen/runtime/numpy_rt/numpy_rt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

__all__ = ['NumpyCodeObject']


class NumpyCodeObject(CodeObject):
'''
Execute code using Numpy
Expand All @@ -18,10 +19,10 @@ class NumpyCodeObject(CodeObject):
'templates'))
language = NumpyLanguage()

def __init__(self, code, namespace, specifiers):
def __init__(self, code, namespace, variables):
# TODO: This should maybe go somewhere else
namespace['logical_not'] = np.logical_not
CodeObject.__init__(self, code, namespace, specifiers)
CodeObject.__init__(self, code, namespace, variables)

def compile(self):
super(NumpyCodeObject, self).compile()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# USE_SPECIFIERS { _post_synaptic, _synaptic_pre, _synaptic_post }
# USES_VARIABLES { _post_synaptic, _synaptic_pre, _synaptic_post }
# ITERATE_ALL { _idx }

import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion brian2/codegen/runtime/numpy_rt/templates/ratemonitor.py_
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# { USE_SPECIFIERS _rate, _t, _spikes, _num_source_neurons, t, dt }
# { USES_VARIABLES _rate, _t, _spikes, _num_source_neurons, t, dt }

_new_len = len(_t) + 1
_t.resize(_new_len)
Expand Down
6 changes: 3 additions & 3 deletions brian2/codegen/runtime/numpy_rt/templates/reset.py_
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# USE_SPECIFIERS { _spikes }
_neuron_idx = _spikes
_vectorisation_idx = _neuron_idx
# USES_VARIABLES { _spikes }
_idx = _spikes
_vectorisation_idx = _idx
{% for line in code_lines %}
{{line}}
{% endfor %}
Loading

0 comments on commit a41f90e

Please sign in to comment.