Skip to content

Commit

Permalink
Merge a1a1202 into 6ee0669
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed May 18, 2020
2 parents 6ee0669 + a1a1202 commit e22e05f
Show file tree
Hide file tree
Showing 6 changed files with 315 additions and 20 deletions.
57 changes: 45 additions & 12 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numbers

import sympy
from numpy import ones, array, arange, concatenate, mean, argmin, nanmin, reshape, zeros
from numpy import ones, array, arange, concatenate, mean, argmin, nanmin, reshape, zeros, sqrt

from brian2.parsing.sympytools import sympy_to_str, str_to_sympy
from brian2.units.fundamentalunits import DIMENSIONLESS, get_dimensions, fail_for_dimension_mismatch
Expand All @@ -17,7 +17,7 @@
from brian2.devices.cpp_standalone.device import CPPStandaloneDevice
from brian2.core.functions import Function
from .simulator import RuntimeSimulator, CPPStandaloneSimulator
from .metric import Metric, SpikeMetric, TraceMetric, MSEMetric
from .metric import Metric, SpikeMetric, TraceMetric, MSEMetric, normalize_weights
from .optimizer import Optimizer
from .utils import callback_setup, make_dic

Expand Down Expand Up @@ -795,6 +795,12 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
if not isinstance(metric, TraceMetric):
raise TypeError("You can only use TraceMetric child metric with "
"TraceFitter")
if metric.t_weights is not None:
if not metric.t_weights.shape == (self.output.shape[1], ):
raise ValueError("The 't_weights' argument of the metric has "
"to be a one-dimensional array of length "
f"{self.output.shape[1]} but has shape "
f"{metric.t_weights.shape}")
self.bounds = dict(params)
best_params, error = super().fit(optimizer, metric, n_rounds,
callback, restart,
Expand All @@ -808,7 +814,7 @@ def generate_traces(self, params=None, param_init=None, level=0):
param_init=param_init, level=level+1)
return fits

def refine(self, params=None, t_start=None, normalization=None,
def refine(self, params=None, t_start=None, t_weights=None, normalization=None,
callback='text', calc_gradient=False, optimize=True,
level=0, **kwds):
"""
Expand All @@ -826,9 +832,19 @@ def refine(self, params=None, t_start=None, normalization=None,
refinement. If not given, the best parameters found so far by
`~.TraceFitter.fit` will be used.
t_start : `~brian2.units.fundamentalunits.Quantity`, optional
Initial simulation/model time that should be ignored for the error
calculation. If not set, will reuse the `t_start` value from the
Start of time window considered for calculating the fit error.
If not set, will reuse the `t_start` value from the
previously used metric.
t_weights : `~.ndarray`, optional
A 1-dimensional array of weights for each time point. This array
has to have the same size as the input/output traces that are used
for fitting. A value of 0 means that data points are ignored. The
weight values will be normalized so only the relative values matter.
For example, an array containing 1s, and 2s, will weigh the
regions with 2s twice as high (with respect to the squared error)
as the regions with 1s. Using instead values of 0.5 and 1 would have
the same effect. Cannot be combined with ``t_start``. If not set, will
reuse the `t_weights` value from the previously used metric.
normalization : float, optional
A normalization term that will be used rescale results before
handing them to the optimization algorithm. Can be useful if the
Expand Down Expand Up @@ -895,8 +911,18 @@ def refine(self, params=None, t_start=None, normalization=None,
'the fit function first.')
params = self.best_params

if t_start is None:
t_start = getattr(self.metric, 't_start', 0*second)
if t_weights is not None:
t_weights = normalize_weights(t_weights)
elif t_start is None:
t_weights = getattr(self.metric, 't_weights', None)
if t_weights is None:
t_start = getattr(self.metric, 't_start', 0*second)
else:
t_start = None

if t_start is not None and t_weights is not None:
raise ValueError("Cannot use both 't_weights' and 't_start'.")

if normalization is None:
normalization = getattr(self.metric, 'normalization', 1.)
else:
Expand Down Expand Up @@ -924,23 +950,30 @@ def refine(self, params=None, t_start=None, normalization=None,
optimize=optimize,
level=level+1)

t_start_steps = int(round(t_start / self.dt))
if t_weights is None:
t_start_steps = int(round(t_start / self.dt))

def _calc_error(params):
param_dic = get_param_dic([params[p] for p in self.parameter_names],
self.parameter_names, self.n_traces, 1)
self.simulator.run(self.duration, param_dic,
self.parameter_names, name='refine')
trace = getattr(self.simulator.statemonitor, self.output_var+'_')
residual = trace[:, t_start_steps:] - self.output_[:, t_start_steps:]
if t_weights is None:
residual = trace[:, t_start_steps:] - self.output_[:, t_start_steps:]
else:
residual = (trace - self.output_) * sqrt(t_weights)
return residual.flatten() * normalization

def _calc_gradient(params):
residuals = []
for name in self.parameter_names:
trace = getattr(self.simulator.statemonitor,
f'S_{self.output_var}_{name}_')
residual = trace[:, t_start_steps:]
if t_weights is None:
residual = trace[:, t_start_steps:]
else:
residual = trace * sqrt(t_weights)
residuals.append(residual.flatten() * normalization)
gradient = array(residuals)
return gradient.T
Expand Down Expand Up @@ -1047,10 +1080,10 @@ def generate_spikes(self, params=None, param_init=None, level=0):


class OnlineTraceFitter(Fitter):
"""Input nad output have to have the same dimensions."""
def __init__(self, model, input_var, input, output_var, output, dt,
n_samples=30, method=None, reset=None, refractory=False,
threshold=None, level=0, param_init=None, t_start=0*second):
threshold=None, param_init=None,
t_start=0*second):
"""Initialize the fitter."""
super().__init__(dt, model, input, output, input_var, output_var,
n_samples, threshold, reset, refractory, method,
Expand Down
66 changes: 61 additions & 5 deletions brian2modelfitting/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from brian2 import Hz, second, Quantity, ms, us, get_dimensions
from brian2.units.fundamentalunits import check_units, in_unit, DIMENSIONLESS
from numpy import (array, sum, abs, amin, digitize, rint, arange, inf, NaN,
clip)
clip, mean)


def firing_rate(spikes):
Expand Down Expand Up @@ -110,6 +110,16 @@ def calc_eFEL(traces, inp_times, feat_list, dt):
return results


def normalize_weights(t_weights):
if any(t_weights < 0):
raise ValueError("Weights in 't_weights' have to be positive.")
mean_weights = mean(t_weights)
if mean_weights == 0:
raise ValueError("Weights in 't_weights' cannot be all zero.")
t_weights = t_weights / mean_weights
return t_weights


class Metric(metaclass=abc.ABCMeta):
"""
Metric abstract class to define functions required for a custom metric
Expand Down Expand Up @@ -240,6 +250,44 @@ class TraceMetric(Metric):
Input traces have to be shaped into 2D array.
"""

@check_units(t_start=second)
def __init__(self, t_start=0*second, t_weights=None, normalization=1.,
**kwds):
"""
Initialize the metric.
Parameters
----------
t_start : `~brian2.units.fundamentalunits.Quantity`, optional
Start of time window considered for calculating the fit error.
t_weights : `~.ndarray`, optional
A 1-dimensional array of weights for each time point. This array
has to have the same size as the input/output traces that are used
for fitting. A value of 0 means that data points are ignored. The
weight values will be normalized so only the relative values matter.
For example, an array containing 1s, and 2s, will weigh the
regions with 2s twice as high (with resepct to the squared error)
as the regions with 1s. Using instead values of 0.5 and 1 would have
the same effect. Cannot be combined with ``t_start``.
normalization : float, optional
A normalization term that will be used rescale results before
handing them to the optimization algorithm. Can be useful if the
algorithm makes assumptions about the scale of errors, e.g. if the
size of steps in the parameter space depends on the absolute value
of the error. Trace-based metrics divide the traces itself by the
value, other metrics use it to scale the total error. Not used by
default, i.e. defaults to 1.
"""
if t_weights is not None and t_start != 0*second:
raise ValueError("Cannot use both 't_weights' and 't_start'.")
super(TraceMetric, self).__init__(t_start=t_start,
normalization=normalization)
if t_weights is not None:
self.t_weights = normalize_weights(t_weights)
else:
self.t_weights = None


def calc(self, model_traces, data_traces, dt):
"""
Perform the error calculation across all parameters,
Expand Down Expand Up @@ -267,9 +315,14 @@ def calc(self, model_traces, data_traces, dt):
``(n_samples, )``.
"""
start_steps = int(round(self.t_start/dt))
features = self.get_features(model_traces[:, :, start_steps:] * float(self.normalization),
data_traces[:, start_steps:] * float(self.normalization),
dt)
if self.t_weights is not None:
features = self.get_features(model_traces * float(self.normalization),
data_traces * float(self.normalization),
dt)
else:
features = self.get_features(model_traces[:, :, start_steps:] * float(self.normalization),
data_traces[:, start_steps:] * float(self.normalization),
dt)
errors = self.get_errors(features)

return errors
Expand Down Expand Up @@ -391,7 +444,10 @@ class MSEMetric(TraceMetric):
def get_features(self, model_traces, data_traces, dt):
# Note that the traces have already beeen normalized in
# TraceMetric.calc
return ((model_traces - data_traces)**2).mean(axis=2)
error = (model_traces - data_traces)**2
if self.t_weights is not None:
error *= self.t_weights
return error.mean(axis=2)

def get_errors(self, features):
return features.mean(axis=1)
Expand Down
46 changes: 46 additions & 0 deletions brian2modelfitting/tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,52 @@ def test_calc_mse_t_start():
assert_equal(mse.calc(inp, out, 0.1*ms), np.zeros(5))


def test_calc_mse_t_weights():
with pytest.raises(ValueError):
# t_start and t_weights
MSEMetric(t_start=1*ms, t_weights=np.ones(20))
with pytest.raises(ValueError):
# all values zero
MSEMetric(t_weights=np.zeros(20))
with pytest.raises(ValueError):
# negative values
weights = np.ones(20)
weights[17] = -1
MSEMetric(t_weights=weights)

weights = np.ones(20)
weights[:10] = 0
mse = MSEMetric(t_weights=weights)
out = np.random.rand(2, 20)
inp = np.random.rand(5, 2, 20)

errors = mse.calc(inp, out, 0.1*ms)
assert_equal(np.shape(errors), (5,))
assert(np.all(errors > 0))
# Everything before 1ms should be ignored, so having the same values for
# the rest should give an error of 0
inp[:, :, 10:] = out[None, :, 10:]
assert_equal(mse.calc(inp, out, 0.1*ms), np.zeros(5))


def test_calc_mse_t_weights_normalization():
# check that normalization works correctly
dt = 0.1*ms
metric1 = MSEMetric(t_start=50*dt)
weights = np.ones(100)
weights[:50] = 0
metric2 = MSEMetric(t_weights=weights)
weights2 = weights*2 # should not make any difference
metric3 = MSEMetric(t_weights=weights2)
data_traces = np.random.rand(3, 100)
model_traces = np.random.rand(2, 3, 100)
error_1 = metric1.calc(model_traces=model_traces, data_traces=data_traces, dt=dt)
error_2 = metric2.calc(model_traces=model_traces, data_traces=data_traces, dt=dt)
error_3 = metric3.calc(model_traces=model_traces, data_traces=data_traces, dt=dt)
assert_almost_equal(error_1, error_2)
assert_almost_equal(error_1, error_3)


def test_calc_gf():
assert_raises(TypeError, GammaFactor)
assert_raises(DimensionMismatchError, GammaFactor, delta=10*mV)
Expand Down

0 comments on commit e22e05f

Please sign in to comment.