Skip to content

Commit

Permalink
Recsys hitrate (#956)
Browse files Browse the repository at this point in the history
* add movielens

* First hitrate metric version

* hit-rate implemenation

* typo in tests

* fixed codestyle

* fixed typos

* fix conflicts

* added docs

* add hitrate to init.py

* edit changelog

* fix conflict

Co-authored-by: Daniel Chepenko <dchepenk@yahoo-corp.co>
Co-authored-by: denyhoof <kde97@yandex.ru>
  • Loading branch information
3 people committed Nov 7, 2020
1 parent 553f699 commit 590f680
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- hitrate calculations [#975] (https://github.com/catalyst-team/catalyst/pull/975)
- extra functions for classification metrics ([#966](https://github.com/catalyst-team/catalyst/pull/966))
- `OneOf` and `OneOfV2` batch transforms ([#951](https://github.com/catalyst-team/catalyst/pull/951))
- ``precision_recall_fbeta_support`` metric ([#971](https://github.com/catalyst-team/catalyst/pull/971))
Expand Down
1 change: 1 addition & 0 deletions catalyst/metrics/__init__.py
Expand Up @@ -4,6 +4,7 @@
from catalyst.metrics.cmc_score import cmc_score, cmc_score_count
from catalyst.metrics.dice import dice, calculate_dice
from catalyst.metrics.f1_score import f1_score, fbeta_score
from catalyst.metrics.hitrate import hitrate
from catalyst.metrics.classification import precision_recall_fbeta_support
from catalyst.metrics.precision import precision
from catalyst.metrics.recall import recall
Expand Down
50 changes: 50 additions & 0 deletions catalyst/metrics/hitrate.py
@@ -0,0 +1,50 @@
"""
Hitrate metric:
* :func:`hitrate`
"""
import torch


def hitrate(
outputs: torch.Tensor, targets: torch.Tensor, k=10
) -> torch.Tensor:
"""
Calculate the hit rate score given model outputs and targets.
Hit-rate is a metric for evaluating ranking systems.
Generate top-N recommendations and if one of the recommendation is
actually what user has rated, you consider that a hit.
By rate we mean any explicit form of user's interactions.
Add up all of the hits for all users and then divide by number of users
Compute top-N recomendation for each user in the training stage
and intentionally remove one of this items fro the training data.
Args:
outputs (torch.Tensor):
Tensor weith predicted score
size: [batch_size, slate_length]
model outputs, logits
targets (torch.Tensor):
Binary tensor with ground truth.
1 means the item is relevant
for the user and 0 not relevant
size: [batch_szie, slate_length]
ground truth, labels
k (int):
Parameter fro evaluation on top-k items
Returns:
hitrate (torch.Tensor): the hit rate score
"""
k = min(outputs.size(1), k)

_, indices_for_sort = outputs.sort(descending=True, dim=-1)
true_sorted_by_preds = torch.gather(
targets, dim=-1, index=indices_for_sort
)
true_sorted_by_pred_shrink = true_sorted_by_preds[:, :k]
hits = torch.sum(true_sorted_by_pred_shrink, dim=1) / k
return hits


__all__ = ["hitrate"]
21 changes: 21 additions & 0 deletions catalyst/metrics/tests/test_hitrate.py
@@ -0,0 +1,21 @@
import torch

from catalyst.utils import metrics


def test_hitrate():
"""
Tests for catalyst.utils.metrics.hitrate metric.
"""
y_pred = [0.5, 0.2]
y_true = [1.0, 0.0]

hitrate = metrics.hitrate(torch.Tensor([y_pred]), torch.Tensor([y_true]))
assert hitrate == 0.5

# check 1 simple case
y_pred = [0.5, 0.2]
y_true = [0.0, 0.0]

hitrate = metrics.hitrate(torch.Tensor([y_pred]), torch.Tensor([y_true]))
assert hitrate == 0.0
7 changes: 7 additions & 0 deletions docs/api/metrics.rst
Expand Up @@ -76,6 +76,13 @@ Precision
:undoc-members:
:show-inheritance:

Hitrate
------------------------
.. automodule:: catalyst.metrics.hitrate
:members:
:undoc-members:
:show-inheritance:

Functional
------------------------
.. automodule:: catalyst.metrics.functional
Expand Down

0 comments on commit 590f680

Please sign in to comment.