Skip to content
This repository has been archived by the owner on Mar 20, 2024. It is now read-only.

Commit

Permalink
Merge pull request #708 from elehack/feature/rec-output
Browse files Browse the repository at this point in the history
Merge configurable recommendation output (#613).
  • Loading branch information
mdekstrand committed Apr 11, 2015
2 parents 9ebc8b0 + 2baf722 commit d0f8a20
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
package org.grouplens.lenskit.eval.metrics.topn;

import org.apache.commons.lang3.builder.Builder;
import org.grouplens.lenskit.eval.metrics.Metric;

/**
* @author <a href="http://www.grouplens.org">GroupLens Research</a>
*/
public abstract class TopNMetricBuilder<K extends Metric<?>> implements Builder<K> {
public abstract class TopNMetricBuilder<K> implements Builder<K> {
protected int listSize = 5;
protected ItemSelector candidates = ItemSelectors.testItems();
protected ItemSelector exclude = ItemSelectors.trainingItems();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@ public static <T> MetricFactory<T> forMetric(Metric<T> m) {
return new Preinstantiated<T>(m);
}

/**
* Create a metric factory that wraps a class. The class is instantiated immediately.
* @param m The metric class.
* @return A metric factory that returns {@code m}.
*/
public static <T> MetricFactory<T> forMetricClass(Class<? extends Metric<T>> m) {
try {
return new Preinstantiated<T>(m.newInstance());
} catch (InstantiationException e) {
throw new RuntimeException("cannot instantiate " + m, e);
} catch (IllegalAccessException e) {
throw new RuntimeException("cannot instantiate " + m, e);
}
}

private static class Preinstantiated<T> extends MetricFactory<T> {
private final Metric<T> metric;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
package org.grouplens.lenskit.eval.traintest;

import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import it.unimi.dsi.fastutil.longs.LongIterator;
import it.unimi.dsi.fastutil.longs.LongSortedSet;
import org.apache.commons.lang3.builder.Builder;
import org.apache.commons.lang3.tuple.Pair;
import org.grouplens.lenskit.Recommender;
import org.grouplens.lenskit.collections.LongUtils;
Expand Down Expand Up @@ -135,15 +137,16 @@ public static class Context {

public static class Factory extends MetricFactory<Context> {
private final List<Pair<Symbol, String>> predictChannels;
private final File file;

public Factory(List<Pair<Symbol, String>> pchans) {
predictChannels = pchans;
public Factory(File f, List<Pair<Symbol, String>> pchans) {
file = f;
predictChannels = ImmutableList.copyOf(pchans);
}

@Override
public OutputPredictMetric createMetric(TrainTestEvalTask task) throws IOException {
return new OutputPredictMetric(task.getOutputLayout(), task.getPredictOutput(),
predictChannels);
return new OutputPredictMetric(task.getOutputLayout(), file, predictChannels);
}

@Override
Expand All @@ -156,4 +159,44 @@ public List<String> getUserColumnLabels() {
return Collections.emptyList();
}
}

/**
* Configure the prediction output.
*/
public static class FactoryBuilder implements Builder<Factory> {
private File file;
private List<Pair<Symbol,String>> channels = Lists.newLinkedList();

public File getFile() {
return file;
}

public void setFile(File f) {
file = f;
}

public void setFile(String fn) {
setFile(new File(fn));
}

public void addChannel(Symbol chan, String col) {
channels.add(Pair.of(chan, col));
}

public void addChannel(String chan, String col) {
channels.add(Pair.of(Symbol.of(chan), col));
}

public List<Pair<Symbol,String>> getChannels() {
return Collections.unmodifiableList(channels);
}

@Override
public Factory build() {
if (file == null) {
throw new IllegalStateException("no file specified");
}
return new Factory(file, channels);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.grouplens.lenskit.eval.metrics.AbstractMetric;
import org.grouplens.lenskit.eval.metrics.topn.ItemSelector;
import org.grouplens.lenskit.eval.metrics.topn.ItemSelectors;
import org.grouplens.lenskit.eval.metrics.topn.TopNMetricBuilder;
import org.grouplens.lenskit.scored.ScoredId;
import org.grouplens.lenskit.util.table.TableLayout;
import org.grouplens.lenskit.util.table.TableLayoutBuilder;
Expand Down Expand Up @@ -88,6 +89,7 @@ public Void doMeasureUser(TestUser user, Context context) {
try {
context.writer.writeRow(user.getUserId(), rec.getId(),
counter, rec.getScore());
counter += 1;
} catch (IOException e) {
throw Throwables.propagate(e);
}
Expand All @@ -114,11 +116,22 @@ public static class Context {
}

public static class Factory extends MetricFactory<Context> {
private final int listSize;
private final ItemSelector candidates;
private final ItemSelector exclude;
private final File file;

public Factory(File f, int size, ItemSelector cand, ItemSelector excl) {
file = f;
listSize = size;
candidates = cand;
exclude = excl;
}

@Override
public OutputTopNMetric createMetric(TrainTestEvalTask task) throws IOException {
return new OutputTopNMetric(task.getOutputLayout(), task.getRecommendOutput(), -1,
ItemSelectors.allItems(),
ItemSelectors.trainingItems());
return new OutputTopNMetric(task.getOutputLayout(), file,
listSize, candidates, exclude);
}

@Override
Expand All @@ -131,4 +144,37 @@ public List<String> getUserColumnLabels() {
return Collections.emptyList();
}
}

/**
* Configure the prediction output.
*/
public static class FactoryBuilder extends TopNMetricBuilder<Factory> {
private File file;

public FactoryBuilder() {
setListSize(-1);
setCandidates(ItemSelectors.allItems());
setCandidates(ItemSelectors.trainingItems());
}

public File getFile() {
return file;
}

public void setFile(File f) {
file = f;
}

public void setFile(String fn) {
setFile(new File(fn));
}

@Override
public Factory build() {
if (file == null) {
throw new IllegalStateException("no file specified");
}
return new Factory(file, getListSize(), getCandidates(), getExclude());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -252,14 +252,8 @@ public SimpleEvaluator addMetric(Metric<?> metric) {
* @param metric The metric to be added.
* @return Itself for method chaining.
*/
public SimpleEvaluator addMetric(Class<? extends Metric<?>> metric) {
try {
result.addMetric(metric);
} catch (IllegalAccessException e) {
throw new IllegalArgumentException(e);
} catch (InstantiationException e) {
throw new IllegalArgumentException(e);
}
public <T> SimpleEvaluator addMetric(Class<? extends Metric<T>> metric) {
result.addMetric(metric);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@
import javax.annotation.Nullable;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ExecutionException;

/**
Expand All @@ -71,17 +74,16 @@ public class TrainTestEvalTask extends AbstractTask<Table> {
private List<AlgorithmInstance> algorithms;
private List<ExternalAlgorithm> externalAlgorithms;
private List<MetricFactory<?>> metrics;
private List<Pair<Symbol,String>> predictChannels;
private boolean isolate;
private boolean separateAlgorithms;
private File outputFile;
private File userOutputFile;
private File predictOutputFile;
private File recommendOutputFile;
private File cacheDir;
private File taskGraphFile;
private File taskStatusFile;
private boolean cacheAll = false;
private OutputPredictMetric.FactoryBuilder defaultPredict = new OutputPredictMetric.FactoryBuilder();
private OutputTopNMetric.FactoryBuilder defaultRecommend = new OutputTopNMetric.FactoryBuilder();

private ExperimentSuite experiments;
private MeasurementSuite measurements;
Expand All @@ -98,7 +100,6 @@ public TrainTestEvalTask(String name) {
algorithms = Lists.newArrayList();
externalAlgorithms = Lists.newArrayList();
metrics = Lists.newArrayList();
predictChannels = Lists.newArrayList();
outputFile = new File("train-test-results.csv");
isolate = false;
}
Expand All @@ -125,13 +126,35 @@ public TrainTestEvalTask addExternalAlgorithm(ExternalAlgorithm algorithm) {
return this;
}

/**
* Add a metric.
* @param metric The metric to add.
* @return The task (for chaining).
*/
public TrainTestEvalTask addMetric(Metric<?> metric) {
metrics.add(MetricFactory.forMetric(metric));
return this;
}

public TrainTestEvalTask addMetric(Class<? extends Metric<?>> metricClass) throws IllegalAccessException, InstantiationException {
return addMetric(metricClass.newInstance());
/**
* Add a metric by factory.
* @param factory The metric factory to add.
* @return The task (for chaining).
*/
public TrainTestEvalTask addMetric(MetricFactory<?> factory) {
metrics.add(factory);
return this;
}

/**
* Add a metric by class.
* @param metricClass The metric class.
* @param <T> The metric's return type.
* @return The task (for chaining).
*/
public <T> TrainTestEvalTask addMetric(Class<? extends Metric<T>> metricClass) {
metrics.add(MetricFactory.forMetricClass(metricClass));
return this;
}

/**
Expand Down Expand Up @@ -177,15 +200,16 @@ public TrainTestEvalTask addWritePredictionChannel(@Nonnull Symbol channelSym) {
* @param label The column label. If {@code null}, the channel symbol's name is used.
* @return The command (for chaining).
* @see #setPredictOutput(File)
* @deprecated Use the {@code predictions} metric.
*/
@Deprecated
public TrainTestEvalTask addWritePredictionChannel(@Nonnull Symbol channelSym,
@Nullable String label) {
Preconditions.checkNotNull(channelSym, "channel is null");
if (label == null) {
label = channelSym.getName();
}
Pair<Symbol, String> entry = Pair.of(channelSym, label);
predictChannels.add(entry);
defaultPredict.addChannel(channelSym, label);
return this;
}

Expand All @@ -207,20 +231,26 @@ public TrainTestEvalTask setUserOutput(String fn) {
return setUserOutput(new File(fn));
}

@Deprecated
public TrainTestEvalTask setPredictOutput(File file) {
predictOutputFile = file;
logger.warn("predictOutput is deprecated, use predictions metric");
defaultPredict.setFile(file);
return this;
}

@Deprecated
public TrainTestEvalTask setPredictOutput(String fn) {
return setPredictOutput(new File(fn));
}

@Deprecated
public TrainTestEvalTask setRecommendOutput(File file) {
recommendOutputFile = file;
logger.warn("recommendOutput is deprecated, use predictions metric");
defaultRecommend.setFile(file);
return this;
}

@Deprecated
public TrainTestEvalTask setRecommendOutput(String fn) {
return setRecommendOutput(new File(fn));
}
Expand Down Expand Up @@ -317,19 +347,19 @@ List<MetricFactory<?>> getMetricFactories() {
}

List<Pair<Symbol,String>> getPredictionChannels() {
return predictChannels;
return defaultPredict.getChannels();
}

File getOutput() {
return outputFile;
}

File getPredictOutput() {
return predictOutputFile;
return defaultPredict.getFile();
}

File getRecommendOutput() {
return recommendOutputFile;
return defaultRecommend.getFile();
}

/**
Expand Down Expand Up @@ -476,11 +506,11 @@ ExperimentSuite createExperimentSuite() {
MeasurementSuite createMeasurementSuite() {
ImmutableList.Builder<MetricFactory<?>> activeMetrics = ImmutableList.builder();
activeMetrics.addAll(metrics);
if (recommendOutputFile != null) {
activeMetrics.add(new OutputTopNMetric.Factory());
if (defaultRecommend.getFile() != null) {
activeMetrics.add(defaultRecommend.build());
}
if (predictOutputFile != null) {
activeMetrics.add(new OutputPredictMetric.Factory(predictChannels));
if (defaultPredict.getFile() != null) {
activeMetrics.add(defaultPredict.build());
}
return new MeasurementSuite(activeMetrics.build());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
builder=org.grouplens.lenskit.eval.traintest.OutputPredictMetric$FactoryBuilder
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
builder=org.grouplens.lenskit.eval.traintest.OutputTopNMetric$FactoryBuilder
Loading

0 comments on commit d0f8a20

Please sign in to comment.