Skip to content

Commit

Permalink
Refactored LTR commandline, fixed issues #44, #46
Browse files Browse the repository at this point in the history
Issue #44 Don't create reranker cascade for every query
Issue #46 Switch to float features before it's too late
  • Loading branch information
lintool committed Nov 5, 2015
2 parents 67d92bf + e8b76b7 commit b5ab1cb
Show file tree
Hide file tree
Showing 12 changed files with 154 additions and 148 deletions.
115 changes: 61 additions & 54 deletions src/main/java/io/anserini/ltr/DumpTweetsLtrData.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,104 +8,111 @@
import io.anserini.rerank.twitter.RemoveRetweetsTemporalTiebreakReranker;
import io.anserini.search.MicroblogTopic;
import io.anserini.search.MicroblogTopicSet;
import io.anserini.search.SearchTweets;
import io.anserini.search.SearchArgs;
import io.anserini.util.AnalyzerUtils;

import java.io.File;
import java.io.FileOutputStream;
import java.io.PrintStream;
import java.nio.file.Paths;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Filter;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.NumericRangeFilter;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.LMDirichletSimilarity;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.MMapDirectory;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.OptionHandlerFilter;
import org.kohsuke.args4j.ParserProperties;

import com.google.common.collect.Sets;

@SuppressWarnings("deprecation")
public class DumpTweetsLtrData {
private static final String INDEX_OPTION = "index";
private static final String QUERIES_OPTION = "queries";
private static final String NUM_RESULTS_OPTION = "num_results";
private static final Logger LOG = LogManager.getLogger(DumpTweetsLtrData.class);

private DumpTweetsLtrData() {}

@SuppressWarnings("static-access")
public static void main(String[] args) throws Exception {
Options options = new Options();
long curTime = System.nanoTime();
SearchArgs searchArgs = new SearchArgs();
CmdLineParser parser = new CmdLineParser(searchArgs, ParserProperties.defaults().withUsageWidth(90));

options.addOption(OptionBuilder.withArgName("path").hasArg()
.withDescription("index location").create(INDEX_OPTION));
options.addOption(OptionBuilder.withArgName("num").hasArg()
.withDescription("number of results to return").create(NUM_RESULTS_OPTION));
options.addOption(OptionBuilder.withArgName("file").hasArg()
.withDescription("file containing topics in TREC format").create(QUERIES_OPTION));

CommandLine cmdline = null;
CommandLineParser parser = new GnuParser();
try {
cmdline = parser.parse(options, args);
} catch (ParseException exp) {
System.err.println("Error parsing command line: " + exp.getMessage());
System.exit(-1);
parser.parseArgument(args);
} catch (CmdLineException e) {
System.err.println(e.getMessage());
parser.printUsage(System.err);
System.err.println("Example: SearchTweets" + parser.printExample(OptionHandlerFilter.REQUIRED));
return;
}

if (!cmdline.hasOption(QUERIES_OPTION) || !cmdline.hasOption(INDEX_OPTION)) {
HelpFormatter formatter = new HelpFormatter();
formatter.printHelp(SearchTweets.class.getName(), options);
System.exit(-1);
LOG.info("Reading index at " + searchArgs.index);
Directory dir;
if (searchArgs.inmem) {
LOG.info("Using MMapDirectory with preload");
dir = new MMapDirectory(Paths.get(searchArgs.index));
((MMapDirectory) dir).setPreload(true);
} else {
LOG.info("Using default FSDirectory");
dir = FSDirectory.open(Paths.get(searchArgs.index));
}

File indexLocation = new File(cmdline.getOptionValue(INDEX_OPTION));
if (!indexLocation.exists()) {
System.err.println("Error: " + indexLocation + " does not exist!");
System.exit(-1);
}

String topicsFile = cmdline.getOptionValue(QUERIES_OPTION);
IndexReader reader = DirectoryReader.open(dir);
IndexSearcher searcher = new IndexSearcher(reader);

int numResults = 1000;
try {
if (cmdline.hasOption(NUM_RESULTS_OPTION)) {
numResults = Integer.parseInt(cmdline.getOptionValue(NUM_RESULTS_OPTION));
}
} catch (NumberFormatException e) {
System.err.println("Invalid " + NUM_RESULTS_OPTION + ": " + cmdline.getOptionValue(NUM_RESULTS_OPTION));
if (searchArgs.ql) {
LOG.info("Using QL scoring model");
searcher.setSimilarity(new LMDirichletSimilarity(searchArgs.mu));
} else if (searchArgs.bm25) {
LOG.info("Using BM25 scoring model");
searcher.setSimilarity(new BM25Similarity(searchArgs.k1, searchArgs.b));
} else {
LOG.error("Error: Must specify scoring model!");
System.exit(-1);
}

PrintStream out = new PrintStream(System.out, true, "UTF-8");
PrintStream out = new PrintStream(new FileOutputStream(new File(searchArgs.output)));
RerankerCascade cascade = new RerankerCascade();
cascade.add(new RemoveRetweetsTemporalTiebreakReranker()).add(new TweetsLtrDataGenerator(out));

IndexReader reader = DirectoryReader.open(FSDirectory.open(Paths.get(indexLocation.getAbsolutePath())));
IndexSearcher searcher = new IndexSearcher(reader);
searcher.setSimilarity(new LMDirichletSimilarity(2500.0f));
MicroblogTopicSet topics = MicroblogTopicSet.fromFile(new File(searchArgs.topics));

MicroblogTopicSet topics = MicroblogTopicSet.fromFile(new File(topicsFile));
LOG.info("Initialized complete! (elapsed time = " + (System.nanoTime()-curTime)/1000000 + "ms)");
long totalTime = 0;
int cnt = 0;
for ( MicroblogTopic topic : topics ) {
long curQueryTime = System.nanoTime();

Filter filter = NumericRangeFilter.newLongRange(StatusField.ID.name, 0L, topic.getQueryTweetTime(), true, true);
Query query = AnalyzerUtils.buildBagOfWordsQuery(StatusField.TEXT.name, IndexTweets.ANALYZER, topic.getQuery());

TopDocs rs = searcher.search(query, filter, numResults);
TopDocs rs = searcher.search(query, filter, searchArgs.hits);

RerankerContext context = new RerankerContext(searcher, query, topic.getId(), topic.getQuery(),
Sets.newHashSet(AnalyzerUtils.tokenize(IndexTweets.ANALYZER, topic.getQuery())), filter);
RerankerCascade cascade = new RerankerCascade(context);

cascade.add(new RemoveRetweetsTemporalTiebreakReranker()).add(new LtrDataGenerator());

cascade.run(ScoredDocuments.fromTopDocs(rs, searcher));
cascade.run(ScoredDocuments.fromTopDocs(rs, searcher), context);
long qtime = (System.nanoTime()-curQueryTime)/1000000;
LOG.info("Query " + topic.getId() + " (elapsed time = " + qtime + "ms)");
totalTime += qtime;
cnt++;
}

LOG.info("All queries completed!");
LOG.info("Total elapsed time = " + totalTime + "ms");
LOG.info("Average query latency = " + (totalTime/cnt) + "ms");

reader.close();
out.close();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,36 @@
package io.anserini.ltr;

import io.anserini.index.IndexTweets.StatusField;
import io.anserini.ltr.feature.IntFeatureExtractors;
import io.anserini.ltr.feature.FeatureExtractors;
import io.anserini.ltr.feature.MatchingTermCount;
import io.anserini.ltr.feature.SumMatchingTf;
import io.anserini.ltr.feature.QueryFeatures;
import io.anserini.ltr.feature.SumMatchingTf;
import io.anserini.rerank.Reranker;
import io.anserini.rerank.RerankerContext;
import io.anserini.rerank.ScoredDocuments;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintStream;
import java.util.Arrays;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Terms;

public class LtrDataGenerator implements Reranker {
public LtrDataGenerator() {}
public class TweetsLtrDataGenerator implements Reranker {
private final PrintStream out;

public TweetsLtrDataGenerator(PrintStream out) throws FileNotFoundException {
this.out = out;
}

@Override
public ScoredDocuments rerank(ScoredDocuments docs, RerankerContext context) {
IndexReader reader = context.getIndexSearcher().getIndexReader();
IntFeatureExtractors intFeatureExtractors = new IntFeatureExtractors();
intFeatureExtractors.add(new MatchingTermCount()).add(new SumMatchingTf());
intFeatureExtractors.add(new QueryFeatures());
FeatureExtractors extractors = new FeatureExtractors();
extractors.add(new MatchingTermCount());
extractors.add(new SumMatchingTf());
extractors.add(new QueryFeatures());

for (int i = 0; i < docs.documents.length; i++) {
Terms terms = null;
Expand All @@ -33,14 +40,14 @@ public ScoredDocuments rerank(ScoredDocuments docs, RerankerContext context) {
continue;
}

System.out.print(context.getQueryId() + "\t");
System.out.print(docs.documents[i].getField(StatusField.ID.name).stringValue() + "\t");
System.out.print(docs.scores[i] + "\t");
out.print(context.getQueryId() + "\t");
out.print(docs.documents[i].getField(StatusField.ID.name).stringValue() + "\t");
out.print(docs.scores[i] + "\t");

int[] intFeatures = intFeatureExtractors.extractAll(terms, context);
System.out.print(Arrays.toString(intFeatures));
float[] intFeatures = extractors.extractAll(terms, context);
out.print(Arrays.toString(intFeatures));

System.out.print("\n");
out.print("\n");
}

return docs;
Expand Down
12 changes: 12 additions & 0 deletions src/main/java/io/anserini/ltr/feature/FeatureExtractor.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package io.anserini.ltr.feature;

import io.anserini.rerank.RerankerContext;

import org.apache.lucene.index.Terms;

/**
* A feature extractor.
*/
public interface FeatureExtractor {
float extract(Terms terms, RerankerContext context);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@
import com.google.common.collect.Lists;

/**
* A collection of {@link IntFeatureExtractor}s.
* A collection of {@link FeatureExtractor}s.
*/
public class IntFeatureExtractors {
public List<IntFeatureExtractor> extractors = Lists.newArrayList();
public class FeatureExtractors {
public List<FeatureExtractor> extractors = Lists.newArrayList();

public IntFeatureExtractors() {}
public FeatureExtractors() {}

public IntFeatureExtractors add(IntFeatureExtractor extractor) {
public FeatureExtractors add(FeatureExtractor extractor) {
extractors.add(extractor);
return this;
}

public int[] extractAll(Terms terms, RerankerContext context) {
int[] features = new int[extractors.size()];
public float[] extractAll(Terms terms, RerankerContext context) {
float[] features = new float[extractors.size()];

for (int i=0; i<extractors.size(); i++) {
features[i] = extractors.get(i).extract(terms, context);
Expand Down
12 changes: 0 additions & 12 deletions src/main/java/io/anserini/ltr/feature/IntFeatureExtractor.java

This file was deleted.

4 changes: 2 additions & 2 deletions src/main/java/io/anserini/ltr/feature/MatchingTermCount.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
* Computes the number of query terms that are found in the document. If there are three terms in
* the query and all three terms are found in the document, the feature value is three.
*/
public class MatchingTermCount implements IntFeatureExtractor {
public class MatchingTermCount implements FeatureExtractor {

@Override
public int extract(Terms terms, RerankerContext context) {
public float extract(Terms terms, RerankerContext context) {
try {
Set<String> queryTokens = context.getQueryTokens();
TermsEnum termsEnum = terms.iterator();
Expand Down
7 changes: 2 additions & 5 deletions src/main/java/io/anserini/ltr/feature/QueryFeatures.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,18 @@

import io.anserini.rerank.RerankerContext;

import java.io.IOException;
import java.util.Set;

import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.util.BytesRef;

/**
* Compute query features for LTR (as described by Macdonald et al., CIKM 2012)
* But just # of tokens for now :-(
*/
public class QueryFeatures implements IntFeatureExtractor {
public class QueryFeatures implements FeatureExtractor {

@Override
public int extract(Terms terms, RerankerContext context) {
public float extract(Terms terms, RerankerContext context) {
Set<String> queryTokens = context.getQueryTokens();
return queryTokens.size();
}
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/io/anserini/ltr/feature/SumMatchingTf.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
* terms and the first occurs twice in the document and the second occurs once in the document, the
* sum of the matching term frequencies is three.
*/
public class SumMatchingTf implements IntFeatureExtractor {
public class SumMatchingTf implements FeatureExtractor {

@Override
public int extract(Terms terms, RerankerContext context) {
public float extract(Terms terms, RerankerContext context) {
try {
Set<String> queryTokens = context.getQueryTokens();
TermsEnum termsEnum = terms.iterator();
Expand Down
7 changes: 1 addition & 6 deletions src/main/java/io/anserini/rerank/RerankerCascade.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,8 @@
* Representation of a cascade of rerankers, applied in sequence.
*/
public class RerankerCascade {
final RerankerContext context;
final List<Reranker> rerankers = Lists.newArrayList();

public RerankerCascade(RerankerContext context) {
this.context = context;
}

/**
* Adds a reranker to this cascade.
*
Expand All @@ -33,7 +28,7 @@ public RerankerCascade add(Reranker reranker) {
* @param docs input documents
* @return reranked results
*/
public ScoredDocuments run(ScoredDocuments docs) {
public ScoredDocuments run(ScoredDocuments docs, RerankerContext context) {
ScoredDocuments results = docs;

for (Reranker reranker : rerankers) {
Expand Down
Loading

0 comments on commit b5ab1cb

Please sign in to comment.