Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions fooof/objs/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _reset_internal_settings(self):

# Bandwidth limits are given in 2-sided peak bandwidth
# Convert to gaussian std parameter limits
self._gauss_std_limits = tuple([bwl / 2 for bwl in self.peak_width_limits])
self._gauss_std_limits = tuple(bwl / 2 for bwl in self.peak_width_limits)

# Otherwise, assume settings are unknown (have been cleared) and set to None
else:
Expand Down Expand Up @@ -378,7 +378,8 @@ def add_results(self, fooof_result):
self._check_loaded_results(fooof_result._asdict())


def report(self, freqs=None, power_spectrum=None, freq_range=None, plt_log=False, **plot_kwargs):
def report(self, freqs=None, power_spectrum=None, freq_range=None,
plt_log=False, **plot_kwargs):
"""Run model fit, and display a report, which includes a plot, and printed results.

Parameters
Expand Down Expand Up @@ -791,9 +792,10 @@ def _simple_ap_fit(self, freqs, power_spectrum):
aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode),
freqs, power_spectrum, p0=guess,
maxfev=self._maxfev, bounds=ap_bounds)
except RuntimeError:
raise FitError("Model fitting failed due to not finding parameters in "
"the simple aperiodic component fit.")
except RuntimeError as excp:
error_msg = ("Model fitting failed due to not finding parameters in "
"the simple aperiodic component fit.")
raise FitError(error_msg) from excp

return aperiodic_params

Expand Down Expand Up @@ -847,12 +849,14 @@ def _robust_ap_fit(self, freqs, power_spectrum):
aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode),
freqs_ignore, spectrum_ignore, p0=popt,
maxfev=self._maxfev, bounds=ap_bounds)
except RuntimeError:
raise FitError("Model fitting failed due to not finding "
"parameters in the robust aperiodic fit.")
except TypeError:
raise FitError("Model fitting failed due to sub-sampling "
"in the robust aperiodic fit.")
except RuntimeError as excp:
error_msg = ("Model fitting failed due to not finding "
"parameters in the robust aperiodic fit.")
raise FitError(error_msg) from excp
except TypeError as excp:
error_msg = ("Model fitting failed due to sub-sampling "
"in the robust aperiodic fit.")
raise FitError(error_msg) from excp

return aperiodic_params

Expand Down Expand Up @@ -981,8 +985,8 @@ def _fit_peak_guess(self, guess):

# Unpacks the embedded lists into flat tuples
# This is what the fit function requires as input
gaus_param_bounds = (tuple([item for sublist in lo_bound for item in sublist]),
tuple([item for sublist in hi_bound for item in sublist]))
gaus_param_bounds = (tuple(item for sublist in lo_bound for item in sublist),
tuple(item for sublist in hi_bound for item in sublist))

# Flatten guess, for use with curve fit
guess = np.ndarray.flatten(guess)
Expand All @@ -991,13 +995,15 @@ def _fit_peak_guess(self, guess):
try:
gaussian_params, _ = curve_fit(gaussian_function, self.freqs, self._spectrum_flat,
p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds)
except RuntimeError:
raise FitError("Model fitting failed due to not finding "
"parameters in the peak component fit.")
except LinAlgError:
raise FitError("Model fitting failed due to a LinAlgError during peak fitting. "
"This can happen with settings that are too liberal, leading, "
"to a large number of guess peaks that cannot be fit together.")
except RuntimeError as excp:
error_msg = ("Model fitting failed due to not finding "
"parameters in the peak component fit.")
raise FitError(error_msg) from excp
except LinAlgError as excp:
error_msg = ("Model fitting failed due to a LinAlgError during peak fitting. "
"This can happen with settings that are too liberal, leading, "
"to a large number of guess peaks that cannot be fit together.")
raise FitError(error_msg) from excp

# Re-organize params into 2d matrix
gaussian_params = np.array(group_three(gaussian_params))
Expand Down Expand Up @@ -1164,8 +1170,8 @@ def _calc_error(self, metric=None):
self.error_ = np.sqrt(((self.power_spectrum - self.fooofed_spectrum_) ** 2).mean())

else:
msg = "Error metric '{}' not understood or not implemented.".format(metric)
raise ValueError(msg)
error_msg = "Error metric '{}' not understood or not implemented.".format(metric)
raise ValueError(error_msg)


def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1):
Expand Down Expand Up @@ -1258,10 +1264,11 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1):
if self._check_data:
# Check if there are any infs / nans, and raise an error if so
if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)):
raise DataError("The input power spectra data, after logging, contains NaNs or Infs. "
"This will cause the fitting to fail. "
"One reason this can happen is if inputs are already logged. "
"Inputs data should be in linear spacing, not log.")
error_msg = ("The input power spectra data, after logging, contains NaNs or Infs. "
"This will cause the fitting to fail. "
"One reason this can happen is if inputs are already logged. "
"Inputs data should be in linear spacing, not log.")
raise DataError(error_msg)

return freqs, power_spectrum, freq_range, freq_res

Expand Down
2 changes: 1 addition & 1 deletion fooof/objs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def combine_fooofs(fooofs):
fg.power_spectra = temp_power_spectra

# Set the check data mode, as True if any of the inputs have it on, False otherwise
fg.set_check_data_mode(any([getattr(f_obj, '_check_data') for f_obj in fooofs]))
fg.set_check_data_mode(any(getattr(f_obj, '_check_data') for f_obj in fooofs))

# Add data information information
fg.add_meta_data(fooofs[0].get_meta_data())
Expand Down