Skip to content
Permalink
Browse files

Java rewrite of retrieve.py for fetching MS MARCO passages (#644)

  • Loading branch information...
edwardhdlu authored and lintool committed May 15, 2019
1 parent 32e0191 commit 949cd3940d8795058270ec365f842498a5c5df49
@@ -60,6 +60,13 @@ python ./src/main/python/msmarco/retrieve.py --index ${DATA_DIR}/lucene-index-ms
Retrieval speed will vary by machine:
On a modern desktop with an SSD, we can get ~0.04 per query (taking about five minutes).
On a slower machine with mechanical disks, the entire process might take as long as a couple of hours.
Alternatively, we can run the same script implemented in Java to remove Python overhead, which ends up being ~4x faster.

```
./target/appassembler/bin/SearchMsmarco -index ${DATA_DIR}/lucene-index-msmarco \
-qid_queries ${DATA_DIR}/queries.dev.small.tsv -output ${DATA_DIR}/run.dev.small.tsv -hits 1000
```

The option `-hits` specifies the of documents per query to be retrieved.
Thus, the output file should have approximately 6980 * 1000 = 6.9M lines.

@@ -105,6 +105,10 @@
<mainClass>io.anserini.search.SearchCollection</mainClass>
<id>SearchCollection</id>
</program>
<program>
<mainClass>io.anserini.search.SearchMsmarco</mainClass>
<id>SearchMsmarco</id>
</program>
<program>
<mainClass>io.anserini.eval.Eval</mainClass>
<id>Eval</id>
@@ -0,0 +1,39 @@
package io.anserini.search;

import org.kohsuke.args4j.Option;
import org.kohsuke.args4j.spi.StringArrayOptionHandler;

public class RetrieveArgs {
// required arguments
@Option(name = "-qid_queries", metaVar = "[file]", required = true, usage="query id - query mapping file")
public String qid_queries = "";

@Option(name = "-output", metaVar = "[file]", required = true, usage = "output file")
public String output = "";

@Option(name = "-index", metaVar = "[path]", required = true, usage = "index path")
public String index = "";

// optional arguments
@Option(name = "-hits", metaVar = "[number]", usage = "number of hits to retrieve")
public int hits = 10;

@Option(name = "-k1", metaVar = "[value]", usage = "BM25 k1 parameter")
public float k1 = 0.82f;

@Option(name = "-b", metaVar = "[value]", usage = "BM25 b parameter")
public float b = 0.72f;

// See our MS MARCO documentation to understand how these parameter values were tuned.
@Option(name = "-rm3", usage = "use RM3 query expansion model")
public boolean rm3 = false;

@Option(name = "-fbTerms", metaVar = "[number]", usage = "RM3 parameter: number of expansion terms")
public int fbTerms = 10;

@Option(name = "-fbDocs", metaVar = "[number]", usage = "RM3 parameter: number of documents")
public int fbDocs = 10;

@Option(name = "-originalQueryWeight", metaVar = "[value]", usage = "RM3 parameter: weight to assign to the original query")
public float originalQueryWeight = 0.5f;
}
@@ -224,4 +224,7 @@

@Option(name = "-model", metaVar = "[file]", required = false, usage = "ranklib model file")
public String model = "";

@Option(name = "-qid_queries", metaVar = "[file]", usage="query id - query mapping file")
public String qid_queries = "";
}
@@ -0,0 +1,66 @@
package io.anserini.search;

import org.apache.commons.io.FileUtils;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.OptionHandlerFilter;
import org.kohsuke.args4j.ParserProperties;

import java.io.File;
import java.util.List;

/*
* Java rewrite of retrieve.py
*/
public class SearchMsmarco {
public static void main(String[] args) throws Exception {
RetrieveArgs retrieveArgs = new RetrieveArgs();
CmdLineParser parser = new CmdLineParser(retrieveArgs, ParserProperties.defaults().withUsageWidth(90));

try {
parser.parseArgument(args);
} catch (CmdLineException e) {
System.err.println(e.getMessage());
parser.printUsage(System.err);
System.err.println("Example: Eval " + parser.printExample(OptionHandlerFilter.REQUIRED));
return;
}

SimpleSearcher searcher = new SimpleSearcher(retrieveArgs.index);
searcher.setBM25Similarity(retrieveArgs.k1, retrieveArgs.b);
System.out.println("Initializing BM25, setting k1=" + retrieveArgs.k1 + " and b=" + retrieveArgs.b + "");

if (retrieveArgs.rm3) {
searcher.setRM3Reranker(retrieveArgs.fbTerms, retrieveArgs.fbDocs, retrieveArgs.originalQueryWeight);
System.out.println("Initializing RM3, setting fbTerms=" + retrieveArgs.fbTerms + ", fbDocs=" + retrieveArgs.fbDocs
+ " and originalQueryWeight=" + retrieveArgs.originalQueryWeight);
}

File fout = new File(retrieveArgs.output);
FileUtils.writeStringToFile(fout, "", "utf-8"); // clear the file

long startTime = System.nanoTime();
List<String> lines = FileUtils.readLines(new File(retrieveArgs.qid_queries), "utf-8");

for (int lineNumber = 0; lineNumber < lines.size(); ++lineNumber) {
String line = lines.get(lineNumber);
String[] split = line.trim().split("\t");
String qid = split[0];
String query = split[1];

SimpleSearcher.Result[] hits = searcher.search(query, retrieveArgs.hits);

if (lineNumber % 10 == 0) {
double timePerQuery = (double) (System.nanoTime() - startTime) / (lineNumber + 1) / 10e9;
System.out.format("Retrieving query " + lineNumber + " (%.3f s/query)\n", timePerQuery);
}

for (int rank = 0; rank < hits.length; ++rank) {
String docno = hits[rank].docid;
FileUtils.writeStringToFile(fout, qid + "\t" + docno + "\t" + (rank + 1) + "\n", "utf-8", true);
}
}

System.out.println("Done!");
}
}

0 comments on commit 949cd39

Please sign in to comment.
You can’t perform that action at this time.