From 68a1757f2997c6a8a73464d49ce89894b0595426 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 19 Sep 2025 02:49:00 -0400 Subject: [PATCH] fix export for pt 2.8 Signed-off-by: yiliu30 --- neural_compressor/torch/export/pt2e_export.py | 14 ++++++++++++-- neural_compressor/torch/utils/environ.py | 1 + 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/neural_compressor/torch/export/pt2e_export.py b/neural_compressor/torch/export/pt2e_export.py index 2beb61fb698..f668169de32 100644 --- a/neural_compressor/torch/export/pt2e_export.py +++ b/neural_compressor/torch/export/pt2e_export.py @@ -13,13 +13,20 @@ # limitations under the License. """Export model for quantization.""" +from functools import partial from typing import Any, Dict, Optional, Tuple, Union import torch from torch.fx.graph_module import GraphModule from neural_compressor.common.utils import logger -from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, TORCH_VERSION_2_7_0, get_torch_version, is_ipex_imported +from neural_compressor.torch.utils import ( + TORCH_VERSION_2_2_2, + TORCH_VERSION_2_7_0, + TORCH_VERSION_2_8_0, + get_torch_version, + is_ipex_imported, +) __all__ = ["export", "export_model_for_pt2e_quant"] @@ -52,7 +59,10 @@ def export_model_for_pt2e_quant( # Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be # updated to use the official `torch.export` API when that is ready. cur_version = get_torch_version() - if cur_version >= TORCH_VERSION_2_7_0: + if cur_version >= TORCH_VERSION_2_8_0: + export_func = torch.export.export + export_func = partial(export_func, strict=True) + elif cur_version >= TORCH_VERSION_2_7_0: export_func = torch.export.export_for_training else: export_func = torch._export.capture_pre_autograd_graph diff --git a/neural_compressor/torch/utils/environ.py b/neural_compressor/torch/utils/environ.py index c8b8f596e61..6d8307841e6 100644 --- a/neural_compressor/torch/utils/environ.py +++ b/neural_compressor/torch/utils/environ.py @@ -152,6 +152,7 @@ def get_ipex_version(): TORCH_VERSION_2_2_2 = Version("2.2.2") TORCH_VERSION_2_7_0 = Version("2.7.0") +TORCH_VERSION_2_8_0 = Version("2.8.0") def get_torch_version():