From 5ba5dda6574947b4bcacaae39920800b63a24e00 Mon Sep 17 00:00:00 2001 From: Kelvin Jiang <20145768+kelvin-jiang@users.noreply.github.com> Date: Mon, 2 Aug 2021 13:13:54 -0400 Subject: [PATCH] Add scripts and docs for LisT5 FEVER pipeline (#203) * added scripts and docs for LisT5 pipeline for FEVER task * added short description and link to paper in README --- experiments/list5/README.md | 128 +++++++++++ experiments/list5/calculate_fever_scores.py | 92 ++++++++ .../calculate_label_prediction_scores.py | 111 +++++++++ .../list5/calculate_oracle_accuracy.py | 91 ++++++++ .../convert_run_to_label_prediction_input.py | 134 +++++++++++ ...convert_run_to_sentence_selection_input.py | 133 +++++++++++ ...onvert_sentence_selection_output_to_run.py | 76 +++++++ experiments/list5/expand_docs_to_sentences.py | 53 +++++ experiments/list5/fever_utils.py | 112 ++++++++++ .../generate_label_prediction_data_gold.py | 143 ++++++++++++ .../generate_label_prediction_data_noisy.py | 162 ++++++++++++++ .../list5/generate_sentence_selection_data.py | 120 ++++++++++ .../generate_sentence_selection_ner_data.py | 150 +++++++++++++ experiments/list5/merge_runs.py | 49 ++++ experiments/list5/predict_for_submission.py | 106 +++++++++ .../list5/ukp-athene/convert_to_run.py | 17 ++ experiments/list5/ukp-athene/doc_retrieval.py | 210 ++++++++++++++++++ experiments/list5/ukp-athene/requirements.txt | 166 ++++++++++++++ 18 files changed, 2053 insertions(+) create mode 100644 experiments/list5/README.md create mode 100644 experiments/list5/calculate_fever_scores.py create mode 100644 experiments/list5/calculate_label_prediction_scores.py create mode 100644 experiments/list5/calculate_oracle_accuracy.py create mode 100644 experiments/list5/convert_run_to_label_prediction_input.py create mode 100644 experiments/list5/convert_run_to_sentence_selection_input.py create mode 100644 experiments/list5/convert_sentence_selection_output_to_run.py create mode 100644 experiments/list5/expand_docs_to_sentences.py create mode 100644 experiments/list5/fever_utils.py create mode 100644 experiments/list5/generate_label_prediction_data_gold.py create mode 100644 experiments/list5/generate_label_prediction_data_noisy.py create mode 100644 experiments/list5/generate_sentence_selection_data.py create mode 100644 experiments/list5/generate_sentence_selection_ner_data.py create mode 100644 experiments/list5/merge_runs.py create mode 100644 experiments/list5/predict_for_submission.py create mode 100644 experiments/list5/ukp-athene/convert_to_run.py create mode 100644 experiments/list5/ukp-athene/doc_retrieval.py create mode 100644 experiments/list5/ukp-athene/requirements.txt diff --git a/experiments/list5/README.md b/experiments/list5/README.md new file mode 100644 index 00000000..2d6b06cf --- /dev/null +++ b/experiments/list5/README.md @@ -0,0 +1,128 @@ +# LisT5: FEVER Pipeline with T5 + +This page describes replication for the LisT5 pipeline for fact verification, outlined in the following paper: +* Kelvin Jiang, Ronak Pradeep, Jimmy Lin. [Exploring Listwise Evidence Reasoning with T5 for Fact Verification.](https://aclanthology.org/2021.acl-short.51.pdf) _ACL 2021_. + +Some initial setup: + +```bash +mkdir data/list5 +mkdir runs/list5 +``` + +## Document Retrieval + +1. Retrieve with anserini + +Follow instructions [here](https://github.com/castorini/anserini/blob/master/docs/experiments-fever.md) to download the FEVER dataset and build an index with anserini. Assume the anserini directory is located at `~/anserini`. After the dataset has been indexed, run the following command to retrieve (note the use of the paragraph index): + +```bash +sh ~/anserini/target/appassembler/bin/SearchCollection \ + -index ~/anserini/indexes/fever/lucene-index-fever-paragraph \ + -topicreader TsvInt -topics ~/anserini/collections/fever/queries.sentence.test.tsv \ + -output runs/list5/run.fever-anserini-paragraph.test.tsv \ + -bm25 -bm25.k1 0.6 -bm25.b 0.5 +``` + +2. Retrieve with MediaWiki API (UKP-Athene) + +Also retrieve documents with MediaWiki API, code originating from [UKP-Athene's repository](https://github.com/UKPLab/fever-2018-team-athene). Install the dependencies for UKP-Athene's code using `experiments/list5/ukp-athene/requirements.txt` if necessary. + +```bash +python experiments/list5/ukp-athene/doc_retrieval.py \ + --db-file ~/anserini/collections/fever/wiki-pages \ + --in-file ~/anserini/collections/fever/shared_task_test.jsonl \ + --out-file runs/list5/run.fever-ukp-athene-paragraph.test.jsonl +``` + +Convert the MediaWiki API results to run format. + +```bash +python experiments/list5/ukp-athene/convert_to_run.py \ + --dataset_file runs/list5/run.fever-ukp-athene-paragraph.test.jsonl \ + --output_run_file runs/list5/run.fever-ukp-athene-paragraph.test.tsv +``` + +3. Merge retrieval runs + +Merge the results of the two methods of retrieval into a single run. Make sure that the anserini run file comes first in the list of `--input_run_file` arguments. + +```bash +python experiments/list5/merge_runs.py \ + --input_run_file runs/list5/run.fever-anserini-paragraph.test.tsv \ + --input_run_file runs/list5/run.fever-ukp-athene-paragraph.test.tsv \ + --output_run_file runs/list5/run.fever-paragraph.test.tsv \ + --strategy zip +``` + +## Sentence Selection + +4. Expand document IDs to all sentence IDs + +Expand run file to a sentence ID granularity. + +```bash +python experiments/list5/expand_docs_to_sentences.py \ + --input_run_file runs/list5/run.fever-paragraph.test.tsv \ + --collection_folder ~/anserini/collections/fever/wiki-pages \ + --output_run_file runs/list5/run.fever-sentence-top-150.test.tsv \ + --k 150 +``` + +5. Convert run file to T5 input file for monoT5 re-ranking + +Re-rank the top `k = 200` sentences, a tradeoff between efficiency and recall. + +```bash +python experiments/list5/convert_run_to_sentence_selection_input.py \ + --dataset_file ~/anserini/collections/fever/shared_task_test.jsonl \ + --run_file runs/list5/run.fever-sentence-top-150.test.tsv \ + --collection_folder ~/anserini/collections/fever/wiki-pages \ + --output_id_file data/list5/query-doc-pairs-id-test-ner-rerank-top-200.txt \ + --output_text_file data/list5/query-doc-pairs-text-test-ner-rerank-top-200.txt \ + --k 200 --type mono --ner +``` + +6. Re-rank T5 input file to get scores file + +Run inference of T5 sentence selection model (e.g. using Google Cloud TPUs). Assume the sentence selection T5 output file is at `data/list5/query-doc-pairs-scores-test-ner-rerank-top-200.txt`. + +7. Convert scores file back to run file + +```bash +python experiments/list5/convert_sentence_selection_output_to_run.py \ + --id_file data/list5/query-doc-pairs-id-test-ner-rerank-top-200.txt \ + --scores_file data/list5/query-doc-pairs-scores-test-ner-rerank-top-200.txt \ + --output_run_file runs/list5/run.fever-sentence-top-150-reranked.txt \ + --type mono +``` + +## Label Prediction + +8. Convert re-ranked run file to T5 input file for label prediction + +Make sure to use `--format concat` to specify listwise format. + +```bash +python experiments/list5/convert_run_to_label_prediction_input.py \ + --dataset_file ~/anserini/collections/fever/shared_task_test.jsonl \ + --run_file runs/list5/run.fever-sentence-top-150-reranked.txt \ + --collection_folder ~/anserini/collections/fever/wiki-pages \ + --output_id_file data/list5/query-doc-pairs-id-test-ner-label-pred-concat.txt \ + --output_text_file data/list5/query-doc-pairs-text-test-ner-label-pred-concat.txt \ + --format concat +``` + +9. Predict labels for T5 input file to get scores file + +Run inference of T5 label prediction model (e.g. using Google Cloud TPUs). Assume the label prediction T5 output file is at `data/list5/query-doc-pairs-scores-test-ner-label-pred-concat.txt`. + +10. Convert scores file to FEVER submission file + +```bash +python experiments/list5/predict_for_submission.py \ + --id_file data/list5/query-doc-pairs-id-test-ner-label-pred-concat.txt \ + --scores_file data/list5/query-doc-pairs-scores-test-ner-label-pred-concat.txt \ + --dataset_file ~/anserini/collections/fever/shared_task_test.jsonl \ + --output_predictions_file data/list5/predictions.jsonl +``` diff --git a/experiments/list5/calculate_fever_scores.py b/experiments/list5/calculate_fever_scores.py new file mode 100644 index 00000000..29f4b5f1 --- /dev/null +++ b/experiments/list5/calculate_fever_scores.py @@ -0,0 +1,92 @@ +import argparse +import json + +def calculate_scores(args): + evidences = {} + labels = {} + + total = 0 + correct = 0 + strict_correct = 0 + total_hits = 0 + total_precision = 0 + total_recall = 0 + + with open(args.dataset_file, 'r', encoding='utf-8') as f: + for line in f: + line_json = json.loads(line.strip()) + query_id = line_json['id'] + if 'label' in line_json: # no "label" field in test datasets + label = line_json['label'] + labels[query_id] = label + + if label not in ['SUPPORTS', 'REFUTES']: + continue + + evidence_sets = [] + for annotator in line_json['evidence']: + evidence_set = [[evidence[2], evidence[3]] for evidence in annotator] + evidence_sets.append(evidence_set) + evidences[query_id] = evidence_sets + + def check_evidence_set(pred_evidence_set, true_evidence_sets): + for evidence_set in true_evidence_sets: + if all([evidence in pred_evidence_set for evidence in evidence_set]): + return True + + return False + + with open(args.submission_file, 'r', encoding='utf-8') as f: + for line in f: + line_json = json.loads(line.strip()) + query_id = line_json['id'] + pred_label = line_json['predicted_label'] + pred_evidence_set = line_json['predicted_evidence'] + + total += 1 + if pred_label == labels[query_id]: + correct += 1 + if labels[query_id] == 'NOT ENOUGH INFO' or check_evidence_set(pred_evidence_set, evidences[query_id]): + strict_correct += 1 + + if labels[query_id] != 'NOT ENOUGH INFO': + total_hits += 1 + + # calculate precision + correct_evidence = [ev for ev_set in evidences[query_id] for ev in ev_set if ev[1] is not None] + if len(pred_evidence_set) == 0: + total_precision += 1 + else: + curr_precision = 0 + curr_precision_hits = 0 + for pred in pred_evidence_set: + curr_precision_hits += 1 + if pred in correct_evidence: + curr_precision += 1 + total_precision += curr_precision / curr_precision_hits + + # calculate recall + if len(evidences[query_id]) == 0 or all([len(ev_set) == 0 for ev_set in evidences[query_id]]) or \ + check_evidence_set(pred_evidence_set, evidences[query_id]): + total_recall += 1 + + print('****************************************') + fever_score = strict_correct / total + print(f'FEVER Score: {fever_score}') + label_acc = correct / total + print(f'Label Accuracy: {label_acc}') + precision = (total_precision / total_hits) if total_hits > 0 else 1.0 + print(f'Evidence Precision: {precision}') + recall = (total_recall / total_hits) if total_hits > 0 else 0.0 + print(f'Evidence Recall: {recall}') + f1 = 2.0 * precision * recall / (precision + recall) + print(f'Evidence F1: {f1}') + print('****************************************') + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Calculates various metrics used in the FEVER shared task.') + parser.add_argument('--dataset_file', required=True, help='FEVER dataset file.') + parser.add_argument('--submission_file', required=True, help='Submission file to FEVER shared task.') + args = parser.parse_args() + + calculate_scores(args) diff --git a/experiments/list5/calculate_label_prediction_scores.py b/experiments/list5/calculate_label_prediction_scores.py new file mode 100644 index 00000000..f1d8fef9 --- /dev/null +++ b/experiments/list5/calculate_label_prediction_scores.py @@ -0,0 +1,111 @@ +import argparse +import glob +import json +import numpy as np +from sklearn.metrics import accuracy_score, f1_score + +from fever_utils import make_sentence_id + +def calculate_scores(args): + evidences = {} + + with open(args.dataset_file, 'r', encoding='utf-8') as f: + for line in f: + line_json = json.loads(line.strip()) + + evidence_sets = [] + if line_json['label'] != 'NOT ENOUGH INFO': + for annotator in line_json['evidence']: + evidence_set = [make_sentence_id(evidence[2], evidence[3]) for evidence in annotator] + evidence_sets.append(evidence_set) + evidences[line_json['id']] = evidence_sets + + def aggregate(scores): + if args.num_classes == 4: + # filter out samples predicted weak and remove weak scores + scores = scores[np.argmax(scores, axis=1) != 3][:, :3] + if len(scores) == 0: + return 1 + + if args.strategy == 'first': + return np.argmax(scores[0]) + elif args.strategy == 'sum': + return np.argmax(np.sum(np.exp(scores), axis=0)) + elif args.strategy == 'nei_default': + maxes = np.argmax(scores, axis=1) + if (0 in maxes and 2 in maxes) or (0 not in maxes and 2 not in maxes): + return 1 + elif 0 in maxes: + return 0 + elif 2 in maxes: + return 2 + return -1 + elif args.strategy == 'max': + return np.argmax(np.max(np.exp(scores), axis=0)) + return -1 + + for scores_file in sorted(glob.glob(f'{args.scores_files_prefix}*')): + labels = [] + pred_labels = [] + fever_scores = [] + with open(args.id_file, 'r', encoding='utf-8') as f_id, open(scores_file, 'r', encoding='utf-8') as f_scores: + curr_query = None + curr_label = None # actual label for current query + curr_scores = [] + curr_evidences = [] + for id_line, scores_line in zip(f_id, f_scores): + query_id, sent_ids, label_str = id_line.strip().split('\t') + query_id = int(query_id) + + if query_id != curr_query: + if curr_query is not None: + # aggregate to get predicted label + pred_label = aggregate(np.array(curr_scores)) + pred_labels.append(pred_label) + # calculate FEVER score + fever_scores.append(int(pred_label == curr_label and (pred_label == 1 or \ + any([set(ev_set).issubset(set(curr_evidences)) for ev_set in evidences[curr_query]])))) + curr_query = query_id + curr_scores.clear() + curr_evidences.clear() + # save actual label + if label_str == 'false': + curr_label = 0 + elif label_str == 'weak': + curr_label = 1 + elif label_str == 'true': + curr_label = 2 + labels.append(curr_label) + + # save predicted evidence(s) and scores + if args.num_classes == 3: + _, false_score, nei_score, true_score = scores_line.strip().split('\t') + scores = [float(false_score), float(nei_score), float(true_score)] + elif args.num_classes == 4: + _, false_score, ignore_score, true_score, nei_score = scores_line.strip().split('\t') + scores = [float(false_score), float(nei_score), float(true_score), float(ignore_score)] + curr_scores.append(scores) + curr_evidences.extend(sent_ids.strip().split(' ')) + + # handle last query + pred_label = aggregate(np.array(curr_scores)) + pred_labels.append(pred_label) + fever_scores.append(int(pred_label == curr_label and (pred_label == 1 or \ + any([set(ev_set).issubset(set(curr_evidences)) for ev_set in evidences[curr_query]])))) + + print(scores_file) + print(f'Label Accuracy: {accuracy_score(labels, pred_labels)}') + print(f'Predicted Label F1 Scores: {f1_score(labels, pred_labels, average=None)}') + print(f'Predicted Label Distribution: {[pred_labels.count(i) for i in range(args.num_classes)]}') + print(f'FEVER Score: {sum(fever_scores) / len(fever_scores)}') + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Calculates various metrics of label prediction output files.') + parser.add_argument('--id_file', required=True, help='Input query-doc pair ids file.') + parser.add_argument('--scores_files_prefix', required=True, help='Prefix of all T5 label prediction scores files.') + parser.add_argument('--dataset_file', help='FEVER dataset file.') + parser.add_argument('--num_classes', type=int, default=3, help='Number of label prediction classes.') + parser.add_argument('--strategy', help='Format of scores file and method of aggregation if applicable.') + args = parser.parse_args() + + calculate_scores(args) diff --git a/experiments/list5/calculate_oracle_accuracy.py b/experiments/list5/calculate_oracle_accuracy.py new file mode 100644 index 00000000..350a0c38 --- /dev/null +++ b/experiments/list5/calculate_oracle_accuracy.py @@ -0,0 +1,91 @@ +import argparse +from collections import Counter +import ftfy +import json + +from fever_utils import make_sentence_id + +def calculate_stats(args): + evidences = {} + num_evidences = [] + + correct = {cutoff: 0 for cutoff in args.k} + max_cutoff = max(args.k) + num_verifiable_queries = 0 + num_queries = 0 + + # read in dataset file and save evidences to dicts + with open(args.dataset_file, 'r', encoding='utf-8') as f: + for line in f: + line_json = json.loads(line.strip()) + + query_id = line_json['id'] + label = line_json['label'] + if label != 'NOT ENOUGH INFO': + num_verifiable_queries += 1 + num_queries += 1 + + annotators = [] + # no evidence set for NEI queries + if label != 'NOT ENOUGH INFO': + for annotator in line_json['evidence']: + evidence_set = [] + for evidence in annotator: + evidence[2] = ftfy.fix_text(evidence[2]) + evidence_set.append(make_sentence_id(evidence[2], evidence[3])) + annotators.append(evidence_set) + num_evidences.append(min([len(evidence_set) for evidence_set in annotators])) + + evidences[query_id] = annotators + + # read in run file and record cutoff counts + with open(args.run_file, 'r', encoding='utf-8') as f: + curr_query = None + pred_sent_ids = [] + for line in f: + query_id, sent_id, rank = line.strip().split('\t') + query_id = int(query_id) + + if query_id != curr_query: + if curr_query is not None: + for rank in args.k: + if not evidences[curr_query]: # if query is NEI, assume it is correct + correct[rank] += 1 + else: + for evidence_set in evidences[curr_query]: + if all([evidence in pred_sent_ids for evidence in evidence_set]): + correct[rank] += 1 + break + curr_query = query_id + pred_sent_ids.clear() + + if int(rank) <= max_cutoff: + pred_sent_ids.append(sent_id) + + # handle last query + for rank in args.k: + if not evidences[curr_query]: # if query is NEI, assume it is correct + correct[rank] += 1 + else: + for evidence_set in evidences[curr_query]: + if all([evidence in pred_sent_ids for evidence in evidence_set]): + correct[rank] += 1 + break + + # print number of queries that can be verified with each minimum number of evidences + evidences_counter = Counter(num_evidences) + for num_evidence, count in evidences_counter.most_common(): + print(f'{num_evidence}-verifiable queries: {count / num_verifiable_queries}') + print('--------------------------------------------------') + # print oracle accuracies + for cutoff, num_correct in correct.items(): + print(f'Oracle accuracy for top {cutoff}: {num_correct / num_queries}') + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Calculates oracle accuracy (upper bound for label prediction).') + parser.add_argument('--dataset_file', required=True, help='FEVER dataset file.') + parser.add_argument('--run_file', required=True, help='Run file generated after re-ranking.') + parser.add_argument('--k', nargs='+', type=int, help='Cutoff values to calculate oracle accuracy for.') + args = parser.parse_args() + + calculate_stats(args) \ No newline at end of file diff --git a/experiments/list5/convert_run_to_label_prediction_input.py b/experiments/list5/convert_run_to_label_prediction_input.py new file mode 100644 index 00000000..a9889129 --- /dev/null +++ b/experiments/list5/convert_run_to_label_prediction_input.py @@ -0,0 +1,134 @@ +import argparse +import json +import os + +from fever_utils import extract_sentences, make_sentence_id, normalize_text, split_sentence_id, truncate + +def convert_run(args): + queries = {} + labels = {} + evidences = {} + docs = {} + + num_truncated = 0 + + # read in dataset file and save queries to dicts + with open(args.dataset_file, 'r', encoding='utf-8') as f: + print('Reading FEVER dataset file...') + for line in f: + line_json = json.loads(line.strip()) + query_id = line_json['id'] + + query = line_json['claim'] + queries[query_id] = query + + if args.has_labels: + label = line_json['label'] + if label == 'SUPPORTS': + labels[query_id] = 'true' + elif label == 'REFUTES': + labels[query_id] = 'false' + else: # label == 'NOT ENOUGH INFO' + labels[query_id] = 'weak' + + def generate_samples(query_id, pred_sent_ids): + evidence_sets = [] + if args.format == 'concat': + evidence_sets = [[sent_id for sent_id in pred_sent_ids]] + elif args.format == 'agg': + evidence_sets = [[sent_id] for sent_id in pred_sent_ids] + else: # args.format == 'seq': + curr_preds = [] + for sent_id in pred_sent_ids: + curr_preds.append(sent_id) + evidence_sets.append([pred for pred in curr_preds]) + + return evidence_sets + + # read in run file and take top run file ranking predictions + with open(args.run_file, 'r', encoding='utf-8') as f: + print('Reading run file...') + curr_query = None + pred_sent_ids = [] + for line in f: + query_id, sent_id, rank = line.strip().split('\t') + query_id = int(query_id) + + # if we reach a new query in the run file, generate samples for previous query + if query_id != curr_query: + if curr_query is not None: + evidences[curr_query] = generate_samples(curr_query, pred_sent_ids) + curr_query = query_id + pred_sent_ids.clear() + + if int(rank) <= args.max_evidences: + doc_id, _ = split_sentence_id(sent_id) + docs[doc_id] = 'N/A' # placeholder + pred_sent_ids.append(sent_id) + + # handle the final query + evidences[curr_query] = generate_samples(curr_query, pred_sent_ids) + + # read through all wiki dump files and save doc text for involved docs + print('Reading wiki pages...') + for file in os.listdir(args.collection_folder): + with open(os.path.join(args.collection_folder, file), 'r', encoding='utf-8') as f: + for line in f: + line_json = json.loads(line.strip()) + if line_json['id'] in docs: + docs[line_json['id']] = line_json['lines'] + + # write query-doc text pairs to files + with open(args.output_id_file, 'w', encoding='utf-8') as f_id, \ + open(args.output_text_file, 'w', encoding='utf-8') as f_text: + print('Writing query-doc pairs to files...') + for query_id, query_text in queries.items(): + if args.has_labels: + label = labels[query_id] + + for evidence_ids in evidences[query_id]: + evidence_texts = [] + for evidence in evidence_ids: + # get specific sentence from within doc_text + doc_id, sent_num = split_sentence_id(evidence) + entity = doc_id.replace('_', ' ') # prepend entity name to document text + doc_text = docs[doc_id] + sent_text, _ = extract_sentences(doc_text)[sent_num] + evidence_texts.append(f'{normalize_text(entity)} . {normalize_text(sent_text)}') + + # format evidence ids and texts in proper format + evidence_ids_str = ' '.join(evidence_ids) + prefixed_evidence_texts = [] + for i, evidence_text in enumerate(evidence_texts): + if args.format == 'agg': + prefixed_evidence_texts.append(f'premise: {evidence_text}') + else: + truncated_text, num_truncated = truncate(query_text, evidence_text, args.max_evidences, + args.max_seq_len, num_truncated) + prefixed_evidence_texts.append(f'sentence{i + 1}: {truncated_text}') + evidence_texts_str = ' '.join(prefixed_evidence_texts) + + if args.has_labels: + f_id.write(f'{query_id}\t{evidence_ids_str}\t{label}\n') + else: + f_id.write(f'{query_id}\t{evidence_ids_str}\n') + f_text.write(f'hypothesis: {query_text} {evidence_texts_str}\n') + + print(f'Number of sentences truncated: {num_truncated}') + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Converts run files to T5 label prediction model input format.') + parser.add_argument('--dataset_file', required=True, help='FEVER dataset file.') + parser.add_argument('--run_file', required=True, help='Run file generated after re-ranking.') + parser.add_argument('--collection_folder', required=True, help='FEVER wiki-pages directory.') + parser.add_argument('--output_id_file', required=True, help='Output query-doc id pairs file (empty for test set).') + parser.add_argument('--output_text_file', required=True, help='Output query-doc text pairs file.') + parser.add_argument('--max_evidences', type=int, default=5, help='Max concatenated evidences per line.') + parser.add_argument('--max_seq_len', type=int, default=512, help='Max number of tokens per line.') + parser.add_argument('--format', required=True, choices=['concat', 'agg', 'seq']) + parser.add_argument('--has_labels', action='store_true', help='Whether the dataset file is labelled.') + args = parser.parse_args() + + convert_run(args) + + print('Done!') diff --git a/experiments/list5/convert_run_to_sentence_selection_input.py b/experiments/list5/convert_run_to_sentence_selection_input.py new file mode 100644 index 00000000..8c52cbb8 --- /dev/null +++ b/experiments/list5/convert_run_to_sentence_selection_input.py @@ -0,0 +1,133 @@ +import argparse +import ftfy +import itertools +import json +import os + +from fever_utils import extract_entities, extract_sentences, make_sentence_id, normalize_text, split_sentence_id + +def convert_run(args): + queries = {} + evidences = {} + pred_evidences = {} + docs = {} + + # read in dataset file and save queries and evidences to dicts + with open(args.dataset_file, 'r', encoding='utf-8') as f: + print('Reading FEVER dataset file...') + for line in f: + line_json = json.loads(line.strip()) + + query_id = line_json['id'] + + query = line_json['claim'] + queries[query_id] = query + + # only save evidences for non-test sets and non-NEI queries + deduped_evidence_set = set() + if args.has_labels and line_json['label'] != 'NOT ENOUGH INFO': + for annotator in line_json['evidence']: + for evidence in annotator: + evidence[2] = ftfy.fix_text(evidence[2]) + docs[evidence[2]] = 'N/A' # placeholder + deduped_evidence_set.add(make_sentence_id(evidence[2], evidence[3])) + evidences[query_id] = deduped_evidence_set + + # read in run file and save rankings to dict + with open(args.run_file, 'r', encoding='utf-8') as f: + print('Reading run file...') + for line in f: + query_id, sent_id, rank = line.strip().split('\t') + query_id = int(query_id) + doc_id, _ = split_sentence_id(sent_id) + docs[doc_id] = 'N/A' # placeholder + if query_id not in pred_evidences: + pred_evidences[query_id] = [] + if args.k is None or int(rank) <= args.k: + pred_evidences[query_id].append(sent_id) + + # read through all wiki dump files and save doc text for involved docs + print('Reading wiki pages...') + for file in os.listdir(args.collection_folder): + with open(os.path.join(args.collection_folder, file), 'r', encoding='utf-8') as f: + for line in f: + line_json = json.loads(line.strip()) + if line_json['id'] in docs: + docs[line_json['id']] = line_json['lines'] + + # write query-doc pairs to files + with open(args.output_id_file, 'w', encoding='utf-8') as f_id, \ + open(args.output_text_file, 'w', encoding='utf-8') as f_text: + print('Writing query-doc pairs to files...') + for query_id, sent_ids in pred_evidences.items(): + query_text = queries[query_id] + if args.type == 'mono': + if args.ner: + ner_entities = extract_entities(query_text) + + for rank, sent_id in enumerate(sent_ids): + if args.has_labels: + relevance = 'true' if sent_id in evidences[query_id] else 'false' + + # get specific sentence from within doc_text + doc_id, sent_num = split_sentence_id(sent_id) + entity = doc_id.replace('_', ' ') # prepend entity name to document text + doc_text = docs[doc_id] + sent_text, _ = extract_sentences(doc_text)[sent_num] + + # write query-doc pair ids and texts + if args.has_labels: + f_id.write(f'{query_id}\t{sent_id}\t{rank + 1}\t{relevance}\n') + else: + f_id.write(f'{query_id}\t{sent_id}\t{rank + 1}\n') + if args.ner: + numbered_entities = [f'Entity{i + 1}: {entity}' for i, entity in enumerate(ner_entities)] + entities_str = ' '.join(numbered_entities) + f_text.write( + f'Query: {query_text} Document: {entity} . {normalize_text(sent_text)} {entities_str} Relevant:\n' + ) + else: + f_text.write( + f'Query: {query_text} Document: {entity} . {normalize_text(sent_text)} Relevant:\n') + else: # args.type == 'duo' + ranked_sent_ids = [(sent_id, i) for i, sent_id in enumerate(sent_ids)] + for (sent_id_1, rank_1), (sent_id_2, rank_2) in itertools.permutations(ranked_sent_ids, 2): + if args.has_labels: + relevance = 'true' if sent_id_1 in evidences[query_id] else 'false' + + # get specific sentence from within doc_text + doc_id_1, sent_1_num = split_sentence_id(sent_id_1) + entity_1 = doc_id_1.replace('_', ' ') # prepend entity name to document text + doc_text_1 = docs[doc_id_1] + sent_1_text, _ = extract_sentences(doc_text_1)[sent_1_num] + + doc_id_2, sent_2_num = split_sentence_id(sent_id_2) + entity_2 = doc_id_2.replace('_', ' ') # prepend entity name to document text + doc_text_2 = docs[doc_id_2] + sent_2_text, _ = extract_sentences(doc_text_2)[sent_2_num] + + # write query-doc pair ids and texts + if args.has_labels: + f_id.write(f'{query_id}\t{sent_id_1}\t{rank_1 + 1}\t{sent_id_2}\t{rank_2 + 1}\t{relevance}\n') + else: + f_id.write(f'{query_id}\t{sent_id_1}\t{rank_1 + 1}\t{sent_id_2}\t{rank_2 + 1}\n') + f_text.write( + f'Query: {query_text} Document1: {entity_1} . {normalize_text(sent_1_text)} Document2: {entity_2} . {normalize_text(sent_2_text)} Relevant:\n' + ) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Converts run files to T5 sentence re-ranking model input format.') + parser.add_argument('--dataset_file', required=True, help='FEVER dataset file.') + parser.add_argument('--run_file', required=True, help='Run file from running retrieval with anserini.') + parser.add_argument('--collection_folder', required=True, help='FEVER wiki-pages directory.') + parser.add_argument('--output_id_file', required=True, help='Output query-doc id pairs file.') + parser.add_argument('--output_text_file', required=True, help='Output query-doc text pairs file.') + parser.add_argument('--k', type=int, help='Number of top sentences to include for re-ranking.') + parser.add_argument('--type', required=True, choices=['mono', 'duo'], help='Type of T5 inference.') + parser.add_argument('--has_labels', action='store_true', help='Whether the dataset file is labelled.') + parser.add_argument('--ner', action='store_true', help='Whether to append NER entities (only for mono re-ranking).') + args = parser.parse_args() + + convert_run(args) + + print('Done!') diff --git a/experiments/list5/convert_sentence_selection_output_to_run.py b/experiments/list5/convert_sentence_selection_output_to_run.py new file mode 100644 index 00000000..cfe1e139 --- /dev/null +++ b/experiments/list5/convert_sentence_selection_output_to_run.py @@ -0,0 +1,76 @@ +import argparse +import numpy as np + +def convert_output(args): + print('Converting T5 output...') + + with open(args.id_file, 'r', encoding='utf-8') as f_id, open(args.scores_file, 'r', encoding='utf-8') as f_scores, \ + open(args.output_run_file, 'w', encoding='utf-8') as f_run: + curr_qid = None + curr_scores = {} + for id_line, scores_line in zip(f_id, f_scores): + if args.type == 'mono': + if args.has_labels: + query_id, sent_id, _, _ = id_line.strip().split('\t') + else: + query_id, sent_id, _ = id_line.strip().split('\t') + else: # args.type == 'duo' + if args.has_labels: + query_id, sent_id_1, _, sent_id_2, _, _ = id_line.strip().split('\t') + else: + query_id, sent_id_1, _, sent_id_2, _ = id_line.strip().split('\t') + _, score = scores_line.strip().split('\t') + + # check if we have reached a new query_id + if query_id != curr_qid: + # sort previously accumulated doc scores and write to run file + sorted_scores = sorted(curr_scores.items(), key=lambda x: x[1], reverse=True) + curr_index = 1 + for curr_sid, curr_score in sorted_scores: + if args.k is not None and curr_index > args.k: + break + # keep the top predicted result even if it does not meet the threshold + if curr_index == 1 or args.p is None or np.exp(curr_score) >= args.p: + f_run.write(f'{curr_qid}\t{curr_sid}\t{curr_index}\n') + curr_index += 1 + + # update curr_qid and curr_scores with new query_id + curr_qid = query_id + curr_scores.clear() + + # save current score + if args.type == 'mono': + curr_scores[sent_id] = float(score) + else: # args.type == 'duo' + if sent_id_1 not in curr_scores: + curr_scores[sent_id_1] = 0 + if sent_id_2 not in curr_scores: + curr_scores[sent_id_2] = 0 + curr_scores[sent_id_1] += np.exp(float(score)) + curr_scores[sent_id_2] += 1 - np.exp(float(score)) + + # write last query_id to file + sorted_scores = sorted(curr_scores.items(), key=lambda x: x[1], reverse=True) + curr_index = 1 + for curr_sid, curr_score in sorted_scores: + if args.k is not None and curr_index > args.k: + break + # keep the top predicted result even if it does not meet the threshold + if curr_index == 1 or args.p is None or np.exp(curr_score) >= args.p: + f_run.write(f'{curr_qid}\t{curr_sid}\t{curr_index}\n') + curr_index += 1 + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Converts T5 re-ranking outputs to anserini run file format.') + parser.add_argument('--id_file', required=True, help='Input query-doc pair ids file.') + parser.add_argument('--scores_file', required=True, help='Prediction scores file outputted by T5 re-ranking model.') + parser.add_argument('--output_run_file', required=True, help='Output run file.') + parser.add_argument('--p', type=float, help='Optional probability threshold.') + parser.add_argument('--k', type=int, help='Optional top-k cutoff.') + parser.add_argument('--type', required=True, choices=['mono', 'duo'], help='Type of T5 inference.') + parser.add_argument('--has_labels', action='store_true', help='Whether the dataset file is labelled.') + args = parser.parse_args() + + convert_output(args) + + print('Done!') diff --git a/experiments/list5/expand_docs_to_sentences.py b/experiments/list5/expand_docs_to_sentences.py new file mode 100644 index 00000000..ba1bce17 --- /dev/null +++ b/experiments/list5/expand_docs_to_sentences.py @@ -0,0 +1,53 @@ +import argparse +import json +import os + +from fever_utils import extract_sentences, make_sentence_id + +def convert_run(args): + doc_sentences = {} + rankings = {} + + # read in input run file and save rankings to dict + with open(args.input_run_file, 'r', encoding='utf-8') as f: + print('Reading run file...') + for line in f: + query_id, doc_id, rank = line.strip().split('\t') + if doc_id not in doc_sentences: + doc_sentences[doc_id] = [] + if query_id not in rankings: + rankings[query_id] = [] + rankings[query_id].append(doc_id) + + # read through all wiki dump files and save sentence IDs for involved docs + print('Reading wiki pages...') + for file in os.listdir(args.collection_folder): + with open(os.path.join(args.collection_folder, file), 'r', encoding='utf-8') as f: + for line in f: + line_json = json.loads(line.strip()) + if line_json['id'] in doc_sentences: + sent_ids = [id for sent, id in extract_sentences(line_json['lines']) if sent] + doc_sentences[line_json['id']].extend(sent_ids) + + # write expanded sentence IDs to output run file + with open(args.output_run_file, 'w', encoding='utf-8') as f: + print('Writing sentences to run file...') + for query_id, doc_ids in rankings.items(): + query_index = 1 + for doc_id in doc_ids[:args.k]: + for sent_num in doc_sentences[doc_id]: + sent_id = make_sentence_id(doc_id, sent_num) + f.write(f'{query_id}\t{sent_id}\t{query_index}\n') + query_index += 1 + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Expands document-level anserini run file to sentence-level.') + parser.add_argument('--input_run_file', required=True, help='Input document-level run file.') + parser.add_argument('--collection_folder', required=True, help='FEVER wiki-pages directory.') + parser.add_argument('--output_run_file', required=True, help='Output sentence-level run file.') + parser.add_argument('--k', default=100, type=int, help='Top k documents to expand.') + args = parser.parse_args() + + convert_run(args) + + print('Done!') diff --git a/experiments/list5/fever_utils.py b/experiments/list5/fever_utils.py new file mode 100644 index 00000000..4c45553e --- /dev/null +++ b/experiments/list5/fever_utils.py @@ -0,0 +1,112 @@ +import re +import spacy +from transformers import T5Tokenizer + +# regex patterns in order of priority +PTB_SYMBOLS = [ + (re.compile(r'-LRB-'), '('), + (re.compile(r'-RRB-'), ')'), + (re.compile(r'\( ([^\(\)]*?) \)'), '(\\1)'), + (re.compile(r'-LSB-'), '['), + (re.compile(r'-RSB-'), ']'), + (re.compile(r'\[ ([^\[\]]*?) \]'), ''), # most occurrences of [] contain pronounciations, which we don't want + (re.compile(r'-LCB-'), '{'), + (re.compile(r'-RCB-'), '}'), + (re.compile(r'\{ ([^\{\}]*?) \}'), '{\\1}'), + (re.compile(r'-COLON-'), ':'), + (re.compile(r'–|−'), '-'), + (re.compile(r'`` ([^`]*?) \'\''), '"\\1"'), + (re.compile(r'` ([^`]*?) \''), '\'\\1\''), + (re.compile(r' ([,\.:;\'!?])'), '\\1') +] + +# NER +nlp = spacy.load('en_core_web_sm') + +# T5 tokenizer +tokenizer = T5Tokenizer.from_pretrained('t5-3b') + +def extract_sentences(lines): + """ + Extracts the non-empty sentences and their numbers of the "lines" field in + a JSON object from a FEVER wiki-pages JSONL file. + """ + sentences = [] + + sentence_index = 0 + for line in lines.split('\n'): + tokens = line.split('\t') + if not tokens[0].isnumeric() or int(tokens[0]) != sentence_index: + # skip non-sentences, caused by unexpected \n's + continue + else: + sentences.append((tokens[1], tokens[0])) + sentence_index += 1 + + return sentences + +def make_sentence_id(doc_id, sentence_num): + """ + Returns the sentence ID of a Wikipedia document ID and the number + corresponding to its specific sentence index in the document. + """ + return f'{doc_id}_{sentence_num}' + +def split_sentence_id(sentence_id): + """ + Returns the original document ID and sentence number of a sentence ID. + """ + separator_index = sentence_id.rfind('_') + doc_id = sentence_id[:separator_index] + sent_num = int(sentence_id[separator_index + 1:]) + + return doc_id, sent_num + +def normalize_text(text): + """ + Normalizes text found in FEVER dataset, removing punctuation tokens and + cleaning whitespace around punctuation. + """ + for regexp, substitution in PTB_SYMBOLS: + text = regexp.sub(substitution, text) + + return text + +def remove_disambiguation(doc_id): + """ + Normalizes and removes disambiguation info from a document ID. + """ + doc_id = doc_id.replace('_', ' ').replace('-COLON-', ':') + if '-LRB-' in doc_id: + doc_id = doc_id[:doc_id.find('-LRB-') - 1] + + return doc_id + +def extract_entities(text): + """ + Extracts named entities from text using spaCy's en_core_web_sm NER module. + """ + doc = nlp(text) + ner_entities = list( + set([ + entity.text for entity in doc.ents + # ignore entities that are less likely to correspond to a Wikipedia article + if entity.label_ not in ['DATE', 'TIME', 'PERCENT', 'MONEY', 'QUANTITY', 'ORDINAL', 'CARDINAL'] + ])) + + return ner_entities + +def truncate(query, sent, num_sents, line_len, trunc_count): + """ + Truncates evidence sentence to fit in a T5 input line with max line_len + tokens and num_sents sentences to be concatenated. Accounts for query text + and tracks the number of sentences truncated. + """ + query_tokens = tokenizer.tokenize(query) + trunc_len = (line_len - 2 - len(query_tokens) - 3 * num_sents - 1) // num_sents + tokens = tokenizer.tokenize(sent) + if len(tokens) > trunc_len: + tokens = tokens[:trunc_len] + return tokenizer.convert_tokens_to_string(tokens), trunc_count + 1 + + return sent, trunc_count diff --git a/experiments/list5/generate_label_prediction_data_gold.py b/experiments/list5/generate_label_prediction_data_gold.py new file mode 100644 index 00000000..685ce782 --- /dev/null +++ b/experiments/list5/generate_label_prediction_data_gold.py @@ -0,0 +1,143 @@ +import argparse +import ftfy +import json +import os +import random + +from fever_utils import extract_sentences, make_sentence_id, normalize_text, split_sentence_id, truncate + +def generate_data(args): + queries = {} + labels = {} + evidences = {} + docs = {} + + num_truncated = 0 + + # read in dataset file and save queries and evidences to dicts + with open(args.dataset_file, 'r', encoding='utf-8') as f: + print('Reading FEVER dataset file...') + for line in f: + line_json = json.loads(line.strip()) + query_id = line_json['id'] + + query = line_json['claim'] + queries[query_id] = query + + label = line_json['label'] + if label == 'SUPPORTS': + labels[query_id] = 'true' + elif label == 'REFUTES': + labels[query_id] = 'false' + else: # label == 'NOT ENOUGH INFO' + labels[query_id] = 'weak' + + annotators = [] + if label != 'NOT ENOUGH INFO': # no evidence set for NEI queries, will sample from run files later + for annotator in line_json['evidence']: + evidence_set = [] + for evidence in annotator: + evidence[2] = ftfy.fix_text(evidence[2]) + docs[evidence[2]] = 'N/A' # placeholder + evidence_set.append(make_sentence_id(evidence[2], evidence[3])) + annotators.append(evidence_set) + evidences[query_id] = annotators + + # samples evidence from pred_sent_ids + def negative_sample(query_id, pred_sent_ids): + neg_sent_ids = random.sample(pred_sent_ids, random.randint(1, args.max_evidences)) + + for sent_id in neg_sent_ids: + doc_id, _ = split_sentence_id(sent_id) + docs[doc_id] = 'N/A' # placeholder + + return [neg_sent_ids] + + # read in run file and sample run file ranking predictions for queries + with open(args.run_file, 'r', encoding='utf-8') as f: + print('Reading run file...') + curr_query = None + pred_sent_ids = [] + for line in f: + query_id, sent_id, rank = line.strip().split('\t') + query_id = int(query_id) + + # if we reach a new query in the run file, perform sampling for previous query if needed + if query_id != curr_query: + if curr_query is not None and len(evidences[curr_query]) == 0: + evidences[curr_query] = negative_sample(curr_query, pred_sent_ids) + curr_query = query_id + pred_sent_ids.clear() + + if args.min_rank <= int(rank) <= args.max_rank: + pred_sent_ids.append(sent_id) + + # handle the final query + if len(evidences[curr_query]) == 0: + evidences[curr_query] = negative_sample(curr_query, pred_sent_ids) + + # read through all wiki dump files and save doc text for involved docs + print('Reading wiki pages...') + for file in os.listdir(args.collection_folder): + with open(os.path.join(args.collection_folder, file), 'r', encoding='utf-8') as f: + for line in f: + line_json = json.loads(line.strip()) + if line_json['id'] in docs: + docs[line_json['id']] = line_json['lines'] + + # write query-doc text pairs to files + with open(args.output_id_file, 'w', encoding='utf-8') as f_id, \ + open(args.output_text_file, 'w', encoding='utf-8') as f_text: + print('Writing query-doc pairs to files...') + for query_id, query_text in queries.items(): + label = labels[query_id] + + for evidence_ids in evidences[query_id]: + evidence_texts = [] + for evidence in evidence_ids: + # get specific sentence from within doc_text + doc_id, sent_num = split_sentence_id(evidence) + entity = doc_id.replace('_', ' ') # prepend entity name to document text + doc_text = docs[doc_id] + sent_text, _ = extract_sentences(doc_text)[sent_num] + evidence_texts.append(f'{normalize_text(entity)} . {normalize_text(sent_text)}') + + if args.format == 'concat': + evidence_ids_str = ' '.join(evidence_ids) + prefixed_evidence_texts = [] + for i, evidence_text in enumerate(evidence_texts): + truncated_text, num_truncated = truncate(query_text, evidence_text, args.max_evidences, + args.max_seq_len, num_truncated) + prefixed_evidence_texts.append(f'sentence{i + 1}: {truncated_text}') + evidence_texts_str = ' '.join(prefixed_evidence_texts) + + f_id.write(f'{query_id}\t{evidence_ids_str}\n') + f_text.write(f'hypothesis: {query_text} {evidence_texts_str}\t{label}\n') + else: # args.format == 'agg' + for evidence_id, evidence_text in zip(evidence_ids, evidence_texts): + f_id.write(f'{query_id}\t{evidence_id}\n') + f_text.write(f'hypothesis: {query_text} premise: {evidence_text}\t{label}\n') + + print(f'Number of sentences truncated: {num_truncated}') + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generates "gold" FEVER label prediction training data.') + parser.add_argument('--dataset_file', required=True, help='FEVER dataset file.') + parser.add_argument('--run_file', required=True, help='Run file generated after re-ranking.') + parser.add_argument('--collection_folder', required=True, help='FEVER wiki-pages directory.') + parser.add_argument('--output_id_file', required=True, help='Output query-doc id pairs file.') + parser.add_argument('--output_text_file', required=True, help='Output query-doc text pairs file.') + parser.add_argument('--min_rank', type=int, help='Smallest rank to sample from.') + parser.add_argument('--max_rank', type=int, help='Largest rank to sample from.') + parser.add_argument('--max_evidences', type=int, required=True, help='Max number of evidences to negative sample.') + parser.add_argument('--max_seq_len', type=int, default=512, help='Max number of tokens per line.') + parser.add_argument('--format', required=True, choices=['concat', 'agg'], help='Format of output query-doc files.') + parser.add_argument('--seed', type=int, help='Optional seed for random sampling.') + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + + generate_data(args) + + print('Done!') diff --git a/experiments/list5/generate_label_prediction_data_noisy.py b/experiments/list5/generate_label_prediction_data_noisy.py new file mode 100644 index 00000000..add990e5 --- /dev/null +++ b/experiments/list5/generate_label_prediction_data_noisy.py @@ -0,0 +1,162 @@ +import argparse +import ftfy +import json +import os +import random + +from fever_utils import extract_sentences, make_sentence_id, normalize_text, split_sentence_id, truncate + +def generate_data(args): + queries = {} + labels = {} + evidences = {} + evidence_relevances = {} + docs = {} + + num_truncated = 0 + + # read in dataset file and save queries and evidences to dicts + with open(args.dataset_file, 'r', encoding='utf-8') as f: + print('Reading FEVER dataset file...') + for line in f: + line_json = json.loads(line.strip()) + query_id = line_json['id'] + + query = line_json['claim'] + queries[query_id] = query + + label = line_json['label'] + if label == 'SUPPORTS': + labels[query_id] = 'true' + elif label == 'REFUTES': + labels[query_id] = 'false' + else: # label == 'NOT ENOUGH INFO' + labels[query_id] = 'weak' + + annotators = [] + if label != 'NOT ENOUGH INFO': # no evidence set for NEI queries, will sample from run files later + for annotator in line_json['evidence']: + evidence_set = [] + for evidence in annotator: + evidence[2] = ftfy.fix_text(evidence[2]) + evidence_set.append(make_sentence_id(evidence[2], evidence[3])) + annotators.append(evidence_set) + else: + annotators.append([]) + evidences[query_id] = annotators + + # for each evidence set, check if all gold evidences are in pred_sent_ids and randomly insert if not present + def generate_samples(query_id, pred_sent_ids): + all_sent_ids = [] + all_relevances = [] + + for true_evidence_set in evidences[query_id]: + sent_ids = [evidence for evidence in pred_sent_ids] + relevances = [int(evidence in true_evidence_set) for evidence in pred_sent_ids] + + # randomly insert relevant evidences if query is not NEI and not all true evidences are in sent_ids + if len(true_evidence_set) != 0 and len(true_evidence_set) != sum(relevances): + for evidence in true_evidence_set: + # stop inserting if all evidences are relevant + if sum(relevances) == len(relevances): + break + if evidence not in sent_ids: + doc_id, _ = split_sentence_id(evidence) + docs[doc_id] = 'N/A' # placeholder + + overwrite_index = random.choice([i for i in range(len(relevances)) if relevances[i] == 0]) + sent_ids[overwrite_index] = evidence + relevances[overwrite_index] = 1 + + all_sent_ids.append(sent_ids) + all_relevances.append(relevances) + + return all_sent_ids, all_relevances + + # read in run file and sample run file ranking predictions for queries + with open(args.run_file, 'r', encoding='utf-8') as f: + print('Reading run file...') + curr_query = None + pred_sent_ids = [] + for line in f: + query_id, sent_id, rank = line.strip().split('\t') + query_id = int(query_id) + + # if we reach a new query in the run file, perform sampling for previous query if needed + if query_id != curr_query: + if curr_query is not None: + all_sent_ids, all_relevances = generate_samples(curr_query, pred_sent_ids) + evidences[curr_query] = all_sent_ids + evidence_relevances[curr_query] = all_relevances + curr_query = query_id + pred_sent_ids.clear() + + if int(rank) <= args.max_evidences: + doc_id, _ = split_sentence_id(sent_id) + docs[doc_id] = 'N/A' # placeholder + pred_sent_ids.append(sent_id) + + # handle the final query + all_sent_ids, all_relevances = generate_samples(curr_query, pred_sent_ids) + evidences[curr_query] = all_sent_ids + evidence_relevances[curr_query] = all_relevances + + # read through all wiki dump files and save doc text for involved docs + print('Reading wiki pages...') + for file in os.listdir(args.collection_folder): + with open(os.path.join(args.collection_folder, file), 'r', encoding='utf-8') as f: + for line in f: + line_json = json.loads(line.strip()) + if line_json['id'] in docs: + docs[line_json['id']] = line_json['lines'] + + # write query-doc text pairs to files + with open(args.output_id_file, 'w', encoding='utf-8') as f_id, \ + open(args.output_text_file, 'w', encoding='utf-8') as f_text: + print('Writing query-doc pairs to files...') + for query_id, query_text in queries.items(): + label = labels[query_id] + + for evidence_ids, relevances in zip(evidences[query_id], evidence_relevances[query_id]): + evidence_texts = [] + for evidence in evidence_ids: + # get specific sentence from within doc_text + doc_id, sent_num = split_sentence_id(evidence) + entity = doc_id.replace('_', ' ') # prepend entity name to document text + doc_text = docs[doc_id] + sent_text, _ = extract_sentences(doc_text)[sent_num] + evidence_texts.append(f'{normalize_text(entity)} . {normalize_text(sent_text)}') + + # format evidence ids and texts in proper format + evidence_ids_str = ' '.join(evidence_ids) + relevances_str = ','.join([str(relevance) for relevance in relevances]) + prefixed_evidence_texts = [] + for i, evidence_text in enumerate(evidence_texts): + truncated_text, num_truncated = truncate(query_text, evidence_text, args.max_evidences, + args.max_seq_len, num_truncated) + prefixed_evidence_texts.append(f'sentence{i + 1}: {truncated_text}') + evidence_texts_str = ' '.join(prefixed_evidence_texts) + + f_id.write(f'{query_id}\t{evidence_ids_str}\t{relevances_str}\n') + f_text.write(f'hypothesis: {query_text} {evidence_texts_str}\t{label}\n') + + print(f'Number of sentences truncated: {num_truncated}') + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generates "noise-infused" FEVER label prediction training data.') + parser.add_argument('--dataset_file', required=True, help='FEVER dataset file.') + parser.add_argument('--run_file', required=True, help='Run file generated after re-ranking.') + parser.add_argument('--collection_folder', required=True, help='FEVER wiki-pages directory.') + parser.add_argument('--output_id_file', required=True, help='Output query-doc id pairs file.') + parser.add_argument('--output_text_file', required=True, help='Output query-doc text pairs file.') + parser.add_argument('--max_evidences', type=int, default=5, help='Max concatenated evidences per line.') + parser.add_argument('--max_seq_len', type=int, default=512, help='Max number of tokens per line.') + parser.add_argument('--seed', type=int, help='Optional seed for random sampling.') + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + + generate_data(args) + + print('Done!') diff --git a/experiments/list5/generate_sentence_selection_data.py b/experiments/list5/generate_sentence_selection_data.py new file mode 100644 index 00000000..e5dac240 --- /dev/null +++ b/experiments/list5/generate_sentence_selection_data.py @@ -0,0 +1,120 @@ +import argparse +import ftfy +import json +import os +import random + +from fever_utils import extract_sentences, make_sentence_id, normalize_text, split_sentence_id + +def generate_data(args): + queries = {} + evidences = {} + pred_evidences = {} + docs = {} + + # read in dataset file and save queries and evidences to dicts + with open(args.dataset_file, 'r', encoding='utf-8') as f: + print('Reading FEVER dataset file...') + for line in f: + line_json = json.loads(line.strip()) + + query_id = line_json['id'] + + query = line_json['claim'] + queries[query_id] = query + + # only save evidences for non-test sets and non-NEI queries + deduped_evidence_set = set() + if line_json['label'] != 'NOT ENOUGH INFO': + for annotator in line_json['evidence']: + for evidence in annotator: + evidence[2] = ftfy.fix_text(evidence[2]) + docs[evidence[2]] = 'N/A' # placeholder + deduped_evidence_set.add(make_sentence_id(evidence[2], evidence[3])) + evidences[query_id] = deduped_evidence_set + + def generate_samples(query_id, pred_sent_ids): + curr_pred_evidences = [] + + # include all ground truth relevant evidences as positive samples + for sent_id in evidences[query_id]: + curr_pred_evidences.append(sent_id) + + # sample negative evidences from pred_sent_ids + neg_pred_sent_ids = [pred for pred in pred_sent_ids if pred not in evidences[query_id]] + neg_sent_ids = random.sample(neg_pred_sent_ids, min(len(evidences[query_id]), len(neg_pred_sent_ids))) + for sent_id in neg_sent_ids: + doc_id, _ = split_sentence_id(sent_id) + docs[doc_id] = 'N/A' # placeholder + curr_pred_evidences.append(sent_id) + + return curr_pred_evidences + + # read in run file and negative sample using run file ranking predictions + with open(args.run_file, 'r', encoding='utf-8') as f: + print('Reading run file...') + curr_query = None + pred_sent_ids = [] + for line in f: + query_id, sent_id, rank = line.strip().split('\t') + query_id = int(query_id) + + # if we reach a new query in the run file, perform sampling for the previous query + if query_id != curr_query: + if curr_query is not None: + pred_evidences[curr_query] = generate_samples(curr_query, pred_sent_ids) + curr_query = query_id + pred_sent_ids.clear() + + if args.min_rank <= int(rank) <= args.max_rank: + pred_sent_ids.append(sent_id) + + # perform sampling for the final query + pred_evidences[curr_query] = generate_samples(curr_query, pred_sent_ids) + + # read through all wiki dump files and save doc text for involved docs + print('Reading wiki pages...') + for file in os.listdir(args.collection_folder): + with open(os.path.join(args.collection_folder, file), 'r', encoding='utf-8') as f: + for line in f: + line_json = json.loads(line.strip()) + if line_json['id'] in docs: + docs[line_json['id']] = line_json['lines'] + + # write query-doc text pairs to files + with open(args.output_id_file, 'w', encoding='utf-8') as f_id, \ + open(args.output_text_file, 'w', encoding='utf-8') as f_text: + print('Writing query-doc pairs to files...') + for query_id, sent_ids in pred_evidences.items(): + query_text = queries[query_id] + + for rank, sent_id in enumerate(sent_ids): + relevance = 'true' if sent_id in evidences[query_id] else 'false' + # get specific sentence from within doc_text + doc_id, sent_num = split_sentence_id(sent_id) + entity = doc_id.replace('_', ' ') # prepend entity name to document text + doc_text = docs[doc_id] + sent_text, _ = extract_sentences(doc_text)[sent_num] + + f_id.write(f'{query_id}\t{sent_id}\t{rank + 1}\n') + f_text.write( + f'Query: {query_text} Document: {entity} . {normalize_text(sent_text)} Relevant:\t{relevance}\n') + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generates FEVER re-ranking training data.') + parser.add_argument('--dataset_file', required=True, help='FEVER dataset file.') + parser.add_argument('--run_file', required=True, help='Run file from running retrieval with anserini.') + parser.add_argument('--collection_folder', required=True, help='FEVER wiki-pages directory.') + parser.add_argument('--output_id_file', required=True, help='Output query-doc id pairs file.') + parser.add_argument('--output_text_file', required=True, help='Output query-doc text pairs file.') + parser.add_argument('--min_rank', type=int, help='Smallest rank to sample from (for negative samples).') + parser.add_argument('--max_rank', type=int, help='Largest rank to sample from (for negative samples).') + parser.add_argument('--seed', type=int, help='Optional seed for random sampling.') + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + + generate_data(args) + + print('Done!') diff --git a/experiments/list5/generate_sentence_selection_ner_data.py b/experiments/list5/generate_sentence_selection_ner_data.py new file mode 100644 index 00000000..06962cae --- /dev/null +++ b/experiments/list5/generate_sentence_selection_ner_data.py @@ -0,0 +1,150 @@ +import argparse +import ftfy +import json +import os +import random + +from fever_utils import extract_entities, extract_sentences, make_sentence_id, normalize_text, remove_disambiguation, split_sentence_id + +def generate_data(args): + queries = {} + evidences = {} + pred_evidences = {} + docs = {} + + num_actual = 0 + num_pred = 0 + correct = 0 + + # read in dataset file and save queries and evidences to dicts + with open(args.dataset_file, 'r', encoding='utf-8') as f: + print('Reading FEVER dataset file...') + for line in f: + line_json = json.loads(line.strip()) + + query_id = line_json['id'] + + query = line_json['claim'] + queries[query_id] = query + + # only save evidences for non-test sets and non-NEI queries + deduped_evidence_set = set() + if line_json['label'] != 'NOT ENOUGH INFO': + for annotator in line_json['evidence']: + for evidence in annotator: + evidence[2] = ftfy.fix_text(evidence[2]) + docs[evidence[2]] = 'N/A' # placeholder + deduped_evidence_set.add(make_sentence_id(evidence[2], evidence[3])) + evidences[query_id] = deduped_evidence_set + + def generate_samples(query_id, pred_sent_ids): + curr_pred_evidences = [] + + # include all ground truth relevant evidences as positive samples + for sent_id in evidences[query_id]: + curr_pred_evidences.append(sent_id) + + # sample negative evidences from pred_sent_ids + neg_pred_sent_ids = [pred for pred in pred_sent_ids if pred not in evidences[query_id]] + neg_sent_ids = random.sample(neg_pred_sent_ids, min(len(evidences[query_id]), len(neg_pred_sent_ids))) + for sent_id in neg_sent_ids: + doc_id, _ = split_sentence_id(sent_id) + docs[doc_id] = 'N/A' # placeholder + curr_pred_evidences.append(sent_id) + + return curr_pred_evidences + + # read in run file and negative sample using run file ranking predictions + with open(args.run_file, 'r', encoding='utf-8') as f: + print('Reading run file...') + curr_query = None + pred_sent_ids = [] + for line in f: + query_id, sent_id, rank = line.strip().split('\t') + query_id = int(query_id) + + # if we reach a new query in the run file, perform sampling for the previous query + if query_id != curr_query: + if curr_query is not None: + pred_evidences[curr_query] = generate_samples(curr_query, pred_sent_ids) + curr_query = query_id + pred_sent_ids.clear() + + if args.min_rank <= int(rank) <= args.max_rank: + pred_sent_ids.append(sent_id) + + # perform sampling for the final query + pred_evidences[curr_query] = generate_samples(curr_query, pred_sent_ids) + + # read through all wiki dump files and save doc text for involved docs + print('Reading wiki pages...') + for file in os.listdir(args.collection_folder): + with open(os.path.join(args.collection_folder, file), 'r', encoding='utf-8') as f: + for line in f: + line_json = json.loads(line.strip()) + if line_json['id'] in docs: + docs[line_json['id']] = line_json['lines'] + + # write query-doc text pairs to files + with open(args.output_id_file, 'w', encoding='utf-8') as f_id, \ + open(args.output_text_file, 'w', encoding='utf-8') as f_text: + print('Writing query-doc pairs to files...') + for query_id, sent_ids in pred_evidences.items(): + query_text = queries[query_id] + + # only track actual entities that can be found within the query + actual_entities = [] + for sent_id in evidences[query_id]: + entity = remove_disambiguation(split_sentence_id(sent_id)[0]) + if entity not in actual_entities and entity.lower() in query_text.lower(): + actual_entities.append(entity) + num_actual += len(actual_entities) + + # run NER to get predicted entities + ner_entities = extract_entities(query_text) + num_pred += len(ner_entities) + + correct += sum([int(entity.lower() in [ent.lower() for ent in ner_entities]) for entity in actual_entities]) + + for rank, sent_id in enumerate(sent_ids): + relevance = 'true' if sent_id in evidences[query_id] else 'false' + # get specific sentence from within doc_text + doc_id, sent_num = split_sentence_id(sent_id) + entity = doc_id.replace('_', ' ') # prepend entity name to document text + doc_text = docs[doc_id] + sent_text, _ = extract_sentences(doc_text)[sent_num] + + numbered_entities = [f'Entity{i + 1}: {entity}' for i, entity in enumerate(ner_entities)] + entities_str = ' '.join(numbered_entities) + + f_id.write(f'{query_id}\t{sent_id}\t{rank + 1}\n') + f_text.write( + f'Query: {query_text} Document: {entity} . {normalize_text(sent_text)} {entities_str} Relevant:\t{relevance}\n' + ) + + print('****************************************') + print(f'Actual Entities: {num_actual}') + print(f'Predicted Entities: {num_pred}') + print(f'Correctly Predicted Entities: {correct}') + print(f'Precision: {correct / num_pred}') + print(f'Recall: {correct / num_actual}') + print('****************************************') + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generates FEVER re-ranking training data with NER entities.') + parser.add_argument('--dataset_file', required=True, help='FEVER dataset file.') + parser.add_argument('--run_file', required=True, help='Run file from running retrieval with anserini.') + parser.add_argument('--collection_folder', required=True, help='FEVER wiki-pages directory.') + parser.add_argument('--output_id_file', required=True, help='Output query-doc id pairs file.') + parser.add_argument('--output_text_file', required=True, help='Output query-doc text pairs file.') + parser.add_argument('--min_rank', type=int, help='Smallest rank to sample from (for negative samples).') + parser.add_argument('--max_rank', type=int, help='Largest rank to sample from (for negative samples).') + parser.add_argument('--seed', type=int, help='Optional seed for random sampling.') + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + + generate_data(args) + + print('Done!') diff --git a/experiments/list5/merge_runs.py b/experiments/list5/merge_runs.py new file mode 100644 index 00000000..a6a2dc6d --- /dev/null +++ b/experiments/list5/merge_runs.py @@ -0,0 +1,49 @@ +import argparse +import itertools + +def merge_runs(args): + rankings = {} + + # read in input run file and save rankings to dict + for input_file in args.input_run_file: + with open(input_file, 'r', encoding='utf-8') as f: + print(f'Reading input run file {input_file}...') + for line in f: + query_id, doc_id, _ = line.strip().split('\t') + if query_id not in rankings: + rankings[query_id] = {} + if input_file not in rankings[query_id]: + rankings[query_id][input_file] = [] + rankings[query_id][input_file].append(doc_id) + + # write expanded sentence IDs to output run file + with open(args.output_run_file, 'w', encoding='utf-8') as f: + print('Writing merged results to run file...') + for query_id, files in rankings.items(): + doc_ids = [] + doc_ids_set = set() + if args.strategy == 'zip': + for curr_doc_ids in itertools.zip_longest(*files.values()): + for doc_id in curr_doc_ids: + if doc_id and doc_id not in doc_ids_set: + doc_ids.append(doc_id) + doc_ids_set.add(doc_id) + else: # args.strategy == 'sequential' + for curr_doc_ids in files.values(): + for doc_id in curr_doc_ids: + if doc_id not in doc_ids_set: + doc_ids.append(doc_id) + doc_ids_set.add(doc_id) + for i, doc_id in enumerate(doc_ids): + f.write(f'{query_id}\t{doc_id}\t{i + 1}\n') + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Merges several anserini run files.') + parser.add_argument('--input_run_file', required=True, action='append', help='Input run files.') + parser.add_argument('--output_run_file', required=True, help='Output run file.') + parser.add_argument('--strategy', required=True, choices=['zip', 'sequential'], help='Strategy to merge the runs.') + args = parser.parse_args() + + merge_runs(args) + + print('Done!') diff --git a/experiments/list5/predict_for_submission.py b/experiments/list5/predict_for_submission.py new file mode 100644 index 00000000..8281dcd8 --- /dev/null +++ b/experiments/list5/predict_for_submission.py @@ -0,0 +1,106 @@ +import argparse +import csv +import json +import numpy as np +from sklearn.metrics import accuracy_score, confusion_matrix, f1_score + +from fever_utils import split_sentence_id + +def predict(args): + preds = {} + + def aggregate(query_id, scores, sent_ids): + pred = {} + + best = np.argmax(scores[0]) + + pred['id'] = query_id + if best == 0: + pred['predicted_label'] = 'REFUTES' + elif best == 1: + pred['predicted_label'] = 'NOT ENOUGH INFO' + else: # best == 2 + pred['predicted_label'] = 'SUPPORTS' + pred['predicted_evidence'] = [list(split_sentence_id(sent)) for sent in sent_ids] + + return best, pred + + with open(args.id_file, 'r', encoding='utf-8') as f_id, open(args.scores_file, 'r', encoding='utf-8') as f_scores, \ + open(args.output_predictions_file, 'w', encoding='utf-8') as f_out: + print('Reading scores file...') + curr_query = None + curr_sent_ids = [] + curr_scores = [] + for id_line, scores_line in zip(f_id, f_scores): + if args.has_labels: + query_id, sent_ids, label = id_line.strip().split('\t') + else: + query_id, sent_ids = id_line.strip().split('\t') + query_id = int(query_id) + _, false_score, nei_score, true_score = scores_line.strip().split('\t') + + if query_id != curr_query: + if curr_query is not None: + best, pred = aggregate(curr_query, curr_scores, curr_sent_ids) + json.dump(pred, f_out) + f_out.write('\n') + preds[curr_query] = best + curr_query = query_id + curr_sent_ids.clear() + curr_scores.clear() + + curr_sent_ids = sent_ids.split(' ') + curr_scores.append((float(false_score), float(nei_score), float(true_score))) + + best, pred = aggregate(curr_query, curr_scores, curr_sent_ids) + json.dump(pred, f_out) + f_out.write('\n') + preds[curr_query] = best + + # print label prediction metrics if dataset file provided + if args.dataset_file: + actual_labels = [] + pred_labels = [] + with open(args.dataset_file, 'r', encoding='utf-8') as f: + print('Reading FEVER dataset file...') + for line in f: + line_json = json.loads(line.strip()) + + label = line_json['label'] + if label == 'SUPPORTS': + actual_labels.append(2) + elif label == 'REFUTES': + actual_labels.append(0) + else: # label == 'NOT ENOUGH INFO' + actual_labels.append(1) + + query_id = line_json['id'] + pred_labels.append(preds[query_id]) + + print('****************************************') + print(f'Number of Queries: {len(actual_labels)}') + print(f'Label Accuracy: {accuracy_score(actual_labels, pred_labels)}') + print(f'Label F1: {f1_score(actual_labels, pred_labels, average=None)}') + print(confusion_matrix(actual_labels, pred_labels)) + print('****************************************') + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Predicts labels and evidence sentences for FEVER submission.') + parser.add_argument('--id_file', required=True, help='Input query-doc pair ids file.') + parser.add_argument('--scores_file', + required=True, + help='Prediction scores file outputted by T5 label prediction model.') + parser.add_argument('--dataset_file', help='FEVER dataset file (only if labelled).') + parser.add_argument('--output_predictions_file', + required=True, + help='Output predictions file in FEVER submission format.') + parser.add_argument('--evidence_k', + type=int, + default=5, + help='Number of top sentences to use as evidence for FEVER submission.') + parser.add_argument('--has_labels', action='store_true', help='Whether the id file is labelled.') + args = parser.parse_args() + + predict(args) + + print('Done!') diff --git a/experiments/list5/ukp-athene/convert_to_run.py b/experiments/list5/ukp-athene/convert_to_run.py new file mode 100644 index 00000000..21586d96 --- /dev/null +++ b/experiments/list5/ukp-athene/convert_to_run.py @@ -0,0 +1,17 @@ +import argparse +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--dataset_file', required=True, help='UKP-Athene doc retrieval output file.') +parser.add_argument('--output_run_file', required=True, help='Output run file.') +args = parser.parse_args() + +with open(args.dataset_file, 'r', encoding='utf-8') as f_in, open(args.output_run_file, 'w', encoding='utf-8') as f_out: + for line in f_in: + i = 0 + line_json = json.loads(line.strip()) + qid = line_json['id'] + + for did in line_json['predicted_pages']: + i += 1 + f_out.write(f'{qid}\t{did}\t{i}\n') diff --git a/experiments/list5/ukp-athene/doc_retrieval.py b/experiments/list5/ukp-athene/doc_retrieval.py new file mode 100644 index 00000000..c89b9bab --- /dev/null +++ b/experiments/list5/ukp-athene/doc_retrieval.py @@ -0,0 +1,210 @@ +import argparse +import json +import os +import re +import time +from multiprocessing.pool import ThreadPool +import nltk +import wikipedia +from allennlp.predictors import Predictor +from tqdm import tqdm +from unicodedata import normalize + +def processed_line(method, line): + nps, wiki_results, pages = method.exact_match(line) + line['noun_phrases'] = nps + line['predicted_pages'] = pages + line['wiki_results'] = wiki_results + return line + +def process_line_with_progress(method, line, progress=None): + if progress is not None and line['id'] in progress: + return progress[line['id']] + else: + return processed_line(method, line) + +class Doc_Retrieval: + def __init__(self, database_path, add_claim=False, k_wiki_results=None): + self.add_claim = add_claim + self.k_wiki_results = k_wiki_results + self.proter_stemm = nltk.PorterStemmer() + self.tokenizer = nltk.word_tokenize + self.predictor = Predictor.from_path( + "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo-constituency-parser-2018.03.14.tar.gz") + + self.db = {} + for file in os.listdir(database_path): + with open(os.path.join(database_path, file), 'r', encoding='utf-8') as f: + for line in f: + line_json = json.loads(line.strip()) + self.db[line_json['id']] = line_json['lines'] + + def get_NP(self, tree, nps): + if isinstance(tree, dict): + if "children" not in tree: + if tree['nodeType'] == "NP": + nps.append(tree['word']) + elif "children" in tree: + if tree['nodeType'] == "NP": + nps.append(tree['word']) + self.get_NP(tree['children'], nps) + else: + self.get_NP(tree['children'], nps) + elif isinstance(tree, list): + for sub_tree in tree: + self.get_NP(sub_tree, nps) + + return nps + + def get_subjects(self, tree): + subject_words = [] + subjects = [] + for subtree in tree['children']: + if subtree['nodeType'] == "VP" or subtree['nodeType'] == 'S' or subtree['nodeType'] == 'VBZ': + subjects.append(' '.join(subject_words)) + subject_words.append(subtree['word']) + else: + subject_words.append(subtree['word']) + return subjects + + def get_noun_phrases(self, line): + claim = line['claim'] + tokens = self.predictor.predict(claim) + nps = [] + tree = tokens['hierplane_tree']['root'] + noun_phrases = self.get_NP(tree, nps) + subjects = self.get_subjects(tree) + for subject in subjects: + if len(subject) > 0: + noun_phrases.append(subject) + if self.add_claim: + noun_phrases.append(claim) + return list(set(noun_phrases)) + + def get_doc_for_claim(self, noun_phrases): + predicted_pages = [] + for np in noun_phrases: + if len(np) > 300: + continue + i = 1 + while i < 12: + try: + docs = wikipedia.search(np) + if self.k_wiki_results is not None: + predicted_pages.extend(docs[:self.k_wiki_results]) + else: + predicted_pages.extend(docs) + except (ConnectionResetError, ConnectionError, ConnectionAbortedError, ConnectionRefusedError): + print("Connection reset error received! Trial #" + str(i)) + time.sleep(600 * i) + i += 1 + else: + break + + predicted_pages = set(predicted_pages) + processed_pages = [] + for page in predicted_pages: + page = page.replace(" ", "_") + page = page.replace("(", "-LRB-") + page = page.replace(")", "-RRB-") + page = page.replace(":", "-COLON-") + processed_pages.append(page) + + return processed_pages + + def np_conc(self, noun_phrases): + noun_phrases = set(noun_phrases) + predicted_pages = [] + for np in noun_phrases: + page = np.replace('( ', '-LRB-') + page = page.replace(' )', '-RRB-') + page = page.replace(' - ', '-') + page = page.replace(' :', '-COLON-') + page = page.replace(' ,', ',') + page = page.replace(" 's", "'s") + page = page.replace(' ', '_') + + if len(page) < 1: + continue + doc_lines = self.db.get(normalize("NFD", page)) + if doc_lines is not None: + predicted_pages.append(page) + return predicted_pages + + def exact_match(self, line): + noun_phrases = self.get_noun_phrases(line) + wiki_results = self.get_doc_for_claim(noun_phrases) + wiki_results = list(set(wiki_results)) + + claim = normalize("NFD", line['claim']) + claim = claim.replace(".", "") + claim = claim.replace("-", " ") + words = [self.proter_stemm.stem(word.lower()) for word in self.tokenizer(claim)] + words = set(words) + predicted_pages = self.np_conc(noun_phrases) + + for page in wiki_results: + page = normalize("NFD", page) + processed_page = re.sub("-LRB-.*?-RRB-", "", page) + processed_page = re.sub("_", " ", processed_page) + processed_page = re.sub("-COLON-", ":", processed_page) + processed_page = processed_page.replace("-", " ") + processed_page = processed_page.replace("–", " ") + processed_page = processed_page.replace(".", "") + page_words = [ + self.proter_stemm.stem(word.lower()) for word in self.tokenizer(processed_page) if len(word) > 0 + ] + + if all([item in words for item in page_words]): + if ':' in page: + page = page.replace(":", "-COLON-") + predicted_pages.append(page) + predicted_pages = list(set(predicted_pages)) + return noun_phrases, wiki_results, predicted_pages + +def get_map_function(parallel, p=None): + assert not parallel or p is not None, "A ThreadPool object should be given if parallel is True" + return p.imap_unordered if parallel else map + +def main(db_file, k_wiki, in_file, out_file, add_claim=True, parallel=True): + method = Doc_Retrieval(database_path=db_file, add_claim=add_claim, k_wiki_results=k_wiki) + processed = dict() + path = os.getcwd() + lines = [] + with open(os.path.join(path, in_file), "r", encoding="utf-8") as f: + for line in f.readlines(): + lines.append(json.loads(line.strip())) + if os.path.isfile(os.path.join(path, in_file + ".progress")): + with open(os.path.join(path, in_file + ".progress"), 'rb') as f_progress: + import pickle + progress = pickle.load(f_progress) + print(os.path.join(path, in_file + ".progress") + " exists. Load it as progress file.") + else: + progress = dict() + + try: + with ThreadPool(processes=4 if parallel else None) as p: + for line in tqdm(get_map_function(parallel, p)(lambda l: process_line_with_progress(method, l, progress), + lines), + total=len(lines)): + processed[line['id']] = line + progress[line['id']] = line + with open(os.path.join(path, out_file), "w+") as f2: + for line in lines: + f2.write(json.dumps(processed[line['id']]) + "\n") + finally: + with open(os.path.join(path, in_file + ".progress"), 'wb') as f_progress: + import pickle + pickle.dump(progress, f_progress, pickle.HIGHEST_PROTOCOL) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--db-file', type=str, help="database file which saves pages") + parser.add_argument('--in-file', type=str, help="input dataset") + parser.add_argument('--out-file', type=str, help="path to save output dataset") + parser.add_argument('--k-wiki', type=int, help="first k pages for wiki search") + parser.add_argument('--parallel', type=bool, default=True) + parser.add_argument('--add-claim', type=bool, default=True) + args = parser.parse_args() + + main(args.db_file, args.k_wiki, args.in_file, args.out_file, args.add_claim, args.parallel) diff --git a/experiments/list5/ukp-athene/requirements.txt b/experiments/list5/ukp-athene/requirements.txt new file mode 100644 index 00000000..18f76458 --- /dev/null +++ b/experiments/list5/ukp-athene/requirements.txt @@ -0,0 +1,166 @@ +absl-py==0.2.2 +aiofiles==0.3.2 +alabaster==0.7.11 +allennlp==0.5.1 +astor==0.6.2 +astroid==1.6.4 +atomicwrites==1.1.5 +attrs==18.1.0 +autopep8==1.3.5 +awscli==1.15.31 +Babel==2.6.0 +backcall==0.1.0 +beautifulsoup4==4.6.0 +bleach==1.5.0 +boto==2.48.0 +boto3==1.7.31 +botocore==1.10.31 +bz2file==0.98 +cffi==1.11.2 +chardet==3.0.4 +click==6.7 +colorama==0.3.9 +conllu==0.10.6 +cookies==2.2.1 +cycler==0.10.0 +cymem==1.31.2 +cysignals==1.7.2 +Cython==0.28.5 +cytoolz==0.8.2 +decorator==4.3.0 +dill==0.2.7.1 +docutils==0.14 +editdistance==0.4 +entrypoints==0.2.3 +flaky==3.4.0 +Flask==0.12.1 +Flask-Cors==3.0.3 +future==0.16.0 +gast==0.2.0 +gensim==3.4.0 +gevent==1.2.2 +greenlet==0.4.14 +grpcio==1.12.0 +h5py==2.7.1 +html5lib==0.9999999 +httptools==0.0.11 +idna==2.6 +imagesize==1.0.0 +ipykernel==4.8.2 +ipython==6.4.0 +ipython-genutils==0.2.0 +ipywidgets==7.2.1 +isort==4.3.4 +itsdangerous==0.24 +jedi==0.12.0 +Jinja2==2.10 +jmespath==0.9.3 +jsonschema==2.6.0 +jupyter==1.0.0 +jupyter-client==5.2.3 +jupyter-console==5.2.0 +jupyter-core==4.4.0 +Keras==2.2.0 +Keras-Applications==1.0.2 +Keras-Preprocessing==1.0.1 +kiwisolver==1.0.1 +lazy-object-proxy==1.3.1 +Markdown==2.6.11 +# MarkupSafe==1.0 +matplotlib==3.0.0 +mccabe==0.6.1 +mistune==0.8.3 +more-itertools==4.2.0 +msgpack-numpy==0.4.1 +msgpack-python==0.5.6 +munkres==1.0.12 +murmurhash==0.28.0 +nbconvert==5.3.1 +nbformat==4.4.0 +networkx==2.1 +nltk==3.3 +notebook==5.5.0 +numpy==1.14.5 +numpydoc==0.8.0 +olefile==0.45.1 +overrides==1.9 +packaging==17.1 +pandas==0.23.2 +pandocfilters==1.4.2 +parameters==0.2.1 +parso==0.2.1 +pathlib==1.0.1 +pexpect==4.6.0 +pickleshare==0.7.4 +Pillow==5.1.0 +plac==0.9.6 +pluggy==0.6.0 +preshed==1.0.0 +prettytable==0.7.2 +prompt-toolkit==1.0.15 +protobuf==3.5.2.post1 +psycopg2==2.7.4 +ptyprocess==0.5.2 +py==1.5.4 +pyasn1==0.4.3 +pycodestyle==2.4.0 +pycorenlp==0.3.0 +pycparser==2.18 +pyfasttext==0.4.5 +Pygments==2.2.0 +pyhocon==0.3.35 +pylint==1.9.1 +pyparsing==2.2.0 +pytest==3.6.3 +python-dateutil==2.7.3 +pytz==2017.3 +PyYAML==3.12 +pyzmq==17.0.0 +qtconsole==4.3.1 +regex==2017.4.5 +requests==2.18.4 +responses==0.9.0 +retinasdk==1.0.0 +rope==0.10.7 +rsa==3.4.2 +s3transfer==0.1.13 +sanic==0.6.0 +Sanic-Cors==0.9.4 +Sanic-Plugins-Framework==0.5.2.dev20180201 +scikit-learn==0.19.1 +scipy==1.1.0 +seaborn==0.9.0 +Send2Trash==1.5.0 +simplegeneric==0.8.1 +six==1.11.0 +smart-open==1.5.7 +snowballstemmer==1.2.1 +spacy==2.0.11 +Sphinx==1.7.6 +sphinxcontrib-websupport==1.1.0 +tensorboard==1.8.0 +tensorboard-pytorch==0.7.1 +tensorboardX==1.2 +tensorflow-gpu==1.8.0 +tensorflow-hub==0.1.0 +termcolor==1.1.0 +terminado==0.8.1 +testpath==0.3.1 +thinc==6.10.2 +toolz==0.9.0 +torch==0.4.0 +torchvision==0.2.1 +tornado==5.0.2 +tqdm==4.23.4 +traitlets==4.3.2 +typing==3.6.4 +ujson==1.35 +Unidecode==1.0.22 +urllib3==1.22 +uvloop==0.10.1 +wcwidth==0.1.7 +websockets==5.0.1 +Werkzeug==0.14.1 +widgetsnbextension==3.2.1 +wikipedia==1.4.0 +wrapt==1.10.11 \ No newline at end of file