From 77f03ff0a768bc9e12edc7178fb396349f2b4256 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Sat, 26 May 2018 12:02:08 -0600 Subject: [PATCH] Simplify list-based metric helper This simplifies the list-based metric API to use `List` instead of `LongList`. In this case the performance shift is likely not worth it. It also adds tests for the NDPM metric. --- .../recommend/ListOnlyTopNMetric.java | 13 ++-- .../recommend/RecommendEvalTask.java | 2 +- .../recommend/TopNEntropyMetric.java | 11 +-- .../traintest/recommend/TopNLengthMetric.java | 4 +- .../traintest/recommend/TopNMAPMetric.java | 10 +-- .../traintest/recommend/TopNMRRMetric.java | 10 +-- .../traintest/recommend/TopNNDCGMetric.java | 26 ++++--- .../traintest/recommend/TopNNDPMMetric.java | 77 ++++++++++--------- .../recommend/TopNPopularityMetric.java | 10 +-- .../recommend/TopNPrecisionRecallMetric.java | 10 +-- 10 files changed, 84 insertions(+), 89 deletions(-) diff --git a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/ListOnlyTopNMetric.java b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/ListOnlyTopNMetric.java index 720614dc54..d599288cd0 100644 --- a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/ListOnlyTopNMetric.java +++ b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/ListOnlyTopNMetric.java @@ -24,7 +24,6 @@ */ package org.lenskit.eval.traintest.recommend; -import it.unimi.dsi.fastutil.longs.LongList; import org.lenskit.api.Recommender; import org.lenskit.api.ResultList; import org.lenskit.eval.traintest.TestUser; @@ -37,7 +36,7 @@ /** * Intermediate class for top-N metrics that only depend on the list of recommended items, not their details. - * Metrics extending this class will implement the {@link #measureUser(Recommender, TestUser, int, LongList, Object)} method + * Metrics extending this class will implement the {@link #measureUserRecList(Recommender, TestUser, int, List, Object)} method * instead of {@link #measureUser(Recommender, TestUser, int, ResultList, Object)}. The recommend eval task uses this * subclass to improve efficiency when results are not used in the evaluation. * @@ -59,13 +58,13 @@ protected ListOnlyTopNMetric(Class resType, Class recommendations, X context); } diff --git a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/RecommendEvalTask.java b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/RecommendEvalTask.java index 01dcffbf9f..548406bd94 100644 --- a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/RecommendEvalTask.java +++ b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/RecommendEvalTask.java @@ -441,7 +441,7 @@ public MetricResult measureUser(Recommender rec, TestUser user, int n, ResultLis @Nonnull public MetricResult measureUser(Recommender rec, TestUser user, int n, LongList recommendations) { - return ((ListOnlyTopNMetric) metric).measureUser(rec, user, n, recommendations, context); + return ((ListOnlyTopNMetric) metric).measureUserRecList(rec, user, n, recommendations, context); } @Nonnull diff --git a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNEntropyMetric.java b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNEntropyMetric.java index 4c4b192605..db22b2a82d 100644 --- a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNEntropyMetric.java +++ b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNEntropyMetric.java @@ -26,8 +26,6 @@ import it.unimi.dsi.fastutil.longs.Long2IntMap; import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; -import it.unimi.dsi.fastutil.longs.LongIterator; -import it.unimi.dsi.fastutil.longs.LongList; import org.lenskit.api.Recommender; import org.lenskit.api.RecommenderEngine; import org.lenskit.eval.traintest.AlgorithmInstance; @@ -39,6 +37,7 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; +import java.util.List; /** * Metric that measures the entropy of the top N recommendations across all users. @@ -66,7 +65,7 @@ public TopNEntropyMetric() { @Nonnull @Override - public MetricResult measureUser(Recommender rec, TestUser user, int targetLength, LongList recommendations, Context context) { + public MetricResult measureUserRecList(Recommender rec, TestUser user, int targetLength, List recommendations, Context context) { context.addUser(recommendations); return MetricResult.empty(); } @@ -95,10 +94,8 @@ public static class Context { private Long2IntMap counts = new Long2IntOpenHashMap(); private int recCount = 0; - private synchronized void addUser(LongList recs) { - LongIterator iter = recs.iterator(); - while (iter.hasNext()) { - long item = iter.nextLong(); + private synchronized void addUser(List recs) { + for (long item: recs) { counts.put(item, counts.get(item) +1); recCount +=1; } diff --git a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNLengthMetric.java b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNLengthMetric.java index e85d9b6d8f..a8fa9cbe74 100644 --- a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNLengthMetric.java +++ b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNLengthMetric.java @@ -24,7 +24,6 @@ */ package org.lenskit.eval.traintest.recommend; -import it.unimi.dsi.fastutil.longs.LongList; import org.apache.commons.math3.stat.descriptive.moment.Mean; import org.lenskit.api.Recommender; import org.lenskit.api.RecommenderEngine; @@ -37,6 +36,7 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; +import java.util.List; /** * Metric that measures how long a TopN list actually is. @@ -53,7 +53,7 @@ public TopNLengthMetric() { @Nonnull @Override - public MetricResult measureUser(Recommender rec, TestUser user, int targetLength, LongList recommendations, Mean context) { + public MetricResult measureUserRecList(Recommender rec, TestUser user, int targetLength, List recommendations, Mean context) { int n = recommendations.size(); synchronized (context) { context.increment(n); diff --git a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNMAPMetric.java b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNMAPMetric.java index fcb3a59d07..4162e91cf4 100644 --- a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNMAPMetric.java +++ b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNMAPMetric.java @@ -25,8 +25,6 @@ package org.lenskit.eval.traintest.recommend; import com.fasterxml.jackson.annotation.JsonCreator; -import it.unimi.dsi.fastutil.longs.LongIterator; -import it.unimi.dsi.fastutil.longs.LongList; import it.unimi.dsi.fastutil.longs.LongSet; import org.apache.commons.lang3.StringUtils; import org.apache.commons.math3.stat.descriptive.moment.Mean; @@ -43,6 +41,7 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; +import java.util.List; /** * Compute the mean average precision. @@ -107,7 +106,7 @@ public MetricResult getAggregateMeasurements(Context context) { @Nonnull @Override - public MetricResult measureUser(Recommender rec, TestUser user, int targetLength, LongList recs, Context context) { + public MetricResult measureUserRecList(Recommender rec, TestUser user, int targetLength, List recs, Context context) { LongSet good = goodItems.selectItems(context.universe, rec, user); if (good.isEmpty()) { logger.warn("no good items for user {}", user.getUserId()); @@ -121,10 +120,9 @@ public MetricResult measureUser(Recommender rec, TestUser user, int targetLength int n = 0; double ngood = 0; double sum = 0; - LongIterator iter = recs.iterator(); - while (iter.hasNext()) { + for (long id: recs) { n += 1; - if(good.contains(iter.nextLong())) { + if (good.contains(id)) { // it is good ngood += 1; // add to MAP sum diff --git a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNMRRMetric.java b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNMRRMetric.java index 186e500e64..2fa07300b1 100644 --- a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNMRRMetric.java +++ b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNMRRMetric.java @@ -25,8 +25,6 @@ package org.lenskit.eval.traintest.recommend; import com.fasterxml.jackson.annotation.JsonCreator; -import it.unimi.dsi.fastutil.longs.LongIterator; -import it.unimi.dsi.fastutil.longs.LongList; import it.unimi.dsi.fastutil.longs.LongSet; import org.apache.commons.lang3.StringUtils; import org.apache.commons.math3.stat.descriptive.moment.Mean; @@ -43,6 +41,7 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; +import java.util.List; /** * Compute the mean reciprocal rank. @@ -104,7 +103,7 @@ public MetricResult getAggregateMeasurements(Context context) { @Nonnull @Override - public MetricResult measureUser(Recommender rec, TestUser user, int targetLength, LongList recommendations, Context context) { + public MetricResult measureUserRecList(Recommender rec, TestUser user, int targetLength, List recommendations, Context context) { LongSet good = goodItems.selectItems(context.universe, rec, user); if (good.isEmpty()) { logger.warn("no good items for user {}", user.getUserId()); @@ -112,10 +111,9 @@ public MetricResult measureUser(Recommender rec, TestUser user, int targetLength Integer rank = null; int i = 0; - LongIterator iter = recommendations.iterator(); - while (iter.hasNext()) { + for (long item: recommendations) { i++; - if(good.contains(iter.nextLong())) { + if (good.contains(item)) { rank = i; break; } diff --git a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNNDCGMetric.java b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNNDCGMetric.java index 9d2baf1429..dc8baee6a8 100644 --- a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNNDCGMetric.java +++ b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNNDCGMetric.java @@ -26,7 +26,9 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import it.unimi.dsi.fastutil.longs.*; +import it.unimi.dsi.fastutil.longs.Long2DoubleFunction; +import it.unimi.dsi.fastutil.longs.Long2DoubleMap; +import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; import org.apache.commons.lang3.StringUtils; import org.apache.commons.math3.stat.descriptive.moment.Mean; import org.lenskit.api.Recommender; @@ -45,8 +47,9 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; -import java.util.Arrays; import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; /** * Measure the nDCG of the top-N recommendations, using ratings as scores. @@ -120,7 +123,7 @@ public MetricResult getAggregateMeasurements(Mean context) { @Nonnull @Override - public MetricResult measureUser(Recommender rec, TestUser user, int targetLength, LongList recommendations, Mean context) { + public MetricResult measureUserRecList(Recommender rec, TestUser user, int targetLength, List recommendations, Mean context) { if (recommendations == null) { return MetricResult.empty(); } @@ -135,15 +138,16 @@ public MetricResult measureUser(Recommender rec, TestUser user, int targetLength throw new IllegalArgumentException("value " + av + " for attribute " + gainAttribute + " is not numeric"); } } - long[] ideal = ratings.keySet().toLongArray(); - LongArrays.quickSort(ideal, LongComparators.oppositeComparator(LongUtils.keyValueComparator(ratings))); - if (targetLength >= 0 && ideal.length > targetLength) { - ideal = Arrays.copyOf(ideal, targetLength); - } + + List ideal = + ratings.keySet() + .stream() + .sorted(LongUtils.keyValueComparator(ratings).reversed()) + .limit(targetLength >= 0 ? targetLength : ratings.size()) + .collect(Collectors.toList()); double idealGain = computeDCG(ideal, ratings); - long[] actual = recommendations.toLongArray(); - double gain = computeDCG(actual, ratings); + double gain = computeDCG(recommendations, ratings); double score = gain / idealGain; @@ -156,7 +160,7 @@ public MetricResult measureUser(Recommender rec, TestUser user, int targetLength /** * Compute the DCG of a list of items with respect to a value vector. */ - double computeDCG(long[] items, Long2DoubleFunction values) { + double computeDCG(List items, Long2DoubleFunction values) { double gain = 0; int rank = 0; diff --git a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNNDPMMetric.java b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNNDPMMetric.java index 4570ad8a66..13b61162be 100644 --- a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNNDPMMetric.java +++ b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNNDPMMetric.java @@ -26,10 +26,10 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import it.unimi.dsi.fastutil.longs.Long2DoubleFunction; import it.unimi.dsi.fastutil.longs.Long2DoubleMap; -import it.unimi.dsi.fastutil.longs.LongList; import org.apache.commons.lang3.StringUtils; +import org.apache.commons.math3.linear.ArrayRealVector; +import org.apache.commons.math3.linear.RealVector; import org.apache.commons.math3.stat.descriptive.moment.Mean; import org.lenskit.api.Recommender; import org.lenskit.api.RecommenderEngine; @@ -38,10 +38,13 @@ import org.lenskit.eval.traintest.TestUser; import org.lenskit.eval.traintest.metrics.MetricResult; import org.lenskit.util.math.Scalars; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.util.Collections; +import java.util.List; /** * Measure the nDPM of the top-N recommendations, using rankings. @@ -49,6 +52,7 @@ * The paper used as a reference for this implementation is http://www2.cs.uregina.ca/~yyao/PAPERS/jasis_ndpm.pdf. */ public class TopNNDPMMetric extends ListOnlyTopNMetric { + private static final Logger logger = LoggerFactory.getLogger(TopNNDPMMetric.class); public static final String DEFAULT_COLUMN = "TopN.nDPM"; /** @@ -76,23 +80,32 @@ public Mean createContext(AlgorithmInstance algorithm, DataSet dataSet, Recommen @Nonnull @Override public MetricResult getAggregateMeasurements(Mean context) { + logger.warn("The NDPM metric is not well-understood and is currently broken."); return MetricResult.singleton(DEFAULT_COLUMN, context.getResult()); } @Nonnull @Override - public MetricResult measureUser(Recommender rec, TestUser user, int targetLength, LongList recommendations, Mean context) { + public MetricResult measureUserRecList(Recommender rec, TestUser user, int targetLength, List recommendations, Mean context) { if (recommendations == null) { return MetricResult.empty(); } Long2DoubleMap ratings = user.getTestRatings(); - long[] actual = recommendations.toLongArray(); + RealVector scores = new ArrayRealVector(recommendations.size()); + int i = 0; + for (long item: recommendations) { + if (ratings.containsKey(item)) { + scores.setEntry(i, ratings.get(item)); + i += 1; + } + } + scores = scores.getSubVector(0, i); - double dpm = computeDPM(actual, ratings); + double dpm = computeDPM(scores); - double normalizingFactor = computeNormalizingFactor(actual, ratings); + double normalizingFactor = computeNormalizingFactor(scores); double nDPM = dpm / normalizingFactor; // Normalized nDPM @@ -104,31 +117,26 @@ public MetricResult measureUser(Recommender rec, TestUser user, int targetLength } /** - * Compute dpm of list of items, with respect to user's ratings. + * Compute dpm of list of item preference scores. */ - - double computeDPM(long [] actual_item, Long2DoubleFunction value) { + static double computeDPM(RealVector scores) { + int n = scores.getDimension(); int nCompatible = 0; int nDisagree = 0; - for(int i = 0; i < actual_item.length; i++){ - for(int j = i+1; j < actual_item.length; j++){ - double valueOne; - double valueTwo; + for(int i = 0; i < n; i++){ + double v1 = scores.getEntry(i); + if (Double.isNaN(v1)) continue; - if (value.containsKey(actual_item[i])) { - valueOne = value.get(actual_item[i]); + for(int j = i+1; j < n; j++){ + double v2 = scores.getEntry(j); - if (value.containsKey(actual_item[j])) { - valueTwo = value.get(actual_item[j]); + if (Double.isNaN(v2)) continue; - if (Scalars.isZero(valueOne - valueTwo)) { - nCompatible++; - } - if(valueOne < valueTwo){ - nDisagree++; - } - } + if (Scalars.isZero(v1 - v2)) { + nCompatible++; + } else if (v1 < v2) { + nDisagree++; } } } @@ -138,23 +146,18 @@ public MetricResult measureUser(Recommender rec, TestUser user, int targetLength return dpm; } - double computeNormalizingFactor(long [] actual_item, Long2DoubleFunction value) { + static double computeNormalizingFactor(RealVector scores) { int npairs = 0; + int n = scores.getDimension(); - for(int i = 0; i < actual_item.length; i++) { - for(int j = i+1; j < actual_item.length; j++) { - double valueOne; - double valueTwo; + for(int i = 0; i < n; i++) { + double v1 = scores.getEntry(i); - if (value.containsKey(actual_item[i])) { - valueOne = value.get(actual_item[i]); + for(int j = i+1; j < n; j++) { + double v2 = scores.getEntry(j); - if (value.containsKey(actual_item[j])) { - valueTwo = value.get(actual_item[j]); - if(valueOne < valueTwo || valueOne > valueTwo) { - npairs++; - } - } + if (!Scalars.isZero(v1 - v2)) { + npairs++; } } } diff --git a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNPopularityMetric.java b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNPopularityMetric.java index 856fb80f19..64aabefab6 100644 --- a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNPopularityMetric.java +++ b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNPopularityMetric.java @@ -24,8 +24,6 @@ */ package org.lenskit.eval.traintest.recommend; -import it.unimi.dsi.fastutil.longs.LongIterator; -import it.unimi.dsi.fastutil.longs.LongList; import org.apache.commons.math3.stat.descriptive.moment.Mean; import org.lenskit.LenskitRecommender; import org.lenskit.api.Recommender; @@ -41,6 +39,7 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.util.Collections; +import java.util.List; import java.util.Set; /** @@ -67,7 +66,7 @@ public Context createContext(AlgorithmInstance algorithm, DataSet dataSet, Recom @Nonnull @Override - public MetricResult measureUser(Recommender rec, TestUser user, int targetLength, LongList recs, Context context) { + public MetricResult measureUserRecList(Recommender rec, TestUser user, int targetLength, List recs, Context context) { RatingSummary summary = null; if (rec instanceof LenskitRecommender) { summary = ((LenskitRecommender) rec).get(RatingSummary.class); @@ -76,9 +75,8 @@ public MetricResult measureUser(Recommender rec, TestUser user, int targetLength return MetricResult.empty(); } double pop = 0; - LongIterator iter = recs.iterator(); - while (iter.hasNext()) { - pop += summary.getItemRatingCount(iter.nextLong()); + for (long item: recs) { + pop += summary.getItemRatingCount(item); } pop = pop / recs.size(); diff --git a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNPrecisionRecallMetric.java b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNPrecisionRecallMetric.java index a5cf2dcc16..4f371a34bd 100644 --- a/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNPrecisionRecallMetric.java +++ b/lenskit-eval/src/main/java/org/lenskit/eval/traintest/recommend/TopNPrecisionRecallMetric.java @@ -25,8 +25,6 @@ package org.lenskit.eval.traintest.recommend; import com.fasterxml.jackson.annotation.JsonCreator; -import it.unimi.dsi.fastutil.longs.LongIterator; -import it.unimi.dsi.fastutil.longs.LongList; import it.unimi.dsi.fastutil.longs.LongSet; import org.apache.commons.lang3.StringUtils; import org.lenskit.api.Recommender; @@ -41,6 +39,7 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; +import java.util.List; /** * A metric to compute the precision and recall of a recommender given a @@ -92,14 +91,13 @@ public TopNPrecisionRecallMetric(ItemSelector good, String sfx) { @Nonnull @Override - public MetricResult measureUser(Recommender rec, TestUser user, int targetLength, LongList recs, Context context) { + public MetricResult measureUserRecList(Recommender rec, TestUser user, int targetLength, List recs, Context context) { int tp = 0; LongSet items = goodItems.selectItems(context.universe, rec, user); - LongIterator iter = recs.iterator(); - while (iter.hasNext()) { - if(items.contains(iter.nextLong())) { + for (long item: recs) { + if(items.contains(item)) { tp += 1; } }