Skip to content

Commit

Permalink
Fix bug in ONNXRT adapter where _rename_node function fails with mode…
Browse files Browse the repository at this point in the history
…l size > 2 GB (#1115)

Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
  • Loading branch information
yuwenzho committed Aug 2, 2023
1 parent 59371fe commit 1f6b1ad
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
23 changes: 17 additions & 6 deletions neural_compressor/adaptor/onnxrt.py
Expand Up @@ -245,10 +245,12 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
if model.model.opset_import[0].version < 11: # pragma: no cover
logger.warning("Quantize input needs model opset 11 or newer.")
if self.backend == 'DnnlExecutionProvider' and \
any([i.domain in ['', 'ai.onnx'] and i.version < 15 for i in model.model.opset_import]):
any([i.domain in ['', 'ai.onnx'] and \
i.version < 15 for i in model.model.opset_import]): # pragma: no cover
from onnx import version_converter
from neural_compressor.model.onnx_model import ONNXModel
try:
model.model = self._rename_node(version_converter.convert_version(model.model, 15))
model = self._rename_node(ONNXModel(version_converter.convert_version(model.model, 15)))
except:
logging.warning("Fail to upgrade model opset_import to >= 15, "\
"please upgrate it manually to run with bf16 data type")
Expand Down Expand Up @@ -716,7 +718,7 @@ def _pre_optimize(self, model, level=1):
model.model = self._replace_gemm_with_matmul(tmp_model).model if \
options.onnxrt.graph_optimization.gemm2matmul and self.recipes.get('gemm_to_matmul', True) else \
tmp_model
model.model = self._rename_node(model.model)
model = self._rename_node(model)
model = self._revert_fusedconv(model)
if self.backend == 'TensorrtExecutionProvider':
model = self._revert_conv_add_fusion(model)
Expand Down Expand Up @@ -788,7 +790,8 @@ def _revert_fusedconv(self, model):
model.update()
return model

def _rename_node(self, model):
def _rename_node(self, model_wrapper):
model = model_wrapper.model
node_names = [i.name for i in model.graph.node]
if len(set(node_names)) < len(node_names):
logger.warning("This model has nodes with the same name, please check" \
Expand All @@ -797,8 +800,16 @@ def _rename_node(self, model):
for idx, node in enumerate(model.graph.node):
if node_names.count(node.name) > 1:
node.name = node.op_type + '_nc_rename_' + str(idx)
onnx.save(model, os.path.join(self.work_space, "renamed_model.onnx"))
return model
if model_wrapper.is_large_model:
onnx.save(model,
os.path.join(self.work_space, "renamed_model.onnx"),
save_as_external_data=True,
all_tensors_to_one_file=True,
location="weights.pb",
convert_attribute=False)
else:
onnx.save(model, os.path.join(self.work_space, "renamed_model.onnx"))
return model_wrapper

@staticmethod
def _replace_gemm_with_matmul(model):
Expand Down
3 changes: 2 additions & 1 deletion neural_compressor/adaptor/ox_utils/quantizer.py
Expand Up @@ -336,6 +336,7 @@ def dfs(match_nodes, node, pattern):
children = self.model.get_children(match_nodes[1])
input_dtype = '1' # float32
output_dtype = '1' # 'float32'
outs = None
for inp in parent.input:
if inp in self.new_value_info:
input_dtype = str(self.new_value_info[inp].new_dtype)
Expand All @@ -345,7 +346,7 @@ def dfs(match_nodes, node, pattern):
if len(outs) > 0:
output_dtype = str(self.new_value_info[outs[0]].new_dtype)
break
if len(outs) == 0 or all([not self.should_cast(i) for i in children]):
if outs is None or len(outs) == 0 or all([not self.should_cast(i) for i in children]):
return
if input_dtype == str(match_nodes[1].attribute[0].i) and \
output_dtype == str(match_nodes[0].attribute[0].i) and \
Expand Down
15 changes: 15 additions & 0 deletions test/adaptor/onnxrt_adaptor/test_adaptor_onnxrt.py
Expand Up @@ -770,6 +770,7 @@ def tearDownClass(self):
os.remove("benchmark.yaml")
os.remove("gather.yaml")
os.remove("rename.yaml")
os.remove("rename_model.onnx")
os.remove("rn50_9.onnx")
os.remove(self.mb_v2_export_path)
os.remove(self.rn50_export_path)
Expand Down Expand Up @@ -1007,6 +1008,20 @@ def test_adaptor(self):
q_model = quantizer.fit()
self.assertNotEqual(q_model, None)

conf.model.framework = 'onnxrt_integerops'
conf.quantization.approach = 'post_training_dynamic_quant'
conf.quantization.calibration.sampling_size = 1
conf.evaluation.accuracy.metric = {'MSE': {'compare_label': False}}
quantizer = Quantization(conf)
quantizer.calib_dataloader = self.rename_dataloader
quantizer.eval_dataloader = self.rename_dataloader
onnx.save(self.rename_model, 'rename_model.onnx')
quantizer.model = 'rename_model.onnx'
# force set the model to large model
quantizer.model._is_large_model = True
q_model = quantizer.fit()
self.assertNotEqual(q_model, None)

quantizer = Quantization("dynamic.yaml")
quantizer.calib_dataloader = self.cv_dataloader
quantizer.eval_dataloader = self.cv_dataloader
Expand Down

0 comments on commit 1f6b1ad

Please sign in to comment.