diff --git a/CHANGELOG.md b/CHANGELOG.md index e9ba95ce852..18b4681105c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for `torch.float` weighted networks for FID and KID calculations ([#2483](https://github.com/Lightning-AI/torchmetrics/pull/2483)) +- Added `zero_division` argument to selected classification metrics ([#2198](https://github.com/Lightning-AI/torchmetrics/pull/2198)) + + ### Changed - Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424)) diff --git a/requirements/_tests.txt b/requirements/_tests.txt index 69f82700725..6bf2a66a3d4 100644 --- a/requirements/_tests.txt +++ b/requirements/_tests.txt @@ -15,5 +15,6 @@ pyGithub ==2.3.0 fire <=0.6.0 cloudpickle >1.3, <=3.0.0 -scikit-learn >=1.1.1, <1.4.0 +scikit-learn >=1.1.1, <1.3.0; python_version < "3.9" +scikit-learn >=1.4.0, <1.5.0; python_version >= "3.9" cachier ==3.0.0 diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 93f26441c2a..526ad1ae0da 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -49,7 +49,8 @@ class BinaryFBetaScore(BinaryStatScores): The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0` where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false - positives and false negatives respectively. If this case is encountered a score of 0 is returned. + positives and false negatives respectively. If this case is encountered a score of `zero_division` + (0 or 1, default is 0) is returned. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -83,6 +84,8 @@ class BinaryFBetaScore(BinaryStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when + :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`. Example (preds is int tensor): >>> from torch import tensor @@ -125,6 +128,7 @@ def __init__( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, **kwargs: Any, ) -> None: super().__init__( @@ -135,14 +139,24 @@ def __init__( **kwargs, ) if validate_args: - _binary_fbeta_score_arg_validation(beta, threshold, multidim_average, ignore_index) + _binary_fbeta_score_arg_validation(beta, threshold, multidim_average, ignore_index, zero_division) self.validate_args = validate_args + self.zero_division = zero_division self.beta = beta def compute(self) -> Tensor: """Compute metric.""" tp, fp, tn, fn = self._final_state() - return _fbeta_reduce(tp, fp, tn, fn, self.beta, average="binary", multidim_average=self.multidim_average) + return _fbeta_reduce( + tp, + fp, + tn, + fn, + self.beta, + average="binary", + multidim_average=self.multidim_average, + zero_division=self.zero_division, + ) def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None @@ -197,7 +211,7 @@ class MulticlassFBetaScore(MulticlassStatScores): The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0` where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any class, the metric for that class - will be set to 0 and the overall metric may therefore be affected in turn. + will be set to `zero_division` (0 or 1, default is 0) and the overall metric may therefore be affected in turn. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -249,6 +263,8 @@ class MulticlassFBetaScore(MulticlassStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when + :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`. Example (preds is int tensor): >>> from torch import tensor @@ -306,6 +322,7 @@ def __init__( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, **kwargs: Any, ) -> None: super().__init__( @@ -318,14 +335,26 @@ def __init__( **kwargs, ) if validate_args: - _multiclass_fbeta_score_arg_validation(beta, num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_fbeta_score_arg_validation( + beta, num_classes, top_k, average, multidim_average, ignore_index, zero_division + ) self.validate_args = validate_args + self.zero_division = zero_division self.beta = beta def compute(self) -> Tensor: """Compute metric.""" tp, fp, tn, fn = self._final_state() - return _fbeta_reduce(tp, fp, tn, fn, self.beta, average=self.average, multidim_average=self.multidim_average) + return _fbeta_reduce( + tp, + fp, + tn, + fn, + self.beta, + average=self.average, + multidim_average=self.multidim_average, + zero_division=self.zero_division, + ) def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None @@ -380,7 +409,7 @@ class MultilabelFBetaScore(MultilabelStatScores): The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0` where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any label, the metric for that label - will be set to 0 and the overall metric may therefore be affected in turn. + will be set to `zero_division` (0 or 1, default is 0) and the overall metric may therefore be affected in turn. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -430,6 +459,8 @@ class MultilabelFBetaScore(MultilabelStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when + :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`. Example (preds is int tensor): >>> from torch import tensor @@ -485,6 +516,7 @@ def __init__( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, **kwargs: Any, ) -> None: super().__init__( @@ -497,15 +529,26 @@ def __init__( **kwargs, ) if validate_args: - _multilabel_fbeta_score_arg_validation(beta, num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_fbeta_score_arg_validation( + beta, num_labels, threshold, average, multidim_average, ignore_index, zero_division + ) self.validate_args = validate_args + self.zero_division = zero_division self.beta = beta def compute(self) -> Tensor: """Compute metric.""" tp, fp, tn, fn = self._final_state() return _fbeta_reduce( - tp, fp, tn, fn, self.beta, average=self.average, multidim_average=self.multidim_average, multilabel=True + tp, + fp, + tn, + fn, + self.beta, + average=self.average, + multidim_average=self.multidim_average, + multilabel=True, + zero_division=self.zero_division, ) def plot( @@ -559,7 +602,8 @@ class BinaryF1Score(BinaryFBetaScore): The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0` where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false - positives and false negatives respectively. If this case is encountered a score of 0 is returned. + positives and false negatives respectively. If this case is encountered a score of `zero_division` + (0 or 1, default is 0) is returned. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -592,6 +636,8 @@ class BinaryF1Score(BinaryFBetaScore): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when + :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`. Example (preds is int tensor): >>> from torch import tensor @@ -633,6 +679,7 @@ def __init__( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, **kwargs: Any, ) -> None: super().__init__( @@ -641,6 +688,7 @@ def __init__( multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args, + zero_division=zero_division, **kwargs, ) @@ -696,7 +744,7 @@ class MulticlassF1Score(MulticlassFBetaScore): The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0` where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any class, the metric for that class - will be set to 0 and the overall metric may therefore be affected in turn. + will be set to `zero_division` (0 or 1, default is 0) and the overall metric may therefore be affected in turn. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -748,6 +796,8 @@ class MulticlassF1Score(MulticlassFBetaScore): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when + :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`. Example (preds is int tensor): >>> from torch import tensor @@ -804,6 +854,7 @@ def __init__( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, **kwargs: Any, ) -> None: super().__init__( @@ -814,6 +865,7 @@ def __init__( multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args, + zero_division=zero_division, **kwargs, ) @@ -869,7 +921,7 @@ class MultilabelF1Score(MultilabelFBetaScore): The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0` where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any label, the metric for that label - will be set to 0 and the overall metric may therefore be affected in turn. + will be set to `zero_division` (0 or 1, default is 0) and the overall metric may therefore be affected in turn. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -919,6 +971,8 @@ class MultilabelF1Score(MultilabelFBetaScore): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when + :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`. Example (preds is int tensor): >>> from torch import tensor @@ -973,6 +1027,7 @@ def __init__( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, **kwargs: Any, ) -> None: super().__init__( @@ -983,6 +1038,7 @@ def __init__( multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args, + zero_division=zero_division, **kwargs, ) @@ -1039,7 +1095,8 @@ class FBetaScore(_ClassificationTaskWrapper): The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0` where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any class/label, the metric for that - class/label will be set to 0 and the overall metric may therefore be affected in turn. + class/label will be set to `zero_division` (0 or 1, default is 0) and the overall metric may therefore be + affected in turn. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of @@ -1070,6 +1127,7 @@ def __new__( # type: ignore[misc] top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, **kwargs: Any, ) -> Metric: """Initialize task metric.""" @@ -1079,6 +1137,7 @@ def __new__( # type: ignore[misc] "multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args, + "zero_division": zero_division, }) if task == ClassificationTask.BINARY: return BinaryFBetaScore(beta, threshold, **kwargs) @@ -1104,7 +1163,8 @@ class F1Score(_ClassificationTaskWrapper): The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0` where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any class/label, the metric for that - class/label will be set to 0 and the overall metric may therefore be affected in turn. + class/label will be set to `zero_division` (0 or 1, default is 0) and the overall metric may therefore be + affected in turn. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of @@ -1133,6 +1193,7 @@ def __new__( # type: ignore[misc] top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, **kwargs: Any, ) -> Metric: """Initialize task metric.""" @@ -1142,6 +1203,7 @@ def __new__( # type: ignore[misc] "multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args, + "zero_division": zero_division, }) if task == ClassificationTask.BINARY: return BinaryF1Score(threshold, **kwargs) diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 4a230122b1b..0ea04849696 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -65,6 +65,8 @@ class BinaryJaccardIndex(BinaryConfusionMatrix): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: + Value to replace when there is a division by zero. Should be `0` or `1`. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): @@ -97,15 +99,17 @@ def __init__( threshold: float = 0.5, ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, **kwargs: Any, ) -> None: super().__init__( threshold=threshold, ignore_index=ignore_index, normalize=None, validate_args=validate_args, **kwargs ) + self.zero_division = zero_division def compute(self) -> Tensor: """Compute metric.""" - return _jaccard_index_reduce(self.confmat, average="binary") + return _jaccard_index_reduce(self.confmat, average="binary", zero_division=self.zero_division) def plot( # type: ignore[override] self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None @@ -187,6 +191,8 @@ class MulticlassJaccardIndex(MulticlassConfusionMatrix): validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: + Value to replace when there is a division by zero. Should be `0` or `1`. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (pred is integer tensor): @@ -224,6 +230,7 @@ def __init__( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, **kwargs: Any, ) -> None: super().__init__( @@ -233,10 +240,13 @@ def __init__( _multiclass_jaccard_index_arg_validation(num_classes, ignore_index, average) self.validate_args = validate_args self.average = average + self.zero_division = zero_division def compute(self) -> Tensor: """Compute metric.""" - return _jaccard_index_reduce(self.confmat, average=self.average, ignore_index=self.ignore_index) + return _jaccard_index_reduce( + self.confmat, average=self.average, ignore_index=self.ignore_index, zero_division=self.zero_division + ) def plot( # type: ignore[override] self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None @@ -319,6 +329,8 @@ class MultilabelJaccardIndex(MultilabelConfusionMatrix): validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: + Value to replace when there is a division by zero. Should be `0` or `1`. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): @@ -354,6 +366,7 @@ def __init__( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, **kwargs: Any, ) -> None: super().__init__( @@ -368,10 +381,11 @@ def __init__( _multilabel_jaccard_index_arg_validation(num_labels, threshold, ignore_index, average) self.validate_args = validate_args self.average = average + self.zero_division = zero_division def compute(self) -> Tensor: """Compute metric.""" - return _jaccard_index_reduce(self.confmat, average=self.average) + return _jaccard_index_reduce(self.confmat, average=self.average, zero_division=self.zero_division) def plot( # type: ignore[override] self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 124215f4e03..0380545b5ac 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -18,7 +18,9 @@ from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores -from torchmetrics.functional.classification.precision_recall import _precision_recall_reduce +from torchmetrics.functional.classification.precision_recall import ( + _precision_recall_reduce, +) from torchmetrics.metric import Metric from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE @@ -42,7 +44,7 @@ class BinaryPrecision(BinaryStatScores): Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is - encountered a score of 0 is returned. + encountered a score of `zero_division` (0 or 1, default is 0) is returned. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -73,6 +75,7 @@ class BinaryPrecision(BinaryStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`. Example (preds is int tensor): >>> from torch import tensor @@ -112,7 +115,14 @@ def compute(self) -> Tensor: """Compute metric.""" tp, fp, tn, fn = self._final_state() return _precision_recall_reduce( - "precision", tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average + "precision", + tp, + fp, + tn, + fn, + average="binary", + multidim_average=self.multidim_average, + zero_division=self.zero_division, ) def plot( @@ -165,8 +175,8 @@ class MulticlassPrecision(MulticlassStatScores): Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is - encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be - affected in turn. + encountered for any class, the metric for that class will be set to `zero_division` (0 or 1, default is 0) and + the overall metric may therefore be affected in turn. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -217,6 +227,7 @@ class MulticlassPrecision(MulticlassStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`. Example (preds is int tensor): >>> from torch import tensor @@ -269,7 +280,15 @@ def compute(self) -> Tensor: """Compute metric.""" tp, fp, tn, fn = self._final_state() return _precision_recall_reduce( - "precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, top_k=self.top_k + "precision", + tp, + fp, + tn, + fn, + average=self.average, + multidim_average=self.multidim_average, + top_k=self.top_k, + zero_division=self.zero_division, ) def plot( @@ -322,8 +341,8 @@ class MultilabelPrecision(MultilabelStatScores): Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is - encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be - affected in turn. + encountered for any label, the metric for that label will be set to `zero_division` (0 or 1, default is 0) and + the overall metric may therefore be affected in turn. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -373,6 +392,7 @@ class MultilabelPrecision(MultilabelStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`. Example (preds is int tensor): >>> from torch import tensor @@ -423,7 +443,15 @@ def compute(self) -> Tensor: """Compute metric.""" tp, fp, tn, fn = self._final_state() return _precision_recall_reduce( - "precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True + "precision", + tp, + fp, + tn, + fn, + average=self.average, + multidim_average=self.multidim_average, + multilabel=True, + zero_division=self.zero_division, ) def plot( @@ -476,7 +504,7 @@ class BinaryRecall(BinaryStatScores): Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and false negatives respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this case is - encountered a score of 0 is returned. + encountered a score of `zero_division` (0 or 1, default is 0) is returned. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -507,6 +535,7 @@ class BinaryRecall(BinaryStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`. Example (preds is int tensor): >>> from torch import tensor @@ -546,7 +575,14 @@ def compute(self) -> Tensor: """Compute metric.""" tp, fp, tn, fn = self._final_state() return _precision_recall_reduce( - "recall", tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average + "recall", + tp, + fp, + tn, + fn, + average="binary", + multidim_average=self.multidim_average, + zero_division=self.zero_division, ) def plot( @@ -599,8 +635,8 @@ class MulticlassRecall(MulticlassStatScores): Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and false negatives respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this case is - encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be - affected in turn. + encountered for any class, the metric for that class will be set to `zero_division` (0 or 1, default is 0) and + the overall metric may therefore be affected in turn. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -650,6 +686,7 @@ class MulticlassRecall(MulticlassStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`. Example (preds is int tensor): >>> from torch import tensor @@ -702,7 +739,15 @@ def compute(self) -> Tensor: """Compute metric.""" tp, fp, tn, fn = self._final_state() return _precision_recall_reduce( - "recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, top_k=self.top_k + "recall", + tp, + fp, + tn, + fn, + average=self.average, + multidim_average=self.multidim_average, + top_k=self.top_k, + zero_division=self.zero_division, ) def plot( @@ -755,8 +800,8 @@ class MultilabelRecall(MultilabelStatScores): Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and false negatives respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this case is - encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be - affected in turn. + encountered for any label, the metric for that label will be set to `zero_division` (0 or 1, default is 0) and + the overall metric may therefore be affected in turn. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -805,6 +850,7 @@ class MultilabelRecall(MultilabelStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`. Example (preds is int tensor): >>> from torch import tensor @@ -855,7 +901,15 @@ def compute(self) -> Tensor: """Compute metric.""" tp, fp, tn, fn = self._final_state() return _precision_recall_reduce( - "recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True + "recall", + tp, + fp, + tn, + fn, + average=self.average, + multidim_average=self.multidim_average, + multilabel=True, + zero_division=self.zero_division, ) def plot( diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 1ae0d4285e6..96d797fd5d6 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -169,13 +169,15 @@ def __init__( validate_args: bool = True, **kwargs: Any, ) -> None: + zero_division = kwargs.pop("zero_division", 0) super(_AbstractStatScores, self).__init__(**kwargs) if validate_args: - _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division) self.threshold = threshold self.multidim_average = multidim_average self.ignore_index = ignore_index self.validate_args = validate_args + self.zero_division = zero_division self._create_state(size=1, multidim_average=multidim_average) @@ -313,15 +315,19 @@ def __init__( validate_args: bool = True, **kwargs: Any, ) -> None: + zero_division = kwargs.pop("zero_division", 0) super(_AbstractStatScores, self).__init__(**kwargs) if validate_args: - _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_arg_validation( + num_classes, top_k, average, multidim_average, ignore_index, zero_division + ) self.num_classes = num_classes self.top_k = top_k self.average = average self.multidim_average = multidim_average self.ignore_index = ignore_index self.validate_args = validate_args + self.zero_division = zero_division self._create_state( size=1 if (average == "micro" and top_k == 1) else num_classes, multidim_average=multidim_average @@ -461,15 +467,19 @@ def __init__( validate_args: bool = True, **kwargs: Any, ) -> None: + zero_division = kwargs.pop("zero_division", 0) super(_AbstractStatScores, self).__init__(**kwargs) if validate_args: - _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_stat_scores_arg_validation( + num_labels, threshold, average, multidim_average, ignore_index, zero_division + ) self.num_labels = num_labels self.threshold = threshold self.average = average self.multidim_average = multidim_average self.ignore_index = ignore_index self.validate_args = validate_args + self.zero_division = zero_division self._create_state(size=num_labels, multidim_average=multidim_average) diff --git a/src/torchmetrics/functional/classification/f_beta.py b/src/torchmetrics/functional/classification/f_beta.py index 0f0e883266c..83f61955960 100644 --- a/src/torchmetrics/functional/classification/f_beta.py +++ b/src/torchmetrics/functional/classification/f_beta.py @@ -43,17 +43,18 @@ def _fbeta_reduce( average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]], multidim_average: Literal["global", "samplewise"] = "global", multilabel: bool = False, + zero_division: float = 0, ) -> Tensor: beta2 = beta**2 if average == "binary": - return _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp) + return _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp, zero_division) if average == "micro": tp = tp.sum(dim=0 if multidim_average == "global" else 1) fn = fn.sum(dim=0 if multidim_average == "global" else 1) fp = fp.sum(dim=0 if multidim_average == "global" else 1) - return _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp) + return _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp, zero_division) - fbeta_score = _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp) + fbeta_score = _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp, zero_division) return _adjust_weights_safe_divide(fbeta_score, average, multilabel, tp, fp, fn) @@ -62,10 +63,11 @@ def _binary_fbeta_score_arg_validation( threshold: float = 0.5, multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, + zero_division: float = 0, ) -> None: if not (isinstance(beta, float) and beta > 0): raise ValueError(f"Expected argument `beta` to be a float larger than 0, but got {beta}.") - _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division) def binary_fbeta_score( @@ -76,6 +78,7 @@ def binary_fbeta_score( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, ) -> Tensor: r"""Compute `F-score`_ metric for binary tasks. @@ -106,6 +109,8 @@ def binary_fbeta_score( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when + :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`. Returns: If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` @@ -136,11 +141,13 @@ def binary_fbeta_score( """ if validate_args: - _binary_fbeta_score_arg_validation(beta, threshold, multidim_average, ignore_index) + _binary_fbeta_score_arg_validation(beta, threshold, multidim_average, ignore_index, zero_division) _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) - return _fbeta_reduce(tp, fp, tn, fn, beta, average="binary", multidim_average=multidim_average) + return _fbeta_reduce( + tp, fp, tn, fn, beta, average="binary", multidim_average=multidim_average, zero_division=zero_division + ) def _multiclass_fbeta_score_arg_validation( @@ -150,10 +157,11 @@ def _multiclass_fbeta_score_arg_validation( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, + zero_division: float = 0, ) -> None: if not (isinstance(beta, float) and beta > 0): raise ValueError(f"Expected argument `beta` to be a float larger than 0, but got {beta}.") - _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index, zero_division) def multiclass_fbeta_score( @@ -166,6 +174,7 @@ def multiclass_fbeta_score( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, ) -> Tensor: r"""Compute `F-score`_ metric for multiclass tasks. @@ -206,6 +215,8 @@ def multiclass_fbeta_score( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when + :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -254,13 +265,17 @@ def multiclass_fbeta_score( """ if validate_args: - _multiclass_fbeta_score_arg_validation(beta, num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_fbeta_score_arg_validation( + beta, num_classes, top_k, average, multidim_average, ignore_index, zero_division + ) _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) preds, target = _multiclass_stat_scores_format(preds, target, top_k) tp, fp, tn, fn = _multiclass_stat_scores_update( preds, target, num_classes, top_k, average, multidim_average, ignore_index ) - return _fbeta_reduce(tp, fp, tn, fn, beta, average=average, multidim_average=multidim_average) + return _fbeta_reduce( + tp, fp, tn, fn, beta, average=average, multidim_average=multidim_average, zero_division=zero_division + ) def _multilabel_fbeta_score_arg_validation( @@ -270,10 +285,13 @@ def _multilabel_fbeta_score_arg_validation( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, + zero_division: float = 0, ) -> None: if not (isinstance(beta, float) and beta > 0): raise ValueError(f"Expected argument `beta` to be a float larger than 0, but got {beta}.") - _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_stat_scores_arg_validation( + num_labels, threshold, average, multidim_average, ignore_index, zero_division + ) def multilabel_fbeta_score( @@ -286,6 +304,7 @@ def multilabel_fbeta_score( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, ) -> Tensor: r"""Compute `F-score`_ metric for multilabel tasks. @@ -325,6 +344,8 @@ def multilabel_fbeta_score( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when + :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -371,11 +392,23 @@ def multilabel_fbeta_score( """ if validate_args: - _multilabel_fbeta_score_arg_validation(beta, num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_fbeta_score_arg_validation( + beta, num_labels, threshold, average, multidim_average, ignore_index, zero_division + ) _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) - return _fbeta_reduce(tp, fp, tn, fn, beta, average=average, multidim_average=multidim_average, multilabel=True) + return _fbeta_reduce( + tp, + fp, + tn, + fn, + beta, + average=average, + multidim_average=multidim_average, + multilabel=True, + zero_division=zero_division, + ) def binary_f1_score( @@ -385,6 +418,7 @@ def binary_f1_score( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, ) -> Tensor: r"""Compute F-1 score for binary tasks. @@ -413,6 +447,8 @@ def binary_f1_score( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when + :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`. Returns: If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` @@ -450,6 +486,7 @@ def binary_f1_score( multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args, + zero_division=zero_division, ) @@ -462,6 +499,7 @@ def multiclass_f1_score( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, ) -> Tensor: r"""Compute F-1 score for multiclass tasks. @@ -500,6 +538,8 @@ def multiclass_f1_score( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when + :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -557,6 +597,7 @@ def multiclass_f1_score( multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args, + zero_division=zero_division, ) @@ -569,6 +610,7 @@ def multilabel_f1_score( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, ) -> Tensor: r"""Compute F-1 score for multilabel tasks. @@ -606,6 +648,8 @@ def multilabel_f1_score( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when + :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -661,6 +705,7 @@ def multilabel_f1_score( multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args, + zero_division=zero_division, ) @@ -677,6 +722,7 @@ def fbeta_score( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, ) -> Tensor: r"""Compute `F-score`_ metric. @@ -702,20 +748,40 @@ def fbeta_score( task = ClassificationTask.from_str(task) assert multidim_average is not None # noqa: S101 # needed for mypy if task == ClassificationTask.BINARY: - return binary_fbeta_score(preds, target, beta, threshold, multidim_average, ignore_index, validate_args) + return binary_fbeta_score( + preds, target, beta, threshold, multidim_average, ignore_index, validate_args, zero_division + ) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if not isinstance(top_k, int): raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") return multiclass_fbeta_score( - preds, target, beta, num_classes, average, top_k, multidim_average, ignore_index, validate_args + preds, + target, + beta, + num_classes, + average, + top_k, + multidim_average, + ignore_index, + validate_args, + zero_division, ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return multilabel_fbeta_score( - preds, target, beta, num_labels, threshold, average, multidim_average, ignore_index, validate_args + preds, + target, + beta, + num_labels, + threshold, + average, + multidim_average, + ignore_index, + validate_args, + zero_division, ) raise ValueError(f"Unsupported task `{task}` passed.") @@ -732,6 +798,7 @@ def f1_score( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, ) -> Tensor: r"""Compute F-1 score. @@ -756,19 +823,19 @@ def f1_score( task = ClassificationTask.from_str(task) assert multidim_average is not None # noqa: S101 # needed for mypy if task == ClassificationTask.BINARY: - return binary_f1_score(preds, target, threshold, multidim_average, ignore_index, validate_args) + return binary_f1_score(preds, target, threshold, multidim_average, ignore_index, validate_args, zero_division) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if not isinstance(top_k, int): raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") return multiclass_f1_score( - preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args, zero_division ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return multilabel_f1_score( - preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args, zero_division ) raise ValueError(f"Unsupported task `{task}` passed.") diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 7e928525ad8..1d240df68af 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -39,6 +39,7 @@ def _jaccard_index_reduce( confmat: Tensor, average: Optional[Literal["micro", "macro", "weighted", "none", "binary"]], ignore_index: Optional[int] = None, + zero_division: float = 0.0, ) -> Tensor: """Perform reduction of an un-normalized confusion matrix into jaccard score. @@ -57,6 +58,8 @@ def _jaccard_index_reduce( ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation + zero_division: + Value to replace when there is a division by zero. Should be `0` or `1`. """ allowed_average = ["binary", "micro", "macro", "weighted", "none", None] @@ -79,7 +82,7 @@ def _jaccard_index_reduce( num = num.sum() denom = denom.sum() - (denom[ignore_index] if ignore_index_cond else 0.0) - jaccard = _safe_divide(num, denom) + jaccard = _safe_divide(num, denom, zero_division=zero_division) if average is None or average == "none" or average == "micro": return jaccard @@ -100,6 +103,7 @@ def binary_jaccard_index( threshold: float = 0.5, ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0.0, ) -> Tensor: r"""Calculate the Jaccard index for binary tasks. @@ -126,7 +130,8 @@ def binary_jaccard_index( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + zero_division: + Value to replace when there is a division by zero. Should be `0` or `1`. Example (preds is int tensor): >>> from torch import tensor @@ -149,7 +154,7 @@ def binary_jaccard_index( _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index) confmat = _binary_confusion_matrix_update(preds, target) - return _jaccard_index_reduce(confmat, average="binary") + return _jaccard_index_reduce(confmat, average="binary", zero_division=zero_division) def _multiclass_jaccard_index_arg_validation( @@ -170,6 +175,7 @@ def multiclass_jaccard_index( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0.0, ) -> Tensor: r"""Calculate the Jaccard index for multiclass tasks. @@ -204,7 +210,8 @@ def multiclass_jaccard_index( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + zero_division: + Value to replace when there is a division by zero. Should be `0` or `1`. Example (pred is integer tensor): >>> from torch import tensor @@ -230,7 +237,7 @@ def multiclass_jaccard_index( _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index) confmat = _multiclass_confusion_matrix_update(preds, target, num_classes) - return _jaccard_index_reduce(confmat, average=average, ignore_index=ignore_index) + return _jaccard_index_reduce(confmat, average=average, ignore_index=ignore_index, zero_division=zero_division) def _multilabel_jaccard_index_arg_validation( @@ -253,6 +260,7 @@ def multilabel_jaccard_index( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0.0, ) -> Tensor: r"""Calculate the Jaccard index for multilabel tasks. @@ -288,7 +296,8 @@ def multilabel_jaccard_index( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + zero_division: + Value to replace when there is a division by zero. Should be `0` or `1`. Example (preds is int tensor): >>> from torch import tensor @@ -311,7 +320,7 @@ def multilabel_jaccard_index( _multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index) preds, target = _multilabel_confusion_matrix_format(preds, target, num_labels, threshold, ignore_index) confmat = _multilabel_confusion_matrix_update(preds, target, num_labels) - return _jaccard_index_reduce(confmat, average=average, ignore_index=ignore_index) + return _jaccard_index_reduce(confmat, average=average, ignore_index=ignore_index, zero_division=zero_division) def jaccard_index( @@ -324,6 +333,7 @@ def jaccard_index( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0.0, ) -> Tensor: r"""Calculate the Jaccard index. @@ -351,13 +361,15 @@ def jaccard_index( """ task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: - return binary_jaccard_index(preds, target, threshold, ignore_index, validate_args) + return binary_jaccard_index(preds, target, threshold, ignore_index, validate_args, zero_division) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") - return multiclass_jaccard_index(preds, target, num_classes, average, ignore_index, validate_args) + return multiclass_jaccard_index(preds, target, num_classes, average, ignore_index, validate_args, zero_division) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") - return multilabel_jaccard_index(preds, target, num_labels, threshold, average, ignore_index, validate_args) + return multilabel_jaccard_index( + preds, target, num_labels, threshold, average, ignore_index, validate_args, zero_division + ) raise ValueError(f"Not handled value: {task}") diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index beb70d54bbc..96214c82274 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -44,17 +44,18 @@ def _precision_recall_reduce( multidim_average: Literal["global", "samplewise"] = "global", multilabel: bool = False, top_k: int = 1, + zero_division: float = 0, ) -> Tensor: different_stat = fp if stat == "precision" else fn # this is what differs between the two scores if average == "binary": - return _safe_divide(tp, tp + different_stat) + return _safe_divide(tp, tp + different_stat, zero_division) if average == "micro": tp = tp.sum(dim=0 if multidim_average == "global" else 1) fn = fn.sum(dim=0 if multidim_average == "global" else 1) different_stat = different_stat.sum(dim=0 if multidim_average == "global" else 1) - return _safe_divide(tp, tp + different_stat) + return _safe_divide(tp, tp + different_stat, zero_division) - score = _safe_divide(tp, tp + different_stat) + score = _safe_divide(tp, tp + different_stat, zero_division) return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k=top_k) @@ -65,6 +66,7 @@ def binary_precision( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, ) -> Tensor: r"""Compute `Precision`_ for binary tasks. @@ -95,6 +97,7 @@ def binary_precision( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`. Returns: If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` @@ -129,7 +132,9 @@ def binary_precision( _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) - return _precision_recall_reduce("precision", tp, fp, tn, fn, average="binary", multidim_average=multidim_average) + return _precision_recall_reduce( + "precision", tp, fp, tn, fn, average="binary", multidim_average=multidim_average, zero_division=zero_division + ) def multiclass_precision( @@ -141,6 +146,7 @@ def multiclass_precision( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, ) -> Tensor: r"""Compute `Precision`_ for multiclass tasks. @@ -182,6 +188,7 @@ def multiclass_precision( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -237,7 +244,15 @@ def multiclass_precision( preds, target, num_classes, top_k, average, multidim_average, ignore_index ) return _precision_recall_reduce( - "precision", tp, fp, tn, fn, average=average, multidim_average=multidim_average, top_k=top_k + "precision", + tp, + fp, + tn, + fn, + average=average, + multidim_average=multidim_average, + top_k=top_k, + zero_division=zero_division, ) @@ -250,6 +265,7 @@ def multilabel_precision( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, ) -> Tensor: r"""Compute `Precision`_ for multilabel tasks. @@ -289,6 +305,7 @@ def multilabel_precision( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -340,7 +357,15 @@ def multilabel_precision( preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) return _precision_recall_reduce( - "precision", tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True + "precision", + tp, + fp, + tn, + fn, + average=average, + multidim_average=multidim_average, + multilabel=True, + zero_division=zero_division, ) @@ -351,6 +376,7 @@ def binary_recall( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, ) -> Tensor: r"""Compute `Recall`_ for binary tasks. @@ -381,6 +407,7 @@ def binary_recall( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`. Returns: If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` @@ -415,7 +442,9 @@ def binary_recall( _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) - return _precision_recall_reduce("recall", tp, fp, tn, fn, average="binary", multidim_average=multidim_average) + return _precision_recall_reduce( + "recall", tp, fp, tn, fn, average="binary", multidim_average=multidim_average, zero_division=zero_division + ) def multiclass_recall( @@ -427,6 +456,7 @@ def multiclass_recall( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, ) -> Tensor: r"""Compute `Recall`_ for multiclass tasks. @@ -468,6 +498,7 @@ def multiclass_recall( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -523,7 +554,15 @@ def multiclass_recall( preds, target, num_classes, top_k, average, multidim_average, ignore_index ) return _precision_recall_reduce( - "recall", tp, fp, tn, fn, average=average, multidim_average=multidim_average, top_k=top_k + "recall", + tp, + fp, + tn, + fn, + average=average, + multidim_average=multidim_average, + top_k=top_k, + zero_division=zero_division, ) @@ -536,6 +575,7 @@ def multilabel_recall( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, ) -> Tensor: r"""Compute `Recall`_ for multilabel tasks. @@ -575,6 +615,7 @@ def multilabel_recall( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -626,7 +667,15 @@ def multilabel_recall( preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) return _precision_recall_reduce( - "recall", tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True + "recall", + tp, + fp, + tn, + fn, + average=average, + multidim_average=multidim_average, + multilabel=True, + zero_division=zero_division, ) @@ -642,6 +691,7 @@ def precision( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, ) -> Tensor: r"""Compute `Precision`_. @@ -669,20 +719,20 @@ def precision( """ assert multidim_average is not None # noqa: S101 # needed for mypy if task == ClassificationTask.BINARY: - return binary_precision(preds, target, threshold, multidim_average, ignore_index, validate_args) + return binary_precision(preds, target, threshold, multidim_average, ignore_index, validate_args, zero_division) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if not isinstance(top_k, int): raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") return multiclass_precision( - preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args, zero_division ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return multilabel_precision( - preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args, zero_division ) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" @@ -701,6 +751,7 @@ def recall( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + zero_division: float = 0, ) -> Tensor: r"""Compute `Recall`_. @@ -729,19 +780,19 @@ def recall( task = ClassificationTask.from_str(task) assert multidim_average is not None # noqa: S101 # needed for mypy if task == ClassificationTask.BINARY: - return binary_recall(preds, target, threshold, multidim_average, ignore_index, validate_args) + return binary_recall(preds, target, threshold, multidim_average, ignore_index, validate_args, zero_division) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if not isinstance(top_k, int): raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") return multiclass_recall( - preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args, zero_division ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return multilabel_recall( - preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args, zero_division ) raise ValueError(f"Not handled value: {task}") diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index aa8e0bf5016..a412efc180f 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -26,12 +26,14 @@ def _binary_stat_scores_arg_validation( threshold: float = 0.5, multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, + zero_division: float = 0, ) -> None: """Validate non tensor input. - ``threshold`` has to be a float in the [0,1] range - ``multidim_average`` has to be either "global" or "samplewise" - ``ignore_index`` has to be None or int + - ``zero_division`` has to be 0 or 1 """ if not (isinstance(threshold, float) and (0 <= threshold <= 1)): @@ -43,6 +45,8 @@ def _binary_stat_scores_arg_validation( ) if ignore_index is not None and not isinstance(ignore_index, int): raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + if zero_division not in [0, 1]: + raise ValueError(f"Expected argument `zero_division` to be 0 or 1, but got {zero_division}.") def _binary_stat_scores_tensor_validation( @@ -220,6 +224,7 @@ def _multiclass_stat_scores_arg_validation( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, + zero_division: float = 0, ) -> None: """Validate non tensor input. @@ -228,6 +233,7 @@ def _multiclass_stat_scores_arg_validation( - ``average`` has to be "micro" | "macro" | "weighted" | "none" - ``multidim_average`` has to be either "global" or "samplewise" - ``ignore_index`` has to be None or int + - ``zero_division`` has to be 0 or 1 """ if not isinstance(num_classes, int) or num_classes < 2: @@ -248,6 +254,8 @@ def _multiclass_stat_scores_arg_validation( ) if ignore_index is not None and not isinstance(ignore_index, int): raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + if zero_division not in [0, 1]: + raise ValueError(f"Expected argument `zero_division` to be 0 or 1, but got {zero_division}.") def _multiclass_stat_scores_tensor_validation( @@ -560,6 +568,7 @@ def _multilabel_stat_scores_arg_validation( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, + zero_division: float = 0, ) -> None: """Validate non tensor input. @@ -568,6 +577,7 @@ def _multilabel_stat_scores_arg_validation( - ``average`` has to be "micro" | "macro" | "weighted" | "none" - ``multidim_average`` has to be either "global" or "samplewise" - ``ignore_index`` has to be None or int + - ``zero_division`` has to be 0 or 1 """ if not isinstance(num_labels, int) or num_labels < 2: @@ -584,6 +594,8 @@ def _multilabel_stat_scores_arg_validation( ) if ignore_index is not None and not isinstance(ignore_index, int): raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + if zero_division not in [0, 1]: + raise ValueError(f"Expected argument `zero_division` to be 0 or 1, but got {zero_division}.") def _multilabel_stat_scores_tensor_validation( diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index 12613103ca6..68cd344877d 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -43,16 +43,21 @@ def _safe_xlogy(x: Tensor, y: Tensor) -> Tensor: return res -def _safe_divide(num: Tensor, denom: Tensor) -> Tensor: +def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tensor: """Safe division, by preventing division by zero. - Additionally casts to float if input is not already to secure backwards compatibility. + Function will cast to float if input is not already to secure backwards compatibility. + + Args: + num: numerator tensor + denom: denominator tensor, which may contain zeros + zero_division: value to replace elements divided by zero """ - denom[denom == 0.0] = 1 num = num if num.is_floating_point() else num.float() denom = denom if denom.is_floating_point() else denom.float() - return num / denom + zero_division = torch.tensor(zero_division).float().to(num.device) + return torch.where(denom != 0, num / denom, zero_division) def _adjust_weights_safe_divide( diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index 6e80411f5c1..f21f29f8152 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -60,5 +60,6 @@ _MECAB_KO_DIC_AVAILABLE = RequirementCache("mecab_ko_dic") _IPADIC_AVAILABLE = RequirementCache("ipadic") _SENTENCEPIECE_AVAILABLE = RequirementCache("sentencepiece") +_SKLEARN_GREATER_EQUAL_1_3 = RequirementCache("scikit-learn>=1.3.0") _LATEX_AVAILABLE: bool = shutil.which("latex") is not None diff --git a/tests/unittests/classification/test_f_beta.py b/tests/unittests/classification/test_f_beta.py index 3a334708485..03c39d336fe 100644 --- a/tests/unittests/classification/test_f_beta.py +++ b/tests/unittests/classification/test_f_beta.py @@ -49,7 +49,7 @@ seed_all(42) -def _reference_sklearn_fbeta_score_binary(preds, target, sk_fn, ignore_index, multidim_average): +def _reference_sklearn_fbeta_score_binary(preds, target, sk_fn, ignore_index, multidim_average, zero_division=0): if multidim_average == "global": preds = preds.view(-1).numpy() target = target.view(-1).numpy() @@ -64,14 +64,14 @@ def _reference_sklearn_fbeta_score_binary(preds, target, sk_fn, ignore_index, mu if multidim_average == "global": target, preds = remove_ignore_index(target, preds, ignore_index) - return sk_fn(target, preds) + return sk_fn(target, preds, zero_division=zero_division) res = [] for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() true, pred = remove_ignore_index(true, pred, ignore_index) - res.append(sk_fn(true, pred)) + res.append(sk_fn(true, pred, zero_division=zero_division)) return np.stack(res) @@ -90,7 +90,10 @@ class TestBinaryFBetaScore(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_binary_fbeta_score(self, ddp, inputs, module, functional, compare, ignore_index, multidim_average): + @pytest.mark.parametrize("zero_division", [0, 1]) + def test_binary_fbeta_score( + self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, zero_division + ): """Test class implementation of metric.""" preds, target = inputs if ignore_index == -1: @@ -110,13 +113,22 @@ def test_binary_fbeta_score(self, ddp, inputs, module, functional, compare, igno sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average, + zero_division=zero_division, ), - metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "zero_division": zero_division, + }, ) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) - def test_binary_fbeta_score_functional(self, inputs, module, functional, compare, ignore_index, multidim_average): + @pytest.mark.parametrize("zero_division", [0, 1]) + def test_binary_fbeta_score_functional( + self, inputs, module, functional, compare, ignore_index, multidim_average, zero_division + ): """Test functional implementation of metric.""" preds, target = inputs if ignore_index == -1: @@ -133,11 +145,13 @@ def test_binary_fbeta_score_functional(self, inputs, module, functional, compare sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average, + zero_division=zero_division, ), metric_args={ "threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average, + "zero_division": zero_division, }, ) @@ -183,14 +197,22 @@ def test_binary_fbeta_score_half_gpu(self, inputs, module, functional, compare, ) -def _reference_sklearn_fbeta_score_multiclass(preds, target, sk_fn, ignore_index, multidim_average, average): +def _reference_sklearn_fbeta_score_multiclass( + preds, target, sk_fn, ignore_index, multidim_average, average, zero_division=0 +): if preds.ndim == target.ndim + 1: preds = torch.argmax(preds, 1) if multidim_average == "global": preds = preds.numpy().flatten() target = target.numpy().flatten() target, preds = remove_ignore_index(target, preds, ignore_index) - return sk_fn(target, preds, average=average, labels=list(range(NUM_CLASSES)) if average is None else None) + return sk_fn( + target, + preds, + average=average, + labels=list(range(NUM_CLASSES)) if average is None else None, + zero_division=zero_division, + ) preds = preds.numpy() target = target.numpy() @@ -199,7 +221,24 @@ def _reference_sklearn_fbeta_score_multiclass(preds, target, sk_fn, ignore_index pred = pred.flatten() true = true.flatten() true, pred = remove_ignore_index(true, pred, ignore_index) - r = sk_fn(true, pred, average=average, labels=list(range(NUM_CLASSES)) if average is None else None) + + if len(pred) == 0 and average == "weighted": + # The result of sk_fn([], [], labels=None, average="weighted", zero_division=zero_division) + # varies depending on the sklearn version: + # 1.2 -> the value of zero_division + # 1.3 -> nan + # 1.4 -> nan + # To avoid breaking some test cases by this behavior, + # hard coded to return 0 in this special case. + r = 0.0 + else: + r = sk_fn( + true, + pred, + average=average, + labels=list(range(NUM_CLASSES)) if average is None else None, + zero_division=zero_division, + ) res.append(0.0 if np.isnan(r).any() else r) return np.stack(res, 0) @@ -224,8 +263,9 @@ class TestMulticlassFBetaScore(MetricTester): @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + @pytest.mark.parametrize("zero_division", [0, 1]) def test_multiclass_fbeta_score( - self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, average + self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, average, zero_division ): """Test class implementation of metric.""" preds, target = inputs @@ -247,20 +287,23 @@ def test_multiclass_fbeta_score( ignore_index=ignore_index, multidim_average=multidim_average, average=average, + zero_division=zero_division, ), metric_args={ "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, "num_classes": NUM_CLASSES, + "zero_division": zero_division, }, ) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("zero_division", [0, 1]) def test_multiclass_fbeta_score_functional( - self, inputs, module, functional, compare, ignore_index, multidim_average, average + self, inputs, module, functional, compare, ignore_index, multidim_average, average, zero_division ): """Test functional implementation of metric.""" preds, target = inputs @@ -279,12 +322,14 @@ def test_multiclass_fbeta_score_functional( ignore_index=ignore_index, multidim_average=multidim_average, average=average, + zero_division=zero_division, ), metric_args={ "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, "num_classes": NUM_CLASSES, + "zero_division": zero_division, }, ) @@ -368,18 +413,18 @@ def test_top_k( assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) -def _reference_sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, average): +def _reference_sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, average, zero_division): if average == "micro": preds = preds.flatten() target = target.flatten() target, preds = remove_ignore_index(target, preds, ignore_index) - return sk_fn(target, preds) + return sk_fn(target, preds, zero_division=zero_division) fbeta_score, weights = [], [] for i in range(preds.shape[1]): pred, true = preds[:, i].flatten(), target[:, i].flatten() true, pred = remove_ignore_index(true, pred, ignore_index) - fbeta_score.append(sk_fn(true, pred)) + fbeta_score.append(sk_fn(true, pred, zero_division=zero_division)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) weights.append(confmat[1, 1] + confmat[1, 0]) res = np.stack(fbeta_score, axis=0) @@ -396,13 +441,13 @@ def _reference_sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignor return None -def _reference_sklearn_fbeta_score_multilabel_local(preds, target, sk_fn, ignore_index, average): +def _reference_sklearn_fbeta_score_multilabel_local(preds, target, sk_fn, ignore_index, average, zero_division): fbeta_score, weights = [], [] for i in range(preds.shape[0]): if average == "micro": pred, true = preds[i].flatten(), target[i].flatten() true, pred = remove_ignore_index(true, pred, ignore_index) - fbeta_score.append(sk_fn(true, pred)) + fbeta_score.append(sk_fn(true, pred, zero_division=zero_division)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) weights.append(confmat[1, 1] + confmat[1, 0]) else: @@ -410,7 +455,7 @@ def _reference_sklearn_fbeta_score_multilabel_local(preds, target, sk_fn, ignore for j in range(preds.shape[1]): pred, true = preds[i, j], target[i, j] true, pred = remove_ignore_index(true, pred, ignore_index) - scores.append(sk_fn(true, pred)) + scores.append(sk_fn(true, pred, zero_division=zero_division)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) w.append(confmat[1, 1] + confmat[1, 0]) fbeta_score.append(np.stack(scores)) @@ -430,7 +475,9 @@ def _reference_sklearn_fbeta_score_multilabel_local(preds, target, sk_fn, ignore return None -def _reference_sklearn_fbeta_score_multilabel(preds, target, sk_fn, ignore_index, multidim_average, average): +def _reference_sklearn_fbeta_score_multilabel( + preds, target, sk_fn, ignore_index, multidim_average, average, zero_division=0 +): preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating): @@ -444,10 +491,13 @@ def _reference_sklearn_fbeta_score_multilabel(preds, target, sk_fn, ignore_index target.transpose(0, 2, 1).reshape(-1, NUM_CLASSES), preds.transpose(0, 2, 1).reshape(-1, NUM_CLASSES), average=average, + zero_division=zero_division, ) if multidim_average == "global": - return _reference_sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, average) - return _reference_sklearn_fbeta_score_multilabel_local(preds, target, sk_fn, ignore_index, average) + return _reference_sklearn_fbeta_score_multilabel_global( + preds, target, sk_fn, ignore_index, average, zero_division + ) + return _reference_sklearn_fbeta_score_multilabel_local(preds, target, sk_fn, ignore_index, average, zero_division) @pytest.mark.parametrize("inputs", _multilabel_cases) @@ -470,8 +520,9 @@ class TestMultilabelFBetaScore(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("zero_division", [0, 1]) def test_multilabel_fbeta_score( - self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, average + self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, average, zero_division ): """Test class implementation of metric.""" preds, target = inputs @@ -493,6 +544,7 @@ def test_multilabel_fbeta_score( ignore_index=ignore_index, multidim_average=multidim_average, average=average, + zero_division=zero_division, ), metric_args={ "num_labels": NUM_CLASSES, @@ -500,14 +552,16 @@ def test_multilabel_fbeta_score( "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, + "zero_division": zero_division, }, ) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("zero_division", [0, 1]) def test_multilabel_fbeta_score_functional( - self, inputs, module, functional, compare, ignore_index, multidim_average, average + self, inputs, module, functional, compare, ignore_index, multidim_average, average, zero_division ): """Test functional implementation of metric.""" preds, target = inputs @@ -526,6 +580,7 @@ def test_multilabel_fbeta_score_functional( ignore_index=ignore_index, multidim_average=multidim_average, average=average, + zero_division=zero_division, ), metric_args={ "num_labels": NUM_CLASSES, @@ -533,6 +588,7 @@ def test_multilabel_fbeta_score_functional( "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, + "zero_division": zero_division, }, ) diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index 8fa17ca1d32..6901868eac9 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -37,7 +37,7 @@ from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -def _reference_sklearn_jaccard_index_binary(preds, target, ignore_index=None): +def _reference_sklearn_jaccard_index_binary(preds, target, ignore_index=None, zero_division=0): preds = preds.view(-1).numpy() target = target.view(-1).numpy() if np.issubdtype(preds.dtype, np.floating): @@ -45,7 +45,7 @@ def _reference_sklearn_jaccard_index_binary(preds, target, ignore_index=None): preds = sigmoid(preds) preds = (preds >= THRESHOLD).astype(np.uint8) target, preds = remove_ignore_index(target, preds, ignore_index) - return sk_jaccard_index(y_true=target, y_pred=preds) + return sk_jaccard_index(y_true=target, y_pred=preds, zero_division=zero_division) @pytest.mark.parametrize("inputs", _binary_cases) @@ -53,8 +53,9 @@ class TestBinaryJaccardIndex(MetricTester): """Test class for `BinaryJaccardIndex` metric.""" @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("zero_division", [0, 1]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_binary_jaccard_index(self, inputs, ddp, ignore_index): + def test_binary_jaccard_index(self, inputs, ddp, ignore_index, zero_division): """Test class implementation of metric.""" preds, target = inputs if ignore_index is not None: @@ -64,15 +65,19 @@ def test_binary_jaccard_index(self, inputs, ddp, ignore_index): preds=preds, target=target, metric_class=BinaryJaccardIndex, - reference_metric=partial(_reference_sklearn_jaccard_index_binary, ignore_index=ignore_index), + reference_metric=partial( + _reference_sklearn_jaccard_index_binary, ignore_index=ignore_index, zero_division=zero_division + ), metric_args={ "threshold": THRESHOLD, "ignore_index": ignore_index, + "zero_division": zero_division, }, ) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) - def test_binary_jaccard_index_functional(self, inputs, ignore_index): + @pytest.mark.parametrize("zero_division", [0, 1]) + def test_binary_jaccard_index_functional(self, inputs, ignore_index, zero_division): """Test functional implementation of metric.""" preds, target = inputs if ignore_index is not None: @@ -81,11 +86,10 @@ def test_binary_jaccard_index_functional(self, inputs, ignore_index): preds=preds, target=target, metric_functional=binary_jaccard_index, - reference_metric=partial(_reference_sklearn_jaccard_index_binary, ignore_index=ignore_index), - metric_args={ - "threshold": THRESHOLD, - "ignore_index": ignore_index, - }, + reference_metric=partial( + _reference_sklearn_jaccard_index_binary, ignore_index=ignore_index, zero_division=zero_division + ), + metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "zero_division": zero_division}, ) def test_binary_jaccard_index_differentiability(self, inputs): @@ -129,7 +133,7 @@ def test_binary_jaccard_index_dtype_gpu(self, inputs, dtype): ) -def _reference_sklearn_jaccard_index_multiclass(preds, target, ignore_index=None, average="macro"): +def _reference_sklearn_jaccard_index_multiclass(preds, target, ignore_index=None, average="macro", zero_division=0): preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating): @@ -137,13 +141,15 @@ def _reference_sklearn_jaccard_index_multiclass(preds, target, ignore_index=None preds = preds.flatten() target = target.flatten() target, preds = remove_ignore_index(target, preds, ignore_index) + if average is None: + return sk_jaccard_index( + y_true=target, y_pred=preds, average=average, labels=list(range(NUM_CLASSES)), zero_division=zero_division + ) if ignore_index is not None and 0 <= ignore_index < NUM_CLASSES: labels = [i for i in range(NUM_CLASSES) if i != ignore_index] - res = sk_jaccard_index(y_true=target, y_pred=preds, average=average, labels=labels) - return np.insert(res, ignore_index, 0.0) if average is None else res - if average is None: - return sk_jaccard_index(y_true=target, y_pred=preds, average=average, labels=list(range(NUM_CLASSES))) - return sk_jaccard_index(y_true=target, y_pred=preds, average=average) + res = sk_jaccard_index(y_true=target, y_pred=preds, average=average, labels=labels, zero_division=zero_division) + return np.insert(res, ignore_index, 0) if average is None else res + return sk_jaccard_index(y_true=target, y_pred=preds, average=average, zero_division=zero_division) @pytest.mark.parametrize("inputs", _multiclass_cases) @@ -152,8 +158,9 @@ class TestMulticlassJaccardIndex(MetricTester): @pytest.mark.parametrize("average", ["macro", "micro", "weighted", None]) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("zero_division", [0, 1]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_multiclass_jaccard_index(self, inputs, ddp, ignore_index, average): + def test_multiclass_jaccard_index(self, inputs, ddp, ignore_index, average, zero_division): """Test class implementation of metric.""" preds, target = inputs if ignore_index is not None: @@ -164,18 +171,23 @@ def test_multiclass_jaccard_index(self, inputs, ddp, ignore_index, average): target=target, metric_class=MulticlassJaccardIndex, reference_metric=partial( - _reference_sklearn_jaccard_index_multiclass, ignore_index=ignore_index, average=average + _reference_sklearn_jaccard_index_multiclass, + ignore_index=ignore_index, + average=average, + zero_division=zero_division, ), metric_args={ "num_classes": NUM_CLASSES, "ignore_index": ignore_index, "average": average, + "zero_division": zero_division, }, ) @pytest.mark.parametrize("average", ["macro", "micro", "weighted", None]) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) - def test_multiclass_jaccard_index_functional(self, inputs, ignore_index, average): + @pytest.mark.parametrize("zero_division", [0, 1]) + def test_multiclass_jaccard_index_functional(self, inputs, ignore_index, average, zero_division): """Test functional implementation of metric.""" preds, target = inputs if ignore_index is not None: @@ -185,12 +197,16 @@ def test_multiclass_jaccard_index_functional(self, inputs, ignore_index, average target=target, metric_functional=multiclass_jaccard_index, reference_metric=partial( - _reference_sklearn_jaccard_index_multiclass, ignore_index=ignore_index, average=average + _reference_sklearn_jaccard_index_multiclass, + ignore_index=ignore_index, + average=average, + zero_division=zero_division, ), metric_args={ "num_classes": NUM_CLASSES, "ignore_index": ignore_index, "average": average, + "zero_division": zero_division, }, ) @@ -233,7 +249,7 @@ def test_multiclass_jaccard_index_dtype_gpu(self, inputs, dtype): ) -def _reference_sklearn_jaccard_index_multilabel(preds, target, ignore_index=None, average="macro"): +def _reference_sklearn_jaccard_index_multilabel(preds, target, ignore_index=None, average="macro", zero_division=0): preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating): @@ -243,16 +259,18 @@ def _reference_sklearn_jaccard_index_multilabel(preds, target, ignore_index=None preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) target = np.moveaxis(target, 1, -1).reshape((-1, target.shape[1])) if ignore_index is None: - return sk_jaccard_index(y_true=target, y_pred=preds, average=average) + return sk_jaccard_index(y_true=target, y_pred=preds, average=average, zero_division=zero_division) if average == "micro": - return _reference_sklearn_jaccard_index_binary(torch.tensor(preds), torch.tensor(target), ignore_index) + return _reference_sklearn_jaccard_index_binary( + torch.tensor(preds), torch.tensor(target), ignore_index, zero_division=zero_division + ) scores, weights = [], [] for i in range(preds.shape[1]): pred, true = preds[:, i], target[:, i] true, pred = remove_ignore_index(true, pred, ignore_index) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) - scores.append(sk_jaccard_index(true, pred)) + scores.append(sk_jaccard_index(true, pred, zero_division=zero_division)) weights.append(confmat[1, 0] + confmat[1, 1]) scores = np.stack(scores, axis=0) weights = np.stack(weights, axis=0) @@ -269,8 +287,9 @@ class TestMultilabelJaccardIndex(MetricTester): @pytest.mark.parametrize("average", ["macro", "micro", "weighted", None]) @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("zero_division", [0, 1]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_multilabel_jaccard_index(self, inputs, ddp, ignore_index, average): + def test_multilabel_jaccard_index(self, inputs, ddp, ignore_index, average, zero_division): """Test class implementation of metric.""" preds, target = inputs if ignore_index is not None: @@ -281,18 +300,23 @@ def test_multilabel_jaccard_index(self, inputs, ddp, ignore_index, average): target=target, metric_class=MultilabelJaccardIndex, reference_metric=partial( - _reference_sklearn_jaccard_index_multilabel, ignore_index=ignore_index, average=average + _reference_sklearn_jaccard_index_multilabel, + ignore_index=ignore_index, + average=average, + zero_division=zero_division, ), metric_args={ "num_labels": NUM_CLASSES, "ignore_index": ignore_index, "average": average, + "zero_division": zero_division, }, ) @pytest.mark.parametrize("average", ["macro", "micro", "weighted", None]) @pytest.mark.parametrize("ignore_index", [None, -1]) - def test_multilabel_jaccard_index_functional(self, inputs, ignore_index, average): + @pytest.mark.parametrize("zero_division", [0, 1]) + def test_multilabel_jaccard_index_functional(self, inputs, ignore_index, average, zero_division): """Test functional implementation of metric.""" preds, target = inputs if ignore_index is not None: @@ -302,12 +326,16 @@ def test_multilabel_jaccard_index_functional(self, inputs, ignore_index, average target=target, metric_functional=multilabel_jaccard_index, reference_metric=partial( - _reference_sklearn_jaccard_index_multilabel, ignore_index=ignore_index, average=average + _reference_sklearn_jaccard_index_multilabel, + ignore_index=ignore_index, + average=average, + zero_division=zero_division, ), metric_args={ "num_labels": NUM_CLASSES, "ignore_index": ignore_index, "average": average, + "zero_division": zero_division, }, ) diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index 86fbe262aea..00eee202cc0 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -49,7 +49,7 @@ seed_all(42) -def _reference_sklearn_precision_recall_binary(preds, target, sk_fn, ignore_index, multidim_average): +def _reference_sklearn_precision_recall_binary(preds, target, sk_fn, ignore_index, multidim_average, zero_division=0): if multidim_average == "global": preds = preds.view(-1).numpy() target = target.view(-1).numpy() @@ -64,14 +64,14 @@ def _reference_sklearn_precision_recall_binary(preds, target, sk_fn, ignore_inde if multidim_average == "global": target, preds = remove_ignore_index(target, preds, ignore_index) - return sk_fn(target, preds) + return sk_fn(target, preds, zero_division=zero_division) res = [] for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() true, pred = remove_ignore_index(true, pred, ignore_index) - res.append(sk_fn(true, pred)) + res.append(sk_fn(true, pred, zero_division=zero_division)) return np.stack(res) @@ -90,7 +90,10 @@ class TestBinaryPrecisionRecall(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_binary_precision_recall(self, ddp, inputs, module, functional, compare, ignore_index, multidim_average): + @pytest.mark.parametrize("zero_division", [0, 1]) + def test_binary_precision_recall( + self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, zero_division + ): """Test class implementation of metric.""" preds, target = inputs if ignore_index == -1: @@ -110,14 +113,21 @@ def test_binary_precision_recall(self, ddp, inputs, module, functional, compare, sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average, + zero_division=zero_division, ), - metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "zero_division": zero_division, + }, ) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("zero_division", [0, 1]) def test_binary_precision_recall_functional( - self, inputs, module, functional, compare, ignore_index, multidim_average + self, inputs, module, functional, compare, ignore_index, multidim_average, zero_division ): """Test functional implementation of metric.""" preds, target = inputs @@ -135,11 +145,13 @@ def test_binary_precision_recall_functional( sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average, + zero_division=zero_division, ), metric_args={ "threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average, + "zero_division": zero_division, }, ) @@ -184,7 +196,9 @@ def test_binary_precision_recall_half_gpu(self, inputs, module, functional, comp ) -def _reference_sklearn_precision_recall_multiclass(preds, target, sk_fn, ignore_index, multidim_average, average): +def _reference_sklearn_precision_recall_multiclass( + preds, target, sk_fn, ignore_index, multidim_average, average, zero_division=0 +): if preds.ndim == target.ndim + 1: preds = torch.argmax(preds, 1) @@ -192,7 +206,13 @@ def _reference_sklearn_precision_recall_multiclass(preds, target, sk_fn, ignore_ preds = preds.numpy().flatten() target = target.numpy().flatten() target, preds = remove_ignore_index(target, preds, ignore_index) - return sk_fn(target, preds, average=average, labels=list(range(NUM_CLASSES)) if average is None else None) + return sk_fn( + target, + preds, + average=average, + labels=list(range(NUM_CLASSES)) if average is None else None, + zero_division=zero_division, + ) preds = preds.numpy() target = target.numpy() @@ -201,7 +221,23 @@ def _reference_sklearn_precision_recall_multiclass(preds, target, sk_fn, ignore_ pred = pred.flatten() true = true.flatten() true, pred = remove_ignore_index(true, pred, ignore_index) - r = sk_fn(true, pred, average=average, labels=list(range(NUM_CLASSES)) if average is None else None) + if len(pred) == 0 and average == "weighted": + # The result of sk_fn([], [], labels=None, average="weighted", zero_division=zero_division) + # varies depending on the sklearn version: + # 1.2 -> the value of zero_division + # 1.3 -> nan + # 1.4 -> nan + # To avoid breaking some test cases by this behavior, + # hard coded to return 0 in this special case. + r = 0.0 + else: + r = sk_fn( + true, + pred, + average=average, + labels=list(range(NUM_CLASSES)) if average is None else None, + zero_division=zero_division, + ) res.append(0.0 if np.isnan(r).any() else r) return np.stack(res, 0) @@ -223,8 +259,18 @@ class TestMulticlassPrecisionRecall(MetricTester): @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + @pytest.mark.parametrize("zero_division", [0, 1]) def test_multiclass_precision_recall( - self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, average + self, + ddp, + inputs, + module, + functional, + compare, + ignore_index, + multidim_average, + average, + zero_division, ): """Test class implementation of metric.""" preds, target = inputs @@ -246,20 +292,23 @@ def test_multiclass_precision_recall( ignore_index=ignore_index, multidim_average=multidim_average, average=average, + zero_division=zero_division, ), metric_args={ "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, "num_classes": NUM_CLASSES, + "zero_division": zero_division, }, ) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("zero_division", [0, 1]) def test_multiclass_precision_recall_functional( - self, inputs, module, functional, compare, ignore_index, multidim_average, average + self, inputs, module, functional, compare, ignore_index, multidim_average, average, zero_division ): """Test functional implementation of metric.""" preds, target = inputs @@ -278,12 +327,14 @@ def test_multiclass_precision_recall_functional( ignore_index=ignore_index, multidim_average=multidim_average, average=average, + zero_division=zero_division, ), metric_args={ "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, "num_classes": NUM_CLASSES, + "zero_division": zero_division, }, ) @@ -367,18 +418,18 @@ def test_top_k( assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) -def _reference_sklearn_precision_recall_multilabel_global(preds, target, sk_fn, ignore_index, average): +def _reference_sklearn_precision_recall_multilabel_global(preds, target, sk_fn, ignore_index, average, zero_division): if average == "micro": preds = preds.flatten() target = target.flatten() target, preds = remove_ignore_index(target, preds, ignore_index) - return sk_fn(target, preds) + return sk_fn(target, preds, zero_division=zero_division) precision_recall, weights = [], [] for i in range(preds.shape[1]): pred, true = preds[:, i].flatten(), target[:, i].flatten() true, pred = remove_ignore_index(true, pred, ignore_index) - precision_recall.append(sk_fn(true, pred)) + precision_recall.append(sk_fn(true, pred, zero_division=zero_division)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) weights.append(confmat[1, 1] + confmat[1, 0]) res = np.stack(precision_recall, axis=0) @@ -395,13 +446,13 @@ def _reference_sklearn_precision_recall_multilabel_global(preds, target, sk_fn, return None -def _reference_sklearn_precision_recall_multilabel_local(preds, target, sk_fn, ignore_index, average): +def _reference_sklearn_precision_recall_multilabel_local(preds, target, sk_fn, ignore_index, average, zero_division): precision_recall, weights = [], [] for i in range(preds.shape[0]): if average == "micro": pred, true = preds[i].flatten(), target[i].flatten() true, pred = remove_ignore_index(true, pred, ignore_index) - precision_recall.append(sk_fn(true, pred)) + precision_recall.append(sk_fn(true, pred, zero_division=zero_division)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) weights.append(confmat[1, 1] + confmat[1, 0]) else: @@ -409,7 +460,7 @@ def _reference_sklearn_precision_recall_multilabel_local(preds, target, sk_fn, i for j in range(preds.shape[1]): pred, true = preds[i, j], target[i, j] true, pred = remove_ignore_index(true, pred, ignore_index) - scores.append(sk_fn(true, pred)) + scores.append(sk_fn(true, pred, zero_division=zero_division)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) w.append(confmat[1, 1] + confmat[1, 0]) precision_recall.append(np.stack(scores)) @@ -429,7 +480,9 @@ def _reference_sklearn_precision_recall_multilabel_local(preds, target, sk_fn, i return None -def _reference_sklearn_precision_recall_multilabel(preds, target, sk_fn, ignore_index, multidim_average, average): +def _reference_sklearn_precision_recall_multilabel( + preds, target, sk_fn, ignore_index, multidim_average, average, zero_division=0 +): preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating): @@ -443,10 +496,15 @@ def _reference_sklearn_precision_recall_multilabel(preds, target, sk_fn, ignore_ target.transpose(0, 2, 1).reshape(-1, NUM_CLASSES), preds.transpose(0, 2, 1).reshape(-1, NUM_CLASSES), average=average, + zero_division=zero_division, ) if multidim_average == "global": - return _reference_sklearn_precision_recall_multilabel_global(preds, target, sk_fn, ignore_index, average) - return _reference_sklearn_precision_recall_multilabel_local(preds, target, sk_fn, ignore_index, average) + return _reference_sklearn_precision_recall_multilabel_global( + preds, target, sk_fn, ignore_index, average, zero_division + ) + return _reference_sklearn_precision_recall_multilabel_local( + preds, target, sk_fn, ignore_index, average, zero_division + ) @pytest.mark.parametrize("inputs", _multilabel_cases) @@ -465,8 +523,9 @@ class TestMultilabelPrecisionRecall(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("zero_division", [0, 1]) def test_multilabel_precision_recall( - self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, average + self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, average, zero_division ): """Test class implementation of metric.""" preds, target = inputs @@ -488,6 +547,7 @@ def test_multilabel_precision_recall( ignore_index=ignore_index, multidim_average=multidim_average, average=average, + zero_division=zero_division, ), metric_args={ "num_labels": NUM_CLASSES, @@ -495,14 +555,16 @@ def test_multilabel_precision_recall( "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, + "zero_division": zero_division, }, ) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("zero_division", [0, 1]) def test_multilabel_precision_recall_functional( - self, inputs, module, functional, compare, ignore_index, multidim_average, average + self, inputs, module, functional, compare, ignore_index, multidim_average, average, zero_division ): """Test functional implementation of metric.""" preds, target = inputs @@ -521,6 +583,7 @@ def test_multilabel_precision_recall_functional( ignore_index=ignore_index, multidim_average=multidim_average, average=average, + zero_division=zero_division, ), metric_args={ "num_labels": NUM_CLASSES, @@ -528,6 +591,7 @@ def test_multilabel_precision_recall_functional( "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, + "zero_division": zero_division, }, ) diff --git a/tests/unittests/classification/test_sensitivity_specificity.py b/tests/unittests/classification/test_sensitivity_specificity.py index d629c86583a..df01b67e4ab 100644 --- a/tests/unittests/classification/test_sensitivity_specificity.py +++ b/tests/unittests/classification/test_sensitivity_specificity.py @@ -33,7 +33,7 @@ multilabel_sensitivity_at_specificity, ) from torchmetrics.metric import Metric -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_11 +from torchmetrics.utilities.imports import _SKLEARN_GREATER_EQUAL_1_3, _TORCH_GREATER_EQUAL_1_11 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -83,6 +83,7 @@ def _reference_sklearn_sensitivity_at_specificity_binary(preds, target, min_spec return _sensitivity_at_specificity_x_multilabel(preds, target, min_specificity) +@pytest.mark.skipif(not _SKLEARN_GREATER_EQUAL_1_3, reason="metric does not support scikit-learn versions below 1.3") @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11") @pytest.mark.parametrize("inputs", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) class TestBinarySensitivityAtSpecificity(MetricTester): @@ -209,6 +210,7 @@ def _reference_sklearn_sensitivity_at_specificity_multiclass(preds, target, min_ return sensitivity, thresholds +@pytest.mark.skipif(not _SKLEARN_GREATER_EQUAL_1_3, reason="metric does not support scikit-learn versions below 1.3") @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11") @pytest.mark.parametrize( "inputs", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) @@ -340,6 +342,7 @@ def _reference_sklearn_sensitivity_at_specificity_multilabel(preds, target, min_ return sensitivity, thresholds +@pytest.mark.skipif(not _SKLEARN_GREATER_EQUAL_1_3, reason="metric does not support scikit-learn versions below 1.3") @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11") @pytest.mark.parametrize( "inputs", (_multilabel_cases[1], _multilabel_cases[2], _multilabel_cases[4], _multilabel_cases[5]) diff --git a/tests/unittests/retrieval/test_ndcg.py b/tests/unittests/retrieval/test_ndcg.py index 48e0c679195..1f68839eb4e 100644 --- a/tests/unittests/retrieval/test_ndcg.py +++ b/tests/unittests/retrieval/test_ndcg.py @@ -80,6 +80,7 @@ def test_class_metric( "ignore_index": ignore_index, "aggregation": aggregation, } + target = target if target.min() >= 0 else target - target.min() self.run_class_metric_test( ddp=ddp, @@ -107,6 +108,7 @@ def test_class_metric_ignore_index( """Test class implementation of metric with ignore_index argument.""" metric_args = {"empty_target_action": empty_target_action, "top_k": k, "ignore_index": -100} + target = target if target.min() >= 0 else target - target.min() self.run_class_metric_test( ddp=ddp, indexes=indexes, @@ -121,6 +123,7 @@ def test_class_metric_ignore_index( @pytest.mark.parametrize("k", [None, 1, 4, 10]) def test_functional_metric(self, preds: Tensor, target: Tensor, k: int): """Test functional implementation of metric.""" + target = target if target.min() >= 0 else target - target.min() self.run_functional_metric_test( preds=preds, target=target, @@ -133,6 +136,7 @@ def test_functional_metric(self, preds: Tensor, target: Tensor, k: int): @pytest.mark.parametrize(**_default_metric_class_input_arguments_with_non_binary_target) def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor): """Test dtype support of the metric on CPU.""" + target = target if target.min() >= 0 else target - target.min() self.run_precision_test_cpu( indexes=indexes, preds=preds, @@ -144,6 +148,7 @@ def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor): @pytest.mark.parametrize(**_default_metric_class_input_arguments_with_non_binary_target) def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): """Test dtype support of the metric on GPU.""" + target = target if target.min() >= 0 else target - target.min() self.run_precision_test_gpu( indexes=indexes, preds=preds,