Skip to content

Commit

Permalink
Added mean_absolute_percentage_error in metrics fixes scikit-learn#10708
Browse files Browse the repository at this point in the history
 (scikit-learn#15007)

Co-authored-by: mohamed-ali <m.ali.jamaoui@gmail.com>
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Joel Nothman <joel.nothman@gmail.com>
Co-authored-by: Roman Yurchak <rth.yurchak@pm.me>
  • Loading branch information
6 people authored and jayzed82 committed Oct 22, 2020
1 parent cb5045a commit 4db62e4
Show file tree
Hide file tree
Showing 10 changed files with 207 additions and 46 deletions.
3 changes: 2 additions & 1 deletion doc/modules/classes.rst
Expand Up @@ -900,7 +900,7 @@ Miscellaneous
manifold.smacof
manifold.spectral_embedding
manifold.trustworthiness


.. _metrics_ref:

Expand Down Expand Up @@ -981,6 +981,7 @@ details.
metrics.mean_squared_error
metrics.mean_squared_log_error
metrics.median_absolute_error
metrics.mean_absolute_percentage_error
metrics.r2_score
metrics.mean_poisson_deviance
metrics.mean_gamma_deviance
Expand Down
117 changes: 77 additions & 40 deletions doc/modules/model_evaluation.rst
Expand Up @@ -54,51 +54,52 @@ the model and the data, like :func:`metrics.mean_squared_error`, are
available as neg_mean_squared_error which return the negated value
of the metric.

============================== ============================================= ==================================
Scoring Function Comment
============================== ============================================= ==================================
==================================== ============================================== ==================================
Scoring Function Comment
==================================== ============================================== ==================================
**Classification**
'accuracy' :func:`metrics.accuracy_score`
'balanced_accuracy' :func:`metrics.balanced_accuracy_score`
'average_precision' :func:`metrics.average_precision_score`
'neg_brier_score' :func:`metrics.brier_score_loss`
'f1' :func:`metrics.f1_score` for binary targets
'f1_micro' :func:`metrics.f1_score` micro-averaged
'f1_macro' :func:`metrics.f1_score` macro-averaged
'f1_weighted' :func:`metrics.f1_score` weighted average
'f1_samples' :func:`metrics.f1_score` by multilabel sample
'neg_log_loss' :func:`metrics.log_loss` requires ``predict_proba`` support
'precision' etc. :func:`metrics.precision_score` suffixes apply as with 'f1'
'recall' etc. :func:`metrics.recall_score` suffixes apply as with 'f1'
'jaccard' etc. :func:`metrics.jaccard_score` suffixes apply as with 'f1'
'roc_auc' :func:`metrics.roc_auc_score`
'roc_auc_ovr' :func:`metrics.roc_auc_score`
'roc_auc_ovo' :func:`metrics.roc_auc_score`
'roc_auc_ovr_weighted' :func:`metrics.roc_auc_score`
'roc_auc_ovo_weighted' :func:`metrics.roc_auc_score`
'accuracy' :func:`metrics.accuracy_score`
'balanced_accuracy' :func:`metrics.balanced_accuracy_score`
'average_precision' :func:`metrics.average_precision_score`
'neg_brier_score' :func:`metrics.brier_score_loss`
'f1' :func:`metrics.f1_score` for binary targets
'f1_micro' :func:`metrics.f1_score` micro-averaged
'f1_macro' :func:`metrics.f1_score` macro-averaged
'f1_weighted' :func:`metrics.f1_score` weighted average
'f1_samples' :func:`metrics.f1_score` by multilabel sample
'neg_log_loss' :func:`metrics.log_loss` requires ``predict_proba`` support
'precision' etc. :func:`metrics.precision_score` suffixes apply as with 'f1'
'recall' etc. :func:`metrics.recall_score` suffixes apply as with 'f1'
'jaccard' etc. :func:`metrics.jaccard_score` suffixes apply as with 'f1'
'roc_auc' :func:`metrics.roc_auc_score`
'roc_auc_ovr' :func:`metrics.roc_auc_score`
'roc_auc_ovo' :func:`metrics.roc_auc_score`
'roc_auc_ovr_weighted' :func:`metrics.roc_auc_score`
'roc_auc_ovo_weighted' :func:`metrics.roc_auc_score`

**Clustering**
'adjusted_mutual_info_score' :func:`metrics.adjusted_mutual_info_score`
'adjusted_rand_score' :func:`metrics.adjusted_rand_score`
'completeness_score' :func:`metrics.completeness_score`
'fowlkes_mallows_score' :func:`metrics.fowlkes_mallows_score`
'homogeneity_score' :func:`metrics.homogeneity_score`
'mutual_info_score' :func:`metrics.mutual_info_score`
'normalized_mutual_info_score' :func:`metrics.normalized_mutual_info_score`
'v_measure_score' :func:`metrics.v_measure_score`
'adjusted_mutual_info_score' :func:`metrics.adjusted_mutual_info_score`
'adjusted_rand_score' :func:`metrics.adjusted_rand_score`
'completeness_score' :func:`metrics.completeness_score`
'fowlkes_mallows_score' :func:`metrics.fowlkes_mallows_score`
'homogeneity_score' :func:`metrics.homogeneity_score`
'mutual_info_score' :func:`metrics.mutual_info_score`
'normalized_mutual_info_score' :func:`metrics.normalized_mutual_info_score`
'v_measure_score' :func:`metrics.v_measure_score`

**Regression**
'explained_variance' :func:`metrics.explained_variance_score`
'max_error' :func:`metrics.max_error`
'neg_mean_absolute_error' :func:`metrics.mean_absolute_error`
'neg_mean_squared_error' :func:`metrics.mean_squared_error`
'neg_root_mean_squared_error' :func:`metrics.mean_squared_error`
'neg_mean_squared_log_error' :func:`metrics.mean_squared_log_error`
'neg_median_absolute_error' :func:`metrics.median_absolute_error`
'r2' :func:`metrics.r2_score`
'neg_mean_poisson_deviance' :func:`metrics.mean_poisson_deviance`
'neg_mean_gamma_deviance' :func:`metrics.mean_gamma_deviance`
============================== ============================================= ==================================
'explained_variance' :func:`metrics.explained_variance_score`
'max_error' :func:`metrics.max_error`
'neg_mean_absolute_error' :func:`metrics.mean_absolute_error`
'neg_mean_squared_error' :func:`metrics.mean_squared_error`
'neg_root_mean_squared_error' :func:`metrics.mean_squared_error`
'neg_mean_squared_log_error' :func:`metrics.mean_squared_log_error`
'neg_median_absolute_error' :func:`metrics.median_absolute_error`
'r2' :func:`metrics.r2_score`
'neg_mean_poisson_deviance' :func:`metrics.mean_poisson_deviance`
'neg_mean_gamma_deviance' :func:`metrics.mean_gamma_deviance`
'neg_mean_absolute_percentage_error' :func:`metrics.mean_absolute_percentage_error`
==================================== ============================================== ==================================


Usage examples:
Expand Down Expand Up @@ -1963,6 +1964,42 @@ function::
>>> mean_squared_log_error(y_true, y_pred)
0.044...

.. _mean_absolute_percentage_error:

Mean absolute percentage error
------------------------------
The :func:`mean_absolute_percentage_error` (MAPE), also known as mean absolute
percentage deviation (MAPD), is an evaluation metric for regression problems.
The idea of this metric is to be sensitive to relative errors. It is for example
not changed by a global scaling of the target variable.

If :math:`\hat{y}_i` is the predicted value of the :math:`i`-th sample
and :math:`y_i` is the corresponding true value, then the mean absolute percentage
error (MAPE) estimated over :math:`n_{\text{samples}}` is defined as

.. math::
\text{MAPE}(y, \hat{y}) = \frac{1}{n_{\text{samples}}} \sum_{i=0}^{n_{\text{samples}}-1} \frac{{}\left| y_i - \hat{y}_i \right|}{max(\epsilon, \left| y_i \right|)}
where :math:`\epsilon` is an arbitrary small yet strictly positive number to
avoid undefined results when y is zero.

The :func:`mean_absolute_percentage_error` function supports multioutput.

Here is a small example of usage of the :func:`mean_absolute_percentage_error`
function::

>>> from sklearn.metrics import mean_absolute_percentage_error
>>> y_true = [1, 10, 1e6]
>>> y_pred = [0.9, 15, 1.2e6]
>>> mean_absolute_percentage_error(y_true, y_pred)
0.2666...

In above example, if we had used `mean_absolute_error`, it would have ignored
the small magnitude values and only reflected the error in prediction of highest
magnitude value. But that problem is resolved in case of MAPE because it calculates
relative percentage error with respect to actual output.

.. _median_absolute_error:

Median absolute error
Expand Down
2 changes: 1 addition & 1 deletion doc/whats_new/_contributors.rst
Expand Up @@ -176,4 +176,4 @@

.. _Nicolas Hug: https://github.com/NicolasHug

.. _Guillaume Lemaitre: https://github.com/glemaitre
.. _Guillaume Lemaitre: https://github.com/glemaitre
6 changes: 6 additions & 0 deletions doc/whats_new/v0.24.rst
Expand Up @@ -150,6 +150,12 @@ Changelog
:mod:`sklearn.metrics`
......................

- |Feature| Added :func:`metrics.mean_absolute_percentage_error` metric and
the associated scorer for regression problems. :issue:`10708` fixed with the
PR :pr:`15007` by :user:`Ashutosh Hathidara <ashutosh1919>`. The scorer and
some practical test cases were taken from PR :pr:`10711` by
:user:`Mohamed Ali Jamaoui <mohamed-ali>`.

- |Fix| Fixed a bug in :func:`metrics.mean_squared_error` where the
average of multiple RMSE values was incorrectly calculated as the root of the
average of multiple MSE values.
Expand Down
2 changes: 2 additions & 0 deletions sklearn/metrics/__init__.py
Expand Up @@ -64,6 +64,7 @@
from ._regression import mean_squared_error
from ._regression import mean_squared_log_error
from ._regression import median_absolute_error
from ._regression import mean_absolute_percentage_error
from ._regression import r2_score
from ._regression import mean_tweedie_deviance
from ._regression import mean_poisson_deviance
Expand Down Expand Up @@ -128,6 +129,7 @@
'mean_gamma_deviance',
'mean_tweedie_deviance',
'median_absolute_error',
'mean_absolute_percentage_error',
'multilabel_confusion_matrix',
'mutual_info_score',
'ndcg_score',
Expand Down
77 changes: 77 additions & 0 deletions sklearn/metrics/_regression.py
Expand Up @@ -20,6 +20,7 @@
# Michael Eickenberg <michael.eickenberg@gmail.com>
# Konstantin Shmelkov <konstantin.shmelkov@polytechnique.edu>
# Christian Lorentzen <lorentzen.ch@googlemail.com>
# Ashutosh Hathidara <ashutoshhathidara98@gmail.com>
# License: BSD 3 clause

import numpy as np
Expand All @@ -41,6 +42,7 @@
"mean_squared_error",
"mean_squared_log_error",
"median_absolute_error",
"mean_absolute_percentage_error",
"r2_score",
"explained_variance_score",
"mean_tweedie_deviance",
Expand Down Expand Up @@ -192,6 +194,81 @@ def mean_absolute_error(y_true, y_pred, *,
return np.average(output_errors, weights=multioutput)


def mean_absolute_percentage_error(y_true, y_pred,
sample_weight=None,
multioutput='uniform_average'):
"""Mean absolute percentage error regression loss
Note here that we do not represent the output as a percentage in range
[0, 100]. Instead, we represent it in range [0, 1/eps]. Read more in the
:ref:`User Guide <mean_absolute_percentage_error>`.
Parameters
----------
y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
Ground truth (correct) target values.
y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
Estimated target values.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
multioutput : {'raw_values', 'uniform_average'} or array-like
Defines aggregating of multiple output values.
Array-like value defines weights used to average errors.
If input is list then the shape must be (n_outputs,).
'raw_values' :
Returns a full set of errors in case of multioutput input.
'uniform_average' :
Errors of all outputs are averaged with uniform weight.
Returns
-------
loss : float or ndarray of floats in the range [0, 1/eps]
If multioutput is 'raw_values', then mean absolute percentage error
is returned for each output separately.
If multioutput is 'uniform_average' or an ndarray of weights, then the
weighted average of all output errors is returned.
MAPE output is non-negative floating point. The best value is 0.0.
But note the fact that bad predictions can lead to arbitarily large
MAPE values, especially if some y_true values are very close to zero.
Note that we return a large value instead of `inf` when y_true is zero.
Examples
--------
>>> from sklearn.metrics import mean_absolute_percentage_error
>>> y_true = [3, -0.5, 2, 7]
>>> y_pred = [2.5, 0.0, 2, 8]
>>> mean_absolute_percentage_error(y_true, y_pred)
0.3273...
>>> y_true = [[0.5, 1], [-1, 1], [7, -6]]
>>> y_pred = [[0, 2], [-1, 2], [8, -5]]
>>> mean_absolute_percentage_error(y_true, y_pred)
0.5515...
>>> mean_absolute_percentage_error(y_true, y_pred, multioutput=[0.3, 0.7])
0.6198...
"""
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput)
check_consistent_length(y_true, y_pred, sample_weight)
epsilon = np.finfo(np.float64).eps
mape = np.abs(y_pred - y_true) / np.maximum(np.abs(y_true), epsilon)
output_errors = np.average(mape,
weights=sample_weight, axis=0)
if isinstance(multioutput, str):
if multioutput == 'raw_values':
return output_errors
elif multioutput == 'uniform_average':
# pass None as weights to np.average: uniform mean
multioutput = None

return np.average(output_errors, weights=multioutput)


@_deprecate_positional_args
def mean_squared_error(y_true, y_pred, *,
sample_weight=None,
Expand Down
6 changes: 5 additions & 1 deletion sklearn/metrics/_scorer.py
Expand Up @@ -30,7 +30,7 @@
f1_score, roc_auc_score, average_precision_score,
precision_score, recall_score, log_loss,
balanced_accuracy_score, explained_variance_score,
brier_score_loss, jaccard_score)
brier_score_loss, jaccard_score, mean_absolute_percentage_error)

from .cluster import adjusted_rand_score
from .cluster import homogeneity_score
Expand Down Expand Up @@ -614,6 +614,9 @@ def make_scorer(score_func, *, greater_is_better=True, needs_proba=False,
greater_is_better=False)
neg_mean_absolute_error_scorer = make_scorer(mean_absolute_error,
greater_is_better=False)
neg_mean_absolute_percentage_error_scorer = make_scorer(
mean_absolute_percentage_error, greater_is_better=False
)
neg_median_absolute_error_scorer = make_scorer(median_absolute_error,
greater_is_better=False)
neg_root_mean_squared_error_scorer = make_scorer(mean_squared_error,
Expand Down Expand Up @@ -674,6 +677,7 @@ def make_scorer(score_func, *, greater_is_better=True, needs_proba=False,
max_error=max_error_scorer,
neg_median_absolute_error=neg_median_absolute_error_scorer,
neg_mean_absolute_error=neg_mean_absolute_error_scorer,
neg_mean_absolute_percentage_error=neg_mean_absolute_percentage_error_scorer, # noqa
neg_mean_squared_error=neg_mean_squared_error_scorer,
neg_mean_squared_log_error=neg_mean_squared_log_error_scorer,
neg_root_mean_squared_error=neg_root_mean_squared_error_scorer,
Expand Down
16 changes: 13 additions & 3 deletions sklearn/metrics/tests/test_common.py
Expand Up @@ -41,6 +41,7 @@
from sklearn.metrics import max_error
from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_absolute_percentage_error
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_tweedie_deviance
from sklearn.metrics import mean_poisson_deviance
Expand Down Expand Up @@ -98,6 +99,7 @@
"mean_absolute_error": mean_absolute_error,
"mean_squared_error": mean_squared_error,
"median_absolute_error": median_absolute_error,
"mean_absolute_percentage_error": mean_absolute_percentage_error,
"explained_variance_score": explained_variance_score,
"r2_score": partial(r2_score, multioutput='variance_weighted'),
"mean_normal_deviance": partial(mean_tweedie_deviance, power=0),
Expand Down Expand Up @@ -425,7 +427,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
# Regression metrics with "multioutput-continuous" format support
MULTIOUTPUT_METRICS = {
"mean_absolute_error", "median_absolute_error", "mean_squared_error",
"r2_score", "explained_variance_score"
"r2_score", "explained_variance_score", "mean_absolute_percentage_error"
}

# Symmetric with respect to their input arguments y_true and y_pred
Expand Down Expand Up @@ -472,7 +474,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
"macro_f0.5_score", "macro_f2_score", "macro_precision_score",
"macro_recall_score", "log_loss", "hinge_loss",
"mean_gamma_deviance", "mean_poisson_deviance",
"mean_compound_poisson_deviance"
"mean_compound_poisson_deviance", "mean_absolute_percentage_error"
}


Expand Down Expand Up @@ -1371,7 +1373,15 @@ def test_thresholded_multilabel_multioutput_permutations_invariance(name):
y_true_perm = y_true[:, perm]

current_score = metric(y_true_perm, y_score_perm)
assert_almost_equal(score, current_score)
if metric == mean_absolute_percentage_error:
assert np.isfinite(current_score)
assert current_score > 1e6
# Here we are not comparing the values in case of MAPE because
# whenever y_true value is exactly zero, the MAPE value doesn't
# signify anything. Thus, in this case we are just expecting
# very large finite value.
else:
assert_almost_equal(score, current_score)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 4db62e4

Please sign in to comment.