Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion specparam/plts/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def plot_event_model(event, **plot_kwargs):
color=PARAM_COLORS['presence'], ax=next(axes))
next(axes).axis('off')

# 03: goodness of fit
# 03: metrics
for ind, glabel in enumerate(event.results.metrics.labels):
plot_param_over_time_yshade(\
None, event.results.event_time_results[glabel],
Expand Down
43 changes: 26 additions & 17 deletions specparam/plts/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def plot_group_model(group, **plot_kwargs):

# Goodness of fit plot
ax1 = plt.subplot(gs[0, 1])
plot_group_goodness(group, ax1, **scatter_kwargs, custom_styler=None)
plot_group_metrics(group, ax1, **scatter_kwargs, custom_styler=None)

# Center frequencies plot
ax2 = plt.subplot(gs[1, :])
Expand All @@ -79,17 +79,17 @@ def plot_group_aperiodic(group, ax=None, **plot_kwargs):
if group.modes.aperiodic.name == 'knee':
plot_scatter_2(group.results.get_params('aperiodic', 'exponent'), 'Exponent',
group.results.get_params('aperiodic', 'knee'), 'Knee',
'Aperiodic Fit', ax=ax)
'Aperiodic Parameters', ax=ax)
else:
plot_scatter_1(group.results.get_params('aperiodic', 'exponent'), 'Exponent',
'Aperiodic Fit', ax=ax)
'Aperiodic Parameters', ax=ax)


@savefig
@style_plot
@check_dependency(plt, 'matplotlib')
def plot_group_goodness(group, ax=None, **plot_kwargs):
"""Plot goodness of fit results, in a scatter plot.
def plot_group_metrics(group, ax=None, **plot_kwargs):
"""Plot metrics results, in a scatter plot.

Parameters
----------
Expand All @@ -101,17 +101,26 @@ def plot_group_goodness(group, ax=None, **plot_kwargs):
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
"""

# Get indices of metrics to plot
err_ind = find_first_ind(group.results.metrics.labels, 'error')
err_label = group.results.metrics.labels[err_ind]
gof_ind = find_first_ind(group.results.metrics.labels, 'gof')
gof_label = group.results.metrics.labels[gof_ind]
if len(group.results.metrics) == 0:
ax.set(xticks=[], yticks=[])

plot_scatter_2(group.results.get_metrics(err_label),
group.results.metrics.flabels[err_ind],
group.results.get_metrics(gof_label),
group.results.metrics.flabels[gof_ind],
'Fit Quality', ax=ax)
if len(group.results.metrics) == 1:
plot_scatter_1(group.results.get_metrics(group.results.metrics.labels[0]),
group.results.metrics.flabels[0],
'Metrics', ax=ax)

elif len(group.results.metrics) >= 2:
ind1 = 0
ind2 = 1
if 'error' in group.results.metrics.categories:
ind1 = find_first_ind(group.results.metrics.labels, 'error')
if 'gof' in group.results.metrics.categories:
ind2 = find_first_ind(group.results.metrics.labels, 'gof')
plot_scatter_2(group.results.get_metrics(group.results.metrics.labels[ind1]),
group.results.metrics.flabels[ind1],
group.results.get_metrics(group.results.metrics.labels[ind2]),
group.results.metrics.flabels[ind2],
'Metrics', ax=ax)


@savefig
Expand All @@ -130,5 +139,5 @@ def plot_group_peak_frequencies(group, ax=None, **plot_kwargs):
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
"""

plot_hist(group.results.get_params('peak', 0)[:, 0], 'Center Frequency',
'Peaks - Center Frequencies', x_lims=group.data.freq_range, ax=ax)
plot_hist(group.results.get_params('peak', 'cf')[:, 0], 'Center Frequency',
'Peak Parameters - Center Frequencies', x_lims=group.data.freq_range, ax=ax)
2 changes: 1 addition & 1 deletion specparam/plts/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def plot_time_model(time, **plot_kwargs):
colors=[PARAM_COLORS[plabel] for plabel in time.modes.periodic.params.labels],
title='Periodic Parameters - ' + blabel, ax=next(axes))

# 03: goodness of fit
# 03: metrics
err_ind = find_first_ind(time.results.metrics.labels, 'error')
gof_ind = find_first_ind(time.results.metrics.labels, 'gof')
plot_params_over_time(None, \
Expand Down
4 changes: 2 additions & 2 deletions specparam/reports/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from specparam.io.utils import create_file_path
from specparam.modutils.dependencies import safe_import, check_dependency
from specparam.plts.templates import plot_text
from specparam.plts.group import (plot_group_aperiodic, plot_group_goodness,
from specparam.plts.group import (plot_group_aperiodic, plot_group_metrics,
plot_group_peak_frequencies)
from specparam.reports.strings import (gen_settings_str, gen_model_results_str,
gen_group_results_str, gen_time_results_str,
Expand Down Expand Up @@ -99,7 +99,7 @@ def save_group_report(group, file_name, file_path=None, add_settings=True):

# Goodness of fit plot
ax2 = plt.subplot(grid[1, 1])
plot_group_goodness(group, ax2, custom_styler=None)
plot_group_metrics(group, ax2, custom_styler=None)

# Peak center frequencies plot
ax3 = plt.subplot(grid[2, :])
Expand Down
82 changes: 43 additions & 39 deletions specparam/reports/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,6 @@ def gen_methods_report_str(concise=False):
return output


# TODO: UPDATE
def gen_methods_text_str(model=None):
"""Generate a string representation of a template methods report.

Expand All @@ -334,32 +333,43 @@ def gen_methods_text_str(model=None):
If None, the text is returned as a template, without values.
"""

template = (
if model:
settings_names = list(model.algorithm.settings.values.keys())
settings_values = list(model.algorithm.settings.values.values())
else:
settings_names = []
settings_values = []

template = [
"The periodic & aperiodic spectral parameterization algorithm (version {}) "
"was used to parameterize neural power spectra. "
"The model was fit with {} aperiodic mode and {} periodic mode. "
"Settings for the algorithm were set as: "
"peak width limits : {}; "
"max number of peaks : {}; "
"minimum peak height : {}; "
"peak threshold : {}; ."
]

if settings_names:
settings_strs = [el + ' : {}, ' for el in settings_names]
settings_strs[-1] = settings_strs[-1][:-2] + '. '
template.extend(settings_strs)
else:
template.extend('XX. ')

template.extend([
"Power spectra were parameterized across the frequency range "
"{} to {} Hz."
)
])

if model:
freq_range = model.data.freq_range if model.data.has_data else ('XX', 'XX')
if model and model.data.has_data:
freq_range = model.data.freq_range
else:
freq_range = ('XX', 'XX')

methods_str = template.format(MODULE_VERSION,
model.modes.aperiodic.name if model else 'XX',
model.modes.periodic.name if model else 'XX',
model.algorithm.settings.peak_width_limits if model else 'XX',
model.algorithm.settings.max_n_peaks if model else 'XX',
model.algorithm.settings.min_peak_height if model else 'XX',
model.algorithm.settings.peak_threshold if model else 'XX',
*freq_range)
methods_str = ''.join(template).format(\
MODULE_VERSION,
model.modes.aperiodic.name if model else 'XX',
model.modes.periodic.name if model else 'XX',
*settings_values,
*freq_range)

return methods_str

Expand Down Expand Up @@ -401,21 +411,18 @@ def gen_model_results_str(model, concise=False):
_report_str_model(model),
'',

# Aperiodic parameters
'Aperiodic Parameters (\'{}\' mode)'.format(model.modes.aperiodic.name),
'(' + ', '.join(model.modes.aperiodic.params.labels) + ')',
', '.join(['{:2.4f}'] * \
len(model.results.params.aperiodic.params)).format(*model.results.params.aperiodic.params),
'',

# Peak parameters
'Peak Parameters (\'{}\' mode) {} peaks found'.format(\
model.modes.periodic.name, model.results.n_peaks),
*[peak_str.format(*op) for op in model.results.params.periodic.params],
'',

# Metrics
'Model fit quality metrics:',
'Model metrics:',
*['{:>18s} is {:1.4f} {:8s}'.format('{:s} ({:s})'.format(*key.split('_')), res, ' ') \
for key, res in model.results.metrics.results.items()],
'',
Expand Down Expand Up @@ -460,28 +467,31 @@ def gen_group_results_str(group, concise=False):
_report_str_model(group),
'',

# Aperiodic parameters
'Aperiodic Parameters (\'{}\' mode)'.format(group.modes.aperiodic.name),
*[el for el in [\
'{:8s} - Min: {:6.2f}, Max: {:6.2f}, Mean: {:5.2f}'.format(label, \
*compute_arr_desc(group.results.get_params('aperiodic', label))) \
for label in group.modes.aperiodic.params.labels]],
'',

# Peak Parameters
'Peak Parameters (\'{}\' mode) {} total peaks found'.format(\
group.modes.periodic.name, sum(group.results.n_peaks)),
'',
]

# Metrics
'Model fit quality metrics:',
*['{:>18s} - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}'.format(\
'{:s} ({:s})'.format(*label.split('_')),
*compute_arr_desc(group.results.get_metrics(label))) \
for label in group.results.metrics.labels],
'',
if len(group.results.metrics) > 0:
str_lst.extend([
'Model metrics:',
*['{:>18s} - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}'.format(\
'{:s} ({:s})'.format(*label.split('_')),
*compute_arr_desc(group.results.get_metrics(label))) \
for label in group.results.metrics.labels],
'',
])

str_lst.extend([
DIVIDER,
]
])

output = _format(str_lst, concise)

Expand Down Expand Up @@ -525,15 +535,13 @@ def gen_time_results_str(time, concise=False):
_report_str_model(time),
'',

# Aperiodic parameters
'Aperiodic Parameters (\'{}\' mode)'.format(time.modes.aperiodic.name),
*[el for el in [\
'{:8s} - Min: {:6.2f}, Max: {:6.2f}, Mean: {:5.2f}'.format(label, \
*compute_arr_desc(time.results.time_results[label])) \
for label in time.modes.aperiodic.params.labels]],
'',

# Peak Parameters
'Peak Parameters (\'{}\' mode) - mean values across windows'.format(\
time.modes.periodic.name),
*[peak_str.format(*[band_label] + \
Expand All @@ -543,8 +551,7 @@ def gen_time_results_str(time, concise=False):
for band_label in time.results.bands.labels],
'',

# Metrics
'Model fit quality metrics (values across windows):',
'Model metrics (values across windows):',
*['{:>18s} - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}'.format(\
'{:s} ({:s})'.format(*key.split('_')),
*compute_arr_desc(time.results.time_results[key])) \
Expand Down Expand Up @@ -597,15 +604,13 @@ def gen_event_results_str(event, concise=False):
_report_str_model(event),
'',

# Aperiodic parameters
'Aperiodic Parameters (\'{}\' mode)'.format(event.modes.aperiodic.name),
*[el for el in [\
'{:8s} - Min: {:6.2f}, Max: {:6.2f}, Mean: {:5.2f}'.format(label, \
*compute_arr_desc(np.mean(event.results.event_time_results[label]))) \
for label in event.modes.aperiodic.params.labels]],
'',

# Peak Parameters
'Peak Parameters (\'{}\' mode) - mean values across windows'.format(\
event.modes.periodic.name),
*[peak_str.format(*[band_label] + \
Expand All @@ -616,8 +621,7 @@ def gen_event_results_str(event, concise=False):
for band_label in event.results.bands.labels],
'',

# Metrics
'Model fit quality metrics (values across events):',
'Model metrics (values across events):',
*['{:>18s} - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}'.format(\
'{:s} ({:s})'.format(*key.split('_')),
*compute_arr_desc(np.mean(event.results.event_time_results[key], 1))) \
Expand Down
2 changes: 1 addition & 1 deletion specparam/results/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def add_results(self, results):


def get_results(self):
"""Return model fit parameters and goodness of fit metrics.
"""Return model fit parameters and metrics.

Returns
-------
Expand Down
4 changes: 2 additions & 2 deletions specparam/tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,15 @@ def test_fit_knee():
assert np.allclose(gauss, tfm.results.params.periodic.get_params('fit')[ii], [2.0, 0.5, 1.0])

def test_fit_default_metrics():
"""Test goodness of fit & error metrics, post model fitting."""
"""Test computing metrics, post model fitting."""

tfm = SpectralModel(verbose=False)

# Hack fake data with known properties: total error magnitude 2
tfm.data.power_spectrum = np.array([1, 2, 3, 4, 5])
tfm.results.model.modeled_spectrum = np.array([1, 2, 5, 4, 5])

# Check default goodness of fit and error measures
# Check default metrics
tfm.results.metrics.compute_metrics(tfm.data, tfm.results)
assert np.isclose(tfm.results.metrics.results['error_mae'], 0.4)
assert np.isclose(tfm.results.metrics.results['gof_rsquared'], 0.75757575)
Expand Down
6 changes: 3 additions & 3 deletions specparam/tests/plts/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def test_plot_group_aperiodic(tfg, skip_if_no_mpl):
file_name='test_plot_group_aperiodic.png')

@plot_test
def test_plot_group_goodness(tfg, skip_if_no_mpl):
def test_plot_group_metrics(tfg, skip_if_no_mpl):

plot_group_goodness(tfg, file_path=TEST_PLOTS_PATH,
file_name='test_plot_group_goodness.png')
plot_group_metrics(tfg, file_path=TEST_PLOTS_PATH,
file_name='test_plot_group_metrics.png')

@plot_test
def test_plot_group_peak_frequencies(tfg, skip_if_no_mpl):
Expand Down