From 53c26c16c315b2332ac624dde75cb6a87285191b Mon Sep 17 00:00:00 2001 From: "bogunowicz@arrival.com" Date: Thu, 17 Mar 2022 12:36:37 +0100 Subject: [PATCH 1/4] Initial commit --- Makefile | 3 + src/sparseml/transformers/export.py | 6 +- tests/sparseml/transformers/__init__.py | 15 ++ .../transformers/test_transformers.py | 198 ++++++++++++++++++ 4 files changed, 219 insertions(+), 3 deletions(-) create mode 100644 tests/sparseml/transformers/__init__.py create mode 100644 tests/sparseml/transformers/test_transformers.py diff --git a/Makefile b/Makefile index 7605d6fdd7f..40b3743e7cb 100644 --- a/Makefile +++ b/Makefile @@ -14,6 +14,9 @@ PYTEST_ARGS ?= "" ifneq ($(findstring deepsparse,$(TARGETS)),deepsparse) PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/deepsparse endif +ifneq ($(findstring transformers,$(TARGETS)),transformers) + PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/transformers +endif ifneq ($(findstring keras,$(TARGETS)),keras) PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/keras endif diff --git a/src/sparseml/transformers/export.py b/src/sparseml/transformers/export.py index f42cd44a668..f4d82cb25b3 100644 --- a/src/sparseml/transformers/export.py +++ b/src/sparseml/transformers/export.py @@ -71,13 +71,13 @@ from sparseml.transformers.utils import SparseAutoModel -__all__ = ["export_transformer_to_onnx"] +__all__ = ["export_transformer_to_onnx", "load_task_model"] _LOGGER = logging.getLogger(__name__) -def _load_task_model(task: str, model_path: str, config: Any) -> Module: +def load_task_model(task: str, model_path: str, config: Any) -> Module: if task == "masked-language-modeling" or task == "mlm": return SparseAutoModel.masked_language_modeling_from_pretrained( model_name_or_path=model_path, @@ -156,7 +156,7 @@ def export_transformer_to_onnx( tokenizer = AutoTokenizer.from_pretrained( model_path, model_max_length=sequence_length ) - model = _load_task_model(task, model_path, config) + model = load_task_model(task, model_path, config) _LOGGER.info(f"loaded model, config, and tokenizer from {model_path}") trainer = Trainer( diff --git a/tests/sparseml/transformers/__init__.py b/tests/sparseml/transformers/__init__.py new file mode 100644 index 00000000000..db990014a67 --- /dev/null +++ b/tests/sparseml/transformers/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sparseml.transformers as _transformers # noqa: F401 diff --git a/tests/sparseml/transformers/test_transformers.py b/tests/sparseml/transformers/test_transformers.py new file mode 100644 index 00000000000..09f997d5e58 --- /dev/null +++ b/tests/sparseml/transformers/test_transformers.py @@ -0,0 +1,198 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import math +import os +import shutil +import tarfile + +import numpy as np +import onnx +import onnxruntime as ort +import pytest +from transformers import AutoConfig + +from sparseml.transformers.sparsification import Trainer +from sparsezoo import Zoo +from src.sparseml.transformers import export_transformer_to_onnx, load_task_model + + +def _is_yaml_recipe_present(model_path): + return any( + [ + file + for file in glob.glob(os.path.join(model_path, "*")) + if (file.endswith(".yaml") or ("recipe" in file)) + ] + ) + + +def _run_inference_onnx(path_onnx, input): + ort_sess = ort.InferenceSession(path_onnx) + with np.load(input) as data: + input_0, input_1, input_2 = ( + data["input_0"].reshape(1, -1), + data["input_1"].reshape(1, -1), + data["input_2"].reshape(1, -1), + ) + output = ort_sess.run( + None, + {"input_ids": input_0, "attention_mask": input_1, "token_type_ids": input_2}, + ) + return output + + +def _compare_onnx_models(model1, model2): + optional_nodes_model1 = [ + "If", + "Equal", + "Gather", + "Shape", + # ops above are those which are used in the + # original graph to create logits and softmax heads + "Constant", + "Cast", + ] # ops above are the remaining optional nodes + optional_nodes_model2 = [ + "Constant", + "Squeeze", + ] # ops above are + # used in the original graph to create + # logits and softmax heads + + nodes1 = model1.graph.node + nodes1_names = [node.name for node in nodes1] + + nodes2 = model2.graph.node + nodes2_names = [node.name for node in nodes2] + + # Extract ops which are in nodes1 but not in nodes2 + nodes1_names_diff = [ + node_name for node_name in nodes1_names if node_name not in nodes2_names + ] + + # Extract ops which are in nodes2 but not in nodes1 + nodes2_names_diff = [ + node_name for node_name in nodes2_names if node_name not in nodes1_names + ] + # Assert that there are no important ops names in + # nodes1_names_diff or nodes2_names_diff + assert not [ + x for x in nodes1_names_diff if x.split("_")[0] not in optional_nodes_model1 + ] + assert not [ + x for x in nodes2_names_diff if x.split("_")[0] not in optional_nodes_model2 + ] + + # Compare the structure of nodes which share names across m1 and m2 + for node1 in nodes1: + if node1.name in set(nodes1_names).intersection(set(nodes2_names)): + for node2 in nodes2: + if node1.name == node2.name: + _compare_onnx_nodes(node1, node2) + + +def _compare_onnx_nodes(n1, n2): + # checking for consistent lengths seems like a sufficient test for now. + # due to internal structure, the naming of graph nodes + # may vary, even thought the semantics remain unchanged. + assert len(n1.input) == len(n2.input) + assert len(n1.output) == len(n2.output) + assert len(n1.op_type) == len(n2.op_type) + assert len(n1.attribute) == len(n2.attribute) + + +@pytest.mark.parametrize( + "model_stub, recipe_present, task", + [ + ( + "zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/pruned-conservative", # noqa: E501 + False, + "question-answering", + ) + ], + scope="function", +) +class TestModelFromZoo: + @pytest.fixture() + def setup(self, model_stub, recipe_present, task): + # setup + model = Zoo.load_model_from_stub(model_stub) + model.download() + + path_onnx = model.onnx_file.downloaded_path() + model_path = os.path.join(os.path.dirname(path_onnx), "pytorch") + + yield path_onnx, model_path, recipe_present, task + + # teardown + shutil.rmtree(os.path.dirname(model_path)) + + def test_load_weights_apply_recipe(self, setup): + path_onnx, model_path, recipe_present, task = setup + config = AutoConfig.from_pretrained(model_path) + model = load_task_model(task, model_path, config) + + assert model + assert recipe_present == _is_yaml_recipe_present(model_path) + if recipe_present: + + trainer = Trainer( + model=model, + model_state_path=model_path, + recipe=None, + recipe_args=None, + teacher=None, + ) + applied = trainer.apply_manager(epoch=math.inf, checkpoint=None) + + assert applied + + def test_outputs(self, setup): + path_onnx, model_path, recipe_present, task = setup + path_retrieved_onnx = export_transformer_to_onnx( + task=task, + model_path=model_path, + onnx_file_name="retrieved_model.onnx", + ) + + inputs_tar_path = os.path.join( + os.path.dirname(path_onnx), "sample-inputs.tar.gz" + ) + my_tar = tarfile.open(inputs_tar_path) + my_tar.extractall(model_path) + my_tar.close() + + inputs = glob.glob(os.path.join(model_path, "sample-inputs/*")) + for input in inputs: + out1 = _run_inference_onnx(path_onnx, input) + out2 = _run_inference_onnx(path_retrieved_onnx, input) + for o1, o2 in zip(out1, out2): + pytest.approx(o1, abs=1e-5) == o2 + + def test_export_to_onnx(self, setup): + path_onnx, model_path, recipe_present, task = setup + path_retrieved_onnx = export_transformer_to_onnx( + task=task, + model_path=model_path, + onnx_file_name="retrieved_model.onnx", + ) + + m1 = onnx.load(path_onnx) + m2 = onnx.load(os.path.join(model_path, path_retrieved_onnx)) + + assert m2 + + _compare_onnx_models(m1, m2) From f2f6bb0ad69ac827425c39df2354888d2cf9a12e Mon Sep 17 00:00:00 2001 From: "bogunowicz@arrival.com" Date: Mon, 4 Apr 2022 13:26:23 +0200 Subject: [PATCH 2/4] Refactor after Ben's comments --- .../transformers/test_transformers.py | 155 +++++++----------- 1 file changed, 62 insertions(+), 93 deletions(-) diff --git a/tests/sparseml/transformers/test_transformers.py b/tests/sparseml/transformers/test_transformers.py index 09f997d5e58..aa8c73db505 100644 --- a/tests/sparseml/transformers/test_transformers.py +++ b/tests/sparseml/transformers/test_transformers.py @@ -16,9 +16,8 @@ import math import os import shutil -import tarfile +from collections import Counter, OrderedDict -import numpy as np import onnx import onnxruntime as ort import pytest @@ -26,6 +25,7 @@ from sparseml.transformers.sparsification import Trainer from sparsezoo import Zoo +from sparsezoo.utils import load_numpy_list from src.sparseml.transformers import export_transformer_to_onnx, load_task_model @@ -39,79 +39,45 @@ def _is_yaml_recipe_present(model_path): ) -def _run_inference_onnx(path_onnx, input): +def _run_inference_onnx(path_onnx, input_data): ort_sess = ort.InferenceSession(path_onnx) - with np.load(input) as data: - input_0, input_1, input_2 = ( - data["input_0"].reshape(1, -1), - data["input_1"].reshape(1, -1), - data["input_2"].reshape(1, -1), - ) + model = onnx.load(path_onnx) + input_names = [inp.name for inp in model.graph.input] + + model_input = OrderedDict( + [(k, v.reshape(1, -1)) for k, v in zip(input_names, input_data.values())] + ) + output = ort_sess.run( None, - {"input_ids": input_0, "attention_mask": input_1, "token_type_ids": input_2}, + model_input, ) return output -def _compare_onnx_models(model1, model2): - optional_nodes_model1 = [ - "If", - "Equal", - "Gather", - "Shape", - # ops above are those which are used in the - # original graph to create logits and softmax heads - "Constant", - "Cast", - ] # ops above are the remaining optional nodes - optional_nodes_model2 = [ - "Constant", - "Squeeze", - ] # ops above are - # used in the original graph to create - # logits and softmax heads - - nodes1 = model1.graph.node - nodes1_names = [node.name for node in nodes1] - - nodes2 = model2.graph.node - nodes2_names = [node.name for node in nodes2] - - # Extract ops which are in nodes1 but not in nodes2 - nodes1_names_diff = [ - node_name for node_name in nodes1_names if node_name not in nodes2_names +def _compare_onnx_models(model_1, model_2): + major_nodes = [ + "QLinearMatMul", + "Gemm", + "MatMul", + "MatMulInteger", + "Conv", + "QLinearConv", + "ConvInteger", + "QuantizeLinear", + "DeQuantizeLinear", ] - # Extract ops which are in nodes2 but not in nodes1 - nodes2_names_diff = [ - node_name for node_name in nodes2_names if node_name not in nodes1_names - ] - # Assert that there are no important ops names in - # nodes1_names_diff or nodes2_names_diff - assert not [ - x for x in nodes1_names_diff if x.split("_")[0] not in optional_nodes_model1 - ] - assert not [ - x for x in nodes2_names_diff if x.split("_")[0] not in optional_nodes_model2 - ] - - # Compare the structure of nodes which share names across m1 and m2 - for node1 in nodes1: - if node1.name in set(nodes1_names).intersection(set(nodes2_names)): - for node2 in nodes2: - if node1.name == node2.name: - _compare_onnx_nodes(node1, node2) + nodes1 = model_1.graph.node + nodes1_names = [node.name for node in nodes1] + nodes1_count = Counter([node_name.split("_")[0] for node_name in nodes1_names]) + nodes2 = model_2.graph.node + nodes2_names = [node.name for node in nodes2] + nodes2_count = Counter([node_name.split("_")[0] for node_name in nodes2_names]) -def _compare_onnx_nodes(n1, n2): - # checking for consistent lengths seems like a sufficient test for now. - # due to internal structure, the naming of graph nodes - # may vary, even thought the semantics remain unchanged. - assert len(n1.input) == len(n2.input) - assert len(n1.output) == len(n2.output) - assert len(n1.op_type) == len(n2.op_type) - assert len(n1.attribute) == len(n2.attribute) + for node in major_nodes: + assert nodes1_count[node] == nodes2_count[node] @pytest.mark.parametrize( @@ -129,19 +95,20 @@ class TestModelFromZoo: @pytest.fixture() def setup(self, model_stub, recipe_present, task): # setup + self.onnx_retrieved_name = "retrieved_model.onnx" model = Zoo.load_model_from_stub(model_stub) model.download() - path_onnx = model.onnx_file.downloaded_path() - model_path = os.path.join(os.path.dirname(path_onnx), "pytorch") - - yield path_onnx, model_path, recipe_present, task + yield model, recipe_present, task # teardown + model_path = model.framework_files[0].dir_path shutil.rmtree(os.path.dirname(model_path)) def test_load_weights_apply_recipe(self, setup): - path_onnx, model_path, recipe_present, task = setup + model, recipe_present, task = setup + model_path = model.framework_files[0].dir_path + config = AutoConfig.from_pretrained(model_path) model = load_task_model(task, model_path, config) @@ -160,39 +127,41 @@ def test_load_weights_apply_recipe(self, setup): assert applied - def test_outputs(self, setup): - path_onnx, model_path, recipe_present, task = setup + def test_export_to_onnx(self, setup): + model, recipe_present, task = setup + path_onnx = model.onnx_file.downloaded_path() + model_path = model.framework_files[0].dir_path + path_retrieved_onnx = export_transformer_to_onnx( task=task, model_path=model_path, - onnx_file_name="retrieved_model.onnx", + onnx_file_name=self.onnx_retrieved_name, ) - inputs_tar_path = os.path.join( - os.path.dirname(path_onnx), "sample-inputs.tar.gz" - ) - my_tar = tarfile.open(inputs_tar_path) - my_tar.extractall(model_path) - my_tar.close() + zoo_model = onnx.load(path_onnx) + export_model = onnx.load(os.path.join(model_path, path_retrieved_onnx)) - inputs = glob.glob(os.path.join(model_path, "sample-inputs/*")) - for input in inputs: - out1 = _run_inference_onnx(path_onnx, input) - out2 = _run_inference_onnx(path_retrieved_onnx, input) - for o1, o2 in zip(out1, out2): - pytest.approx(o1, abs=1e-5) == o2 + assert export_model + + onnx.checker.check_model(export_model) + _compare_onnx_models(zoo_model, export_model) + + def test_outputs_ort(self, setup): + + model, recipe_present, task = setup + path_onnx = model.onnx_file.downloaded_path() + model_path = model.framework_files[0].dir_path + inputs_path = model.data_inputs.path + + input_data = load_numpy_list(inputs_path)[0] - def test_export_to_onnx(self, setup): - path_onnx, model_path, recipe_present, task = setup path_retrieved_onnx = export_transformer_to_onnx( task=task, model_path=model_path, - onnx_file_name="retrieved_model.onnx", + onnx_file_name=self.onnx_retrieved_name, ) - m1 = onnx.load(path_onnx) - m2 = onnx.load(os.path.join(model_path, path_retrieved_onnx)) - - assert m2 - - _compare_onnx_models(m1, m2) + out1 = _run_inference_onnx(path_onnx, input_data) + out2 = _run_inference_onnx(path_retrieved_onnx, input_data) + for o1, o2 in zip(out1, out2): + pytest.approx(o1, abs=1e-5) == o2 From da182db186a245821464ed59045ffab60b7076c7 Mon Sep 17 00:00:00 2001 From: "bogunowicz@arrival.com" Date: Thu, 17 Mar 2022 12:36:37 +0100 Subject: [PATCH 3/4] Add testing for weights load and recipe application --- tests/sparseml/transformers/test_transformers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/sparseml/transformers/test_transformers.py b/tests/sparseml/transformers/test_transformers.py index aa8c73db505..648f3e9cc49 100644 --- a/tests/sparseml/transformers/test_transformers.py +++ b/tests/sparseml/transformers/test_transformers.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +#import sparseml.transformers as __transformers # noqa: F401 import glob import math @@ -21,9 +22,8 @@ import onnx import onnxruntime as ort import pytest -from transformers import AutoConfig - from sparseml.transformers.sparsification import Trainer +from transformers import AutoConfig from sparsezoo import Zoo from sparsezoo.utils import load_numpy_list from src.sparseml.transformers import export_transformer_to_onnx, load_task_model From 212a7d42e9102e7cc78094eb8436164341c3a875 Mon Sep 17 00:00:00 2001 From: "bogunowicz@arrival.com" Date: Thu, 7 Apr 2022 11:19:50 +0200 Subject: [PATCH 4/4] Fix style --- tests/sparseml/transformers/test_transformers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/sparseml/transformers/test_transformers.py b/tests/sparseml/transformers/test_transformers.py index 648f3e9cc49..aa8c73db505 100644 --- a/tests/sparseml/transformers/test_transformers.py +++ b/tests/sparseml/transformers/test_transformers.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -#import sparseml.transformers as __transformers # noqa: F401 import glob import math @@ -22,8 +21,9 @@ import onnx import onnxruntime as ort import pytest -from sparseml.transformers.sparsification import Trainer from transformers import AutoConfig + +from sparseml.transformers.sparsification import Trainer from sparsezoo import Zoo from sparsezoo.utils import load_numpy_list from src.sparseml.transformers import export_transformer_to_onnx, load_task_model