Skip to content

Commit

Permalink
Add scripts and docs for LisT5 FEVER pipeline (#203)
Browse files Browse the repository at this point in the history
* added scripts and docs for LisT5 pipeline for FEVER task

* added short description and link to paper in README
  • Loading branch information
kelvin-jiang committed Aug 2, 2021
1 parent 8fbe93d commit 5ba5dda
Show file tree
Hide file tree
Showing 18 changed files with 2,053 additions and 0 deletions.
128 changes: 128 additions & 0 deletions experiments/list5/README.md
@@ -0,0 +1,128 @@
# LisT5: FEVER Pipeline with T5

This page describes replication for the LisT5 pipeline for fact verification, outlined in the following paper:
* Kelvin Jiang, Ronak Pradeep, Jimmy Lin. [Exploring Listwise Evidence Reasoning with T5 for Fact Verification.](https://aclanthology.org/2021.acl-short.51.pdf) _ACL 2021_.

Some initial setup:

```bash
mkdir data/list5
mkdir runs/list5
```

## Document Retrieval

1. Retrieve with anserini

Follow instructions [here](https://github.com/castorini/anserini/blob/master/docs/experiments-fever.md) to download the FEVER dataset and build an index with anserini. Assume the anserini directory is located at `~/anserini`. After the dataset has been indexed, run the following command to retrieve (note the use of the paragraph index):

```bash
sh ~/anserini/target/appassembler/bin/SearchCollection \
-index ~/anserini/indexes/fever/lucene-index-fever-paragraph \
-topicreader TsvInt -topics ~/anserini/collections/fever/queries.sentence.test.tsv \
-output runs/list5/run.fever-anserini-paragraph.test.tsv \
-bm25 -bm25.k1 0.6 -bm25.b 0.5
```

2. Retrieve with MediaWiki API (UKP-Athene)

Also retrieve documents with MediaWiki API, code originating from [UKP-Athene's repository](https://github.com/UKPLab/fever-2018-team-athene). Install the dependencies for UKP-Athene's code using `experiments/list5/ukp-athene/requirements.txt` if necessary.

```bash
python experiments/list5/ukp-athene/doc_retrieval.py \
--db-file ~/anserini/collections/fever/wiki-pages \
--in-file ~/anserini/collections/fever/shared_task_test.jsonl \
--out-file runs/list5/run.fever-ukp-athene-paragraph.test.jsonl
```

Convert the MediaWiki API results to run format.

```bash
python experiments/list5/ukp-athene/convert_to_run.py \
--dataset_file runs/list5/run.fever-ukp-athene-paragraph.test.jsonl \
--output_run_file runs/list5/run.fever-ukp-athene-paragraph.test.tsv
```

3. Merge retrieval runs

Merge the results of the two methods of retrieval into a single run. Make sure that the anserini run file comes first in the list of `--input_run_file` arguments.

```bash
python experiments/list5/merge_runs.py \
--input_run_file runs/list5/run.fever-anserini-paragraph.test.tsv \
--input_run_file runs/list5/run.fever-ukp-athene-paragraph.test.tsv \
--output_run_file runs/list5/run.fever-paragraph.test.tsv \
--strategy zip
```

## Sentence Selection

4. Expand document IDs to all sentence IDs

Expand run file to a sentence ID granularity.

```bash
python experiments/list5/expand_docs_to_sentences.py \
--input_run_file runs/list5/run.fever-paragraph.test.tsv \
--collection_folder ~/anserini/collections/fever/wiki-pages \
--output_run_file runs/list5/run.fever-sentence-top-150.test.tsv \
--k 150
```

5. Convert run file to T5 input file for monoT5 re-ranking

Re-rank the top `k = 200` sentences, a tradeoff between efficiency and recall.

```bash
python experiments/list5/convert_run_to_sentence_selection_input.py \
--dataset_file ~/anserini/collections/fever/shared_task_test.jsonl \
--run_file runs/list5/run.fever-sentence-top-150.test.tsv \
--collection_folder ~/anserini/collections/fever/wiki-pages \
--output_id_file data/list5/query-doc-pairs-id-test-ner-rerank-top-200.txt \
--output_text_file data/list5/query-doc-pairs-text-test-ner-rerank-top-200.txt \
--k 200 --type mono --ner
```

6. Re-rank T5 input file to get scores file

Run inference of T5 sentence selection model (e.g. using Google Cloud TPUs). Assume the sentence selection T5 output file is at `data/list5/query-doc-pairs-scores-test-ner-rerank-top-200.txt`.

7. Convert scores file back to run file

```bash
python experiments/list5/convert_sentence_selection_output_to_run.py \
--id_file data/list5/query-doc-pairs-id-test-ner-rerank-top-200.txt \
--scores_file data/list5/query-doc-pairs-scores-test-ner-rerank-top-200.txt \
--output_run_file runs/list5/run.fever-sentence-top-150-reranked.txt \
--type mono
```

## Label Prediction

8. Convert re-ranked run file to T5 input file for label prediction

Make sure to use `--format concat` to specify listwise format.

```bash
python experiments/list5/convert_run_to_label_prediction_input.py \
--dataset_file ~/anserini/collections/fever/shared_task_test.jsonl \
--run_file runs/list5/run.fever-sentence-top-150-reranked.txt \
--collection_folder ~/anserini/collections/fever/wiki-pages \
--output_id_file data/list5/query-doc-pairs-id-test-ner-label-pred-concat.txt \
--output_text_file data/list5/query-doc-pairs-text-test-ner-label-pred-concat.txt \
--format concat
```

9. Predict labels for T5 input file to get scores file

Run inference of T5 label prediction model (e.g. using Google Cloud TPUs). Assume the label prediction T5 output file is at `data/list5/query-doc-pairs-scores-test-ner-label-pred-concat.txt`.

10. Convert scores file to FEVER submission file

```bash
python experiments/list5/predict_for_submission.py \
--id_file data/list5/query-doc-pairs-id-test-ner-label-pred-concat.txt \
--scores_file data/list5/query-doc-pairs-scores-test-ner-label-pred-concat.txt \
--dataset_file ~/anserini/collections/fever/shared_task_test.jsonl \
--output_predictions_file data/list5/predictions.jsonl
```
92 changes: 92 additions & 0 deletions experiments/list5/calculate_fever_scores.py
@@ -0,0 +1,92 @@
import argparse
import json

def calculate_scores(args):
evidences = {}
labels = {}

total = 0
correct = 0
strict_correct = 0
total_hits = 0
total_precision = 0
total_recall = 0

with open(args.dataset_file, 'r', encoding='utf-8') as f:
for line in f:
line_json = json.loads(line.strip())
query_id = line_json['id']
if 'label' in line_json: # no "label" field in test datasets
label = line_json['label']
labels[query_id] = label

if label not in ['SUPPORTS', 'REFUTES']:
continue

evidence_sets = []
for annotator in line_json['evidence']:
evidence_set = [[evidence[2], evidence[3]] for evidence in annotator]
evidence_sets.append(evidence_set)
evidences[query_id] = evidence_sets

def check_evidence_set(pred_evidence_set, true_evidence_sets):
for evidence_set in true_evidence_sets:
if all([evidence in pred_evidence_set for evidence in evidence_set]):
return True

return False

with open(args.submission_file, 'r', encoding='utf-8') as f:
for line in f:
line_json = json.loads(line.strip())
query_id = line_json['id']
pred_label = line_json['predicted_label']
pred_evidence_set = line_json['predicted_evidence']

total += 1
if pred_label == labels[query_id]:
correct += 1
if labels[query_id] == 'NOT ENOUGH INFO' or check_evidence_set(pred_evidence_set, evidences[query_id]):
strict_correct += 1

if labels[query_id] != 'NOT ENOUGH INFO':
total_hits += 1

# calculate precision
correct_evidence = [ev for ev_set in evidences[query_id] for ev in ev_set if ev[1] is not None]
if len(pred_evidence_set) == 0:
total_precision += 1
else:
curr_precision = 0
curr_precision_hits = 0
for pred in pred_evidence_set:
curr_precision_hits += 1
if pred in correct_evidence:
curr_precision += 1
total_precision += curr_precision / curr_precision_hits

# calculate recall
if len(evidences[query_id]) == 0 or all([len(ev_set) == 0 for ev_set in evidences[query_id]]) or \
check_evidence_set(pred_evidence_set, evidences[query_id]):
total_recall += 1

print('****************************************')
fever_score = strict_correct / total
print(f'FEVER Score: {fever_score}')
label_acc = correct / total
print(f'Label Accuracy: {label_acc}')
precision = (total_precision / total_hits) if total_hits > 0 else 1.0
print(f'Evidence Precision: {precision}')
recall = (total_recall / total_hits) if total_hits > 0 else 0.0
print(f'Evidence Recall: {recall}')
f1 = 2.0 * precision * recall / (precision + recall)
print(f'Evidence F1: {f1}')
print('****************************************')

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Calculates various metrics used in the FEVER shared task.')
parser.add_argument('--dataset_file', required=True, help='FEVER dataset file.')
parser.add_argument('--submission_file', required=True, help='Submission file to FEVER shared task.')
args = parser.parse_args()

calculate_scores(args)
111 changes: 111 additions & 0 deletions experiments/list5/calculate_label_prediction_scores.py
@@ -0,0 +1,111 @@
import argparse
import glob
import json
import numpy as np
from sklearn.metrics import accuracy_score, f1_score

from fever_utils import make_sentence_id

def calculate_scores(args):
evidences = {}

with open(args.dataset_file, 'r', encoding='utf-8') as f:
for line in f:
line_json = json.loads(line.strip())

evidence_sets = []
if line_json['label'] != 'NOT ENOUGH INFO':
for annotator in line_json['evidence']:
evidence_set = [make_sentence_id(evidence[2], evidence[3]) for evidence in annotator]
evidence_sets.append(evidence_set)
evidences[line_json['id']] = evidence_sets

def aggregate(scores):
if args.num_classes == 4:
# filter out samples predicted weak and remove weak scores
scores = scores[np.argmax(scores, axis=1) != 3][:, :3]
if len(scores) == 0:
return 1

if args.strategy == 'first':
return np.argmax(scores[0])
elif args.strategy == 'sum':
return np.argmax(np.sum(np.exp(scores), axis=0))
elif args.strategy == 'nei_default':
maxes = np.argmax(scores, axis=1)
if (0 in maxes and 2 in maxes) or (0 not in maxes and 2 not in maxes):
return 1
elif 0 in maxes:
return 0
elif 2 in maxes:
return 2
return -1
elif args.strategy == 'max':
return np.argmax(np.max(np.exp(scores), axis=0))
return -1

for scores_file in sorted(glob.glob(f'{args.scores_files_prefix}*')):
labels = []
pred_labels = []
fever_scores = []
with open(args.id_file, 'r', encoding='utf-8') as f_id, open(scores_file, 'r', encoding='utf-8') as f_scores:
curr_query = None
curr_label = None # actual label for current query
curr_scores = []
curr_evidences = []
for id_line, scores_line in zip(f_id, f_scores):
query_id, sent_ids, label_str = id_line.strip().split('\t')
query_id = int(query_id)

if query_id != curr_query:
if curr_query is not None:
# aggregate to get predicted label
pred_label = aggregate(np.array(curr_scores))
pred_labels.append(pred_label)
# calculate FEVER score
fever_scores.append(int(pred_label == curr_label and (pred_label == 1 or \
any([set(ev_set).issubset(set(curr_evidences)) for ev_set in evidences[curr_query]]))))
curr_query = query_id
curr_scores.clear()
curr_evidences.clear()
# save actual label
if label_str == 'false':
curr_label = 0
elif label_str == 'weak':
curr_label = 1
elif label_str == 'true':
curr_label = 2
labels.append(curr_label)

# save predicted evidence(s) and scores
if args.num_classes == 3:
_, false_score, nei_score, true_score = scores_line.strip().split('\t')
scores = [float(false_score), float(nei_score), float(true_score)]
elif args.num_classes == 4:
_, false_score, ignore_score, true_score, nei_score = scores_line.strip().split('\t')
scores = [float(false_score), float(nei_score), float(true_score), float(ignore_score)]
curr_scores.append(scores)
curr_evidences.extend(sent_ids.strip().split(' '))

# handle last query
pred_label = aggregate(np.array(curr_scores))
pred_labels.append(pred_label)
fever_scores.append(int(pred_label == curr_label and (pred_label == 1 or \
any([set(ev_set).issubset(set(curr_evidences)) for ev_set in evidences[curr_query]]))))

print(scores_file)
print(f'Label Accuracy: {accuracy_score(labels, pred_labels)}')
print(f'Predicted Label F1 Scores: {f1_score(labels, pred_labels, average=None)}')
print(f'Predicted Label Distribution: {[pred_labels.count(i) for i in range(args.num_classes)]}')
print(f'FEVER Score: {sum(fever_scores) / len(fever_scores)}')

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Calculates various metrics of label prediction output files.')
parser.add_argument('--id_file', required=True, help='Input query-doc pair ids file.')
parser.add_argument('--scores_files_prefix', required=True, help='Prefix of all T5 label prediction scores files.')
parser.add_argument('--dataset_file', help='FEVER dataset file.')
parser.add_argument('--num_classes', type=int, default=3, help='Number of label prediction classes.')
parser.add_argument('--strategy', help='Format of scores file and method of aggregation if applicable.')
args = parser.parse_args()

calculate_scores(args)

0 comments on commit 5ba5dda

Please sign in to comment.