-
Notifications
You must be signed in to change notification settings - Fork 337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dense retrieval draft #278
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
0677127
init commit for dense retrieval
MXueguang 9293d91
first draft of dsearcher implementation
MXueguang 09d0dab
add QueryEncoder for dsearch
MXueguang a112edb
add batch_search for SimpleDenseSearcher which will run faiss cpu mul…
MXueguang 7f9c272
move densesearch to dsearch
MXueguang 443c093
auto download prebuild faiss index
MXueguang 57b8cd4
update pypi replication for dense retrieval
MXueguang a7b9018
add pypi replication code for dense retrieval
MXueguang 587734a
update dependency for dense retrieval
MXueguang 7cd184d
Merge branch 'master' into dense_retrieval
MXueguang aa568de
move dense retrieval doc to new doc file
MXueguang 2561cae
fix typo in dense retrieval
MXueguang ea1eedc
complete load pre encoded queires feature
MXueguang ac65688
update doc and change cli option name
MXueguang File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Dense Retrieval Replication | ||
|
||
Please follow the development installation [here](https://github.com/castorini/pyserini#development-installation) to setup `pyserini` | ||
It's easy to replicate runs on dense retrieval experiments! | ||
|
||
## MS MARCO Passage Ranking | ||
|
||
MS MARCO passage ranking task, dense retrieval with TCT-ColBERT, HNSW index. | ||
```bash | ||
$ python -m pyserini.dsearch --topics msmarco_passage_dev_subset \ | ||
--index msmarco-passage-tct_colbert-hnsw \ | ||
--encoded-queries msmarco-passage-dev-subset-tct_colbert \ | ||
--output runs/run.msmarco-passage.tct_colbert.hnsw.tsv \ | ||
--msmarco | ||
``` | ||
|
||
To evaluate: | ||
|
||
Using official MS MARCO evaluation script: | ||
```bash | ||
$ wget https://www.dropbox.com/s/khsplt2fhqwjs0v/qrels.dev.small.tsv -P collections/msmarco-passage/ | ||
$ python tools/scripts/msmarco/msmarco_eval.py qrels.dev.small.tsv runs/run.msmarco-passage.tct_colbert.hnsw.tsv | ||
``` | ||
``` | ||
##################### | ||
MRR @10: 0.33395142584254184 | ||
QueriesRanked: 6980 | ||
##################### | ||
``` | ||
We can also use the official TREC evaluation tool, trec_eval, to compute other metrics than MRR@10. | ||
For that we first need to convert runs and qrels files to the TREC format: | ||
``` | ||
python tools/scripts/msmarco/convert_msmarco_to_trec_run.py \ | ||
--input runs/run.msmarco-passage.tct_colbert.hnsw.tsv \ | ||
--output runs/run.msmarco-passage.tct_colbert.hnsw.trec | ||
|
||
python tools/scripts/msmarco/convert_msmarco_to_trec_qrels.py \ | ||
--input collections/msmarco-passage/qrels.dev.small.tsv \ | ||
--output collections/msmarco-passage/qrels.dev.small.trec | ||
``` | ||
And run the trec_eval tool: | ||
``` | ||
tools/eval/trec_eval.9.0.4/trec_eval -c -mrecall.1000 -mmap \ | ||
collections/msmarco-passage/qrels.dev.small.trec runs/run.msmarco-passage.tct_colbert.hnsw.trec | ||
``` | ||
``` | ||
map all 0.3407 | ||
recall_1000 all 0.9618 | ||
``` | ||
|
||
MS MARCO passage ranking task, dense retrieval with TCT-ColBERT, brute force index. | ||
|
||
```bash | ||
$ python -m pyserini.dsearch --topics msmarco_passage_dev_subset \ | ||
--index msmarco-passage-tct_colbert-bf \ | ||
--encoded-queries msmarco-passage-dev-subset-tct_colbert \ | ||
--batch 12 \ | ||
--output runs/run.msmarco-passage.tct_colbert.bf.tsv \ | ||
--msmarco | ||
``` | ||
|
||
To evaluate: | ||
|
||
Using official MS MARCO evaluation script: | ||
```bash | ||
$ wget https://www.dropbox.com/s/khsplt2fhqwjs0v/qrels.dev.small.tsv -P collections/msmarco-passage/ | ||
$ python tools/scripts/msmarco/msmarco_eval.py collections/msmarco-passage/qrels.dev.small.tsv runs/run.msmarco-passage.tct_colbert.bf.tsv | ||
``` | ||
``` | ||
##################### | ||
MRR @10: 0.3344603629417369 | ||
QueriesRanked: 6980 | ||
##################### | ||
``` | ||
|
||
We can also use the official TREC evaluation tool, trec_eval, to compute other metrics than MRR@10. | ||
For that we first need to convert runs and qrels files to the TREC format: | ||
``` | ||
python tools/scripts/msmarco/convert_msmarco_to_trec_run.py \ | ||
--input runs/run.msmarco-passage.tct_colbert.bf.tsv \ | ||
--output runs/run.msmarco-passage.tct_colbert.bf.trec | ||
|
||
python tools/scripts/msmarco/convert_msmarco_to_trec_qrels.py \ | ||
--input collections/msmarco-passage/qrels.dev.small.tsv \ | ||
--output collections/msmarco-passage/qrels.dev.small.trec | ||
``` | ||
And run the trec_eval tool: | ||
``` | ||
tools/eval/trec_eval.9.0.4/trec_eval -c -mrecall.1000 -mmap \ | ||
collections/msmarco-passage/qrels.dev.small.trec runs/run.msmarco-passage.tct_colbert.bf.trec | ||
``` | ||
|
||
``` | ||
map all 0.3412 | ||
recall_1000 all 0.9637 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# | ||
# Pyserini: Python interface to the Anserini IR toolkit built on Lucene | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from ._dsearcher import SimpleDenseSearcher, QueryEncoder | ||
|
||
__all__ = ['SimpleDenseSearcher', 'QueryEncoder'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# | ||
# Pyserini: Python interface to the Anserini IR toolkit built on Lucene | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import argparse | ||
import os | ||
|
||
import numpy as np | ||
from tqdm import tqdm | ||
|
||
from pyserini.dsearch import QueryEncoder, SimpleDenseSearcher | ||
from pyserini.search import get_topics | ||
|
||
parser = argparse.ArgumentParser(description='Search a Faiss index.') | ||
parser.add_argument('--index', type=str, metavar='path to index or index name', required=True, | ||
help="Path to Faiss index or name of prebuilt index.") | ||
parser.add_argument('--topics', type=str, metavar='topic_name', required=True, | ||
help="Name of topics. Available: msmarco_passage_dev_subset.") | ||
parser.add_argument('--encoded-queries', type=str, metavar='path to query embedding or query name', required=True, | ||
help="Path to query embedding or name of pre encoded queries") | ||
parser.add_argument('--hits', type=int, metavar='num', required=False, default=1000, help="Number of hits.") | ||
parser.add_argument('--batch', type=int, metavar='num', required=False, default=1, | ||
help="search batch of queries in parallel") | ||
parser.add_argument('--msmarco', action='store_true', default=False, help="Output in MS MARCO format.") | ||
parser.add_argument('--output', type=str, metavar='path', required=True, help="Path to output file.") | ||
args = parser.parse_args() | ||
|
||
topics = get_topics(args.topics) | ||
|
||
if os.path.exists(args.encoded_queries): | ||
# create query encoder from query embedding directory | ||
query_encoder = QueryEncoder(args.encoded_queries) | ||
else: | ||
# create query encoder from pre encoded query name | ||
query_encoder = QueryEncoder.load_encoded_queries(args.encoded_queries) | ||
|
||
if not query_encoder: | ||
exit() | ||
|
||
if os.path.exists(args.index): | ||
# create searcher from index directory | ||
searcher = SimpleDenseSearcher(args.index) | ||
else: | ||
# create searcher from prebuilt index name | ||
searcher = SimpleDenseSearcher.from_prebuilt_index(args.index) | ||
|
||
if not searcher: | ||
exit() | ||
|
||
# invalid topics name | ||
if topics == {}: | ||
print(f'Topic {args.topics} Not Found') | ||
exit() | ||
|
||
# build output path | ||
output_path = args.output | ||
|
||
print(f'Running {args.topics} topics, saving to {output_path}...') | ||
tag = 'Faiss' | ||
|
||
if args.batch > 1: | ||
with open(output_path, 'w') as target_file: | ||
topic_keys = sorted(topics.keys()) | ||
for i in tqdm(range(0, len(topic_keys), args.batch)): | ||
topic_key_batch = topic_keys[i: i+args.batch] | ||
topic_emb_batch = np.array([query_encoder.encode(topics[topic].get('title').strip()) | ||
for topic in topic_key_batch]) | ||
hits = searcher.batch_search(topic_emb_batch, topic_key_batch, k=args.hits, threads=args.batch) | ||
for topic in hits: | ||
for idx, hit in enumerate(hits[topic]): | ||
if args.msmarco: | ||
target_file.write(f'{topic}\t{hit.docid}\t{idx + 1}\n') | ||
else: | ||
target_file.write(f'{topic} Q0 {hit.docid} {idx + 1} {hit.score:.6f} {tag}\n') | ||
exit() | ||
|
||
with open(output_path, 'w') as target_file: | ||
for index, topic in enumerate(tqdm(sorted(topics.keys()))): | ||
search = topics[topic].get('title').strip() | ||
hits = searcher.search(query_encoder.encode(search), args.hits) | ||
docids = [hit.docid.strip() for hit in hits] | ||
scores = [hit.score for hit in hits] | ||
|
||
if args.msmarco: | ||
for i, docid in enumerate(docids): | ||
target_file.write(f'{topic}\t{docid}\t{i + 1}\n') | ||
else: | ||
for i, (docid, score) in enumerate(zip(docids, scores)): | ||
target_file.write(f'{topic} Q0 {docid} {i + 1} {score:.6f} {tag}\n') |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got lost here reading the code... the query encodings are pre-stored somewhere right? Where are they loaded?
Should we also have something like
.load_encoded_queries()
or something like that to load pre-encoded queries?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a line
query_encoder = QueryEncoder(searcher.index_dir)
where I loaded the pre encoded query from theindex dir
I place the pre encoded query with the pre build index rn, i.e. the registered
msmarco-passage-tct_colbert.tar.gz
containsindex
+docid
+query text
+query embedding
.will we always have a pre_encoded query for a prebuild index? If so, I feel we can pack prebuild index and pre encoded queries together like like what i am doing rn?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I think the pre-encoded queries should be separate, because an index can have multiple queries - for example, for MS MARCO, there's dev queries and test queries.
I think something like
.load_encoded_queries()
would be clearer.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, sg