Skip to content

Commit

Permalink
Merge 20706ac into f22955b
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Jul 13, 2021
2 parents f22955b + 20706ac commit c567640
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 27 deletions.
65 changes: 39 additions & 26 deletions brian2/stateupdaters/exact.py
Expand Up @@ -40,22 +40,25 @@ def get_linear_system(eqs, variables):
ValueError
If the equations cannot be converted into an M * X + B form.
'''
diff_eqs = eqs.get_substituted_expressions(variables)
diff_eq_names = [name for name, _ in diff_eqs]
diff_eqs = {name: str_to_sympy(expr.code, variables).expand()
for name, expr in eqs.get_substituted_expressions(variables)}

symbols = [Symbol(name, real=True) for name in diff_eq_names]
# Sometimes, in particular in testing, variables defined as differential
# equations are actually constant (e.g. `dv/dt = 0/second`). We ignore
# them here
symbols = [Symbol(name, real=True) for name, expr in diff_eqs.items()
if expr != 0]

coefficients = sp.zeros(len(diff_eq_names))
constants = sp.zeros(len(diff_eq_names), 1)

for row_idx, (name, expr) in enumerate(diff_eqs):
s_expr = str_to_sympy(expr.code, variables).expand()
coefficients = sp.zeros(len(symbols))
constants = sp.zeros(len(symbols), 1)

for row_idx, symbol in enumerate(symbols):
s_expr = diff_eqs[symbol.name]
current_s_expr = s_expr
for col_idx, symbol in enumerate(symbols):
current_s_expr = current_s_expr.collect(symbol)
constant_wildcard = Wild('c', exclude=[symbol])
factor_wildcard = Wild('c_'+name, exclude=symbols)
factor_wildcard = Wild('c_'+symbol.name, exclude=symbols)
one_pattern = factor_wildcard*symbol + constant_wildcard
matches = current_s_expr.match(one_pattern)
if matches is None:
Expand All @@ -64,15 +67,16 @@ def get_linear_system(eqs, variables):
'%s, could not be '
'separated into linear '
'components.') %
(expr, name))
(sympy_to_str(s_expr),
symbol.name))

coefficients[row_idx, col_idx] = matches[factor_wildcard]
current_s_expr = matches[constant_wildcard]

# The remaining constant should be a true constant
constants[row_idx] = current_s_expr

return (diff_eq_names, coefficients, constants)
return [s.name for s in symbols], coefficients, constants


class IndependentStateUpdater(StateUpdateMethod):
Expand Down Expand Up @@ -191,40 +195,49 @@ def __call__(self, equations, variables=None, method_options=None):
('Expression "{}" is not guaranteed to be constant over a '
'time step').format(sympy_to_str(entry)))

symbols = [Symbol(variable, real=True) for variable in varnames]
solution = sp.solve_linear_system(matrix.row_join(constants), *symbols)
if solution is None or set(symbols) != set(solution.keys()):
raise UnsupportedEquationsException('Cannot solve the given '
'equations with this '
'stateupdater.')
b = sp.ImmutableMatrix([solution[symbol] for symbol in symbols])

# Solve the system
dt = Symbol('dt', real=True, positive=True)
# Add the constant terms as new variables
const_vars = []
const_terms = []
for idx, (varname, const_term) in enumerate(zip(varnames, constants)):
if const_term != 0:
matrix = matrix.col_insert(matrix.cols, sp.Matrix([1 if i == idx else 0
for i in range(matrix.rows)]))
matrix = matrix.row_insert(matrix.rows, sp.zeros(1, matrix.cols))
const_vars.append('_const_term_' + varname)
const_terms.append(const_term)

try:
A = (matrix * dt).exp()
except NotImplementedError:
raise UnsupportedEquationsException('Cannot solve the given '
'equations with this '
'stateupdater.')

if method_options['simplify']:
A = A.applyfunc(lambda x:
sp.factor_terms(sp.cancel(sp.signsimp(x))))
C = sp.ImmutableMatrix(A * b) - b
_S = sp.MatrixSymbol('_S', len(varnames), 1)
updates = A * _S + C

_S = sp.MatrixSymbol('_S', len(varnames) + len(const_vars), 1)
updates = A * _S
updates = updates.as_explicit()
abstract_code = []

# Add code for the constant terms:
for const_var, const_term in zip(const_vars, const_terms):
abstract_code.append(const_var + ' = ' + sympy_to_str(const_term))

# The solution contains _S[0, 0], _S[1, 0] etc. for the state variables,
# replace them with the state variable names
abstract_code = []
for idx, (variable, update) in enumerate(zip(varnames, updates)):
# replace them with the state variable names
for variable, update in zip(varnames, updates[:len(varnames)]):
rhs = update
if rhs.has(I, re, im):
raise UnsupportedEquationsException('The solution to the linear system '
'contains complex values '
'which is currently not implemented.')
for row_idx, varname in enumerate(varnames):

for row_idx, varname in enumerate(itertools.chain(varnames, const_vars)):
rhs = rhs.subs(_S[row_idx, 0], varname)

# Do not overwrite the real state variables yet, the update step
Expand Down
5 changes: 4 additions & 1 deletion brian2/tests/test_synapses.py
Expand Up @@ -1585,6 +1585,8 @@ def test_event_driven_dependency_error():

@pytest.mark.codegen_independent
def test_event_driven_dependency_error2():
pytest.xfail("This will be fixed with the rewrite of the equation "
"dependency check.")
stim = SpikeGeneratorGroup(1, [0], [0]*ms, period=5*ms)
tau = 5*ms
syn = Synapses(stim, stim, '''
Expand Down Expand Up @@ -1962,7 +1964,8 @@ def test_vectorisation_STDP_like():
neurons = NeuronGroup(6, '''dv/dt = rate : 1
ge : 1
rate : Hz
dA/dt = -A/(1*ms) : 1''', threshold='v>1', reset='v=0')
dA/dt = -A/(1*ms) : 1''', threshold='v>1',
reset='v=0', method='euler')
# Note that the synapse does not actually increase the target v, we want
# to have simple control about when neurons spike. Also, we separate the
# "depression" and "facilitation" completely. The example also uses
Expand Down

0 comments on commit c567640

Please sign in to comment.