Skip to content

Commit

Permalink
Use the word "variable" instead of "specifier" everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcel Stimberg committed Aug 9, 2013
1 parent 851c302 commit 31cb60c
Show file tree
Hide file tree
Showing 32 changed files with 914 additions and 927 deletions.
64 changes: 33 additions & 31 deletions brian2/codegen/codeobject.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools

from brian2.core.specifiers import (ArrayVariable, Variable,
from brian2.core.variables import (ArrayVariable, Variable,
AttributeVariable, Subexpression,
StochasticVariable)
from .functions.base import Function
Expand Down Expand Up @@ -31,19 +31,19 @@ def get_default_codeobject_class():
return codeobj_class


def prepare_namespace(namespace, specifiers):
def prepare_namespace(namespace, variables):
namespace = dict(namespace)
# Add variables referring to the arrays
arrays = []
for value in specifiers.itervalues():
for value in variables.itervalues():
if isinstance(value, ArrayVariable):
arrays.append((value.arrayname, value.get_value()))
namespace.update(arrays)

return namespace


def create_codeobject(name, abstract_code, namespace, specifiers, template_name,
def create_codeobject(name, abstract_code, namespace, variables, template_name,
indices, variable_indices, iterate_all,
codeobj_class=None,
template_kwds=None):
Expand All @@ -69,10 +69,10 @@ def create_codeobject(name, abstract_code, namespace, specifiers, template_name,
template = get_codeobject_template(template_name,
codeobj_class=codeobj_class)

namespace = prepare_namespace(namespace, specifiers)
namespace = prepare_namespace(namespace, variables)

logger.debug(name + " abstract code:\n" + abstract_code)
innercode, kwds = translate(abstract_code, specifiers, namespace,
innercode, kwds = translate(abstract_code, variables, namespace,
dtype=brian_prefs['core.default_scalar_dtype'],
language=codeobj_class.language,
variable_indices=variable_indices,
Expand All @@ -82,8 +82,8 @@ def create_codeobject(name, abstract_code, namespace, specifiers, template_name,
code = template(innercode, **template_kwds)
logger.debug(name + " code:\n" + str(code))

specifiers.update(indices)
codeobj = codeobj_class(code, namespace, specifiers)
variables.update(indices)
codeobj = codeobj_class(code, namespace, variables)
codeobj.compile()
return codeobj

Expand Down Expand Up @@ -114,40 +114,42 @@ class CodeObject(object):
#: The `Language` used by this `CodeObject`
language = None

def __init__(self, code, namespace, specifiers):
def __init__(self, code, namespace, variables):
self.code = code
self.compile_methods = self.get_compile_methods(specifiers)
self.compile_methods = self.get_compile_methods(variables)
self.namespace = namespace
self.specifiers = specifiers
self.variables = variables

# Specifiers can refer to values that are either constant (e.g. dt)
# or change every timestep (e.g. t). We add the values of the
# constant specifiers here and add the names of non-constant specifiers
# constant variables here and add the names of non-constant variables
# to a list

# A list containing tuples of name and a function giving the value
self.nonconstant_values = []

for name, spec in self.specifiers.iteritems():
if isinstance(spec, Variable):
if not spec.constant:
self.nonconstant_values.append((name, spec.get_value))
if not spec.scalar:
self.nonconstant_values.append(('_num' + name,
spec.get_len))
else:
value = spec.get_value()
self.namespace[name] = value
# if it is a type that has a length, add a variable called
# '_num'+name with its length
if not spec.scalar:
self.namespace['_num' + name] = spec.get_len()

def get_compile_methods(self, specifiers):
for name, var in self.variables.iteritems():
if not var.constant:
self.nonconstant_values.append((name, var.get_value))
if not var.scalar:
self.nonconstant_values.append(('_num' + name,
var.get_len))
else:
try:
value = var.get_value()
except TypeError: # A dummy Variable without unit
continue
self.namespace[name] = value
# if it is a type that has a length, add a variable called
# '_num'+name with its length
if not var.scalar:
self.namespace['_num' + name] = var.get_len()

def get_compile_methods(self, variables):
meths = []
for var, spec in specifiers.items():
if isinstance(spec, Function):
meths.append(functools.partial(spec.on_compile,
for var, var in variables.items():
if isinstance(var, Function):
meths.append(functools.partial(var.on_compile,
language=self.language,
var=var))
return meths
Expand Down
12 changes: 7 additions & 5 deletions brian2/codegen/languages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Base class for languages, gives the methods which should be overridden to
implement a new language.
'''
from brian2.core.specifiers import (ArrayVariable, AttributeVariable,
from brian2.core.variables import (ArrayVariable, AttributeVariable,
Subexpression)
from brian2.utils.stringtools import get_identifiers

Expand Down Expand Up @@ -34,7 +34,7 @@ def translate_statement(self, statement):
'''
raise NotImplementedError

def translate_statement_sequence(self, statements, specifiers, namespace, indices):
def translate_statement_sequence(self, statements, variables, namespace, indices):
'''
Translate a sequence of `Statement` into the target language, taking
care to declare variables, etc. if necessary.
Expand All @@ -45,7 +45,7 @@ def translate_statement_sequence(self, statements, specifiers, namespace, indice
'''
raise NotImplementedError

def array_read_write(self, statements, specifiers):
def array_read_write(self, statements, variables):
'''
Helper function, gives the set of ArrayVariables that are read from and
written to in the series of statements. Returns the pair read, write
Expand All @@ -60,6 +60,8 @@ def array_read_write(self, statements, specifiers):
ids.add(stmt.var)
read = read.union(ids)
write.add(stmt.var)
read = set(var for var, spec in specifiers.items() if isinstance(spec, ArrayVariable) and var in read)
write = set(var for var, spec in specifiers.items() if isinstance(spec, ArrayVariable) and var in write)
read = set(varname for varname, var in variables.items()
if isinstance(var, ArrayVariable) and varname in read)
write = set(varname for varname, var in variables.items()
if isinstance(var, ArrayVariable) and varname in write)
return read, write
59 changes: 29 additions & 30 deletions brian2/codegen/languages/cpp_lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from brian2.utils.logger import get_logger
from brian2.parsing.rendering import CPPNodeRenderer
from brian2.core.preferences import brian_prefs, BrianPreference
from brian2.core.specifiers import ArrayVariable
from brian2.core.variables import ArrayVariable

from .base import Language

Expand Down Expand Up @@ -125,38 +125,38 @@ def translate_statement(self, statement):
decl = ''
return decl + var + ' ' + op + ' ' + self.translate_expression(expr) + ';'

def translate_statement_sequence(self, statements, specifiers, namespace,
def translate_statement_sequence(self, statements, variables, namespace,
variable_indices, iterate_all):

# Note that C++ code does not care about the iterate_all argument -- it
# always has to loop over the elements

read, write = self.array_read_write(statements, specifiers)
read, write = self.array_read_write(statements, variables)
lines = []
# read arrays
for var in read:
index_var = variable_indices[specifiers[var]] + '_idx'
spec = specifiers[var]
if var not in write:
for varname in read:
index_var = variable_indices[variables[varname]] + '_idx'
var = variables[varname]
if varname not in write:
line = 'const '
else:
line = ''
line = line + c_data_type(spec.dtype) + ' ' + var + ' = '
line = line + '_ptr' + spec.arrayname + '[' + index_var + '];'
line = line + c_data_type(var.dtype) + ' ' + varname + ' = '
line = line + '_ptr' + var.arrayname + '[' + index_var + '];'
lines.append(line)
# simply declare variables that will be written but not read
for var in write:
if var not in read:
spec = specifiers[var]
line = c_data_type(spec.dtype) + ' ' + var + ';'
for varname in write:
if varname not in read:
var = variables[varname]
line = c_data_type(var.dtype) + ' ' + varname + ';'
lines.append(line)
# the actual code
lines.extend([self.translate_statement(stmt) for stmt in statements])
# write arrays
for var in write:
index_var = variable_indices[specifiers[var]] + '_idx'
spec = specifiers[var]
line = '_ptr' + spec.arrayname + '[' + index_var + '] = ' + var + ';'
for varname in write:
index_var = variable_indices[variables[varname]] + '_idx'
var = variables[varname]
line = '_ptr' + var.arrayname + '[' + index_var + '] = ' + varname + ';'
lines.append(line)
code = '\n'.join(lines)
# set up the restricted pointers, these are used so that the compiler
Expand All @@ -166,11 +166,11 @@ def translate_statement_sequence(self, statements, specifiers, namespace,
# same array. E.g. in gapjunction code, v_pre and v_post refer to the
# same array if a group is connected to itself
arraynames = set()
for var, spec in specifiers.iteritems():
if isinstance(spec, ArrayVariable):
arrayname = spec.arrayname
for varname, var in variables.iteritems():
if isinstance(var, ArrayVariable):
arrayname = var.arrayname
if not arrayname in arraynames:
line = c_data_type(spec.dtype) + ' * ' + self.restrict + '_ptr' + arrayname + ' = ' + arrayname + ';'
line = c_data_type(var.dtype) + ' * ' + self.restrict + '_ptr' + arrayname + ' = ' + arrayname + ';'
lines.append(line)
arraynames.add(arrayname)
pointers = '\n'.join(lines)
Expand All @@ -179,23 +179,22 @@ def translate_statement_sequence(self, statements, specifiers, namespace,
user_functions = []
support_code = ''
hash_defines = ''
for var, spec in itertools.chain(namespace.items(),
specifiers.items()):
if isinstance(spec, Function):
user_functions.append(var)
speccode = spec.code(self, var)
for varname, variable in namespace.items():
if isinstance(variable, Function):
user_functions.append(varname)
speccode = variable.code(self, varname)
support_code += '\n' + deindent(speccode['support_code'])
hash_defines += deindent(speccode['hashdefine_code'])
# add the Python function with a leading '_python', if it
# exists. This allows the function to make use of the Python
# function via weave if necessary (e.g. in the case of randn)
if not spec.pyfunc is None:
pyfunc_name = '_python_' + var
if pyfunc_name in namespace:
if not variable.pyfunc is None:
pyfunc_name = '_python_' + varname
if pyfunc_name in namespace:
logger.warn(('Namespace already contains function %s, '
'not replacing it') % pyfunc_name)
else:
namespace[pyfunc_name] = spec.pyfunc
namespace[pyfunc_name] = variable.pyfunc

# delete the user-defined functions from the namespace
for func in user_functions:
Expand Down
10 changes: 5 additions & 5 deletions brian2/codegen/languages/numpy_lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ def translate_statement(self, statement):
op = '='
return var + ' ' + op + ' ' + self.translate_expression(expr)

def translate_statement_sequence(self, statements, specifiers, namespace,
def translate_statement_sequence(self, statements, variables, namespace,
variable_indices, iterate_all):
read, write = self.array_read_write(statements, specifiers)
read, write = self.array_read_write(statements, variables)
lines = []
# read arrays
for var in read:
spec = specifiers[var]
spec = variables[var]
index = variable_indices[spec]
line = var + ' = ' + spec.arrayname
if not index in iterate_all:
Expand All @@ -42,7 +42,7 @@ def translate_statement_sequence(self, statements, specifiers, namespace,
lines.extend([self.translate_statement(stmt) for stmt in statements])
# write arrays
for var in write:
index_var = variable_indices[specifiers[var]]
index_var = variable_indices[variables[var]]
# check if all operations were inplace and we're operating on the
# whole vector, if so we don't need to write the array back
if not index_var in iterate_all:
Expand All @@ -54,7 +54,7 @@ def translate_statement_sequence(self, statements, specifiers, namespace,
all_inplace = False
break
if not all_inplace:
line = specifiers[var].arrayname
line = variables[var].arrayname
if index_var in iterate_all:
line = line + '[:]'
else:
Expand Down
4 changes: 2 additions & 2 deletions brian2/codegen/runtime/numpy_rt/numpy_rt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ class NumpyCodeObject(CodeObject):
'templates'))
language = NumpyLanguage()

def __init__(self, code, namespace, specifiers):
def __init__(self, code, namespace, variables):
# TODO: This should maybe go somewhere else
namespace['logical_not'] = np.logical_not
CodeObject.__init__(self, code, namespace, specifiers)
CodeObject.__init__(self, code, namespace, variables)

def compile(self):
super(NumpyCodeObject, self).compile()
Expand Down
4 changes: 2 additions & 2 deletions brian2/codegen/runtime/weave_rt/weave_rt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class WeaveCodeObject(CodeObject):
'templates'))
language = CPPLanguage()

def __init__(self, code, namespace, specifiers):
super(WeaveCodeObject, self).__init__(code, namespace, specifiers)
def __init__(self, code, namespace, variables):
super(WeaveCodeObject, self).__init__(code, namespace, variables)
self.compiler = brian_prefs['codegen.runtime.weave.compiler']
self.extra_compile_args = brian_prefs['codegen.runtime.weave.extra_compile_args']

Expand Down
6 changes: 3 additions & 3 deletions brian2/codegen/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ def __init__(self, template):
self.words = set([])
for v in temps:
self.words.update(get_identifiers(v))
#: The set of specifiers in this template
self.specifiers = set([])
#: The set of variables in this template
self.variables = set([])
for v in temps:
# This is the bit inside {} for USE_SPECIFIERS { list of words }
specifier_blocks = re.findall(r'\bUSE_SPECIFIERS\b\s*\{(.*?)\}',
v, re.M|re.S)
for block in specifier_blocks:
self.specifiers.update(get_identifiers(block))
self.variables.update(get_identifiers(block))

def __call__(self, code_lines, **kwds):
kwds['code_lines'] = code_lines
Expand Down
Loading

0 comments on commit 31cb60c

Please sign in to comment.