Skip to content

Commit

Permalink
[Model Compression] Expand export_model arguments: dummy input and on…
Browse files Browse the repository at this point in the history
…nx opset_version (#3968)
  • Loading branch information
xiaowu0162 committed Jul 26, 2021
1 parent deef0c4 commit 68818a3
Showing 1 changed file with 36 additions and 10 deletions.
46 changes: 36 additions & 10 deletions nni/compression/pytorch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ def _wrap_modules(self, layer, config):
wrapper.to(layer.module.weight.device)
return wrapper

def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None):
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None,
dummy_input=None, opset_version=None):
"""
Export pruned model weights, masks and onnx model(optional)
Expand All @@ -388,10 +389,21 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
input shape to onnx model, used for creating a dummy input tensor for torch.onnx.export
if the input has a complex structure (e.g., a tuple), please directly create the input and
pass it to dummy_input instead
note: this argument is deprecated and will be removed; please use dummy_input instead
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
device of the model, where to place the dummy input tensor for exporting onnx file;
the tensor is placed on cpu if ```device``` is None
only useful when both onnx_path and input_shape are passed
note: this argument is deprecated and will be removed; please use dummy_input instead
dummy_input: torch.Tensor or tuple
dummy input to the onnx model; used when input_shape is not enough to specify dummy input
user should ensure that the dummy_input is on the same device as the model
opset_version: int
opset_version parameter for torch.onnx.export; only useful when onnx_path is not None
if not passed, torch.onnx.export will use its default opset_version
"""
assert model_path is not None, 'model_path must be specified'
mask_dict = {}
Expand All @@ -412,17 +424,31 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N

torch.save(self.bound_model.state_dict(), model_path)
_logger.info('Model state_dict saved to %s', model_path)

if mask_path is not None:
torch.save(mask_dict, mask_path)
_logger.info('Mask dict saved to %s', mask_path)

if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
assert input_shape is not None or dummy_input is not None,\
'input_shape or dummy_input must be specified to export onnx model'
# create dummy_input using input_shape if input_shape is not passed
if dummy_input is None:
_logger.warning("""The argument input_shape and device will be removed in the future.
Please create a dummy input and pass it to dummy_input instead.""")
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape).to(device)
else:
input_data = dummy_input
if opset_version is not None:
torch.onnx.export(self.bound_model, input_data, onnx_path, opset_version=opset_version)
else:
torch.onnx.export(self.bound_model, input_data, onnx_path)
if dummy_input is None:
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
else:
_logger.info('Model in onnx saved to %s', onnx_path)

self._wrap_model()

Expand Down

0 comments on commit 68818a3

Please sign in to comment.