Skip to content

Commit

Permalink
Add t_weight argument to weigh error temporarily
Browse files Browse the repository at this point in the history
Closes #22
  • Loading branch information
mstimberg committed Mar 23, 2020
1 parent 8ccb68e commit 85c725c
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 4 deletions.
49 changes: 45 additions & 4 deletions brian2modelfitting/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,39 @@ class TraceMetric(Metric):
Input traces have to be shaped into 2D array.
"""

@check_units(t_start=second)
def __init__(self, t_start=0*second, t_weights=None, normalization=1.,
**kwds):
"""
Initialize the metric.
Parameters
----------
t_start : `~brian2.units.fundamentalunits.Quantity`, optional
Start of time window considered for calculating the fit error.
t_weights : `~.ndarray`, optional
A 1-dimensional array of weights for each time point. This array
has to have the same size as the input/output traces that are used
for fitting. A value of 0 means that data points are ignored, and
a value of 1 means that a data point enters into the calculation
as usual. Values of > 1 will emphasize the error by the respective
factor at that point, a value of < 1 will de-emphasize it. Cannot
be combined with ``t_start``.
normalization : float, optional
A normalization term that will be used rescale results before
handing them to the optimization algorithm. Can be useful if the
algorithm makes assumptions about the scale of errors, e.g. if the
size of steps in the parameter space depends on the absolute value
of the error. Trace-based metrics divide the traces itself by the
value, other metrics use it to scale the total error. Not used by
default, i.e. defaults to 1.
"""
if t_weights is not None and t_start != 0*second:
raise ValueError("Cannot use both 't_weigths' and 't_start'.")
super(TraceMetric, self).__init__(t_start=t_start,
normalization=normalization)
self.t_weights = t_weights

def calc(self, model_traces, data_traces, dt):
"""
Perform the error calculation across all parameters,
Expand Down Expand Up @@ -266,9 +299,14 @@ def calc(self, model_traces, data_traces, dt):
``(n_samples, )``.
"""
start_steps = int(round(self.t_start/dt))
features = self.get_features(model_traces[:, :, start_steps:] * float(self.normalization),
data_traces[:, start_steps:] * float(self.normalization),
dt)
if self.t_weights is not None:
features = self.get_features(model_traces * float(self.normalization),
data_traces * float(self.normalization),
dt)
else:
features = self.get_features(model_traces[:, :, start_steps:] * float(self.normalization),
data_traces[:, start_steps:] * float(self.normalization),
dt)
errors = self.get_errors(features)

return errors
Expand Down Expand Up @@ -390,7 +428,10 @@ class MSEMetric(TraceMetric):
def get_features(self, model_traces, data_traces, dt):
# Note that the traces have already beeen normalized in
# TraceMetric.calc
return ((model_traces - data_traces)**2).mean(axis=2)
error = (model_traces - data_traces)**2
if self.t_weights is not None:
error *= self.t_weights
return error.mean(axis=2)

def get_errors(self, features):
return features.mean(axis=1)
Expand Down
40 changes: 40 additions & 0 deletions brian2modelfitting/tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,37 @@ def test_calc_mse_t_start():
assert_equal(mse.calc(inp, out, 0.1*ms), np.zeros(5))


def test_calc_mse_t_weights():
with pytest.raises(ValueError):
MSEMetric(t_start=1*ms, t_weights=np.ones(20))
weights = np.ones(20)
weights[:10] = 0
mse = MSEMetric(t_weights=weights)
out = np.random.rand(2, 20)
inp = np.random.rand(5, 2, 20)

errors = mse.calc(inp, out, 0.1*ms)
assert_equal(np.shape(errors), (5,))
assert(np.all(errors > 0))
# Everything before 1ms should be ignored, so having the same values for
# the rest should give an error of 0
inp[:, :, 10:] = out[None, :, 10:]
assert_equal(mse.calc(inp, out, 0.1*ms), np.zeros(5))


def test_calc_mse_t_weights2():
mse = MSEMetric()
mse_weighted = MSEMetric(t_weights=2*np.ones(20))
out = np.random.rand(2, 20)
inp = np.random.rand(5, 2, 20)

errors = mse.calc(inp, out, 0.1*ms)
assert_equal(np.shape(errors), (5,))
assert(np.all(errors > 0))
errors_weighted = mse_weighted.calc(inp, out, 0.1*ms)
assert_almost_equal(errors_weighted, 2*errors)


def test_calc_gf():
assert_raises(TypeError, GammaFactor)
assert_raises(DimensionMismatchError, GammaFactor, delta=10*mV)
Expand All @@ -106,6 +137,15 @@ def test_calc_gf():
assert all(errors > 0)


def test_calc_gf_t_start():
# Spikes starting at 3ms are identical to data
model_spikes = [[np.array([1, 5, 9])*1e-3, np.array([2, 3, 5, 6])*1e-3],
[np.array([5, 9])*1e-3, np.array([3, 5, 6])*1e-3]]
data_spikes = [np.array([0, 5, 9])*1e-3, np.array([1, 3, 5, 6])*1e-3]
gf = GammaFactor(delta=0.5 * ms, time=10 * ms, t_start=2.5*ms)
assert_almost_equal(gf.calc(model_spikes, data_spikes, 0.1*ms), 0)


def test_get_features_mse():
mse = MSEMetric()
out_mse = np.random.rand(2, 20)
Expand Down

0 comments on commit 85c725c

Please sign in to comment.