Skip to content

Commit

Permalink
Fix normalized errors in callback
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Mar 17, 2020
1 parent fbd864b commit 6d924b2
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
10 changes: 6 additions & 4 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,9 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
output_dim = DIMENSIONLESS
else:
output_dim = self.output_dim
error_dim = self.metric.get_dimensions(output_dim)
# Correct the units for the normalization factor
error_dim = self.metric.get_normalized_dimensions(output_dim)
best_error = Quantity(float(self.best_error), dim=error_dim)
errors = Quantity(errors, dim=error_dim)
param_dicts = [{p: Quantity(v, dim=self.model[p].dim)
for p, v in zip(self.parameter_names,
Expand All @@ -545,13 +547,13 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
param_dicts = [{p: v for p, v in zip(self.parameter_names,
one_param_set)}
for one_param_set in parameters]
best_error = self.best_error

if callback(param_dicts,
errors,
self.best_params,
self.best_error,
best_error,
index) is True:
print('Stopping simulation')
break

return self.best_params, self.best_error
Expand Down Expand Up @@ -916,7 +918,7 @@ def _callback_wrapper(params, iter, resid, *args, **kwds):
error = mean(resid**2)
errors.append(error)
if self.use_units:
error_dim = self.output_dim**2
error_dim = self.output_dim**2 * get_dimensions(normalization)**2
all_errors = Quantity(errors, dim=error_dim)
params = {p: Quantity(val, dim=self.model[p].dim)
for p, val in params.items()}
Expand Down
31 changes: 25 additions & 6 deletions brian2modelfitting/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
except ImportError:
warnings.warn('eFEL package not found.')
from itertools import repeat
from brian2 import Hz, second, Quantity, ms, us
from brian2 import Hz, second, Quantity, ms, us, get_dimensions
from brian2.units.fundamentalunits import check_units, in_unit, DIMENSIONLESS
from numpy import (array, sum, square, reshape, abs, amin, digitize,
rint, arange, atleast_2d, NaN, float64, split, shape,)
Expand Down Expand Up @@ -158,6 +158,23 @@ def get_dimensions(self, output_dim):
"""
return DIMENSIONLESS

def get_normalized_dimensions(self, output_dim):
"""
The physical dimensions of the normalized error. This will be
the same as the dimensions returned by `~.Metric.get_dimensions` if
the ``normalization`` is not used or set to a dimensionless value.
Parameters
----------
output_dim : `.Dimension`
The dimensions of the output variable.
Returns
-------
dim : `.Dimension`
The physical dimensions of the normalized error.
"""
return DIMENSIONLESS

@abc.abstractmethod
def get_features(self, model_results, target_results, dt):
Expand Down Expand Up @@ -249,8 +266,8 @@ def calc(self, model_traces, data_traces, dt):
``(n_samples, )``.
"""
start_steps = int(round(self.t_start/dt))
features = self.get_features(model_traces[:, :, start_steps:] * self.normalization,
data_traces[:, start_steps:] * self.normalization,
features = self.get_features(model_traces[:, :, start_steps:] * float(self.normalization),
data_traces[:, start_steps:] * float(self.normalization),
dt)
errors = self.get_errors(features)

Expand Down Expand Up @@ -332,7 +349,7 @@ def calc(self, model_spikes, data_spikes, dt):
model_spikes = relevant_model_spikes
data_spikes = relevant_data_spikes
features = self.get_features(model_spikes, data_spikes, dt)
errors = self.get_errors(features) * self.normalization
errors = self.get_errors(features) * float(self.normalization)

return errors

Expand Down Expand Up @@ -381,6 +398,8 @@ def get_errors(self, features):
def get_dimensions(self, output_dim):
return output_dim**2

def get_normalized_dimensions(self, output_dim):
return output_dim**2 * get_dimensions(self.normalization)**2

class FeatureMetric(TraceMetric):
def __init__(self, stim_times, feat_list, weights=None, combine=None,
Expand Down Expand Up @@ -483,7 +502,7 @@ def get_errors(self, features):
sample_error += total
errors.append(sample_error)

return array(errors) * self.normalization
return array(errors) * float(self.normalization)


class GammaFactor(SpikeMetric):
Expand Down Expand Up @@ -535,7 +554,7 @@ def get_features(self, traces, output, dt):
rate_correction=self.rate_correction)
gf_for_sample.append(gf)
all_gf.append(gf_for_sample)
return array(all_gf) * self.normalization
return array(all_gf) * float(self.normalization)

def get_errors(self, features):
errors = features.mean(axis=1)
Expand Down

0 comments on commit 6d924b2

Please sign in to comment.