Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 65 additions & 20 deletions fooof/analysis.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,64 @@
"""Basic analysis functions for FOOOF results."""
"""Analysis functions for FOOOF results."""

import numpy as np

###################################################################################################
###################################################################################################

def get_band_peak_group(peak_params, band_def, n_fits):
"""Extracts peaks within a given band of interest, for a group of FOOOF model fits.
def get_band_peak_fm(fm, band, ret_one=True, attribute='peak_params'):
"""Extract peaks from a band of interest from a FOOOFGroup object.

Parameters
----------
fm : FOOOF
FOOOF object to extract peak data from.
band : tuple of (float, float)
Defines the band of interest, as (lower_frequency_bound, upper_frequency_bound).
ret_one : bool, optional, default: True
Whether to return single peak (if True) or all peaks within the range found (if False).
If True, returns the highest power peak within the search range.
attribute : {'peak_params', 'gaussian_params'}
Which attribute of peak data to extract data from.

Returns
-------
1d or 2d array
Peak data. Each row is a peak, as [CF, Amp, BW]
"""

return get_band_peak(getattr(fm, attribute + '_'), band, ret_one)


def get_band_peaks_fg(fg, band, attribute='peak_params'):
"""Extract peaks from a band of interest from a FOOOF object.

Parameters
----------
fg : FOOOFGroup
FOOOFGroup object to extract peak data from.
band : tuple of (float, float)
Defines the band of interest, as (lower_frequency_bound, upper_frequency_bound).
attribute : {'peak_params', 'gaussian_params'}
Which attribute of peak data to extract data from.

Returns
-------
2d array
Peak data. Each row is a peak, as [CF, Amp, BW].
"""

return get_band_peaks_group(fg.get_all_data(attribute), band, len(fg))


def get_band_peaks_group(peak_params, band, n_fits):
"""Extracts peaks within a given band of interest.

Parameters
----------
peak_params : 2d array
Peak parameters, for a group fit, from FOOOF, with shape of [n_peaks, 4].
band_def : [float, float]
Defines the band of interest, as [lower_frequency_bound, upper_frequency_bound].
band : tuple of (float, float)
Defines the band of interest, as (lower_frequency_bound, upper_frequency_bound).
n_fits : int
The number of model fits in the FOOOFGroup data.

Expand All @@ -34,32 +79,32 @@ def get_band_peak_group(peak_params, band_def, n_fits):

>>> peaks = np.empty((0, 3))
>>> for f_res in fg:
>>> peaks = np.vstack((peaks, get_band_peak(f_res.peak_params, band_def, ret_one=False)))
>>> peaks = np.vstack((peaks, get_band_peak(f_res.peak_params, band, ret_one=False)))

"""

band_peaks = np.zeros(shape=[n_fits, 3])
for ind in range(n_fits):

# Extacts an array per FOOOF fit, and extracts band peaks from it
# Extacts an array per FOOOF fit, and extracts band peaks from it
for ind in range(n_fits):
band_peaks[ind, :] = get_band_peak(peak_params[tuple([peak_params[:, -1] == ind])][:, 0:3],
band_def=band_def, ret_one=True)
band=band, ret_one=True)

return band_peaks


def get_band_peak(peak_params, band_def, ret_one=True):
"""Extracts peaks within a given band of interest, for a FOOOF model fit.
def get_band_peak(peak_params, band, ret_one=True):
"""Extracts peaks within a given band of interest.

Parameters
----------
peak_params : 2d array
Peak parameters, from FOOOF, with shape of [n_peaks, 3].
band_def : [float, float]
Defines the band of interest, as [lower_frequency_bound, upper_frequency_bound].
band : tuple of (float, float)
Defines the band of interest, as (lower_frequency_bound, upper_frequency_bound).
ret_one : bool, optional, default: True
Whether to return single peak (if True) or all peaks within the range found (if False).
If True, returns the highest amplitude peak within the search range.
If True, returns the highest power peak within the search range.

Returns
-------
Expand All @@ -72,7 +117,7 @@ def get_band_peak(peak_params, band_def, ret_one=True):
return np.array([np.nan, np.nan, np.nan])

# Find indices of peaks in the specified range, and check the number found
peak_inds = (peak_params[:, 0] >= band_def[0]) & (peak_params[:, 0] <= band_def[1])
peak_inds = (peak_params[:, 0] >= band[0]) & (peak_params[:, 0] <= band[1])
n_peaks = sum(peak_inds)

# If there are no peaks within the specified range
Expand All @@ -82,12 +127,12 @@ def get_band_peak(peak_params, band_def, ret_one=True):

band_peaks = peak_params[peak_inds, :]

# If results > 1 and ret_one, then we return the highest amplitude peak
# If results > 1 and ret_one, then we return the highest power peak
# Call a sub-function to select highest power peak in band
if n_peaks > 1 and ret_one:
band_peaks = get_highest_amp_peak(band_peaks)

# If results == 1, return peak - [cen, power, bw]
# If results == 1, return single peak
return np.squeeze(band_peaks)


Expand All @@ -96,13 +141,13 @@ def get_highest_amp_peak(band_peaks):

Parameters
----------
peak_params : 2d array
band_peaks : 2d array
Peak parameters, from FOOOF, with shape of [n_peaks, 3].

Returns
-------
band_peaks : array
Peak data. Each row is a peak, as [CF, Amp, BW].
1d array
Seleced peak data. Row is a peak, as [CF, Amp, BW].
"""

# Catch & return NaN if empty
Expand Down
132 changes: 132 additions & 0 deletions fooof/bands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""A class for managing band definitions."""

from collections import OrderedDict

###################################################################################################
###################################################################################################

class Bands():
"""Class to hold bands definitions.

Attributes
----------
bands : dict
Dictionary of band definitions.
Each entry should be as {'label' : (f_low, f_high)}.
"""

def __init__(self, input_bands={}):
"""Initialize the Bands object.

Parameters
----------
input_bands : dict, optional
A dictionary of oscillation bands to use.
"""

self.bands = OrderedDict()

for label, band_def in input_bands.items():
self.add_band(label, band_def)

def __getitem__(self, label):

try:
return self.bands[label]
except KeyError:
message = "The label '{}' was not found in the defined bands.".format(label)
raise BandNotDefinedError(message) from None

def __getattr__(self, label):

return self.__getitem__(label)

def __repr__(self):

return '\n'.join(['{:8} : {:2} - {:2} Hz'.format(key, *val) \
for key, val in self.bands.items()])

def __len__(self):

return self.n_bands

def __iter__(self):

for label, band_definition in self.bands.items():
yield (label, band_definition)

@property
def labels(self):
"""Get the labels for all bands defined in the object."""

return list(self.bands.keys())

@property
def n_bands(self):
"""Get the number of bands defined in the object."""

return len(self.bands)


def add_band(self, label, band_definition):
"""Add a new oscillation band definition.

Parameters
----------
label : str
Band label to add.
band_definition : tuple of (float, float)
The lower and upper frequency limit of the band, in Hz.
"""

self._check_band(label, band_definition)
self.bands[label] = band_definition


def remove_band(self, label):
"""Remove a previously defined oscillation band.

Parameters
----------
label : str
Band label to remove from band definitions.
"""

self.bands.pop(label)


@staticmethod
def _check_band(label, band_definition):
"""Check that a proposed band definition is valid.

Parameters
----------
label : str
The name of the new band.
band_definition : tuple of (float, float)
The lower and upper frequency limit of the band, in Hz.

Raises
------
InconsistentDataError
If band definition is not properly formatted.
"""

# Check that band name is a string
if not isinstance(label, str):
raise InconsistentDataError('Band name definition is not a string.')

# Check that band limits has the right size
if not len(band_definition) == 2:
raise InconsistentDataError('Band limit definition is not the right size.')

# Safety check that limits are in correct order
if not band_definition[0] < band_definition[1]:
raise InconsistentDataError('Band limit definitions are invalid.')


class BandNotDefinedError(Exception):
pass

class InconsistentDataError(Exception):
pass
55 changes: 53 additions & 2 deletions fooof/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,64 @@

import numpy as np

from fooof import FOOOFGroup
from fooof.synth.gen import gen_freqs
from fooof import FOOOF, FOOOFGroup
from fooof.data import FOOOFResults
from fooof.utils import compare_info
from fooof.synth.gen import gen_freqs
from fooof.analysis import get_band_peaks_fg

###################################################################################################
###################################################################################################

def average_fg(fg, bands, avg_method='mean'):
"""Average across a FOOOFGroup object.

Parameters
----------
fg : FOOOFGroup
A FOOOFGroup object with data to average across.
bands : Bands
Bands object that defines the frequency bands to collapse peaks across.
avg : {'mean', 'median'}
Averaging function to use.

Returns
-------
fm : FOOOF
FOOOF object containing the average results from the FOOOFGroup input.
"""

if avg_method not in ['mean', 'median']:
raise ValueError('Requested average method not understood.')
if not len(fg):
raise ValueError('Input FOOOFGroup has no fit results - can not proceed.')

if avg_method == 'mean':
avg_func = np.nanmean
elif avg_method == 'median':
avg_func = np.nanmedian

ap_params = avg_func(fg.get_all_data('aperiodic_params'), 0)

peak_params = np.array([avg_func(get_band_peaks_fg(fg, band, 'peak_params'), 0) \
for label, band in bands])
gaussian_params = np.array([avg_func(get_band_peaks_fg(fg, band, 'gaussian_params'), 0) \
for label, band in bands])

r2 = avg_func(fg.get_all_data('r_squared'))
error = avg_func(fg.get_all_data('error'))

results = FOOOFResults(ap_params, peak_params, r2, error, gaussian_params)

# Create the new FOOOF object, with settings, data info & results
fm = FOOOF()
fm.add_settings(fg.get_settings())
fm.add_data_info(fg.get_data_info())
fm.add_results(results)

return fm


def combine_fooofs(fooofs):
"""Combine a group of FOOOF and/or FOOOFGroup objects into a single FOOOFGroup object.

Expand Down
6 changes: 5 additions & 1 deletion fooof/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np

from fooof.core.modutils import safe_import
from fooof.tests.utils import get_tfm, get_tfg
from fooof.tests.utils import get_tfm, get_tfg, get_tbands

plt = safe_import('.pyplot', 'matplotlib')

Expand Down Expand Up @@ -45,6 +45,10 @@ def tfm():
def tfg():
yield get_tfg()

@pytest.fixture(scope='session')
def tbands():
yield get_tbands()

@pytest.fixture(scope='session')
def skip_if_no_mpl():
if not safe_import('matplotlib'):
Expand Down
Loading