Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dense retrieval draft #278

Merged
merged 14 commits into from
Jan 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
96 changes: 96 additions & 0 deletions docs/dense-retrieval.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Dense Retrieval Replication

Please follow the development installation [here](https://github.com/castorini/pyserini#development-installation) to setup `pyserini`
It's easy to replicate runs on dense retrieval experiments!

## MS MARCO Passage Ranking

MS MARCO passage ranking task, dense retrieval with TCT-ColBERT, HNSW index.
```bash
$ python -m pyserini.dsearch --topics msmarco_passage_dev_subset \
--index msmarco-passage-tct_colbert-hnsw \
--encoded-queries msmarco-passage-dev-subset-tct_colbert \
--output runs/run.msmarco-passage.tct_colbert.hnsw.tsv \
--msmarco
```

To evaluate:

Using official MS MARCO evaluation script:
```bash
$ wget https://www.dropbox.com/s/khsplt2fhqwjs0v/qrels.dev.small.tsv -P collections/msmarco-passage/
$ python tools/scripts/msmarco/msmarco_eval.py qrels.dev.small.tsv runs/run.msmarco-passage.tct_colbert.hnsw.tsv
```
```
#####################
MRR @10: 0.33395142584254184
QueriesRanked: 6980
#####################
```
We can also use the official TREC evaluation tool, trec_eval, to compute other metrics than MRR@10.
For that we first need to convert runs and qrels files to the TREC format:
```
python tools/scripts/msmarco/convert_msmarco_to_trec_run.py \
--input runs/run.msmarco-passage.tct_colbert.hnsw.tsv \
--output runs/run.msmarco-passage.tct_colbert.hnsw.trec

python tools/scripts/msmarco/convert_msmarco_to_trec_qrels.py \
--input collections/msmarco-passage/qrels.dev.small.tsv \
--output collections/msmarco-passage/qrels.dev.small.trec
```
And run the trec_eval tool:
```
tools/eval/trec_eval.9.0.4/trec_eval -c -mrecall.1000 -mmap \
collections/msmarco-passage/qrels.dev.small.trec runs/run.msmarco-passage.tct_colbert.hnsw.trec
```
```
map all 0.3407
recall_1000 all 0.9618
```

MS MARCO passage ranking task, dense retrieval with TCT-ColBERT, brute force index.

```bash
$ python -m pyserini.dsearch --topics msmarco_passage_dev_subset \
--index msmarco-passage-tct_colbert-bf \
--encoded-queries msmarco-passage-dev-subset-tct_colbert \
--batch 12 \
--output runs/run.msmarco-passage.tct_colbert.bf.tsv \
--msmarco
```

To evaluate:

Using official MS MARCO evaluation script:
```bash
$ wget https://www.dropbox.com/s/khsplt2fhqwjs0v/qrels.dev.small.tsv -P collections/msmarco-passage/
$ python tools/scripts/msmarco/msmarco_eval.py collections/msmarco-passage/qrels.dev.small.tsv runs/run.msmarco-passage.tct_colbert.bf.tsv
```
```
#####################
MRR @10: 0.3344603629417369
QueriesRanked: 6980
#####################
```

We can also use the official TREC evaluation tool, trec_eval, to compute other metrics than MRR@10.
For that we first need to convert runs and qrels files to the TREC format:
```
python tools/scripts/msmarco/convert_msmarco_to_trec_run.py \
--input runs/run.msmarco-passage.tct_colbert.bf.tsv \
--output runs/run.msmarco-passage.tct_colbert.bf.trec

python tools/scripts/msmarco/convert_msmarco_to_trec_qrels.py \
--input collections/msmarco-passage/qrels.dev.small.tsv \
--output collections/msmarco-passage/qrels.dev.small.trec
```
And run the trec_eval tool:
```
tools/eval/trec_eval.9.0.4/trec_eval -c -mrecall.1000 -mmap \
collections/msmarco-passage/qrels.dev.small.trec runs/run.msmarco-passage.tct_colbert.bf.trec
```

```
map all 0.3412
recall_1000 all 0.9637
```
1 change: 1 addition & 0 deletions docs/pypi-replication.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ map all 0.2805
recall_1000 all 0.9470
```


## MS MARCO Document Ranking

MS MARCO document ranking task, BM25 baseline:
Expand Down
19 changes: 19 additions & 0 deletions pyserini/dsearch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#
# Pyserini: Python interface to the Anserini IR toolkit built on Lucene
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from ._dsearcher import SimpleDenseSearcher, QueryEncoder

__all__ = ['SimpleDenseSearcher', 'QueryEncoder']
101 changes: 101 additions & 0 deletions pyserini/dsearch/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#
# Pyserini: Python interface to the Anserini IR toolkit built on Lucene
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
import os

import numpy as np
from tqdm import tqdm

from pyserini.dsearch import QueryEncoder, SimpleDenseSearcher
from pyserini.search import get_topics

parser = argparse.ArgumentParser(description='Search a Faiss index.')
parser.add_argument('--index', type=str, metavar='path to index or index name', required=True,
help="Path to Faiss index or name of prebuilt index.")
parser.add_argument('--topics', type=str, metavar='topic_name', required=True,
help="Name of topics. Available: msmarco_passage_dev_subset.")
parser.add_argument('--encoded-queries', type=str, metavar='path to query embedding or query name', required=True,
help="Path to query embedding or name of pre encoded queries")
parser.add_argument('--hits', type=int, metavar='num', required=False, default=1000, help="Number of hits.")
parser.add_argument('--batch', type=int, metavar='num', required=False, default=1,
help="search batch of queries in parallel")
parser.add_argument('--msmarco', action='store_true', default=False, help="Output in MS MARCO format.")
parser.add_argument('--output', type=str, metavar='path', required=True, help="Path to output file.")
args = parser.parse_args()

topics = get_topics(args.topics)

if os.path.exists(args.encoded_queries):
# create query encoder from query embedding directory
query_encoder = QueryEncoder(args.encoded_queries)
else:
# create query encoder from pre encoded query name
query_encoder = QueryEncoder.load_encoded_queries(args.encoded_queries)

if not query_encoder:
exit()

if os.path.exists(args.index):
# create searcher from index directory
searcher = SimpleDenseSearcher(args.index)
else:
# create searcher from prebuilt index name
searcher = SimpleDenseSearcher.from_prebuilt_index(args.index)

if not searcher:
exit()

# invalid topics name
if topics == {}:
print(f'Topic {args.topics} Not Found')
exit()

# build output path
output_path = args.output

print(f'Running {args.topics} topics, saving to {output_path}...')
tag = 'Faiss'

if args.batch > 1:
with open(output_path, 'w') as target_file:
topic_keys = sorted(topics.keys())
for i in tqdm(range(0, len(topic_keys), args.batch)):
topic_key_batch = topic_keys[i: i+args.batch]
topic_emb_batch = np.array([query_encoder.encode(topics[topic].get('title').strip())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got lost here reading the code... the query encodings are pre-stored somewhere right? Where are they loaded?

Should we also have something like .load_encoded_queries() or something like that to load pre-encoded queries?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a line query_encoder = QueryEncoder(searcher.index_dir) where I loaded the pre encoded query from the index dir
I place the pre encoded query with the pre build index rn, i.e. the registered msmarco-passage-tct_colbert.tar.gz contains index+docid+query text+query embedding.

will we always have a pre_encoded query for a prebuild index? If so, I feel we can pack prebuild index and pre encoded queries together like like what i am doing rn?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I think the pre-encoded queries should be separate, because an index can have multiple queries - for example, for MS MARCO, there's dev queries and test queries.

I think something like .load_encoded_queries() would be clearer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, sg

for topic in topic_key_batch])
hits = searcher.batch_search(topic_emb_batch, topic_key_batch, k=args.hits, threads=args.batch)
for topic in hits:
for idx, hit in enumerate(hits[topic]):
if args.msmarco:
target_file.write(f'{topic}\t{hit.docid}\t{idx + 1}\n')
else:
target_file.write(f'{topic} Q0 {hit.docid} {idx + 1} {hit.score:.6f} {tag}\n')
exit()

with open(output_path, 'w') as target_file:
for index, topic in enumerate(tqdm(sorted(topics.keys()))):
search = topics[topic].get('title').strip()
hits = searcher.search(query_encoder.encode(search), args.hits)
docids = [hit.docid.strip() for hit in hits]
scores = [hit.score for hit in hits]

if args.msmarco:
for i, docid in enumerate(docids):
target_file.write(f'{topic}\t{docid}\t{i + 1}\n')
else:
for i, (docid, score) in enumerate(zip(docids, scores)):
target_file.write(f'{topic} Q0 {docid} {i + 1} {score:.6f} {tag}\n')