diff --git a/models/export.py b/models/export.py index 5ca17efc781a..dc4a039d68bb 100644 --- a/models/export.py +++ b/models/export.py @@ -8,6 +8,7 @@ from copy import deepcopy import sys import time +import os sys.path.append('./') # to run '$ python *.py' files in subdirectories @@ -197,20 +198,33 @@ def load_state_dict(model, state_dict, train, exclude_anchors): print(f'{prefix} starting export with onnx {onnx.__version__}...') f = opt.weights.replace('.pt', '.onnx') # filename - if not sparseml_wrapper.enabled: - torch.onnx.export(model, img, f, verbose=False, opset_version=opt.opset_version, input_names=['images'], - dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640) - 'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None) - else: - # export through SparseML so quantized and pruned graphs can be corrected - save_dir = '/'.join(f.split('/')[:-1]) - save_name = f.split('/')[-1] - exporter = ModuleExporter(model, save_dir) - exporter.export_onnx(img, name=save_name, convert_qat=True) - try: - skip_onnx_input_quantize(f, f) - except: - pass + # export through SparseML so quantized and pruned graphs can be corrected + save_dir = os.path.join(f.split(os.path.sep)[:-1]) + save_name = f.split(os.path.sep)[-1] + + # get the number of outputs so we know how to name and change dynamic axes + # nested outputs can be returned if model is exported with dynamic + def _count_outputs(outputs): + count = 0 + if isinstance(outputs, list) or isinstance(outputs, tuple): + for out in outputs: + count += _count_outputs(out) + else: + count += 1 + return count + + outputs = model(img) + num_outputs = _count_outputs(outputs) + input_names = ['input'] + output_names = [f'out_{i}' for i in range(num_outputs)] + dynamic_axes = {k: {0: 'batch'} for k in (input_names + output_names)} if opt.dynamic else None + exporter = ModuleExporter(model, save_dir) + exporter.export_onnx(img, name=save_name, convert_qat=True, + input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes) + try: + skip_onnx_input_quantize(f, f) + except: + pass # Checks model_onnx = onnx.load(f) # load onnx model