From 03039e2646457ad5d20c0e8a4843016e1e69b18b Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Tue, 13 Jul 2021 17:09:30 -0400 Subject: [PATCH 1/2] Export ONNX models with named inputs/outputs, samples --- examples/pytorch/question-answering/run_qa.py | 5 +- .../question-answering/sparseml_utils.py | 58 +++++++++++++++++-- 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index 5983ce648d1e..940e2ddc21ff 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -106,6 +106,9 @@ class DataTrainingArguments: onnx_export_path: Optional[str] = field( default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"} ) + num_exported_samples: Optional[int] = field( + default=20, metadata={"help": "Number of exported samples, default to 20"} + ) dataset_name: Optional[str] = field( default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} ) @@ -647,7 +650,7 @@ def compute_metrics(p: EvalPrediction): if data_args.onnx_export_path: logger.info("*** Export to ONNX ***") eval_dataloader = trainer.get_eval_dataloader(eval_dataset) - export_model(model, eval_dataloader, data_args.onnx_export_path) + export_model(model, eval_dataloader, data_args.onnx_export_path, data_args.num_exported_samples) def _mp_fn(index): diff --git a/examples/pytorch/question-answering/sparseml_utils.py b/examples/pytorch/question-answering/sparseml_utils.py index ca30f7c61954..42b49d3ce678 100644 --- a/examples/pytorch/question-answering/sparseml_utils.py +++ b/examples/pytorch/question-answering/sparseml_utils.py @@ -1,12 +1,17 @@ import math +import os +from typing import Any +import numpy import torch import torch.nn.functional as F +import onnxruntime from sparseml.pytorch.optim.manager import ScheduledModifierManager from sparseml.pytorch.optim.optimizer import ScheduledOptimizer from sparseml.pytorch.utils import ModuleExporter, logger from trainer_qa import QuestionAnsweringTrainer +from transformers.modeling_outputs import QuestionAnsweringModelOutput class SparseMLQATrainer(QuestionAnsweringTrainer): @@ -107,15 +112,58 @@ def compute_loss(self, model, inputs, return_outputs=False): return (loss, outputs) if return_outputs else loss -def export_model(model, dataloader, output_dir): +class QuestionAnsweringModuleExporter(ModuleExporter): + """ + Module exporter class for Question Answering + """ + + def get_output_names(self, out: Any): + if not isinstance(out, QuestionAnsweringModelOutput): + raise ValueError("Expected QuestionAnsweringModelOutput, got {type(out)}") + expected = ["start_logits", "end_logits"] + if numpy.any([name for name in expected if name not in out]): + raise ValueError("Expected output names not found in model output") + return expected + + +def export_model(model, dataloader, output_dir, num_exported_samples): """ Export a trained model to ONNX :param model: trained model :param dataloader: dataloader to get sample batch :param output_dir: output directory for ONNX model """ - exporter = ModuleExporter(model, output_dir=output_dir) + exporter = QuestionAnsweringModuleExporter(model, output_dir=output_dir) + + sess = None + num_samples = 0 + + sample_inputs = os.path.join(output_dir, "sample-inputs") + sample_outputs = os.path.join(output_dir, "sample-outputs") + os.makedirs(sample_inputs, exist_ok=True) + os.makedirs(sample_outputs, exist_ok=True) + for _, sample_batch in enumerate(dataloader): - sample_input = (sample_batch["input_ids"], sample_batch["attention_mask"], sample_batch["token_type_ids"]) - exporter.export_onnx(sample_batch=sample_input, convert_qat=True) - break + if sess is None: + one_sample_input = {f: sample_batch[f][0].reshape(1, -1) for f in sample_batch} + + try: + exporter.export_onnx(sample_batch=one_sample_input, convert_qat=True) + onnx_file = os.path.join(output_dir, "model.onnx") + except Exception: + raise RuntimeError("Error exporting ONNX models and/or inputs/outputs") + + sess = onnxruntime.InferenceSession(onnx_file) + + input_names = list(sample_batch.keys()) + output_names = [o.name for o in sess.get_outputs()] + for input_vals in zip(*sample_batch.values()): + input_feed = {k: v.reshape(1, -1).numpy() for k, v in zip(input_names, input_vals)} + output_vals = sess.run(output_names, input_feed) + output_dict = {name: val for name, val in zip(output_names, output_vals)} + file_idx = f"{num_samples}".zfill(4) + numpy.savez(f"{sample_inputs}/inp-{file_idx}.npz", **input_feed) + numpy.savez(f"{sample_outputs}/out-{file_idx}.npz", **output_dict) + num_samples += 1 + if num_samples >= num_exported_samples: + return From 4c9435c99b5f237ed546539ef7ebb52e080b35f2 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Wed, 14 Jul 2021 16:17:55 -0400 Subject: [PATCH 2/2] Ensure passing BERT inputs in order of forward to ONNX export --- examples/pytorch/question-answering/sparseml_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/question-answering/sparseml_utils.py b/examples/pytorch/question-answering/sparseml_utils.py index 42b49d3ce678..cc505174b263 100644 --- a/examples/pytorch/question-answering/sparseml_utils.py +++ b/examples/pytorch/question-answering/sparseml_utils.py @@ -1,3 +1,5 @@ +import inspect +import collections import math import os from typing import Any @@ -12,6 +14,7 @@ from sparseml.pytorch.utils import ModuleExporter, logger from trainer_qa import QuestionAnsweringTrainer from transformers.modeling_outputs import QuestionAnsweringModelOutput +from transformers.models.bert.modeling_bert import BertForQuestionAnswering class SparseMLQATrainer(QuestionAnsweringTrainer): @@ -116,7 +119,7 @@ class QuestionAnsweringModuleExporter(ModuleExporter): """ Module exporter class for Question Answering """ - + @classmethod def get_output_names(self, out: Any): if not isinstance(out, QuestionAnsweringModelOutput): raise ValueError("Expected QuestionAnsweringModelOutput, got {type(out)}") @@ -143,9 +146,12 @@ def export_model(model, dataloader, output_dir, num_exported_samples): os.makedirs(sample_inputs, exist_ok=True) os.makedirs(sample_outputs, exist_ok=True) + forward_args_spec = inspect.getfullargspec(BertForQuestionAnswering.forward) for _, sample_batch in enumerate(dataloader): if sess is None: - one_sample_input = {f: sample_batch[f][0].reshape(1, -1) for f in sample_batch} + one_sample_input = collections.OrderedDict( + [(f, sample_batch[f][0].reshape(1, -1)) for f in forward_args_spec.args if f in sample_batch] + ) try: exporter.export_onnx(sample_batch=one_sample_input, convert_qat=True)