Skip to content

Commit

Permalink
Merge pull request #65
Browse files Browse the repository at this point in the history
Fix handling of multiple variables with differing t_start/t_weights in refine
  • Loading branch information
mstimberg committed Sep 14, 2021
2 parents 650f904 + d2fcf67 commit 49eb822
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 149 deletions.
136 changes: 64 additions & 72 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from typing import Sequence, Mapping

import sympy
from numpy import ones, array, arange, concatenate, mean, argmin, nanargmin, reshape, zeros, sqrt, ndarray, broadcast_to, sum
from numpy import (ones, array, arange, concatenate, mean, argmin, nanargmin,
reshape, zeros, sqrt, ndarray, broadcast_to, sum, cumsum,
hstack)

from brian2.parsing.sympytools import sympy_to_str, str_to_sympy
from brian2.units.fundamentalunits import DIMENSIONLESS, get_dimensions, fail_for_dimension_mismatch
Expand All @@ -15,7 +17,7 @@
Quantity, get_logger, Expression, ms)
from brian2.input import TimedArray
from brian2.equations.equations import Equations, SUBEXPRESSION, SingleEquation
from brian2.devices import set_device, reset_device, device
from brian2.devices import device
from brian2.devices.cpp_standalone.device import CPPStandaloneDevice
from brian2.core.functions import Function
from .simulator import RuntimeSimulator, CPPStandaloneSimulator
Expand Down Expand Up @@ -1160,7 +1162,7 @@ def generate_traces(self, params=None, param_init=None, iteration=1e9,
level=level+1)
return fits

def refine(self, params=None, t_start=None, t_weights=None, normalization=None,
def refine(self, params=None, metric=None,
callback='text', calc_gradient=False, optimize=True,
iteration=1e9, level=0, **kwds):
"""
Expand All @@ -1177,28 +1179,14 @@ def refine(self, params=None, t_start=None, t_weights=None, normalization=None,
A dictionary with the parameters to use as a starting point for the
refinement. If not given, the best parameters found so far by
`~.TraceFitter.fit` will be used.
t_start : `~brian2.units.fundamentalunits.Quantity`, optional
Start of time window considered for calculating the fit error.
If not set, will reuse the `t_start` value from the
previously used metric.
t_weights : `~.ndarray`, optional
A 1-dimensional array of weights for each time point. This array
has to have the same size as the input/output traces that are used
for fitting. A value of 0 means that data points are ignored. The
weight values will be normalized so only the relative values matter.
For example, an array containing 1s, and 2s, will weigh the
regions with 2s twice as high (with respect to the squared error)
as the regions with 1s. Using instead values of 0.5 and 1 would have
the same effect. Cannot be combined with ``t_start``. If not set, will
reuse the `t_weights` value from the previously used metric.
normalization : float or array-like of float, optional
A normalization term that will be used rescale results before
handing them to the optimization algorithm. Can be useful if the
algorithm makes assumptions about the scale of errors, e.g. if the
size of steps in the parameter space depends on the absolute value
of the error. The difference between simulated and target traces
will be divided by this value. If not set, will reuse the
`normalization` value from the previously used metric.
metric: `~.MSEMetric` or dict, optional
Optimization metrics to use. Since the refinement only supports
mean-squared-error metrics, this is only useful to provide the
``t_start``/``t_weights``/``normalization`` values. In
the case of multiple fitted output variables, can either be a single
`~.MSEMetric` that is applied to all variables, or a dictionary with a
`~.MSEMetric` for each variable. If not given, will reuse the metrics
of a previous `~.Fitter.fit` call.
callback: `str` or `~typing.Callable`
Either the name of a provided callback function (``text`` or
``progressbar``), or a custom feedback function
Expand Down Expand Up @@ -1264,32 +1252,30 @@ def refine(self, params=None, t_start=None, t_weights=None, normalization=None,
'the fit function first.')
params = self.best_params

if t_weights is not None:
t_weights = normalize_weights(t_weights)
elif t_start is None:
if self.metric is not None:
t_weights = getattr(self.metric[0], 't_weights', None)
if t_weights is None:
if self.metric is not None:
t_start = getattr(self.metric[0], 't_start', 0*second)
else:
t_start = 0*second
else:
t_start = None

if t_start is not None and t_weights is not None:
raise ValueError("Cannot use both 't_weights' and 't_start'.")

if normalization is None:
if self.metric is not None:
normalization = [getattr(metric, 'normalization', 1.)
for metric in self.metric]
if metric is None:
if self.metric is None:
metric = {output_var: MSEMetric()
for output_var in self.output_var}
else:
normalization = [1.] * len(self.output_var)
else:
if not isinstance(normalization, Sequence):
normalization = [normalization] * len(self.output_var)
normalization = [1.0/n for n in normalization]
metric = {output_var: m
for output_var, m in zip(self.output_var, self.metric)}
elif not isinstance(metric, Mapping):
metric = {output_var: metric
for output_var in self.output_var}

for var, m in metric.items():
if not isinstance(m, MSEMetric):
raise TypeError(f"The refine method only supports MSEMetric, but "
f"the metric for variable '{var}' is of type "
f"'{type(m)}'")

# Extract the necessary normalization info in flat arrays
t_weights = [getattr(metric[v], 't_weights', None)
for v in self.output_var]
t_start = [getattr(metric[v], 't_start', 0*second)
for v in self.output_var]
normalization = [getattr(metric[v], 'normalization', 1.0)
for v in self.output_var]

callback_func = callback_setup(callback, None)

Expand All @@ -1313,10 +1299,8 @@ def refine(self, params=None, t_start=None, t_weights=None, normalization=None,
optimize=optimize,
level=level+1)

if t_weights is None:
t_start_steps = int(round(t_start / self.dt))
else:
t_start_steps = 0
t_start_steps = [int(round(t_s / self.dt)) if t_w is None else 0
for t_s, t_w in zip(t_start, t_weights)]

def _calc_error(params):
param_dic = get_param_dic([params[p] for p in self.parameter_names],
Expand All @@ -1326,41 +1310,49 @@ def _calc_error(params):
name='refine')
one_residual = []

for out_var, out, norm in zip(self.output_var,
self.output_,
normalization):
for out_var, out, t_s_steps, t_w, norm in zip(self.output_var,
self.output_,
t_start_steps,
t_weights,
normalization):
trace = getattr(self.simulator.statemonitor, out_var+'_')
if t_weights is None:
residual = trace[:, t_start_steps:] - out[:, t_start_steps:]
if t_w is None:
residual = trace[:, t_s_steps:] - out[:, t_s_steps:]
else:
residual = (trace - out) * sqrt(t_weights)
one_residual.append(residual*norm)
return array(one_residual).flatten()
residual = (trace - out) * sqrt(t_w)
one_residual.append((residual*norm).flatten())
return array(hstack(one_residual))

def _calc_gradient(params):
residuals = []
for name in self.parameter_names:
one_residual = []
for out_var, norm in zip(self.output_var, normalization):
for out_var, t_s_steps, t_w, norm in zip(self.output_var,
t_start_steps,
t_weights,
normalization):
trace = getattr(self.simulator.statemonitor,
f'S_{out_var}_{name}_')
if t_weights is None:
residual = trace[:, t_start_steps:]
if t_w is None:
residual = trace[:, t_s_steps:]
else:
residual = trace * sqrt(t_weights)
one_residual.append(residual*norm)
residuals.append(array(one_residual).flatten())
residual = trace * sqrt(t_w)
one_residual.append((residual*norm).flatten())
residuals.append(array(hstack(one_residual)))
gradient = array(residuals)
return gradient.T

tested_parameters = []
errors = []
combined_errors = []
def _callback_wrapper(params, iter, resid, *args, **kwds):
# TODO: Assumes all the outputs have the same size
output_len = self.output[0][:, t_start_steps:].size
error = tuple([mean(resid[idx*output_len:(idx + 1)*output_len]**2)
for idx in range(len(self.output_var))])
output_len = [output[:, t_s_steps:].size
for output, t_s_steps in zip(self.output,
t_start_steps)]
end_idx = cumsum(output_len)
start_idx = hstack([0, end_idx[:-1]])
error = tuple([mean(resid[start:end]**2)
for start, end in zip(start_idx, end_idx)])
combined_error = sum(array(error))
errors.append(error)
combined_errors.append(combined_error)
Expand Down
Loading

0 comments on commit 49eb822

Please sign in to comment.