Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ PYTEST_ARGS ?= ""
ifneq ($(findstring deepsparse,$(TARGETS)),deepsparse)
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/deepsparse
endif
ifneq ($(findstring transformers,$(TARGETS)),transformers)
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/transformers
endif
ifneq ($(findstring keras,$(TARGETS)),keras)
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/keras
endif
Expand Down
6 changes: 3 additions & 3 deletions src/sparseml/transformers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@
from sparseml.transformers.utils import SparseAutoModel


__all__ = ["export_transformer_to_onnx"]
__all__ = ["export_transformer_to_onnx", "load_task_model"]


_LOGGER = logging.getLogger(__name__)


def _load_task_model(task: str, model_path: str, config: Any) -> Module:
def load_task_model(task: str, model_path: str, config: Any) -> Module:
if task == "masked-language-modeling" or task == "mlm":
return SparseAutoModel.masked_language_modeling_from_pretrained(
model_name_or_path=model_path,
Expand Down Expand Up @@ -156,7 +156,7 @@ def export_transformer_to_onnx(
tokenizer = AutoTokenizer.from_pretrained(
model_path, model_max_length=sequence_length
)
model = _load_task_model(task, model_path, config)
model = load_task_model(task, model_path, config)
_LOGGER.info(f"loaded model, config, and tokenizer from {model_path}")

trainer = Trainer(
Expand Down
15 changes: 15 additions & 0 deletions tests/sparseml/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sparseml.transformers as _transformers # noqa: F401
167 changes: 167 additions & 0 deletions tests/sparseml/transformers/test_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import math
import os
import shutil
from collections import Counter, OrderedDict

import onnx
import onnxruntime as ort
import pytest
from transformers import AutoConfig

from sparseml.transformers.sparsification import Trainer
from sparsezoo import Zoo
from sparsezoo.utils import load_numpy_list
from src.sparseml.transformers import export_transformer_to_onnx, load_task_model


def _is_yaml_recipe_present(model_path):
return any(
[
file
for file in glob.glob(os.path.join(model_path, "*"))
if (file.endswith(".yaml") or ("recipe" in file))
]
)


def _run_inference_onnx(path_onnx, input_data):
ort_sess = ort.InferenceSession(path_onnx)
model = onnx.load(path_onnx)
input_names = [inp.name for inp in model.graph.input]

model_input = OrderedDict(
[(k, v.reshape(1, -1)) for k, v in zip(input_names, input_data.values())]
)

output = ort_sess.run(
None,
model_input,
)
return output


def _compare_onnx_models(model_1, model_2):
major_nodes = [
"QLinearMatMul",
"Gemm",
"MatMul",
"MatMulInteger",
"Conv",
"QLinearConv",
"ConvInteger",
"QuantizeLinear",
"DeQuantizeLinear",
]

nodes1 = model_1.graph.node
nodes1_names = [node.name for node in nodes1]
nodes1_count = Counter([node_name.split("_")[0] for node_name in nodes1_names])

nodes2 = model_2.graph.node
nodes2_names = [node.name for node in nodes2]
nodes2_count = Counter([node_name.split("_")[0] for node_name in nodes2_names])

for node in major_nodes:
assert nodes1_count[node] == nodes2_count[node]


@pytest.mark.parametrize(
"model_stub, recipe_present, task",
[
(
"zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/pruned-conservative", # noqa: E501
False,
"question-answering",
)
],
scope="function",
)
class TestModelFromZoo:
@pytest.fixture()
def setup(self, model_stub, recipe_present, task):
# setup
self.onnx_retrieved_name = "retrieved_model.onnx"
model = Zoo.load_model_from_stub(model_stub)
model.download()

yield model, recipe_present, task

# teardown
model_path = model.framework_files[0].dir_path
shutil.rmtree(os.path.dirname(model_path))

def test_load_weights_apply_recipe(self, setup):
model, recipe_present, task = setup
model_path = model.framework_files[0].dir_path

config = AutoConfig.from_pretrained(model_path)
model = load_task_model(task, model_path, config)

assert model
assert recipe_present == _is_yaml_recipe_present(model_path)
if recipe_present:

trainer = Trainer(
model=model,
model_state_path=model_path,
recipe=None,
recipe_args=None,
teacher=None,
)
applied = trainer.apply_manager(epoch=math.inf, checkpoint=None)

assert applied

def test_export_to_onnx(self, setup):
model, recipe_present, task = setup
path_onnx = model.onnx_file.downloaded_path()
model_path = model.framework_files[0].dir_path

path_retrieved_onnx = export_transformer_to_onnx(
task=task,
model_path=model_path,
onnx_file_name=self.onnx_retrieved_name,
)

zoo_model = onnx.load(path_onnx)
export_model = onnx.load(os.path.join(model_path, path_retrieved_onnx))

assert export_model

onnx.checker.check_model(export_model)
_compare_onnx_models(zoo_model, export_model)

def test_outputs_ort(self, setup):

model, recipe_present, task = setup
path_onnx = model.onnx_file.downloaded_path()
model_path = model.framework_files[0].dir_path
inputs_path = model.data_inputs.path

input_data = load_numpy_list(inputs_path)[0]

path_retrieved_onnx = export_transformer_to_onnx(
task=task,
model_path=model_path,
onnx_file_name=self.onnx_retrieved_name,
)

out1 = _run_inference_onnx(path_onnx, input_data)
out2 = _run_inference_onnx(path_retrieved_onnx, input_data)
for o1, o2 in zip(out1, out2):
pytest.approx(o1, abs=1e-5) == o2