Skip to content

Commit

Permalink
Fix for evaluation not using official ground truth spans; also fix pr…
Browse files Browse the repository at this point in the history
…edicted

span tokenization issue.
  • Loading branch information
obryanlouis committed Jan 6, 2018
1 parent 28c18bc commit 82feaa3
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 25 deletions.
15 changes: 11 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,24 @@ python3 setup.py
```

### Training
The following command will start model training and create or restore the
current model parameters from the last checkpoint (if it exists). After each
epcoh, the Dev F1/Em are calculated, and if the F1 score is a new high score,
then the model parameters are saved. There is no mechanism to automatically
stop training; it should be done manually.
```
python3 train_local.py --num_gpus=<NUMBER OF GPUS>
```

### Evaluation
The following command will evaluate the model
on the Dev dataset and print out the exact match (em) and f1 scores.
on the Dev dataset and print out the exact match and f1 scores.
To make it easier to use the compatible SQuAD-formatted model outputs, the
predicted strings for each question will be written to the `evaluation_dir`
in a file called `predictions.json.`
In addition, if the `visualize_evaluated_results` flag is `true`, then
the passsages, questions, ground truth spans, and spans predicted by the
model will be written to output files specified in the `evaluation_dir`
flag.
the passsages, questions, and ground truth spans will be written to output
files specified in the `evaluation_dir` flag.

```
python3 evaluate_local.py --num_gpus=<NUMBER OF GPUS>
Expand Down
25 changes: 18 additions & 7 deletions datasets/squad_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def __init__(self, files_dir, options, vocab):
constants.QUESTION_NER_FILE_PATTERN)
self.text_tokens_files = get_data_files_list(files_dir,
constants.TEXT_TOKENS_FILE_PATTERN)
self.question_ids_to_squad_id_files = get_data_files_list(files_dir,
constants.QUESTION_IDS_TO_SQUAD_QUESTION_ID_FILE_PATTERN)
self.question_ids_to_passage_context_files = get_data_files_list(files_dir,
constants.QUESTION_IDS_TO_PASSAGE_CONTEXT_FILE_PATTERN)

assert len(self.context_files) > 0
assert len(self.context_files) == len(self.question_files)
Expand All @@ -85,6 +89,8 @@ def __init__(self, files_dir, options, vocab):
assert len(self.context_files) == len(self.question_pos_files)
assert len(self.context_files) == len(self.question_ner_files)
assert len(self.context_files) == len(self.text_tokens_files)
assert len(self.context_files) == len(self.question_ids_to_squad_id_files)
assert len(self.context_files) == len(self.question_ids_to_passage_context_files)

self.zip_ds = tf.contrib.data.Dataset.zip({
_CONTEXT_KEY: self._create_ds(self.context_placeholder),
Expand All @@ -104,15 +110,16 @@ def __init__(self, files_dir, options, vocab):
self.num_samples_in_current_files = 0

def get_sentences_for_all_gnd_truths(self, question_id):
gnd_truths_list = self.question_ids_to_ground_truths[question_id]
sentences = []
for start_idx, end_idx in gnd_truths_list:
sentences.append(self.get_sentence(question_id, start_idx, end_idx))
return sentences
passage_context = self.question_ids_to_passage_context[question_id]
return passage_context.acceptable_gnd_truths

def get_sentence(self, example_idx, start_idx, end_idx):
list_text_tokens = self.text_tokens_dict[example_idx]
return " ".join(list_text_tokens[start_idx: end_idx + 1])
# A 'PassageContext' defined in preprocessing/create_train_data.py
passage_context = self.question_ids_to_passage_context[example_idx]
max_word_id = max(passage_context.word_id_to_text_positions.keys())
text_start_idx = passage_context.word_id_to_text_positions[min(start_idx, max_word_id)].start_idx
text_end_idx = passage_context.word_id_to_text_positions[min(end_idx, max_word_id)].end_idx
return passage_context.passage_str[text_start_idx:text_end_idx]

def _create_ds(self, placeholder):
return tf.contrib.data.Dataset.from_tensor_slices(placeholder) \
Expand Down Expand Up @@ -158,6 +165,10 @@ def load_next_file(self, increment_file_number):
self.qst_ner = self._load_2d_np_arr_with_possible_padding(self.question_ner_files[self.current_file_number], max_qst_len, pad_value=0)
self.text_tokens_dict = load_text_file(
self.text_tokens_files[self.current_file_number])
self.question_ids_to_squad_ids = load_text_file(
self.question_ids_to_squad_id_files[self.current_file_number])
self.question_ids_to_passage_context = load_text_file(
self.question_ids_to_passage_context_files[self.current_file_number])

if increment_file_number:
self.current_file_number = (self.current_file_number + 1) % len(self.context_files)
Expand Down
2 changes: 2 additions & 0 deletions datasets/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
class _TestDataset:
def __init__(self, test_data):
self.test_data = test_data
self.question_ids_to_squad_ids = None
self.question_ids_to_passage_context = None

def get_sentences_for_all_gnd_truths(self, ctx_id):
sentences = []
Expand Down
2 changes: 2 additions & 0 deletions preprocessing/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
QUESTION_POS_FILE_PATTERN = "question.pos.%d.npy"
QUESTION_NER_FILE_PATTERN = "question.ner.%d.npy"
TEXT_TOKENS_FILE_PATTERN = "text_tokens.%d"
QUESTION_IDS_TO_SQUAD_QUESTION_ID_FILE_PATTERN = "question_ids_to_squad_question_id.%d"
QUESTION_IDS_TO_PASSAGE_CONTEXT_FILE_PATTERN = "passage_context.%d"

VECTORS_URL = "http://nlp.stanford.edu/data/glove.840B.300d.zip"
WORD_VEC_DIM = 300
Expand Down
40 changes: 37 additions & 3 deletions preprocessing/create_train_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@
# tries to make sure that at least one of the "qa" options in the acceptable
# answers list is accurate and includes it in the data set.

class TextPosition:
def __init__(self, start_idx, end_idx):
self.start_idx = start_idx
self.end_idx = end_idx

class PassageContext:
'''Class used to save the tokenization positions in a given passage
so that the original strings can be used for constructing answer
spans rather than joining tokenized strings, which isn't 100% correct.
'''
def __init__(self, passage_str, word_id_to_text_positions,
acceptable_gnd_truths):
self.passage_str = passage_str
self.word_id_to_text_positions = word_id_to_text_positions
self.acceptable_gnd_truths = acceptable_gnd_truths

class DataParser():
def __init__(self, data_dir, download_dir):
self.data_dir = data_dir
Expand Down Expand Up @@ -177,6 +193,8 @@ def _create_train_data_internal(self, data_file, is_dev):
question_pos = []
context_ner = []
question_ner = []
question_ids_to_squad_question_id = {}
question_ids_to_passage_context = {}
self.value_idx = 0
for article in dataset:
for paragraph in article["paragraphs"]:
Expand All @@ -186,13 +204,27 @@ def _create_train_data_internal(self, data_file, is_dev):
assert tok_context is not None
ctx_offset_dict = {}
ctx_end_offset_dict = {}
word_idx_to_text_position = {}
for z in range(len(tok_context)):
tok = tok_context[z]
ctx_offset_dict[tok.idx] = tok
ctx_end_offset_dict[tok.idx + len(tok.text)] = tok
st = tok.idx
end = tok.idx + len(tok.text)
ctx_offset_dict[st] = tok
ctx_end_offset_dict[end] = tok
word_idx_to_text_position[z] = TextPosition(st, end)
for qa in paragraph["qas"]:
self.question_id += 1
acceptable_gnd_truths = []
for answer in qa["answers"]:
acceptable_gnd_truths.append(answer["text"])
question_ids_to_passage_context[self.question_id] = \
PassageContext(context, word_idx_to_text_position,
acceptable_gnd_truths)
question = qa["question"]
squad_question_id = qa["id"]
assert squad_question_id is not None
question_ids_to_squad_question_id[self.question_id] = \
squad_question_id
tok_question = self.nlp(question)
qst_ner_dict = self._get_ner_dict(tok_question)
assert tok_question is not None
Expand Down Expand Up @@ -229,7 +261,9 @@ def _create_train_data_internal(self, data_file, is_dev):
context_pos = context_pos,
question_pos = question_pos,
context_ner = context_ner,
question_ner = question_ner)
question_ner = question_ner,
question_ids_to_squad_question_id = question_ids_to_squad_question_id,
question_ids_to_passage_context = question_ids_to_passage_context)

def _create_padded_array(self, list_of_py_arrays, max_len, pad_value):
return [py_arr + [pad_value] * (max_len - len(py_arr)) for py_arr in list_of_py_arrays]
Expand Down
14 changes: 14 additions & 0 deletions preprocessing/dataset_files_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,18 @@ def save(self):
print("Saving text tokens to binary pickle files")
filtered_text_tokens = {}
filtered_gnd_truths_dict = {}
filtered_squad_ids_dict = {}
filtered_passage_contexts_dict = {}
for z in range(question_ids_np_arr.shape[0]):
question_id = question_ids_np_arr[z]
filtered_gnd_truths_dict[question_id] = \
self.data.question_ids_to_ground_truths[question_id]
filtered_text_tokens[question_id] = \
self.data.text_tokens_dict[question_id]
filtered_squad_ids_dict[question_id] = \
self.data.question_ids_to_squad_question_id[question_id]
filtered_passage_contexts_dict[question_id] = \
self.data.question_ids_to_passage_context[question_id]

save_pickle_file(file_names.text_tokens_file_name,
filtered_text_tokens)
Expand All @@ -58,6 +64,14 @@ def save(self):
save_pickle_file(file_names.question_ids_to_ground_truths_file_name,
filtered_gnd_truths_dict)

print("Saving question ids to SQuAD question ids dict")
save_pickle_file(file_names.question_ids_to_squad_question_id_file_name,
filtered_squad_ids_dict)

print("Saving passage contexts to SQuAD question ids dict")
save_pickle_file(file_names.question_ids_to_passage_context_file_name,
filtered_passage_contexts_dict)

print("Saving span numpy arrays")
np.save(file_names.spn_file_name,
self.data.spans[batch_idx:next_batch_idx])
Expand Down
12 changes: 10 additions & 2 deletions preprocessing/dataset_files_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def create_new_file_names(self):
self._full_file_name(constants.CONTEXT_POS_FILE_PATTERN),
self._full_file_name(constants.QUESTION_POS_FILE_PATTERN),
self._full_file_name(constants.CONTEXT_NER_FILE_PATTERN),
self._full_file_name(constants.QUESTION_NER_FILE_PATTERN))
self._full_file_name(constants.QUESTION_NER_FILE_PATTERN),
self._full_file_name(constants.QUESTION_IDS_TO_SQUAD_QUESTION_ID_FILE_PATTERN),
self._full_file_name(constants.QUESTION_IDS_TO_PASSAGE_CONTEXT_FILE_PATTERN))
self.next_batch_number += 1
return file_names

Expand All @@ -45,7 +47,9 @@ def __init__(self,
context_pos_file_name,
question_pos_file_name,
context_ner_file_name,
question_ner_file_name):
question_ner_file_name,
question_ids_to_squad_question_id_file_name,
question_ids_to_passage_context_file_name):
self.text_tokens_file_name = text_tokens_file_name
self.qst_file_name = qst_file_name
self.ctx_file_name = ctx_file_name
Expand All @@ -59,3 +63,7 @@ def __init__(self,
self.question_pos_file_name = question_pos_file_name
self.context_ner_file_name = context_ner_file_name
self.question_ner_file_name = question_ner_file_name
self.question_ids_to_squad_question_id_file_name = \
question_ids_to_squad_question_id_file_name
self.question_ids_to_passage_context_file_name = \
question_ids_to_passage_context_file_name
6 changes: 5 additions & 1 deletion preprocessing/raw_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def __init__(self,
context_pos,
question_pos,
context_ner,
question_ner):
question_ner,
question_ids_to_squad_question_id,
question_ids_to_passage_context):
self.list_contexts = list_contexts
self.list_word_in_question = list_word_in_question
self.list_questions = list_questions
Expand All @@ -28,3 +30,5 @@ def __init__(self,
self.question_pos = question_pos
self.context_ner = context_ner
self.question_ner = question_ner
self.question_ids_to_squad_question_id = question_ids_to_squad_question_id
self.question_ids_to_passage_context = question_ids_to_passage_context
14 changes: 11 additions & 3 deletions test/print_training_data.py → print_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
PRINT_LIMIT = 10
WORD_PRINT_LIMIT = 5

def _print_qst_or_ctx_np_arr(arr, vocab, ds, is_ctx, wiq_or_wic):
def _print_qst_or_ctx_np_arr(arr, vocab, ds, is_ctx, wiq_or_wic,
question_ids=None, question_ids_to_squad_ids=None):
for z in range(PRINT_LIMIT):
l = []
if question_ids_to_squad_ids is not None:
question_id = question_ids[z]
squad_id = question_ids_to_squad_ids[question_id]
l.append("[SQUAD ID: " + squad_id + "]")
for zz in range(arr.shape[1]):
i = arr[z, zz]
if i == vocab.PAD_ID:
Expand All @@ -29,9 +34,12 @@ def _print_gnd_truths(ds, vocab):

def _print_ds(vocab, ds):
print("Context")
_print_qst_or_ctx_np_arr(ds.ctx, vocab, ds, is_ctx=True, wiq_or_wic=ds.wiq)
_print_qst_or_ctx_np_arr(ds.ctx, vocab, ds, is_ctx=True, wiq_or_wic=ds.wiq,
question_ids=ds.qid)
print("Questions")
_print_qst_or_ctx_np_arr(ds.qst, vocab, ds, is_ctx=False, wiq_or_wic=ds.wic)
_print_qst_or_ctx_np_arr(ds.qst, vocab, ds, is_ctx=False, wiq_or_wic=ds.wic,
question_ids=ds.qid,
question_ids_to_squad_ids=ds.question_ids_to_squad_ids)
print("Spans")
print(ds.spn[:PRINT_LIMIT])
print("Ground truths")
Expand Down
21 changes: 16 additions & 5 deletions train/evaluation_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utility methods for evaluating a model.
"""

import json
import math
import os
import time
Expand Down Expand Up @@ -90,6 +91,7 @@ def _eval(session, towers, squad_dataset, options, is_train, sample_limit):
estimated_total_dev_samples = squad_dataset.estimate_total_dev_ds_size()
total_samples_processed = 0
start_time = time.time()
squad_prediction_format = {} # key=squad question id, value=prediction (string)
while True:
if total_samples_processed >= estimated_total_dev_samples \
and num_dev_files == 1:
Expand Down Expand Up @@ -126,15 +128,20 @@ def _eval(session, towers, squad_dataset, options, is_train, sample_limit):
start, end = get_best_start_and_end(start_span_probs[zz],
end_span_probs[zz], options)
example_index = data_indices[zz]
passages.append(dataset.get_sentence(example_index, 0, squad_dataset.get_max_ctx_len() - 1))
question_word_ids = qst_values[zz]
question = find_question_sentence(question_word_ids, squad_dataset.vocab)
questions.append(question)
# These need to be the original sentences from the training/dev
# sets, without any padding/unique word replacements.
text_predictions.append(dataset.get_sentence(example_index, start, end))
prediction_str = dataset.get_sentence(example_index, start, end)
if dataset.question_ids_to_squad_ids is not None:
squad_question_id = \
dataset.question_ids_to_squad_ids[example_index]
if squad_question_id in squad_prediction_format:
continue
squad_prediction_format[squad_question_id] = prediction_str
text_predictions.append(prediction_str)
acceptable_gnd_truths = dataset.get_sentences_for_all_gnd_truths(example_index)
ground_truths.append(acceptable_gnd_truths)
passages.append(dataset.get_sentence(example_index, 0, squad_dataset.get_max_ctx_len() - 1))
questions.append(question)
if not is_train:
squad_dataset.increment_val_samples_processed(batch_increment)
else:
Expand All @@ -155,6 +162,10 @@ def _eval(session, towers, squad_dataset, options, is_train, sample_limit):
num_dev_files, readable_eta(est_time_left)),
end="\r", flush=True)
print("")
if not is_train:
with open(os.path.join(options.evaluation_dir,
"predictions.json"), mode="w") as predictions_file:
json.dump(squad_prediction_format, predictions_file)
if options.verbose_logging:
print("text_predictions", utf8_str(text_predictions),
"ground_truths", utf8_str(ground_truths))
Expand Down

0 comments on commit 82feaa3

Please sign in to comment.