Skip to content
This repository has been archived by the owner on Mar 20, 2024. It is now read-only.

Commit

Permalink
Simplify list-based metric helper
Browse files Browse the repository at this point in the history
This simplifies the list-based metric API to use `List<Long>` instead of `LongList`. In this case the performance shift is likely not worth it.

It also adds tests for the NDPM metric.
  • Loading branch information
mdekstrand committed May 26, 2018
1 parent 8a575c6 commit 77f03ff
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 89 deletions.
Expand Up @@ -24,7 +24,6 @@
*/
package org.lenskit.eval.traintest.recommend;

import it.unimi.dsi.fastutil.longs.LongList;
import org.lenskit.api.Recommender;
import org.lenskit.api.ResultList;
import org.lenskit.eval.traintest.TestUser;
Expand All @@ -37,7 +36,7 @@

/**
* Intermediate class for top-N metrics that only depend on the list of recommended items, not their details.
* Metrics extending this class will implement the {@link #measureUser(Recommender, TestUser, int, LongList, Object)} method
* Metrics extending this class will implement the {@link #measureUserRecList(Recommender, TestUser, int, List, Object)} method
* instead of {@link #measureUser(Recommender, TestUser, int, ResultList, Object)}. The recommend eval task uses this
* subclass to improve efficiency when results are not used in the evaluation.
*
Expand All @@ -59,13 +58,13 @@ protected ListOnlyTopNMetric(Class<? extends TypedMetricResult> resType, Class<?
@Nonnull
@Override
public final MetricResult measureUser(Recommender rec, TestUser user, int targetLength, ResultList recommendations, X context) {
return measureUser(rec, user, targetLength,
LongUtils.asLongList(recommendations.idList()),
context);
return measureUserRecList(rec, user, targetLength,
LongUtils.asLongList(recommendations.idList()),
context);
}

/**
* Measurement method that only uses the recommend list.
* Measurement method that only uses the recommended list, without any scores or details.
*
* **Thread Safety:** This method may be called concurrently by multiple threads with the same recommender and
* context.
Expand All @@ -78,5 +77,5 @@ public final MetricResult measureUser(Recommender rec, TestUser user, int target
* @return The results of measuring this user.
*/
@Nonnull
public abstract MetricResult measureUser(Recommender rec, TestUser user, int targetLength, LongList recommendations, X context);
public abstract MetricResult measureUserRecList(Recommender rec, TestUser user, int targetLength, List<Long> recommendations, X context);
}
Expand Up @@ -441,7 +441,7 @@ public MetricResult measureUser(Recommender rec, TestUser user, int n, ResultLis

@Nonnull
public MetricResult measureUser(Recommender rec, TestUser user, int n, LongList recommendations) {
return ((ListOnlyTopNMetric<X>) metric).measureUser(rec, user, n, recommendations, context);
return ((ListOnlyTopNMetric<X>) metric).measureUserRecList(rec, user, n, recommendations, context);
}

@Nonnull
Expand Down
Expand Up @@ -26,8 +26,6 @@

import it.unimi.dsi.fastutil.longs.Long2IntMap;
import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap;
import it.unimi.dsi.fastutil.longs.LongIterator;
import it.unimi.dsi.fastutil.longs.LongList;
import org.lenskit.api.Recommender;
import org.lenskit.api.RecommenderEngine;
import org.lenskit.eval.traintest.AlgorithmInstance;
Expand All @@ -39,6 +37,7 @@

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.List;

/**
* Metric that measures the entropy of the top N recommendations across all users.
Expand Down Expand Up @@ -66,7 +65,7 @@ public TopNEntropyMetric() {

@Nonnull
@Override
public MetricResult measureUser(Recommender rec, TestUser user, int targetLength, LongList recommendations, Context context) {
public MetricResult measureUserRecList(Recommender rec, TestUser user, int targetLength, List<Long> recommendations, Context context) {
context.addUser(recommendations);
return MetricResult.empty();
}
Expand Down Expand Up @@ -95,10 +94,8 @@ public static class Context {
private Long2IntMap counts = new Long2IntOpenHashMap();
private int recCount = 0;

private synchronized void addUser(LongList recs) {
LongIterator iter = recs.iterator();
while (iter.hasNext()) {
long item = iter.nextLong();
private synchronized void addUser(List<Long> recs) {
for (long item: recs) {
counts.put(item, counts.get(item) +1);
recCount +=1;
}
Expand Down
Expand Up @@ -24,7 +24,6 @@
*/
package org.lenskit.eval.traintest.recommend;

import it.unimi.dsi.fastutil.longs.LongList;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.lenskit.api.Recommender;
import org.lenskit.api.RecommenderEngine;
Expand All @@ -37,6 +36,7 @@

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.List;

/**
* Metric that measures how long a TopN list actually is.
Expand All @@ -53,7 +53,7 @@ public TopNLengthMetric() {

@Nonnull
@Override
public MetricResult measureUser(Recommender rec, TestUser user, int targetLength, LongList recommendations, Mean context) {
public MetricResult measureUserRecList(Recommender rec, TestUser user, int targetLength, List<Long> recommendations, Mean context) {
int n = recommendations.size();
synchronized (context) {
context.increment(n);
Expand Down
Expand Up @@ -25,8 +25,6 @@
package org.lenskit.eval.traintest.recommend;

import com.fasterxml.jackson.annotation.JsonCreator;
import it.unimi.dsi.fastutil.longs.LongIterator;
import it.unimi.dsi.fastutil.longs.LongList;
import it.unimi.dsi.fastutil.longs.LongSet;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
Expand All @@ -43,6 +41,7 @@

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.List;

/**
* Compute the mean average precision.
Expand Down Expand Up @@ -107,7 +106,7 @@ public MetricResult getAggregateMeasurements(Context context) {

@Nonnull
@Override
public MetricResult measureUser(Recommender rec, TestUser user, int targetLength, LongList recs, Context context) {
public MetricResult measureUserRecList(Recommender rec, TestUser user, int targetLength, List<Long> recs, Context context) {
LongSet good = goodItems.selectItems(context.universe, rec, user);
if (good.isEmpty()) {
logger.warn("no good items for user {}", user.getUserId());
Expand All @@ -121,10 +120,9 @@ public MetricResult measureUser(Recommender rec, TestUser user, int targetLength
int n = 0;
double ngood = 0;
double sum = 0;
LongIterator iter = recs.iterator();
while (iter.hasNext()) {
for (long id: recs) {
n += 1;
if(good.contains(iter.nextLong())) {
if (good.contains(id)) {
// it is good
ngood += 1;
// add to MAP sum
Expand Down
Expand Up @@ -25,8 +25,6 @@
package org.lenskit.eval.traintest.recommend;

import com.fasterxml.jackson.annotation.JsonCreator;
import it.unimi.dsi.fastutil.longs.LongIterator;
import it.unimi.dsi.fastutil.longs.LongList;
import it.unimi.dsi.fastutil.longs.LongSet;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
Expand All @@ -43,6 +41,7 @@

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.List;

/**
* Compute the mean reciprocal rank.
Expand Down Expand Up @@ -104,18 +103,17 @@ public MetricResult getAggregateMeasurements(Context context) {

@Nonnull
@Override
public MetricResult measureUser(Recommender rec, TestUser user, int targetLength, LongList recommendations, Context context) {
public MetricResult measureUserRecList(Recommender rec, TestUser user, int targetLength, List<Long> recommendations, Context context) {
LongSet good = goodItems.selectItems(context.universe, rec, user);
if (good.isEmpty()) {
logger.warn("no good items for user {}", user.getUserId());
}

Integer rank = null;
int i = 0;
LongIterator iter = recommendations.iterator();
while (iter.hasNext()) {
for (long item: recommendations) {
i++;
if(good.contains(iter.nextLong())) {
if (good.contains(item)) {
rank = i;
break;
}
Expand Down
Expand Up @@ -26,7 +26,9 @@

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import it.unimi.dsi.fastutil.longs.*;
import it.unimi.dsi.fastutil.longs.Long2DoubleFunction;
import it.unimi.dsi.fastutil.longs.Long2DoubleMap;
import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.lenskit.api.Recommender;
Expand All @@ -45,8 +47,9 @@

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

/**
* Measure the nDCG of the top-N recommendations, using ratings as scores.
Expand Down Expand Up @@ -120,7 +123,7 @@ public MetricResult getAggregateMeasurements(Mean context) {

@Nonnull
@Override
public MetricResult measureUser(Recommender rec, TestUser user, int targetLength, LongList recommendations, Mean context) {
public MetricResult measureUserRecList(Recommender rec, TestUser user, int targetLength, List<Long> recommendations, Mean context) {
if (recommendations == null) {
return MetricResult.empty();
}
Expand All @@ -135,15 +138,16 @@ public MetricResult measureUser(Recommender rec, TestUser user, int targetLength
throw new IllegalArgumentException("value " + av + " for attribute " + gainAttribute + " is not numeric");
}
}
long[] ideal = ratings.keySet().toLongArray();
LongArrays.quickSort(ideal, LongComparators.oppositeComparator(LongUtils.keyValueComparator(ratings)));
if (targetLength >= 0 && ideal.length > targetLength) {
ideal = Arrays.copyOf(ideal, targetLength);
}

List<Long> ideal =
ratings.keySet()
.stream()
.sorted(LongUtils.keyValueComparator(ratings).reversed())
.limit(targetLength >= 0 ? targetLength : ratings.size())
.collect(Collectors.toList());
double idealGain = computeDCG(ideal, ratings);

long[] actual = recommendations.toLongArray();
double gain = computeDCG(actual, ratings);
double gain = computeDCG(recommendations, ratings);

double score = gain / idealGain;

Expand All @@ -156,7 +160,7 @@ public MetricResult measureUser(Recommender rec, TestUser user, int targetLength
/**
* Compute the DCG of a list of items with respect to a value vector.
*/
double computeDCG(long[] items, Long2DoubleFunction values) {
double computeDCG(List<Long> items, Long2DoubleFunction values) {
double gain = 0;
int rank = 0;

Expand Down

0 comments on commit 77f03ff

Please sign in to comment.