-
Notifications
You must be signed in to change notification settings - Fork 141
/
exporter.py
816 lines (708 loc) · 30.2 KB
/
exporter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Export PyTorch models to the local device
"""
import collections
import json
import logging
import os
import shutil
import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy
import onnx
import torch
from onnx import numpy_helper
from packaging import version
from torch import Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from sparseml.exporters.onnx_to_deepsparse import ONNXToDeepsparse
from sparseml.onnx.utils import ONNXGraph
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparseml.pytorch.utils.helpers import (
adjust_quantization_for_onnx_export,
tensors_export,
tensors_module_forward,
tensors_to_device,
)
from sparseml.pytorch.utils.model import (
is_parallel_model,
save_model,
script_model,
trace_model,
)
from sparseml.utils import clean_path, create_parent_dirs
from sparsezoo.utils import save_onnx, validate_onnx
__all__ = [
"ModuleExporter",
"export_onnx",
]
_PARSED_TORCH_VERSION = version.parse(torch.__version__)
MODEL_ONNX_NAME = "model.onnx"
CONFIG_JSON_NAME = "config.json"
_LOGGER = logging.getLogger(__name__)
class ModuleExporter(object):
"""
An exporter for exporting PyTorch modules into ONNX format
as well as numpy arrays for the input and output tensors.
:param module: the module to export
:param output_dir: the directory to export the module and extras to
"""
def __init__(
self,
module: Module,
output_dir: str,
):
if is_parallel_model(module):
module = module.module
self._module = deepcopy(module).to("cpu").eval()
self._output_dir = clean_path(output_dir)
def export_to_zoo(
self,
dataloader: DataLoader,
original_dataloader: Optional[DataLoader] = None,
shuffle: bool = False,
max_samples: int = 20,
data_split_cb: Optional[Callable[[Any], Tuple[Any, Any]]] = None,
label_mapping_cb: Optional[Callable[[Any], Any]] = None,
trace_script: bool = False,
fail_on_torchscript_failure: bool = True,
export_entire_model: bool = False,
):
"""
Creates and exports all related content of module including
sample data, onnx, pytorch and torchscript.
:param dataloader: DataLoader used to generate sample data
:param original_dataloader: Optional dataloader to obtain the untransformed
image.
:param shuffle: Whether to shuffle sample data
:param max_samples: Max number of sample data to create
:param data_split_cb: Optional callback function to split data sample into
a tuple (features,labels). If not provided will assume dataloader
returns a tuple (features,labels).
:param label_mapping_cb: Optional callback function to mapping dataset label to
other formats.
:param dataset_wrapper: Wrapper function for the dataset to add original data
to each sample. If set to None will default to use the
'iter_dataset_with_orig_wrapper' function.
:param trace_script: If true, creates torchscript via tracing. Otherwise,
creates the torchscripe via scripting.
:param fail_on_torchscript_failure: If true, fails if torchscript is unable
to export model.
:param export_entire_model: Exports entire file instead of state_dict
"""
sample_batches = []
sample_labels = []
sample_originals = None
if original_dataloader is not None:
sample_originals = []
for originals in original_dataloader:
sample_originals.append(originals)
if len(sample_originals) == max_samples:
break
for sample in dataloader:
if data_split_cb is not None:
features, labels = data_split_cb(sample)
else:
features, labels = sample
if label_mapping_cb:
labels = label_mapping_cb(labels)
sample_batches.append(features)
sample_labels.append(labels)
if len(sample_batches) == max_samples:
break
self.export_onnx(sample_batch=sample_batches[0])
self.export_pytorch(export_entire_model=export_entire_model)
try:
if trace_script:
self.export_torchscript(sample_batch=sample_batches[0])
else:
self.export_torchscript()
except Exception as e:
if fail_on_torchscript_failure:
raise e
else:
_LOGGER.warn(
f"Unable to create torchscript file. Following error occurred: {e}"
)
self.export_samples(
sample_batches,
sample_labels=sample_labels,
sample_originals=sample_originals,
)
@classmethod
def get_output_names(cls, out: Any):
"""
Get name of output tensors. Derived exporters specific to frameworks
could override this method
:param out: outputs of the model
:return: list of names
"""
return _get_output_names(out)
def export_onnx(
self,
sample_batch: Any,
name: str = MODEL_ONNX_NAME,
opset: int = TORCH_DEFAULT_ONNX_OPSET,
disable_bn_fusing: bool = True,
convert_qat: bool = False,
**export_kwargs,
):
"""
Export an onnx file for the current module and for a sample batch.
Sample batch used to feed through the model to freeze the graph for a
particular execution.
:param sample_batch: the batch to export an onnx for, handles creating the
static graph for onnx as well as setting dimensions
:param name: name of the onnx file to save
:param opset: onnx opset to use for exported model.
Default is based on torch version.
:param disable_bn_fusing: torch >= 1.7.0 only. Set True to disable batch norm
fusing during torch export. Default and suggested setting is True. Batch
norm fusing will change the exported parameter names as well as affect
sensitivity analyses of the exported graph. Additionally, the DeepSparse
inference engine, and other engines, perform batch norm fusing at model
compilation.
:param convert_qat: if True and quantization aware training is detected in
the module being exported, the resulting QAT ONNX model will be converted
to a fully quantized ONNX model using `ONNXToDeepsparse`. Default
is False.
:param export_kwargs: kwargs to be passed as is to the torch.onnx.export api
call. Useful to pass in dyanmic_axes, input_names, output_names, etc.
See more on the torch.onnx.export api spec in the PyTorch docs:
https://pytorch.org/docs/stable/onnx.html
"""
if not export_kwargs:
export_kwargs = {}
module = deepcopy(self._module).cpu() # don't modify the original model
if "output_names" not in export_kwargs:
sample_batch = tensors_to_device(sample_batch, "cpu")
module.eval()
with torch.no_grad():
out = tensors_module_forward(
sample_batch, module, check_feat_lab_inp=False
)
export_kwargs["output_names"] = self.get_output_names(out)
adjust_quantization_for_onnx_export(module) # in-place operation
export_onnx(
module=module,
sample_batch=sample_batch,
file_path=os.path.join(self._output_dir, name),
opset=opset,
disable_bn_fusing=disable_bn_fusing,
convert_qat=convert_qat,
**export_kwargs,
)
def export_torchscript(
self,
name: str = "model.pts",
sample_batch: Optional[Any] = None,
):
"""
Export the torchscript into a pts file within a framework directory. If
a sample batch is provided, will create torchscript model in trace mode.
Otherwise uses script to create torchscript.
:param name: name of the torchscript file to save
:param sample_batch: If provided, will create torchscript model via tracing
using the sample_batch
"""
path = os.path.join(self._output_dir, "framework", name)
create_parent_dirs(path)
if sample_batch:
trace_model(path, self._module, sample_batch)
else:
script_model(path, self._module)
def create_deployment_folder(
self,
labels_to_class_mapping: Optional[Union[str, Dict[int, str]]] = None,
onnx_model_name: Optional[str] = None,
) -> str:
"""
Create a deployment folder inside the `self._output_dir` directory.
:param labels_to_class_mapping: information about the mapping
from integer labels to string class names.
Can be either a string (path to the .json serialized dictionary)
or a dictionary. Default is None
:param onnx_model_name: name of the onnx model file. Defaults to `model.onnx`
:return path to the deployment folder
"""
deployment_folder_dir = os.path.join(self._output_dir, "deployment")
if os.path.isdir(deployment_folder_dir):
shutil.rmtree(deployment_folder_dir)
os.makedirs(deployment_folder_dir)
_LOGGER.info(f"Created deployment folder at {deployment_folder_dir}")
# copy over model onnx
onnx_model_name = onnx_model_name or MODEL_ONNX_NAME
expected_onnx_model_dir = os.path.join(self._output_dir, onnx_model_name)
deployment_onnx_model_dir = os.path.join(deployment_folder_dir, onnx_model_name)
_copy_file(src=expected_onnx_model_dir, target=deployment_onnx_model_dir)
_LOGGER.info(
f"Saved {onnx_model_name} in the deployment "
f"folder at {deployment_onnx_model_dir}"
)
# create config.json
config_file_path = _create_config_file(save_dir=deployment_folder_dir)
if labels_to_class_mapping:
# append `labels_to_class_mapping` info to config.json
_save_label_to_class_mapping(
labels_to_class_mapping=labels_to_class_mapping,
config_file_path=config_file_path,
)
return deployment_folder_dir
def export_pytorch(
self,
optimizer: Optional[Optimizer] = None,
recipe: Optional[str] = None,
epoch: Optional[int] = None,
name: str = "model.pth",
use_zipfile_serialization_if_available: bool = True,
include_modifiers: bool = False,
export_entire_model: bool = False,
arch_key: Optional[str] = None,
):
"""
Export the pytorch state dicts into pth file within a
pytorch framework directory.
:param optimizer: optional optimizer to export along with the module
:param recipe: the recipe used to obtain the model
:param epoch: optional epoch to export along with the module
:param name: name of the pytorch file to save
:param use_zipfile_serialization_if_available: for torch >= 1.6.0 only
exports the Module's state dict using the new zipfile serialization
:param include_modifiers: if True, and a ScheduledOptimizer is provided
as the optimizer, the associated ScheduledModifierManager and its
Modifiers will be exported under the 'manager' key. Default is False
:param export_entire_model: Exports entire file instead of state_dict
:param arch_key: if provided, the `arch_key` will be saved in the
checkpoint
"""
pytorch_path = os.path.join(self._output_dir, "training")
pth_path = os.path.join(pytorch_path, name)
create_parent_dirs(pth_path)
if export_entire_model:
torch.save(self._module, pth_path)
else:
save_model(
pth_path,
self._module,
optimizer,
recipe,
epoch,
use_zipfile_serialization_if_available=(
use_zipfile_serialization_if_available
),
include_modifiers=include_modifiers,
arch_key=arch_key,
)
def export_samples(
self,
sample_batches: List[Any],
sample_labels: Optional[List[Any]] = None,
sample_originals: Optional[List[Any]] = None,
exp_counter: int = 0,
):
"""
Export a set list of sample batches as inputs and outputs through the model.
:param sample_batches: a list of the sample batches to feed through the module
for saving inputs and outputs
:param sample_labels: an optional list of sample labels that correspond to the
the batches for saving
:param exp_counter: the counter to start exporting the tensor files at
"""
sample_batches = [tensors_to_device(batch, "cpu") for batch in sample_batches]
inputs_dir = os.path.join(self._output_dir, "sample-inputs")
outputs_dir = os.path.join(self._output_dir, "sample-outputs")
labels_dir = os.path.join(self._output_dir, "sample-labels")
originals_dir = os.path.join(self._output_dir, "sample_originals")
with torch.no_grad():
for batch, lab, orig in zip(
sample_batches,
sample_labels if sample_labels else [None for _ in sample_batches],
sample_originals
if sample_originals
else [None for _ in sample_batches],
):
out = tensors_module_forward(batch, self._module)
exported_input = tensors_export(
batch,
inputs_dir,
name_prefix="inp",
counter=exp_counter,
break_batch=True,
)
if isinstance(out, dict):
new_out = []
for key in out:
new_out.append(out[key])
out = new_out
exported_output = tensors_export(
out,
outputs_dir,
name_prefix="out",
counter=exp_counter,
break_batch=True,
)
if lab is not None:
tensors_export(
lab, labels_dir, "lab", counter=exp_counter, break_batch=True
)
if orig is not None:
tensors_export(
orig,
originals_dir,
"orig",
counter=exp_counter,
break_batch=True,
)
assert len(exported_input) == len(exported_output)
exp_counter += len(exported_input)
def export_onnx(
module: Module,
sample_batch: Any,
file_path: str,
opset: int = TORCH_DEFAULT_ONNX_OPSET,
disable_bn_fusing: bool = True,
convert_qat: bool = False,
dynamic_axes: Union[str, Dict[str, List[int]]] = None,
skip_input_quantize: bool = False,
**export_kwargs,
):
"""
Export an onnx file for the current module and for a sample batch.
Sample batch used to feed through the model to freeze the graph for a
particular execution.
:param module: torch Module object to export
:param sample_batch: the batch to export an onnx for, handles creating the
static graph for onnx as well as setting dimensions
:param file_path: path to the onnx file to save
:param opset: onnx opset to use for exported model.
Default is based on torch version.
:param disable_bn_fusing: torch >= 1.7.0 only. Set True to disable batch norm
fusing during torch export. Default and suggested setting is True. Batch
norm fusing will change the exported parameter names as well as affect
sensitivity analyses of the exported graph. Additionally, the DeepSparse
inference engine, and other engines, perform batch norm fusing at model
compilation.
:param convert_qat: if True and quantization aware training is detected in
the module being exported, the resulting QAT ONNX model will be converted
to a fully quantized ONNX model using `ONNXToDeepsparse`. Default
is False.
:param dynamic_axes: dictionary of input or output names to list of dimensions
of those tensors that should be exported as dynamic. May input 'batch'
to set the first dimension of all inputs and outputs to dynamic. Default
is an empty dict
:param skip_input_quantize: if True, the export flow will attempt to delete
the first Quantize Linear Nodes(s) immediately after model input and set
the model input type to UINT8. Default is False
:param export_kwargs: kwargs to be passed as is to the torch.onnx.export api
call. Useful to pass in dyanmic_axes, input_names, output_names, etc.
See more on the torch.onnx.export api spec in the PyTorch docs:
https://pytorch.org/docs/stable/onnx.html
"""
if _PARSED_TORCH_VERSION >= version.parse("1.10.0") and opset < 13 and convert_qat:
raise ValueError(
"Exporting onnx with QAT and opset < 13 may result in errors. "
"Please use opset>=13 with QAT. "
"See https://github.com/pytorch/pytorch/issues/77455 for more info. "
)
if not export_kwargs:
export_kwargs = {}
if isinstance(sample_batch, Dict) and not isinstance(
sample_batch, collections.OrderedDict
):
warnings.warn(
"Sample inputs passed into the ONNX exporter should be in "
"the same order defined in the model forward function. "
"Consider using OrderedDict for this purpose.",
UserWarning,
)
sample_batch = tensors_to_device(sample_batch, "cpu")
create_parent_dirs(file_path)
module = deepcopy(module).cpu()
with torch.no_grad():
out = tensors_module_forward(sample_batch, module, check_feat_lab_inp=False)
if "input_names" not in export_kwargs:
if isinstance(sample_batch, Tensor):
export_kwargs["input_names"] = ["input"]
elif isinstance(sample_batch, Dict):
export_kwargs["input_names"] = list(sample_batch.keys())
sample_batch = tuple(
[sample_batch[f] for f in export_kwargs["input_names"]]
)
elif isinstance(sample_batch, Iterable):
export_kwargs["input_names"] = [
"input_{}".format(index) for index, _ in enumerate(iter(sample_batch))
]
if isinstance(sample_batch, List):
sample_batch = tuple(sample_batch) # torch.onnx.export requires tuple
if "output_names" not in export_kwargs:
export_kwargs["output_names"] = _get_output_names(out)
# Set all batch sizes to be dynamic
if dynamic_axes is not None:
for tensor_name in export_kwargs["input_names"] + export_kwargs["output_names"]:
if tensor_name not in dynamic_axes:
dynamic_axes[tensor_name] = {0: "batch"}
else:
dynamic_axes[tensor_name][0] = "batch"
else:
dynamic_axes = {
tensor_name: {0: "batch"}
for tensor_name in (
export_kwargs["input_names"] + export_kwargs["output_names"]
)
}
# disable active quantization observers because they cannot be exported
disabled_observers = []
for submodule in module.modules():
if (
hasattr(submodule, "observer_enabled")
and submodule.observer_enabled[0] == 1
):
submodule.observer_enabled[0] = 0
disabled_observers.append(submodule)
is_quant_module = any(
hasattr(submodule, "qconfig") and submodule.qconfig
for submodule in module.modules()
)
batch_norms_wrapped = False
if (
_PARSED_TORCH_VERSION >= version.parse("1.7")
and not is_quant_module
and disable_bn_fusing
):
# prevent batch norm fusing by adding a trivial operation before every
# batch norm layer
batch_norms_wrapped = _wrap_batch_norms(module)
kwargs = dict(
model=module,
args=sample_batch,
f=file_path,
verbose=False,
opset_version=opset,
dynamic_axes=dynamic_axes,
**export_kwargs,
)
if _PARSED_TORCH_VERSION < version.parse("1.10.0"):
kwargs["strip_doc_string"] = True
else:
kwargs["training"] = torch.onnx.TrainingMode.PRESERVE
kwargs["do_constant_folding"] = not module.training
kwargs["keep_initializers_as_inputs"] = False
torch.onnx.export(**kwargs)
# re-enable disabled quantization observers
for submodule in disabled_observers:
submodule.observer_enabled[0] = 1
# onnx file fixes
onnx_model = onnx.load(file_path)
_fold_identity_initializers(onnx_model)
_flatten_qparams(onnx_model)
if batch_norms_wrapped:
# fix changed batch norm names
_unwrap_batchnorms(onnx_model)
# clean up graph from any injected / wrapped operations
_delete_trivial_onnx_adds(onnx_model)
save_onnx(onnx_model, file_path)
if convert_qat and is_quant_module:
use_qlinear_conv = hasattr(module, "export_with_qlinearconv") and (
module.export_with_qlinearconv
)
use_qlinear_matmul = hasattr(module, "export_with_qlinearmatmul") and (
module.export_with_qlinearmatmul
)
exporter = ONNXToDeepsparse(
use_qlinear_conv=use_qlinear_conv,
use_qlinear_matmul=use_qlinear_matmul,
skip_input_quantize=skip_input_quantize,
)
exporter.export(pre_transforms_model=file_path, file_path=file_path)
def _copy_file(src: str, target: str):
if not os.path.exists(src):
raise ValueError(
f"Attempting to copy file from {src}, but the file does not exist."
)
shutil.copyfile(src, target)
def _create_config_file(save_dir: str) -> str:
config_file_path = os.path.join(save_dir, CONFIG_JSON_NAME)
with open(config_file_path, "w"):
# create empty json file
pass
_LOGGER.info(f"Created {CONFIG_JSON_NAME} file at {save_dir}")
return config_file_path
def _save_label_to_class_mapping(
labels_to_class_mapping: Union[str, Dict[int, str]],
config_file_path: str,
key_name: str = "labels_to_class_mapping",
):
"""
Appends `labels_to_class_mapping` information to the config.json file:
- new key: `labels_to_class_mapping`
- new value: a dictionary that maps the integer
labels to string class names
If config.json already contains `labels_to_class_mapping`,
this information will be overwritten
:param labels_to_class_mapping: information about the mapping from
integer labels to string class names. Can be either a string
(path to the .json serialized dictionary) or a dictionary.
:param config_file_path: path to the directory of the `config.json` file.
:param key_name: the key under which the information about
the mapping will be stored inside the config.json file
"""
is_config_empty = os.stat(config_file_path).st_size == 0
if not is_config_empty:
with open(config_file_path, "r") as outfile:
config = json.load(outfile.read())
else:
config = {}
# check whether the label names are not already present in the config.
if key_name in config.keys():
_LOGGER.warning(
f"File: {CONFIG_JSON_NAME} already contains key {key_name}. "
f"{key_name} data will be overwritten"
)
if isinstance(labels_to_class_mapping, str):
with open(labels_to_class_mapping) as outfile:
labels_to_class_mapping = json.load(outfile)
config[key_name] = labels_to_class_mapping
with open(config_file_path, "w") as outfile:
json.dump(config, outfile)
_LOGGER.info(
f"Appended {key_name} data to {CONFIG_JSON_NAME} at {config_file_path}"
)
def _flatten_qparams(model: onnx.ModelProto):
# transforms any QuantizeLinear/DequantizeLinear that have
# zero_point/scale with shapes `(1,)` into shape `()`
graph = ONNXGraph(model)
inits_to_flatten = set()
for node in model.graph.node:
if node.op_type in ["QuantizeLinear", "DequantizeLinear"]:
# scale is required if the input is an initializer
scale_init = graph.get_init_by_name(node.input[1])
if scale_init is not None and list(scale_init.dims) == [1]:
inits_to_flatten.add(node.input[1])
# zero_point is optional AND shape must match
# scale. so if scale is (1,), then so will zero point
if len(node.input) == 3:
inits_to_flatten.add(node.input[2])
for i, init in enumerate(model.graph.initializer):
if init.name not in inits_to_flatten:
continue
a = numpy_helper.to_array(init)
assert a.shape == (1,)
b = numpy.array(a[0])
assert b.shape == ()
assert b.dtype == a.dtype
model.graph.initializer[i].CopyFrom(numpy_helper.from_array(b, name=init.name))
def _fold_identity_initializers(model: onnx.ModelProto):
# folds any Identity nodes that have a single input (which is an initializer)
# and a single output
matches = []
graph = ONNXGraph(model)
def is_match(node: onnx.NodeProto) -> bool:
return (
node.op_type == "Identity"
and len(node.input) == 1
and len(node.output) == 1
and node.input[0] in graph._name_to_initializer
)
for node in model.graph.node:
if not is_match(node):
continue
matches.append(node)
# find any node in the graph that uses the output of `node`
# as an input. replace the input with `node`'s input
for other in graph.get_node_children(node):
for i, other_input_i in enumerate(other.input):
# NOTE: this just replaces the str ids
if other_input_i == node.output[0]:
other.input[i] = node.input[0]
for node in matches:
model.graph.node.remove(node)
def _get_output_names(out: Any):
"""
Get name of output tensors
:param out: outputs of the model
:return: list of names
"""
output_names = None
if isinstance(out, Tensor):
output_names = ["output"]
elif hasattr(out, "keys") and callable(out.keys):
output_names = list(out.keys())
elif isinstance(out, Iterable):
output_names = ["output_{}".format(index) for index, _ in enumerate(iter(out))]
return output_names
class _AddNoOpWrapper(Module):
# trivial wrapper to break-up Conv-BN blocks
def __init__(self, module: Module):
super().__init__()
self.bn_wrapper_replace_me = module
def forward(self, inp):
inp = inp + 0 # no-op
return self.bn_wrapper_replace_me(inp)
def _get_submodule(module: Module, path: List[str]) -> Module:
if not path:
return module
return _get_submodule(getattr(module, path[0]), path[1:])
def _wrap_batch_norms(module: Module) -> bool:
# wrap all batch norm layers in module with a trivial wrapper
# to prevent BN fusing during export
batch_norms_wrapped = False
for name, submodule in module.named_modules():
if (
isinstance(submodule, torch.nn.BatchNorm1d)
or isinstance(submodule, torch.nn.BatchNorm2d)
or isinstance(submodule, torch.nn.BatchNorm3d)
):
submodule_path = name.split(".")
parent_module = _get_submodule(module, submodule_path[:-1])
setattr(parent_module, submodule_path[-1], _AddNoOpWrapper(submodule))
batch_norms_wrapped = True
return batch_norms_wrapped
def _delete_trivial_onnx_adds(model: onnx.ModelProto):
# delete all add nodes in the graph with second inputs as constant nodes set to 0
add_nodes = [node for node in model.graph.node if node.op_type == "Add"]
for add_node in add_nodes:
try:
add_const_node = [
node for node in model.graph.node if node.output[0] == add_node.input[1]
][0]
add_const_val = numpy_helper.to_array(add_const_node.attribute[0].t)
if numpy.all(add_const_val == 0.0):
# update graph edges
parent_node = [
node
for node in model.graph.node
if add_node.input[0] in node.output
]
if not parent_node:
continue
parent_node[0].output[0] = add_node.output[0]
# remove node and constant
model.graph.node.remove(add_node)
model.graph.node.remove(add_const_node)
except Exception: # skip node on any error
continue
def _unwrap_batchnorms(model: onnx.ModelProto):
for init in model.graph.initializer:
init.name = init.name.replace(".bn_wrapper_replace_me", "")
for node in model.graph.node:
for idx in range(len(node.input)):
node.input[idx] = node.input[idx].replace(".bn_wrapper_replace_me", "")
for idx in range(len(node.output)):
node.output[idx] = node.output[idx].replace(".bn_wrapper_replace_me", "")
validate_onnx(model)