diff --git a/fooof/core/info.py b/fooof/core/info.py new file mode 100644 index 000000000..e41443400 --- /dev/null +++ b/fooof/core/info.py @@ -0,0 +1,51 @@ +"""Internal functions to manage info related to FOOOF objects.""" + +def get_obj_desc(): + """Get dictionary specifying FOOOF object names and kind of attributes. + + Returns + ------- + attibutes : dict + Mapping of FOOOF object attributes, and what kind of data they are. + """ + + attributes = {'results' : ['aperiodic_params_', 'peak_params_', + 'r_squared_', 'error_', + '_gaussian_params'], + 'settings' : ['peak_width_limits', 'max_n_peaks', + 'min_peak_amplitude', 'peak_threshold', + 'aperiodic_mode'], + 'data' : ['power_spectrum', 'freq_range', 'freq_res'], + 'data_info' : ['freq_range', 'freq_res'], + 'arrays' : ['freqs', 'power_spectrum', 'aperiodic_params_', + 'peak_params_', '_gaussian_params'], + 'model_components' : ['_spectrum_flat', '_spectrum_peak_rm', + '_ap_fit', '_peak_fit']} + + return attributes + + +def get_data_indices(aperiodic_mode): + """Get a dictionary mapping the column labels to indices in FOOOF data (FOOOFResults). + + Parameters + ---------- + aperiodic_mode : {'fixed', 'knee'} + Which approach taken to fit the aperiodic component. + + Returns + ------- + indices : dict + Mapping for data columns to the column indices in which they appear. + """ + + indices = { + 'CF' : 0, + 'Amp' : 1, + 'BW' : 2, + 'offset' : 0, + 'knee' : 1 if aperiodic_mode == 'knee' else None, + 'exponent' : 1 if aperiodic_mode == 'fixed' else 2 + } + + return indices diff --git a/fooof/core/io.py b/fooof/core/io.py index 633d31233..417b707ce 100644 --- a/fooof/core/io.py +++ b/fooof/core/io.py @@ -5,7 +5,8 @@ import json from json import JSONDecodeError -from fooof.core.utils import dict_array_to_lst, dict_select_keys, dict_lst_to_array, get_obj_desc +from fooof.core.info import get_obj_desc +from fooof.core.utils import dict_array_to_lst, dict_select_keys, dict_lst_to_array ################################################################################################### ################################################################################################### @@ -89,7 +90,7 @@ def save_fm(fm, file_name, file_path=None, append=False, # Set and select which variables to keep. Use a set to drop any potential overlap # Note that results also saves frequency information to be able to recreate freq vector attributes = get_obj_desc() - keep = set((attributes['results'] + attributes['freq_info'] if save_results else []) + \ + keep = set((attributes['results'] + attributes['data_info'] if save_results else []) + \ (attributes['settings'] if save_settings else []) + \ (attributes['data'] if save_data else [])) obj_dict = dict_select_keys(obj_dict, keep) diff --git a/fooof/core/strings.py b/fooof/core/strings.py index 138104ff7..bb1d621e7 100644 --- a/fooof/core/strings.py +++ b/fooof/core/strings.py @@ -27,8 +27,8 @@ def gen_wid_warn_str(freq_res, bwl): output = '\n'.join([ '', - "FOOOF WARNING: Lower-bound peak width limit is < or ~= the frequency resolution: " + \ - "{:1.2f} <= {:1.2f}".format(freq_res, bwl), + 'FOOOF WARNING: Lower-bound peak width limit is < or ~= the frequency resolution: ' + \ + '{:1.2f} <= {:1.2f}'.format(freq_res, bwl), '\tLower bounds below frequency-resolution have no effect (effective lower bound is freq-res)', '\tToo low a limit may lead to overfitting noise as small bandwidth peaks.', '\tWe recommend a lower bound of approximately 2x the frequency resolution.', @@ -57,11 +57,11 @@ def gen_settings_str(f_obj, description=False, concise=False): """ # Parameter descriptions to print out, if requested - desc = {'aperiodic_mode' : 'The aproach taken to fitting the aperiodic component.', - 'peak_width_limits' : 'Enforced limits for peak widths, in Hz.', + desc = {'peak_width_limits' : 'Enforced limits for peak widths, in Hz.', 'max_n_peaks' : 'The maximum number of peaks that can be extracted.', - 'min_peak_amplitude' : "Minimum absolute amplitude of a peak, above aperiodic component.", - 'peak_threshold' : "Threshold at which to stop searching for peaks."} + 'min_peak_amplitude' : 'Minimum absolute amplitude of a peak, above aperiodic component.', + 'peak_threshold' : 'Threshold at which to stop searching for peaks.', + 'aperiodic_mode' : 'The aproach taken to fitting the aperiodic component.'} # Clear description for printing if not requested if not description: @@ -77,16 +77,16 @@ def gen_settings_str(f_obj, description=False, concise=False): '', # Settings - include descriptions if requested - *[el for el in ['Aperiodic Mode : {}'.format(f_obj.aperiodic_mode), - '{}'.format(desc['aperiodic_mode']), - 'Peak Width Limits : {}'.format(f_obj.peak_width_limits), + *[el for el in ['Peak Width Limits : {}'.format(f_obj.peak_width_limits), '{}'.format(desc['peak_width_limits']), 'Max Number of Peaks : {}'.format(f_obj.max_n_peaks), '{}'.format(desc['max_n_peaks']), 'Minimum Amplitude : {}'.format(f_obj.min_peak_amplitude), '{}'.format(desc['min_peak_amplitude']), 'Amplitude Threshold: {}'.format(f_obj.peak_threshold), - '{}'.format(desc['peak_threshold'])] if el != ''], + '{}'.format(desc['peak_threshold']), + 'Aperiodic Mode : {}'.format(f_obj.aperiodic_mode), + '{}'.format(desc['aperiodic_mode'])] if el != ''], # Footer '', @@ -280,7 +280,7 @@ def gen_issue_str(concise=False): 'If FOOOF gives you any weird / bad fits, please let us know!', 'To do so, send us a FOOOF report, and a FOOOF data file, ', '', - "With a FOOOF object (fm), after fitting, run the following commands:", + 'With a FOOOF object (fm), after fitting, run the following commands:', "fm.create_report('FOOOF_bad_fit_report')", "fm.save('FOOOF_bad_fit_data', True, True, True)", '', diff --git a/fooof/core/utils.py b/fooof/core/utils.py index f8309f0b0..abd285547 100644 --- a/fooof/core/utils.py +++ b/fooof/core/utils.py @@ -100,53 +100,6 @@ def check_array_dim(arr): return np.empty([0, 3]) if arr.ndim == 1 else arr -def get_obj_desc(): - """Get dictionary specifying FOOOF object names and kind of attributes. - - Returns - ------- - attibutes : dict - Mapping of FOOOF object attributes, and what kind of data they are. - """ - - attributes = {'results' : ['aperiodic_params_', 'peak_params_', 'error_', - 'r_squared_', '_gaussian_params'], - 'settings' : ['peak_width_limits', 'max_n_peaks', 'min_peak_amplitude', - 'peak_threshold', 'aperiodic_mode'], - 'data' : ['power_spectrum', 'freq_range', 'freq_res'], - 'freq_info' : ['freq_range', 'freq_res'], - 'arrays' : ['freqs', 'power_spectrum', 'aperiodic_params_', - 'peak_params_', '_gaussian_params']} - - return attributes - - -def get_data_indices(aperiodic_mode): - """Get a dictionary mapping the column labels to indices in FOOOF data (FOOOFResults). - - Parameters - ---------- - aperiodic_mode : {'fixed', 'knee'} - Which approach taken to fit the aperiodic component. - - Returns - ------- - indices : dict - Mapping for data columns to the column indices in which they appear. - """ - - indices = { - 'CF' : 0, - 'Amp' : 1, - 'BW' : 2, - 'offset' : 0, - 'knee' : 1 if aperiodic_mode == 'knee' else None, - 'exponent' : 1 if aperiodic_mode == 'fixed' else 2 - } - - return indices - - def check_iter(obj, length): """Check an object to ensure that it is iterable, and make it iterable if not. diff --git a/fooof/data.py b/fooof/data.py new file mode 100644 index 000000000..e9373758c --- /dev/null +++ b/fooof/data.py @@ -0,0 +1,78 @@ +"""Data objects for FOOOF.""" + +from collections import namedtuple + +################################################################################################### +################################################################################################### + +FOOOFSettings = namedtuple('FOOOFSettings', ['peak_width_limits', 'max_n_peaks', + 'min_peak_amplitude', 'peak_threshold', + 'aperiodic_mode']) +FOOOFSettings.__doc__ = """\ +The user defined settings for a FOOOF object. + +Attributes +---------- +peak_width_limits : tuple of (float, float), optional, default: (0.5, 12.0) + Limits on possible peak width, as (lower_bound, upper_bound). +max_n_peaks : int, optional, default: inf + Maximum number of gaussians to be fit in a single spectrum. +min_peak_amplitude : float, optional, default: 0 + Minimum amplitude threshold for a peak to be modeled. +peak_threshold : float, optional, default: 2.0 + Threshold for detecting peaks, units of standard deviation. +aperiodic_mode : {'fixed', 'knee'} + Which approach to take for fitting the aperiodic component. +""" + + +FOOOFDataInfo = namedtuple('FOOOFDataInfo', ['freq_range', 'freq_res']) + +FOOOFDataInfo.__doc__ = """\ +Data related information for a FOOOF object. + +Attributes +---------- +freq_range : list of [float, float] + Frequency range of the power spectrum, as [lowest_freq, highest_freq]. +freq_res : float + Frequency resolution of the power spectrum. +""" + + +FOOOFResults = namedtuple('FOOOFResults', ['aperiodic_params', 'peak_params', + 'r_squared', 'error', 'gaussian_params']) +FOOOFResults.__doc__ = """\ +The resulting parameters and associated data of a FOOOF model fit. + +Attributes +---------- +aperiodic_params : 1d array + Parameters that define the aperiodic fit. As [Offset, (Knee), Exponent]. + The knee parameter is only included if aperiodic is fit with knee. +peak_params : 2d array + Fitted parameter values for the peaks. Each row is a peak, as [CF, Amp, BW]. +r_squared : float + R-squared of the fit between the input power spectrum and the full model fit. +error : float + Root mean squared error of the full model fit. +gaussian_params : 2d array + Parameters that define the gaussian fit(s). Each row is a gaussian, as [mean, amp, std]. +""" + + +SynParams = namedtuple('SynParams', ['aperiodic_params', 'gaussian_params', 'nlv']) + +SynParams.__doc__ = """\ +Stores parameters used to synthesize a single power spectra. + +Attributes +---------- +aperiodic_params : list, len 2 or 3 + Parameters that define the aperiodic fit. As [Offset, (Knee), Exponent]. + The knee parameter is only included if aperiodic is fit with knee. Otherwise, length is 2. +gaussian_params : list or list of lists + Fitted parameter values for the peaks. Each list is a peak, as [CF, Amp, BW]. +nlv : float + Noise level added to the generated power spectrum. +""" diff --git a/fooof/fit.py b/fooof/fit.py index 5747f45e6..1effa50f4 100644 --- a/fooof/fit.py +++ b/fooof/fit.py @@ -5,10 +5,10 @@ Methods without defined docstrings import docs at runtime, from aliased external functions. Private attributes of the FOOOF method, not publicly exposed, are documented below. -Attributes (private) ----------- +Private Attributes +------------------ _spectrum_flat : 1d array - Flattened power spectrum (aperiodic component removed) + Flattened power spectrum (aperiodic component removed). _spectrum_peak_rm : 1d array Power spectrum with peaks removed (not flattened). _gaussian_params : 2d array @@ -24,7 +24,7 @@ _ap_bounds : tuple of tuple of float Upper and lower bounds on fitting aperiodic component. _bw_std_edge : float - Bandwidth threshold for edge rejection of peaks, in units of gaussian std. deviation. + Bandwidth threshold for edge rejection of peaks, in units of gaussian standard deviation. _gauss_overlap_thresh : float Degree of overlap (in units of guassian std. deviation) between gaussian guesses to drop one. _gauss_std_limits : list of [float, float] @@ -33,45 +33,26 @@ import warnings from copy import deepcopy -from collections import namedtuple import numpy as np from scipy.optimize import curve_fit -from fooof.utils import trim_spectrum -from fooof.plts.fm import plot_fm from fooof.core.io import save_fm, load_json from fooof.core.reports import save_report_fm from fooof.core.funcs import gaussian_function, get_ap_func, infer_ap_func -from fooof.core.utils import group_three, check_array_dim, get_obj_desc +from fooof.core.utils import group_three, check_array_dim +from fooof.core.info import get_obj_desc from fooof.core.modutils import copy_doc_func_to_method from fooof.core.strings import gen_settings_str, gen_results_str_fm, gen_issue_str, gen_wid_warn_str +from fooof.plts.fm import plot_fm +from fooof.utils import trim_spectrum +from fooof.data import FOOOFResults, FOOOFSettings, FOOOFDataInfo from fooof.synth.gen import gen_freqs, gen_aperiodic, gen_peaks ################################################################################################### ################################################################################################### -FOOOFResult = namedtuple('FOOOFResult', ['aperiodic_params', 'peak_params', - 'r_squared', 'error', 'gaussian_params']) -FOOOFResult.__doc__ = """\ -The resulting parameters and associated data of a FOOOF model fit. - -Attributes ----------- -aperiodic_params : 1d array, len 2 or 3 - Parameters that define the aperiodic fit. As [Offset, (Knee), Exponent]. - The knee parameter is only included if aperiodic is fit with knee. Otherwise, length is 2. -peak_params : 2d array, shape=[n_peaks, 3] - Fitted parameter values for the peaks. Each row is a peak, as [CF, Amp, BW]. -r_squared : float - R-squared of the fit between the input power spectrum and the full model fit. -error : float - Root mean squared error of the full model fit. -gaussian_params : 2d array, shape=[n_peaks, 3] - Parameters that define the gaussian fit(s). Each row is a gaussian, as [mean, amp, std]. -""" - class FOOOF(object): """Model the physiological power spectrum as a combination of aperiodic and periodic components. @@ -82,8 +63,8 @@ class FOOOF(object): Parameters ---------- - peak_width_limits : tuple of (float, float), optional, default: [0.5, 12.0] - Limits on possible peak width, as [lower_bound, upper_bound]. + peak_width_limits : tuple of (float, float), optional, default: (0.5, 12.0) + Limits on possible peak width, as (lower_bound, upper_bound). max_n_peaks : int, optional, default: inf Maximum number of gaussians to be fit in a single spectrum. min_peak_amplitude : float, optional, default: 0 @@ -91,7 +72,7 @@ class FOOOF(object): peak_threshold : float, optional, default: 2.0 Threshold for detecting peaks, units of standard deviation. aperiodic_mode : {'fixed', 'knee'} - Which approach to take to fitting the aperiodic component. + Which approach to take for fitting the aperiodic component. verbose : boolean, optional, default: True Whether to be verbose in printing out warnings. @@ -130,7 +111,7 @@ class FOOOF(object): get smoother power spectra, as this will give better FOOOF fits. """ - def __init__(self, peak_width_limits=[0.5, 12.0], max_n_peaks=np.inf, min_peak_amplitude=0.0, + def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_amplitude=0.0, peak_threshold=2.0, aperiodic_mode='fixed', verbose=True): """Initialize FOOOF object with run parameters.""" @@ -141,11 +122,11 @@ def __init__(self, peak_width_limits=[0.5, 12.0], max_n_peaks=np.inf, min_peak_a raise ImportError('Scipy version of >= 0.19.0 required.') # Set input parameters - self.aperiodic_mode = aperiodic_mode self.peak_width_limits = peak_width_limits self.max_n_peaks = max_n_peaks self.min_peak_amplitude = min_peak_amplitude self.peak_threshold = peak_threshold + self.aperiodic_mode = aperiodic_mode self.verbose = verbose ## SETTINGS - these are updateable by the user if required. @@ -154,11 +135,11 @@ def __init__(self, peak_width_limits=[0.5, 12.0], max_n_peaks=np.inf, min_peak_a self._ap_amp_thresh = 0.025 # Guess parameters for aperiodic fitting, [offset, knee, exponent] # If offset guess is None, the first value of the power spectrum is used as offset guess - self._ap_guess = [None, 0, 2] + self._ap_guess = (None, 0, 2) # Bounds for aperiodic fitting, as: ((offset_low_bound, knee_low_bound, sl_low_bound), - # (offset_high_bound, knee_high_bound, sl_high_bound)) - # By default, aperiodic fitting is unbound, but can be restricted here, if desired - # Even if fitting without knee, leave bounds for knee (they are dropped later) + # (offset_high_bound, knee_high_bound, sl_high_bound)) + # By default, aperiodic fitting is unbound, but can be restricted here, if desired + # Even if fitting without knee, leave bounds for knee (they are dropped later) self._ap_bounds = ((-np.inf, -np.inf, -np.inf), (np.inf, np.inf, np.inf)) # Threshold for how far (units of gaus std dev) a peak has to be from edge to keep. self._bw_std_edge = 1.0 @@ -186,7 +167,7 @@ def _reset_internal_settings(self): # Bandwidth limits are given in 2-sided peak bandwidth. # Convert to gaussian std parameter limits. - self._gauss_std_limits = [bwl / 2 for bwl in self.peak_width_limits] + self._gauss_std_limits = tuple([bwl / 2 for bwl in self.peak_width_limits]) # Bounds for aperiodic fitting. Drops bounds on knee parameter if not set to fit knee self._ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ else tuple(bound[0::2] for bound in self._ap_bounds) @@ -219,16 +200,16 @@ def _reset_data_results(self, clear_freqs=True, clear_spectrum=True, clear_resul self.power_spectrum = None if clear_results: + self.aperiodic_params_ = None + self.peak_params_ = None + self.r_squared_ = None + self.error_ = None + self._gaussian_params = None + self.fooofed_spectrum_ = None - self.aperiodic_params_ = np.array([np.nan, np.nan]) if \ - self.aperiodic_mode == 'fixed' else np.array([np.nan, np.nan, np.nan]) - self.peak_params_ = np.array([np.nan, np.nan, np.nan]) - self.r_squared_ = np.nan - self.error_ = np.nan self._spectrum_flat = None self._spectrum_peak_rm = None - self._gaussian_params = np.array([np.nan, np.nan, np.nan]) self._ap_fit = None self._peak_fit = None @@ -260,15 +241,43 @@ def add_data(self, freqs, power_spectrum, freq_range=None): self._prepare_data(freqs, power_spectrum, freq_range, 1, self.verbose) - def add_results(self, fooof_result, regenerate=False): - """Add results data back into object from a FOOOFResult object. + def add_settings(self, fooof_settings): + """Add settings into object from a FOOOFSettings object. + + Parameters + ---------- + fooof_settings : FOOOFSettings + A FOOOF data object containing the settings for a FOOOF model. + """ + + for setting in get_obj_desc()['settings']: + setattr(self, setting, getattr(fooof_settings, setting)) + + self._check_loaded_settings(fooof_settings._asdict()) + + + def add_data_info(self, fooof_data_info): + """Add data information into object from a FOOOFDataInfo object. + + Parameters + ---------- + fooof_data_info : FOOOFDataInfo + A FOOOF data object containing information about the data. + """ + + for data_info in get_obj_desc()['data_info']: + setattr(self, data_info, getattr(fooof_data_info, data_info)) + + self._regenerate_freqs() + + + def add_results(self, fooof_result): + """Add results data into object from a FOOOFResults object. Parameters ---------- - fooof_result : FOOOFResult - An object containing the results from fitting a FOOOF model. - regenerate : bool, optional, default: False - Whether to regenerate the model fits from the given fit parameters. + fooof_result : FOOOFResults + A FOOOF data object containing the results from fitting a FOOOF model. """ self.aperiodic_params_ = fooof_result.aperiodic_params @@ -277,8 +286,7 @@ def add_results(self, fooof_result, regenerate=False): self.error_ = fooof_result.error self._gaussian_params = fooof_result.gaussian_params - if regenerate: - self._regenerate_model() + self._check_loaded_results(fooof_result._asdict()) def report(self, freqs=None, power_spectrum=None, freq_range=None, plt_log=False): @@ -426,11 +434,40 @@ def print_report_issue(concise=False): print(gen_issue_str(concise)) + def get_settings(self): + """Return user defined settings of the FOOOF object. + + Returns + ------- + FOOOFSettings + Object containing the settings from the current FOOOF object. + """ + + return FOOOFSettings(**{key : getattr(self, key) for key in get_obj_desc()['settings']}) + + + def get_data_info(self): + """Return data information from the FOOOF object. + + Returns + ------- + FOOOFDataInfo + Object containing information about the data from the current FOOOF object. + """ + + return FOOOFDataInfo(**{key : getattr(self, key) for key in get_obj_desc()['data_info']}) + + def get_results(self): - """Return model fit parameters and goodness of fit metrics.""" + """Return model fit parameters and goodness of fit metrics. - return FOOOFResult(self.aperiodic_params_, self.peak_params_, self.r_squared_, - self.error_, self._gaussian_params) + Returns + ------- + FOOOFResults + Object containing the FOOOF model fit results from the current FOOOF object. + """ + + return FOOOFResults(**{key.strip('_') : getattr(self, key) for key in get_obj_desc()['results']}) @copy_doc_func_to_method(plot_fm) @@ -452,8 +489,8 @@ def save(self, file_name='FOOOF_results', file_path=None, append=False, save_fm(self, file_name, file_path, append, save_results, save_settings, save_data) - def load(self, file_name='FOOOF_results', file_path=None): - """Load in FOOOF file. Reads in a JSON file. + def load(self, file_name='FOOOF_results', file_path=None, regenerate=True): + """Load in FOOOF file. Reads in a FOOOF formatted JSON file. Parameters ---------- @@ -461,6 +498,8 @@ def load(self, file_name='FOOOF_results', file_path=None): File from which to load data. file_path : str, optional Path to directory from which to load. If not provided, loads from current directory. + regenerate : bool, optional, default: True + Whether to regenerate the model fit from the loaded data, if data is available. """ # Reset data in object, so old data can't interfere @@ -472,6 +511,13 @@ def load(self, file_name='FOOOF_results', file_path=None): self._check_loaded_settings(data) self._check_loaded_results(data) + # Regenerate model components, based on what's available + if regenerate: + if self.freq_res: + self._regenerate_freqs() + if np.all(self.freqs) and np.all(self.aperiodic_params_): + self._regenerate_model() + def copy(self): """Return a copy of the FOOOF object.""" @@ -487,29 +533,6 @@ def _check_width_limits(self): print(gen_wid_warn_str(self.freq_res, self.peak_width_limits[0])) - def _check_loaded_results(self, data, regenerate=True): - """Check if results added, check data, and regenerate model, if requested. - - Parameters - ---------- - data : dict - The dictionary of data that has been added to the object. - regenerate : bool, optional, default: True - Whether to regenerate the power_spectrum model. - """ - - # If results loaded, check dimensions of peak parameters - # This fixes an issue where they end up the wrong shape if they are empty (no peaks) - if set(get_obj_desc()['results']).issubset(set(data.keys())): - self.peak_params_ = check_array_dim(self.peak_params_) - self._gaussian_params = check_array_dim(self._gaussian_params) - - # Regenerate power_spectrum model & components - if regenerate: - if np.all(self.freqs) and np.all(self.aperiodic_params_): - self._regenerate_model() - - def _simple_ap_fit(self, freqs, power_spectrum): """Fit the aperiodic component of the power spectrum. @@ -930,9 +953,21 @@ def _add_from_dict(self, data): for key in data.keys(): setattr(self, key, data[key]) - # Reconstruct frequency vector, if data available to do so - if self.freq_res: - self.freqs = gen_freqs(self.freq_range, self.freq_res) + + def _check_loaded_results(self, data): + """Check if results have been added and check data. + + Parameters + ---------- + data : dict + A dictionary of data that has been added to the object. + """ + + # If results loaded, check dimensions of peak parameters + # This fixes an issue where they end up the wrong shape if they are empty (no peaks) + if set(get_obj_desc()['results']).issubset(set(data.keys())): + self.peak_params_ = check_array_dim(self.peak_params_) + self._gaussian_params = check_array_dim(self._gaussian_params) def _check_loaded_settings(self, data): @@ -941,7 +976,7 @@ def _check_loaded_settings(self, data): Parameters ---------- data : dict - The dictionary of data that has been added to the object. + A dictionary of data that has been added to the object. """ # If settings not loaded from file, clear from object, so that default @@ -961,6 +996,12 @@ def _check_loaded_settings(self, data): self._reset_internal_settings() + def _regenerate_freqs(self): + """Regenerate the frequency vector, given the object metadata.""" + + self.freqs = gen_freqs(self.freq_range, self.freq_res) + + def _regenerate_model(self): """Regenerate model fit from parameters.""" diff --git a/fooof/funcs.py b/fooof/funcs.py index 2f8dd445d..7ed5e9ae8 100644 --- a/fooof/funcs.py +++ b/fooof/funcs.py @@ -4,7 +4,7 @@ from fooof import FOOOFGroup from fooof.synth.gen import gen_freqs -from fooof.utils import get_settings, get_obj_desc, compare_settings, compare_data_info +from fooof.utils import compare_info ################################################################################################### ################################################################################################### @@ -24,29 +24,35 @@ def combine_fooofs(fooofs): """ # Compare settings - if not compare_settings(fooofs) or not compare_data_info(fooofs): + if not compare_info(fooofs, 'settings') or not compare_info(fooofs, 'data_info'): raise ValueError("These objects have incompatible settings or data," \ "and so cannot be combined.") # Initialize FOOOFGroup object, with settings derived from input objects - fg = FOOOFGroup(**get_settings(fooofs[0]), verbose=fooofs[0].verbose) - fg.power_spectra = np.empty([0, len(fooofs[0].freqs)]) + fg = FOOOFGroup(*fooofs[0].get_settings(), verbose=fooofs[0].verbose) + + # Use a temporary store to collect spectra, because we only add them if consistently present + temp_power_spectra = np.empty([0, len(fooofs[0].freqs)]) # Add FOOOF results from each FOOOF object to group for f_obj in fooofs: # Add FOOOFGroup object if isinstance(f_obj, FOOOFGroup): fg.group_results.extend(f_obj.group_results) - fg.power_spectra = np.vstack([fg.power_spectra, f_obj.power_spectra]) + if f_obj.power_spectra is not None: + temp_power_spectra = np.vstack([temp_power_spectra, f_obj.power_spectra]) # Add FOOOF object else: fg.group_results.append(f_obj.get_results()) - fg.power_spectra = np.vstack([fg.power_spectra, f_obj.power_spectrum]) + if f_obj.power_spectrum is not None: + temp_power_spectra = np.vstack([temp_power_spectra, f_obj.power_spectrum]) + + # If the number of collected power spectra is consistent, then add them to object + if len(fg) == temp_power_spectra.shape[0]: + fg.power_spectra = temp_power_spectra # Add data information information - for data_info in get_obj_desc()['freq_info']: - setattr(fg, data_info, getattr(fooofs[0], data_info)) - fg.freqs = gen_freqs(fg.freq_range, fg.freq_res) + fg.add_data_info(fooofs[0].get_data_info()) return fg diff --git a/fooof/group.py b/fooof/group.py index f60d59079..30aa57f72 100644 --- a/fooof/group.py +++ b/fooof/group.py @@ -16,7 +16,7 @@ from fooof.core.reports import save_report_fg from fooof.core.strings import gen_results_str_fg from fooof.core.io import save_fg, load_jsonlines -from fooof.core.utils import get_data_indices +from fooof.core.info import get_data_indices from fooof.core.modutils import copy_doc_func_to_method, copy_doc_class, safe_import ################################################################################################### @@ -38,7 +38,7 @@ def __init__(self, *args, **kwargs): FOOOF.__init__(self, *args, **kwargs) - self.power_spectra = np.array([]) + self.power_spectra = None self._reset_group_results() @@ -59,6 +59,26 @@ def __getitem__(self, index): return self.group_results[index] + def _reset_data_results(self, clear_freqs=True, clear_spectrum=True, clear_results=True, clear_spectra=True): + """Set (or reset) data & results attributes to empty. + + Parameters + ---------- + clear_freqs : bool, optional, default: True + Whether to clear frequency attributes. + clear_power_spectrum : bool, optional, default: True + Whether to clear power spectrum attribute. + clear_results : bool, optional, default: True + Whether to clear model results attributes. + clear_spectra : bool, optional, default: True + Whether to clear power spectra attribute. + """ + + super()._reset_data_results(clear_freqs, clear_spectrum, clear_results) + if clear_spectra: + self.power_spectra = None + + def _reset_group_results(self, length=0): """Set (or reset) results to be empty. @@ -164,7 +184,7 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1): self.power_spectra), self.verbose, len(self.power_spectra))) - self._reset_data_results(clear_freqs=False) + self._reset_data_results(clear_freqs=False, clear_spectra=False) def get_results(self): @@ -266,7 +286,7 @@ def load(self, file_name='FOOOFGroup_results', file_path=None): if ind == 0: self._check_loaded_settings(data) - self._check_loaded_results(data, False) + self._check_loaded_results(data) self.group_results.append(self._get_results()) # Reset peripheral data from last loaded result, keeping freqs info @@ -279,30 +299,31 @@ def get_fooof(self, ind, regenerate=False): Parameters ---------- ind : int - The index of the FOOOFResult in FOOOFGroup.group_results to load. + The index of the FOOOFResults in FOOOFGroup.group_results to load. regenerate : bool, optional, default: False Whether to regenerate the model fits from the given fit parameters. Returns ------- inst : FOOOF() object - The FOOOFResult data loaded into a FOOOF object. + The FOOOFResults data loaded into a FOOOF object. """ # Initialize a FOOOF object, with same settings as current FOOOFGroup - fm = FOOOF(self.peak_width_limits, self.max_n_peaks, self.min_peak_amplitude, - self.peak_threshold, self.aperiodic_mode, self.verbose) + fm = FOOOF(*self.get_settings(), verbose=self.verbose) # Add data for specified single power spectrum, if available # The power spectrum is inverted back to linear, as it's re-logged when added to FOOOF if np.any(self.power_spectra): fm.add_data(self.freqs, np.power(10, self.power_spectra[ind])) - # If no power spectrum data available, copy over frequency information + # If no power spectrum data available, copy over data information & regenerate freqs else: - fm._add_from_dict({'freq_range': self.freq_range, 'freq_res': self.freq_res}) + fm.add_data_info(self.get_data_info()) # Add results for specified power spectrum, regenerating full fit if requested - fm.add_results(self.group_results[ind], regenerate=regenerate) + fm.add_results(self.group_results[ind]) + if regenerate: + fm._regenerate_model() return fm diff --git a/fooof/synth/params.py b/fooof/synth/params.py index 8989619be..47ceb386b 100644 --- a/fooof/synth/params.py +++ b/fooof/synth/params.py @@ -1,31 +1,15 @@ """Classes & functions for managing parameter choices for synthesizing power spectra.""" -from collections import namedtuple - import numpy as np -from fooof.core.utils import check_flat, get_data_indices from fooof.core.funcs import infer_ap_func +from fooof.core.utils import check_flat +from fooof.core.info import get_data_indices +from fooof.data import SynParams ################################################################################################### ################################################################################################### -SynParams = namedtuple('SynParams', ['aperiodic_params', 'gaussian_params', 'nlv']) - -SynParams.__doc__ = """\ -Stores parameters used to synthesize a single power spectra. - -Attributes ----------- -aperiodic_params : list, len 2 or 3 - Parameters that define the aperiodic fit. As [Offset, (Knee), Exponent]. - The knee parameter is only included if aperiodic is fit with knee. Otherwise, length is 2. -gaussian_params : list or list of lists - Fitted parameter values for the peaks. Each list is a peak, as [CF, Amp, BW]. -nlv : float - Noise level added to the generated power spectrum. -""" - def update_syn_ap_params(syn_params, delta, field=None): """Update the aperiodic parameter definition in a SynParams object. diff --git a/fooof/tests/test_core_info.py b/fooof/tests/test_core_info.py new file mode 100644 index 000000000..f7d3085e5 --- /dev/null +++ b/fooof/tests/test_core_info.py @@ -0,0 +1,31 @@ +"""Tests for FOOOF core.info.""" + +from fooof.core.info import * + +################################################################################################### +################################################################################################### + +def test_get_obj_desc(tfm): + + desc = get_obj_desc() + objs = dir(tfm) + + # Test that everything in dict is a valid component of the fooof object + for ke, va in desc.items(): + for it in va: + assert it in objs + +def test_get_data_indices(): + + indices_fixed = get_data_indices('fixed') + assert indices_fixed + for ke, va in indices_fixed.items(): + if ke == 'knee': + assert not va + else: + assert isinstance(va, int) + + indices_knee = get_data_indices('knee') + assert indices_knee + for ke, va in indices_knee.items(): + assert isinstance(va, int) diff --git a/fooof/tests/test_core_io.py b/fooof/tests/test_core_io.py index da832d06c..3510fc2b2 100644 --- a/fooof/tests/test_core_io.py +++ b/fooof/tests/test_core_io.py @@ -4,7 +4,7 @@ import pkg_resources as pkg from fooof import FOOOF -from fooof.core.utils import get_obj_desc +from fooof.core.info import get_obj_desc from fooof.core.io import * diff --git a/fooof/tests/test_core_utils.py b/fooof/tests/test_core_utils.py index 3e09bfa40..98c3aa27b 100644 --- a/fooof/tests/test_core_utils.py +++ b/fooof/tests/test_core_utils.py @@ -59,33 +59,6 @@ def test_check_array_dim(): out = check_array_dim(np.array([1, 2, 3])) assert out.shape == (0, 3) -def test_get_obj_desc(): - - desc = get_obj_desc() - - tfm = FOOOF() - objs = dir(tfm) - - # Test that everything in dict is a valid component of the fooof object - for ke, va in desc.items(): - for it in va: - assert it in objs - -def test_get_data_indices(): - - indices_fixed = get_data_indices('fixed') - assert indices_fixed - for ke, va in indices_fixed.items(): - if ke == 'knee': - assert not va - else: - assert isinstance(va, int) - - indices_knee = get_data_indices('knee') - assert indices_knee - for ke, va in indices_knee.items(): - assert isinstance(va, int) - def test_check_iter(): # Note: generator case not tested diff --git a/fooof/tests/test_data.py b/fooof/tests/test_data.py new file mode 100644 index 000000000..cb9575c81 --- /dev/null +++ b/fooof/tests/test_data.py @@ -0,0 +1,40 @@ +"""Tests for the FOOOF data objects.""" + +from fooof.core.info import get_obj_desc + +from fooof.data import * + +################################################################################################### +################################################################################################### + +def test_fooof_settings(): + + settings = FOOOFSettings([], None, None, None, None) + assert settings + + # Check that the object has the correct fields, given the object description + settings_fields = get_obj_desc()['settings'] + for field in settings_fields: + getattr(settings, field) + assert True + +def test_fooof_results(): + + results = FOOOFResults([], [], None, None, []) + assert results + + # Check that the object has the correct fields, given the object description + results_fields = get_obj_desc()['results'] + for field in results_fields: + getattr(results, field.strip('_')) + assert True + +def test_syn_params(): + + syn_params = SynParams([1, 1], [10, 1, 1], 0.05) + assert syn_params + + # Check that the object has the correct fields + for field in ['aperiodic_params', 'gaussian_params', 'nlv']: + getattr(syn_params, field) + assert True diff --git a/fooof/tests/test_fit.py b/fooof/tests/test_fit.py index 9da8b0601..4d1712234 100644 --- a/fooof/tests/test_fit.py +++ b/fooof/tests/test_fit.py @@ -12,8 +12,10 @@ import pkg_resources as pkg from fooof import FOOOF +from fooof.data import FOOOFSettings, FOOOFDataInfo, FOOOFResults from fooof.synth import gen_power_spectrum -from fooof.core.utils import group_three, get_obj_desc +from fooof.core.utils import group_three +from fooof.core.info import get_obj_desc from fooof.tests.utils import get_tfm, plot_test @@ -105,6 +107,52 @@ def test_fooof_load(): tfm.load(file_name_res, file_path) assert tfm +def test_adds(): + """Tests methods that add data to FOOOF objects. + + Checks: add_data, add_settings, add_results. + """ + + # Note: uses it's own tfm, to not add stuff to the global one + tfm = get_tfm() + + # Test adding data + freqs, pows = np.array([1, 2, 3]), np.array([10, 10, 10]) + tfm.add_data(freqs, pows) + assert np.all(tfm.freqs == freqs) + assert np.all(tfm.power_spectrum == np.log10(pows)) + + # Test adding settings + fooof_settings = FOOOFSettings([1, 4], 6, 0, 2, 'fixed') + tfm.add_settings(fooof_settings) + for setting in get_obj_desc()['settings']: + assert getattr(tfm, setting) == getattr(fooof_settings, setting) + + # Test adding data info + fooof_data_info = FOOOFDataInfo([3, 40], 0.5) + tfm.add_data_info(fooof_data_info) + for data_info in get_obj_desc()['data_info']: + assert getattr(tfm, data_info) == getattr(fooof_data_info, data_info) + + # Test adding results + fooof_results = FOOOFResults([1, 1], [10, 0.5, 0.5], 0.95, 0.02, [10, 0.5, 0.25]) + tfm.add_results(fooof_results) + for setting in get_obj_desc()['results']: + assert getattr(tfm, setting) == getattr(fooof_results, setting.strip('_')) + +def test_gets(tfm): + """Tests methods that return FOOOF data objects. + + Checks: get_settings, get_data_info, get_results + """ + + settings = tfm.get_settings() + assert isinstance(settings, FOOOFSettings) + data_info = tfm.get_data_info() + assert isinstance(data_info, FOOOFDataInfo) + results = tfm.get_results() + assert isinstance(results, FOOOFResults) + def test_copy(): """Test copy FOOOF method.""" @@ -113,18 +161,16 @@ def test_copy(): assert tfm != ntfm -def test_fooof_prints_get(tfm): - """Test methods that print, return results (alias and pass through methods). +def test_fooof_prints(tfm): + """Test methods that print (alias and pass through methods). - Checks: print_settings, print_results, get_results.""" + Checks: print_settings, print_results. + """ tfm.print_settings() tfm.print_results() tfm.print_report_issue() - out = tfm.get_results() - assert out - @plot_test def test_fooof_plot(tfm, skip_if_no_mpl): """Check the alias to plot FOOOF.""" @@ -140,12 +186,12 @@ def test_fooof_resets(): tfm._reset_data_results() tfm._reset_internal_settings() - assert tfm.freqs is None and tfm.freq_range is None and tfm.freq_res is None \ - and tfm.power_spectrum is None and tfm.fooofed_spectrum_ is None and tfm._spectrum_flat is None \ - and tfm._spectrum_peak_rm is None and tfm._ap_fit is None and tfm._peak_fit is None + desc = get_obj_desc() - assert np.all(np.isnan(tfm.aperiodic_params_)) and np.all(np.isnan(tfm.peak_params_)) \ - and np.all(np.isnan(tfm.r_squared_)) and np.all(np.isnan(tfm.error_)) and np.all(np.isnan(tfm._gaussian_params)) + for data in ['data', 'results', 'model_components']: + for field in desc[data]: + assert getattr(tfm, field) == None + assert tfm.freqs == None and tfm.fooofed_spectrum_ == None def test_fooof_report(skip_if_no_mpl): """Check that running the top level model method runs.""" diff --git a/fooof/tests/test_funcs.py b/fooof/tests/test_funcs.py index c5e922357..a83fa4cdb 100644 --- a/fooof/tests/test_funcs.py +++ b/fooof/tests/test_funcs.py @@ -4,7 +4,7 @@ import numpy as np -from fooof.utils import compare_settings +from fooof.utils import compare_info from fooof.group import FOOOFGroup from fooof.synth import gen_group_power_spectra @@ -20,44 +20,51 @@ def test_combine_fooofs(tfm, tfg): tfg2 = tfg.copy(); tfg3 = tfg.copy() # Check combining 2 FOOOFs - fg1 = combine_fooofs([tfm, tfm2]) - assert fg1 - assert len(fg1) == 2 - assert compare_settings([fg1, tfm]) - assert fg1.group_results[0] == tfm.get_results() - assert fg1.group_results[-1] == tfm2.get_results() + nfg1 = combine_fooofs([tfm, tfm2]) + assert nfg1 + assert len(nfg1) == 2 + assert compare_info([nfg1, tfm], 'settings') + assert nfg1.group_results[0] == tfm.get_results() + assert nfg1.group_results[-1] == tfm2.get_results() # Check combining 3 FOOOFs - fg2 = combine_fooofs([tfm, tfm2, tfm3]) - assert fg2 - assert len(fg2) == 3 - assert compare_settings([fg2, tfm]) - assert fg2.group_results[0] == tfm.get_results() - assert fg2.group_results[-1] == tfm3.get_results() + nfg2 = combine_fooofs([tfm, tfm2, tfm3]) + assert nfg2 + assert len(nfg2) == 3 + assert compare_info([nfg2, tfm], 'settings') + assert nfg2.group_results[0] == tfm.get_results() + assert nfg2.group_results[-1] == tfm3.get_results() # Check combining 2 FOOOFGroups - nfg1 = combine_fooofs([tfg, tfg2]) - assert nfg1 - assert len(nfg1) == len(tfg) + len(tfg2) - assert compare_settings([nfg1, tfg, tfg2]) - assert nfg1.group_results[0] == tfg.group_results[0] - assert nfg1.group_results[-1] == tfg2.group_results[-1] + nfg3 = combine_fooofs([tfg, tfg2]) + assert nfg3 + assert len(nfg3) == len(tfg) + len(tfg2) + assert compare_info([nfg3, tfg, tfg2], 'settings') + assert nfg3.group_results[0] == tfg.group_results[0] + assert nfg3.group_results[-1] == tfg2.group_results[-1] # Check combining 3 FOOOFGroups - nfg2 = combine_fooofs([tfg, tfg2, tfg3]) - assert nfg2 - assert len(nfg2) == len(tfg) + len(tfg2) + len(tfg3) - assert compare_settings([nfg2, tfg, tfg2, tfg3]) - assert nfg2.group_results[0] == tfg.group_results[0] - assert nfg2.group_results[-1] == tfg3.group_results[-1] + nfg4 = combine_fooofs([tfg, tfg2, tfg3]) + assert nfg4 + assert len(nfg4) == len(tfg) + len(tfg2) + len(tfg3) + assert compare_info([nfg4, tfg, tfg2, tfg3], 'settings') + assert nfg4.group_results[0] == tfg.group_results[0] + assert nfg4.group_results[-1] == tfg3.group_results[-1] # Check combining a mixture of FOOOF & FOOOFGroup - mfg3 = combine_fooofs([tfg, tfm, tfg2, tfm2]) - assert mfg3 - assert len(mfg3) == len(tfg) + 1 + len(tfg2) + 1 - assert compare_settings([tfg, tfm, tfg2, tfm2]) - assert mfg3.group_results[0] == tfg.group_results[0] - assert mfg3.group_results[-1] == tfm2.get_results() + nfg5 = combine_fooofs([tfg, tfm, tfg2, tfm2]) + assert nfg5 + assert len(nfg5) == len(tfg) + 1 + len(tfg2) + 1 + assert compare_info([nfg5, tfg, tfm, tfg2, tfm2], 'settings') + assert nfg5.group_results[0] == tfg.group_results[0] + assert nfg5.group_results[-1] == tfm2.get_results() + + # Check combining objects with no data + tfm2._reset_data_results(False, True, True) + tfg2._reset_data_results(False, True, True, True) + nfg6 = combine_fooofs([tfm2, tfg2]) + assert len(nfg6) == 1 + len(tfg2) + assert nfg6.power_spectra == None def test_combine_errors(tfm, tfg): diff --git a/fooof/tests/test_group.py b/fooof/tests/test_group.py index 3a8292e47..7cb367967 100644 --- a/fooof/tests/test_group.py +++ b/fooof/tests/test_group.py @@ -11,9 +11,9 @@ import numpy as np from fooof.group import * -from fooof.fit import FOOOFResult +from fooof.fit import FOOOFResults from fooof.synth import gen_group_power_spectra -from fooof.core.utils import get_obj_desc +from fooof.core.info import get_obj_desc from fooof.tests.utils import default_group_params, plot_test @@ -51,7 +51,7 @@ def test_fg_fit(): assert out assert len(out) == n_spectra - assert isinstance(out[0], FOOOFResult) + assert isinstance(out[0], FOOOFResults) assert np.all(out[1].aperiodic_params) def test_fg_fit_par(): @@ -66,7 +66,7 @@ def test_fg_fit_par(): assert out assert len(out) == n_spectra - assert isinstance(out[0], FOOOFResult) + assert isinstance(out[0], FOOOFResults) assert np.all(out[1].aperiodic_params) def test_fg_print(tfg): @@ -144,3 +144,12 @@ def test_fg_get_fooof(tfg): # Check that regenerated model is created for result in desc['results']: assert np.all(getattr(tfm1, result)) + + # Test when object has no data (clear a copy of tfg) + new_tfg = tfg.copy() + new_tfg._reset_data_results(False, True, True, True) + tfm2 = new_tfg.get_fooof(0, True) + assert tfm2 + # Check that data info is copied over properly + for data_info in desc['data_info']: + assert getattr(tfm2, data_info) diff --git a/fooof/tests/test_synth_params.py b/fooof/tests/test_synth_params.py index 77b01ecc6..b11a44f05 100644 --- a/fooof/tests/test_synth_params.py +++ b/fooof/tests/test_synth_params.py @@ -7,10 +7,6 @@ ################################################################################################### ################################################################################################### -def test_syn_params(): - - assert SynParams([1, 1], [10, 1, 1], 0.05) - def test_update_syn_ap_params(): syn_params = SynParams([1, 1], [10, 1, 1], 0.05) diff --git a/fooof/tests/test_utils.py b/fooof/tests/test_utils.py index bc25163da..574628936 100644 --- a/fooof/tests/test_utils.py +++ b/fooof/tests/test_utils.py @@ -17,35 +17,24 @@ def test_trim_spectrum(): assert np.array_equal(f_out, np.array([2., 3., 4.])) assert np.array_equal(p_out, np.array([3., 4., 5.])) -def test_get_settings(tfm, tfg): +def test_get_info(tfm, tfg): for f_obj in [tfm, tfg]: - assert get_settings(f_obj) + assert get_info(f_obj, 'settings') + assert get_info(f_obj, 'data_info') + assert get_info(f_obj, 'results') -def test_get_data_info(tfm, tfg): +def test_compare_info(tfm, tfg): for f_obj in [tfm, tfg]: - assert get_data_info(f_obj) -def test_compare_settings(tfm, tfg): - - for f_obj in [tfm, tfg]: f_obj2 = f_obj.copy() - assert compare_settings([f_obj, f_obj2]) - + assert compare_info([f_obj, f_obj2], 'settings') f_obj2.peak_width_limits = [2, 4] f_obj2._reset_internal_settings() + assert not compare_info([f_obj, f_obj2], 'settings') - assert not compare_settings([f_obj, f_obj2]) - -def test_compare_data_info(tfm, tfg): - - for f_obj in [tfm, tfg]: - f_obj2 = f_obj.copy() - - assert compare_data_info([f_obj, f_obj2]) - + assert compare_info([f_obj, f_obj2], 'data_info') f_obj2.freq_range = [5, 25] - - assert not compare_data_info([f_obj, f_obj2]) + assert not compare_info([f_obj, f_obj2], 'data_info') diff --git a/fooof/utils.py b/fooof/utils.py index 982d09a7c..8e5140855 100644 --- a/fooof/utils.py +++ b/fooof/utils.py @@ -1,9 +1,9 @@ -"""Utility functions for FOOOF.""" +"""Public utility & helper functions for FOOOF.""" import numpy as np from fooof.synth import gen_freqs -from fooof.core.utils import get_obj_desc +from fooof.core.info import get_obj_desc ################################################################################################### ################################################################################################### @@ -44,89 +44,47 @@ def trim_spectrum(freqs, power_spectra, f_range): return freqs_ext, power_spectra_ext -def get_settings(f_obj): - """Get a dictionary of current settings from a FOOOF or FOOOFGroup object. +def get_info(f_obj, aspect): + """Get a specified selection of information from a FOOOF derived object. Parameters ---------- f_obj : FOOOF or FOOOFGroup - FOOOF derived object to get settings from. + FOOOF derived object to get attributes from. + aspect : {'settings', 'data_info', 'results'} + Which set of attributes to compare the objects across. Returns ------- - dictionary - Settings for the input FOOOF derived object. + dict + The set of specified info from the FOOOF derived object. """ - return {setting : getattr(f_obj, setting) for setting in get_obj_desc()['settings']} + return {key : getattr(f_obj, key) for key in get_obj_desc()[aspect]} -def get_data_info(f_obj): - """Get a dictionary of current data information from a FOOOF or FOOOFGroup object. - - Parameters - ---------- - f_obj : FOOOF or FOOOFGroup - FOOOF derived object to get data information from. - - Returns - ------- - dictionary - Data information for the input FOOOF derived object. - - Notes - ----- - Data information for a FOOOF object is the frequency range and frequency resolution. - """ - - return {dat_info : getattr(f_obj, dat_info) for dat_info in get_obj_desc()['freq_info']} - - -def compare_settings(lst): - """Compare the settings between FOOOF and/or FOOOFGroup objects. +def compare_info(lst, aspect): + """Compare a specified aspect of FOOOF objects across instances. Parameters ---------- lst : list of FOOOF or FOOOFGroup objects - FOOOF related objects whose settings are to be compared. + FOOOF related objects whose attibutes are to be compared. + aspect : {'setting', 'data_info'} + Which set of attributes to compare the objects across. Returns ------- - bool + consistent : bool Whether the settings are consistent across the input list of objects. """ - # Check settings are the same across list of given objects - for ind, f_obj in enumerate(lst[:-1]): - if get_settings(f_obj) != get_settings(lst[ind+1]): - return False - - # If no settings fail comparison, return that objects have consistent settings - return True - - -def compare_data_info(lst): - """Compare the data information between FOOOF and/or FOOOFGroup objects. - - Parameters - ---------- - lst : list of FOOOF or FOOOFGroup objects - FOOOF related objects whose settings are to be compared. - - Returns - ------- - bool - Whether the data information is consistent across the input list of objects. - - Notes - ----- - Data information for a FOOOF object is the frequency range and frequency resolution. - """ - - # Check data information is the same across the list of given objects - for ind, f_obj in enumerate(lst[:-1]): - if get_data_info(f_obj) != get_data_info(lst[ind+1]): - return False + # Check specified aspect of the objects are the same across instances + for f_obj_1, f_obj_2 in zip(lst[:-1], lst[1:]): + if getattr(f_obj_1, 'get_' + aspect)() != getattr(f_obj_2, 'get_' + aspect)(): + consistent = False + break + else: + consistent = True - # If no data info comparisons fail, return that objects have consistent information - return True + return consistent