-
Notifications
You must be signed in to change notification settings - Fork 100
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Repo layout (adding logs, models, indexes, runs), MS MARCO passage re…
…plication doc (#24) * create docs, indexes, logs, models, runs and add MS Marco official eval * remove index-dir from settings for added clarity * index-dir bf2 * index-dir bf3 * index-dir bf4 * done doc * fix typos * fix typos 2 * fix heading * Update docs/experiments-msmarco-passage.md Co-authored-by: Rodrigo Frassetto Nogueira <rodrigonogueira4@gmail.com> * Update experiments-msmarco-passage.md resolve comments Co-authored-by: Rodrigo Frassetto Nogueira <rodrigonogueira4@gmail.com>
- Loading branch information
1 parent
2905235
commit 69de7db
Showing
8 changed files
with
357 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# PyGaggle: Neural Ranking Baselines on [MS MARCO Passage Retrieval](https://github.com/microsoft/MSMARCO-Passage-Ranking) | ||
|
||
This page contains instructions for running various neural reranking baselines on the MS MARCO *passage* ranking task. | ||
Note that there is also a separate [MS MARCO *document* ranking task](https://github.com/castorini/anserini/blob/master/docs/experiments-msmarco-doc.md). | ||
|
||
Prior to running this, we suggest looking at our first-stage [BM25 ranking instructions](https://github.com/castorini/anserini/blob/master/docs/experiments-msmarco-passage.md). | ||
We rerank the BM25 run files that contain ~1000 passages per query using both monoBERT and monoT5. | ||
monoBERT and monoT5 are pointwise rerankers. This means that each document is scored independently using either BERT or T5 respectively. | ||
|
||
Since it can take many hours to run these models on all of the 6980 queries from the MS MARCO dev set, we will instead use a subset of 105 queries randomly sampled from the dev set. | ||
Running these instructions with the entire MS MARCO dev set should give about the same results as that in the corresponding paper. | ||
|
||
Note 1: Run the following instructions at root of this repo. | ||
Note 2: Installation must have been done from source. | ||
Note 3: Make sure that you have access to a GPU | ||
|
||
## Models | ||
|
||
+ monoBERT-Large: Passage Re-ranking with BERT [(Nogueira et al., 2019)](https://arxiv.org/pdf/1901.04085.pdf) | ||
+ monoT5-base: Document Ranking with a Pretrained Sequence-to-Sequence Model [(Nogueira et al., 2020)](https://arxiv.org/pdf/2003.06713.pdf) | ||
|
||
## Data Prep | ||
|
||
We're first going to download the queries, qrels and run files corresponding to the MS MARCO set considered. The run file is generated by following the BM25 ranking instructions. We'll store all these files in the `data` directory. | ||
|
||
``` | ||
wget https://www.dropbox.com/s/5xa5vjbjle0c8jv/msmarco_ans_small.zip -P data | ||
``` | ||
|
||
To confirm, `msmarco_ans_small.zip` should have MD5 checksum of `65d8007bfb2c72b5fc384738e5572f74`. | ||
|
||
Next, we extract the contents into `data`. | ||
|
||
``` | ||
unzip msmarco_ans_small.zip -d data | ||
``` | ||
|
||
As a sanity check, we can evaluate the first-stage retrieved documents using the official MS MARCO evaluation script. | ||
|
||
``` | ||
python evaluate/msmarco/msmarco_eval.py data/msmarco_ans_small/qrels.dev.small.tsv data/msmarco_ans_small/run.dev.small.tsv | ||
``` | ||
|
||
The output should be: | ||
|
||
``` | ||
##################### | ||
MRR @10: 0.15906651549508694 | ||
QueriesRanked: 105 | ||
##################### | ||
``` | ||
|
||
Let's download and extract the pre-built MS MARCO index into `indexes`: | ||
|
||
``` | ||
wget https://git.uwaterloo.ca/jimmylin/anserini-indexes/raw/master/index-msmarco-passage-20191117-0ed488.tar.gz -P indexes | ||
tar xvfz indexes/index-msmarco-passage-20191117-0ed488.tar.gz -C indexes | ||
``` | ||
|
||
## Model Prep | ||
|
||
Let's download and extract monoBERT into `models`: | ||
|
||
``` | ||
wget https://www.dropbox.com/s/jr0hpksboh7pa48/monobert_msmarco_large.zip -P models | ||
unzip models/monobert_msmarco_large.zip -d models | ||
``` | ||
|
||
While running the re-ranking script with the monoT5 model, it is automatically downloaded from Google Cloud Storage. | ||
|
||
Now, we can begin with re-ranking the set. | ||
|
||
## Re-Ranking with monoBERT | ||
|
||
First, lets evaluate using monoBERT! | ||
|
||
``` | ||
python -um pygaggle.run.evaluate_passage_ranker --split dev \ | ||
--method seq_class_transformer \ | ||
--model-name-or-path models/monobert_msmarco_large \ | ||
--data-dir data/msmarco_ans_small/ \ | ||
--index-dir indexes/index-msmarco-passage-20191117-0ed488 \ | ||
--dataset msmarco \ | ||
--output-file runs/run.monobert.ans_small.dev.tsv | ||
``` | ||
|
||
Upon completion, the following output will be visible: | ||
|
||
``` | ||
precision@1 0.2761904761904762 | ||
recall@3 0.42698412698412697 | ||
recall@50 0.8174603174603176 | ||
recall@1000 0.8476190476190476 | ||
mrr 0.41089693612003686 | ||
mrr@10 0.4026795162509449 | ||
``` | ||
|
||
It takes about ~52 minutes to re-rank this subset on MS MARCO using a P100. | ||
The type of GPU will directly influence your inference time. | ||
It is possible that the default batch results in a GPU OOM error. | ||
In this case, assigning a batch size (using option `--batch-size`) which is smaller than the default (96) should help! | ||
|
||
The re-ranked run file `run.monobert.ans_small.dev.tsv` will also be available in the `runs` directory upon completion. | ||
|
||
We can use the official MS MARCO evaluation script to verify the MRR@10: | ||
|
||
``` | ||
python evaluate/msmarco/msmarco_eval.py data/msmarco_ans_small/qrels.dev.small.tsv runs/run.monobert.ans_small.dev.tsv | ||
``` | ||
|
||
You should see the same result. Great, let's move on to monoT5! | ||
|
||
## Re-Ranking with monoT5 | ||
|
||
We use the monoT5-base variant as it is the easiest to run without access to larger GPUs/TPUs. Let us now re-rank the set: | ||
|
||
``` | ||
python -um pygaggle.run.evaluate_passage_ranker --split dev \ | ||
--method t5 \ | ||
--model-name-or-path gs://neuralresearcher_data/doc2query/experiments/367 \ | ||
--data-dir data/msmarco_ans_small \ | ||
--model-type t5-base \ | ||
--dataset msmarco \ | ||
--index-dir indexes/index-msmarco-passage-20191117-0ed488 \ | ||
--batch-size 32 \ | ||
--output-file runs/run.monot5.ans_small.dev.tsv | ||
``` | ||
|
||
The following output will be visible after it has finished: | ||
|
||
``` | ||
precision@1 0.26666666666666666 | ||
recall@3 0.4603174603174603 | ||
recall@50 0.8063492063492063 | ||
recall@1000 0.8476190476190476 | ||
mrr 0.3973368360121561 | ||
mrr@10 0.39044217687074834 | ||
``` | ||
|
||
It takes about ~13 minutes to re-rank this subset on MS MARCO using a P100. | ||
It is worth noting again that you might need to modify the batch size to best fit the GPU at hand. | ||
|
||
Upon completion, the re-ranked run file `run.monot5.ans_small.dev.tsv` will be available in the `runs` directory. | ||
|
||
We can use the official MS MARCO evaluation script to verify the MRR@10: | ||
|
||
``` | ||
python evaluate/msmarco/msmarco_eval.py data/msmarco_ans_small/qrels.dev.small.tsv runs/run.monot5.ans_small.dev.tsv | ||
``` | ||
|
||
You should see the same result. | ||
|
||
If you were able to replicate any of these results, please submit a PR adding to the replication log! | ||
|
||
|
||
## Replication Log | ||
|
||
### monoBERT | ||
+ | ||
|
||
### monoT5 | ||
+ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
""" | ||
This module computes evaluation metrics for MSMARCO dataset on the ranking task. | ||
Command line: | ||
python msmarco_eval_ranking.py <path_to_reference_file> <path_to_candidate_file> | ||
Creation Date : 06/12/2018 | ||
Last Modified : 1/21/2019 | ||
Authors : Daniel Campos <dacamp@microsoft.com>, Rutger van Haasteren <ruvanh@microsoft.com> | ||
""" | ||
import sys | ||
import statistics | ||
|
||
from collections import Counter | ||
|
||
MaxMRRRank = 10 | ||
|
||
def load_reference_from_stream(f): | ||
"""Load Reference reference relevant passages | ||
Args:f (stream): stream to load. | ||
Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). | ||
""" | ||
qids_to_relevant_passageids = {} | ||
for l in f: | ||
try: | ||
l = l.strip().split('\t') | ||
qid = int(l[0]) | ||
if qid in qids_to_relevant_passageids: | ||
pass | ||
else: | ||
qids_to_relevant_passageids[qid] = [] | ||
qids_to_relevant_passageids[qid].append(int(l[2])) | ||
except: | ||
raise IOError('\"%s\" is not valid format' % l) | ||
return qids_to_relevant_passageids | ||
|
||
def load_reference(path_to_reference): | ||
"""Load Reference reference relevant passages | ||
Args:path_to_reference (str): path to a file to load. | ||
Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). | ||
""" | ||
with open(path_to_reference,'r') as f: | ||
qids_to_relevant_passageids = load_reference_from_stream(f) | ||
return qids_to_relevant_passageids | ||
|
||
def load_candidate_from_stream(f): | ||
"""Load candidate data from a stream. | ||
Args:f (stream): stream to load. | ||
Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance | ||
""" | ||
qid_to_ranked_candidate_passages = {} | ||
for l in f: | ||
try: | ||
l = l.strip().split('\t') | ||
qid = int(l[0]) | ||
pid = int(l[1]) | ||
rank = int(l[2]) | ||
if qid in qid_to_ranked_candidate_passages: | ||
pass | ||
else: | ||
# By default, all PIDs in the list of 1000 are 0. Only override those that are given | ||
tmp = [0] * 1000 | ||
qid_to_ranked_candidate_passages[qid] = tmp | ||
qid_to_ranked_candidate_passages[qid][rank-1]=pid | ||
except: | ||
raise IOError('\"%s\" is not valid format' % l) | ||
return qid_to_ranked_candidate_passages | ||
|
||
def load_candidate(path_to_candidate): | ||
"""Load candidate data from a file. | ||
Args:path_to_candidate (str): path to file to load. | ||
Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance | ||
""" | ||
|
||
with open(path_to_candidate,'r') as f: | ||
qid_to_ranked_candidate_passages = load_candidate_from_stream(f) | ||
return qid_to_ranked_candidate_passages | ||
|
||
def quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): | ||
"""Perform quality checks on the dictionaries | ||
Args: | ||
p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping | ||
Dict as read in with load_reference or load_reference_from_stream | ||
p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates | ||
Returns: | ||
bool,str: Boolean whether allowed, message to be shown in case of a problem | ||
""" | ||
message = '' | ||
allowed = True | ||
|
||
# Create sets of the QIDs for the submitted and reference queries | ||
candidate_set = set(qids_to_ranked_candidate_passages.keys()) | ||
ref_set = set(qids_to_relevant_passageids.keys()) | ||
|
||
# Check that we do not have multiple passages per query | ||
for qid in qids_to_ranked_candidate_passages: | ||
# Remove all zeros from the candidates | ||
duplicate_pids = set([item for item, count in Counter(qids_to_ranked_candidate_passages[qid]).items() if count > 1]) | ||
|
||
if len(duplicate_pids-set([0])) > 0: | ||
message = "Cannot rank a passage multiple times for a single query. QID={qid}, PID={pid}".format( | ||
qid=qid, pid=list(duplicate_pids)[0]) | ||
allowed = False | ||
|
||
return allowed, message | ||
|
||
def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): | ||
"""Compute MRR metric | ||
Args: | ||
p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping | ||
Dict as read in with load_reference or load_reference_from_stream | ||
p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates | ||
Returns: | ||
dict: dictionary of metrics {'MRR': <MRR Score>} | ||
""" | ||
all_scores = {} | ||
MRR = 0 | ||
qids_with_relevant_passages = 0 | ||
ranking = [] | ||
for qid in qids_to_ranked_candidate_passages: | ||
if qid in qids_to_relevant_passageids: | ||
ranking.append(0) | ||
target_pid = qids_to_relevant_passageids[qid] | ||
candidate_pid = qids_to_ranked_candidate_passages[qid] | ||
for i in range(0,MaxMRRRank): | ||
if candidate_pid[i] in target_pid: | ||
MRR += 1/(i + 1) | ||
ranking.pop() | ||
ranking.append(i+1) | ||
break | ||
if len(ranking) == 0: | ||
raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?") | ||
|
||
MRR = MRR/len(qids_to_relevant_passageids) | ||
all_scores['MRR @10'] = MRR | ||
all_scores['QueriesRanked'] = len(qids_to_ranked_candidate_passages) | ||
return all_scores | ||
|
||
def compute_metrics_from_files(path_to_reference, path_to_candidate, perform_checks=True): | ||
"""Compute MRR metric | ||
Args: | ||
p_path_to_reference_file (str): path to reference file. | ||
Reference file should contain lines in the following format: | ||
QUERYID\tPASSAGEID | ||
Where PASSAGEID is a relevant passage for a query. Note QUERYID can repeat on different lines with different PASSAGEIDs | ||
p_path_to_candidate_file (str): path to candidate file. | ||
Candidate file sould contain lines in the following format: | ||
QUERYID\tPASSAGEID1\tRank | ||
If a user wishes to use the TREC format please run the script with a -t flag at the end. If this flag is used the expected format is | ||
QUERYID\tITER\tDOCNO\tRANK\tSIM\tRUNID | ||
Where the values are separated by tabs and ranked in order of relevance | ||
Returns: | ||
dict: dictionary of metrics {'MRR': <MRR Score>} | ||
""" | ||
|
||
qids_to_relevant_passageids = load_reference(path_to_reference) | ||
qids_to_ranked_candidate_passages = load_candidate(path_to_candidate) | ||
if perform_checks: | ||
allowed, message = quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) | ||
if message != '': print(message) | ||
|
||
return compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) | ||
|
||
def main(): | ||
"""Command line: | ||
python msmarco_eval_ranking.py <path_to_reference_file> <path_to_candidate_file> | ||
""" | ||
|
||
if len(sys.argv) == 3: | ||
path_to_reference = sys.argv[1] | ||
path_to_candidate = sys.argv[2] | ||
metrics = compute_metrics_from_files(path_to_reference, path_to_candidate) | ||
print('#####################') | ||
for metric in sorted(metrics): | ||
print('{}: {}'.format(metric, metrics[metric])) | ||
print('#####################') | ||
|
||
else: | ||
print('Usage: msmarco_eval_ranking.py <reference ranking> <candidate ranking>') | ||
exit() | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# This is the default directory for indexes. Placeholder so that directory is kept in git. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# This is the default directory for logs. Placeholder so that directory is kept in git. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# This is the default directory for models. Placeholder so that directory is kept in git. |
Oops, something went wrong.