Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions specparam/measures/metrics.py

This file was deleted.

73 changes: 73 additions & 0 deletions specparam/metrics/definitions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Collect together library of available built in metrics."""

from specparam.metrics.metrics import Metric
from specparam.metrics.error import (compute_mean_abs_error, compute_mean_squared_error,
compute_root_mean_squared_error, compute_median_abs_error)
from specparam.metrics.gof import compute_r_squared, compute_adj_r_squared

###################################################################################################
## ERROR METRICS

error_mae = Metric(
category='error',
measure='mae',
description='Mean absolute error of the model fit to the data.',
func=compute_mean_abs_error,
)

error_mse = Metric(
category='error',
measure='mse',
description='Mean squared error of the model fit to the data.',
func=compute_mean_squared_error
)

error_rmse = Metric(
category='error',
measure='rmse',
description='Root mean squared error of the model fit to the data.',
func=compute_root_mean_squared_error,
)

error_medae = Metric(
category='error',
measure='medae',
description='Median absolute error of the model fit to the data.',
func=compute_median_abs_error,
)

###################################################################################################
## GOF

gof_rsquared = Metric(
category='gof',
measure='rsquared',
description='R-squared between the model fit and the data.',
func=compute_r_squared,
)

gof_adjrsquared = Metric(
category='gof',
measure='adjrsquared',
description='Adjusted R-squared between the model fit and the data.',
func=compute_adj_r_squared,
kwargs={'n_params' : lambda data, results: \
results.params.periodic.params.size + results.params.aperiodic.params.size},
)

###################################################################################################
## COLLECT ALL METRICS TOGETHER

METRICS = {

# Available error metrics
'error_mae' : error_mae,
'error_mse' : error_mse,
'error_rmse' : error_rmse,
'error_medae' : error_medae,

# Available GOF / r-squared metrics
'gof_rsquared' : gof_rsquared,
'gof_adjrsquared' : gof_adjrsquared,

}
File renamed without changes.
File renamed without changes.
85 changes: 85 additions & 0 deletions specparam/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Metric object."""

import numpy as np

###################################################################################################
###################################################################################################

class Metric():
"""Define a metric to apply to a power spectrum model.

Parameters
----------
category : str
The category of measure, e.g. 'error' or 'gof'.
measure : str
The specific measure, e.g. 'r_squared'.
description : str
Description of the metric.
func : callable
The function that computes the metric.
kwargs : dictionary
Additional keyword argument to compute the metric.
Each key should be the name of the additional argument.
Each value should be a lambda function that takes 'data' & 'results'
and returns the desired parameter / computed value.
"""

def __init__(self, category, measure, description, func, kwargs=None):
"""Initialize metric."""

self.category = category
self.measure = measure
self.description = description
self.func = func
self.result = np.nan
self.kwargs = {} if not kwargs else kwargs


def __repr__(self):
"""Set string representation of object."""

return 'Metric: ' + self.label


@property
def label(self):
"""Define label property."""

return self.category + '_' + self.measure


@property
def flabel(self):
"""Define formatted label property."""

if self.category == 'error':
flabel = '{} ({})'.format(self.category.capitalize(), self.measure.upper())
if self.category == 'gof':
flabel = '{} ({})'.format(self.category.upper(), self.measure)

return flabel


def compute_metric(self, data, results):
"""Compute metric.

Parameters
----------
data : Data
Model data.
results : Results
Model results.
"""

kwargs = {}
for key, lfunc in self.kwargs.items():
kwargs[key] = lfunc(data, results)

self.result = self.func(data.power_spectrum, results.model.modeled_spectrum, **kwargs)


def reset(self):
"""Reset metric result."""

self.result = np.nan
79 changes: 2 additions & 77 deletions specparam/objs/metrics.py → specparam/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,86 +4,11 @@

import numpy as np

from specparam.metrics.metric import Metric

###################################################################################################
###################################################################################################

class Metric():
"""Define a metric to apply to a power spectrum model.

Parameters
----------
category : str
The category of measure, e.g. 'error' or 'gof'.
measure : str
The specific measure, e.g. 'r_squared'.
func : callable
The function that computes the metric.
kwargs : dictionary
Additional keyword argument to compute the metric.
Each key should be the name of the additional argument.
Each value should be a lambda function that takes 'data' & 'results'
and returns the desired parameter / computed value.
"""

def __init__(self, category, measure, func, kwargs=None):
"""Initialize metric."""

self.category = category
self.measure = measure
self.func = func
self.result = np.nan
self.kwargs = {} if not kwargs else kwargs


def __repr__(self):
"""Set string representation of object."""

return 'Metric: ' + self.label


@property
def label(self):
"""Define label property."""

return self.category + '_' + self.measure


@property
def flabel(self):
"""Define formatted label property."""

if self.category == 'error':
flabel = '{} ({})'.format(self.category.capitalize(), self.measure.upper())
if self.category == 'gof':
flabel = '{} ({})'.format(self.category.upper(), self.measure)

return flabel


def compute_metric(self, data, results):
"""Compute metric.

Parameters
----------
data : Data
Model data.
results : Results
Model results.
"""

kwargs = {}
for key, lfunc in self.kwargs.items():
kwargs[key] = lfunc(data, results)

self.result = self.func(data.power_spectrum, results.model.modeled_spectrum, **kwargs)


def reset(self):
"""Reset metric result."""

self.result = np.nan


class Metrics():
"""Define a collection of metrics.

Expand Down
7 changes: 7 additions & 0 deletions specparam/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ class SpectralModel(BaseModel):
Which approach to take for fitting the aperiodic component.
periodic_mode : {'gaussian', 'skewed_gaussian', 'cauchy'}
Which approach to take for fitting the periodic component.
metrics : Metrics or list of Metric or list or str
Metrics definition(s) to use to evaluate the model.
bands : Bands or dict or int or None, optional
Bands object with band definitions, or definition that can be turned into a Bands object.
debug : bool, optional, default: False
Whether to run in debug mode.
If in debug, any errors encountered during fitting will raise an error.
verbose : bool, optional, default: True
Verbosity mode. If True, prints out warnings and general status updates.
**model_kwargs
Expand Down
8 changes: 4 additions & 4 deletions specparam/objs/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from specparam.bands.bands import check_bands
from specparam.modes.modes import Modes
from specparam.objs.metrics import Metrics
from specparam.objs.params import ModelParameters
from specparam.objs.components import ModelComponents
from specparam.measures.metrics import METRICS
from specparam.metrics.metrics import Metrics
from specparam.metrics.definitions import METRICS
from specparam.utils.checks import check_inds, check_array_dim
from specparam.modutils.errors import NoModelError
from specparam.modutils.docs import (copy_doc_func_to_method, docs_get_section,
Expand All @@ -37,8 +37,8 @@ class Results():
Modes object with fit mode definitions.
metrics : Metrics
Metrics object with metric definitions.
bands : Bands
Bands object with band definitions.
bands : Bands or dict or int or None
Bands object with band definitions, or definition that can be turned into a Bands object.

Attributes
----------
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Test functions for specparam.measures.metrics."""
"""Test functions for specparam.metrics.definitions."""

from specparam.objs.metrics import Metric
from specparam.metrics.metrics import Metric

from specparam.measures.metrics import *
from specparam.metrics.definitions import *

###################################################################################################
###################################################################################################
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test functions for specparam.measures.error."""
"""Test functions for specparam.metrics.error."""

from specparam.measures.error import *
from specparam.metrics.error import *

###################################################################################################
###################################################################################################
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test functions for specparam.measures.gof."""
"""Test functions for specparam.metrics.gof."""

from specparam.measures.gof import *
from specparam.metrics.gof import *

###################################################################################################
###################################################################################################
Expand Down
30 changes: 30 additions & 0 deletions specparam/tests/metrics/test_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Tests for specparam.metrics.metric"""

from specparam.metrics.error import compute_mean_abs_error
from specparam.metrics.gof import compute_adj_r_squared

from specparam.metrics.metric import *

###################################################################################################
###################################################################################################

def test_metric(tfm):

metric = Metric('error', 'mae', 'Description.', compute_mean_abs_error)
assert isinstance(metric, Metric)
assert isinstance(metric.label, str)

metric.compute_metric(tfm.data, tfm.results)
assert isinstance(metric.result, float)

def test_metric_kwargs(tfm):

metric = Metric('gof', 'ar2', 'Description.', compute_adj_r_squared,
{'n_params' : lambda data, results: \
results.params.periodic.params.size + results.params.aperiodic.params.size})

assert isinstance(metric, Metric)
assert isinstance(metric.label, str)

metric.compute_metric(tfm.data, tfm.results)
assert isinstance(metric.result, float)
Loading