Skip to content

Commit

Permalink
Fix topk metrics (#1140)
Browse files Browse the repository at this point in the history
* add movielens

* First hitrate metric version

* metrge with upstream

* fix topk_map

* fix error in map

* tests for map

* top_k ndcg

* edit changelog

* edit the tabs

* check codestyle

* check the intent

* check the intent

* remove trailing whitespace

Co-authored-by: Daniel Chepenko <dchepenk@yahoo-corp.co>
Co-authored-by: denyhoof <kde97@yandex.ru>
Co-authored-by: Даниил <zkid18@MacBook-Pro.Dlink>
  • Loading branch information
4 people committed Mar 28, 2021
1 parent 2e3ef50 commit a7720d8
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 51 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -35,6 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Data-Model device sync and ``Engine`` logic during `runner.predict_loader` ([#1134](https://github.com/catalyst-team/catalyst/issues/1134))
- BatchLimitLoaderWrapper logic for loaders with shuffle flag ([#1136](https://github.com/catalyst-team/catalyst/issues/1136))
- RecSys metrics Top_k calculations ([#1140] (https://github.com/catalyst-team/catalyst/pull/1140))

## [21.03] - 2021-03-13 ([#1095](https://github.com/catalyst-team/catalyst/issues/1095))

Expand Down
64 changes: 36 additions & 28 deletions catalyst/metrics/functional/_average_precision.py
Expand Up @@ -72,7 +72,7 @@ def binary_average_precision(
return ap


def average_precision(outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
def average_precision(outputs: torch.Tensor, targets: torch.Tensor, k: int) -> torch.Tensor:
"""
Calculate the Average Precision for RecSys.
The precision metric summarizes the fraction of relevant items
Expand Down Expand Up @@ -104,6 +104,8 @@ def average_precision(outputs: torch.Tensor, targets: torch.Tensor) -> torch.Ten
and 0 not relevant
size: [batch_szie, slate_length]
ground truth, labels
k:
Parameter for evaluation on top-k items
Returns:
ap_score (torch.Tensor):
Expand All @@ -112,21 +114,25 @@ def average_precision(outputs: torch.Tensor, targets: torch.Tensor) -> torch.Ten
Examples:
>>> average_precision(
>>> outputs=torch.tensor([
>>> [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
>>> [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
>>> ]),
>>> targets=torch.tensor([
>>> [1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0],
>>> [0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0],
>>> ]),
>>> outputs=torch.tensor([
>>> [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
>>> [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
>>> ]),
>>> targets=torch.tensor([
>>> [1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0],
>>> [0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0],
>>> ]),
>>> )
tensor([0.6222, 0.4429])
"""
targets_sort_by_outputs = process_recsys_components(outputs, targets)
targets_sort_by_outputs = process_recsys_components(outputs, targets)[:, :k]
precisions = torch.zeros_like(targets_sort_by_outputs)

for index in range(outputs.size(1)):
for index in range(k):
precisions[:, index] = torch.sum(targets_sort_by_outputs[:, : (index + 1)], dim=1) / float(
index + 1
)

precisions[:, index] = torch.sum(targets_sort_by_outputs[:, : (index + 1)], dim=1) / float(
index + 1
)
Expand All @@ -148,14 +154,12 @@ def mean_average_precision(
relevant items for each query
Args:
outputs (torch.Tensor):
Tensor with predicted score
outputs (torch.Tensor): Tensor with predicted score
size: [batch_size, slate_length]
model outputs, logits
targets (torch.Tensor):
Binary tensor with ground truth.
1 means the item is relevant
and 0 not relevant
1 means the item is relevant and 0 not relevant
size: [batch_szie, slate_length]
ground truth, labels
topk (List[int]):
Expand All @@ -164,29 +168,33 @@ def mean_average_precision(
Returns:
map_at_k (Tuple[float]):
The map score for every k.
size: len(top_k)
The map score for every k.
size: len(top_k)
Examples:
>>> mean_average_precision(
>>> outputs=torch.tensor([
>>> [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
>>> [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
>>> ]),
>>> targets=torch.tensor([
>>> [1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0],
>>> [0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0],
>>> ]),
>>> topk=[10],
>>> outputs=torch.tensor([
>>> [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
>>> [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
>>> ]),
>>> targets=torch.tensor([
>>> [1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0],
>>> [0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0],
>>> ]),
>>> topk=[10],
>>> )
[tensor(0.5325)]
"""
results = []
for k in topk:
k = min(outputs.size(1), k)
results.append(torch.mean(average_precision(outputs, targets)[:k]))
results.append(torch.mean(average_precision(outputs, targets, k)))

return results


__all__ = ["binary_average_precision", "mean_average_precision", "average_precision"]
__all__ = [
"binary_average_precision",
"mean_average_precision",
"average_precision",
]
75 changes: 55 additions & 20 deletions catalyst/metrics/functional/tests/test_average_precision.py
Expand Up @@ -19,7 +19,7 @@ def test_binary_average_precision_base():
targets = torch.Tensor([0, 0, 1, 1])

assert torch.isclose(
binary_average_precision(outputs, targets), torch.tensor(0.8333), atol=1e-3
binary_average_precision(outputs, targets), torch.tensor(0.8333), atol=1e-3,
)


Expand Down Expand Up @@ -115,70 +115,76 @@ def test_binary_average_precision_weighted():
), "ap test12 failed"


def test_avg_precision():
def test_average_precision():
"""
Tests for catalyst.avg_precision metric.
Tests for catalyst.metrics.average_precision metric.
"""
# # check everything is relevant
y_pred = [0.5, 0.2, 0.3, 0.8]
y_true = [1.0, 1.0, 1.0, 1.0]
k = 4

ap_val = average_precision(torch.Tensor([y_pred]), torch.Tensor([y_true]))
assert ap_val[0] == 1
avg_precision = average_precision(torch.Tensor([y_pred]), torch.Tensor([y_true]), k)
assert avg_precision[0] == 1

# # check is everything is relevant for 3 users
y_pred = [0.5, 0.2, 0.3, 0.8]
y_true = [1.0, 1.0, 1.0, 1.0]
k = 4

ap_val = average_precision(
torch.Tensor([y_pred, y_pred, y_pred]), torch.Tensor([y_true, y_true, y_true]),
avg_precision = average_precision(
torch.Tensor([y_pred, y_pred, y_pred]), torch.Tensor([y_true, y_true, y_true]), k,
)
assert torch.equal(ap_val, torch.ones(3))
assert torch.equal(avg_precision, torch.ones(3))

# # check everything is irrelevant
y_pred = [0.5, 0.2, 0.3, 0.8]
y_true = [0.0, 0.0, 0.0, 0.0]
k = 4

ap_val = average_precision(torch.Tensor([y_pred]), torch.Tensor([y_true]))
assert ap_val[0] == 0
avg_precision = average_precision(torch.Tensor([y_pred]), torch.Tensor([y_true]), k)
assert avg_precision[0] == 0

# # check is everything is irrelevant for 3 users
y_pred = [0.5, 0.2, 0.3, 0.8]
y_true = [0.0, 0.0, 0.0, 0.0]
k = 4

ap_val = average_precision(
torch.Tensor([y_pred, y_pred, y_pred]), torch.Tensor([y_true, y_true, y_true]),
avg_precision = average_precision(
torch.Tensor([y_pred, y_pred, y_pred]), torch.Tensor([y_true, y_true, y_true]), k,
)
assert torch.equal(ap_val, torch.zeros(3))
assert torch.equal(avg_precision, torch.zeros(3))

# # check 4 test with k
# # check 4
y_pred1 = [4.0, 2.0, 3.0, 1.0]
y_pred2 = [1.0, 2.0, 3.0, 4.0]
y_true1 = [0.0, 1.0, 1.0, 1.0]
y_true2 = [0.0, 1.0, 0.0, 0.0]
k = 4

y_pred_torch = torch.Tensor([y_pred1, y_pred2])
y_true_torch = torch.Tensor([y_true1, y_true2])

ap_val = average_precision(y_pred_torch, y_true_torch)
avg_precision = average_precision(y_pred_torch, y_true_torch, k)

assert np.isclose(ap_val[0], 0.6389, atol=1e-3)
assert np.isclose(ap_val[1], 0.333, atol=1e-3)
assert np.isclose(avg_precision[0], 0.6389, atol=1e-3)
assert np.isclose(avg_precision[1], 0.333, atol=1e-3)

# check 5
# Stanford Introdcution to information retrieval primer
y_pred1 = np.arange(9, -1, -1)
y_true1 = [1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0]
y_pred2 = np.arange(9, -1, -1)
y_true2 = [0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0]
k = 10

y_pred_torch = torch.Tensor([y_pred1, y_pred2])
y_true_torch = torch.Tensor([y_true1, y_true2])

ap_val = average_precision(y_pred_torch, y_true_torch)
avg_precision = average_precision(y_pred_torch, y_true_torch, k)

assert np.isclose(ap_val[0], 0.6222, atol=1e-3)
assert np.isclose(ap_val[1], 0.4429, atol=1e-3)
assert np.isclose(avg_precision[0], 0.6222, atol=1e-3)
assert np.isclose(avg_precision[1], 0.4429, atol=1e-3)


def test_mean_avg_precision():
Expand All @@ -199,3 +205,32 @@ def test_mean_avg_precision():
map_at10 = mean_average_precision(y_pred_torch, y_true_torch, top_k)[0]

assert np.allclose(map_at10, 0.5325, atol=1e-3)

# check 2
# map_at1: (1.0 + 0.0) / 2 = 0.5
# map_at3: ((1 + 0.67)/2 + 0.5) / 2 = 0.6675
# map_at5: ((1 + 0.67)/2 + (0.5 + 0.4)/2) / 2 = 0.6425
# map_at10: ((1 + 0.67 + 0.5 + 0.44 + 0.5)/5 + (0.5 + 0.4 + 0.43)/3 ) / 2 = 0.53

y_pred1 = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
y_pred2 = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]

y_true1 = [1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0]
y_true2 = [0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0]

y_pred_torch = torch.Tensor([y_pred1, y_pred2])
y_true_torch = torch.Tensor([y_true1, y_true2])

top_k = [1, 3, 5, 10]

map_k = mean_average_precision(y_pred_torch, y_true_torch, top_k)

map_at1 = map_k[0]
map_at3 = map_k[1]
map_at5 = map_k[2]
map_at10 = map_k[3]

assert np.allclose(map_at1, 0.5, atol=1e-3)
assert np.allclose(map_at3, 0.6675, atol=1e-3)
assert np.allclose(map_at5, 0.6425, atol=1e-3)
assert np.allclose(map_at10, 0.5325, atol=1e-3)
25 changes: 22 additions & 3 deletions catalyst/metrics/functional/tests/test_ndcg.py
Expand Up @@ -16,23 +16,23 @@ def test_dcg():
y_pred = np.arange(3, -1, -1)

dcg_at4 = torch.sum(
dcg(torch.tensor([y_pred]), torch.tensor([y_true]), gain_function="linear_rank")
dcg(torch.tensor([y_pred]), torch.tensor([y_true]), gain_function="linear_rank",)
)
assert torch.isclose(dcg_at4, torch.tensor(4.261), atol=0.05)

y_true = [2.0, 2.0, 1.0, 0.0]
y_pred = np.arange(3, -1, -1)

dcg_at4 = torch.sum(
dcg(torch.tensor([y_pred]), torch.tensor([y_true]), gain_function="linear_rank")
dcg(torch.tensor([y_pred]), torch.tensor([y_true]), gain_function="linear_rank",)
)
assert torch.isclose(dcg_at4, torch.tensor(4.631), atol=0.05)

y_true = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0]
y_pred = np.arange(9, -1, -1)

dcg_at10 = torch.sum(
dcg(torch.tensor([y_pred]), torch.tensor([y_true]), gain_function="linear_rank")
dcg(torch.tensor([y_pred]), torch.tensor([y_true]), gain_function="linear_rank",)
)

assert torch.isclose(dcg_at10, torch.tensor(9.61), atol=0.05)
Expand Down Expand Up @@ -73,6 +73,7 @@ def test_sample_ndcg():
y_pred2 = [0.5, 0.2, 0.1]
y_true1 = [1.0, 0.0, 1.0]
y_true2 = [1.0, 0.0, 1.0]
top_k = [2]

outputs = torch.Tensor([y_pred1, y_pred2])
targets = torch.Tensor([y_true1, y_true2])
Expand All @@ -81,3 +82,21 @@ def test_sample_ndcg():
comp_ndcg_at2 = ndcg(outputs, targets, topk=[2])[0]

assert np.isclose(true_ndcg_at2, comp_ndcg_at2)

y_pred1 = [0.5, 0.2, 0.1]
y_pred2 = [0.5, 0.2, 0.1]
y_true1 = [1.0, 0.0, 1.0]
y_true2 = [1.0, 0.0, 1.0]
top_k = [1, 2]

outputs = torch.Tensor([y_pred1, y_pred2])
targets = torch.Tensor([y_true1, y_true2])

true_ndcg_at2 = 1.0 / (1.0 + 1 / math.log2(3))
comp_ndcg = ndcg(outputs, targets, topk=top_k)

comp_ndcg_at1 = comp_ndcg[0]
comp_ndcg_at2 = comp_ndcg[1]

assert np.isclose(1, comp_ndcg_at1)
assert np.isclose(true_ndcg_at2, comp_ndcg_at2)

0 comments on commit a7720d8

Please sign in to comment.