Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/pytorch/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ class DataTrainingArguments:
onnx_export_path: Optional[str] = field(
default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"}
)
num_exported_samples: Optional[int] = field(
default=20, metadata={"help": "Number of exported samples, default to 20"}
)
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
Expand Down Expand Up @@ -647,7 +650,7 @@ def compute_metrics(p: EvalPrediction):
if data_args.onnx_export_path:
logger.info("*** Export to ONNX ***")
eval_dataloader = trainer.get_eval_dataloader(eval_dataset)
export_model(model, eval_dataloader, data_args.onnx_export_path)
export_model(model, eval_dataloader, data_args.onnx_export_path, data_args.num_exported_samples)


def _mp_fn(index):
Expand Down
64 changes: 59 additions & 5 deletions examples/pytorch/question-answering/sparseml_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import inspect
import collections
import math
import os
from typing import Any

import numpy
import torch
import torch.nn.functional as F

import onnxruntime
from sparseml.pytorch.optim.manager import ScheduledModifierManager
from sparseml.pytorch.optim.optimizer import ScheduledOptimizer
from sparseml.pytorch.utils import ModuleExporter, logger
from trainer_qa import QuestionAnsweringTrainer
from transformers.modeling_outputs import QuestionAnsweringModelOutput
from transformers.models.bert.modeling_bert import BertForQuestionAnswering


class SparseMLQATrainer(QuestionAnsweringTrainer):
Expand Down Expand Up @@ -107,15 +115,61 @@ def compute_loss(self, model, inputs, return_outputs=False):
return (loss, outputs) if return_outputs else loss


def export_model(model, dataloader, output_dir):
class QuestionAnsweringModuleExporter(ModuleExporter):
"""
Module exporter class for Question Answering
"""
@classmethod
def get_output_names(self, out: Any):
if not isinstance(out, QuestionAnsweringModelOutput):
raise ValueError("Expected QuestionAnsweringModelOutput, got {type(out)}")
expected = ["start_logits", "end_logits"]
if numpy.any([name for name in expected if name not in out]):
raise ValueError("Expected output names not found in model output")
return expected


def export_model(model, dataloader, output_dir, num_exported_samples):
"""
Export a trained model to ONNX
:param model: trained model
:param dataloader: dataloader to get sample batch
:param output_dir: output directory for ONNX model
"""
exporter = ModuleExporter(model, output_dir=output_dir)
exporter = QuestionAnsweringModuleExporter(model, output_dir=output_dir)

sess = None
num_samples = 0

sample_inputs = os.path.join(output_dir, "sample-inputs")
sample_outputs = os.path.join(output_dir, "sample-outputs")
os.makedirs(sample_inputs, exist_ok=True)
os.makedirs(sample_outputs, exist_ok=True)

forward_args_spec = inspect.getfullargspec(BertForQuestionAnswering.forward)
for _, sample_batch in enumerate(dataloader):
sample_input = (sample_batch["input_ids"], sample_batch["attention_mask"], sample_batch["token_type_ids"])
exporter.export_onnx(sample_batch=sample_input, convert_qat=True)
break
if sess is None:
one_sample_input = collections.OrderedDict(
[(f, sample_batch[f][0].reshape(1, -1)) for f in forward_args_spec.args if f in sample_batch]
)

try:
exporter.export_onnx(sample_batch=one_sample_input, convert_qat=True)
onnx_file = os.path.join(output_dir, "model.onnx")
except Exception:
raise RuntimeError("Error exporting ONNX models and/or inputs/outputs")

sess = onnxruntime.InferenceSession(onnx_file)

input_names = list(sample_batch.keys())
output_names = [o.name for o in sess.get_outputs()]
for input_vals in zip(*sample_batch.values()):
input_feed = {k: v.reshape(1, -1).numpy() for k, v in zip(input_names, input_vals)}
output_vals = sess.run(output_names, input_feed)
output_dict = {name: val for name, val in zip(output_names, output_vals)}
file_idx = f"{num_samples}".zfill(4)
numpy.savez(f"{sample_inputs}/inp-{file_idx}.npz", **input_feed)
numpy.savez(f"{sample_outputs}/out-{file_idx}.npz", **output_dict)
num_samples += 1
if num_samples >= num_exported_samples:
return