Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add scripts and docs for LisT5 FEVER pipeline (#203)
* added scripts and docs for LisT5 pipeline for FEVER task * added short description and link to paper in README
- Loading branch information
1 parent
8fbe93d
commit 5ba5dda
Showing
18 changed files
with
2,053 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# LisT5: FEVER Pipeline with T5 | ||
|
||
This page describes replication for the LisT5 pipeline for fact verification, outlined in the following paper: | ||
* Kelvin Jiang, Ronak Pradeep, Jimmy Lin. [Exploring Listwise Evidence Reasoning with T5 for Fact Verification.](https://aclanthology.org/2021.acl-short.51.pdf) _ACL 2021_. | ||
|
||
Some initial setup: | ||
|
||
```bash | ||
mkdir data/list5 | ||
mkdir runs/list5 | ||
``` | ||
|
||
## Document Retrieval | ||
|
||
1. Retrieve with anserini | ||
|
||
Follow instructions [here](https://github.com/castorini/anserini/blob/master/docs/experiments-fever.md) to download the FEVER dataset and build an index with anserini. Assume the anserini directory is located at `~/anserini`. After the dataset has been indexed, run the following command to retrieve (note the use of the paragraph index): | ||
|
||
```bash | ||
sh ~/anserini/target/appassembler/bin/SearchCollection \ | ||
-index ~/anserini/indexes/fever/lucene-index-fever-paragraph \ | ||
-topicreader TsvInt -topics ~/anserini/collections/fever/queries.sentence.test.tsv \ | ||
-output runs/list5/run.fever-anserini-paragraph.test.tsv \ | ||
-bm25 -bm25.k1 0.6 -bm25.b 0.5 | ||
``` | ||
|
||
2. Retrieve with MediaWiki API (UKP-Athene) | ||
|
||
Also retrieve documents with MediaWiki API, code originating from [UKP-Athene's repository](https://github.com/UKPLab/fever-2018-team-athene). Install the dependencies for UKP-Athene's code using `experiments/list5/ukp-athene/requirements.txt` if necessary. | ||
|
||
```bash | ||
python experiments/list5/ukp-athene/doc_retrieval.py \ | ||
--db-file ~/anserini/collections/fever/wiki-pages \ | ||
--in-file ~/anserini/collections/fever/shared_task_test.jsonl \ | ||
--out-file runs/list5/run.fever-ukp-athene-paragraph.test.jsonl | ||
``` | ||
|
||
Convert the MediaWiki API results to run format. | ||
|
||
```bash | ||
python experiments/list5/ukp-athene/convert_to_run.py \ | ||
--dataset_file runs/list5/run.fever-ukp-athene-paragraph.test.jsonl \ | ||
--output_run_file runs/list5/run.fever-ukp-athene-paragraph.test.tsv | ||
``` | ||
|
||
3. Merge retrieval runs | ||
|
||
Merge the results of the two methods of retrieval into a single run. Make sure that the anserini run file comes first in the list of `--input_run_file` arguments. | ||
|
||
```bash | ||
python experiments/list5/merge_runs.py \ | ||
--input_run_file runs/list5/run.fever-anserini-paragraph.test.tsv \ | ||
--input_run_file runs/list5/run.fever-ukp-athene-paragraph.test.tsv \ | ||
--output_run_file runs/list5/run.fever-paragraph.test.tsv \ | ||
--strategy zip | ||
``` | ||
|
||
## Sentence Selection | ||
|
||
4. Expand document IDs to all sentence IDs | ||
|
||
Expand run file to a sentence ID granularity. | ||
|
||
```bash | ||
python experiments/list5/expand_docs_to_sentences.py \ | ||
--input_run_file runs/list5/run.fever-paragraph.test.tsv \ | ||
--collection_folder ~/anserini/collections/fever/wiki-pages \ | ||
--output_run_file runs/list5/run.fever-sentence-top-150.test.tsv \ | ||
--k 150 | ||
``` | ||
|
||
5. Convert run file to T5 input file for monoT5 re-ranking | ||
|
||
Re-rank the top `k = 200` sentences, a tradeoff between efficiency and recall. | ||
|
||
```bash | ||
python experiments/list5/convert_run_to_sentence_selection_input.py \ | ||
--dataset_file ~/anserini/collections/fever/shared_task_test.jsonl \ | ||
--run_file runs/list5/run.fever-sentence-top-150.test.tsv \ | ||
--collection_folder ~/anserini/collections/fever/wiki-pages \ | ||
--output_id_file data/list5/query-doc-pairs-id-test-ner-rerank-top-200.txt \ | ||
--output_text_file data/list5/query-doc-pairs-text-test-ner-rerank-top-200.txt \ | ||
--k 200 --type mono --ner | ||
``` | ||
|
||
6. Re-rank T5 input file to get scores file | ||
|
||
Run inference of T5 sentence selection model (e.g. using Google Cloud TPUs). Assume the sentence selection T5 output file is at `data/list5/query-doc-pairs-scores-test-ner-rerank-top-200.txt`. | ||
|
||
7. Convert scores file back to run file | ||
|
||
```bash | ||
python experiments/list5/convert_sentence_selection_output_to_run.py \ | ||
--id_file data/list5/query-doc-pairs-id-test-ner-rerank-top-200.txt \ | ||
--scores_file data/list5/query-doc-pairs-scores-test-ner-rerank-top-200.txt \ | ||
--output_run_file runs/list5/run.fever-sentence-top-150-reranked.txt \ | ||
--type mono | ||
``` | ||
|
||
## Label Prediction | ||
|
||
8. Convert re-ranked run file to T5 input file for label prediction | ||
|
||
Make sure to use `--format concat` to specify listwise format. | ||
|
||
```bash | ||
python experiments/list5/convert_run_to_label_prediction_input.py \ | ||
--dataset_file ~/anserini/collections/fever/shared_task_test.jsonl \ | ||
--run_file runs/list5/run.fever-sentence-top-150-reranked.txt \ | ||
--collection_folder ~/anserini/collections/fever/wiki-pages \ | ||
--output_id_file data/list5/query-doc-pairs-id-test-ner-label-pred-concat.txt \ | ||
--output_text_file data/list5/query-doc-pairs-text-test-ner-label-pred-concat.txt \ | ||
--format concat | ||
``` | ||
|
||
9. Predict labels for T5 input file to get scores file | ||
|
||
Run inference of T5 label prediction model (e.g. using Google Cloud TPUs). Assume the label prediction T5 output file is at `data/list5/query-doc-pairs-scores-test-ner-label-pred-concat.txt`. | ||
|
||
10. Convert scores file to FEVER submission file | ||
|
||
```bash | ||
python experiments/list5/predict_for_submission.py \ | ||
--id_file data/list5/query-doc-pairs-id-test-ner-label-pred-concat.txt \ | ||
--scores_file data/list5/query-doc-pairs-scores-test-ner-label-pred-concat.txt \ | ||
--dataset_file ~/anserini/collections/fever/shared_task_test.jsonl \ | ||
--output_predictions_file data/list5/predictions.jsonl | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import argparse | ||
import json | ||
|
||
def calculate_scores(args): | ||
evidences = {} | ||
labels = {} | ||
|
||
total = 0 | ||
correct = 0 | ||
strict_correct = 0 | ||
total_hits = 0 | ||
total_precision = 0 | ||
total_recall = 0 | ||
|
||
with open(args.dataset_file, 'r', encoding='utf-8') as f: | ||
for line in f: | ||
line_json = json.loads(line.strip()) | ||
query_id = line_json['id'] | ||
if 'label' in line_json: # no "label" field in test datasets | ||
label = line_json['label'] | ||
labels[query_id] = label | ||
|
||
if label not in ['SUPPORTS', 'REFUTES']: | ||
continue | ||
|
||
evidence_sets = [] | ||
for annotator in line_json['evidence']: | ||
evidence_set = [[evidence[2], evidence[3]] for evidence in annotator] | ||
evidence_sets.append(evidence_set) | ||
evidences[query_id] = evidence_sets | ||
|
||
def check_evidence_set(pred_evidence_set, true_evidence_sets): | ||
for evidence_set in true_evidence_sets: | ||
if all([evidence in pred_evidence_set for evidence in evidence_set]): | ||
return True | ||
|
||
return False | ||
|
||
with open(args.submission_file, 'r', encoding='utf-8') as f: | ||
for line in f: | ||
line_json = json.loads(line.strip()) | ||
query_id = line_json['id'] | ||
pred_label = line_json['predicted_label'] | ||
pred_evidence_set = line_json['predicted_evidence'] | ||
|
||
total += 1 | ||
if pred_label == labels[query_id]: | ||
correct += 1 | ||
if labels[query_id] == 'NOT ENOUGH INFO' or check_evidence_set(pred_evidence_set, evidences[query_id]): | ||
strict_correct += 1 | ||
|
||
if labels[query_id] != 'NOT ENOUGH INFO': | ||
total_hits += 1 | ||
|
||
# calculate precision | ||
correct_evidence = [ev for ev_set in evidences[query_id] for ev in ev_set if ev[1] is not None] | ||
if len(pred_evidence_set) == 0: | ||
total_precision += 1 | ||
else: | ||
curr_precision = 0 | ||
curr_precision_hits = 0 | ||
for pred in pred_evidence_set: | ||
curr_precision_hits += 1 | ||
if pred in correct_evidence: | ||
curr_precision += 1 | ||
total_precision += curr_precision / curr_precision_hits | ||
|
||
# calculate recall | ||
if len(evidences[query_id]) == 0 or all([len(ev_set) == 0 for ev_set in evidences[query_id]]) or \ | ||
check_evidence_set(pred_evidence_set, evidences[query_id]): | ||
total_recall += 1 | ||
|
||
print('****************************************') | ||
fever_score = strict_correct / total | ||
print(f'FEVER Score: {fever_score}') | ||
label_acc = correct / total | ||
print(f'Label Accuracy: {label_acc}') | ||
precision = (total_precision / total_hits) if total_hits > 0 else 1.0 | ||
print(f'Evidence Precision: {precision}') | ||
recall = (total_recall / total_hits) if total_hits > 0 else 0.0 | ||
print(f'Evidence Recall: {recall}') | ||
f1 = 2.0 * precision * recall / (precision + recall) | ||
print(f'Evidence F1: {f1}') | ||
print('****************************************') | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='Calculates various metrics used in the FEVER shared task.') | ||
parser.add_argument('--dataset_file', required=True, help='FEVER dataset file.') | ||
parser.add_argument('--submission_file', required=True, help='Submission file to FEVER shared task.') | ||
args = parser.parse_args() | ||
|
||
calculate_scores(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import argparse | ||
import glob | ||
import json | ||
import numpy as np | ||
from sklearn.metrics import accuracy_score, f1_score | ||
|
||
from fever_utils import make_sentence_id | ||
|
||
def calculate_scores(args): | ||
evidences = {} | ||
|
||
with open(args.dataset_file, 'r', encoding='utf-8') as f: | ||
for line in f: | ||
line_json = json.loads(line.strip()) | ||
|
||
evidence_sets = [] | ||
if line_json['label'] != 'NOT ENOUGH INFO': | ||
for annotator in line_json['evidence']: | ||
evidence_set = [make_sentence_id(evidence[2], evidence[3]) for evidence in annotator] | ||
evidence_sets.append(evidence_set) | ||
evidences[line_json['id']] = evidence_sets | ||
|
||
def aggregate(scores): | ||
if args.num_classes == 4: | ||
# filter out samples predicted weak and remove weak scores | ||
scores = scores[np.argmax(scores, axis=1) != 3][:, :3] | ||
if len(scores) == 0: | ||
return 1 | ||
|
||
if args.strategy == 'first': | ||
return np.argmax(scores[0]) | ||
elif args.strategy == 'sum': | ||
return np.argmax(np.sum(np.exp(scores), axis=0)) | ||
elif args.strategy == 'nei_default': | ||
maxes = np.argmax(scores, axis=1) | ||
if (0 in maxes and 2 in maxes) or (0 not in maxes and 2 not in maxes): | ||
return 1 | ||
elif 0 in maxes: | ||
return 0 | ||
elif 2 in maxes: | ||
return 2 | ||
return -1 | ||
elif args.strategy == 'max': | ||
return np.argmax(np.max(np.exp(scores), axis=0)) | ||
return -1 | ||
|
||
for scores_file in sorted(glob.glob(f'{args.scores_files_prefix}*')): | ||
labels = [] | ||
pred_labels = [] | ||
fever_scores = [] | ||
with open(args.id_file, 'r', encoding='utf-8') as f_id, open(scores_file, 'r', encoding='utf-8') as f_scores: | ||
curr_query = None | ||
curr_label = None # actual label for current query | ||
curr_scores = [] | ||
curr_evidences = [] | ||
for id_line, scores_line in zip(f_id, f_scores): | ||
query_id, sent_ids, label_str = id_line.strip().split('\t') | ||
query_id = int(query_id) | ||
|
||
if query_id != curr_query: | ||
if curr_query is not None: | ||
# aggregate to get predicted label | ||
pred_label = aggregate(np.array(curr_scores)) | ||
pred_labels.append(pred_label) | ||
# calculate FEVER score | ||
fever_scores.append(int(pred_label == curr_label and (pred_label == 1 or \ | ||
any([set(ev_set).issubset(set(curr_evidences)) for ev_set in evidences[curr_query]])))) | ||
curr_query = query_id | ||
curr_scores.clear() | ||
curr_evidences.clear() | ||
# save actual label | ||
if label_str == 'false': | ||
curr_label = 0 | ||
elif label_str == 'weak': | ||
curr_label = 1 | ||
elif label_str == 'true': | ||
curr_label = 2 | ||
labels.append(curr_label) | ||
|
||
# save predicted evidence(s) and scores | ||
if args.num_classes == 3: | ||
_, false_score, nei_score, true_score = scores_line.strip().split('\t') | ||
scores = [float(false_score), float(nei_score), float(true_score)] | ||
elif args.num_classes == 4: | ||
_, false_score, ignore_score, true_score, nei_score = scores_line.strip().split('\t') | ||
scores = [float(false_score), float(nei_score), float(true_score), float(ignore_score)] | ||
curr_scores.append(scores) | ||
curr_evidences.extend(sent_ids.strip().split(' ')) | ||
|
||
# handle last query | ||
pred_label = aggregate(np.array(curr_scores)) | ||
pred_labels.append(pred_label) | ||
fever_scores.append(int(pred_label == curr_label and (pred_label == 1 or \ | ||
any([set(ev_set).issubset(set(curr_evidences)) for ev_set in evidences[curr_query]])))) | ||
|
||
print(scores_file) | ||
print(f'Label Accuracy: {accuracy_score(labels, pred_labels)}') | ||
print(f'Predicted Label F1 Scores: {f1_score(labels, pred_labels, average=None)}') | ||
print(f'Predicted Label Distribution: {[pred_labels.count(i) for i in range(args.num_classes)]}') | ||
print(f'FEVER Score: {sum(fever_scores) / len(fever_scores)}') | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='Calculates various metrics of label prediction output files.') | ||
parser.add_argument('--id_file', required=True, help='Input query-doc pair ids file.') | ||
parser.add_argument('--scores_files_prefix', required=True, help='Prefix of all T5 label prediction scores files.') | ||
parser.add_argument('--dataset_file', help='FEVER dataset file.') | ||
parser.add_argument('--num_classes', type=int, default=3, help='Number of label prediction classes.') | ||
parser.add_argument('--strategy', help='Format of scores file and method of aggregation if applicable.') | ||
args = parser.parse_args() | ||
|
||
calculate_scores(args) |
Oops, something went wrong.