Skip to content

Commit

Permalink
Make the data in predict.py a command-line argument rather than a fie…
Browse files Browse the repository at this point in the history
…ld in the config. Also make the data loading lazy.
  • Loading branch information
kentonl committed Jul 15, 2018
1 parent b733669 commit c7ce6c2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
3 changes: 1 addition & 2 deletions README.md
Expand Up @@ -44,8 +44,7 @@ This repository contains the code for replicating results from
* `clusters` should be left empty and is only used for evaluation purposes.
* `doc_key` indicates the genre, which can be one of the following: `"bc", "bn", "mz", "nw", "pt", "tc", "wb"`
* `speakers` indicates the speaker of each word. These can be all empty strings if there is only one known speaker.
* Change the value of `eval_path` in the configuration file to the path to this new file.
* Run `python predict.py <experiment> <output_file>`, which outputs the original file extended with annotations of the predicted clusters.
* Run `python predict.py <experiment> <input_file> <output_file>`, which outputs the input jsonlines with predicted clusters.

## Other Quirks

Expand Down
31 changes: 18 additions & 13 deletions predict.py
Expand Up @@ -12,23 +12,28 @@
if __name__ == "__main__":
config = util.initialize_from_env()

# Input file in .jsonlines format.
input_filename = sys.argv[2]

# Predictions will be written to this file in .jsonlines format.
output_filename = sys.argv[2]
output_filename = sys.argv[3]

model = cm.CorefModel(config)
model.load_eval_data()

with tf.Session() as session:
model.restore(session)

with open(output_filename, "w") as f:
for example_num, (tensorized_example, example) in enumerate(model.eval_data):
feed_dict = {i:t for i,t in zip(model.input_tensors, tensorized_example)}
_, _, _, top_span_starts, top_span_ends, top_antecedents, top_antecedent_scores = session.run(model.predictions, feed_dict=feed_dict)
predicted_antecedents = model.get_predicted_antecedents(top_antecedents, top_antecedent_scores)
example["predicted_clusters"], _ = model.get_predicted_clusters(top_span_starts, top_span_ends, predicted_antecedents)

f.write(json.dumps(example))
f.write("\n")
if example_num % 100 == 0:
print("Decoded {} examples.".format(example_num + 1))
with open(output_filename, "w") as output_file:
with open(input_filename) as input_file:
for example_num, line in enumerate(input_file.readlines()):
example = json.loads(line)
tensorized_example = model.tensorize_example(example, is_training=False)
feed_dict = {i:t for i,t in zip(model.input_tensors, tensorized_example)}
_, _, _, top_span_starts, top_span_ends, top_antecedents, top_antecedent_scores = session.run(model.predictions, feed_dict=feed_dict)
predicted_antecedents = model.get_predicted_antecedents(top_antecedents, top_antecedent_scores)
example["predicted_clusters"], _ = model.get_predicted_clusters(top_span_starts, top_span_ends, predicted_antecedents)

output_file.write(json.dumps(example))
output_file.write("\n")
if example_num % 100 == 0:
print("Decoded {} examples.".format(example_num + 1))

0 comments on commit c7ce6c2

Please sign in to comment.