Skip to content

Commit

Permalink
Merge pull request #312 from mdekstrand/fix/309-main
Browse files Browse the repository at this point in the history
Fix nDCG truncation bug (#309)
  • Loading branch information
mdekstrand committed Mar 11, 2022
2 parents 0716300 + 65081ed commit 371edfd
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lenskit/metrics/topn.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,11 @@ def ndcg(recs, truth, discount=np.log2, k=None):
The maximum list length.
"""

tpos = truth.index.get_indexer(recs['item'])

if k is not None:
recs = recs.iloc[:k]

tpos = truth.index.get_indexer(recs['item'])

if 'rating' in truth.columns:
i_rates = np.sort(truth.rating.values)[::-1]
if k is not None:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_topn_ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ def test_ndcg_perfect():
assert ndcg(recs, truth) == approx(1.0)


def test_ndcg_perfect_k_short():
recs = pd.DataFrame({'item': [2, 3, 1]})
truth = pd.DataFrame({'item': [1, 2, 3], 'rating': [3.0, 5.0, 4.0]})
truth = truth.set_index('item')
assert ndcg(recs, truth, k=2) == approx(1.0)
assert ndcg(recs[:2], truth, k=2) == approx(1.0)


def test_ndcg_wrong():
recs = pd.DataFrame({'item': [1, 2]})
truth = pd.DataFrame({'item': [1, 2, 3], 'rating': [3.0, 5.0, 4.0]})
Expand Down

0 comments on commit 371edfd

Please sign in to comment.