diff --git a/brian2modelfitting/fitter.py b/brian2modelfitting/fitter.py index c82bf750..21ff34cc 100644 --- a/brian2modelfitting/fitter.py +++ b/brian2modelfitting/fitter.py @@ -17,7 +17,7 @@ from brian2.devices.cpp_standalone.device import CPPStandaloneDevice from brian2.core.functions import Function from .simulator import RuntimeSimulator, CPPStandaloneSimulator -from .metric import Metric, SpikeMetric, TraceMetric +from .metric import Metric, SpikeMetric, TraceMetric, MSEMetric from .optimizer import Optimizer from .utils import callback_setup, make_dic @@ -265,7 +265,8 @@ class Fitter(metaclass=abc.ABCMeta): Dictionary of variables to be initialized with respective values """ def __init__(self, dt, model, input, output, input_var, output_var, - n_samples, threshold, reset, refractory, method, param_init): + n_samples, threshold, reset, refractory, method, param_init, + use_units=True): """Initialize the fitter.""" if dt is None: @@ -295,8 +296,14 @@ def __init__(self, dt, model, input, output, input_var, output_var, self.output = Quantity(output) self.output_ = array(output) self.output_var = output_var + if output_var == 'spikes': + self.output_dim = DIMENSIONLESS + else: + self.output_dim = model[output_var].dim self.model = model + self.use_units = use_units + input_dim = get_dimensions(input) input_dim = '1' if input_dim is DIMENSIONLESS else repr(input_dim) input_eqs = "{} = input_var(t, i % n_traces) : {}".format(input_var, @@ -307,7 +314,8 @@ def __init__(self, dt, model, input, output, input_var, output_var, self.input_traces = input_traces # initialization of attributes used later - self.best_params = None + self._best_params = None + self._best_error = None self.optimizer = None self.metric = None if not param_init: @@ -470,11 +478,10 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text', online_error: bool, optional Whether to calculate the squared error between target trace and simulated trace online. Defaults to ``False``. + level : `int`, optional + How much farther to go down in the stack to find the namespace. **params bounds for each parameter - level : `int`, optional - How much farther to go down in the stack to find the namespace. - Returns ------- best_results : dict @@ -517,24 +524,62 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text', level=level+1) # Run Optimization Loop - error = None for index in range(n_rounds): best_params, parameters, errors = self.optimization_iter(optimizer, metric) - + self._best_error = nanmin(self.optimizer.errors) # create output variables - self.best_params = make_dic(self.parameter_names, best_params) - error = nanmin(self.optimizer.errors) - param_dicts = [{p: v for p, v in zip(self.parameter_names, + self._best_params = make_dic(self.parameter_names, best_params) + if self.use_units: + if self.output_var == 'spikes': + output_dim = DIMENSIONLESS + else: + output_dim = self.output_dim + # Correct the units for the normalization factor + error_dim = self.metric.get_normalized_dimensions(output_dim) + best_error = Quantity(float(self.best_error), dim=error_dim) + errors = Quantity(errors, dim=error_dim) + param_dicts = [{p: Quantity(v, dim=self.model[p].dim) + for p, v in zip(self.parameter_names, one_param_set)} - for one_param_set in parameters] - - if callback(param_dicts, errors, self.best_params, error, index) is True: + for one_param_set in parameters] + else: + param_dicts = [{p: v for p, v in zip(self.parameter_names, + one_param_set)} + for one_param_set in parameters] + best_error = self.best_error + + if callback(param_dicts, + errors, + self.best_params, + best_error, + index) is True: break - return self.best_params, error + return self.best_params, self.best_error + + @property + def best_params(self): + if self._best_params is None: + return None + if self.use_units: + params_with_units = {p: Quantity(v, dim=self.model[p].dim) + for p, v in self._best_params.items()} + return params_with_units + else: + return self._best_params + + @property + def best_error(self): + if self._best_error is None: + return None + if self.use_units: + error_dim = self.metric.get_dimensions(self.output_dim) + return Quantity(self._best_error, dim=error_dim) + else: + return self._best_error - def results(self, format='list'): + def results(self, format='list', use_units=None): """ Returns all of the gathered results (parameters and errors). In one of the 3 formats: 'dataframe', 'list', 'dict'. @@ -544,6 +589,10 @@ def results(self, format='list'): format: str The desired output format. Currently supported: ``dataframe``, ``list``, or ``dict``. + use_units: bool, optional + Whether to use units in the results. If not specified, defaults to + `.Tracefitter.use_units`, i.e. the value that was specified when + the `.Tracefitter` object was created (``True`` by default). Returns ------- @@ -552,40 +601,57 @@ def results(self, format='list'): 'list': list of dictionaries 'dict': dictionary of lists """ + if use_units is None: + use_units = self.use_units names = list(self.parameter_names) - names.append('errors') params = array(self.optimizer.tested_parameters) params = params.reshape(-1, params.shape[-1]) - errors = array([array(self.optimizer.errors).flatten()]) - data = concatenate((params, errors.transpose()), axis=1) + if use_units: + error_dim = self.metric.get_dimensions(self.output_dim) + errors = Quantity(array(self.optimizer.errors).flatten(), + dim=error_dim) + else: + errors = array(array(self.optimizer.errors).flatten()) + dim = self.model.dimensions if format == 'list': res_list = [] for j in arange(0, len(params)): - temp_data = data[j] + temp_data = params[j] res_dict = dict() - for i, n in enumerate(names[:-1]): - res_dict[n] = Quantity(temp_data[i], dim=dim[n]) - res_dict[names[-1]] = temp_data[-1] + for i, n in enumerate(names): + if use_units: + res_dict[n] = Quantity(temp_data[i], dim=dim[n]) + else: + res_dict[n] = float(temp_data[i]) + res_dict['error'] = errors[j] res_list.append(res_dict) return res_list elif format == 'dict': res_dict = dict() - for i, n in enumerate(names[:-1]): - res_dict[n] = Quantity(data[:, i], dim=dim[n]) + for i, n in enumerate(names): + if use_units: + res_dict[n] = Quantity(params[:, i], dim=dim[n]) + else: + res_dict[n] = array(params[:, i]) - res_dict[names[-1]] = data[:, -1] + res_dict['error'] = errors return res_dict elif format == 'dataframe': from pandas import DataFrame - return DataFrame(data=data, columns=names) + if use_units: + logger.warn('Results in dataframes do not support units. ' + 'Specify "use_units=False" to avoid this warning.', + name_suffix='dataframe_units') + data = concatenate((params, array(errors)[None, :].transpose()), axis=1) + return DataFrame(data=data, columns=names + ['error']) def generate(self, params=None, output_var=None, param_init=None, level=0): """ @@ -640,14 +706,34 @@ def generate(self, params=None, output_var=None, param_init=None, level=0): class TraceFitter(Fitter): - """Input and output have to have the same dimensions.""" + """ + A `Fitter` for fitting recorded traces (e.g. of the membrane potential). + + Parameters + ---------- + model + input_var + input + output_var + output + dt + n_samples + method + reset + refractory + threshold + param_init + use_units: bool, optional + Whether to use units in all user-facing interfaces, e.g. in the callback + arguments or in the returned parameter dictionary and errors. Defaults + to ``True``. + """ def __init__(self, model, input_var, input, output_var, output, dt, n_samples=30, method=None, reset=None, refractory=False, - threshold=None, param_init=None): - """Initialize the fitter.""" + threshold=None, param_init=None, use_units=True): super().__init__(dt, model, input, output, input_var, output_var, n_samples, threshold, reset, refractory, method, - param_init) + param_init, use_units=use_units) # We store the bounds set in TraceFitter.fit, so that Tracefitter.refine # can reuse them self.bounds = None @@ -677,10 +763,11 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text', raise TypeError("You can only use TraceMetric child metric with " "TraceFitter") self.bounds = dict(params) - self.best_params, error = super().fit(optimizer, metric, n_rounds, - callback, restart, level=level+1, - **params) - return self.best_params, error + best_params, error = super().fit(optimizer, metric, n_rounds, + callback, restart, + level=level+1, + **params) + return best_params, error def generate_traces(self, params=None, param_init=None, level=0): """Generates traces for best fit of parameters and all inputs""" @@ -829,13 +916,22 @@ def _calc_gradient(params): errors = [] def _callback_wrapper(params, iter, resid, *args, **kwds): error = mean(resid**2) - params = {p: float(val) for p, val in params.items()} - tested_parameters.append(params) errors.append(error) + if self.use_units: + error_dim = self.output_dim**2 * get_dimensions(normalization)**2 + all_errors = Quantity(errors, dim=error_dim) + params = {p: Quantity(val, dim=self.model[p].dim) + for p, val in params.items()} + else: + all_errors = array(errors) + params = {p: float(val) for p, val in params.items()} + tested_parameters.append(params) + best_idx = argmin(errors) - best_error = errors[best_idx] + best_error = all_errors[best_idx] best_params = tested_parameters[best_idx] - return callback_func(params, array(errors), + + return callback_func(params, all_errors, best_params, best_error, iter) assert 'Dfun' not in kwds @@ -858,19 +954,26 @@ def _callback_wrapper(params, iter, resid, *args, **kwds): iter_cb=iter_cb, **kwds) - return {p: float(val) for p, val in result.params.items()}, result + if self.use_units: + param_dict = {p: Quantity(float(val), dim=self.model[p].dim) + for p, val in result.params.items()} + else: + param_dict = {p: float(val) + for p, val in result.params.items()} + + return param_dict, result class SpikeFitter(Fitter): def __init__(self, model, input, output, dt, reset, threshold, input_var='I', refractory=False, n_samples=30, - method=None, param_init=None): + method=None, param_init=None, use_units=True): """Initialize the fitter.""" if method is None: method = 'exponential_euler' super().__init__(dt, model, input, output, input_var, 'v', n_samples, threshold, reset, refractory, method, - param_init) + param_init, use_units=use_units) self.output_var = 'spikes' if param_init: @@ -897,10 +1000,10 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text', if not isinstance(metric, SpikeMetric): raise TypeError("You can only use SpikeMetric child metric with " "SpikeFitter") - self.best_params, error = super().fit(optimizer, metric, n_rounds, - callback, restart, level=level+1, - **params) - return self.best_params, error + best_params, error = super().fit(optimizer, metric, n_rounds, + callback, restart, level=level+1, + **params) + return best_params, error def generate_spikes(self, params=None, param_init=None, level=0): """Generates traces for best fit of parameters and all inputs""" @@ -943,8 +1046,9 @@ def __init__(self, model, input_var, input, output_var, output, dt, self.simulator = None - def fit(self, optimizer, metric=None, n_rounds=1, callback='text', + def fit(self, optimizer, n_rounds=1, callback='text', restart=False, level=0, **params): + metric = MSEMetric() # not used, but makes error dimensions correct return super(OnlineTraceFitter, self).fit(optimizer, metric=metric, n_rounds=n_rounds, callback=callback, diff --git a/brian2modelfitting/metric.py b/brian2modelfitting/metric.py index d27b720b..5857c461 100644 --- a/brian2modelfitting/metric.py +++ b/brian2modelfitting/metric.py @@ -7,8 +7,8 @@ except ImportError: warnings.warn('eFEL package not found.') from itertools import repeat -from brian2 import Hz, second, Quantity, ms, us -from brian2.units.fundamentalunits import check_units, in_unit +from brian2 import Hz, second, Quantity, ms, us, get_dimensions +from brian2.units.fundamentalunits import check_units, in_unit, DIMENSIONLESS from numpy import (array, sum, square, reshape, abs, amin, digitize, rint, arange, atleast_2d, NaN, float64, split, shape,) @@ -136,6 +136,46 @@ def __init__(self, t_start=0*second, normalization=1., **kwds): self.t_start = t_start self.normalization = 1/normalization + def get_dimensions(self, output_dim): + """ + The physical dimensions of the error. In metrics such as `MSEMetric`, + this depends on the dimensions of the output variable (e.g. if the + output variable has units of volts, the mean squared error will have + units of voltĀ²); in other metrics, e.g. `FeatureMetric`, this cannot + be defined in a meaningful way since the metric combines different + types of errors. In cases where defining dimensions is not meaningful, + this method should return `DIMENSIONLESS`. + + Parameters + ---------- + output_dim : `.Dimension` + The dimensions of the output variable. + + Returns + ------- + dim : `.Dimension` + The physical dimensions of the error. + """ + return DIMENSIONLESS + + def get_normalized_dimensions(self, output_dim): + """ + The physical dimensions of the normalized error. This will be + the same as the dimensions returned by `~.Metric.get_dimensions` if + the ``normalization`` is not used or set to a dimensionless value. + + Parameters + ---------- + output_dim : `.Dimension` + The dimensions of the output variable. + + Returns + ------- + dim : `.Dimension` + The physical dimensions of the normalized error. + """ + return DIMENSIONLESS + @abc.abstractmethod def get_features(self, model_results, target_results, dt): """ @@ -226,8 +266,8 @@ def calc(self, model_traces, data_traces, dt): ``(n_samples, )``. """ start_steps = int(round(self.t_start/dt)) - features = self.get_features(model_traces[:, :, start_steps:] * self.normalization, - data_traces[:, start_steps:] * self.normalization, + features = self.get_features(model_traces[:, :, start_steps:] * float(self.normalization), + data_traces[:, start_steps:] * float(self.normalization), dt) errors = self.get_errors(features) @@ -309,7 +349,7 @@ def calc(self, model_spikes, data_spikes, dt): model_spikes = relevant_model_spikes data_spikes = relevant_data_spikes features = self.get_features(model_spikes, data_spikes, dt) - errors = self.get_errors(features) * self.normalization + errors = self.get_errors(features) * float(self.normalization) return errors @@ -355,6 +395,11 @@ def get_features(self, model_traces, data_traces, dt): def get_errors(self, features): return features.mean(axis=1) + def get_dimensions(self, output_dim): + return output_dim**2 + + def get_normalized_dimensions(self, output_dim): + return output_dim**2 * get_dimensions(self.normalization)**2 class FeatureMetric(TraceMetric): def __init__(self, stim_times, feat_list, weights=None, combine=None, @@ -457,7 +502,7 @@ def get_errors(self, features): sample_error += total errors.append(sample_error) - return array(errors) * self.normalization + return array(errors) * float(self.normalization) class GammaFactor(SpikeMetric): @@ -509,7 +554,7 @@ def get_features(self, traces, output, dt): rate_correction=self.rate_correction) gf_for_sample.append(gf) all_gf.append(gf_for_sample) - return array(all_gf) * self.normalization + return array(all_gf) * float(self.normalization) def get_errors(self, features): errors = features.mean(axis=1) diff --git a/brian2modelfitting/tests/test_modelfitting_tracefitter.py b/brian2modelfitting/tests/test_modelfitting_tracefitter.py index 9a3f07c5..5dc8adc2 100644 --- a/brian2modelfitting/tests/test_modelfitting_tracefitter.py +++ b/brian2modelfitting/tests/test_modelfitting_tracefitter.py @@ -63,6 +63,23 @@ def fin(): return dt, tf +@pytest.fixture +def setup_no_units(request): + dt = 0.01 * ms + tf = TraceFitter(dt=dt, + model=model, + input_var='v', + output_var='I', + input=input_traces, + output=output_traces, + n_samples=2, + use_units=False) + + def fin(): + reinit_devices() + request.addfinalizer(fin) + + return dt, tf @pytest.fixture def setup_constant(request): @@ -161,6 +178,7 @@ def test_tracefitter_init(setup): assert isinstance(tf.model, Equations) + def test_tracefitter_init_errors(setup): dt, _ = setup with pytest.raises(Exception): @@ -202,10 +220,37 @@ def test_fitter_fit(setup): assert isinstance(tf.simulator, Simulator) assert isinstance(results, dict) + assert all(isinstance(v, Quantity) for v in results.values()) + assert isinstance(errors, Quantity) + assert 'g' in results.keys() + + assert_equal(results, tf.best_params) + assert_equal(errors, tf.best_error) + + +def test_fitter_fit_no_units(setup_no_units): + dt, tf = setup_no_units + results, errors = tf.fit(n_rounds=2, + optimizer=n_opt, + metric=metric, + g=[1*nS, 30*nS], + callback=None) + + attr_fit = ['optimizer', 'metric', 'best_params'] + for attr in attr_fit: + assert hasattr(tf, attr) + + assert isinstance(tf.metric, Metric) + assert isinstance(tf.optimizer, Optimizer) + assert isinstance(tf.simulator, Simulator) + + assert isinstance(results, dict) + assert all(isinstance(v, float) for v in results.values()) assert isinstance(errors, float) assert 'g' in results.keys() assert_equal(results, tf.best_params) + assert_equal(errors, tf.best_error) def test_fitter_fit_callback(setup): @@ -217,7 +262,7 @@ def our_callback(params, errors, best_params, best_error, index): assert all(isinstance(p, dict) for p in params) assert isinstance(errors, np.ndarray) assert isinstance(best_params, dict) - assert isinstance(best_error, float) + assert isinstance(best_error, Quantity) assert isinstance(index, int) results, errors = tf.fit(n_rounds=2, optimizer=n_opt, @@ -234,7 +279,7 @@ def our_callback(params, errors, best_params, best_error, index): assert all(isinstance(p, dict) for p in params) assert isinstance(errors, np.ndarray) assert isinstance(best_params, dict) - assert isinstance(best_error, float) + assert isinstance(best_error, Quantity) assert isinstance(index, int) return True # stop @@ -269,7 +314,7 @@ def test_fitter_fit_tstart(setup_constant): metric=MSEMetric(t_start=50*dt), c=[0 * mV, 30 * mV]) # Fit should be close to 20mV - assert np.abs(params['c']*volt - 20*mV) < 1*mV + assert np.abs(params['c'] - 20*mV) < 1*mV @pytest.mark.skipif(lmfit is None, reason="needs lmfit package") def test_fitter_refine(setup): @@ -376,7 +421,7 @@ def test_fitter_refine_tstart(setup_constant): t_start=50*dt) # Fit should be close to 20mV - assert np.abs(params['c']*volt - 20*mV) < 1*mV + assert np.abs(params['c'] - 20*mV) < 1*mV @pytest.mark.skipif(lmfit is None, reason="needs lmfit package") @@ -391,7 +436,7 @@ def test_fitter_refine_reuse_tstart(setup_constant): params, result = tf.refine({'c': 5 * mV}, c=[0 * mV, 30 * mV]) # Fit should be close to 20mV - assert np.abs(params['c'] * volt - 20 * mV) < 1 * mV + assert np.abs(params['c'] - 20 * mV) < 1 * mV @pytest.mark.skipif(lmfit is None, reason="needs lmfit package") @@ -415,7 +460,7 @@ def our_callback(params, errors, best_params, best_error, index): assert isinstance(params, dict) assert isinstance(errors, np.ndarray) assert isinstance(best_params, dict) - assert isinstance(best_error, float) + assert isinstance(best_error, Quantity) assert isinstance(index, int) tf.refine({'g': 5 * nS}, g=[1 * nS, 30 * nS], callback=our_callback) @@ -543,7 +588,7 @@ def test_fitter_generate_traces_standalone(setup_standalone): assert_equal(np.shape(traces), np.shape(output_traces)) -def test_fitter_results(setup): +def test_fitter_results(setup, caplog): dt, tf = setup best_params, errors = tf.fit(n_rounds=2, optimizer=n_opt, @@ -554,9 +599,10 @@ def test_fitter_results(setup): params_list = tf.results(format='list') assert isinstance(params_list, list) assert isinstance(params_list[0], dict) + print(params_list) assert isinstance(params_list[0]['g'], Quantity) assert 'g' in params_list[0].keys() - assert 'errors' in params_list[0].keys() + assert 'error' in params_list[0].keys() assert_equal(np.shape(params_list), (4,)) assert_equal(len(params_list[0]), 2) assert have_same_dimensions(params_list[0]['g'].dim, nS) @@ -564,17 +610,53 @@ def test_fitter_results(setup): params_dic = tf.results(format='dict') assert isinstance(params_dic, dict) assert 'g' in params_dic.keys() - assert 'errors' in params_dic.keys() + assert 'error' in params_dic.keys() assert isinstance(params_dic['g'], Quantity) assert_equal(len(params_dic), 2) assert_equal(np.shape(params_dic['g']), (4,)) - assert_equal(np.shape(params_dic['errors']), (4,)) + assert_equal(np.shape(params_dic['error']), (4,)) + + # Should raise a warning because dataframe cannot have units + assert len(caplog.records) == 0 + params_df = tf.results(format='dataframe') + assert len(caplog.records) == 1 + assert isinstance(params_df, pd.DataFrame) + assert_equal(params_df.shape, (4, 2)) + assert 'g' in params_df.keys() + assert 'error' in params_df.keys() + + +def test_fitter_results_no_units(setup_no_units, caplog): + dt, tf = setup_no_units + tf.fit(n_rounds=2, + optimizer=n_opt, + metric=metric, + g=[1*nS, 30*nS], + restart=False) + + params_list = tf.results(format='list') + assert isinstance(params_list, list) + assert isinstance(params_list[0], dict) + assert isinstance(params_list[0]['g'], float) + assert 'g' in params_list[0].keys() + assert 'error' in params_list[0].keys() + assert_equal(np.shape(params_list), (4,)) + assert_equal(len(params_list[0]), 2) + + params_dic = tf.results(format='dict') + assert isinstance(params_dic, dict) + assert 'g' in params_dic.keys() + assert 'error' in params_dic.keys() + assert isinstance(params_dic['g'], np.ndarray) + assert_equal(len(params_dic), 2) + assert_equal(np.shape(params_dic['g']), (4,)) + assert_equal(np.shape(params_dic['error']), (4,)) params_df = tf.results(format='dataframe') assert isinstance(params_df, pd.DataFrame) assert_equal(params_df.shape, (4, 2)) assert 'g' in params_df.keys() - assert 'errors' in params_df.keys() + assert 'error' in params_df.keys() # OnlineTraceFitter @@ -631,16 +713,16 @@ def test_onlinetracefitter_fit(setup_online): optimizer=n_opt, g=[1*nS, 30*nS], restart=False,) - + print(otf.best_params) attr_fit = ['optimizer', 'metric', 'best_params'] for attr in attr_fit: assert hasattr(otf, attr) - assert otf.metric is None + assert isinstance(otf.metric, MSEMetric) assert isinstance(otf.optimizer, Optimizer) assert isinstance(results, dict) - assert isinstance(errors, float) + assert isinstance(errors, Quantity) assert 'g' in results.keys() assert_equal(results, otf.best_params) diff --git a/brian2modelfitting/utils.py b/brian2modelfitting/utils.py index 21ec168f..effb1ffb 100644 --- a/brian2modelfitting/utils.py +++ b/brian2modelfitting/utils.py @@ -4,8 +4,8 @@ def callback_text(params, errors, best_params, best_error, index): """Default callback print-out for Fitters""" - param_str = ', '.join([f"{p}={v}" for p, v in sorted(best_params.items())]) - print(f"Round {index}: Best parameters {param_str} (error: {best_error})") + param_str = ', '.join([f"{p}={v!s}" for p, v in sorted(best_params.items())]) + print(f"Round {index}: Best parameters {param_str} (error: {best_error!s})") def callback_none(params, errors, best_params, best_error, index):