From c7ce6c2894fd75d155323698f7f38ec93c4913c4 Mon Sep 17 00:00:00 2001 From: kentonl Date: Sun, 15 Jul 2018 22:42:31 +0000 Subject: [PATCH] Make the data in predict.py a command-line argument rather than a field in the config. Also make the data loading lazy. --- README.md | 3 +-- predict.py | 31 ++++++++++++++++++------------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 3fa9e2be..e8c01ff8 100644 --- a/README.md +++ b/README.md @@ -44,8 +44,7 @@ This repository contains the code for replicating results from * `clusters` should be left empty and is only used for evaluation purposes. * `doc_key` indicates the genre, which can be one of the following: `"bc", "bn", "mz", "nw", "pt", "tc", "wb"` * `speakers` indicates the speaker of each word. These can be all empty strings if there is only one known speaker. -* Change the value of `eval_path` in the configuration file to the path to this new file. -* Run `python predict.py `, which outputs the original file extended with annotations of the predicted clusters. +* Run `python predict.py `, which outputs the input jsonlines with predicted clusters. ## Other Quirks diff --git a/predict.py b/predict.py index d8ed23eb..c6e43665 100755 --- a/predict.py +++ b/predict.py @@ -12,23 +12,28 @@ if __name__ == "__main__": config = util.initialize_from_env() + # Input file in .jsonlines format. + input_filename = sys.argv[2] + # Predictions will be written to this file in .jsonlines format. - output_filename = sys.argv[2] + output_filename = sys.argv[3] model = cm.CorefModel(config) - model.load_eval_data() with tf.Session() as session: model.restore(session) - with open(output_filename, "w") as f: - for example_num, (tensorized_example, example) in enumerate(model.eval_data): - feed_dict = {i:t for i,t in zip(model.input_tensors, tensorized_example)} - _, _, _, top_span_starts, top_span_ends, top_antecedents, top_antecedent_scores = session.run(model.predictions, feed_dict=feed_dict) - predicted_antecedents = model.get_predicted_antecedents(top_antecedents, top_antecedent_scores) - example["predicted_clusters"], _ = model.get_predicted_clusters(top_span_starts, top_span_ends, predicted_antecedents) - - f.write(json.dumps(example)) - f.write("\n") - if example_num % 100 == 0: - print("Decoded {} examples.".format(example_num + 1)) + with open(output_filename, "w") as output_file: + with open(input_filename) as input_file: + for example_num, line in enumerate(input_file.readlines()): + example = json.loads(line) + tensorized_example = model.tensorize_example(example, is_training=False) + feed_dict = {i:t for i,t in zip(model.input_tensors, tensorized_example)} + _, _, _, top_span_starts, top_span_ends, top_antecedents, top_antecedent_scores = session.run(model.predictions, feed_dict=feed_dict) + predicted_antecedents = model.get_predicted_antecedents(top_antecedents, top_antecedent_scores) + example["predicted_clusters"], _ = model.get_predicted_clusters(top_span_starts, top_span_ends, predicted_antecedents) + + output_file.write(json.dumps(example)) + output_file.write("\n") + if example_num % 100 == 0: + print("Decoded {} examples.".format(example_num + 1))