Skip to content

Commit

Permalink
Reduce the time needed to get the components of a linear system, in p…
Browse files Browse the repository at this point in the history
…articular fail much faster for a non-linear system (important when a state updater is determined automatically).
  • Loading branch information
Marcel Stimberg committed Oct 11, 2013
1 parent c1164b5 commit e48fb30
Showing 1 changed file with 24 additions and 23 deletions.
47 changes: 24 additions & 23 deletions brian2/stateupdaters/exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

logger = get_logger(__name__)


def get_linear_system(eqs):
'''
Convert equations into a linear system using sympy.
Expand All @@ -42,30 +41,30 @@ def get_linear_system(eqs):
diff_eq_names = eqs.diff_eq_names

symbols = [Symbol(name, real=True) for name in diff_eq_names]
# Coefficients
wildcards = [Wild('c_' + name, exclude=symbols) for name in diff_eq_names]

#Additive constant
constant_wildcard = Wild('c', exclude=symbols)

pattern = reduce(operator.add, [c * s for c, s in zip(wildcards, symbols)])
pattern += constant_wildcard

coefficients = sp.zeros(len(diff_eq_names))
constants = sp.zeros((len(diff_eq_names), 1))
constants = sp.zeros(len(diff_eq_names), 1)

for row_idx, (name, expr) in enumerate(diff_eqs):
s_expr = expr.sympy_expr.expand()
pattern_matches = s_expr.match(pattern)
if pattern_matches is None:
raise ValueError(('The expression "%s", defining the variable %s, '
'could not be separated into linear components') %
(expr, name))

for col_idx in xrange(len(diff_eq_names)):
coefficients[row_idx, col_idx] = pattern_matches[wildcards[col_idx]]

constants[row_idx] = pattern_matches[constant_wildcard]

current_s_expr = s_expr
for col_idx, (name, symbol) in enumerate(zip(eqs.diff_eq_names, symbols)):
current_s_expr = current_s_expr.collect(symbol)
constant_wildcard = Wild('c', exclude=[symbol])
factor_wildcard = Wild('c_'+name, exclude=symbols)
one_pattern = factor_wildcard*symbol + constant_wildcard
matches = current_s_expr.match(one_pattern)
if matches is None:
raise ValueError(('The expression "%s", defining the variable %s, '
'could not be separated into linear components') %
(expr, name))

coefficients[row_idx, col_idx] = matches[factor_wildcard]
current_s_expr = matches[constant_wildcard]

# The remaining constant should be a true constant
constants[row_idx] = current_s_expr

return (diff_eq_names, coefficients, constants)

Expand Down Expand Up @@ -209,7 +208,7 @@ def can_integrate(self, equations, variables):

# It worked
return True

def __call__(self, equations, variables=None):

if variables is None:
Expand Down Expand Up @@ -250,7 +249,9 @@ def __call__(self, equations, variables=None):
# replace them with the state variable names
abstract_code = []
for idx, (variable, update) in enumerate(zip(varnames, updates)):
rhs = update.subs(_S[idx, 0], variable)
rhs = update
for row_idx, varname in enumerate(varnames):
rhs = rhs.subs(_S[row_idx, 0], varname)
identifiers = get_identifiers(sympy_to_str(rhs))
for identifier in identifiers:
if identifier in variables:
Expand Down

0 comments on commit e48fb30

Please sign in to comment.