Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_offline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: 3.9
python-version: "3.10"

- name: Install dependencies
run: |
Expand Down
8 changes: 8 additions & 0 deletions docs/source/openvino/export.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ Optional arguments:
--group-size GROUP_SIZE
The group size to use for quantization. Recommended value is 128 and -1 uses per-column
quantization.
--group-size-fallback {error,ignore,adjust}
Specifies how to handle operations that do not support the given group size. Possible values are:
`error`: raise an error if the given group size is not supported by a node, this is the default
behavior;
`ignore`: skip nodes that cannot be compressed with the given group size;
`adjust`: adjust the group size to the maximum supported value for each problematic node, if
there is no valid value greater than or equal to 32, then the node is quantized to the backup
precision which is int8_asym by default.
--backup-precision {none,int8_sym,int8_asym}
Defines a backup precision for mixed-precision weight compression. Only valid for 4-bit weight
formats. If not provided, backup precision is int8_asym. 'none' stands for original floating-
Expand Down
15 changes: 15 additions & 0 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,20 @@ def parse_args_openvino(parser: "ArgumentParser"):
default=None,
help=("The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization."),
)
optional_group.add_argument(
"--group-size-fallback",
type=str,
choices=["error", "ignore", "adjust"],
default=None,
help=(
"Specifies how to handle operations that do not support the given group size. Possible values are: "
"`error`: raise an error if the given group size is not supported by a node, this is the default behavior; "
"`ignore`: skip nodes that cannot be compressed with the given group size; "
"`adjust`: adjust the group size to the maximum supported value for each problematic node, if there is no "
"valid value greater than or equal to 32, then the node is quantized to the backup precision which is "
"int8_asym by default. "
),
)
optional_group.add_argument(
"--backup-precision",
type=str,
Expand Down Expand Up @@ -595,6 +609,7 @@ def prepare_wc_config(args, default_configs):
"dtype": args.weight_format,
"backup_precision": args.backup_precision,
"statistics_path": args.quantization_statistics_path,
"group_size_fallback": args.group_size_fallback,
}


Expand Down
7 changes: 7 additions & 0 deletions optimum/intel/openvino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ..utils.import_utils import (
is_diffusers_available,
is_nncf_available,
is_nncf_version,
is_sentence_transformers_available,
)
from .utils import (
Expand Down Expand Up @@ -47,6 +48,12 @@
logging.disable(logging.INFO)
import nncf

if is_nncf_version("<", "2.19"):
raise ImportError(
"NNCF version 2.19 or higher is required to use NNCF-based quantization. "
f"Please upgrade your NNCF installation. The current version of NNCF is {nncf.__version__}."
)

logging.disable(logging.NOTSET)

# Suppress version mismatch logging
Expand Down
70 changes: 38 additions & 32 deletions optimum/intel/openvino/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from optimum.configuration_utils import BaseConfig

from ..utils.import_utils import is_nncf_available, is_nncf_version
from ..utils.import_utils import is_nncf_available
from .utils import (
PREDEFINED_CAUSAL_LANGUAGE_DATASETS,
PREDEFINED_LANGUAGE_DATASETS,
Expand Down Expand Up @@ -350,32 +350,22 @@ class OVQuantizationMethod(str, Enum):
"sym": False,
"group_size": -1,
},
"inceptionai/jais-13b": {
"bits": 4,
"sym": False,
"group_size": 128,
"ratio": 1.0,
"group_size_fallback": "adjust",
},
"HuggingFaceTB/SmolVLM2-256M-Video-Instruct": {
"bits": 4,
"sym": False,
"group_size": 128,
"ratio": 1.0,
"group_size_fallback": "adjust",
},
}

if is_nncf_available():
# TODO: Remove after update to NNCF 2.19 because `group_size_fallback` argument will be added to OVWeightQuantizationConfig
_DEFAULT_4BIT_WQ_CONFIGS.update(
{
"inceptionai/jais-13b": {
"bits": 4,
"sym": False,
"group_size": 128,
"ratio": 1.0,
"advanced_parameters": nncf.AdvancedCompressionParameters(
group_size_fallback_mode=nncf.GroupSizeFallbackMode.ADJUST,
),
},
"HuggingFaceTB/SmolVLM2-256M-Video-Instruct": {
"bits": 4,
"sym": False,
"group_size": 128,
"ratio": 1.0,
"advanced_parameters": nncf.AdvancedCompressionParameters(
group_size_fallback_mode=nncf.GroupSizeFallbackMode.ADJUST,
),
},
}
)

# Add configs for model id aliases
# The list below contains pairs of model ids: config for the second model id will be copied from the first model id.
Expand Down Expand Up @@ -726,6 +716,13 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase):
multiple times on the same model and dataset to avoid recomputing statistics.
Please note that the statistics depend on the dataset, so if you change the dataset, you should also change
the statistics path to avoid confusion.
group_size_fallback (`str`, *optional*):
Defines the behavior when the specified group size is not compatible with the weight shape. Possible values:
- "error": raises an error if the group size is not compatible with the weight shape (default);
- "ignore": skips quantization for the layers where the group size is not compatible with the weight shape;
- "adjust": automatically adjusts the group size to the maximum compatible value for each weight tensor,
if there is no valid value greater than or equal to 32, then the node is quantized to the backup precision
which is int8_asym by default.
kwargs: Additional parameters for nncf.compress_weights() call.
"""

Expand All @@ -750,6 +747,7 @@ def __init__(
lora_correction: bool = None,
backup_precision: Optional[str] = None,
statistics_path: Optional[str] = None,
group_size_fallback: Optional[str] = None,
**kwargs,
):
weight_format = kwargs.pop("weight_format", None)
Expand Down Expand Up @@ -781,6 +779,7 @@ def __init__(
self.backup_precision = backup_precision
self.dtype = dtype
self.statistics_path = statistics_path
self.group_size_fallback = group_size_fallback
self.post_init()

def post_init(self):
Expand Down Expand Up @@ -830,9 +829,6 @@ def post_init(self):
"quantization algorithm is selected and compression ratio is 1.0."
)

if self.dataset is None and self.quant_method == OVQuantizationMethod.AWQ and is_nncf_version("<", "2.17.0"):
raise ValueError("Data-free AWQ is available starting form NNCF 2.17. Please update nncf package.")

if self.dtype in ["int4", "int8"]:
bits = 4 if self.dtype == "int4" else 8
if self.bits is not None and self.bits != bits:
Expand Down Expand Up @@ -914,6 +910,13 @@ def post_init(self):
if self.gptq and self.lora_correction:
raise ValueError("The GPTQ and LoRA Correction algorithms can't be applied simultaneously")

valid_group_size_fallback_values = [e.value for e in nncf.GroupSizeFallbackMode]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if we should add a check somewhere (like in

if is_nncf_available():
) to ensure that the nncf version is compatible with our integration (in our case >= v2.19 I imagine ?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point! Added

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

if self.group_size_fallback not in valid_group_size_fallback_values + [None]:
raise ValueError(
f"`group_size_fallback` must be one of the following: {valid_group_size_fallback_values}, "
f"but found: {self.group_size_fallback}"
)

def to_nncf_dict(self) -> Dict[str, Any]:
"""
Returns a dictionary with the variables that are ready to use for nncf.quantize() call.
Expand All @@ -923,8 +926,6 @@ def to_nncf_dict(self) -> Dict[str, Any]:
mode = self.dtype if self.dtype else signed_bitness[self.bits]
if mode in signed_bitness.values():
mode += "_sym" if self.sym else "_asym"
if mode == "mxfp4":
mode = "e2m1" if is_nncf_version("<=", "2.18") else "mxfp4"
if mode == "cb4":
mode = "cb4_f8e4m3"
mode = nncf.CompressWeightsMode(mode)
Expand All @@ -933,9 +934,14 @@ def to_nncf_dict(self) -> Dict[str, Any]:
sensitivity_metric = nncf.SensitivityMetric(self.sensitivity_metric) if self.sensitivity_metric else None
backup_mode = nncf.BackupMode(self.backup_precision) if self.backup_precision else None
kwargs = self.kwargs.copy()
if self.statistics_path:
if self.statistics_path or self.group_size_fallback:
advanced_parameters = kwargs.get("advanced_parameters", nncf.AdvancedCompressionParameters())
advanced_parameters = dataclasses.replace(advanced_parameters, statistics_path=self.statistics_path)
if self.statistics_path:
advanced_parameters = dataclasses.replace(advanced_parameters, statistics_path=self.statistics_path)
if self.group_size_fallback:
advanced_parameters = dataclasses.replace(
advanced_parameters, group_size_fallback_mode=nncf.GroupSizeFallbackMode(self.group_size_fallback)
)
kwargs["advanced_parameters"] = advanced_parameters
result = {
"mode": mode,
Expand Down
8 changes: 0 additions & 8 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,12 @@
from transformers.utils import is_accelerate_available

from optimum.quantization_base import OptimumQuantizer
from optimum.utils.logging import warn_once

from ..utils.import_utils import (
DATASETS_IMPORT_ERROR,
_nncf_version,
is_datasets_available,
is_diffusers_available,
is_nncf_version,
is_sentence_transformers_available,
)
from .configuration import (
Expand Down Expand Up @@ -771,12 +769,6 @@ def _prepare_visual_causal_lm_calibration_data(
and input_dict["pixel_values"].dim() == 4
and input_dict["pixel_values"].shape[0] > 1
):
if is_nncf_version("<=", "2.18"):
# TODO (Nikita): Remove once NNCF 2.19 is released.
warn_once(
logger,
"If you are facing RAM OOM issues, please update to the latest NNCF develop version.",
)
batch_size = input_dict["pixel_values"].shape[0]
for i in range(batch_size):
single_batch_input_dict = {}
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@
QUALITY_REQUIRE = ["black~=23.1", "ruff==0.4.4"]

EXTRAS_REQUIRE = {
"nncf": ["nncf>=2.18.0"],
"openvino": ["nncf>=2.18.0", "openvino>=2025.1.0", "openvino-tokenizers>=2025.1.0"],
"nncf": ["nncf>=2.19.0"],
"openvino": ["nncf>=2.19.0", "openvino>=2025.1.0", "openvino-tokenizers>=2025.1.0"],
"neural-compressor": ["neural-compressor[pt]>=3.4.1", "accelerate", "transformers<4.46", "datasets"],
"ipex": ["intel-extension-for-pytorch>=2.8", "transformers>4.54,<4.56", "accelerate"],
"diffusers": ["diffusers"],
Expand Down
11 changes: 8 additions & 3 deletions tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
from optimum.intel.openvino.utils import _HEAD_TO_AUTOMODELS, TemporaryDirectory
from optimum.intel.utils.import_utils import (
compare_versions,
is_nncf_version,
is_openvino_tokenizers_available,
is_openvino_version,
is_tokenizers_version,
Expand Down Expand Up @@ -466,11 +465,11 @@ class OVCLIExportTestCase(unittest.TestCase):
"--dataset coco --num-samples 1",
{
"vision_encoder": 75,
"prompt_encoder_mask_decoder": 61 if is_nncf_version("<=", "2.18") else 60,
"prompt_encoder_mask_decoder": 60,
},
{
"vision_encoder": {"int8": 75},
"prompt_encoder_mask_decoder": {"int8": 50 if is_nncf_version("<=", "2.18") else 49},
"prompt_encoder_mask_decoder": {"int8": 49},
},
),
(
Expand Down Expand Up @@ -528,6 +527,12 @@ class OVCLIExportTestCase(unittest.TestCase):
"int4 --ratio 1.0 --sym --group-size 8 --all-layers",
{"model": {"int4": 16}},
),
(
"text-generation-with-past",
"gpt2",
"int4 --sym --group-size-fallback adjust",
{"model": {"int8": 4, "int4": 20}},
),
(
"text-generation-with-past",
"llama_awq",
Expand Down
40 changes: 33 additions & 7 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,16 @@
import unittest
from collections import defaultdict
from collections.abc import Iterable
from enum import Enum
from functools import partial
from typing import Union, Type

import openvino as ov
import pytest
import numpy as np
import torch
from PIL import Image
from parameterized import parameterized
import nncf
from transformers import (
AutoModelForQuestionAnswering,
AutoTokenizer,
AutoProcessor,
AutoConfig,
Expand Down Expand Up @@ -85,7 +82,7 @@
from copy import deepcopy

from optimum.intel.openvino.quantization import InferRequestWrapper, OVCalibrationDatasetBuilder
from optimum.intel.utils.import_utils import is_openvino_version, is_transformers_version, is_nncf_version
from optimum.intel.utils.import_utils import is_openvino_version, is_transformers_version
from utils_tests import (
MODEL_NAMES,
get_num_quantized_nodes,
Expand Down Expand Up @@ -362,11 +359,11 @@ class OVQuantizerTest(unittest.TestCase):
OVQuantizationConfig(bits=8, dataset="coco", num_samples=1),
{
"vision_encoder": 75,
"prompt_encoder_mask_decoder": 61 if is_nncf_version("<=", "2.18") else 60,
"prompt_encoder_mask_decoder": 60,
},
{
"vision_encoder": {"int8": 75},
"prompt_encoder_mask_decoder": {"int8": 50 if is_nncf_version("<=", "2.18") else 49},
"prompt_encoder_mask_decoder": {"int8": 49},
},
),
(
Expand Down Expand Up @@ -735,7 +732,7 @@ class OVWeightCompressionTest(unittest.TestCase):
dict(bits=4, dataset="coco", num_samples=1, group_size=2),
{
"vision_encoder": {"int8": 56, "int4": 94},
"prompt_encoder_mask_decoder": {"int8": 6, "int4": 94 if is_nncf_version("<=", "2.18") else 92},
"prompt_encoder_mask_decoder": {"int8": 6, "int4": 92},
},
),
(
Expand Down Expand Up @@ -938,6 +935,35 @@ class OVWeightCompressionTest(unittest.TestCase):
dict(bits=4, sym=False, group_size=32, ratio=1.0),
{"model": {"int8": 2, "int4": 14}},
),
(
OVModelForCausalLM,
"gpt2",
False,
dict(bits=4, sym=True, group_size_fallback="adjust"),
{"model": {"int8": 4, "int4": 20}},
),
(
OVModelForCausalLM,
"llama",
False,
dict(
bits=4,
sym=True,
group_size_fallback="adjust",
),
{"model": {"int8": 28, "int4": 2}},
),
(
OVModelForCausalLM,
"llama",
False,
dict(
bits=4,
sym=True,
group_size_fallback="ignore",
),
{"model": {"int8": 4}},
),
]

# filter models type depending on min max transformers version
Expand Down
6 changes: 3 additions & 3 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch

from optimum.exporters.tasks import TasksManager
from optimum.intel.utils.import_utils import is_nncf_version, is_openvino_version, is_transformers_version
from optimum.intel.utils.import_utils import is_openvino_version, is_transformers_version


SEED = 42
Expand Down Expand Up @@ -319,7 +319,7 @@
"transformer": 58,
"vae_decoder": 28,
"vae_encoder": 28,
"text_encoder": 16 if is_nncf_version(">", "2.17") else 18,
"text_encoder": 16,
},
"ltx-video": {
"transformer": 34,
Expand All @@ -329,7 +329,7 @@
},
"sam": {
"vision_encoder": 102 if is_openvino_version("<", "2025.2.0") else 150,
"prompt_encoder_mask_decoder": 100 if is_nncf_version("<=", "2.18") else 98,
"prompt_encoder_mask_decoder": 98,
},
"speecht5": {
"encoder": 28,
Expand Down