Skip to content

Commit

Permalink
Merge a0d0f24 into 1660f35
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Nov 3, 2020
2 parents 1660f35 + a0d0f24 commit c0235e5
Show file tree
Hide file tree
Showing 9 changed files with 476 additions and 83 deletions.
109 changes: 103 additions & 6 deletions brian2/equations/codestrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
information about its namespace. Only serves as a parent class, its subclasses
`Expression` and `Statements` are the ones that are actually used.
'''

try:
from collections.abc import Hashable
except ImportError:
from collections import Hashable
import re
import string
from collections.abc import Hashable
from typing import Sequence
import numbers

import sympy
import numpy as np

from brian2.utils.logger import get_logger
from brian2.utils.stringtools import get_identifiers
from brian2.utils.topsort import topsort
from brian2.parsing.sympytools import str_to_sympy, sympy_to_str

__all__ = ['Expression', 'Statements']
Expand All @@ -38,6 +40,7 @@ def __init__(self, code):

# : Set of identifiers in the code string
self.identifiers = get_identifiers(code)
self.template_identifiers = get_identifiers(code, template=True)

code = property(lambda self: self._code,
doc='The code string')
Expand Down Expand Up @@ -80,6 +83,11 @@ class Statements(CodeString):
pass


class Default(dict):
def __missing__(self, key):
return f'{{{key}}}'


class Expression(CodeString):
'''
Class for representing an expression.
Expand All @@ -103,7 +111,7 @@ def __init__(self, code=None, sympy_expression=None):

if code is None:
code = sympy_to_str(sympy_expression)
else:
elif '{' not in code:
# Just try to convert it to a sympy expression to get syntax errors
# for incorrect expressions
str_to_sympy(code)
Expand Down Expand Up @@ -196,6 +204,95 @@ def __ne__(self, other):
def __hash__(self):
return hash(self.code)

def _do_substitution(self, to_replace, replacement):
# Replacements can be lists, deal with single replacements
# as single-element lists
replaced_name = False
replaced_placeholder = False
if not isinstance(replacement, Sequence) or isinstance(replacement, str):
replacement = [replacement]
replacement_strs = []
for one_replacement in replacement:
if isinstance(one_replacement, str):
if any(c not in string.ascii_letters + '_{}'
for c in one_replacement):
# Check whether the replacement can be interpreted as an expression
try:
expr = Expression(one_replacement)
replacement_strs.append(expr.code)
except SyntaxError:
raise SyntaxError(f'Replacement \'{one_replacement}\' for'
f'\'{to_replace}\' is neither a name nor a '
f'valid expression.')
else:
replacement_strs.append(one_replacement)
elif isinstance(one_replacement, (numbers.Number, np.ndarray)):
if not getattr(one_replacement, 'shape', ()) == ():
raise TypeError(f'Cannot replace variable \'{to_replace}\' with an '
f'array of values.')
replacement_strs.append(repr(one_replacement))
elif isinstance(one_replacement, Expression):
replacement_strs.append(one_replacement.code)
else:
raise TypeError(f'Cannot replace \'{to_replace}\' with an object of type '
f'\'{type(one_replacement)}\'.')

if len(replacement_strs) == 1:
replacement_str = replacement_strs[0]
# Be careful if the string is more than just a name/number
if any(c not in string.ascii_letters + string.digits + '_.{}'
for c in replacement_str):
replacement_str = '(' + replacement_str + ')'
else:
replacement_str = '(' + (' + '.join(replacement_strs)) + ')'

new_expr = self
if to_replace in new_expr.identifiers:
code = new_expr.code
new_expr = Expression(re.sub(r'(?<!\w|{)' + to_replace + r'(?!\w|})',
replacement_str, code))
replaced_name = True
if to_replace in new_expr.template_identifiers:
code = new_expr.code
new_expr = Expression(code.replace('{' + to_replace + '}',
replacement_str))
replaced_placeholder = True
if not (replaced_name or replaced_placeholder):
raise KeyError(f'Replacement argument \'{to_replace}\' does not correspond '
f'to any name or placeholder in the equations.')
if replaced_name and replaced_placeholder:
logger.warn(f'Replacement argument \'{to_replace}\' replaced both a name '
f'and a placeholder \'{{{to_replace}}}\'.',
name_suffix='ambiguous_replacement')
return new_expr

def __call__(self, **replacements):
if len(replacements) == 0:
return self

# Figure out in which order elements should be substituted
dependencies = {}
for to_replace, replacement in replacements.items():
if not isinstance(replacement, Sequence) or isinstance(replacement, str):
replacement = [replacement]
for one_replacement in replacement:
dependencies[to_replace] = set()
if not isinstance(one_replacement, (numbers.Number, np.ndarray, str, Expression)):
raise TypeError(f'Cannot use an object of type \'{type(one_replacement)}\''
f'to replace \'{to_replace}\' in an expression.')
if isinstance(one_replacement, Expression):
dependencies[to_replace] |= one_replacement.identifiers | one_replacement.template_identifiers
# We only care about dependencies to values that are replaced at the same time
for dep_key, deps in dependencies.items():
dependencies[dep_key] = {d for d in deps if d in dependencies}

replacements_in_order = topsort(dependencies)[::-1]
expr = self
for to_replace in replacements_in_order:
replacement = replacements[to_replace]
expr = expr._do_substitution(to_replace, replacement)
return expr


def is_constant_over_dt(expression, variables, dt_value):
'''
Expand Down
Loading

0 comments on commit c0235e5

Please sign in to comment.