diff --git a/fooof/analysis/error.py b/fooof/analysis/error.py index 76e8eda5e..ceb85c2d0 100644 --- a/fooof/analysis/error.py +++ b/fooof/analysis/error.py @@ -102,7 +102,7 @@ def compute_pointwise_error_fg(fg, plot_errors=True, return_errors=False, **plt_ def compute_pointwise_error(model, data): - """Calculate pointwise error between original data and a model fit of that data. + """Calculate point-wise error between original data and a model fit of that data. Parameters ---------- diff --git a/fooof/bands/bands.py b/fooof/bands/bands.py index 2a03c6502..6a07a687a 100644 --- a/fooof/bands/bands.py +++ b/fooof/bands/bands.py @@ -1,4 +1,4 @@ -"""A data oject for managing band definitions.""" +"""A data object for managing band definitions.""" from collections import OrderedDict @@ -60,7 +60,7 @@ def __len__(self): return self.n_bands def __iter__(self): - """Define iteratation as stepping across each band.""" + """Define iteration as stepping across each band.""" for label, band_definition in self.bands.items(): yield (label, band_definition) diff --git a/fooof/core/info.py b/fooof/core/info.py index dbf344ee6..9cd9c3e49 100644 --- a/fooof/core/info.py +++ b/fooof/core/info.py @@ -70,7 +70,7 @@ def get_ap_indices(aperiodic_mode): Returns ------- indices : dict - Mapping of the column labels and indices for the aperiodc parameters. + Mapping of the column labels and indices for the aperiodic parameters. """ if aperiodic_mode == 'fixed': diff --git a/fooof/core/io.py b/fooof/core/io.py index c135f4933..037238ae1 100644 --- a/fooof/core/io.py +++ b/fooof/core/io.py @@ -198,7 +198,7 @@ def load_json(file_name, file_path): def load_jsonlines(file_name, file_path): - """Load a jsonlines file, yielding data line by line. + """Load a json-lines file, yielding data line by line. Parameters ---------- diff --git a/fooof/core/modutils.py b/fooof/core/modutils.py index 1027797bb..c342a52c9 100644 --- a/fooof/core/modutils.py +++ b/fooof/core/modutils.py @@ -79,7 +79,7 @@ def docs_append_to_section(docstring, section, add): Parameters ---------- - ds : str + docstring : str Docstring to update. section : str Name of the section within the docstring to add to. diff --git a/fooof/core/reports.py b/fooof/core/reports.py index a3dc69207..20323b1ce 100644 --- a/fooof/core/reports.py +++ b/fooof/core/reports.py @@ -16,6 +16,7 @@ REPORT_FONT = {'family': 'monospace', 'weight': 'normal', 'size': 16} +SAVE_FORMAT = 'pdf' ################################################################################################### ################################################################################################### @@ -61,7 +62,7 @@ def save_report_fm(fm, file_name, file_path=None, plt_log=False): ax2.set_yticks([]) # Save out the report - plt.savefig(fpath(file_path, fname(file_name, 'pdf'))) + plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) plt.close() @@ -104,5 +105,5 @@ def save_report_fg(fg, file_name, file_path=None): plot_fg_peak_cens(fg, ax3) # Save out the report - plt.savefig(fpath(file_path, fname(file_name, 'pdf'))) + plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) plt.close() diff --git a/fooof/core/utils.py b/fooof/core/utils.py index 2a224432b..f0fcdb588 100644 --- a/fooof/core/utils.py +++ b/fooof/core/utils.py @@ -210,7 +210,7 @@ def check_inds(inds): This function works only on indices defined for 1 dimension. """ - # Typcasting: if a single int, convert to an array + # Typecasting: if a single int, convert to an array if isinstance(inds, int): inds = np.array([inds]) # Typecasting: if a list or range, convert to an array diff --git a/fooof/objs/fit.py b/fooof/objs/fit.py index e5b007b77..009671b55 100644 --- a/fooof/objs/fit.py +++ b/fooof/objs/fit.py @@ -238,14 +238,10 @@ def _reset_internal_settings(self): # Bandwidth limits are given in 2-sided peak bandwidth # Convert to gaussian std parameter limits self._gauss_std_limits = tuple([bwl / 2 for bwl in self.peak_width_limits]) - # Bounds for aperiodic fitting. Drops bounds on knee parameter if not set to fit knee - self._ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ - else tuple(bound[0::2] for bound in self._ap_bounds) # Otherwise, assume settings are unknown (have been cleared) and set to None else: self._gauss_std_limits = None - self._ap_bounds = None def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False): @@ -286,7 +282,7 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_res self._peak_fit = None - def add_data(self, freqs, power_spectrum, freq_range=None): + def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True): """Add data (frequencies, and power spectrum values) to the current object. Parameters @@ -298,6 +294,9 @@ def add_data(self, freqs, power_spectrum, freq_range=None): freq_range : list of [float, float], optional Frequency range to restrict power spectrum to. If not provided, keeps the entire range. + clear_results : bool, optional, default: True + Whether to clear prior results, if any are present in the object. + This should only be set to False if data for the current results are being re-added. Notes ----- @@ -305,10 +304,12 @@ def add_data(self, freqs, power_spectrum, freq_range=None): they will be cleared by this method call. """ - # If any data is already present, then clear data & results + # If any data is already present, then clear previous data + # Also clear results, if present, unless indicated not to # This is to ensure object consistency of all data & results - if np.any(self.freqs): - self._reset_data_results(True, True, True) + self._reset_data_results(clear_freqs=self.has_data, + clear_spectrum=self.has_data, + clear_results=self.has_model and clear_results) self.freqs, self.power_spectrum, self.freq_range, self.freq_res = \ self._prepare_data(freqs, power_spectrum, freq_range, 1, self.verbose) @@ -717,6 +718,10 @@ def _simple_ap_fit(self, freqs, power_spectrum): np.log10(self.freqs[-1]) - np.log10(self.freqs[0])) if not self._ap_guess[2] else self._ap_guess[2]] + # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee + ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ + else tuple(bound[0::2] for bound in self._ap_bounds) + # Collect together guess parameters guess = np.array([off_guess + kne_guess + exp_guess]) @@ -729,7 +734,7 @@ def _simple_ap_fit(self, freqs, power_spectrum): warnings.simplefilter("ignore") aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), freqs, power_spectrum, p0=guess, - maxfev=self._maxfev, bounds=self._ap_bounds) + maxfev=self._maxfev, bounds=ap_bounds) except RuntimeError: raise FitError("Model fitting failed due to not finding parameters in " "the simple aperiodic component fit.") @@ -774,6 +779,10 @@ def _robust_ap_fit(self, freqs, power_spectrum): freqs_ignore = freqs[perc_mask] spectrum_ignore = power_spectrum[perc_mask] + # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee + ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ + else tuple(bound[0::2] for bound in self._ap_bounds) + # Second aperiodic fit - using results of first fit as guess parameters # See note in _simple_ap_fit about warnings try: @@ -781,7 +790,7 @@ def _robust_ap_fit(self, freqs, power_spectrum): warnings.simplefilter("ignore") aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), freqs_ignore, spectrum_ignore, p0=popt, - maxfev=self._maxfev, bounds=self._ap_bounds) + maxfev=self._maxfev, bounds=ap_bounds) except RuntimeError: raise FitError("Model fitting failed due to not finding " "parameters in the robust aperiodic fit.") @@ -851,7 +860,7 @@ def _fit_peaks(self, flat_iter): guess_std = compute_gauss_std(fwhm) except ValueError: - # This procedure can fail (extremely rarely), if both le & ri ind's end up as None + # This procedure can fail (very rarely), if both left & right inds end up as None # In this case, default the guess to the average of the peak width limits guess_std = np.mean(self.peak_width_limits) @@ -1027,21 +1036,21 @@ def _drop_peak_overlap(self, guess): Notes ----- For any gaussians with an overlap that crosses the threshold, - the lowest height guess guassian is dropped. + the lowest height guess Gaussian is dropped. """ - # Sort the peak guesses by increasing frequency, so adjacenent peaks can - # be compared from right to left. + # Sort the peak guesses by increasing frequency + # This is so adjacent peaks can be compared from right to left guess = sorted(guess, key=lambda x: float(x[0])) # Calculate standard deviation bounds for checking amount of overlap - # The bounds are the gaussian frequncy +/- gaussian standard deviation + # The bounds are the gaussian frequency +/- gaussian standard deviation bounds = [[peak[0] - peak[2] * self._gauss_overlap_thresh, peak[0] + peak[2] * self._gauss_overlap_thresh] for peak in guess] # Loop through peak bounds, comparing current bound to that of next peak # If the left peak's upper bound extends pass the right peaks lower bound, - # Then drop the guassian with the lower height. + # then drop the Gaussian with the lower height drop_inds = [] for ind, b_0 in enumerate(bounds[:-1]): b_1 = bounds[ind + 1] diff --git a/fooof/objs/group.py b/fooof/objs/group.py index 8bd293d3a..937563242 100644 --- a/fooof/objs/group.py +++ b/fooof/objs/group.py @@ -557,7 +557,7 @@ def _check_width_limits(self): """Check and warn about bandwidth limits / frequency resolution interaction.""" # Only check & warn on first power spectrum - # This is to avoid spamming stdout for every spectrum in the group + # This is to avoid spamming standard output for every spectrum in the group if self.power_spectra[0, 0] == self.power_spectrum[0]: super()._check_width_limits() diff --git a/fooof/objs/utils.py b/fooof/objs/utils.py index 253a4f13b..65bec0f5b 100644 --- a/fooof/objs/utils.py +++ b/fooof/objs/utils.py @@ -138,7 +138,7 @@ def combine_fooofs(fooofs): -------- Combine FOOOF objects together (where `fm1`, `fm2` & `fm3` are assumed to be defined and fit): - >>> fg = combine_fooofs([fm1, fm2, f3]) # doctest:+SKIP + >>> fg = combine_fooofs([fm1, fm2, fm3]) # doctest:+SKIP Combine FOOOFGroup objects together (where `fg1` & `fg2` are assumed to be defined and fit): diff --git a/fooof/plts/aperiodic.py b/fooof/plts/aperiodic.py index b20038b81..fafc1dd5e 100644 --- a/fooof/plts/aperiodic.py +++ b/fooof/plts/aperiodic.py @@ -70,9 +70,9 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, Aperiodic parameters. Each row is a parameter set, as [Off, Exp] or [Off, Knee, Exp]. freq_range : list of [float, float] The frequency range to plot the peak fits across, as [f_min, f_max]. - control_offset : boolean, optonal, default: False + control_offset : boolean, optional, default: False Whether to control for the offset, by setting it to zero. - log_freqs : boolean, optonal, default: False + log_freqs : boolean, optional, default: False Whether to plot the x-axis in log space. colors : str or list of str, optional Color(s) to plot data. diff --git a/fooof/plts/fm.py b/fooof/plts/fm.py index 6818e4840..68be6ea5b 100644 --- a/fooof/plts/fm.py +++ b/fooof/plts/fm.py @@ -283,7 +283,7 @@ def _add_peaks_width(fm, plt_log, ax, **plot_kwargs): Notes ----- - This line representents the bandwidth (width or gaussian standard deviation) of + This line represents the bandwidth (width or gaussian standard deviation) of the peak, though what is literally plotted is the full-width half-max. """ diff --git a/fooof/tests/analysis/test_error.py b/fooof/tests/analysis/test_error.py index d14307f9f..2a87e1512 100644 --- a/fooof/tests/analysis/test_error.py +++ b/fooof/tests/analysis/test_error.py @@ -11,7 +11,7 @@ def test_compute_pointwise_error_fm(tfm): assert np.all(errs) def test_compute_pointwise_error_fm_plt(tfm, skip_if_no_mpl): - """Run a seperate test to run with plot pass-through.""" + """Run a separate test to run with plot pass-through.""" compute_pointwise_error_fm(tfm, True, False) @@ -21,7 +21,7 @@ def test_compute_pointwise_error_fg(tfg): assert np.all(errs) def test_compute_pointwise_error_fg_plt(tfg, skip_if_no_mpl): - """Run a seperate test to run with plot pass-through.""" + """Run a separate test to run with plot pass-through.""" compute_pointwise_error_fg(tfg, True, False) diff --git a/fooof/tests/objs/test_fit.py b/fooof/tests/objs/test_fit.py index 430b55eed..6a620e893 100644 --- a/fooof/tests/objs/test_fit.py +++ b/fooof/tests/objs/test_fit.py @@ -226,36 +226,68 @@ def test_fooof_load(): for meta_dat in OBJ_DESC['meta_data']: assert getattr(tfm, meta_dat) is not None -def test_adds(): - """Tests methods that add data to FOOOF objects. +def test_add_data(): + """Tests method to add data to FOOOF objects.""" - Checks: add_data, add_settings, add_results. - """ - - # Note: uses it's own tfm, to not add stuff to the global one + # This test uses it's own FOOOF object, to not add stuff to the global one tfm = get_tfm() - # Test adding data + # Test data for adding freqs, pows = np.array([1, 2, 3]), np.array([10, 10, 10]) + + # Test adding data tfm.add_data(freqs, pows) + assert tfm.has_data assert np.all(tfm.freqs == freqs) assert np.all(tfm.power_spectrum == np.log10(pows)) + # Test that prior data does not get cleared, when requesting not to clear + tfm._reset_data_results(True, True, True) + tfm.add_results(FOOOFResults([1, 1], [10, 0.5, 0.5], 0.95, 0.02, [10, 0.5, 0.25])) + tfm.add_data(freqs, pows, clear_results=False) + assert tfm.has_data + assert tfm.has_model + + # Test that prior data does get cleared, when requesting not to clear + tfm._reset_data_results(True, True, True) + tfm.add_data(freqs, pows, clear_results=True) + assert tfm.has_data + assert not tfm.has_model + +def test_add_settings(): + """Tests method to add settings to FOOOF objects.""" + + # This test uses it's own FOOOF object, to not add stuff to the global one + tfm = get_tfm() + # Test adding settings fooof_settings = FOOOFSettings([1, 4], 6, 0, 2, 'fixed') tfm.add_settings(fooof_settings) for setting in OBJ_DESC['settings']: assert getattr(tfm, setting) == getattr(fooof_settings, setting) +def test_add_meta_data(): + """Tests method to add meta data to FOOOF objects.""" + + # This test uses it's own FOOOF object, to not add stuff to the global one + tfm = get_tfm() + # Test adding meta data fooof_meta_data = FOOOFMetaData([3, 40], 0.5) tfm.add_meta_data(fooof_meta_data) for meta_dat in OBJ_DESC['meta_data']: assert getattr(tfm, meta_dat) == getattr(fooof_meta_data, meta_dat) +def test_add_results(): + """Tests method to add results to FOOOF objects.""" + + # This test uses it's own FOOOF object, to not add stuff to the global one + tfm = get_tfm() + # Test adding results fooof_results = FOOOFResults([1, 1], [10, 0.5, 0.5], 0.95, 0.02, [10, 0.5, 0.25]) tfm.add_results(fooof_results) + assert tfm.has_model for setting in OBJ_DESC['results']: assert getattr(tfm, setting) == getattr(fooof_results, setting.strip('_')) diff --git a/fooof/tests/plts/test_styles.py b/fooof/tests/plts/test_styles.py index 75b002d1b..f1b2f09af 100644 --- a/fooof/tests/plts/test_styles.py +++ b/fooof/tests/plts/test_styles.py @@ -9,7 +9,6 @@ def test_check_n_style(skip_if_no_mpl): # Check can pass None and do nothing check_n_style(None) - assert True # Check can pass a callable def checker(*args): diff --git a/fooof/tests/plts/test_templates.py b/fooof/tests/plts/test_templates.py index a1e1f3240..f2faa6808 100644 --- a/fooof/tests/plts/test_templates.py +++ b/fooof/tests/plts/test_templates.py @@ -12,25 +12,20 @@ @plot_test def test_plot_scatter_1(skip_if_no_mpl): - dat = np.random.randint(0, 100, 100) + data = np.random.randint(0, 100, 100) - plot_scatter_1(dat, 'label', 'title') + plot_scatter_1(data, 'label', 'title') @plot_test def test_plot_scatter_2(skip_if_no_mpl): - plt.close('all') + data1 = np.random.randint(0, 100, 100) + data2 = np.random.randint(0, 100, 100) - dat1 = np.random.randint(0, 100, 100) - dat2 = np.random.randint(0, 100, 100) - - plot_scatter_2(dat1, 'label1', dat2, 'label2', 'title') - - ax = plt.gca() - assert ax.has_data() + plot_scatter_2(data1, 'label1', data2, 'label2', 'title') @plot_test def test_plot_hist(skip_if_no_mpl): - dat = np.random.randint(0, 100, 100) - plot_hist(dat, 'label', 'title') + data = np.random.randint(0, 100, 100) + plot_hist(data, 'label', 'title') diff --git a/fooof/tests/plts/test_utils.py b/fooof/tests/plts/test_utils.py index 51a2117f4..e7b2c5abb 100644 --- a/fooof/tests/plts/test_utils.py +++ b/fooof/tests/plts/test_utils.py @@ -61,7 +61,7 @@ def test_check_plot_kwargs(skip_if_no_mpl): plot_kwargs_out = check_plot_kwargs(plot_kwargs, defaults) assert plot_kwargs_out == defaults - # Check it keeps orignal values, and updates to defaults parameters when missing + # Check it keeps original values, and updates to defaults parameters when missing plot_kwargs = {'alpha' : 0.5} defaults = {'alpha' : 1, 'linewidth' : 2} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) diff --git a/fooof/utils/io.py b/fooof/utils/io.py index d931bea1e..3c2449a46 100644 --- a/fooof/utils/io.py +++ b/fooof/utils/io.py @@ -1,4 +1,4 @@ -"""Utilities for input / ouput for data and models.""" +"""Utilities for input / output for data and models.""" ################################################################################################### ###################################################################################################