Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add movielens * First hitrate metric version * hit-rate implemenation * typo in tests * fixed codestyle * fixed typos * fix conflicts * added docs * add hitrate to init.py * edit changelog * fix conflict Co-authored-by: Daniel Chepenko <dchepenk@yahoo-corp.co> Co-authored-by: denyhoof <kde97@yandex.ru>
- Loading branch information
1 parent
553f699
commit 590f680
Showing
5 changed files
with
80 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
""" | ||
Hitrate metric: | ||
* :func:`hitrate` | ||
""" | ||
import torch | ||
|
||
|
||
def hitrate( | ||
outputs: torch.Tensor, targets: torch.Tensor, k=10 | ||
) -> torch.Tensor: | ||
""" | ||
Calculate the hit rate score given model outputs and targets. | ||
Hit-rate is a metric for evaluating ranking systems. | ||
Generate top-N recommendations and if one of the recommendation is | ||
actually what user has rated, you consider that a hit. | ||
By rate we mean any explicit form of user's interactions. | ||
Add up all of the hits for all users and then divide by number of users | ||
Compute top-N recomendation for each user in the training stage | ||
and intentionally remove one of this items fro the training data. | ||
Args: | ||
outputs (torch.Tensor): | ||
Tensor weith predicted score | ||
size: [batch_size, slate_length] | ||
model outputs, logits | ||
targets (torch.Tensor): | ||
Binary tensor with ground truth. | ||
1 means the item is relevant | ||
for the user and 0 not relevant | ||
size: [batch_szie, slate_length] | ||
ground truth, labels | ||
k (int): | ||
Parameter fro evaluation on top-k items | ||
Returns: | ||
hitrate (torch.Tensor): the hit rate score | ||
""" | ||
k = min(outputs.size(1), k) | ||
|
||
_, indices_for_sort = outputs.sort(descending=True, dim=-1) | ||
true_sorted_by_preds = torch.gather( | ||
targets, dim=-1, index=indices_for_sort | ||
) | ||
true_sorted_by_pred_shrink = true_sorted_by_preds[:, :k] | ||
hits = torch.sum(true_sorted_by_pred_shrink, dim=1) / k | ||
return hits | ||
|
||
|
||
__all__ = ["hitrate"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
|
||
from catalyst.utils import metrics | ||
|
||
|
||
def test_hitrate(): | ||
""" | ||
Tests for catalyst.utils.metrics.hitrate metric. | ||
""" | ||
y_pred = [0.5, 0.2] | ||
y_true = [1.0, 0.0] | ||
|
||
hitrate = metrics.hitrate(torch.Tensor([y_pred]), torch.Tensor([y_true])) | ||
assert hitrate == 0.5 | ||
|
||
# check 1 simple case | ||
y_pred = [0.5, 0.2] | ||
y_true = [0.0, 0.0] | ||
|
||
hitrate = metrics.hitrate(torch.Tensor([y_pred]), torch.Tensor([y_true])) | ||
assert hitrate == 0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters