From a8bd2fb90d77833de7c5138656acbdf40edce601 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 29 Oct 2025 19:28:17 +0000 Subject: [PATCH 1/7] update obj descriptions --- specparam/models/model.py | 7 +++++++ specparam/objs/results.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/specparam/models/model.py b/specparam/models/model.py index 7b2545f2..ed65443c 100644 --- a/specparam/models/model.py +++ b/specparam/models/model.py @@ -38,6 +38,13 @@ class SpectralModel(BaseModel): Which approach to take for fitting the aperiodic component. periodic_mode : {'gaussian', 'skewed_gaussian', 'cauchy'} Which approach to take for fitting the periodic component. + metrics : Metrics or list of Metric or list or str + Metrics definition(s) to use to evaluate the model. + bands : Bands or dict or int or None, optional + Bands object with band definitions, or definition that can be turned into a Bands object. + debug : bool, optional, default: False + Whether to run in debug mode. + If in debug, any errors encountered during fitting will raise an error. verbose : bool, optional, default: True Verbosity mode. If True, prints out warnings and general status updates. **model_kwargs diff --git a/specparam/objs/results.py b/specparam/objs/results.py index bd422fe7..9b81d0cf 100644 --- a/specparam/objs/results.py +++ b/specparam/objs/results.py @@ -37,8 +37,8 @@ class Results(): Modes object with fit mode definitions. metrics : Metrics Metrics object with metric definitions. - bands : Bands - Bands object with band definitions. + bands : Bands or dict or int or None + Bands object with band definitions, or definition that can be turned into a Bands object. Attributes ---------- From ca1260b380b3cc327e5edc313981c89544be8017 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 29 Oct 2025 19:28:35 +0000 Subject: [PATCH 2/7] refactor metric definitions --- specparam/measures/metrics.py | 59 ++++++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 8 deletions(-) diff --git a/specparam/measures/metrics.py b/specparam/measures/metrics.py index b073ccaf..062c0acd 100644 --- a/specparam/measures/metrics.py +++ b/specparam/measures/metrics.py @@ -6,19 +6,62 @@ from specparam.measures.gof import compute_r_squared, compute_adj_r_squared ################################################################################################### +## ERROR METRICS + +error_mae = Metric( + category='error', + measure='mae', + func=compute_mean_abs_error, +) + +error_mse = Metric( + category='error', + measure='mse', + func=compute_mean_squared_error +) + +error_rmse = Metric( + category='error', + measure='rmse', + func=compute_root_mean_squared_error, +) + +error_medae = Metric( + category='error', + measure='medae', + func=compute_median_abs_error, +) + ################################################################################################### +## GOF + +gof_rsquared = Metric( + category='gof', + measure='rsquared', + func=compute_r_squared, +) + +gof_adjrsquared = Metric( + category='gof', + measure='adjrsquared', + func=compute_adj_r_squared, + kwargs={'n_params' : lambda data, results: \ + results.params.periodic.params.size + results.params.aperiodic.params.size}, +) + +################################################################################################### +## COLLECT ALL METRICS TOGETHER METRICS = { # Available error metrics - 'error_mae' : Metric('error', 'mae', compute_mean_abs_error), - 'error_mse' : Metric('error', 'mse', compute_mean_squared_error), - 'error_rmse' : Metric('error', 'rmse', compute_root_mean_squared_error), - 'error_medae' : Metric('error', 'medae', compute_median_abs_error), + 'error_mae' : error_mae, + 'error_mse' : error_mse, + 'error_rmse' : error_rmse, + 'error_medae' : error_medae, # Available GOF / r-squared metrics - 'gof_rsquared' : Metric('gof', 'rsquared', compute_r_squared), - 'gof_adjrsquared' : Metric('gof', 'adjrsquared', compute_adj_r_squared, \ - {'n_params' : lambda data, results: \ - results.params.periodic.params.size + results.params.aperiodic.params.size}) + 'gof_rsquared' : gof_rsquared, + 'gof_adjrsquared' : gof_adjrsquared, + } From 13a5981c6cfdec20bdaa134e6f384ebf1b42fdc8 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 29 Oct 2025 19:37:04 +0000 Subject: [PATCH 3/7] add description field to Metric object --- specparam/objs/metrics.py | 5 ++++- specparam/tests/objs/test_metrics.py | 18 +++++++++++------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/specparam/objs/metrics.py b/specparam/objs/metrics.py index 1915481d..6102d552 100644 --- a/specparam/objs/metrics.py +++ b/specparam/objs/metrics.py @@ -16,6 +16,8 @@ class Metric(): The category of measure, e.g. 'error' or 'gof'. measure : str The specific measure, e.g. 'r_squared'. + description : str + Description of the metric. func : callable The function that computes the metric. kwargs : dictionary @@ -25,11 +27,12 @@ class Metric(): and returns the desired parameter / computed value. """ - def __init__(self, category, measure, func, kwargs=None): + def __init__(self, category, measure, description, func, kwargs=None): """Initialize metric.""" self.category = category self.measure = measure + self.description = description self.func = func self.result = np.nan self.kwargs = {} if not kwargs else kwargs diff --git a/specparam/tests/objs/test_metrics.py b/specparam/tests/objs/test_metrics.py index 793b838c..b925d8d9 100644 --- a/specparam/tests/objs/test_metrics.py +++ b/specparam/tests/objs/test_metrics.py @@ -12,7 +12,7 @@ def test_metric(tfm): - metric = Metric('error', 'mae', compute_mean_abs_error) + metric = Metric('error', 'mae', 'Description.', compute_mean_abs_error) assert isinstance(metric, Metric) assert isinstance(metric.label, str) @@ -21,7 +21,7 @@ def test_metric(tfm): def test_metric_kwargs(tfm): - metric = Metric('gof', 'ar2', compute_adj_r_squared, + metric = Metric('gof', 'ar2', 'Description.', compute_adj_r_squared, {'n_params' : lambda data, results: \ results.params.periodic.params.size + results.params.aperiodic.params.size}) @@ -38,8 +38,8 @@ def test_metrics_null(): def test_metrics_obj(tfm): - er_metric = Metric('error', 'mae', compute_mean_abs_error) - gof_metric = Metric('gof', 'rsquared', compute_r_squared) + er_metric = Metric('error', 'mae', 'Description.', compute_mean_abs_error) + gof_metric = Metric('gof', 'rsquared', 'Description.', compute_r_squared) metrics = Metrics([er_metric, gof_metric]) assert isinstance(metrics, Metrics) @@ -61,8 +61,10 @@ def test_metrics_obj(tfm): def test_metrics_dict(tfm): - er_met_def = {'category' : 'error', 'measure' : 'mae', 'func' : compute_mean_abs_error} - gof_met_def = {'category' : 'gof', 'measure' : 'rsquared', 'func' : compute_r_squared} + er_met_def = {'category' : 'error', 'measure' : 'mae', + 'description' : 'Description.', 'func' : compute_mean_abs_error} + gof_met_def = {'category' : 'gof', 'measure' : 'rsquared', + 'description' : 'Description.', 'func' : compute_r_squared} metrics = Metrics([er_met_def, gof_met_def]) assert isinstance(metrics, Metrics) @@ -79,8 +81,10 @@ def test_metrics_dict(tfm): def test_metrics_kwargs(tfm): - er_met_def = {'category' : 'error', 'measure' : 'mae', 'func' : compute_mean_abs_error} + er_met_def = {'category' : 'error', 'measure' : 'mae', + 'description' : 'Description.', 'func' : compute_mean_abs_error} ar2_met_def = {'category' : 'gof', 'measure' : 'arsquared', + 'description' : 'Description.', 'func' : compute_adj_r_squared, 'kwargs' : {'n_params' : lambda data, results: \ results.params.periodic.params.size + results.params.aperiodic.params.size}} From b8a6d26be9a2cad9d3d43a8bb0a7b5889d766810 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 29 Oct 2025 19:39:12 +0000 Subject: [PATCH 4/7] add descriptions to metric definitions --- specparam/measures/metrics.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/specparam/measures/metrics.py b/specparam/measures/metrics.py index 062c0acd..8fb23b28 100644 --- a/specparam/measures/metrics.py +++ b/specparam/measures/metrics.py @@ -11,24 +11,28 @@ error_mae = Metric( category='error', measure='mae', + description='Mean absolute error of the model fit to the data.', func=compute_mean_abs_error, ) error_mse = Metric( category='error', measure='mse', + description='Mean squared error of the model fit to the data.', func=compute_mean_squared_error ) error_rmse = Metric( category='error', measure='rmse', + description='Root mean squared error of the model fit to the data.', func=compute_root_mean_squared_error, ) error_medae = Metric( category='error', measure='medae', + description='Median absolute error of the model fit to the data.', func=compute_median_abs_error, ) @@ -38,12 +42,14 @@ gof_rsquared = Metric( category='gof', measure='rsquared', + description='R-squared between the model fit and the data.', func=compute_r_squared, ) gof_adjrsquared = Metric( category='gof', measure='adjrsquared', + description='Adjusted R-squared between the model fit and the data.', func=compute_adj_r_squared, kwargs={'n_params' : lambda data, results: \ results.params.periodic.params.size + results.params.aperiodic.params.size}, From c76f4ba6feca74239afaf13a2746f9022ce50f83 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 29 Oct 2025 19:43:40 +0000 Subject: [PATCH 5/7] move metric files around --- specparam/{measures/metrics.py => metrics/definitions.py} | 0 specparam/{measures => metrics}/error.py | 0 specparam/{measures => metrics}/gof.py | 0 specparam/{objs => metrics}/metrics.py | 0 .../{measures/test_metrics.py => metrics/test_definitions.py} | 0 specparam/tests/{measures => metrics}/test_error.py | 0 specparam/tests/{measures => metrics}/test_gof.py | 0 specparam/tests/{objs => metrics}/test_metrics.py | 0 8 files changed, 0 insertions(+), 0 deletions(-) rename specparam/{measures/metrics.py => metrics/definitions.py} (100%) rename specparam/{measures => metrics}/error.py (100%) rename specparam/{measures => metrics}/gof.py (100%) rename specparam/{objs => metrics}/metrics.py (100%) rename specparam/tests/{measures/test_metrics.py => metrics/test_definitions.py} (100%) rename specparam/tests/{measures => metrics}/test_error.py (100%) rename specparam/tests/{measures => metrics}/test_gof.py (100%) rename specparam/tests/{objs => metrics}/test_metrics.py (100%) diff --git a/specparam/measures/metrics.py b/specparam/metrics/definitions.py similarity index 100% rename from specparam/measures/metrics.py rename to specparam/metrics/definitions.py diff --git a/specparam/measures/error.py b/specparam/metrics/error.py similarity index 100% rename from specparam/measures/error.py rename to specparam/metrics/error.py diff --git a/specparam/measures/gof.py b/specparam/metrics/gof.py similarity index 100% rename from specparam/measures/gof.py rename to specparam/metrics/gof.py diff --git a/specparam/objs/metrics.py b/specparam/metrics/metrics.py similarity index 100% rename from specparam/objs/metrics.py rename to specparam/metrics/metrics.py diff --git a/specparam/tests/measures/test_metrics.py b/specparam/tests/metrics/test_definitions.py similarity index 100% rename from specparam/tests/measures/test_metrics.py rename to specparam/tests/metrics/test_definitions.py diff --git a/specparam/tests/measures/test_error.py b/specparam/tests/metrics/test_error.py similarity index 100% rename from specparam/tests/measures/test_error.py rename to specparam/tests/metrics/test_error.py diff --git a/specparam/tests/measures/test_gof.py b/specparam/tests/metrics/test_gof.py similarity index 100% rename from specparam/tests/measures/test_gof.py rename to specparam/tests/metrics/test_gof.py diff --git a/specparam/tests/objs/test_metrics.py b/specparam/tests/metrics/test_metrics.py similarity index 100% rename from specparam/tests/objs/test_metrics.py rename to specparam/tests/metrics/test_metrics.py From dfc5c9e7d24bb60161ca9c6b65c444f4e3c53f84 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 29 Oct 2025 19:48:47 +0000 Subject: [PATCH 6/7] fix paths for moves --- specparam/metrics/definitions.py | 8 ++++---- specparam/objs/results.py | 4 ++-- specparam/tests/metrics/test_definitions.py | 6 +++--- specparam/tests/metrics/test_error.py | 4 ++-- specparam/tests/metrics/test_gof.py | 4 ++-- specparam/tests/metrics/test_metrics.py | 8 ++++---- specparam/tests/models/test_group.py | 2 +- specparam/tests/models/test_model.py | 2 +- specparam/tests/models/test_utils.py | 2 +- 9 files changed, 20 insertions(+), 20 deletions(-) diff --git a/specparam/metrics/definitions.py b/specparam/metrics/definitions.py index 8fb23b28..1125864b 100644 --- a/specparam/metrics/definitions.py +++ b/specparam/metrics/definitions.py @@ -1,9 +1,9 @@ """Collect together library of available built in metrics.""" -from specparam.objs.metrics import Metric -from specparam.measures.error import (compute_mean_abs_error, compute_mean_squared_error, - compute_root_mean_squared_error, compute_median_abs_error) -from specparam.measures.gof import compute_r_squared, compute_adj_r_squared +from specparam.metrics.metrics import Metric +from specparam.metrics.error import (compute_mean_abs_error, compute_mean_squared_error, + compute_root_mean_squared_error, compute_median_abs_error) +from specparam.metrics.gof import compute_r_squared, compute_adj_r_squared ################################################################################################### ## ERROR METRICS diff --git a/specparam/objs/results.py b/specparam/objs/results.py index 9b81d0cf..45247310 100644 --- a/specparam/objs/results.py +++ b/specparam/objs/results.py @@ -7,10 +7,10 @@ from specparam.bands.bands import check_bands from specparam.modes.modes import Modes -from specparam.objs.metrics import Metrics from specparam.objs.params import ModelParameters from specparam.objs.components import ModelComponents -from specparam.measures.metrics import METRICS +from specparam.metrics.metrics import Metrics +from specparam.metrics.definitions import METRICS from specparam.utils.checks import check_inds, check_array_dim from specparam.modutils.errors import NoModelError from specparam.modutils.docs import (copy_doc_func_to_method, docs_get_section, diff --git a/specparam/tests/metrics/test_definitions.py b/specparam/tests/metrics/test_definitions.py index 89289ccb..d8d6c3b7 100644 --- a/specparam/tests/metrics/test_definitions.py +++ b/specparam/tests/metrics/test_definitions.py @@ -1,8 +1,8 @@ -"""Test functions for specparam.measures.metrics.""" +"""Test functions for specparam.metrics.definitions.""" -from specparam.objs.metrics import Metric +from specparam.metrics.metrics import Metric -from specparam.measures.metrics import * +from specparam.metrics.definitions import * ################################################################################################### ################################################################################################### diff --git a/specparam/tests/metrics/test_error.py b/specparam/tests/metrics/test_error.py index 34bed15a..2c305c1e 100644 --- a/specparam/tests/metrics/test_error.py +++ b/specparam/tests/metrics/test_error.py @@ -1,6 +1,6 @@ -"""Test functions for specparam.measures.error.""" +"""Test functions for specparam.metrics.error.""" -from specparam.measures.error import * +from specparam.metrics.error import * ################################################################################################### ################################################################################################### diff --git a/specparam/tests/metrics/test_gof.py b/specparam/tests/metrics/test_gof.py index 6b54a4ee..bd9f5107 100644 --- a/specparam/tests/metrics/test_gof.py +++ b/specparam/tests/metrics/test_gof.py @@ -1,6 +1,6 @@ -"""Test functions for specparam.measures.gof.""" +"""Test functions for specparam.metrics.gof.""" -from specparam.measures.gof import * +from specparam.metrics.gof import * ################################################################################################### ################################################################################################### diff --git a/specparam/tests/metrics/test_metrics.py b/specparam/tests/metrics/test_metrics.py index b925d8d9..01f35587 100644 --- a/specparam/tests/metrics/test_metrics.py +++ b/specparam/tests/metrics/test_metrics.py @@ -1,11 +1,11 @@ -"""Tests for specparam.objs.metrics.""" +"""Tests for specparam.metrics.metrics""" from pytest import raises -from specparam.measures.error import compute_mean_abs_error -from specparam.measures.gof import compute_r_squared, compute_adj_r_squared +from specparam.metrics.error import compute_mean_abs_error +from specparam.metrics.gof import compute_r_squared, compute_adj_r_squared -from specparam.objs.metrics import * +from specparam.metrics.metrics import * ################################################################################################### ################################################################################################### diff --git a/specparam/tests/models/test_group.py b/specparam/tests/models/test_group.py index b706a3d5..2b72e875 100644 --- a/specparam/tests/models/test_group.py +++ b/specparam/tests/models/test_group.py @@ -11,7 +11,7 @@ import numpy as np from numpy.testing import assert_equal -from specparam.measures.metrics import METRICS +from specparam.metrics.definitions import METRICS from specparam.models.utils import compare_model_objs from specparam.modutils.dependencies import safe_import from specparam.sim import sim_group_power_spectra diff --git a/specparam/tests/models/test_model.py b/specparam/tests/models/test_model.py index 7b8ee88d..97cf6d68 100644 --- a/specparam/tests/models/test_model.py +++ b/specparam/tests/models/test_model.py @@ -11,7 +11,7 @@ from specparam.utils.select import groupby from specparam.modutils.errors import FitError -from specparam.measures.metrics import METRICS +from specparam.metrics.definitions import METRICS from specparam.sim import gen_freqs, sim_power_spectrum from specparam.modes.definitions import AP_MODES, PE_MODES from specparam.models.utils import compare_model_objs diff --git a/specparam/tests/models/test_utils.py b/specparam/tests/models/test_utils.py index 8e2d9903..68a6a0d0 100644 --- a/specparam/tests/models/test_utils.py +++ b/specparam/tests/models/test_utils.py @@ -6,7 +6,7 @@ from specparam import SpectralGroupModel from specparam.sim import sim_group_power_spectra -from specparam.measures.metrics import METRICS +from specparam.metrics.definitions import METRICS from specparam.modutils.errors import NoModelError, IncompatibleSettingsError from specparam.tests.tdata import default_group_params From 8ab18333074682c287fe57b18a600245ae2c241a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 29 Oct 2025 19:52:50 +0000 Subject: [PATCH 7/7] split up metric & metrics --- specparam/metrics/metric.py | 85 +++++++++++++++++++++++++ specparam/metrics/metrics.py | 82 +----------------------- specparam/tests/metrics/test_metric.py | 30 +++++++++ specparam/tests/metrics/test_metrics.py | 22 +------ 4 files changed, 118 insertions(+), 101 deletions(-) create mode 100644 specparam/metrics/metric.py create mode 100644 specparam/tests/metrics/test_metric.py diff --git a/specparam/metrics/metric.py b/specparam/metrics/metric.py new file mode 100644 index 00000000..9735ca58 --- /dev/null +++ b/specparam/metrics/metric.py @@ -0,0 +1,85 @@ +"""Metric object.""" + +import numpy as np + +################################################################################################### +################################################################################################### + +class Metric(): + """Define a metric to apply to a power spectrum model. + + Parameters + ---------- + category : str + The category of measure, e.g. 'error' or 'gof'. + measure : str + The specific measure, e.g. 'r_squared'. + description : str + Description of the metric. + func : callable + The function that computes the metric. + kwargs : dictionary + Additional keyword argument to compute the metric. + Each key should be the name of the additional argument. + Each value should be a lambda function that takes 'data' & 'results' + and returns the desired parameter / computed value. + """ + + def __init__(self, category, measure, description, func, kwargs=None): + """Initialize metric.""" + + self.category = category + self.measure = measure + self.description = description + self.func = func + self.result = np.nan + self.kwargs = {} if not kwargs else kwargs + + + def __repr__(self): + """Set string representation of object.""" + + return 'Metric: ' + self.label + + + @property + def label(self): + """Define label property.""" + + return self.category + '_' + self.measure + + + @property + def flabel(self): + """Define formatted label property.""" + + if self.category == 'error': + flabel = '{} ({})'.format(self.category.capitalize(), self.measure.upper()) + if self.category == 'gof': + flabel = '{} ({})'.format(self.category.upper(), self.measure) + + return flabel + + + def compute_metric(self, data, results): + """Compute metric. + + Parameters + ---------- + data : Data + Model data. + results : Results + Model results. + """ + + kwargs = {} + for key, lfunc in self.kwargs.items(): + kwargs[key] = lfunc(data, results) + + self.result = self.func(data.power_spectrum, results.model.modeled_spectrum, **kwargs) + + + def reset(self): + """Reset metric result.""" + + self.result = np.nan \ No newline at end of file diff --git a/specparam/metrics/metrics.py b/specparam/metrics/metrics.py index 6102d552..38f96bc6 100644 --- a/specparam/metrics/metrics.py +++ b/specparam/metrics/metrics.py @@ -4,89 +4,11 @@ import numpy as np +from specparam.metrics.metric import Metric + ################################################################################################### ################################################################################################### -class Metric(): - """Define a metric to apply to a power spectrum model. - - Parameters - ---------- - category : str - The category of measure, e.g. 'error' or 'gof'. - measure : str - The specific measure, e.g. 'r_squared'. - description : str - Description of the metric. - func : callable - The function that computes the metric. - kwargs : dictionary - Additional keyword argument to compute the metric. - Each key should be the name of the additional argument. - Each value should be a lambda function that takes 'data' & 'results' - and returns the desired parameter / computed value. - """ - - def __init__(self, category, measure, description, func, kwargs=None): - """Initialize metric.""" - - self.category = category - self.measure = measure - self.description = description - self.func = func - self.result = np.nan - self.kwargs = {} if not kwargs else kwargs - - - def __repr__(self): - """Set string representation of object.""" - - return 'Metric: ' + self.label - - - @property - def label(self): - """Define label property.""" - - return self.category + '_' + self.measure - - - @property - def flabel(self): - """Define formatted label property.""" - - if self.category == 'error': - flabel = '{} ({})'.format(self.category.capitalize(), self.measure.upper()) - if self.category == 'gof': - flabel = '{} ({})'.format(self.category.upper(), self.measure) - - return flabel - - - def compute_metric(self, data, results): - """Compute metric. - - Parameters - ---------- - data : Data - Model data. - results : Results - Model results. - """ - - kwargs = {} - for key, lfunc in self.kwargs.items(): - kwargs[key] = lfunc(data, results) - - self.result = self.func(data.power_spectrum, results.model.modeled_spectrum, **kwargs) - - - def reset(self): - """Reset metric result.""" - - self.result = np.nan - - class Metrics(): """Define a collection of metrics. diff --git a/specparam/tests/metrics/test_metric.py b/specparam/tests/metrics/test_metric.py new file mode 100644 index 00000000..03d2865e --- /dev/null +++ b/specparam/tests/metrics/test_metric.py @@ -0,0 +1,30 @@ +"""Tests for specparam.metrics.metric""" + +from specparam.metrics.error import compute_mean_abs_error +from specparam.metrics.gof import compute_adj_r_squared + +from specparam.metrics.metric import * + +################################################################################################### +################################################################################################### + +def test_metric(tfm): + + metric = Metric('error', 'mae', 'Description.', compute_mean_abs_error) + assert isinstance(metric, Metric) + assert isinstance(metric.label, str) + + metric.compute_metric(tfm.data, tfm.results) + assert isinstance(metric.result, float) + +def test_metric_kwargs(tfm): + + metric = Metric('gof', 'ar2', 'Description.', compute_adj_r_squared, + {'n_params' : lambda data, results: \ + results.params.periodic.params.size + results.params.aperiodic.params.size}) + + assert isinstance(metric, Metric) + assert isinstance(metric.label, str) + + metric.compute_metric(tfm.data, tfm.results) + assert isinstance(metric.result, float) diff --git a/specparam/tests/metrics/test_metrics.py b/specparam/tests/metrics/test_metrics.py index 01f35587..a807d01f 100644 --- a/specparam/tests/metrics/test_metrics.py +++ b/specparam/tests/metrics/test_metrics.py @@ -2,6 +2,7 @@ from pytest import raises +from specparam.metrics.metric import Metric from specparam.metrics.error import compute_mean_abs_error from specparam.metrics.gof import compute_r_squared, compute_adj_r_squared @@ -10,27 +11,6 @@ ################################################################################################### ################################################################################################### -def test_metric(tfm): - - metric = Metric('error', 'mae', 'Description.', compute_mean_abs_error) - assert isinstance(metric, Metric) - assert isinstance(metric.label, str) - - metric.compute_metric(tfm.data, tfm.results) - assert isinstance(metric.result, float) - -def test_metric_kwargs(tfm): - - metric = Metric('gof', 'ar2', 'Description.', compute_adj_r_squared, - {'n_params' : lambda data, results: \ - results.params.periodic.params.size + results.params.aperiodic.params.size}) - - assert isinstance(metric, Metric) - assert isinstance(metric.label, str) - - metric.compute_metric(tfm.data, tfm.results) - assert isinstance(metric.result, float) - def test_metrics_null(): metrics = Metrics()