diff --git a/specparam/measures/gof.py b/specparam/measures/gof.py index e0a3bc561..0d4240756 100644 --- a/specparam/measures/gof.py +++ b/specparam/measures/gof.py @@ -76,7 +76,7 @@ def compute_mean_abs_error(power_spectrum, modeled_spectrum): Returns ------- error : float - Computed MAE. + Computed mean absolute error. """ error = np.abs(power_spectrum - modeled_spectrum).mean() @@ -97,7 +97,7 @@ def compute_mean_squared_error(power_spectrum, modeled_spectrum): Returns ------- error : float - Computed MSE. + Computed mean squared error. """ error = ((power_spectrum - modeled_spectrum) ** 2).mean() @@ -118,7 +118,7 @@ def compute_root_mean_squared_error(power_spectrum, modeled_spectrum): Returns ------- error : float - Computed rMSE. + Computed root mean squared error. """ error = np.sqrt(((power_spectrum - modeled_spectrum) ** 2).mean()) @@ -126,11 +126,33 @@ def compute_root_mean_squared_error(power_spectrum, modeled_spectrum): return error +def compute_median_abs_error(power_spectrum, modeled_spectrum): + """Calculate the median absolute error. + + Parameters + ---------- + power_spectrum : 1d array + Real data power spectrum. + modeled_spectrum : 1d array + Modelled power spectrum. + + Returns + ------- + error : float + Computed median absolute error. + """ + + error = np.median(np.abs(modeled_spectrum - power_spectrum), axis=0) + + return error + + # Collect available error functions together ERROR_FUNCS = { 'mae' : compute_mean_abs_error, 'mse' : compute_mean_squared_error, 'rmse' : compute_root_mean_squared_error, + 'medae' : compute_median_abs_error, } @@ -143,7 +165,7 @@ def compute_error(power_spectrum, modeled_spectrum, error_metric='mae'): Real data power spectrum. modeled_spectrum : 1d array Modelled power spectrum. - error_metric : {'mae', 'mse', 'rsme'} or callable + error_metric : {'mae', 'mse', 'rsme', 'medae'} or callable Which approach to take to compute the error. Returns diff --git a/specparam/tests/measures/test_gof.py b/specparam/tests/measures/test_gof.py index 8a34803d3..77e6bee88 100644 --- a/specparam/tests/measures/test_gof.py +++ b/specparam/tests/measures/test_gof.py @@ -36,8 +36,13 @@ def test_compute_root_mean_squared_error(tfm): error = compute_root_mean_squared_error(tfm.power_spectrum, tfm.modeled_spectrum_) assert isinstance(error, float) +def test_compute_median_abs_error(tfm): + + error = compute_median_abs_error(tfm.power_spectrum, tfm.modeled_spectrum_) + assert isinstance(error, float) + def test_compute_error(tfm): - for metric in ['mae', 'mse', 'rmse']: + for metric in ['mae', 'mse', 'rmse', 'medae']: error = compute_error(tfm.power_spectrum, tfm.modeled_spectrum_) assert isinstance(error, float)