Skip to content

Commit

Permalink
fix bug (ndcg and recall nan)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanfeng97 committed Oct 18, 2022
1 parent 9fe043d commit e124154
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 4 additions & 1 deletion dhg/metrics/recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def recall(
assert y_true.max() == 1, "The input y_true must be binary."
pred_seq = y_true.gather(1, torch.argsort(y_pred, dim=-1, descending=True))[:, :k]
num_true = y_true.sum(dim=1)
res_list = (pred_seq.sum(dim=1) / num_true).detach().cpu()
res_list = (pred_seq.sum(dim=1) / num_true).cpu()
res_list[torch.isinf(res_list)] = 0
res_list[torch.isnan(res_list)] = 0
if ret_batch:
return [res.item() for res in res_list]
else:
Expand Down Expand Up @@ -169,6 +171,7 @@ def ndcg(

res_list = (pred_dcg / ideal_dcg).detach().cpu()
res_list[torch.isinf(res_list)] = 0
res_list[torch.isnan(res_list)] = 0
if ret_batch:
return [res.item() for res in res_list]
else:
Expand Down
5 changes: 4 additions & 1 deletion dhg/metrics/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def recall(
assert y_true.max() == 1, "The input y_true must be binary."
pred_seq = y_true.gather(1, torch.argsort(y_pred, dim=-1, descending=True))[:, :k]
num_true = y_true.sum(dim=1)
res_list = (pred_seq.sum(dim=1) / num_true).detach().cpu()
res_list = (pred_seq.sum(dim=1) / num_true).cpu()
res_list[torch.isinf(res_list)] = 0
res_list[torch.isnan(res_list)] = 0
if ret_batch:
return res_list
else:
Expand Down Expand Up @@ -252,6 +254,7 @@ def ndcg(

res_list = pred_dcg / ideal_dcg
res_list[torch.isinf(res_list)] = 0
res_list[torch.isnan(res_list)] = 0
if ret_batch:
return res_list
else:
Expand Down

0 comments on commit e124154

Please sign in to comment.