Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions neural_compressor/torch/export/pt2e_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,20 @@
# limitations under the License.
"""Export model for quantization."""

from functools import partial
from typing import Any, Dict, Optional, Tuple, Union

import torch
from torch.fx.graph_module import GraphModule

from neural_compressor.common.utils import logger
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, TORCH_VERSION_2_7_0, get_torch_version, is_ipex_imported
from neural_compressor.torch.utils import (
TORCH_VERSION_2_2_2,
TORCH_VERSION_2_7_0,
TORCH_VERSION_2_8_0,
get_torch_version,
is_ipex_imported,
)

__all__ = ["export", "export_model_for_pt2e_quant"]

Expand Down Expand Up @@ -52,7 +59,10 @@ def export_model_for_pt2e_quant(
# Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be
# updated to use the official `torch.export` API when that is ready.
cur_version = get_torch_version()
if cur_version >= TORCH_VERSION_2_7_0:
if cur_version >= TORCH_VERSION_2_8_0:
export_func = torch.export.export
export_func = partial(export_func, strict=True)
elif cur_version >= TORCH_VERSION_2_7_0:
export_func = torch.export.export_for_training
else:
export_func = torch._export.capture_pre_autograd_graph
Expand Down
1 change: 1 addition & 0 deletions neural_compressor/torch/utils/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def get_ipex_version():

TORCH_VERSION_2_2_2 = Version("2.2.2")
TORCH_VERSION_2_7_0 = Version("2.7.0")
TORCH_VERSION_2_8_0 = Version("2.8.0")


def get_torch_version():
Expand Down
Loading