Skip to content

Commit

Permalink
Fix type check if several metrics are provided
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Nov 10, 2020
1 parent aaefa48 commit 9181970
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import numbers
from typing import Sequence

import sympy
from numpy import ones, array, arange, concatenate, mean, argmin, nanmin, reshape, zeros, sqrt, ndarray, broadcast_to, sum
Expand Down Expand Up @@ -900,15 +901,19 @@ def calc_errors(self, metric):
def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
restart=False, start_iteration=None, penalty=None,
level=0, **params):
if not isinstance(metric, TraceMetric):
raise TypeError("You can only use TraceMetric child metric with "
"TraceFitter")
if metric.t_weights is not None:
if not metric.t_weights.shape == (self.output[0].shape[1], ):
raise ValueError("The 't_weights' argument of the metric has "
"to be a one-dimensional array of length "
f"{self.output[0].shape[1]} but has shape "
f"{metric.t_weights.shape}")
if not isinstance(metric, Sequence):
metric = [metric] * len(self.output_var)
for single_metric in metric:
if not isinstance(single_metric, TraceMetric):
raise TypeError("Metric has to be a 'TraceMetric', but is "
f"type '{type(single_metric)}'.")
for single_metric, output in zip(metric, self.output):
if single_metric.t_weights is not None:
if not single_metric.t_weights.shape == (output.shape[1], ):
raise ValueError("The 't_weights' argument of the metric has "
"to be a one-dimensional array of length "
f"{output.shape[1]} but has shape "
f"{single_metric.t_weights.shape}")
self.bounds = dict(params)
best_params, error = super().fit(optimizer=optimizer,
metric=metric,
Expand Down Expand Up @@ -1218,9 +1223,12 @@ def calc_errors(self, metric):
def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
restart=False, start_iteration=None, penalty=None,
level=0, **params):
if not isinstance(metric, SpikeMetric):
raise TypeError("You can only use SpikeMetric child metric with "
"SpikeFitter")
if not isinstance(metric, Sequence):
metric = [metric] * len(self.output_var)
for single_metric in metric:
if not isinstance(single_metric, SpikeMetric):
raise TypeError("Metric has to be a 'SpikeMetric', but is "
f"type '{type(single_metric)}'.")
best_params, error = super().fit(optimizer=optimizer,
metric=metric,
n_rounds=n_rounds,
Expand Down

0 comments on commit 9181970

Please sign in to comment.