Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions models/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from copy import deepcopy
import sys
import time
import os

sys.path.append('./') # to run '$ python *.py' files in subdirectories

Expand Down Expand Up @@ -197,20 +198,33 @@ def load_state_dict(model, state_dict, train, exclude_anchors):

print(f'{prefix} starting export with onnx {onnx.__version__}...')
f = opt.weights.replace('.pt', '.onnx') # filename
if not sparseml_wrapper.enabled:
torch.onnx.export(model, img, f, verbose=False, opset_version=opt.opset_version, input_names=['images'],
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None)
else:
# export through SparseML so quantized and pruned graphs can be corrected
save_dir = '/'.join(f.split('/')[:-1])
save_name = f.split('/')[-1]
exporter = ModuleExporter(model, save_dir)
exporter.export_onnx(img, name=save_name, convert_qat=True)
try:
skip_onnx_input_quantize(f, f)
except:
pass
# export through SparseML so quantized and pruned graphs can be corrected
save_dir = os.path.join(f.split(os.path.sep)[:-1])
save_name = f.split(os.path.sep)[-1]

# get the number of outputs so we know how to name and change dynamic axes
# nested outputs can be returned if model is exported with dynamic
def _count_outputs(outputs):
count = 0
if isinstance(outputs, list) or isinstance(outputs, tuple):
for out in outputs:
count += _count_outputs(out)
else:
count += 1
return count

outputs = model(img)
num_outputs = _count_outputs(outputs)
input_names = ['input']
output_names = [f'out_{i}' for i in range(num_outputs)]
dynamic_axes = {k: {0: 'batch'} for k in (input_names + output_names)} if opt.dynamic else None
exporter = ModuleExporter(model, save_dir)
exporter.export_onnx(img, name=save_name, convert_qat=True,
input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes)
try:
skip_onnx_input_quantize(f, f)
except:
pass

# Checks
model_onnx = onnx.load(f) # load onnx model
Expand Down