Skip to content

Commit

Permalink
Initialize sensitivity parameter based on param_init
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Mar 6, 2020
1 parent 3ed0b56 commit 75df433
Showing 1 changed file with 55 additions and 7 deletions.
62 changes: 55 additions & 7 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import abc
import numbers

import sympy
from numpy import ones, array, arange, concatenate, mean, argmin, nanmin, reshape, zeros

from brian2.parsing.sympytools import sympy_to_str, str_to_sympy
from brian2.units.fundamentalunits import DIMENSIONLESS, get_dimensions
from numpy import ones, array, arange, concatenate, mean, argmin, nanmin, reshape, zeros
from brian2.utils.stringtools import get_identifiers

from brian2 import (NeuronGroup, defaultclock, get_device, Network,
StateMonitor, SpikeMonitor, second, get_local_namespace,
Quantity, get_logger)
from brian2.input import TimedArray
from brian2.equations.equations import Equations
from brian2.equations.equations import Equations, SUBEXPRESSION
from brian2.devices import set_device, reset_device, device
from brian2.devices.cpp_standalone.device import CPPStandaloneDevice
from brian2.core.functions import Function
Expand Down Expand Up @@ -78,7 +82,6 @@ def setup_fit():


def get_sensitivity_equations(group, parameters, namespace=None, level=1):
from sympy import Matrix
if namespace is None:
namespace = get_local_namespace(level)
namespace.update(group.namespace)
Expand All @@ -87,8 +90,8 @@ def get_sensitivity_equations(group, parameters, namespace=None, level=1):
diff_eqs = eqs.get_substituted_expressions(group.variables)
diff_eq_names = [name for name, _ in diff_eqs]

system = Matrix([str_to_sympy(diff_eq[1].code)
for diff_eq in diff_eqs]).as_mutable()
system = sympy.Matrix([str_to_sympy(diff_eq[1].code)
for diff_eq in diff_eqs])
J = system.jacobian([str_to_sympy(d) for d in diff_eq_names])

sensitivity = []
Expand All @@ -97,7 +100,7 @@ def get_sensitivity_equations(group, parameters, namespace=None, level=1):
F = system.jacobian([str_to_sympy(parameter)])
names = [str_to_sympy(f'S_{diff_eq_name}_{parameter}')
for diff_eq_name in diff_eq_names]
sensitivity.append(J * Matrix(names) + F)
sensitivity.append(J * sympy.Matrix(names) + F)
sensitivity_names.append(names)

new_eqs = []
Expand All @@ -117,6 +120,47 @@ def get_sensitivity_equations(group, parameters, namespace=None, level=1):
return new_eqs


def get_sensitivity_init(group, parameters, param_init):
'''
Calculate the initial values for the sensitivity parameters (necessary if
initial values are functions of parameters).
Parameters
----------
group : `NeuronGroup`
The group of neurons that will be simulated.
parameters : list of str
Names of the parameters that are fit.
param_init : dict
The dictionary with expressions to initialize the model variables.
Returns
-------
sensitivity_init : dict
Dictionary of expressions to initialize the sensitivity
parameters.
'''
sensitivity_dict = {}
for var_name, expr in param_init.items():
if not isinstance(expr, str):
continue
identifiers = get_identifiers(expr)
for identifier in identifiers:
if (identifier in group.variables
and getattr(group.variables[identifier],
'type', None) == SUBEXPRESSION):
raise NotImplementedError('Initializations that refer to a '
'subexpression are currently not '
'supported')
sympy_expr = str_to_sympy(expr)
for parameter in parameters:
diffed = sympy_expr.diff(str_to_sympy(parameter))
if diffed != sympy.S.Zero:
init_expr = sympy_to_str(diffed)
sensitivity_dict[f'S_{var_name}_{parameter}'] = init_expr
return sensitivity_dict


class Fitter(metaclass=abc.ABCMeta):
"""
Base Fitter class for model fitting applications.
Expand Down Expand Up @@ -670,14 +714,18 @@ def refine(self, params=None, t_start=None, normalization=None,
calc_gradient=calc_gradient,
name='neurons')
monitored_variables = [self.output_var]
param_init = dict(self.param_init)
if calc_gradient:
monitored_variables += [f'S_{self.output_var}_{p}'
for p in self.parameter_names]
param_init.update(get_sensitivity_init(neurons,
self.parameter_names,
param_init))
monitor = StateMonitor(neurons, monitored_variables, record=True,
name='monitor')
network = Network(neurons, monitor)

simulator.initialize(network, self.param_init, name='refine')
simulator.initialize(network, param_init, name='refine')

t_start_steps = int(round(t_start / self.dt))

Expand Down

0 comments on commit 75df433

Please sign in to comment.