diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 053a7a5aeb..f1122e4362 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -29,7 +29,7 @@ from transformers.modeling_utils import get_parameter_dtype from transformers.utils import is_tf_available, is_torch_available -from ...onnx.utils import _get_onnx_external_data_tensors, check_model_uses_external_data +from ...onnx.utils import _get_onnx_external_constants, _get_onnx_external_data_tensors, check_model_uses_external_data from ...utils import ( DEFAULT_DUMMY_SHAPES, ONNX_WEIGHTS_NAME, @@ -592,6 +592,7 @@ def remap(value): if model_uses_external_data or FORCE_ONNX_EXTERNAL_DATA: tensors_paths = _get_onnx_external_data_tensors(onnx_model) + constant_paths = _get_onnx_external_constants(onnx_model) logger.info("Saving external data to one file...") # try free model memory @@ -618,6 +619,10 @@ def remap(value): for tensor in tensors_paths: os.remove(output.parent / tensor) + for tensor in constant_paths: + if os.path.isfile(output.parent / tensor): + os.remove(output.parent / tensor) + return input_names, output_names diff --git a/optimum/onnx/utils.py b/optimum/onnx/utils.py index 3eca9a8610..b52c4f4cda 100644 --- a/optimum/onnx/utils.py +++ b/optimum/onnx/utils.py @@ -19,6 +19,19 @@ from onnx.external_data_helper import ExternalDataInfo, _get_initializer_tensors +def _get_onnx_external_constants(model: onnx.ModelProto) -> List[str]: + external_constants = [] + + for node in model.graph.node: + if node.op_type == "Constant": + for attribute in node.attribute: + external_datas = attribute.t.external_data + for external_data in external_datas: + external_constants.append(external_data.value) + + return external_constants + + def _get_onnx_external_data_tensors(model: onnx.ModelProto) -> List[str]: """ Gets the paths of the external data tensors in the model.