From e0fe53ec8992972cdea15188d61090cd36719372 Mon Sep 17 00:00:00 2001 From: TomDonoghue Date: Tue, 26 Feb 2019 15:11:23 -0800 Subject: [PATCH 01/15] Make new data file, mv & rename FOOOFResults, add FOOOFSettings --- fooof/data.py | 48 +++++++++++++++++++++++++++++++++++++++ fooof/fit.py | 34 ++++++--------------------- fooof/group.py | 4 ++-- fooof/tests/test_group.py | 6 ++--- 4 files changed, 60 insertions(+), 32 deletions(-) create mode 100644 fooof/data.py diff --git a/fooof/data.py b/fooof/data.py new file mode 100644 index 000000000..9e3543baf --- /dev/null +++ b/fooof/data.py @@ -0,0 +1,48 @@ +""" """ + +from collections import namedtuple + +################################################################################################### +################################################################################################### + +FOOOFResults = namedtuple('FOOOFResults', ['aperiodic_params', 'peak_params', + 'r_squared', 'error', 'gaussian_params']) + +FOOOFSettings = namedtuple('FOOOFSettings', ['peak_width_limits', 'max_n_peaks', + 'min_peak_amplitude', 'peak_threshold', + 'aperiodic_mode']) + +FOOOFResults.__doc__ = """\ +The resulting parameters and associated data of a FOOOF model fit. + +Attributes +---------- +aperiodic_params : 1d array, len 2 or 3 + Parameters that define the aperiodic fit. As [Offset, (Knee), Exponent]. + The knee parameter is only included if aperiodic is fit with knee. Otherwise, length is 2. +peak_params : 2d array, shape=[n_peaks, 3] + Fitted parameter values for the peaks. Each row is a peak, as [CF, Amp, BW]. +r_squared : float + R-squared of the fit between the input power spectrum and the full model fit. +error : float + Root mean squared error of the full model fit. +gaussian_params : 2d array, shape=[n_peaks, 3] + Parameters that define the gaussian fit(s). Each row is a gaussian, as [mean, amp, std]. +""" + +FOOOFSettings.__doc__ = """\ +The user defined settings for a FOOOF object. + +Attributes +---------- +peak_width_limits : tuple of (float, float), optional, default: [0.5, 12.0] + Limits on possible peak width, as [lower_bound, upper_bound]. +max_n_peaks : int, optional, default: inf + Maximum number of gaussians to be fit in a single spectrum. +min_peak_amplitude : float, optional, default: 0 + Minimum amplitude threshold for a peak to be modeled. +peak_threshold : float, optional, default: 2.0 + Threshold for detecting peaks, units of standard deviation. +aperiodic_mode : {'fixed', 'knee'} + Which approach to take to fitting the aperiodic component. +""" \ No newline at end of file diff --git a/fooof/fit.py b/fooof/fit.py index 5747f45e6..ce9a812d0 100644 --- a/fooof/fit.py +++ b/fooof/fit.py @@ -33,13 +33,10 @@ import warnings from copy import deepcopy -from collections import namedtuple import numpy as np from scipy.optimize import curve_fit -from fooof.utils import trim_spectrum -from fooof.plts.fm import plot_fm from fooof.core.io import save_fm, load_json from fooof.core.reports import save_report_fm from fooof.core.funcs import gaussian_function, get_ap_func, infer_ap_func @@ -47,31 +44,14 @@ from fooof.core.modutils import copy_doc_func_to_method from fooof.core.strings import gen_settings_str, gen_results_str_fm, gen_issue_str, gen_wid_warn_str +from fooof.plts.fm import plot_fm +from fooof.utils import trim_spectrum +from fooof.data import FOOOFResults, FOOOFSettings from fooof.synth.gen import gen_freqs, gen_aperiodic, gen_peaks ################################################################################################### ################################################################################################### -FOOOFResult = namedtuple('FOOOFResult', ['aperiodic_params', 'peak_params', - 'r_squared', 'error', 'gaussian_params']) -FOOOFResult.__doc__ = """\ -The resulting parameters and associated data of a FOOOF model fit. - -Attributes ----------- -aperiodic_params : 1d array, len 2 or 3 - Parameters that define the aperiodic fit. As [Offset, (Knee), Exponent]. - The knee parameter is only included if aperiodic is fit with knee. Otherwise, length is 2. -peak_params : 2d array, shape=[n_peaks, 3] - Fitted parameter values for the peaks. Each row is a peak, as [CF, Amp, BW]. -r_squared : float - R-squared of the fit between the input power spectrum and the full model fit. -error : float - Root mean squared error of the full model fit. -gaussian_params : 2d array, shape=[n_peaks, 3] - Parameters that define the gaussian fit(s). Each row is a gaussian, as [mean, amp, std]. -""" - class FOOOF(object): """Model the physiological power spectrum as a combination of aperiodic and periodic components. @@ -261,11 +241,11 @@ def add_data(self, freqs, power_spectrum, freq_range=None): def add_results(self, fooof_result, regenerate=False): - """Add results data back into object from a FOOOFResult object. + """Add results data back into object from a FOOOFResults object. Parameters ---------- - fooof_result : FOOOFResult + fooof_result : FOOOFResults An object containing the results from fitting a FOOOF model. regenerate : bool, optional, default: False Whether to regenerate the model fits from the given fit parameters. @@ -429,8 +409,8 @@ def print_report_issue(concise=False): def get_results(self): """Return model fit parameters and goodness of fit metrics.""" - return FOOOFResult(self.aperiodic_params_, self.peak_params_, self.r_squared_, - self.error_, self._gaussian_params) + return FOOOFResults(self.aperiodic_params_, self.peak_params_, self.r_squared_, + self.error_, self._gaussian_params) @copy_doc_func_to_method(plot_fm) diff --git a/fooof/group.py b/fooof/group.py index f60d59079..bdf11fba8 100644 --- a/fooof/group.py +++ b/fooof/group.py @@ -279,14 +279,14 @@ def get_fooof(self, ind, regenerate=False): Parameters ---------- ind : int - The index of the FOOOFResult in FOOOFGroup.group_results to load. + The index of the FOOOFResults in FOOOFGroup.group_results to load. regenerate : bool, optional, default: False Whether to regenerate the model fits from the given fit parameters. Returns ------- inst : FOOOF() object - The FOOOFResult data loaded into a FOOOF object. + The FOOOFResults data loaded into a FOOOF object. """ # Initialize a FOOOF object, with same settings as current FOOOFGroup diff --git a/fooof/tests/test_group.py b/fooof/tests/test_group.py index 3a8292e47..bee861332 100644 --- a/fooof/tests/test_group.py +++ b/fooof/tests/test_group.py @@ -11,7 +11,7 @@ import numpy as np from fooof.group import * -from fooof.fit import FOOOFResult +from fooof.fit import FOOOFResults from fooof.synth import gen_group_power_spectra from fooof.core.utils import get_obj_desc @@ -51,7 +51,7 @@ def test_fg_fit(): assert out assert len(out) == n_spectra - assert isinstance(out[0], FOOOFResult) + assert isinstance(out[0], FOOOFResults) assert np.all(out[1].aperiodic_params) def test_fg_fit_par(): @@ -66,7 +66,7 @@ def test_fg_fit_par(): assert out assert len(out) == n_spectra - assert isinstance(out[0], FOOOFResult) + assert isinstance(out[0], FOOOFResults) assert np.all(out[1].aperiodic_params) def test_fg_print(tfg): From 3600beede161766925d0c44700d12bb5cc598545 Mon Sep 17 00:00:00 2001 From: TomDonoghue Date: Tue, 26 Feb 2019 15:19:01 -0800 Subject: [PATCH 02/15] Add a get_settings method, that uses FOOOFSettings --- fooof/fit.py | 21 ++++++++++++++++++++- fooof/tests/test_fit.py | 9 ++++++--- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/fooof/fit.py b/fooof/fit.py index ce9a812d0..2f7c67086 100644 --- a/fooof/fit.py +++ b/fooof/fit.py @@ -407,12 +407,31 @@ def print_report_issue(concise=False): def get_results(self): - """Return model fit parameters and goodness of fit metrics.""" + """Return model fit parameters and goodness of fit metrics. + + Returns + ------- + FOOOFResults + Object containing the FOOOF model fit results from the current FOOOF object. + """ return FOOOFResults(self.aperiodic_params_, self.peak_params_, self.r_squared_, self.error_, self._gaussian_params) + def get_settings(self): + """Return user defined settings of the FOOOF object. + + Returns + ------- + FOOOFSettings + Object containing the settings from the current FOOOF object. + """ + + return FOOOFSettings(self.peak_width_limits, self.max_n_peaks, self.min_peak_amplitude, + self.peak_threshold, self.aperiodic_mode) + + @copy_doc_func_to_method(plot_fm) def plot(self, plt_log=False, save_fig=False, file_name='FOOOF_plot', file_path=None, ax=None): diff --git a/fooof/tests/test_fit.py b/fooof/tests/test_fit.py index 9da8b0601..2fe6893d8 100644 --- a/fooof/tests/test_fit.py +++ b/fooof/tests/test_fit.py @@ -116,14 +116,17 @@ def test_copy(): def test_fooof_prints_get(tfm): """Test methods that print, return results (alias and pass through methods). - Checks: print_settings, print_results, get_results.""" + Checks: print_settings, print_results, get_results, get_settings.""" tfm.print_settings() tfm.print_results() tfm.print_report_issue() - out = tfm.get_results() - assert out + results = tfm.get_results() + assert results + + settings = tfm.get_settings() + assert settings @plot_test def test_fooof_plot(tfm, skip_if_no_mpl): From 058f3708290de6f98e6506ded690c3fd28b88b44 Mon Sep 17 00:00:00 2001 From: TomDonoghue Date: Tue, 26 Feb 2019 15:28:43 -0800 Subject: [PATCH 03/15] Make consistent order for printing / checking settings --- fooof/core/strings.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/fooof/core/strings.py b/fooof/core/strings.py index 138104ff7..bb1d621e7 100644 --- a/fooof/core/strings.py +++ b/fooof/core/strings.py @@ -27,8 +27,8 @@ def gen_wid_warn_str(freq_res, bwl): output = '\n'.join([ '', - "FOOOF WARNING: Lower-bound peak width limit is < or ~= the frequency resolution: " + \ - "{:1.2f} <= {:1.2f}".format(freq_res, bwl), + 'FOOOF WARNING: Lower-bound peak width limit is < or ~= the frequency resolution: ' + \ + '{:1.2f} <= {:1.2f}'.format(freq_res, bwl), '\tLower bounds below frequency-resolution have no effect (effective lower bound is freq-res)', '\tToo low a limit may lead to overfitting noise as small bandwidth peaks.', '\tWe recommend a lower bound of approximately 2x the frequency resolution.', @@ -57,11 +57,11 @@ def gen_settings_str(f_obj, description=False, concise=False): """ # Parameter descriptions to print out, if requested - desc = {'aperiodic_mode' : 'The aproach taken to fitting the aperiodic component.', - 'peak_width_limits' : 'Enforced limits for peak widths, in Hz.', + desc = {'peak_width_limits' : 'Enforced limits for peak widths, in Hz.', 'max_n_peaks' : 'The maximum number of peaks that can be extracted.', - 'min_peak_amplitude' : "Minimum absolute amplitude of a peak, above aperiodic component.", - 'peak_threshold' : "Threshold at which to stop searching for peaks."} + 'min_peak_amplitude' : 'Minimum absolute amplitude of a peak, above aperiodic component.', + 'peak_threshold' : 'Threshold at which to stop searching for peaks.', + 'aperiodic_mode' : 'The aproach taken to fitting the aperiodic component.'} # Clear description for printing if not requested if not description: @@ -77,16 +77,16 @@ def gen_settings_str(f_obj, description=False, concise=False): '', # Settings - include descriptions if requested - *[el for el in ['Aperiodic Mode : {}'.format(f_obj.aperiodic_mode), - '{}'.format(desc['aperiodic_mode']), - 'Peak Width Limits : {}'.format(f_obj.peak_width_limits), + *[el for el in ['Peak Width Limits : {}'.format(f_obj.peak_width_limits), '{}'.format(desc['peak_width_limits']), 'Max Number of Peaks : {}'.format(f_obj.max_n_peaks), '{}'.format(desc['max_n_peaks']), 'Minimum Amplitude : {}'.format(f_obj.min_peak_amplitude), '{}'.format(desc['min_peak_amplitude']), 'Amplitude Threshold: {}'.format(f_obj.peak_threshold), - '{}'.format(desc['peak_threshold'])] if el != ''], + '{}'.format(desc['peak_threshold']), + 'Aperiodic Mode : {}'.format(f_obj.aperiodic_mode), + '{}'.format(desc['aperiodic_mode'])] if el != ''], # Footer '', @@ -280,7 +280,7 @@ def gen_issue_str(concise=False): 'If FOOOF gives you any weird / bad fits, please let us know!', 'To do so, send us a FOOOF report, and a FOOOF data file, ', '', - "With a FOOOF object (fm), after fitting, run the following commands:", + 'With a FOOOF object (fm), after fitting, run the following commands:', "fm.create_report('FOOOF_bad_fit_report')", "fm.save('FOOOF_bad_fit_data', True, True, True)", '', From 0c5f51db7caadeb705154086b31973e294e6cc79 Mon Sep 17 00:00:00 2001 From: TomDonoghue Date: Tue, 26 Feb 2019 21:44:22 -0800 Subject: [PATCH 04/15] Consolidate on using tuples, & doc updates --- fooof/data.py | 18 +++++++++--------- fooof/fit.py | 20 ++++++++++---------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/fooof/data.py b/fooof/data.py index 9e3543baf..5c1f21f33 100644 --- a/fooof/data.py +++ b/fooof/data.py @@ -1,4 +1,4 @@ -""" """ +"""Data objects for FOOOF.""" from collections import namedtuple @@ -17,16 +17,16 @@ Attributes ---------- -aperiodic_params : 1d array, len 2 or 3 +aperiodic_params : 1d array Parameters that define the aperiodic fit. As [Offset, (Knee), Exponent]. - The knee parameter is only included if aperiodic is fit with knee. Otherwise, length is 2. -peak_params : 2d array, shape=[n_peaks, 3] + The knee parameter is only included if aperiodic is fit with knee. +peak_params : 2d array Fitted parameter values for the peaks. Each row is a peak, as [CF, Amp, BW]. r_squared : float R-squared of the fit between the input power spectrum and the full model fit. error : float Root mean squared error of the full model fit. -gaussian_params : 2d array, shape=[n_peaks, 3] +gaussian_params : 2d array Parameters that define the gaussian fit(s). Each row is a gaussian, as [mean, amp, std]. """ @@ -35,8 +35,8 @@ Attributes ---------- -peak_width_limits : tuple of (float, float), optional, default: [0.5, 12.0] - Limits on possible peak width, as [lower_bound, upper_bound]. +peak_width_limits : tuple of (float, float), optional, default: (0.5, 12.0) + Limits on possible peak width, as (lower_bound, upper_bound). max_n_peaks : int, optional, default: inf Maximum number of gaussians to be fit in a single spectrum. min_peak_amplitude : float, optional, default: 0 @@ -44,5 +44,5 @@ peak_threshold : float, optional, default: 2.0 Threshold for detecting peaks, units of standard deviation. aperiodic_mode : {'fixed', 'knee'} - Which approach to take to fitting the aperiodic component. -""" \ No newline at end of file + Which approach to take for fitting the aperiodic component. +""" diff --git a/fooof/fit.py b/fooof/fit.py index 2f7c67086..bfb13120e 100644 --- a/fooof/fit.py +++ b/fooof/fit.py @@ -5,10 +5,10 @@ Methods without defined docstrings import docs at runtime, from aliased external functions. Private attributes of the FOOOF method, not publicly exposed, are documented below. -Attributes (private) ----------- +Private Attributes +------------------ _spectrum_flat : 1d array - Flattened power spectrum (aperiodic component removed) + Flattened power spectrum (aperiodic component removed). _spectrum_peak_rm : 1d array Power spectrum with peaks removed (not flattened). _gaussian_params : 2d array @@ -24,7 +24,7 @@ _ap_bounds : tuple of tuple of float Upper and lower bounds on fitting aperiodic component. _bw_std_edge : float - Bandwidth threshold for edge rejection of peaks, in units of gaussian std. deviation. + Bandwidth threshold for edge rejection of peaks, in units of gaussian standard deviation. _gauss_overlap_thresh : float Degree of overlap (in units of guassian std. deviation) between gaussian guesses to drop one. _gauss_std_limits : list of [float, float] @@ -62,8 +62,8 @@ class FOOOF(object): Parameters ---------- - peak_width_limits : tuple of (float, float), optional, default: [0.5, 12.0] - Limits on possible peak width, as [lower_bound, upper_bound]. + peak_width_limits : tuple of (float, float), optional, default: (0.5, 12.0) + Limits on possible peak width, as (lower_bound, upper_bound). max_n_peaks : int, optional, default: inf Maximum number of gaussians to be fit in a single spectrum. min_peak_amplitude : float, optional, default: 0 @@ -71,7 +71,7 @@ class FOOOF(object): peak_threshold : float, optional, default: 2.0 Threshold for detecting peaks, units of standard deviation. aperiodic_mode : {'fixed', 'knee'} - Which approach to take to fitting the aperiodic component. + Which approach to take for fitting the aperiodic component. verbose : boolean, optional, default: True Whether to be verbose in printing out warnings. @@ -110,7 +110,7 @@ class FOOOF(object): get smoother power spectra, as this will give better FOOOF fits. """ - def __init__(self, peak_width_limits=[0.5, 12.0], max_n_peaks=np.inf, min_peak_amplitude=0.0, + def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_amplitude=0.0, peak_threshold=2.0, aperiodic_mode='fixed', verbose=True): """Initialize FOOOF object with run parameters.""" @@ -134,7 +134,7 @@ def __init__(self, peak_width_limits=[0.5, 12.0], max_n_peaks=np.inf, min_peak_a self._ap_amp_thresh = 0.025 # Guess parameters for aperiodic fitting, [offset, knee, exponent] # If offset guess is None, the first value of the power spectrum is used as offset guess - self._ap_guess = [None, 0, 2] + self._ap_guess = (None, 0, 2) # Bounds for aperiodic fitting, as: ((offset_low_bound, knee_low_bound, sl_low_bound), # (offset_high_bound, knee_high_bound, sl_high_bound)) # By default, aperiodic fitting is unbound, but can be restricted here, if desired @@ -166,7 +166,7 @@ def _reset_internal_settings(self): # Bandwidth limits are given in 2-sided peak bandwidth. # Convert to gaussian std parameter limits. - self._gauss_std_limits = [bwl / 2 for bwl in self.peak_width_limits] + self._gauss_std_limits = tuple([bwl / 2 for bwl in self.peak_width_limits]) # Bounds for aperiodic fitting. Drops bounds on knee parameter if not set to fit knee self._ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ else tuple(bound[0::2] for bound in self._ap_bounds) From f2e84990bebda19837b0e6b347924d3d41201c1e Mon Sep 17 00:00:00 2001 From: TomDonoghue Date: Wed, 27 Feb 2019 16:50:16 -0800 Subject: [PATCH 05/15] Update and reorg some setting and IO stuff --- fooof/fit.py | 106 +++++++++++++++++++++++++++++-------------------- fooof/group.py | 6 ++- 2 files changed, 67 insertions(+), 45 deletions(-) diff --git a/fooof/fit.py b/fooof/fit.py index bfb13120e..ff39909f9 100644 --- a/fooof/fit.py +++ b/fooof/fit.py @@ -121,11 +121,11 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_a raise ImportError('Scipy version of >= 0.19.0 required.') # Set input parameters - self.aperiodic_mode = aperiodic_mode self.peak_width_limits = peak_width_limits self.max_n_peaks = max_n_peaks self.min_peak_amplitude = min_peak_amplitude self.peak_threshold = peak_threshold + self.aperiodic_mode = aperiodic_mode self.verbose = verbose ## SETTINGS - these are updateable by the user if required. @@ -136,9 +136,9 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_a # If offset guess is None, the first value of the power spectrum is used as offset guess self._ap_guess = (None, 0, 2) # Bounds for aperiodic fitting, as: ((offset_low_bound, knee_low_bound, sl_low_bound), - # (offset_high_bound, knee_high_bound, sl_high_bound)) - # By default, aperiodic fitting is unbound, but can be restricted here, if desired - # Even if fitting without knee, leave bounds for knee (they are dropped later) + # (offset_high_bound, knee_high_bound, sl_high_bound)) + # By default, aperiodic fitting is unbound, but can be restricted here, if desired + # Even if fitting without knee, leave bounds for knee (they are dropped later) self._ap_bounds = ((-np.inf, -np.inf, -np.inf), (np.inf, np.inf, np.inf)) # Threshold for how far (units of gaus std dev) a peak has to be from edge to keep. self._bw_std_edge = 1.0 @@ -240,15 +240,31 @@ def add_data(self, freqs, power_spectrum, freq_range=None): self._prepare_data(freqs, power_spectrum, freq_range, 1, self.verbose) - def add_results(self, fooof_result, regenerate=False): - """Add results data back into object from a FOOOFResults object. + def add_settings(self, fooof_settings): + """Add settings into object from a FOOOFSettings object. + + Parameters + ---------- + fooof_settings : FOOOFSettings + An object containing the settings for a FOOOF model. + """ + + self.aperiodic_mode = fooof_settings.aperiodic_mode + self.peak_width_limits = fooof_settings.peak_width_limits + self.max_n_peaks = fooof_settings.max_n_peaks + self.min_peak_amplitude = fooof_settings.min_peak_amplitude + self.peak_threshold = fooof_settings.peak_threshold + + self._check_loaded_settings(fooof_settings._asdict()) + + + def add_results(self, fooof_result): + """Add results data into object from a FOOOFResults object. Parameters ---------- fooof_result : FOOOFResults An object containing the results from fitting a FOOOF model. - regenerate : bool, optional, default: False - Whether to regenerate the model fits from the given fit parameters. """ self.aperiodic_params_ = fooof_result.aperiodic_params @@ -257,8 +273,7 @@ def add_results(self, fooof_result, regenerate=False): self.error_ = fooof_result.error self._gaussian_params = fooof_result.gaussian_params - if regenerate: - self._regenerate_model() + self._check_loaded_results(fooof_result._asdict()) def report(self, freqs=None, power_spectrum=None, freq_range=None, plt_log=False): @@ -415,8 +430,8 @@ def get_results(self): Object containing the FOOOF model fit results from the current FOOOF object. """ - return FOOOFResults(self.aperiodic_params_, self.peak_params_, self.r_squared_, - self.error_, self._gaussian_params) + return FOOOFResults(self.aperiodic_params_, self.peak_params_, + self.r_squared_, self.error_, self._gaussian_params) def get_settings(self): @@ -428,8 +443,9 @@ def get_settings(self): Object containing the settings from the current FOOOF object. """ - return FOOOFSettings(self.peak_width_limits, self.max_n_peaks, self.min_peak_amplitude, - self.peak_threshold, self.aperiodic_mode) + return FOOOFSettings(self.peak_width_limits, self.max_n_peaks, + self.min_peak_amplitude, self.peak_threshold, + self.aperiodic_mode) @copy_doc_func_to_method(plot_fm) @@ -451,8 +467,8 @@ def save(self, file_name='FOOOF_results', file_path=None, append=False, save_fm(self, file_name, file_path, append, save_results, save_settings, save_data) - def load(self, file_name='FOOOF_results', file_path=None): - """Load in FOOOF file. Reads in a JSON file. + def load(self, file_name='FOOOF_results', file_path=None, regenerate=True): + """Load in FOOOF file. Reads in a FOOOF formatted JSON file. Parameters ---------- @@ -460,6 +476,8 @@ def load(self, file_name='FOOOF_results', file_path=None): File from which to load data. file_path : str, optional Path to directory from which to load. If not provided, loads from current directory. + regenerate : bool, optional, default: True + Whether to regenerate the model fit from the loaded data, if data is available. """ # Reset data in object, so old data can't interfere @@ -471,6 +489,13 @@ def load(self, file_name='FOOOF_results', file_path=None): self._check_loaded_settings(data) self._check_loaded_results(data) + # Regenerate model components, based on what's available + if regenerate: + if self.freq_res: + self._regenerate_freqs() + if np.all(self.freqs) and np.all(self.aperiodic_params_): + self._regenerate_model() + def copy(self): """Return a copy of the FOOOF object.""" @@ -486,29 +511,6 @@ def _check_width_limits(self): print(gen_wid_warn_str(self.freq_res, self.peak_width_limits[0])) - def _check_loaded_results(self, data, regenerate=True): - """Check if results added, check data, and regenerate model, if requested. - - Parameters - ---------- - data : dict - The dictionary of data that has been added to the object. - regenerate : bool, optional, default: True - Whether to regenerate the power_spectrum model. - """ - - # If results loaded, check dimensions of peak parameters - # This fixes an issue where they end up the wrong shape if they are empty (no peaks) - if set(get_obj_desc()['results']).issubset(set(data.keys())): - self.peak_params_ = check_array_dim(self.peak_params_) - self._gaussian_params = check_array_dim(self._gaussian_params) - - # Regenerate power_spectrum model & components - if regenerate: - if np.all(self.freqs) and np.all(self.aperiodic_params_): - self._regenerate_model() - - def _simple_ap_fit(self, freqs, power_spectrum): """Fit the aperiodic component of the power spectrum. @@ -929,9 +931,21 @@ def _add_from_dict(self, data): for key in data.keys(): setattr(self, key, data[key]) - # Reconstruct frequency vector, if data available to do so - if self.freq_res: - self.freqs = gen_freqs(self.freq_range, self.freq_res) + + def _check_loaded_results(self, data): + """Check if results have been added and check data. + + Parameters + ---------- + data : dict + A dictionary of data that has been added to the object. + """ + + # If results loaded, check dimensions of peak parameters + # This fixes an issue where they end up the wrong shape if they are empty (no peaks) + if set(get_obj_desc()['results']).issubset(set(data.keys())): + self.peak_params_ = check_array_dim(self.peak_params_) + self._gaussian_params = check_array_dim(self._gaussian_params) def _check_loaded_settings(self, data): @@ -940,7 +954,7 @@ def _check_loaded_settings(self, data): Parameters ---------- data : dict - The dictionary of data that has been added to the object. + A dictionary of data that has been added to the object. """ # If settings not loaded from file, clear from object, so that default @@ -960,6 +974,12 @@ def _check_loaded_settings(self, data): self._reset_internal_settings() + def _regenerate_freqs(self): + """Regenerate the frequency vector, given the object metadata.""" + + self.freqs = gen_freqs(self.freq_range, self.freq_res) + + def _regenerate_model(self): """Regenerate model fit from parameters.""" diff --git a/fooof/group.py b/fooof/group.py index bdf11fba8..895bbb7b8 100644 --- a/fooof/group.py +++ b/fooof/group.py @@ -266,7 +266,7 @@ def load(self, file_name='FOOOFGroup_results', file_path=None): if ind == 0: self._check_loaded_settings(data) - self._check_loaded_results(data, False) + self._check_loaded_results(data) self.group_results.append(self._get_results()) # Reset peripheral data from last loaded result, keeping freqs info @@ -302,7 +302,9 @@ def get_fooof(self, ind, regenerate=False): fm._add_from_dict({'freq_range': self.freq_range, 'freq_res': self.freq_res}) # Add results for specified power spectrum, regenerating full fit if requested - fm.add_results(self.group_results[ind], regenerate=regenerate) + fm.add_results(self.group_results[ind]) + if regenerate: + fm._regenerate_model() return fm From 28ed03383ce909ed50a81484b987643d30618a06 Mon Sep 17 00:00:00 2001 From: TomDonoghue Date: Wed, 27 Feb 2019 17:01:27 -0800 Subject: [PATCH 06/15] Add tests of FOOOF data objects --- fooof/tests/test_data.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 fooof/tests/test_data.py diff --git a/fooof/tests/test_data.py b/fooof/tests/test_data.py new file mode 100644 index 000000000..204350ecc --- /dev/null +++ b/fooof/tests/test_data.py @@ -0,0 +1,31 @@ +"""Tests for the FOOOF data objects.""" + +from fooof.core.utils import get_obj_desc + +from fooof.data import * + +################################################################################################### +################################################################################################### + +def test_fooof_results(): + + results = FOOOFResults([], [], None, None, []) + assert results + + # Check that the object has the correct fields, given the object description + results_fields = get_obj_desc()['results'] + for field in results_fields: + getattr(results, field.strip('_')) + assert True + + +def test_fooof_settings(): + + settings = FOOOFSettings([], None, None, None, None) + assert settings + + # Check that the object has the correct fields, given the object description + settings_fields = get_obj_desc()['settings'] + for field in settings_fields: + getattr(settings, field) + assert True From b57d4071dc9d4db61a70ddf236a84f49c88c999a Mon Sep 17 00:00:00 2001 From: TomDonoghue Date: Thu, 28 Feb 2019 18:54:16 -0800 Subject: [PATCH 07/15] Move SynParams to data file --- fooof/data.py | 46 +++++++++++++++++++++----------- fooof/tests/test_data.py | 25 +++++++++++------ fooof/tests/test_synth_params.py | 4 --- 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/fooof/data.py b/fooof/data.py index 5c1f21f33..726e2446c 100644 --- a/fooof/data.py +++ b/fooof/data.py @@ -5,13 +5,29 @@ ################################################################################################### ################################################################################################### -FOOOFResults = namedtuple('FOOOFResults', ['aperiodic_params', 'peak_params', - 'r_squared', 'error', 'gaussian_params']) - FOOOFSettings = namedtuple('FOOOFSettings', ['peak_width_limits', 'max_n_peaks', 'min_peak_amplitude', 'peak_threshold', 'aperiodic_mode']) +FOOOFSettings.__doc__ = """\ +The user defined settings for a FOOOF object. + +Attributes +---------- +peak_width_limits : tuple of (float, float), optional, default: (0.5, 12.0) + Limits on possible peak width, as (lower_bound, upper_bound). +max_n_peaks : int, optional, default: inf + Maximum number of gaussians to be fit in a single spectrum. +min_peak_amplitude : float, optional, default: 0 + Minimum amplitude threshold for a peak to be modeled. +peak_threshold : float, optional, default: 2.0 + Threshold for detecting peaks, units of standard deviation. +aperiodic_mode : {'fixed', 'knee'} + Which approach to take for fitting the aperiodic component. +""" + +FOOOFResults = namedtuple('FOOOFResults', ['aperiodic_params', 'peak_params', + 'r_squared', 'error', 'gaussian_params']) FOOOFResults.__doc__ = """\ The resulting parameters and associated data of a FOOOF model fit. @@ -30,19 +46,19 @@ Parameters that define the gaussian fit(s). Each row is a gaussian, as [mean, amp, std]. """ -FOOOFSettings.__doc__ = """\ -The user defined settings for a FOOOF object. + +SynParams = namedtuple('SynParams', ['aperiodic_params', 'gaussian_params', 'nlv']) + +SynParams.__doc__ = """\ +Stores parameters used to synthesize a single power spectra. Attributes ---------- -peak_width_limits : tuple of (float, float), optional, default: (0.5, 12.0) - Limits on possible peak width, as (lower_bound, upper_bound). -max_n_peaks : int, optional, default: inf - Maximum number of gaussians to be fit in a single spectrum. -min_peak_amplitude : float, optional, default: 0 - Minimum amplitude threshold for a peak to be modeled. -peak_threshold : float, optional, default: 2.0 - Threshold for detecting peaks, units of standard deviation. -aperiodic_mode : {'fixed', 'knee'} - Which approach to take for fitting the aperiodic component. +aperiodic_params : list, len 2 or 3 + Parameters that define the aperiodic fit. As [Offset, (Knee), Exponent]. + The knee parameter is only included if aperiodic is fit with knee. Otherwise, length is 2. +gaussian_params : list or list of lists + Fitted parameter values for the peaks. Each list is a peak, as [CF, Amp, BW]. +nlv : float + Noise level added to the generated power spectrum. """ diff --git a/fooof/tests/test_data.py b/fooof/tests/test_data.py index 204350ecc..5337b7fc8 100644 --- a/fooof/tests/test_data.py +++ b/fooof/tests/test_data.py @@ -7,6 +7,17 @@ ################################################################################################### ################################################################################################### +def test_fooof_settings(): + + settings = FOOOFSettings([], None, None, None, None) + assert settings + + # Check that the object has the correct fields, given the object description + settings_fields = get_obj_desc()['settings'] + for field in settings_fields: + getattr(settings, field) + assert True + def test_fooof_results(): results = FOOOFResults([], [], None, None, []) @@ -18,14 +29,12 @@ def test_fooof_results(): getattr(results, field.strip('_')) assert True +def test_syn_params(): -def test_fooof_settings(): - - settings = FOOOFSettings([], None, None, None, None) - assert settings + syn_params = SynParams([1, 1], [10, 1, 1], 0.05) + assert syn_params - # Check that the object has the correct fields, given the object description - settings_fields = get_obj_desc()['settings'] - for field in settings_fields: - getattr(settings, field) + # Check that the object has the correct fields + for field in ['aperiodic_params', 'gaussian_params', 'nlv']: + getattr(syn_params, field) assert True diff --git a/fooof/tests/test_synth_params.py b/fooof/tests/test_synth_params.py index 77b01ecc6..b11a44f05 100644 --- a/fooof/tests/test_synth_params.py +++ b/fooof/tests/test_synth_params.py @@ -7,10 +7,6 @@ ################################################################################################### ################################################################################################### -def test_syn_params(): - - assert SynParams([1, 1], [10, 1, 1], 0.05) - def test_update_syn_ap_params(): syn_params = SynParams([1, 1], [10, 1, 1], 0.05) From dc888a8b73a835c5ea91e136bfcb84776ba06f26 Mon Sep 17 00:00:00 2001 From: TomDonoghue Date: Thu, 28 Feb 2019 19:07:24 -0800 Subject: [PATCH 08/15] Update test related to settings and add/get updates --- fooof/tests/test_fit.py | 45 +++++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/fooof/tests/test_fit.py b/fooof/tests/test_fit.py index 2fe6893d8..345eaa8f5 100644 --- a/fooof/tests/test_fit.py +++ b/fooof/tests/test_fit.py @@ -12,6 +12,7 @@ import pkg_resources as pkg from fooof import FOOOF +from fooof.data import FOOOFSettings, FOOOFResults from fooof.synth import gen_power_spectrum from fooof.core.utils import group_three, get_obj_desc @@ -105,6 +106,44 @@ def test_fooof_load(): tfm.load(file_name_res, file_path) assert tfm +def test_adds(): + """Tests methods that add data to FOOOF objects. + + Checks: add_data, add_settings, add_results. + """ + + # Note: uses it's own tfm, to not add stuff to the global one + tfm = get_tfm() + + # Test adding data + freqs, pows = np.array([1, 2, 3]), np.array([10, 10, 10]) + tfm.add_data(freqs, pows) + assert np.all(tfm.freqs == freqs) + assert np.all(tfm.power_spectrum == np.log10(pows)) + + # Test adding settings + fooof_settings = FOOOFSettings([1, 4], 6, 0, 2, 'fixed') + tfm.add_settings(fooof_settings) + for setting in get_obj_desc()['settings']: + assert getattr(tfm, setting) == getattr(fooof_settings, setting) + + # Test adding results + fooof_results = FOOOFResults([1, 1], [10, 0.5, 0.5], 0.95, 0.02, [10, 0.5, 0.25]) + tfm.add_results(fooof_results) + for setting in get_obj_desc()['results']: + assert getattr(tfm, setting) == getattr(fooof_results, setting.strip('_')) + +def test_gets(tfm): + """Tests methods that return FOOOF data objects. + + Checks: get_settings, get_results + """ + + settings = tfm.get_settings() + assert isinstance(settings, FOOOFSettings) + results = tfm.get_results() + assert isinstance(results, FOOOFResults) + def test_copy(): """Test copy FOOOF method.""" @@ -122,12 +161,6 @@ def test_fooof_prints_get(tfm): tfm.print_results() tfm.print_report_issue() - results = tfm.get_results() - assert results - - settings = tfm.get_settings() - assert settings - @plot_test def test_fooof_plot(tfm, skip_if_no_mpl): """Check the alias to plot FOOOF.""" From c2f32a00ed0d30f00f48244d6f76a15769767f9a Mon Sep 17 00:00:00 2001 From: TomDonoghue Date: Thu, 28 Feb 2019 19:09:15 -0800 Subject: [PATCH 09/15] Remove old code from moving SynParams -> data --- fooof/synth/params.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/fooof/synth/params.py b/fooof/synth/params.py index 8989619be..15a18d42d 100644 --- a/fooof/synth/params.py +++ b/fooof/synth/params.py @@ -1,31 +1,14 @@ """Classes & functions for managing parameter choices for synthesizing power spectra.""" -from collections import namedtuple - import numpy as np -from fooof.core.utils import check_flat, get_data_indices from fooof.core.funcs import infer_ap_func +from fooof.core.utils import check_flat, get_data_indices +from fooof.data import SynParams ################################################################################################### ################################################################################################### -SynParams = namedtuple('SynParams', ['aperiodic_params', 'gaussian_params', 'nlv']) - -SynParams.__doc__ = """\ -Stores parameters used to synthesize a single power spectra. - -Attributes ----------- -aperiodic_params : list, len 2 or 3 - Parameters that define the aperiodic fit. As [Offset, (Knee), Exponent]. - The knee parameter is only included if aperiodic is fit with knee. Otherwise, length is 2. -gaussian_params : list or list of lists - Fitted parameter values for the peaks. Each list is a peak, as [CF, Amp, BW]. -nlv : float - Noise level added to the generated power spectrum. -""" - def update_syn_ap_params(syn_params, delta, field=None): """Update the aperiodic parameter definition in a SynParams object. From 5d089f9f1f86d916a12eebfa40af3d54c6c31aa1 Mon Sep 17 00:00:00 2001 From: TomDonoghue Date: Mon, 4 Mar 2019 22:15:13 -0800 Subject: [PATCH 10/15] Update get & comp info, add FOOOFDataInfo to consolidate appraoch to data --- fooof/core/io.py | 2 +- fooof/core/utils.py | 2 +- fooof/data.py | 14 ++++ fooof/fit.py | 44 +++++----- fooof/funcs.py | 8 +- fooof/tests/test_fit.py | 13 +-- fooof/tests/test_funcs.py | 12 +-- fooof/tests/test_utils.py | 28 ++----- fooof/utils.py | 163 +++++++++++++++++++++++++------------- 9 files changed, 175 insertions(+), 111 deletions(-) diff --git a/fooof/core/io.py b/fooof/core/io.py index 633d31233..533764fb7 100644 --- a/fooof/core/io.py +++ b/fooof/core/io.py @@ -89,7 +89,7 @@ def save_fm(fm, file_name, file_path=None, append=False, # Set and select which variables to keep. Use a set to drop any potential overlap # Note that results also saves frequency information to be able to recreate freq vector attributes = get_obj_desc() - keep = set((attributes['results'] + attributes['freq_info'] if save_results else []) + \ + keep = set((attributes['results'] + attributes['data_info'] if save_results else []) + \ (attributes['settings'] if save_settings else []) + \ (attributes['data'] if save_data else [])) obj_dict = dict_select_keys(obj_dict, keep) diff --git a/fooof/core/utils.py b/fooof/core/utils.py index f8309f0b0..32df2bfa1 100644 --- a/fooof/core/utils.py +++ b/fooof/core/utils.py @@ -114,7 +114,7 @@ def get_obj_desc(): 'settings' : ['peak_width_limits', 'max_n_peaks', 'min_peak_amplitude', 'peak_threshold', 'aperiodic_mode'], 'data' : ['power_spectrum', 'freq_range', 'freq_res'], - 'freq_info' : ['freq_range', 'freq_res'], + 'data_info' : ['freq_range', 'freq_res'], 'arrays' : ['freqs', 'power_spectrum', 'aperiodic_params_', 'peak_params_', '_gaussian_params']} diff --git a/fooof/data.py b/fooof/data.py index 726e2446c..e9373758c 100644 --- a/fooof/data.py +++ b/fooof/data.py @@ -26,6 +26,20 @@ """ +FOOOFDataInfo = namedtuple('FOOOFDataInfo', ['freq_range', 'freq_res']) + +FOOOFDataInfo.__doc__ = """\ +Data related information for a FOOOF object. + +Attributes +---------- +freq_range : list of [float, float] + Frequency range of the power spectrum, as [lowest_freq, highest_freq]. +freq_res : float + Frequency resolution of the power spectrum. +""" + + FOOOFResults = namedtuple('FOOOFResults', ['aperiodic_params', 'peak_params', 'r_squared', 'error', 'gaussian_params']) FOOOFResults.__doc__ = """\ diff --git a/fooof/fit.py b/fooof/fit.py index ff39909f9..f31838cda 100644 --- a/fooof/fit.py +++ b/fooof/fit.py @@ -46,7 +46,7 @@ from fooof.plts.fm import plot_fm from fooof.utils import trim_spectrum -from fooof.data import FOOOFResults, FOOOFSettings +from fooof.data import FOOOFResults, FOOOFSettings, FOOOFDataInfo from fooof.synth.gen import gen_freqs, gen_aperiodic, gen_peaks ################################################################################################### @@ -249,11 +249,8 @@ def add_settings(self, fooof_settings): An object containing the settings for a FOOOF model. """ - self.aperiodic_mode = fooof_settings.aperiodic_mode - self.peak_width_limits = fooof_settings.peak_width_limits - self.max_n_peaks = fooof_settings.max_n_peaks - self.min_peak_amplitude = fooof_settings.min_peak_amplitude - self.peak_threshold = fooof_settings.peak_threshold + for setting in get_obj_desc()['settings']: + setattr(self, setting, getattr(fooof_settings, setting)) self._check_loaded_settings(fooof_settings._asdict()) @@ -421,31 +418,40 @@ def print_report_issue(concise=False): print(gen_issue_str(concise)) - def get_results(self): - """Return model fit parameters and goodness of fit metrics. + def get_settings(self): + """Return user defined settings of the FOOOF object. Returns ------- - FOOOFResults - Object containing the FOOOF model fit results from the current FOOOF object. + FOOOFSettings + Object containing the settings from the current FOOOF object. """ - return FOOOFResults(self.aperiodic_params_, self.peak_params_, - self.r_squared_, self.error_, self._gaussian_params) + return FOOOFSettings(**{key : getattr(self, key) for key in get_obj_desc()['settings']}) - def get_settings(self): - """Return user defined settings of the FOOOF object. + def get_data_info(self): + """Return data information from the FOOOF object. Returns ------- - FOOOFSettings - Object containing the settings from the current FOOOF object. + FOOOFDataInfo + Object containing information about the data from the current FOOOF object. + """ + + return FOOOFDataInfo(**{key : getattr(self, key) for key in get_obj_desc()['data_info']}) + + + def get_results(self): + """Return model fit parameters and goodness of fit metrics. + + Returns + ------- + FOOOFResults + Object containing the FOOOF model fit results from the current FOOOF object. """ - return FOOOFSettings(self.peak_width_limits, self.max_n_peaks, - self.min_peak_amplitude, self.peak_threshold, - self.aperiodic_mode) + return FOOOFResults(**{key.strip('_') : getattr(self, key) for key in get_obj_desc()['results']}) @copy_doc_func_to_method(plot_fm) diff --git a/fooof/funcs.py b/fooof/funcs.py index 2f8dd445d..05cf18ac2 100644 --- a/fooof/funcs.py +++ b/fooof/funcs.py @@ -4,7 +4,7 @@ from fooof import FOOOFGroup from fooof.synth.gen import gen_freqs -from fooof.utils import get_settings, get_obj_desc, compare_settings, compare_data_info +from fooof.utils import get_obj_desc, compare_info ################################################################################################### ################################################################################################### @@ -24,12 +24,12 @@ def combine_fooofs(fooofs): """ # Compare settings - if not compare_settings(fooofs) or not compare_data_info(fooofs): + if not compare_info(fooofs, 'settings') or not compare_info(fooofs, 'data_info'): raise ValueError("These objects have incompatible settings or data," \ "and so cannot be combined.") # Initialize FOOOFGroup object, with settings derived from input objects - fg = FOOOFGroup(**get_settings(fooofs[0]), verbose=fooofs[0].verbose) + fg = FOOOFGroup(*fooofs[0].get_settings(), verbose=fooofs[0].verbose) fg.power_spectra = np.empty([0, len(fooofs[0].freqs)]) # Add FOOOF results from each FOOOF object to group @@ -44,7 +44,7 @@ def combine_fooofs(fooofs): fg.power_spectra = np.vstack([fg.power_spectra, f_obj.power_spectrum]) # Add data information information - for data_info in get_obj_desc()['freq_info']: + for data_info in get_obj_desc()['data_info']: setattr(fg, data_info, getattr(fooofs[0], data_info)) fg.freqs = gen_freqs(fg.freq_range, fg.freq_res) diff --git a/fooof/tests/test_fit.py b/fooof/tests/test_fit.py index 345eaa8f5..0052e8f8d 100644 --- a/fooof/tests/test_fit.py +++ b/fooof/tests/test_fit.py @@ -12,7 +12,7 @@ import pkg_resources as pkg from fooof import FOOOF -from fooof.data import FOOOFSettings, FOOOFResults +from fooof.data import FOOOFSettings, FOOOFDataInfo, FOOOFResults from fooof.synth import gen_power_spectrum from fooof.core.utils import group_three, get_obj_desc @@ -136,11 +136,13 @@ def test_adds(): def test_gets(tfm): """Tests methods that return FOOOF data objects. - Checks: get_settings, get_results + Checks: get_settings, get_data_info, get_results """ settings = tfm.get_settings() assert isinstance(settings, FOOOFSettings) + data_info = tfm.get_data_info() + assert isinstance(data_info, FOOOFDataInfo) results = tfm.get_results() assert isinstance(results, FOOOFResults) @@ -152,10 +154,11 @@ def test_copy(): assert tfm != ntfm -def test_fooof_prints_get(tfm): - """Test methods that print, return results (alias and pass through methods). +def test_fooof_prints(tfm): + """Test methods that print (alias and pass through methods). - Checks: print_settings, print_results, get_results, get_settings.""" + Checks: print_settings, print_results. + """ tfm.print_settings() tfm.print_results() diff --git a/fooof/tests/test_funcs.py b/fooof/tests/test_funcs.py index c5e922357..4be6607b0 100644 --- a/fooof/tests/test_funcs.py +++ b/fooof/tests/test_funcs.py @@ -4,7 +4,7 @@ import numpy as np -from fooof.utils import compare_settings +from fooof.utils import compare_info from fooof.group import FOOOFGroup from fooof.synth import gen_group_power_spectra @@ -23,7 +23,7 @@ def test_combine_fooofs(tfm, tfg): fg1 = combine_fooofs([tfm, tfm2]) assert fg1 assert len(fg1) == 2 - assert compare_settings([fg1, tfm]) + assert compare_info([fg1, tfm], 'settings') assert fg1.group_results[0] == tfm.get_results() assert fg1.group_results[-1] == tfm2.get_results() @@ -31,7 +31,7 @@ def test_combine_fooofs(tfm, tfg): fg2 = combine_fooofs([tfm, tfm2, tfm3]) assert fg2 assert len(fg2) == 3 - assert compare_settings([fg2, tfm]) + assert compare_info([fg2, tfm], 'settings') assert fg2.group_results[0] == tfm.get_results() assert fg2.group_results[-1] == tfm3.get_results() @@ -39,7 +39,7 @@ def test_combine_fooofs(tfm, tfg): nfg1 = combine_fooofs([tfg, tfg2]) assert nfg1 assert len(nfg1) == len(tfg) + len(tfg2) - assert compare_settings([nfg1, tfg, tfg2]) + assert compare_info([nfg1, tfg, tfg2], 'settings') assert nfg1.group_results[0] == tfg.group_results[0] assert nfg1.group_results[-1] == tfg2.group_results[-1] @@ -47,7 +47,7 @@ def test_combine_fooofs(tfm, tfg): nfg2 = combine_fooofs([tfg, tfg2, tfg3]) assert nfg2 assert len(nfg2) == len(tfg) + len(tfg2) + len(tfg3) - assert compare_settings([nfg2, tfg, tfg2, tfg3]) + assert compare_info([nfg2, tfg, tfg2, tfg3], 'settings') assert nfg2.group_results[0] == tfg.group_results[0] assert nfg2.group_results[-1] == tfg3.group_results[-1] @@ -55,7 +55,7 @@ def test_combine_fooofs(tfm, tfg): mfg3 = combine_fooofs([tfg, tfm, tfg2, tfm2]) assert mfg3 assert len(mfg3) == len(tfg) + 1 + len(tfg2) + 1 - assert compare_settings([tfg, tfm, tfg2, tfm2]) + assert compare_info([tfg, tfm, tfg2, tfm2], 'settings') assert mfg3.group_results[0] == tfg.group_results[0] assert mfg3.group_results[-1] == tfm2.get_results() diff --git a/fooof/tests/test_utils.py b/fooof/tests/test_utils.py index bc25163da..107101f01 100644 --- a/fooof/tests/test_utils.py +++ b/fooof/tests/test_utils.py @@ -17,35 +17,23 @@ def test_trim_spectrum(): assert np.array_equal(f_out, np.array([2., 3., 4.])) assert np.array_equal(p_out, np.array([3., 4., 5.])) -def test_get_settings(tfm, tfg): +def test_get_info(tfm, tfg): for f_obj in [tfm, tfg]: - assert get_settings(f_obj) + assert get_info(f_obj, 'settings') + assert get_info(f_obj, 'data_info') -def test_get_data_info(tfm, tfg): +def test_compare_info(tfm, tfg): for f_obj in [tfm, tfg]: - assert get_data_info(f_obj) -def test_compare_settings(tfm, tfg): - - for f_obj in [tfm, tfg]: f_obj2 = f_obj.copy() - assert compare_settings([f_obj, f_obj2]) - + assert compare_info([f_obj, f_obj2], 'settings') f_obj2.peak_width_limits = [2, 4] f_obj2._reset_internal_settings() + assert not compare_info([f_obj, f_obj2], 'settings') - assert not compare_settings([f_obj, f_obj2]) - -def test_compare_data_info(tfm, tfg): - - for f_obj in [tfm, tfg]: - f_obj2 = f_obj.copy() - - assert compare_data_info([f_obj, f_obj2]) - + assert compare_info([f_obj, f_obj2], 'data_info') f_obj2.freq_range = [5, 25] - - assert not compare_data_info([f_obj, f_obj2]) + assert not compare_info([f_obj, f_obj2], 'data_info') diff --git a/fooof/utils.py b/fooof/utils.py index 982d09a7c..e561684f4 100644 --- a/fooof/utils.py +++ b/fooof/utils.py @@ -1,4 +1,4 @@ -"""Utility functions for FOOOF.""" +"""Public utility & helper functions for FOOOF.""" import numpy as np @@ -44,89 +44,142 @@ def trim_spectrum(freqs, power_spectra, f_range): return freqs_ext, power_spectra_ext -def get_settings(f_obj): - """Get a dictionary of current settings from a FOOOF or FOOOFGroup object. +def get_info(f_obj, aspect): + """ Parameters ---------- f_obj : FOOOF or FOOOFGroup - FOOOF derived object to get settings from. + FOOOF derived object to get attributes from. + aspect : {'settings', 'data_info', 'results'} + Which set of attributes to compare the objects across. Returns ------- - dictionary - Settings for the input FOOOF derived object. + dict + xx """ - return {setting : getattr(f_obj, setting) for setting in get_obj_desc()['settings']} + return {key : getattr(f_obj, key) for key in get_obj_desc()[aspect]} -def get_data_info(f_obj): - """Get a dictionary of current data information from a FOOOF or FOOOFGroup object. +def compare_info(lst, aspect): + """Compare a specified aspect of FOOOF objects across instances. Parameters ---------- - f_obj : FOOOF or FOOOFGroup - FOOOF derived object to get data information from. + lst : list of FOOOF or FOOOFGroup objects + FOOOF related objects whose attibutes are to be compared. + aspect : {'setting', 'data_info'} + Which set of attributes to compare the objects across. Returns ------- - dictionary - Data information for the input FOOOF derived object. - - Notes - ----- - Data information for a FOOOF object is the frequency range and frequency resolution. + consistent : bool + Whether the settings are consistent across the input list of objects. """ - return {dat_info : getattr(f_obj, dat_info) for dat_info in get_obj_desc()['freq_info']} + # Check specified aspect of the objects are the same across instances + for f_obj_1, f_obj_2 in zip(lst[:-1], lst[1:]): + if getattr(f_obj_1, 'get_' + aspect)() != getattr(f_obj_2, 'get_' + aspect)(): + consistent = False + break + else: + consistent = True + return consistent -def compare_settings(lst): - """Compare the settings between FOOOF and/or FOOOFGroup objects. - Parameters - ---------- - lst : list of FOOOF or FOOOFGroup objects - FOOOF related objects whose settings are to be compared. - Returns - ------- - bool - Whether the settings are consistent across the input list of objects. - """ - # Check settings are the same across list of given objects - for ind, f_obj in enumerate(lst[:-1]): - if get_settings(f_obj) != get_settings(lst[ind+1]): - return False - # If no settings fail comparison, return that objects have consistent settings - return True +# def get_settings(f_obj): +# """Get a dictionary of current settings from a FOOOF or FOOOFGroup object. +# Parameters +# ---------- +# f_obj : FOOOF or FOOOFGroup +# FOOOF derived object to get settings from. -def compare_data_info(lst): - """Compare the data information between FOOOF and/or FOOOFGroup objects. +# Returns +# ------- +# dictionary +# Settings for the input FOOOF derived object. +# """ - Parameters - ---------- - lst : list of FOOOF or FOOOFGroup objects - FOOOF related objects whose settings are to be compared. +# return {setting : getattr(f_obj, setting) for setting in get_obj_desc()['settings']} - Returns - ------- - bool - Whether the data information is consistent across the input list of objects. - Notes - ----- - Data information for a FOOOF object is the frequency range and frequency resolution. - """ +# def get_data_info(f_obj): +# """Get a dictionary of current data information from a FOOOF or FOOOFGroup object. + +# Parameters +# ---------- +# f_obj : FOOOF or FOOOFGroup +# FOOOF derived object to get data information from. + +# Returns +# ------- +# dictionary +# Data information for the input FOOOF derived object. + +# Notes +# ----- +# Data information for a FOOOF object is the frequency range and frequency resolution. +# """ + +# return {dat_info : getattr(f_obj, dat_info) for dat_info in get_obj_desc()['data_info']} + + +# def compare_settings(lst): +# """Compare the settings between FOOOF and/or FOOOFGroup objects. + +# Parameters +# ---------- +# lst : list of FOOOF or FOOOFGroup objects +# FOOOF related objects whose settings are to be compared. + +# Returns +# ------- +# consistent : bool +# Whether the settings are consistent across the input list of objects. +# """ + +# # Check settings are the same across list of given objects +# for f_obj_1, f_obj_2 in zip(lst[:-1], lst[1:]): +# if f_obj_1.get_settings() != f_obj_2.get_settings(): +# consistent = False +# break +# else: +# consistent = True + +# return consistent + + +# def compare_data_info(lst): +# """Compare the data information between FOOOF and/or FOOOFGroup objects. + +# Parameters +# ---------- +# lst : list of FOOOF or FOOOFGroup objects +# FOOOF related objects whose settings are to be compared. + +# Returns +# ------- +# consistent : bool +# Whether the data information is consistent across the input list of objects. + +# Notes +# ----- +# Data information for a FOOOF object is the frequency range and frequency resolution. +# """ - # Check data information is the same across the list of given objects - for ind, f_obj in enumerate(lst[:-1]): - if get_data_info(f_obj) != get_data_info(lst[ind+1]): - return False +# # Check data information is the same across the list of given objects +# for f_obj_1, f_obj_2 in zip(lst[:-1], lst[1:]): +# if get_data_info(f_obj_1) != get_data_info(f_obj_2): +# consistent = False +# break +# else: +# consistent = True - # If no data info comparisons fail, return that objects have consistent information - return True +# return consistent From 848e2427026941fcbeef4714633681ca663990b1 Mon Sep 17 00:00:00 2001 From: TomDonoghue Date: Mon, 4 Mar 2019 22:20:41 -0800 Subject: [PATCH 11/15] Finish clean up for moving to get_ & compare_ info funcs --- fooof/tests/test_utils.py | 1 + fooof/utils.py | 99 +-------------------------------------- 2 files changed, 3 insertions(+), 97 deletions(-) diff --git a/fooof/tests/test_utils.py b/fooof/tests/test_utils.py index 107101f01..574628936 100644 --- a/fooof/tests/test_utils.py +++ b/fooof/tests/test_utils.py @@ -22,6 +22,7 @@ def test_get_info(tfm, tfg): for f_obj in [tfm, tfg]: assert get_info(f_obj, 'settings') assert get_info(f_obj, 'data_info') + assert get_info(f_obj, 'results') def test_compare_info(tfm, tfg): diff --git a/fooof/utils.py b/fooof/utils.py index e561684f4..341fc42de 100644 --- a/fooof/utils.py +++ b/fooof/utils.py @@ -45,7 +45,7 @@ def trim_spectrum(freqs, power_spectra, f_range): def get_info(f_obj, aspect): - """ + """Get a specified selection of information from a FOOOF derived object. Parameters ---------- @@ -57,7 +57,7 @@ def get_info(f_obj, aspect): Returns ------- dict - xx + The set of specified info from the FOOOF derived object. """ return {key : getattr(f_obj, key) for key in get_obj_desc()[aspect]} @@ -88,98 +88,3 @@ def compare_info(lst, aspect): consistent = True return consistent - - - - - -# def get_settings(f_obj): -# """Get a dictionary of current settings from a FOOOF or FOOOFGroup object. - -# Parameters -# ---------- -# f_obj : FOOOF or FOOOFGroup -# FOOOF derived object to get settings from. - -# Returns -# ------- -# dictionary -# Settings for the input FOOOF derived object. -# """ - -# return {setting : getattr(f_obj, setting) for setting in get_obj_desc()['settings']} - - -# def get_data_info(f_obj): -# """Get a dictionary of current data information from a FOOOF or FOOOFGroup object. - -# Parameters -# ---------- -# f_obj : FOOOF or FOOOFGroup -# FOOOF derived object to get data information from. - -# Returns -# ------- -# dictionary -# Data information for the input FOOOF derived object. - -# Notes -# ----- -# Data information for a FOOOF object is the frequency range and frequency resolution. -# """ - -# return {dat_info : getattr(f_obj, dat_info) for dat_info in get_obj_desc()['data_info']} - - -# def compare_settings(lst): -# """Compare the settings between FOOOF and/or FOOOFGroup objects. - -# Parameters -# ---------- -# lst : list of FOOOF or FOOOFGroup objects -# FOOOF related objects whose settings are to be compared. - -# Returns -# ------- -# consistent : bool -# Whether the settings are consistent across the input list of objects. -# """ - -# # Check settings are the same across list of given objects -# for f_obj_1, f_obj_2 in zip(lst[:-1], lst[1:]): -# if f_obj_1.get_settings() != f_obj_2.get_settings(): -# consistent = False -# break -# else: -# consistent = True - -# return consistent - - -# def compare_data_info(lst): -# """Compare the data information between FOOOF and/or FOOOFGroup objects. - -# Parameters -# ---------- -# lst : list of FOOOF or FOOOFGroup objects -# FOOOF related objects whose settings are to be compared. - -# Returns -# ------- -# consistent : bool -# Whether the data information is consistent across the input list of objects. - -# Notes -# ----- -# Data information for a FOOOF object is the frequency range and frequency resolution. -# """ - -# # Check data information is the same across the list of given objects -# for f_obj_1, f_obj_2 in zip(lst[:-1], lst[1:]): -# if get_data_info(f_obj_1) != get_data_info(f_obj_2): -# consistent = False -# break -# else: -# consistent = True - -# return consistent From 15841f7836ff9848ae60f8b87845951f437c2561 Mon Sep 17 00:00:00 2001 From: TomDonoghue Date: Mon, 4 Mar 2019 22:32:13 -0800 Subject: [PATCH 12/15] Move info & index related tools to their own internal file --- fooof/core/info.py | 47 ++++++++++++++++++++++++++++++++++ fooof/core/io.py | 3 ++- fooof/core/utils.py | 47 ---------------------------------- fooof/fit.py | 3 ++- fooof/funcs.py | 3 ++- fooof/group.py | 2 +- fooof/synth/params.py | 3 ++- fooof/tests/test_core_info.py | 31 ++++++++++++++++++++++ fooof/tests/test_core_io.py | 2 +- fooof/tests/test_core_utils.py | 27 ------------------- fooof/tests/test_data.py | 2 +- fooof/tests/test_fit.py | 3 ++- fooof/tests/test_group.py | 2 +- fooof/utils.py | 2 +- 14 files changed, 93 insertions(+), 84 deletions(-) create mode 100644 fooof/core/info.py create mode 100644 fooof/tests/test_core_info.py diff --git a/fooof/core/info.py b/fooof/core/info.py new file mode 100644 index 000000000..23c39dca5 --- /dev/null +++ b/fooof/core/info.py @@ -0,0 +1,47 @@ +"""Internal functions to manage info related to FOOOF objects.""" + +def get_obj_desc(): + """Get dictionary specifying FOOOF object names and kind of attributes. + + Returns + ------- + attibutes : dict + Mapping of FOOOF object attributes, and what kind of data they are. + """ + + attributes = {'results' : ['aperiodic_params_', 'peak_params_', 'error_', + 'r_squared_', '_gaussian_params'], + 'settings' : ['peak_width_limits', 'max_n_peaks', 'min_peak_amplitude', + 'peak_threshold', 'aperiodic_mode'], + 'data' : ['power_spectrum', 'freq_range', 'freq_res'], + 'data_info' : ['freq_range', 'freq_res'], + 'arrays' : ['freqs', 'power_spectrum', 'aperiodic_params_', + 'peak_params_', '_gaussian_params']} + + return attributes + + +def get_data_indices(aperiodic_mode): + """Get a dictionary mapping the column labels to indices in FOOOF data (FOOOFResults). + + Parameters + ---------- + aperiodic_mode : {'fixed', 'knee'} + Which approach taken to fit the aperiodic component. + + Returns + ------- + indices : dict + Mapping for data columns to the column indices in which they appear. + """ + + indices = { + 'CF' : 0, + 'Amp' : 1, + 'BW' : 2, + 'offset' : 0, + 'knee' : 1 if aperiodic_mode == 'knee' else None, + 'exponent' : 1 if aperiodic_mode == 'fixed' else 2 + } + + return indices diff --git a/fooof/core/io.py b/fooof/core/io.py index 533764fb7..417b707ce 100644 --- a/fooof/core/io.py +++ b/fooof/core/io.py @@ -5,7 +5,8 @@ import json from json import JSONDecodeError -from fooof.core.utils import dict_array_to_lst, dict_select_keys, dict_lst_to_array, get_obj_desc +from fooof.core.info import get_obj_desc +from fooof.core.utils import dict_array_to_lst, dict_select_keys, dict_lst_to_array ################################################################################################### ################################################################################################### diff --git a/fooof/core/utils.py b/fooof/core/utils.py index 32df2bfa1..abd285547 100644 --- a/fooof/core/utils.py +++ b/fooof/core/utils.py @@ -100,53 +100,6 @@ def check_array_dim(arr): return np.empty([0, 3]) if arr.ndim == 1 else arr -def get_obj_desc(): - """Get dictionary specifying FOOOF object names and kind of attributes. - - Returns - ------- - attibutes : dict - Mapping of FOOOF object attributes, and what kind of data they are. - """ - - attributes = {'results' : ['aperiodic_params_', 'peak_params_', 'error_', - 'r_squared_', '_gaussian_params'], - 'settings' : ['peak_width_limits', 'max_n_peaks', 'min_peak_amplitude', - 'peak_threshold', 'aperiodic_mode'], - 'data' : ['power_spectrum', 'freq_range', 'freq_res'], - 'data_info' : ['freq_range', 'freq_res'], - 'arrays' : ['freqs', 'power_spectrum', 'aperiodic_params_', - 'peak_params_', '_gaussian_params']} - - return attributes - - -def get_data_indices(aperiodic_mode): - """Get a dictionary mapping the column labels to indices in FOOOF data (FOOOFResults). - - Parameters - ---------- - aperiodic_mode : {'fixed', 'knee'} - Which approach taken to fit the aperiodic component. - - Returns - ------- - indices : dict - Mapping for data columns to the column indices in which they appear. - """ - - indices = { - 'CF' : 0, - 'Amp' : 1, - 'BW' : 2, - 'offset' : 0, - 'knee' : 1 if aperiodic_mode == 'knee' else None, - 'exponent' : 1 if aperiodic_mode == 'fixed' else 2 - } - - return indices - - def check_iter(obj, length): """Check an object to ensure that it is iterable, and make it iterable if not. diff --git a/fooof/fit.py b/fooof/fit.py index f31838cda..8f9c1d240 100644 --- a/fooof/fit.py +++ b/fooof/fit.py @@ -40,7 +40,8 @@ from fooof.core.io import save_fm, load_json from fooof.core.reports import save_report_fm from fooof.core.funcs import gaussian_function, get_ap_func, infer_ap_func -from fooof.core.utils import group_three, check_array_dim, get_obj_desc +from fooof.core.utils import group_three, check_array_dim +from fooof.core.info import get_obj_desc from fooof.core.modutils import copy_doc_func_to_method from fooof.core.strings import gen_settings_str, gen_results_str_fm, gen_issue_str, gen_wid_warn_str diff --git a/fooof/funcs.py b/fooof/funcs.py index 05cf18ac2..bcb5ec981 100644 --- a/fooof/funcs.py +++ b/fooof/funcs.py @@ -4,7 +4,8 @@ from fooof import FOOOFGroup from fooof.synth.gen import gen_freqs -from fooof.utils import get_obj_desc, compare_info +from fooof.utils import compare_info +from fooof.core.info import get_obj_desc ################################################################################################### ################################################################################################### diff --git a/fooof/group.py b/fooof/group.py index 895bbb7b8..8fc42f4a3 100644 --- a/fooof/group.py +++ b/fooof/group.py @@ -16,7 +16,7 @@ from fooof.core.reports import save_report_fg from fooof.core.strings import gen_results_str_fg from fooof.core.io import save_fg, load_jsonlines -from fooof.core.utils import get_data_indices +from fooof.core.info import get_data_indices from fooof.core.modutils import copy_doc_func_to_method, copy_doc_class, safe_import ################################################################################################### diff --git a/fooof/synth/params.py b/fooof/synth/params.py index 15a18d42d..47ceb386b 100644 --- a/fooof/synth/params.py +++ b/fooof/synth/params.py @@ -3,7 +3,8 @@ import numpy as np from fooof.core.funcs import infer_ap_func -from fooof.core.utils import check_flat, get_data_indices +from fooof.core.utils import check_flat +from fooof.core.info import get_data_indices from fooof.data import SynParams ################################################################################################### diff --git a/fooof/tests/test_core_info.py b/fooof/tests/test_core_info.py new file mode 100644 index 000000000..f7d3085e5 --- /dev/null +++ b/fooof/tests/test_core_info.py @@ -0,0 +1,31 @@ +"""Tests for FOOOF core.info.""" + +from fooof.core.info import * + +################################################################################################### +################################################################################################### + +def test_get_obj_desc(tfm): + + desc = get_obj_desc() + objs = dir(tfm) + + # Test that everything in dict is a valid component of the fooof object + for ke, va in desc.items(): + for it in va: + assert it in objs + +def test_get_data_indices(): + + indices_fixed = get_data_indices('fixed') + assert indices_fixed + for ke, va in indices_fixed.items(): + if ke == 'knee': + assert not va + else: + assert isinstance(va, int) + + indices_knee = get_data_indices('knee') + assert indices_knee + for ke, va in indices_knee.items(): + assert isinstance(va, int) diff --git a/fooof/tests/test_core_io.py b/fooof/tests/test_core_io.py index da832d06c..3510fc2b2 100644 --- a/fooof/tests/test_core_io.py +++ b/fooof/tests/test_core_io.py @@ -4,7 +4,7 @@ import pkg_resources as pkg from fooof import FOOOF -from fooof.core.utils import get_obj_desc +from fooof.core.info import get_obj_desc from fooof.core.io import * diff --git a/fooof/tests/test_core_utils.py b/fooof/tests/test_core_utils.py index 3e09bfa40..98c3aa27b 100644 --- a/fooof/tests/test_core_utils.py +++ b/fooof/tests/test_core_utils.py @@ -59,33 +59,6 @@ def test_check_array_dim(): out = check_array_dim(np.array([1, 2, 3])) assert out.shape == (0, 3) -def test_get_obj_desc(): - - desc = get_obj_desc() - - tfm = FOOOF() - objs = dir(tfm) - - # Test that everything in dict is a valid component of the fooof object - for ke, va in desc.items(): - for it in va: - assert it in objs - -def test_get_data_indices(): - - indices_fixed = get_data_indices('fixed') - assert indices_fixed - for ke, va in indices_fixed.items(): - if ke == 'knee': - assert not va - else: - assert isinstance(va, int) - - indices_knee = get_data_indices('knee') - assert indices_knee - for ke, va in indices_knee.items(): - assert isinstance(va, int) - def test_check_iter(): # Note: generator case not tested diff --git a/fooof/tests/test_data.py b/fooof/tests/test_data.py index 5337b7fc8..cb9575c81 100644 --- a/fooof/tests/test_data.py +++ b/fooof/tests/test_data.py @@ -1,6 +1,6 @@ """Tests for the FOOOF data objects.""" -from fooof.core.utils import get_obj_desc +from fooof.core.info import get_obj_desc from fooof.data import * diff --git a/fooof/tests/test_fit.py b/fooof/tests/test_fit.py index 0052e8f8d..f33dfbdad 100644 --- a/fooof/tests/test_fit.py +++ b/fooof/tests/test_fit.py @@ -14,7 +14,8 @@ from fooof import FOOOF from fooof.data import FOOOFSettings, FOOOFDataInfo, FOOOFResults from fooof.synth import gen_power_spectrum -from fooof.core.utils import group_three, get_obj_desc +from fooof.core.utils import group_three +from fooof.core.info import get_obj_desc from fooof.tests.utils import get_tfm, plot_test diff --git a/fooof/tests/test_group.py b/fooof/tests/test_group.py index bee861332..1472b719a 100644 --- a/fooof/tests/test_group.py +++ b/fooof/tests/test_group.py @@ -13,7 +13,7 @@ from fooof.group import * from fooof.fit import FOOOFResults from fooof.synth import gen_group_power_spectra -from fooof.core.utils import get_obj_desc +from fooof.core.info import get_obj_desc from fooof.tests.utils import default_group_params, plot_test diff --git a/fooof/utils.py b/fooof/utils.py index 341fc42de..8e5140855 100644 --- a/fooof/utils.py +++ b/fooof/utils.py @@ -3,7 +3,7 @@ import numpy as np from fooof.synth import gen_freqs -from fooof.core.utils import get_obj_desc +from fooof.core.info import get_obj_desc ################################################################################################### ################################################################################################### From 3af502bd461610db47fd708edea1231924a9ab66 Mon Sep 17 00:00:00 2001 From: TomDonoghue Date: Tue, 5 Mar 2019 00:21:56 -0800 Subject: [PATCH 13/15] Update & fix combine_fooofs & consolidate approach with add_data_info --- fooof/fit.py | 31 +++++++++++++++++++++++-------- fooof/funcs.py | 19 ++++++++++++------- fooof/group.py | 31 +++++++++++++++++++++++++------ fooof/tests/test_fit.py | 4 ++-- 4 files changed, 62 insertions(+), 23 deletions(-) diff --git a/fooof/fit.py b/fooof/fit.py index 8f9c1d240..1effa50f4 100644 --- a/fooof/fit.py +++ b/fooof/fit.py @@ -200,16 +200,16 @@ def _reset_data_results(self, clear_freqs=True, clear_spectrum=True, clear_resul self.power_spectrum = None if clear_results: + self.aperiodic_params_ = None + self.peak_params_ = None + self.r_squared_ = None + self.error_ = None + self._gaussian_params = None + self.fooofed_spectrum_ = None - self.aperiodic_params_ = np.array([np.nan, np.nan]) if \ - self.aperiodic_mode == 'fixed' else np.array([np.nan, np.nan, np.nan]) - self.peak_params_ = np.array([np.nan, np.nan, np.nan]) - self.r_squared_ = np.nan - self.error_ = np.nan self._spectrum_flat = None self._spectrum_peak_rm = None - self._gaussian_params = np.array([np.nan, np.nan, np.nan]) self._ap_fit = None self._peak_fit = None @@ -247,7 +247,7 @@ def add_settings(self, fooof_settings): Parameters ---------- fooof_settings : FOOOFSettings - An object containing the settings for a FOOOF model. + A FOOOF data object containing the settings for a FOOOF model. """ for setting in get_obj_desc()['settings']: @@ -256,13 +256,28 @@ def add_settings(self, fooof_settings): self._check_loaded_settings(fooof_settings._asdict()) + def add_data_info(self, fooof_data_info): + """Add data information into object from a FOOOFDataInfo object. + + Parameters + ---------- + fooof_data_info : FOOOFDataInfo + A FOOOF data object containing information about the data. + """ + + for data_info in get_obj_desc()['data_info']: + setattr(self, data_info, getattr(fooof_data_info, data_info)) + + self._regenerate_freqs() + + def add_results(self, fooof_result): """Add results data into object from a FOOOFResults object. Parameters ---------- fooof_result : FOOOFResults - An object containing the results from fitting a FOOOF model. + A FOOOF data object containing the results from fitting a FOOOF model. """ self.aperiodic_params_ = fooof_result.aperiodic_params diff --git a/fooof/funcs.py b/fooof/funcs.py index bcb5ec981..7ed5e9ae8 100644 --- a/fooof/funcs.py +++ b/fooof/funcs.py @@ -5,7 +5,6 @@ from fooof import FOOOFGroup from fooof.synth.gen import gen_freqs from fooof.utils import compare_info -from fooof.core.info import get_obj_desc ################################################################################################### ################################################################################################### @@ -31,23 +30,29 @@ def combine_fooofs(fooofs): # Initialize FOOOFGroup object, with settings derived from input objects fg = FOOOFGroup(*fooofs[0].get_settings(), verbose=fooofs[0].verbose) - fg.power_spectra = np.empty([0, len(fooofs[0].freqs)]) + + # Use a temporary store to collect spectra, because we only add them if consistently present + temp_power_spectra = np.empty([0, len(fooofs[0].freqs)]) # Add FOOOF results from each FOOOF object to group for f_obj in fooofs: # Add FOOOFGroup object if isinstance(f_obj, FOOOFGroup): fg.group_results.extend(f_obj.group_results) - fg.power_spectra = np.vstack([fg.power_spectra, f_obj.power_spectra]) + if f_obj.power_spectra is not None: + temp_power_spectra = np.vstack([temp_power_spectra, f_obj.power_spectra]) # Add FOOOF object else: fg.group_results.append(f_obj.get_results()) - fg.power_spectra = np.vstack([fg.power_spectra, f_obj.power_spectrum]) + if f_obj.power_spectrum is not None: + temp_power_spectra = np.vstack([temp_power_spectra, f_obj.power_spectrum]) + + # If the number of collected power spectra is consistent, then add them to object + if len(fg) == temp_power_spectra.shape[0]: + fg.power_spectra = temp_power_spectra # Add data information information - for data_info in get_obj_desc()['data_info']: - setattr(fg, data_info, getattr(fooofs[0], data_info)) - fg.freqs = gen_freqs(fg.freq_range, fg.freq_res) + fg.add_data_info(fooofs[0].get_data_info()) return fg diff --git a/fooof/group.py b/fooof/group.py index 8fc42f4a3..2b23afa5c 100644 --- a/fooof/group.py +++ b/fooof/group.py @@ -38,7 +38,7 @@ def __init__(self, *args, **kwargs): FOOOF.__init__(self, *args, **kwargs) - self.power_spectra = np.array([]) + self.power_spectra = None#np.array([]) self._reset_group_results() @@ -59,6 +59,26 @@ def __getitem__(self, index): return self.group_results[index] + def _reset_data_results(self, clear_freqs=True, clear_spectrum=True, clear_results=True, clear_spectra=True): + """Set (or reset) data & results attributes to empty. + + Parameters + ---------- + clear_freqs : bool, optional, default: True + Whether to clear frequency attributes. + clear_power_spectrum : bool, optional, default: True + Whether to clear power spectrum attribute. + clear_results : bool, optional, default: True + Whether to clear model results attributes. + clear_spectra : bool, optional, default: True + Whether to clear power spectra attribute. + """ + + super()._reset_data_results(clear_freqs, clear_spectrum, clear_results) + if clear_spectra: + self.power_spectra = None + + def _reset_group_results(self, length=0): """Set (or reset) results to be empty. @@ -164,7 +184,7 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1): self.power_spectra), self.verbose, len(self.power_spectra))) - self._reset_data_results(clear_freqs=False) + self._reset_data_results(clear_freqs=False, clear_spectra=False) def get_results(self): @@ -290,16 +310,15 @@ def get_fooof(self, ind, regenerate=False): """ # Initialize a FOOOF object, with same settings as current FOOOFGroup - fm = FOOOF(self.peak_width_limits, self.max_n_peaks, self.min_peak_amplitude, - self.peak_threshold, self.aperiodic_mode, self.verbose) + fm = FOOOF(*self.get_settings(), verbose=self.verbose) # Add data for specified single power spectrum, if available # The power spectrum is inverted back to linear, as it's re-logged when added to FOOOF if np.any(self.power_spectra): fm.add_data(self.freqs, np.power(10, self.power_spectra[ind])) - # If no power spectrum data available, copy over frequency information + # If no power spectrum data available, copy over data information & regenerate freqs else: - fm._add_from_dict({'freq_range': self.freq_range, 'freq_res': self.freq_res}) + fm.add_data_info(self.get_data_info) # Add results for specified power spectrum, regenerating full fit if requested fm.add_results(self.group_results[ind]) diff --git a/fooof/tests/test_fit.py b/fooof/tests/test_fit.py index f33dfbdad..4a7ccd9cb 100644 --- a/fooof/tests/test_fit.py +++ b/fooof/tests/test_fit.py @@ -184,8 +184,8 @@ def test_fooof_resets(): and tfm.power_spectrum is None and tfm.fooofed_spectrum_ is None and tfm._spectrum_flat is None \ and tfm._spectrum_peak_rm is None and tfm._ap_fit is None and tfm._peak_fit is None - assert np.all(np.isnan(tfm.aperiodic_params_)) and np.all(np.isnan(tfm.peak_params_)) \ - and np.all(np.isnan(tfm.r_squared_)) and np.all(np.isnan(tfm.error_)) and np.all(np.isnan(tfm._gaussian_params)) + # assert np.all(np.isnan(tfm.aperiodic_params_)) and np.all(np.isnan(tfm.peak_params_)) \ + # and np.all(np.isnan(tfm.r_squared_)) and np.all(np.isnan(tfm.error_)) and np.all(np.isnan(tfm._gaussian_params)) def test_fooof_report(skip_if_no_mpl): """Check that running the top level model method runs.""" From a5e8e3e226a9288238c7f0c8a8996f04f4f2fdab Mon Sep 17 00:00:00 2001 From: TomDonoghue Date: Tue, 5 Mar 2019 01:24:36 -0800 Subject: [PATCH 14/15] Update tests for updates, and related fixes --- fooof/core/info.py | 14 +++++--- fooof/group.py | 2 +- fooof/tests/test_fit.py | 16 +++++++--- fooof/tests/test_funcs.py | 67 +++++++++++++++++++++------------------ fooof/tests/test_group.py | 9 ++++++ 5 files changed, 67 insertions(+), 41 deletions(-) diff --git a/fooof/core/info.py b/fooof/core/info.py index 23c39dca5..e41443400 100644 --- a/fooof/core/info.py +++ b/fooof/core/info.py @@ -9,14 +9,18 @@ def get_obj_desc(): Mapping of FOOOF object attributes, and what kind of data they are. """ - attributes = {'results' : ['aperiodic_params_', 'peak_params_', 'error_', - 'r_squared_', '_gaussian_params'], - 'settings' : ['peak_width_limits', 'max_n_peaks', 'min_peak_amplitude', - 'peak_threshold', 'aperiodic_mode'], + attributes = {'results' : ['aperiodic_params_', 'peak_params_', + 'r_squared_', 'error_', + '_gaussian_params'], + 'settings' : ['peak_width_limits', 'max_n_peaks', + 'min_peak_amplitude', 'peak_threshold', + 'aperiodic_mode'], 'data' : ['power_spectrum', 'freq_range', 'freq_res'], 'data_info' : ['freq_range', 'freq_res'], 'arrays' : ['freqs', 'power_spectrum', 'aperiodic_params_', - 'peak_params_', '_gaussian_params']} + 'peak_params_', '_gaussian_params'], + 'model_components' : ['_spectrum_flat', '_spectrum_peak_rm', + '_ap_fit', '_peak_fit']} return attributes diff --git a/fooof/group.py b/fooof/group.py index 2b23afa5c..e82ba8bcf 100644 --- a/fooof/group.py +++ b/fooof/group.py @@ -318,7 +318,7 @@ def get_fooof(self, ind, regenerate=False): fm.add_data(self.freqs, np.power(10, self.power_spectra[ind])) # If no power spectrum data available, copy over data information & regenerate freqs else: - fm.add_data_info(self.get_data_info) + fm.add_data_info(self.get_data_info()) # Add results for specified power spectrum, regenerating full fit if requested fm.add_results(self.group_results[ind]) diff --git a/fooof/tests/test_fit.py b/fooof/tests/test_fit.py index 4a7ccd9cb..4d1712234 100644 --- a/fooof/tests/test_fit.py +++ b/fooof/tests/test_fit.py @@ -128,6 +128,12 @@ def test_adds(): for setting in get_obj_desc()['settings']: assert getattr(tfm, setting) == getattr(fooof_settings, setting) + # Test adding data info + fooof_data_info = FOOOFDataInfo([3, 40], 0.5) + tfm.add_data_info(fooof_data_info) + for data_info in get_obj_desc()['data_info']: + assert getattr(tfm, data_info) == getattr(fooof_data_info, data_info) + # Test adding results fooof_results = FOOOFResults([1, 1], [10, 0.5, 0.5], 0.95, 0.02, [10, 0.5, 0.25]) tfm.add_results(fooof_results) @@ -180,12 +186,12 @@ def test_fooof_resets(): tfm._reset_data_results() tfm._reset_internal_settings() - assert tfm.freqs is None and tfm.freq_range is None and tfm.freq_res is None \ - and tfm.power_spectrum is None and tfm.fooofed_spectrum_ is None and tfm._spectrum_flat is None \ - and tfm._spectrum_peak_rm is None and tfm._ap_fit is None and tfm._peak_fit is None + desc = get_obj_desc() - # assert np.all(np.isnan(tfm.aperiodic_params_)) and np.all(np.isnan(tfm.peak_params_)) \ - # and np.all(np.isnan(tfm.r_squared_)) and np.all(np.isnan(tfm.error_)) and np.all(np.isnan(tfm._gaussian_params)) + for data in ['data', 'results', 'model_components']: + for field in desc[data]: + assert getattr(tfm, field) == None + assert tfm.freqs == None and tfm.fooofed_spectrum_ == None def test_fooof_report(skip_if_no_mpl): """Check that running the top level model method runs.""" diff --git a/fooof/tests/test_funcs.py b/fooof/tests/test_funcs.py index 4be6607b0..a83fa4cdb 100644 --- a/fooof/tests/test_funcs.py +++ b/fooof/tests/test_funcs.py @@ -20,44 +20,51 @@ def test_combine_fooofs(tfm, tfg): tfg2 = tfg.copy(); tfg3 = tfg.copy() # Check combining 2 FOOOFs - fg1 = combine_fooofs([tfm, tfm2]) - assert fg1 - assert len(fg1) == 2 - assert compare_info([fg1, tfm], 'settings') - assert fg1.group_results[0] == tfm.get_results() - assert fg1.group_results[-1] == tfm2.get_results() + nfg1 = combine_fooofs([tfm, tfm2]) + assert nfg1 + assert len(nfg1) == 2 + assert compare_info([nfg1, tfm], 'settings') + assert nfg1.group_results[0] == tfm.get_results() + assert nfg1.group_results[-1] == tfm2.get_results() # Check combining 3 FOOOFs - fg2 = combine_fooofs([tfm, tfm2, tfm3]) - assert fg2 - assert len(fg2) == 3 - assert compare_info([fg2, tfm], 'settings') - assert fg2.group_results[0] == tfm.get_results() - assert fg2.group_results[-1] == tfm3.get_results() + nfg2 = combine_fooofs([tfm, tfm2, tfm3]) + assert nfg2 + assert len(nfg2) == 3 + assert compare_info([nfg2, tfm], 'settings') + assert nfg2.group_results[0] == tfm.get_results() + assert nfg2.group_results[-1] == tfm3.get_results() # Check combining 2 FOOOFGroups - nfg1 = combine_fooofs([tfg, tfg2]) - assert nfg1 - assert len(nfg1) == len(tfg) + len(tfg2) - assert compare_info([nfg1, tfg, tfg2], 'settings') - assert nfg1.group_results[0] == tfg.group_results[0] - assert nfg1.group_results[-1] == tfg2.group_results[-1] + nfg3 = combine_fooofs([tfg, tfg2]) + assert nfg3 + assert len(nfg3) == len(tfg) + len(tfg2) + assert compare_info([nfg3, tfg, tfg2], 'settings') + assert nfg3.group_results[0] == tfg.group_results[0] + assert nfg3.group_results[-1] == tfg2.group_results[-1] # Check combining 3 FOOOFGroups - nfg2 = combine_fooofs([tfg, tfg2, tfg3]) - assert nfg2 - assert len(nfg2) == len(tfg) + len(tfg2) + len(tfg3) - assert compare_info([nfg2, tfg, tfg2, tfg3], 'settings') - assert nfg2.group_results[0] == tfg.group_results[0] - assert nfg2.group_results[-1] == tfg3.group_results[-1] + nfg4 = combine_fooofs([tfg, tfg2, tfg3]) + assert nfg4 + assert len(nfg4) == len(tfg) + len(tfg2) + len(tfg3) + assert compare_info([nfg4, tfg, tfg2, tfg3], 'settings') + assert nfg4.group_results[0] == tfg.group_results[0] + assert nfg4.group_results[-1] == tfg3.group_results[-1] # Check combining a mixture of FOOOF & FOOOFGroup - mfg3 = combine_fooofs([tfg, tfm, tfg2, tfm2]) - assert mfg3 - assert len(mfg3) == len(tfg) + 1 + len(tfg2) + 1 - assert compare_info([tfg, tfm, tfg2, tfm2], 'settings') - assert mfg3.group_results[0] == tfg.group_results[0] - assert mfg3.group_results[-1] == tfm2.get_results() + nfg5 = combine_fooofs([tfg, tfm, tfg2, tfm2]) + assert nfg5 + assert len(nfg5) == len(tfg) + 1 + len(tfg2) + 1 + assert compare_info([nfg5, tfg, tfm, tfg2, tfm2], 'settings') + assert nfg5.group_results[0] == tfg.group_results[0] + assert nfg5.group_results[-1] == tfm2.get_results() + + # Check combining objects with no data + tfm2._reset_data_results(False, True, True) + tfg2._reset_data_results(False, True, True, True) + nfg6 = combine_fooofs([tfm2, tfg2]) + assert len(nfg6) == 1 + len(tfg2) + assert nfg6.power_spectra == None def test_combine_errors(tfm, tfg): diff --git a/fooof/tests/test_group.py b/fooof/tests/test_group.py index 1472b719a..7cb367967 100644 --- a/fooof/tests/test_group.py +++ b/fooof/tests/test_group.py @@ -144,3 +144,12 @@ def test_fg_get_fooof(tfg): # Check that regenerated model is created for result in desc['results']: assert np.all(getattr(tfm1, result)) + + # Test when object has no data (clear a copy of tfg) + new_tfg = tfg.copy() + new_tfg._reset_data_results(False, True, True, True) + tfm2 = new_tfg.get_fooof(0, True) + assert tfm2 + # Check that data info is copied over properly + for data_info in desc['data_info']: + assert getattr(tfm2, data_info) From 4bc5e101f23c3b8eb21bb317c84cd393a67a110b Mon Sep 17 00:00:00 2001 From: TomDonoghue Date: Tue, 19 Mar 2019 17:58:35 -0700 Subject: [PATCH 15/15] Remove outdated comment --- fooof/group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fooof/group.py b/fooof/group.py index e82ba8bcf..30aa57f72 100644 --- a/fooof/group.py +++ b/fooof/group.py @@ -38,7 +38,7 @@ def __init__(self, *args, **kwargs): FOOOF.__init__(self, *args, **kwargs) - self.power_spectra = None#np.array([]) + self.power_spectra = None self._reset_group_results()