Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix topk metrics #1330

Merged
merged 13 commits into from
Oct 20, 2021
4 changes: 2 additions & 2 deletions catalyst/callbacks/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union
from typing import Iterable, Union

import torch

Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(
self,
input_key: str,
target_key: str,
topk_args: List[int] = None,
topk_args: Iterable[int] = None,
num_classes: int = None,
log_on_batch: bool = True,
prefix: str = None,
Expand Down
6 changes: 3 additions & 3 deletions catalyst/callbacks/metrics/cmc_score.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Iterable

from catalyst.callbacks.metric import LoaderMetricCallback
from catalyst.metrics._cmc_score import CMCMetric, ReidCMCMetric
Expand Down Expand Up @@ -134,7 +134,7 @@ def __init__(
embeddings_key: str,
labels_key: str,
is_query_key: str,
topk_args: List[int] = None,
topk_args: Iterable[int] = None,
prefix: str = None,
suffix: str = None,
):
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(
pids_key: str,
cids_key: str,
is_query_key: str,
topk_args: List[int] = None,
topk_args: Iterable[int] = None,
prefix: str = None,
suffix: str = None,
):
Expand Down
10 changes: 5 additions & 5 deletions catalyst/callbacks/metrics/recsys.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Iterable

from catalyst.callbacks.metric import BatchMetricCallback
from catalyst.metrics._hitrate import HitrateMetric
Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(
self,
input_key: str,
target_key: str,
topk_args: List[int] = None,
topk_args: Iterable[int] = None,
log_on_batch: bool = True,
prefix: str = None,
suffix: str = None,
Expand Down Expand Up @@ -194,7 +194,7 @@ def __init__(
self,
input_key: str,
target_key: str,
topk_args: List[int] = None,
topk_args: Iterable[int] = None,
log_on_batch: bool = True,
prefix: str = None,
suffix: str = None,
Expand Down Expand Up @@ -295,7 +295,7 @@ def __init__(
self,
input_key: str,
target_key: str,
topk_args: List[int] = None,
topk_args: Iterable[int] = None,
log_on_batch: bool = True,
prefix: str = None,
suffix: str = None,
Expand Down Expand Up @@ -396,7 +396,7 @@ def __init__(
self,
input_key: str,
target_key: str,
topk_args: List[int] = None,
topk_args: Iterable[int] = None,
log_on_batch: bool = True,
prefix: str = None,
suffix: str = None,
Expand Down
97 changes: 13 additions & 84 deletions catalyst/metrics/_accuracy.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, Iterable, Optional, Union

import numpy as np

import torch

from catalyst.metrics._additive import AdditiveMetric
from catalyst.metrics._metric import ICallbackBatchMetric
from catalyst.metrics._topk_metric import TopKMetric
from catalyst.metrics.functional._accuracy import accuracy, multilabel_accuracy
from catalyst.metrics.functional._misc import get_default_topk_args


class AccuracyMetric(ICallbackBatchMetric):
class AccuracyMetric(TopKMetric):
"""
This metric computes accuracy for multiclass classification case.
It computes mean value of accuracy and it's approximate std value
Expand Down Expand Up @@ -49,8 +50,6 @@ class AccuracyMetric(ICallbackBatchMetric):

metric.compute_key_value()
# {
# 'accuracy': 0.5,
# 'accuracy/std': 0.0,
# 'accuracy01': 0.5,
# 'accuracy01/std': 0.0,
# 'accuracy03': 1.0,
Expand Down Expand Up @@ -121,92 +120,22 @@ class AccuracyMetric(ICallbackBatchMetric):

def __init__(
self,
topk_args: List[int] = None,
topk_args: Iterable[int] = None,
num_classes: int = None,
compute_on_call: bool = True,
prefix: str = None,
suffix: str = None,
):
"""Init AccuracyMetric"""
super().__init__(compute_on_call=compute_on_call, prefix=prefix, suffix=suffix)
self.metric_name_mean = f"{self.prefix}accuracy{self.suffix}"
self.metric_name_std = f"{self.prefix}accuracy{self.suffix}/std"
self.topk_args: List[int] = topk_args or get_default_topk_args(num_classes)
self.additive_metrics: List[AdditiveMetric] = [
AdditiveMetric() for _ in range(len(self.topk_args))
]

def reset(self) -> None:
"""Reset all fields"""
for metric in self.additive_metrics:
metric.reset()

def update(self, logits: torch.Tensor, targets: torch.Tensor) -> List[float]:
"""
Updates metric value with accuracy for new data and return intermediate metrics values.

Args:
logits: tensor of logits
targets: tensor of targets

Returns:
list of accuracy@k values
"""
values = accuracy(logits, targets, topk=self.topk_args)
values = [v.item() for v in values]
for value, metric in zip(values, self.additive_metrics):
metric.update(value, len(targets))
return values

def update_key_value(self, logits: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]:
"""
Update metric value with accuracy for new data and return intermediate metrics
values in key-value format.

Args:
logits: tensor of logits
targets: tensor of targets

Returns:
dict of accuracy@k values
"""
values = self.update(logits=logits, targets=targets)
output = {
f"{self.prefix}accuracy{key:02d}{self.suffix}": value
for key, value in zip(self.topk_args, values)
}
output[self.metric_name_mean] = output[f"{self.prefix}accuracy01{self.suffix}"]
return output

def compute(self) -> Tuple[List[float], List[float]]:
"""
Compute accuracy for all data

Returns:
list of mean values, list of std values
"""
means, stds = zip(*(metric.compute() for metric in self.additive_metrics))
return means, stds

def compute_key_value(self) -> Dict[str, float]:
"""
Compute accuracy for all data and return results in key-value format

Returns:
dict of metrics
"""
means, stds = self.compute()
output_mean = {
f"{self.prefix}accuracy{key:02d}{self.suffix}": value
for key, value in zip(self.topk_args, means)
}
output_std = {
f"{self.prefix}accuracy{key:02d}{self.suffix}/std": value
for key, value in zip(self.topk_args, stds)
}
output_mean[self.metric_name_mean] = output_mean[f"{self.prefix}accuracy01{self.suffix}"]
output_std[self.metric_name_std] = output_std[f"{self.prefix}accuracy01{self.suffix}/std"]
return {**output_mean, **output_std}
self.topk_args = topk_args or get_default_topk_args(num_classes)
super().__init__(
metric_name="accuracy",
metric_function=accuracy,
topk_args=self.topk_args,
compute_on_call=compute_on_call,
prefix=prefix,
suffix=suffix,
)


class MultilabelAccuracyMetric(AdditiveMetric, ICallbackBatchMetric):
Expand Down
8 changes: 4 additions & 4 deletions catalyst/metrics/_cmc_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def __init__(
self.labels_key = labels_key
self.is_query_key = is_query_key
self.topk_args = topk_args or (1,)
self.metric_name = f"{self.prefix}cmc{self.suffix}"

def reset(self, num_batches: int, num_samples: int) -> None:
"""
Expand Down Expand Up @@ -227,7 +226,8 @@ def compute_key_value(self) -> Dict[str, float]:
"""
values = self.compute()
kv_metrics = {
f"{self.metric_name}{k:02d}": value for k, value in zip(self.topk_args, values)
f"{self.prefix}cmc{k:02d}{self.suffix}": value
for k, value in zip(self.topk_args, values)
}
return kv_metrics

Expand Down Expand Up @@ -310,7 +310,6 @@ def __init__(
self.cids_key = cids_key
self.is_query_key = is_query_key
self.topk_args = topk_args or (1,)
self.metric_name = f"{self.prefix}cmc{self.suffix}"

def reset(self, num_batches: int, num_samples: int) -> None:
"""
Expand Down Expand Up @@ -384,7 +383,8 @@ def compute_key_value(self) -> Dict[str, float]:
"""
values = self.compute()
kv_metrics = {
f"{self.metric_name}{k:02d}": value for k, value in zip(self.topk_args, values)
f"{self.prefix}cmc{k:02d}{self.suffix}": value
for k, value in zip(self.topk_args, values)
}
return kv_metrics

Expand Down
100 changes: 12 additions & 88 deletions catalyst/metrics/_hitrate.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from typing import Any, Dict, List
from typing import Iterable

import torch

from catalyst.metrics._additive import AdditiveMetric
from catalyst.metrics._metric import ICallbackBatchMetric
from catalyst.metrics._topk_metric import TopKMetric
from catalyst.metrics.functional._hitrate import hitrate


class HitrateMetric(ICallbackBatchMetric):
class HitrateMetric(TopKMetric):
"""Calculates the hitrate.

Args:
Expand Down Expand Up @@ -39,8 +36,6 @@ class HitrateMetric(ICallbackBatchMetric):

metric.compute_key_value()
# {
# 'hitrate': 0.0,
# 'hitrate/std': 0.0,
# 'hitrate01': 0.0,
# 'hitrate01/std': 0.0,
# 'hitrate02': 0.25,
Expand Down Expand Up @@ -125,91 +120,20 @@ class HitrateMetric(ICallbackBatchMetric):

def __init__(
self,
topk_args: List[int] = None,
topk_args: Iterable[int] = None,
compute_on_call: bool = True,
prefix: str = None,
suffix: str = None,
):
"""Init HitrateMetric"""
super().__init__(compute_on_call=compute_on_call, prefix=prefix, suffix=suffix)
self.metric_name_mean = f"{self.prefix}hitrate{self.suffix}"
self.metric_name_std = f"{self.prefix}hitrate{self.suffix}/std"
self.topk_args: List[int] = topk_args or [1]
self.additive_metrics: List[AdditiveMetric] = [
AdditiveMetric() for _ in range(len(self.topk_args))
]

def reset(self) -> None:
"""Reset all fields"""
for metric in self.additive_metrics:
metric.reset()

def update(self, logits: torch.Tensor, targets: torch.Tensor) -> List[float]:
"""
Update metric value with hitrate for new data and return intermediate metrics values.

Args:
logits (torch.Tensor): tensor of logits
targets (torch.Tensor): tensor of targets

Returns:
list of hitrate@k values
"""
values = hitrate(logits, targets, topk=self.topk_args)
values = [v.item() for v in values]
for value, metric in zip(values, self.additive_metrics):
metric.update(value, len(targets))
return values

def update_key_value(self, logits: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]:
"""
Update metric value with hitrate for new data and return intermediate metrics
values in key-value format.

Args:
logits (torch.Tensor): tensor of logits
targets (torch.Tensor): tensor of targets

Returns:
dict of hitrate@k values
"""
values = self.update(logits=logits, targets=targets)
output = {
f"{self.prefix}hitrate{key:02d}{self.suffix}": value
for key, value in zip(self.topk_args, values)
}
output[self.metric_name_mean] = output[f"{self.prefix}hitrate01{self.suffix}"]
return output

def compute(self) -> Any:
"""
Compute hitrate for all data

Returns:
list of mean values, list of std values
"""
means, stds = zip(*(metric.compute() for metric in self.additive_metrics))
return means, stds

def compute_key_value(self) -> Dict[str, float]:
"""
Compute hitrate for all data and return results in key-value format

Returns:
dict of metrics
"""
means, stds = self.compute()
output_mean = {
f"{self.prefix}hitrate{key:02d}{self.suffix}": value
for key, value in zip(self.topk_args, means)
}
output_std = {
f"{self.prefix}hitrate{key:02d}{self.suffix}/std": value
for key, value in zip(self.topk_args, stds)
}
output_mean[self.metric_name_mean] = output_mean[f"{self.prefix}hitrate01{self.suffix}"]
output_std[self.metric_name_std] = output_std[f"{self.prefix}hitrate01{self.suffix}/std"]
return {**output_mean, **output_std}
super().__init__(
metric_name="hitrate",
metric_function=hitrate,
topk_args=topk_args,
compute_on_call=compute_on_call,
prefix=prefix,
suffix=suffix,
)


__all__ = ["HitrateMetric"]
Loading