Skip to content

Commit

Permalink
Merge pull request #832 from brian-team/stateupdate_caching
Browse files Browse the repository at this point in the history
Caching and small run preparation improvements
  • Loading branch information
mstimberg committed Sep 8, 2017
2 parents 91c9ba5 + 404779d commit e1fa9e9
Show file tree
Hide file tree
Showing 22 changed files with 478 additions and 228 deletions.
5 changes: 3 additions & 2 deletions brian2/codegen/generators/base.py
Expand Up @@ -249,6 +249,7 @@ def translate(self, code, dtype):
'you are sure that the order of operations does not '
'matter. ' + error_msg))

return self.translate_statement_sequence(scalar_statements,
vector_statements)
translated = self.translate_statement_sequence(scalar_statements,
vector_statements)

return translated
10 changes: 0 additions & 10 deletions brian2/codegen/generators/cpp_generator.py
Expand Up @@ -390,16 +390,6 @@ def determine_keywords(self):
support_code.extend(sc)
pointers.extend(ps)
hash_defines.extend(hd)


# delete the user-defined functions from the namespace and add the
# function namespaces (if any)
for funcname, func in user_functions:
del self.variables[funcname]
func_namespace = func.implementations[self.codeobj_class].get_namespace(self.owner)
if func_namespace is not None:
self.variables.update(func_namespace)

support_code.append(self.universal_support_code)


Expand Down
8 changes: 0 additions & 8 deletions brian2/codegen/generators/cython_generator.py
Expand Up @@ -285,14 +285,6 @@ def determine_keywords(self):
# fallback to Python object
load_namespace.append('{0} = _namespace["{1}"]'.format(varname, varname))

# delete the user-defined functions from the namespace and add the
# function namespaces (if any)
for funcname, func in user_functions:
del self.variables[funcname]
func_namespace = func.implementations[self.codeobj_class].get_namespace(self.owner)
if func_namespace is not None:
self.variables.update(func_namespace)

return {'load_namespace': '\n'.join(load_namespace),
'support_code_lines': support_code}

Expand Down
7 changes: 0 additions & 7 deletions brian2/codegen/generators/numpy_generator.py
Expand Up @@ -273,13 +273,6 @@ def translate_one_statement_sequence(self, statements, scalar=False):
lines.extend(self.vectorise_code(statements, variables,
variable_indices))

# Make sure we do not use the __call__ function of Function objects but
# rather the Python function stored internally. The __call__ function
# would otherwise return values with units
for varname, var in variables.iteritems():
if isinstance(var, Function):
variables[varname] = var.implementations[self.codeobj_class].get_code(self.owner)

return lines

def determine_keywords(self):
Expand Down
12 changes: 12 additions & 0 deletions brian2/codegen/runtime/cython_rt/cython_rt.py
Expand Up @@ -5,6 +5,7 @@

from brian2.core.variables import (DynamicArrayVariable, ArrayVariable,
AuxiliaryVariable, Subexpression)
from brian2.core.functions import Function
from brian2.core.preferences import prefs, BrianPreference
from brian2.utils.logger import get_logger
from brian2.utils.stringtools import get_identifiers
Expand Down Expand Up @@ -114,6 +115,15 @@ def run(self):

# the following are copied from WeaveCodeObject

def _insert_func_namespace(self, func):
impl = func.implementations[self]
func_namespace = impl.get_namespace(self.owner)
if func_namespace is not None:
self.namespace.update(func_namespace)
if impl.dependencies is not None:
for dep in impl.dependencies.itervalues():
self._insert_func_namespace(dep)

def variables_to_namespace(self):

# Variables can refer to values that are either constant (e.g. dt)
Expand All @@ -125,6 +135,8 @@ def variables_to_namespace(self):
self.nonconstant_values = []

for name, var in self.variables.iteritems():
if isinstance(var, Function):
self._insert_func_namespace(var)
if isinstance(var, (AuxiliaryVariable, Subexpression)):
continue
try:
Expand Down
9 changes: 7 additions & 2 deletions brian2/codegen/runtime/numpy_rt/numpy_rt.py
Expand Up @@ -10,6 +10,7 @@
from brian2.core.preferences import prefs, BrianPreference
from brian2.core.variables import (DynamicArrayVariable, ArrayVariable,
AuxiliaryVariable, Subexpression)
from brian2.core.functions import Function

from ...codeobject import CodeObject, constant_or_scalar

Expand Down Expand Up @@ -188,8 +189,12 @@ def variables_to_namespace(self):
raise TypeError()
value = var.get_value()
except TypeError:
# A dummy Variable without value or a function
self.namespace[name] = var
# Either a dummy Variable without a value or a Function object
if isinstance(var, Function):
impl = var.implementations[self.__class__].get_code(self.owner)
self.namespace[name] = impl
else:
self.namespace[name] = var
continue

if isinstance(var, ArrayVariable):
Expand Down
15 changes: 14 additions & 1 deletion brian2/codegen/runtime/weave_rt/weave_rt.py
Expand Up @@ -162,6 +162,15 @@ def is_available(cls):
'failed_compile_test')
return False

def _insert_func_namespace(self, func):
impl = func.implementations[self]
func_namespace = impl.get_namespace(self.owner)
if func_namespace is not None:
self.namespace.update(func_namespace)
if impl.dependencies is not None:
for dep in impl.dependencies.itervalues():
self._insert_func_namespace(dep)

def variables_to_namespace(self):

# Variables can refer to values that are either constant (e.g. dt)
Expand All @@ -173,7 +182,11 @@ def variables_to_namespace(self):
self.nonconstant_values = []

for name, var in self.variables.iteritems():
if isinstance(var, (AuxiliaryVariable, Subexpression, Function)):
if isinstance(var, Function):
self._insert_func_namespace(var)
continue # Everything else has already been dealt with in the
# CodeGenerator (support code, renaming, etc.)
elif isinstance(var, (AuxiliaryVariable, Subexpression)):
continue
try:
value = var.get_value()
Expand Down
9 changes: 5 additions & 4 deletions brian2/codegen/translation.py
Expand Up @@ -22,14 +22,13 @@

from brian2.core.preferences import prefs
from brian2.core.variables import Variable, Subexpression, AuxiliaryVariable
from brian2.utils.caching import cached
from brian2.core.functions import Function
from brian2.utils.stringtools import (deindent, strip_empty_lines,
get_identifiers)
from brian2.utils.topsort import topsort
from brian2.units.fundamentalunits import Unit, DIMENSIONLESS
from brian2.parsing.statements import parse_statement
from brian2.parsing.sympytools import (str_to_sympy, sympy_to_str,
check_expression_for_multiple_stateful_functions)
from brian2.parsing.sympytools import str_to_sympy, sympy_to_str

from .statements import Statement
from .optimisation import optimise_statements
Expand Down Expand Up @@ -166,9 +165,11 @@ def is_scalar_expression(expr, variables):
(isinstance(variables[name], Function) and variables[name].stateless)
for name in identifiers)


@cached
def make_statements(code, variables, dtype, optimise=True, blockname=''):
'''
make_statements(code, variables, dtype, optimise=True, blockname='')
Turn a series of abstract code statements into Statement objects, inferring
whether each line is a set/declare operation, whether the variables are
constant or not, and handling the cacheing of subexpressions.
Expand Down
18 changes: 13 additions & 5 deletions brian2/core/variables.py
Expand Up @@ -6,14 +6,15 @@
import functools
import numbers

import sympy
import numpy as np
import sympy

from brian2.utils.stringtools import get_identifiers, word_substitute
from brian2.units.fundamentalunits import (Quantity, get_unit, DIMENSIONLESS,
fail_for_dimension_mismatch,
Dimension)
from brian2.utils.logger import get_logger
from brian2.utils.stringtools import get_identifiers, word_substitute
from brian2.utils.caching import CacheKey

from .base import weakproxy_with_fallback, device_override
from .preferences import prefs
Expand Down Expand Up @@ -89,7 +90,7 @@ def variables_by_owner(variables, owner):
if getattr(var.owner, 'name', None) is owner_name])


class Variable(object):
class Variable(CacheKey):
'''
An object providing information about model variables (including implicit
variables such as ``t`` or ``xi``). This class should never be
Expand Down Expand Up @@ -127,6 +128,9 @@ class Variable(object):
Whether this variable is an array. Allows for simpler check than testing
``isinstance(var, ArrayVariable)``. Defaults to ``False``.
'''

_cache_irrelevant_attributes = {'owner'}

def __init__(self, name, dimensions=DIMENSIONLESS, owner=None, dtype=None,
scalar=False, constant=False, read_only=False, dynamic=False,
array=False):
Expand Down Expand Up @@ -275,6 +279,7 @@ def __repr__(self):
read_only=repr(self.read_only))



# ------------------------------------------------------------------------------
# Concrete classes derived from `Variable` -- these are the only ones ever
# instantiated.
Expand Down Expand Up @@ -514,6 +519,10 @@ class DynamicArrayVariable(ArrayVariable):
corresponding pre- and post-synaptic indices are not). Defaults to
``False``.
'''
# The size of a dynamic variable can of course change and changes in
# size should not invalidate the cache
cache_irrelevant_attributes = (ArrayVariable._cache_irrelevant_attributes |
{'size'})

def __init__(self, name, owner, size, device, dimensions=DIMENSIONLESS,
dtype=None, constant=False, needs_reference_update=False,
Expand Down Expand Up @@ -550,6 +559,7 @@ def __init__(self, name, owner, size, device, dimensions=DIMENSIONLESS,
read_only=read_only,
unique=unique)


@property
def dimensions(self):
logger.warn('The DynamicArrayVariable.dimensions attribute is '
Expand Down Expand Up @@ -653,7 +663,6 @@ def __repr__(self):
expr=repr(self.expr),
owner=self.owner.name)


# ------------------------------------------------------------------------------
# Classes providing views on variables and storing variables information
# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -1364,7 +1373,6 @@ def dtype(self):
return self.get_item(slice(None), level=1).dtype



class Variables(collections.Mapping):
'''
A container class for storing `Variable` objects. Instances of this class
Expand Down
30 changes: 18 additions & 12 deletions brian2/devices/cpp_standalone/device.py
Expand Up @@ -19,6 +19,7 @@
from brian2.codegen.cpp_prefs import get_compiler_and_args
from brian2.core.network import Network
from brian2.devices.device import Device, all_devices, set_device, reset_device
from brian2.core.functions import Function
from brian2.core.variables import *
from brian2.core.namespace import get_local_namespace
from brian2.groups.group import Group
Expand Down Expand Up @@ -785,18 +786,23 @@ def copy_source_files(self, writer, directory):
'random', 'randomkit', 'randomkit.h'),
os.path.join(directory, 'brianlib', 'randomkit', 'randomkit.h'))

def _insert_func_namespace(self, func, code_object, namespace):
impl = func.implementations[CPPStandaloneCodeObject]
func_namespace = impl.get_namespace(code_object.owner)
if func_namespace is not None:
namespace.update(func_namespace)
if impl.dependencies is not None:
for dep in impl.dependencies.itervalues():
self._insert_func_namespace(dep, code_object, namespace)

def write_static_arrays(self, directory):
# # Find np arrays in the namespaces and convert them into static
# # arrays. Hopefully they are correctly used in the code: For example,
# # this works for the namespaces for functions with C++ (e.g. TimedArray
# # treats it as a C array) but does not work in places that are
# # implicitly vectorized (state updaters, resets, etc.). But arrays
# # shouldn't be used there anyway.
# Write Function namespaces as static arrays
for code_object in self.code_objects.itervalues():
for name, value in code_object.variables.iteritems():
if isinstance(value, np.ndarray):
self.static_arrays[name] = value

for var in code_object.variables.itervalues():
if isinstance(var, Function):
self._insert_func_namespace(var, code_object,
self.static_arrays)

logger.diagnostic("static arrays: "+str(sorted(self.static_arrays.keys())))

static_array_specs = []
Expand Down Expand Up @@ -853,7 +859,7 @@ def compile_source(self, directory, compiler, debug, clean):
os.remove('winmake.log')
with std_silent(debug):
if clean:
os.system('%s >>winmake.log 2>&1 && %s clean >>winmake.log 2>&1' % (vcvars_cmd, make_cmd))
os.system('%s >>winmake.log 2>&1 && %s clean > NUL 2>&1' % (vcvars_cmd, make_cmd))
x = os.system('%s >>winmake.log 2>&1 && %s %s>>winmake.log 2>&1' % (vcvars_cmd,
make_cmd,
make_args))
Expand All @@ -870,7 +876,7 @@ def compile_source(self, directory, compiler, debug, clean):
else:
with std_silent(debug):
if clean:
os.system('make clean')
os.system('make clean >/dev/null 2>&1')
if debug:
x = os.system('make debug')
else:
Expand Down
4 changes: 1 addition & 3 deletions brian2/devices/device.py
Expand Up @@ -268,7 +268,7 @@ def code_object_class(self, codeobj_class=None, fallback_pref='codegen.target'):
def code_object(self, owner, name, abstract_code, variables, template_name,
variable_indices, codeobj_class=None,
template_kwds=None, override_conditional_write=None):

name = find_name(name)
codeobj_class = self.code_object_class(codeobj_class)
template = getattr(codeobj_class.templater, template_name)
iterate_all = template.iterate_all
Expand Down Expand Up @@ -320,8 +320,6 @@ def code_object(self, owner, name, abstract_code, variables, template_name,
logger.diagnostic('%s snippet (scalar):\n%s' % (name, indent(code_representation(scalar_code))))
logger.diagnostic('%s snippet (vector):\n%s' % (name, indent(code_representation(vector_code))))

name = find_name(name)

code = template(scalar_code, vector_code,
owner=owner, variables=variables, codeobj_name=name,
variable_indices=variable_indices,
Expand Down
21 changes: 18 additions & 3 deletions brian2/equations/codestrings.py
Expand Up @@ -3,6 +3,8 @@
information about its namespace. Only serves as a parent class, its subclasses
`Expression` and `Statements` are the ones that are actually used.
'''
import collections

import sympy

from brian2.utils.logger import get_logger
Expand All @@ -14,7 +16,7 @@
logger = get_logger(__name__)


class CodeString(object):
class CodeString(collections.Hashable):
'''
A class for representing "code strings", i.e. a single Python expression
or a sequence of Python statements.
Expand All @@ -29,18 +31,31 @@ class CodeString(object):

def __init__(self, code):

# : The code string
self.code = code
self._code = code

# : Set of identifiers in the code string
self.identifiers = get_identifiers(code)

code = property(lambda self: self._code,
doc='The code string')

def __str__(self):
return self.code

def __repr__(self):
return '%s(%r)' % (self.__class__.__name__, self.code)

def __eq__(self, other):
if not isinstance(other, CodeString):
return NotImplemented
return self.code == other.code

def __ne__(self, other):
return not self == other

def __hash__(self):
return hash(self.code)


class Statements(CodeString):
'''
Expand Down

0 comments on commit e1fa9e9

Please sign in to comment.