diff --git a/fooof/analysis.py b/fooof/analysis.py index 041f165b2..2c78f8222 100644 --- a/fooof/analysis.py +++ b/fooof/analysis.py @@ -1,19 +1,64 @@ -"""Basic analysis functions for FOOOF results.""" +"""Analysis functions for FOOOF results.""" import numpy as np ################################################################################################### ################################################################################################### -def get_band_peak_group(peak_params, band_def, n_fits): - """Extracts peaks within a given band of interest, for a group of FOOOF model fits. +def get_band_peak_fm(fm, band, ret_one=True, attribute='peak_params'): + """Extract peaks from a band of interest from a FOOOFGroup object. + + Parameters + ---------- + fm : FOOOF + FOOOF object to extract peak data from. + band : tuple of (float, float) + Defines the band of interest, as (lower_frequency_bound, upper_frequency_bound). + ret_one : bool, optional, default: True + Whether to return single peak (if True) or all peaks within the range found (if False). + If True, returns the highest power peak within the search range. + attribute : {'peak_params', 'gaussian_params'} + Which attribute of peak data to extract data from. + + Returns + ------- + 1d or 2d array + Peak data. Each row is a peak, as [CF, Amp, BW] + """ + + return get_band_peak(getattr(fm, attribute + '_'), band, ret_one) + + +def get_band_peaks_fg(fg, band, attribute='peak_params'): + """Extract peaks from a band of interest from a FOOOF object. + + Parameters + ---------- + fg : FOOOFGroup + FOOOFGroup object to extract peak data from. + band : tuple of (float, float) + Defines the band of interest, as (lower_frequency_bound, upper_frequency_bound). + attribute : {'peak_params', 'gaussian_params'} + Which attribute of peak data to extract data from. + + Returns + ------- + 2d array + Peak data. Each row is a peak, as [CF, Amp, BW]. + """ + + return get_band_peaks_group(fg.get_all_data(attribute), band, len(fg)) + + +def get_band_peaks_group(peak_params, band, n_fits): + """Extracts peaks within a given band of interest. Parameters ---------- peak_params : 2d array Peak parameters, for a group fit, from FOOOF, with shape of [n_peaks, 4]. - band_def : [float, float] - Defines the band of interest, as [lower_frequency_bound, upper_frequency_bound]. + band : tuple of (float, float) + Defines the band of interest, as (lower_frequency_bound, upper_frequency_bound). n_fits : int The number of model fits in the FOOOFGroup data. @@ -34,32 +79,32 @@ def get_band_peak_group(peak_params, band_def, n_fits): >>> peaks = np.empty((0, 3)) >>> for f_res in fg: - >>> peaks = np.vstack((peaks, get_band_peak(f_res.peak_params, band_def, ret_one=False))) + >>> peaks = np.vstack((peaks, get_band_peak(f_res.peak_params, band, ret_one=False))) """ band_peaks = np.zeros(shape=[n_fits, 3]) - for ind in range(n_fits): - # Extacts an array per FOOOF fit, and extracts band peaks from it + # Extacts an array per FOOOF fit, and extracts band peaks from it + for ind in range(n_fits): band_peaks[ind, :] = get_band_peak(peak_params[tuple([peak_params[:, -1] == ind])][:, 0:3], - band_def=band_def, ret_one=True) + band=band, ret_one=True) return band_peaks -def get_band_peak(peak_params, band_def, ret_one=True): - """Extracts peaks within a given band of interest, for a FOOOF model fit. +def get_band_peak(peak_params, band, ret_one=True): + """Extracts peaks within a given band of interest. Parameters ---------- peak_params : 2d array Peak parameters, from FOOOF, with shape of [n_peaks, 3]. - band_def : [float, float] - Defines the band of interest, as [lower_frequency_bound, upper_frequency_bound]. + band : tuple of (float, float) + Defines the band of interest, as (lower_frequency_bound, upper_frequency_bound). ret_one : bool, optional, default: True Whether to return single peak (if True) or all peaks within the range found (if False). - If True, returns the highest amplitude peak within the search range. + If True, returns the highest power peak within the search range. Returns ------- @@ -72,7 +117,7 @@ def get_band_peak(peak_params, band_def, ret_one=True): return np.array([np.nan, np.nan, np.nan]) # Find indices of peaks in the specified range, and check the number found - peak_inds = (peak_params[:, 0] >= band_def[0]) & (peak_params[:, 0] <= band_def[1]) + peak_inds = (peak_params[:, 0] >= band[0]) & (peak_params[:, 0] <= band[1]) n_peaks = sum(peak_inds) # If there are no peaks within the specified range @@ -82,12 +127,12 @@ def get_band_peak(peak_params, band_def, ret_one=True): band_peaks = peak_params[peak_inds, :] - # If results > 1 and ret_one, then we return the highest amplitude peak + # If results > 1 and ret_one, then we return the highest power peak # Call a sub-function to select highest power peak in band if n_peaks > 1 and ret_one: band_peaks = get_highest_amp_peak(band_peaks) - # If results == 1, return peak - [cen, power, bw] + # If results == 1, return single peak return np.squeeze(band_peaks) @@ -96,13 +141,13 @@ def get_highest_amp_peak(band_peaks): Parameters ---------- - peak_params : 2d array + band_peaks : 2d array Peak parameters, from FOOOF, with shape of [n_peaks, 3]. Returns ------- - band_peaks : array - Peak data. Each row is a peak, as [CF, Amp, BW]. + 1d array + Seleced peak data. Row is a peak, as [CF, Amp, BW]. """ # Catch & return NaN if empty diff --git a/fooof/bands.py b/fooof/bands.py new file mode 100644 index 000000000..a9ee639f9 --- /dev/null +++ b/fooof/bands.py @@ -0,0 +1,132 @@ +"""A class for managing band definitions.""" + +from collections import OrderedDict + +################################################################################################### +################################################################################################### + +class Bands(): + """Class to hold bands definitions. + + Attributes + ---------- + bands : dict + Dictionary of band definitions. + Each entry should be as {'label' : (f_low, f_high)}. + """ + + def __init__(self, input_bands={}): + """Initialize the Bands object. + + Parameters + ---------- + input_bands : dict, optional + A dictionary of oscillation bands to use. + """ + + self.bands = OrderedDict() + + for label, band_def in input_bands.items(): + self.add_band(label, band_def) + + def __getitem__(self, label): + + try: + return self.bands[label] + except KeyError: + message = "The label '{}' was not found in the defined bands.".format(label) + raise BandNotDefinedError(message) from None + + def __getattr__(self, label): + + return self.__getitem__(label) + + def __repr__(self): + + return '\n'.join(['{:8} : {:2} - {:2} Hz'.format(key, *val) \ + for key, val in self.bands.items()]) + + def __len__(self): + + return self.n_bands + + def __iter__(self): + + for label, band_definition in self.bands.items(): + yield (label, band_definition) + + @property + def labels(self): + """Get the labels for all bands defined in the object.""" + + return list(self.bands.keys()) + + @property + def n_bands(self): + """Get the number of bands defined in the object.""" + + return len(self.bands) + + + def add_band(self, label, band_definition): + """Add a new oscillation band definition. + + Parameters + ---------- + label : str + Band label to add. + band_definition : tuple of (float, float) + The lower and upper frequency limit of the band, in Hz. + """ + + self._check_band(label, band_definition) + self.bands[label] = band_definition + + + def remove_band(self, label): + """Remove a previously defined oscillation band. + + Parameters + ---------- + label : str + Band label to remove from band definitions. + """ + + self.bands.pop(label) + + + @staticmethod + def _check_band(label, band_definition): + """Check that a proposed band definition is valid. + + Parameters + ---------- + label : str + The name of the new band. + band_definition : tuple of (float, float) + The lower and upper frequency limit of the band, in Hz. + + Raises + ------ + InconsistentDataError + If band definition is not properly formatted. + """ + + # Check that band name is a string + if not isinstance(label, str): + raise InconsistentDataError('Band name definition is not a string.') + + # Check that band limits has the right size + if not len(band_definition) == 2: + raise InconsistentDataError('Band limit definition is not the right size.') + + # Safety check that limits are in correct order + if not band_definition[0] < band_definition[1]: + raise InconsistentDataError('Band limit definitions are invalid.') + + +class BandNotDefinedError(Exception): + pass + +class InconsistentDataError(Exception): + pass diff --git a/fooof/funcs.py b/fooof/funcs.py index 7ed5e9ae8..661ca51fa 100644 --- a/fooof/funcs.py +++ b/fooof/funcs.py @@ -2,13 +2,64 @@ import numpy as np -from fooof import FOOOFGroup -from fooof.synth.gen import gen_freqs +from fooof import FOOOF, FOOOFGroup +from fooof.data import FOOOFResults from fooof.utils import compare_info +from fooof.synth.gen import gen_freqs +from fooof.analysis import get_band_peaks_fg ################################################################################################### ################################################################################################### +def average_fg(fg, bands, avg_method='mean'): + """Average across a FOOOFGroup object. + + Parameters + ---------- + fg : FOOOFGroup + A FOOOFGroup object with data to average across. + bands : Bands + Bands object that defines the frequency bands to collapse peaks across. + avg : {'mean', 'median'} + Averaging function to use. + + Returns + ------- + fm : FOOOF + FOOOF object containing the average results from the FOOOFGroup input. + """ + + if avg_method not in ['mean', 'median']: + raise ValueError('Requested average method not understood.') + if not len(fg): + raise ValueError('Input FOOOFGroup has no fit results - can not proceed.') + + if avg_method == 'mean': + avg_func = np.nanmean + elif avg_method == 'median': + avg_func = np.nanmedian + + ap_params = avg_func(fg.get_all_data('aperiodic_params'), 0) + + peak_params = np.array([avg_func(get_band_peaks_fg(fg, band, 'peak_params'), 0) \ + for label, band in bands]) + gaussian_params = np.array([avg_func(get_band_peaks_fg(fg, band, 'gaussian_params'), 0) \ + for label, band in bands]) + + r2 = avg_func(fg.get_all_data('r_squared')) + error = avg_func(fg.get_all_data('error')) + + results = FOOOFResults(ap_params, peak_params, r2, error, gaussian_params) + + # Create the new FOOOF object, with settings, data info & results + fm = FOOOF() + fm.add_settings(fg.get_settings()) + fm.add_data_info(fg.get_data_info()) + fm.add_results(results) + + return fm + + def combine_fooofs(fooofs): """Combine a group of FOOOF and/or FOOOFGroup objects into a single FOOOFGroup object. diff --git a/fooof/tests/conftest.py b/fooof/tests/conftest.py index 99311c6d6..9be886105 100644 --- a/fooof/tests/conftest.py +++ b/fooof/tests/conftest.py @@ -8,7 +8,7 @@ import numpy as np from fooof.core.modutils import safe_import -from fooof.tests.utils import get_tfm, get_tfg +from fooof.tests.utils import get_tfm, get_tfg, get_tbands plt = safe_import('.pyplot', 'matplotlib') @@ -45,6 +45,10 @@ def tfm(): def tfg(): yield get_tfg() +@pytest.fixture(scope='session') +def tbands(): + yield get_tbands() + @pytest.fixture(scope='session') def skip_if_no_mpl(): if not safe_import('matplotlib'): diff --git a/fooof/tests/test_analysis.py b/fooof/tests/test_analysis.py index abb468ab0..7307aff1e 100644 --- a/fooof/tests/test_analysis.py +++ b/fooof/tests/test_analysis.py @@ -7,15 +7,23 @@ ################################################################################################### ################################################################################################### -def test_get_band_peak_group(): +def test_get_band_peak_fm(tfm): + + assert np.all(get_band_peak_fm(tfm, (8, 12))) + +def test_get_band_peaks_fg(tfg): + + assert np.all(get_band_peaks_fg(tfg, (8, 12))) + +def test_get_band_peaks_group(): dat = np.array([[10, 1, 1.8, 0], [13, 1, 2, 2], [14, 2, 4, 2]]) - out1 = get_band_peak_group(dat, [8, 12], 3) + out1 = get_band_peaks_group(dat, [8, 12], 3) assert out1.shape == (3, 3) assert np.array_equal(out1[0, :], [10, 1, 1.8]) - out2 = get_band_peak_group(dat, [12, 16], 3) + out2 = get_band_peaks_group(dat, [12, 16], 3) assert out2.shape == (3, 3) assert np.array_equal(out2[2, :], [14, 2, 4]) @@ -50,4 +58,4 @@ def test_empty_inputs(): dat = np.empty(shape=[0, 4]) - assert np.all(get_band_peak_group(dat, [8, 12], 0)) + assert np.all(get_band_peaks_group(dat, [8, 12], 0)) diff --git a/fooof/tests/test_bands.py b/fooof/tests/test_bands.py new file mode 100644 index 000000000..3b5bbd783 --- /dev/null +++ b/fooof/tests/test_bands.py @@ -0,0 +1,48 @@ +"""Test functions for FOOOF bands.""" + +from py.test import raises + +from fooof.bands import * + +################################################################################################### +################################################################################################### + +def test_bands(): + + bands = Bands() + assert isinstance(bands, Bands) + +def test_bands_add_band(): + + bands = Bands() + bands.add_band('test', (5, 10)) + assert bands.bands == {'test' : (5, 10)} + +def test_bands_remove_band(): + + bands = Bands() + bands.add_band('test', (5, 10)) + bands.remove_band('test') + assert bands.bands == {} + +def test_bands_errors(): + + bands = Bands() + with raises(InconsistentDataError): + bands.add_band(1, (1, 1)) + with raises(InconsistentDataError): + bands.add_band('test', (1, 1, 1)) + with raises(InconsistentDataError): + bands.add_band('test', (2, 1)) + +def test_bands_dunders(tbands): + + assert tbands['theta'] + assert tbands.alpha + assert repr(tbands) + assert len(tbands) == 3 + +def test_bands_properties(tbands): + + assert set(tbands.labels) == set(['theta', 'alpha', 'beta']) + assert tbands.n_bands == 3 diff --git a/fooof/tests/test_funcs.py b/fooof/tests/test_funcs.py index a83fa4cdb..24a3f9743 100644 --- a/fooof/tests/test_funcs.py +++ b/fooof/tests/test_funcs.py @@ -14,6 +14,11 @@ ################################################################################################### ################################################################################################### +def test_average_fg(tfg, tbands): + + nfm = average_fg(tfg, tbands) + assert nfm + def test_combine_fooofs(tfm, tfg): tfm2 = tfm.copy(); tfm3 = tfm.copy() diff --git a/fooof/tests/utils.py b/fooof/tests/utils.py index 51814283a..911b914b9 100644 --- a/fooof/tests/utils.py +++ b/fooof/tests/utils.py @@ -3,6 +3,7 @@ from functools import wraps from fooof import FOOOF, FOOOFGroup +from fooof.bands import Bands from fooof.synth import gen_power_spectrum, gen_group_power_spectra, param_sampler from fooof.core.modutils import safe_import @@ -36,6 +37,11 @@ def get_tfg(): return tfg +def get_tbands(): + """Get a bands object, for testing.""" + + return Bands({'theta' : (4, 8), 'alpha' : (8, 12), 'beta' : (13, 30)}) + def default_group_params(): """Create default parameters for generating a test group of power spectra."""