diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index ea7af31b3e..8280b1c39c 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -115,6 +115,7 @@ def optimize_for_ort( config_name: str | None = None, *, debug: bool = False, + clear_metadata: bool = False, ) -> tuple[ir.Model, dict[str, int]]: """ Optimize the model for ORT backend. @@ -128,6 +129,7 @@ def optimize_for_ort( Typically it identifies the Execution Provider (EP) to optimize for. If None, the default configuration will be used. debug: If debug is True, enable pattern matching tracer for debugging. + clear_metadata: If True, clear metadata and doc strings from the model. Returns: A tuple containing: @@ -145,7 +147,6 @@ def optimize_for_ort( passes = ir.passes.Sequential( # Apply the ORT optimization passes. # https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L172 - common_passes.ClearMetadataAndDocStringPass(), # https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L139 common_passes.LiftConstantsToInitializersPass(lift_all_constants=False, size_limit=1), common_passes.RemoveInitializersFromInputsPass(), @@ -154,4 +155,8 @@ def optimize_for_ort( assert passes.in_place result = passes(model) assert result.model is model + + if clear_metadata: + common_passes.ClearMetadataAndDocStringPass()(model) + return model, fusion_count