From a93c4298c9a3fa153751349ad26e7e0ba9678d2b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 3 Sep 2025 12:38:57 +0530 Subject: [PATCH 1/6] feat: support aobaseconfig classes. --- .../quantizers/quantization_config.py | 152 +++++++++++++++--- .../quantizers/torchao/torchao_quantizer.py | 92 ++++++++--- tests/quantization/torchao/test_torchao.py | 23 +++ 3 files changed, 228 insertions(+), 39 deletions(-) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index bf857956512c..250441023281 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -21,12 +21,13 @@ """ import copy +import dataclasses import importlib.metadata import inspect import json import os import warnings -from dataclasses import dataclass +from dataclasses import dataclass, is_dataclass from enum import Enum from functools import partial from typing import Any, Callable, Dict, List, Optional, Union @@ -443,7 +444,7 @@ class TorchAoConfig(QuantizationConfigMixin): """This is a config class for torchao quantization/sparsity techniques. Args: - quant_type (`str`): + quant_type (Union[`str`, AOBaseConfig]): The type of quantization we want to use, currently supporting: - **Integer quantization:** - Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`, @@ -465,6 +466,7 @@ class TorchAoConfig(QuantizationConfigMixin): - **Unsigned Integer quantization:** - Full function names: `uintx_weight_only` - Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` + - An AOBaseConfig instance: for more advanced configuration options. modules_to_not_convert (`List[str]`, *optional*, default to `None`): The list of modules to not quantize, useful for quantizing models that explicitly require to have some modules left in their original precision. @@ -478,6 +480,12 @@ class TorchAoConfig(QuantizationConfigMixin): ```python from diffusers import FluxTransformer2DModel, TorchAoConfig + # AOBaseConfig-based configuration + from torchao.quantization import Int8WeightOnlyConfig + + quantization_config = TorchAoConfig(Int8WeightOnlyConfig()) + + # String-based config quantization_config = TorchAoConfig("int8wo") transformer = FluxTransformer2DModel.from_pretrained( "black-forest-labs/Flux.1-Dev", @@ -490,7 +498,7 @@ class TorchAoConfig(QuantizationConfigMixin): def __init__( self, - quant_type: str, + quant_type: Union[str, "AOBaseConfig"], # noqa: F821 modules_to_not_convert: Optional[List[str]] = None, **kwargs, ) -> None: @@ -504,8 +512,13 @@ def __init__( else: self.quant_type_kwargs = kwargs + self.post_init() + + def post_init(self): TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() - if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys(): + AO_VERSION = self._get_ao_version() + + if isinstance(self.quant_type, str) and self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys(): is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp") if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9(): raise ValueError( @@ -517,22 +530,95 @@ def __init__( f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the " f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." ) + elif AO_VERSION > version.parse("0.9.0"): + from torchao.quantization.quant_api import AOBaseConfig - method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type] - signature = inspect.signature(method) - all_kwargs = { - param.name - for param in signature.parameters.values() - if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD] - } - unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs) - - if len(unsupported_kwargs) > 0: + if not isinstance(self.quant_type, AOBaseConfig): + raise TypeError( + f"`quant_type` must be either a string or an `AOBaseConfig` instance, got {type(self.quant_type)}." + ) + else: raise ValueError( - f'The quantization method "{quant_type}" does not support the following keyword arguments: ' - f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}." + f"In torchao <= 0.9.0, quant_type must be a string. Got {type(self.quant_type)}. " + f"Please upgrade to torchao > 0.9.0 to use `AOBaseConfig` instances." ) + if isinstance(self.quant_type, str): + method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type] + signature = inspect.signature(method) + all_kwargs = { + param.name + for param in signature.parameters.values() + if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD] + } + unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs) + + if len(unsupported_kwargs) > 0: + raise ValueError( + f'The quantization method "{self.quant_type}" does not support the following keyword arguments: ' + f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}." + ) + + def to_dict(self): + """Convert configuration to a dictionary.""" + d = super().to_dict() + + if isinstance(self.quant_type, str): + # Handle layout serialization if present + if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]: + if is_dataclass(d["quant_type_kwargs"]["layout"]): + d["quant_type_kwargs"]["layout"] = [ + d["quant_type_kwargs"]["layout"].__class__.__name__, + dataclasses.asdict(d["quant_type_kwargs"]["layout"]), + ] + if isinstance(d["quant_type_kwargs"]["layout"], list): + assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layout kwargs" + assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string" + assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict" + else: + raise ValueError("layout must be a list") + else: + # Handle AOBaseConfig serialization + from torchao.core.config import config_to_dict + + # For now we assume there is 1 config per Transformer, however in the future + # We may want to support a config per fqn. + d["quant_type"] = {"default": config_to_dict(self.quant_type)} + + return d + + @classmethod + def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): + """Create configuration from a dictionary.""" + ao_version = cls._get_ao_version() + assert ao_version > version.parse("0.9.0"), "TorchAoConfig requires torchao > 0.9.0 for construction from dict" + config_dict = config_dict.copy() + quant_type = config_dict.pop("quant_type") + + if isinstance(quant_type, str): + return cls(quant_type=quant_type, **config_dict) + # Check if we only have one key which is "default" + # In the future we may update this + assert len(quant_type) == 1 and "default" in quant_type, ( + "Expected only one key 'default' in quant_type dictionary" + ) + quant_type = quant_type["default"] + + # Deserialize quant_type if needed + from torchao.core.config import config_from_dict + + quant_type = config_from_dict(quant_type) + + return cls(quant_type=quant_type, **config_dict) + + @staticmethod + def _get_ao_version() -> version.Version: + """Centralized check for TorchAO availability and version requirements.""" + if not is_torchao_available(): + raise ValueError("TorchAoConfig requires torchao to be installed. Install with `pip install torchao`") + + return version.parse(importlib.metadata.version("torchao")) + @classmethod def _get_torchao_quant_type_to_method(cls): r""" @@ -681,8 +767,38 @@ def _is_xpu_or_cuda_capability_atleast_8_9() -> bool: raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.") def get_apply_tensor_subclass(self): - TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() - return TORCHAO_QUANT_TYPE_METHODS[self.quant_type](**self.quant_type_kwargs) + """Create the appropriate quantization method based on configuration.""" + if isinstance(self.quant_type, str): + methods = self._get_torchao_quant_type_to_method() + quant_type_kwargs = self.quant_type_kwargs.copy() + if ( + not torch.cuda.is_available() + and is_torchao_available() + and self.quant_type == "int4_weight_only" + and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0") + and quant_type_kwargs.get("layout", None) is None + ): + if torch.xpu.is_available(): + if version.parse(importlib.metadata.version("torchao")) >= version.parse( + "0.11.0" + ) and version.parse(importlib.metadata.version("torch")) > version.parse("2.7.9"): + from torchao.dtypes import Int4XPULayout + from torchao.quantization.quant_primitives import ZeroPointDomain + + quant_type_kwargs["layout"] = Int4XPULayout() + quant_type_kwargs["zero_point_domain"] = ZeroPointDomain.INT + else: + raise ValueError( + "TorchAoConfig requires torchao >= 0.11.0 and torch >= 2.8.0 for XPU support. Please upgrade the version or use run on CPU with the cpu version pytorch." + ) + else: + from torchao.dtypes import Int4CPULayout + + quant_type_kwargs["layout"] = Int4CPULayout() + + return methods[self.quant_type](**quant_type_kwargs) + else: + return self.quant_type def __repr__(self): r""" diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 976bc8a1e0e5..ecb65ffea32c 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -18,9 +18,10 @@ """ import importlib +import re import types from fnmatch import fnmatch -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from packaging import version @@ -107,6 +108,21 @@ def _update_torch_safe_globals(): _update_torch_safe_globals() +def fuzzy_match_size(config_name: str) -> Optional[str]: + """ + Extract the size digit from strings like "4weight", "8weight". Returns the digit as an integer if found, otherwise + None. + """ + config_name = config_name.lower() + + str_match = re.search(r"(\d)weight", config_name) + + if str_match: + return str_match.group(1) + + return None + + logger = logging.get_logger(__name__) @@ -176,8 +192,7 @@ def validate_environment(self, *args, **kwargs): def update_torch_dtype(self, torch_dtype): quant_type = self.quantization_config.quant_type - - if quant_type.startswith("int") or quant_type.startswith("uint"): + if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")): if torch_dtype is not None and torch_dtype != torch.bfloat16: logger.warning( f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but " @@ -197,24 +212,44 @@ def update_torch_dtype(self, torch_dtype): def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": quant_type = self.quantization_config.quant_type - - if quant_type.startswith("int8") or quant_type.startswith("int4"): - # Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8 - return torch.int8 - elif quant_type == "uintx_weight_only": - return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8) - elif quant_type.startswith("uint"): - return { - 1: torch.uint1, - 2: torch.uint2, - 3: torch.uint3, - 4: torch.uint4, - 5: torch.uint5, - 6: torch.uint6, - 7: torch.uint7, - }[int(quant_type[4])] - elif quant_type.startswith("float") or quant_type.startswith("fp"): - return torch.bfloat16 + from accelerate.utils import CustomDtype + + if isinstance(quant_type, str): + if quant_type.startswith("int8"): + # Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8 + return torch.int8 + elif quant_type.startswith("int4"): + return CustomDtype.INT4 + elif quant_type == "uintx_weight_only": + return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8) + elif quant_type.startswith("uint"): + return { + 1: torch.uint1, + 2: torch.uint2, + 3: torch.uint3, + 4: torch.uint4, + 5: torch.uint5, + 6: torch.uint6, + 7: torch.uint7, + }[int(quant_type[4])] + elif quant_type.startswith("float") or quant_type.startswith("fp"): + return torch.bfloat16 + + elif self.quantization_config._get_ao_version() > version.Version("0.9.0"): + from torchao.core.config import AOBaseConfig + + quant_type = self.quantization_config.quant_type + if isinstance(quant_type, AOBaseConfig): + # Extract size digit using fuzzy match on the class name + config_name = quant_type.__class__.__name__ + size_digit = fuzzy_match_size(config_name) + + # Map the extracted digit to appropriate dtype + if size_digit == "4": + return CustomDtype.INT4 + else: + # Default to int8 + return torch.int8 if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION): return target_dtype @@ -297,6 +332,21 @@ def get_cuda_warm_up_factor(self): # Original mapping for non-AOBaseConfig types # For the uint types, this is a best guess. Once these types become more used # we can look into their nuances. + if self.quantization_config._get_ao_version() > version.Version("0.9.0"): + from torchao.core.config import AOBaseConfig + + quant_type = self.quantization_config.quant_type + # For autoquant case, it will be treated in the string implementation below in map_to_target_dtype + if isinstance(quant_type, AOBaseConfig): + # Extract size digit using fuzzy match on the class name + config_name = quant_type.__class__.__name__ + size_digit = fuzzy_match_size(config_name) + + if size_digit == "4": + return 8 + else: + return 4 + map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4} quant_type = self.quantization_config.quant_type for pattern, target_dtype in map_to_target_dtype.items(): diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 920c3a55f56c..14f9017a7c75 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -14,11 +14,13 @@ # limitations under the License. import gc +import importlib.metadata import tempfile import unittest from typing import List import numpy as np +from packaging import version from parameterized import parameterized from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel @@ -65,6 +67,9 @@ from torchao.quantization.quant_primitives import MappingType from torchao.utils import get_model_size_in_bytes + if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.9.0"): + from torchao.quantization import Int8WeightOnlyConfig + @require_torch @require_torch_accelerator @@ -522,6 +527,15 @@ def test_sequential_cpu_offload(self): inputs = self.get_dummy_inputs(torch_device) _ = pipe(**inputs) + @require_torchao_version_greater_or_equal("0.9.0") + def test_aobase_config(self): + quantization_config = TorchAoConfig(Int8WeightOnlyConfig()) + components = self.get_dummy_components(quantization_config) + pipe = FluxPipeline(**components).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + _ = pipe(**inputs) + # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch @@ -576,6 +590,7 @@ def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, inputs = self.get_dummy_tensor_inputs(torch_device) output = quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + print(f"{output_slice=}") weight = quantized_model.transformer_blocks[0].ff.net[2].weight self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) @@ -628,6 +643,14 @@ def test_int_a16w8_cpu(self): self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) + @require_torchao_version_greater_or_equal("0.9.0") + def test_aobase_config(self): + quant_method, quant_method_kwargs = Int8WeightOnlyConfig(), {} + expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) + device = torch_device + self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) + self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) + @require_torchao_version_greater_or_equal("0.7.0") class TorchAoCompileTest(QuantCompileTests, unittest.TestCase): From 44ed55efe4fb585426ecb2ef2d5cb27af8910756 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Mon, 8 Sep 2025 19:24:09 -0700 Subject: [PATCH 2/6] [docs] AOBaseConfig (#12302) init Co-authored-by: Sayak Paul --- docs/source/en/quantization/torchao.md | 105 ++++++++++++++++--------- 1 file changed, 66 insertions(+), 39 deletions(-) diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index 5c7578dcbb4e..18cc109e0785 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -11,69 +11,96 @@ specific language governing permissions and limitations under the License. --> # torchao -[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more. +[torchao](https://github.com/pytorch/ao) provides high-performance dtypes and optimizations based on quantization and sparsity for inference and training PyTorch models. It is supported for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. -Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed. +Make sure Pytorch 2.5+ and torchao are installed with the command below. ```bash -pip install -U torch torchao +uv pip install -U torch torchao ``` +Each quantization dtype is available as a separate instance of a [AOBaseConfig](https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize) class. This provides more flexible configuration options by exposing more available arguments. -Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. +Pass the `AOBaseConfig` of a quantization dtype, like [Int4WeightOnlyConfig](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int4WeightOnlyConfig) to [`TorchAoConfig`] in [`~ModelMixin.from_pretrained`]. -The example below only quantizes the weights to int8. - -```python +```py import torch -from diffusers import FluxPipeline, AutoModel, TorchAoConfig - -model_id = "black-forest-labs/FLUX.1-dev" -dtype = torch.bfloat16 +from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig +from torchao.quantization import Int8WeightOnlyConfig -quantization_config = TorchAoConfig("int8wo") -transformer = AutoModel.from_pretrained( - model_id, - subfolder="transformer", - quantization_config=quantization_config, - torch_dtype=dtype, +pipeline_quant_config = PipelineQuantizationConfig( + quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128)))} ) -pipe = FluxPipeline.from_pretrained( - model_id, - transformer=transformer, - torch_dtype=dtype, +pipeline = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + quantzation_config=pipeline_quant_config, + torch_dtype=torch.bfloat16, + device_map="cuda" ) -pipe.to("cuda") +``` -# Without quantization: ~31.447 GB -# With quantization: ~20.40 GB -print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB") +For simple use cases, you could also provide a string identifier in [`TorchAo`] as shown below. -prompt = "A cat holding a sign that says hello world" -image = pipe( - prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512 -).images[0] -image.save("output.png") +```py +import torch +from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig + +pipeline_quant_config = PipelineQuantizationConfig( + quant_mapping={"transformer": TorchAoConfig("int8wo")} +) +pipeline = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + quantzation_config=pipeline_quant_config, + torch_dtype=torch.bfloat16, + device_map="cuda" +) ``` -TorchAO is fully compatible with [torch.compile](../optimization/fp16#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code. +## torch.compile + +torchao supports [torch.compile](../optimization/fp16#torchcompile) which can speed up inference with one line of code. ```python -# In the above code, add the following after initializing the transformer -transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True) +import torch +from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig +from torchao.quantization import Int4WeightOnlyConfig + +pipeline_quant_config = PipelineQuantizationConfig( + quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128)))} +) +pipeline = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + quantzation_config=pipeline_quant_config, + torch_dtype=torch.bfloat16, + device_map="cuda" +) + +pipeline.transformer.compile(transformer, mode="max-autotune", fullgraph=True) ``` -For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware. +Refer to this [table](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450) for inference speed and memory usage benchmarks with Flux and CogVideoX. More benchmarks on various hardware are also available in the torchao [repository](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks). > [!TIP] > The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible. -torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future. +## autoquant + +torchao provides [autoquant](https://docs.pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) an automatic quantization API. Autoquantization chooses the best quantization strategy by comparing the performance of each strategy on chosen input types and shapes. This is only supported in Diffusers for individual models at the moment. + +```py +import torch +from diffusers import DiffusionPipeline +from torchao.quantization import autoquant + +# Load the pipeline +pipeline = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + torch_dtype=torch.bfloat16, + device_map="cuda" +) -The `TorchAoConfig` class accepts three parameters: -- `quant_type`: A string value mentioning one of the quantization types below. -- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`. -- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`. +transformer = autoquant(pipeline.transformer) +``` ## Supported quantization types From 5524a9da94e110adc465fc32316f12c9ebfa02b6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 18 Sep 2025 12:11:22 +0530 Subject: [PATCH 3/6] up --- .../quantizers/quantization_config.py | 46 +++++++++---------- tests/quantization/torchao/test_torchao.py | 1 - 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 250441023281..d959f2beb72a 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -518,32 +518,20 @@ def post_init(self): TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() AO_VERSION = self._get_ao_version() - if isinstance(self.quant_type, str) and self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys(): - is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp") - if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9(): + if isinstance(self.quant_type, str): + if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys(): + is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp") + if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9(): + raise ValueError( + f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You " + f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`." + ) + raise ValueError( - f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You " - f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`." + f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the " + f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." ) - raise ValueError( - f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the " - f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." - ) - elif AO_VERSION > version.parse("0.9.0"): - from torchao.quantization.quant_api import AOBaseConfig - - if not isinstance(self.quant_type, AOBaseConfig): - raise TypeError( - f"`quant_type` must be either a string or an `AOBaseConfig` instance, got {type(self.quant_type)}." - ) - else: - raise ValueError( - f"In torchao <= 0.9.0, quant_type must be a string. Got {type(self.quant_type)}. " - f"Please upgrade to torchao > 0.9.0 to use `AOBaseConfig` instances." - ) - - if isinstance(self.quant_type, str): method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type] signature = inspect.signature(method) all_kwargs = { @@ -558,6 +546,18 @@ def post_init(self): f'The quantization method "{self.quant_type}" does not support the following keyword arguments: ' f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}." ) + elif AO_VERSION > version.parse("0.9.0"): + from torchao.quantization.quant_api import AOBaseConfig + + if not isinstance(self.quant_type, AOBaseConfig): + raise TypeError( + f"`quant_type` must be either a string or an `AOBaseConfig` instance, got {type(self.quant_type)}." + ) + else: + raise ValueError( + f"In torchao <= 0.9.0, quant_type must be a string. Got {type(self.quant_type)}. " + f"Please upgrade to torchao > 0.9.0 to use `AOBaseConfig` instances." + ) def to_dict(self): """Convert configuration to a dictionary.""" diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 14f9017a7c75..38997de17b12 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -590,7 +590,6 @@ def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, inputs = self.get_dummy_tensor_inputs(torch_device) output = quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() - print(f"{output_slice=}") weight = quantized_model.transformer_blocks[0].ff.net[2].weight self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) From a8bcb0324599adf8cf0d90dc0a2e17f201f2675a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 22 Sep 2025 17:18:12 +0530 Subject: [PATCH 4/6] replace with is_torchao_version --- src/diffusers/quantizers/quantization_config.py | 17 ++++------------- .../quantizers/torchao/torchao_quantizer.py | 4 ++-- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index d959f2beb72a..850f215faa76 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -34,7 +34,7 @@ from packaging import version -from ..utils import is_torch_available, is_torchao_available, logging +from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging if is_torch_available(): @@ -516,7 +516,6 @@ def __init__( def post_init(self): TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() - AO_VERSION = self._get_ao_version() if isinstance(self.quant_type, str): if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys(): @@ -546,7 +545,7 @@ def post_init(self): f'The quantization method "{self.quant_type}" does not support the following keyword arguments: ' f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}." ) - elif AO_VERSION > version.parse("0.9.0"): + elif is_torchao_version(">", "0.9.0"): from torchao.quantization.quant_api import AOBaseConfig if not isinstance(self.quant_type, AOBaseConfig): @@ -590,8 +589,8 @@ def to_dict(self): @classmethod def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): """Create configuration from a dictionary.""" - ao_version = cls._get_ao_version() - assert ao_version > version.parse("0.9.0"), "TorchAoConfig requires torchao > 0.9.0 for construction from dict" + if not is_torchao_version(">", "0.9.0"): + raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict") config_dict = config_dict.copy() quant_type = config_dict.pop("quant_type") @@ -611,14 +610,6 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): return cls(quant_type=quant_type, **config_dict) - @staticmethod - def _get_ao_version() -> version.Version: - """Centralized check for TorchAO availability and version requirements.""" - if not is_torchao_available(): - raise ValueError("TorchAoConfig requires torchao to be installed. Install with `pip install torchao`") - - return version.parse(importlib.metadata.version("torchao")) - @classmethod def _get_torchao_quant_type_to_method(cls): r""" diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index ecb65ffea32c..2334c7af8630 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -235,7 +235,7 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": elif quant_type.startswith("float") or quant_type.startswith("fp"): return torch.bfloat16 - elif self.quantization_config._get_ao_version() > version.Version("0.9.0"): + elif is_torchao_version(">", "0.9.0"): from torchao.core.config import AOBaseConfig quant_type = self.quantization_config.quant_type @@ -332,7 +332,7 @@ def get_cuda_warm_up_factor(self): # Original mapping for non-AOBaseConfig types # For the uint types, this is a best guess. Once these types become more used # we can look into their nuances. - if self.quantization_config._get_ao_version() > version.Version("0.9.0"): + if is_torchao_version(">", "0.9.0"): from torchao.core.config import AOBaseConfig quant_type = self.quantization_config.quant_type From 32c09b3d40b431a1a2bbc9a5a863484334a797bf Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 22 Sep 2025 17:29:20 +0530 Subject: [PATCH 5/6] up --- src/diffusers/quantizers/quantization_config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 850f215faa76..d8d27d165827 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -759,7 +759,9 @@ def _is_xpu_or_cuda_capability_atleast_8_9() -> bool: def get_apply_tensor_subclass(self): """Create the appropriate quantization method based on configuration.""" - if isinstance(self.quant_type, str): + if not isinstance(self.quant_type, str): + return self.quant_type + else: methods = self._get_torchao_quant_type_to_method() quant_type_kwargs = self.quant_type_kwargs.copy() if ( @@ -788,8 +790,6 @@ def get_apply_tensor_subclass(self): quant_type_kwargs["layout"] = Int4CPULayout() return methods[self.quant_type](**quant_type_kwargs) - else: - return self.quant_type def __repr__(self): r""" From ee62edffb1d58874b59efa46ba7356eeba05aa65 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 24 Sep 2025 15:26:12 +0530 Subject: [PATCH 6/6] up --- .../quantizers/quantization_config.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index d8d27d165827..5dd8f56717df 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -515,9 +515,21 @@ def __init__( self.post_init() def post_init(self): - TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() + if not isinstance(self.quant_type, str): + if is_torchao_version("<=", "0.9.0"): + raise ValueError( + f"torchao <= 0.9.0 only supports string quant_type, got {type(self.quant_type).__name__}. " + f"Upgrade to torchao > 0.9.0 to use AOBaseConfig." + ) + + from torchao.quantization.quant_api import AOBaseConfig + + if not isinstance(self.quant_type, AOBaseConfig): + raise TypeError(f"quant_type must be a AOBaseConfig instance, got {type(self.quant_type).__name__}") + + elif isinstance(self.quant_type, str): + TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() - if isinstance(self.quant_type, str): if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys(): is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp") if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9(): @@ -545,18 +557,6 @@ def post_init(self): f'The quantization method "{self.quant_type}" does not support the following keyword arguments: ' f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}." ) - elif is_torchao_version(">", "0.9.0"): - from torchao.quantization.quant_api import AOBaseConfig - - if not isinstance(self.quant_type, AOBaseConfig): - raise TypeError( - f"`quant_type` must be either a string or an `AOBaseConfig` instance, got {type(self.quant_type)}." - ) - else: - raise ValueError( - f"In torchao <= 0.9.0, quant_type must be a string. Got {type(self.quant_type)}. " - f"Please upgrade to torchao > 0.9.0 to use `AOBaseConfig` instances." - ) def to_dict(self): """Convert configuration to a dictionary."""