From f9760acf3f76646f8c216b0a1db276ed2923b030 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Thu, 22 Jul 2021 17:25:00 -0400 Subject: [PATCH] Fix input/output shapes for exported samples --- examples/pytorch/question-answering/sparseml_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/pytorch/question-answering/sparseml_utils.py b/examples/pytorch/question-answering/sparseml_utils.py index a2fe04010004..eb32511d85f8 100644 --- a/examples/pytorch/question-answering/sparseml_utils.py +++ b/examples/pytorch/question-answering/sparseml_utils.py @@ -217,9 +217,9 @@ def export_model(model, dataloader, output_dir, num_exported_samples): input_names = list(sample_batch.keys()) output_names = [o.name for o in sess.get_outputs()] for input_vals in zip(*sample_batch.values()): - input_feed = {k: v.reshape(1, -1).numpy() for k, v in zip(input_names, input_vals)} - output_vals = sess.run(output_names, input_feed) - output_dict = {name: val for name, val in zip(output_names, output_vals)} + input_feed = {k: v.numpy() for k, v in zip(input_names, input_vals)} + output_vals = sess.run(output_names, {k: input_feed[k].reshape(1, -1) for k in input_feed}) + output_dict = {name: numpy.squeeze(val) for name, val in zip(output_names, output_vals)} file_idx = f"{num_samples}".zfill(4) numpy.savez(f"{sample_inputs}/inp-{file_idx}.npz", **input_feed) numpy.savez(f"{sample_outputs}/out-{file_idx}.npz", **output_dict)