Skip to content

Commit 5b735b7

Browse files
authored
Merge branch 'main' into flux-kontext-modular
2 parents de3846b + 4acbfbf commit 5b735b7

22 files changed

+1055
-91
lines changed

.github/workflows/nightly_tests.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,9 @@ jobs:
340340
- backend: "optimum_quanto"
341341
test_location: "quanto"
342342
additional_deps: []
343+
- backend: "nvidia_modelopt"
344+
test_location: "modelopt"
345+
additional_deps: []
343346
runs-on:
344347
group: aws-g6e-xlarge-plus
345348
container:

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@
188188
title: torchao
189189
- local: quantization/quanto
190190
title: quanto
191+
- local: quantization/modelopt
192+
title: NVIDIA ModelOpt
191193

192194
- title: Model accelerators and hardware
193195
isExpanded: false
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# NVIDIA ModelOpt
13+
14+
[NVIDIA-ModelOpt](https://github.com/NVIDIA/TensorRT-Model-Optimizer) is a unified library of state-of-the-art model optimization techniques like quantization, pruning, distillation, speculative decoding, etc. It compresses deep learning models for downstream deployment frameworks like TensorRT-LLM or TensorRT to optimize inference speed.
15+
16+
Before you begin, make sure you have nvidia_modelopt installed.
17+
18+
```bash
19+
pip install -U "nvidia_modelopt[hf]"
20+
```
21+
22+
Quantize a model by passing [`NVIDIAModelOptConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
23+
24+
The example below only quantizes the weights to FP8.
25+
26+
```python
27+
import torch
28+
from diffusers import AutoModel, SanaPipeline, NVIDIAModelOptConfig
29+
30+
model_id = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
31+
dtype = torch.bfloat16
32+
33+
quantization_config = NVIDIAModelOptConfig(quant_type="FP8", quant_method="modelopt")
34+
transformer = AutoModel.from_pretrained(
35+
model_id,
36+
subfolder="transformer",
37+
quantization_config=quantization_config,
38+
torch_dtype=dtype,
39+
)
40+
pipe = SanaPipeline.from_pretrained(
41+
model_id,
42+
transformer=transformer,
43+
torch_dtype=dtype,
44+
)
45+
pipe.to("cuda")
46+
47+
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
48+
49+
prompt = "A cat holding a sign that says hello world"
50+
image = pipe(
51+
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
52+
).images[0]
53+
image.save("output.png")
54+
```
55+
56+
> **Note:**
57+
>
58+
> The quantization methods in NVIDIA-ModelOpt are designed to reduce the memory footprint of model weights using various QAT (Quantization-Aware Training) and PTQ (Post-Training Quantization) techniques while maintaining model performance. However, the actual performance gain during inference depends on the deployment framework (e.g., TRT-LLM, TensorRT) and the specific hardware configuration.
59+
>
60+
> More details can be found [here](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples).
61+
62+
## NVIDIAModelOptConfig
63+
64+
The `NVIDIAModelOptConfig` class accepts three parameters:
65+
- `quant_type`: A string value mentioning one of the quantization types below.
66+
- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`SD3Transformer2DModel`]'s pos_embed projection blocks, one would specify: `modules_to_not_convert=["pos_embed.proj.weight"]`.
67+
- `disable_conv_quantization`: A boolean value which when set to `True` disables quantization for all convolutional layers in the model. This is useful as channel and block quantization generally don't work well with convolutional layers (used with INT4, NF4, NVFP4). If you want to disable quantization for specific convolutional layers, use `modules_to_not_convert` instead.
68+
- `algorithm`: The algorithm to use for determining scale, defaults to `"max"`. You can check modelopt documentation for more algorithms and details.
69+
- `forward_loop`: The forward loop function to use for calibrating activation during quantization. If not provided, it relies on static scale values computed using the weights only.
70+
- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
71+
72+
## Supported quantization types
73+
74+
ModelOpt supports weight-only, channel and block quantization int8, fp8, int4, nf4, and nvfp4. The quantization methods are designed to reduce the memory footprint of the model weights while maintaining the performance of the model during inference.
75+
76+
Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation.
77+
78+
The quantization methods supported are as follows:
79+
80+
| **Quantization Type** | **Supported Schemes** | **Required Kwargs** | **Additional Notes** |
81+
|-----------------------|-----------------------|---------------------|----------------------|
82+
| **INT8** | `int8 weight only`, `int8 channel quantization`, `int8 block quantization` | `quant_type`, `quant_type + channel_quantize`, `quant_type + channel_quantize + block_quantize` |
83+
| **FP8** | `fp8 weight only`, `fp8 channel quantization`, `fp8 block quantization` | `quant_type`, `quant_type + channel_quantize`, `quant_type + channel_quantize + block_quantize` |
84+
| **INT4** | `int4 weight only`, `int4 block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize` | `channel_quantize = -1 is only supported for now`|
85+
| **NF4** | `nf4 weight only`, `nf4 double block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize + scale_channel_quantize` + `scale_block_quantize` | `channel_quantize = -1 and scale_channel_quantize = -1 are only supported for now` |
86+
| **NVFP4** | `nvfp4 weight only`, `nvfp4 block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize` | `channel_quantize = -1 is only supported for now`|
87+
88+
89+
Refer to the [official modelopt documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/) for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
90+
91+
## Serializing and Deserializing quantized models
92+
93+
To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method.
94+
95+
```python
96+
import torch
97+
from diffusers import AutoModel, NVIDIAModelOptConfig
98+
from modelopt.torch.opt import enable_huggingface_checkpointing
99+
100+
enable_huggingface_checkpointing()
101+
102+
model_id = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
103+
quant_config_fp8 = {"quant_type": "FP8", "quant_method": "modelopt"}
104+
quant_config_fp8 = NVIDIAModelOptConfig(**quant_config_fp8)
105+
model = AutoModel.from_pretrained(
106+
model_id,
107+
subfolder="transformer",
108+
quantization_config=quant_config_fp8,
109+
torch_dtype=torch.bfloat16,
110+
)
111+
model.save_pretrained('path/to/sana_fp8', safe_serialization=False)
112+
```
113+
114+
To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method.
115+
116+
```python
117+
import torch
118+
from diffusers import AutoModel, NVIDIAModelOptConfig, SanaPipeline
119+
from modelopt.torch.opt import enable_huggingface_checkpointing
120+
121+
enable_huggingface_checkpointing()
122+
123+
quantization_config = NVIDIAModelOptConfig(quant_type="FP8", quant_method="modelopt")
124+
transformer = AutoModel.from_pretrained(
125+
"path/to/sana_fp8",
126+
subfolder="transformer",
127+
quantization_config=quantization_config,
128+
torch_dtype=torch.bfloat16,
129+
)
130+
pipe = SanaPipeline.from_pretrained(
131+
"Efficient-Large-Model/Sana_600M_1024px_diffusers",
132+
transformer=transformer,
133+
torch_dtype=torch.bfloat16,
134+
)
135+
pipe.to("cuda")
136+
prompt = "A cat holding a sign that says hello world"
137+
image = pipe(
138+
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
139+
).images[0]
140+
image.save("output.png")
141+
```

docs/source/en/tutorials/autopipeline.md

Lines changed: 29 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -12,112 +12,56 @@ specific language governing permissions and limitations under the License.
1212

1313
# AutoPipeline
1414

15-
Diffusers provides many pipelines for basic tasks like generating images, videos, audio, and inpainting. On top of these, there are specialized pipelines for adapters and features like upscaling, super-resolution, and more. Different pipeline classes can even use the same checkpoint because they share the same pretrained model! With so many different pipelines, it can be overwhelming to know which pipeline class to use.
15+
[AutoPipeline](../api/models/auto_model) is a *task-and-model* pipeline that automatically selects the correct pipeline subclass based on the task. It handles the complexity of loading different pipeline subclasses without needing to know the specific pipeline subclass name.
1616

17-
The [AutoPipeline](../api/pipelines/auto_pipeline) class is designed to simplify the variety of pipelines in Diffusers. It is a generic *task-first* pipeline that lets you focus on a task ([`AutoPipelineForText2Image`], [`AutoPipelineForImage2Image`], and [`AutoPipelineForInpainting`]) without needing to know the specific pipeline class. The [AutoPipeline](../api/pipelines/auto_pipeline) automatically detects the correct pipeline class to use.
17+
This is unlike [`DiffusionPipeline`], a *model-only* pipeline that automatically selects the pipeline subclass based on the model.
1818

19-
For example, let's use the [dreamlike-art/dreamlike-photoreal-2.0](https://hf.co/dreamlike-art/dreamlike-photoreal-2.0) checkpoint.
20-
21-
Under the hood, [AutoPipeline](../api/pipelines/auto_pipeline):
22-
23-
1. Detects a `"stable-diffusion"` class from the [model_index.json](https://hf.co/dreamlike-art/dreamlike-photoreal-2.0/blob/main/model_index.json) file.
24-
2. Depending on the task you're interested in, it loads the [`StableDiffusionPipeline`], [`StableDiffusionImg2ImgPipeline`], or [`StableDiffusionInpaintPipeline`]. Any parameter (`strength`, `num_inference_steps`, etc.) you would pass to these specific pipelines can also be passed to the [AutoPipeline](../api/pipelines/auto_pipeline).
25-
26-
<hfoptions id="autopipeline">
27-
<hfoption id="text-to-image">
19+
[`AutoPipelineForImage2Image`] returns a specific pipeline subclass, (for example, [`StableDiffusionXLImg2ImgPipeline`]), which can only be used for image-to-image tasks.
2820

2921
```py
30-
from diffusers import AutoPipelineForText2Image
3122
import torch
32-
33-
pipe_txt2img = AutoPipelineForText2Image.from_pretrained(
34-
"dreamlike-art/dreamlike-photoreal-2.0", torch_dtype=torch.float16, use_safetensors=True
35-
).to("cuda")
36-
37-
prompt = "cinematic photo of Godzilla eating sushi with a cat in a izakaya, 35mm photograph, film, professional, 4k, highly detailed"
38-
generator = torch.Generator(device="cpu").manual_seed(37)
39-
image = pipe_txt2img(prompt, generator=generator).images[0]
40-
image
41-
```
42-
43-
<div class="flex justify-center">
44-
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-text2img.png"/>
45-
</div>
46-
47-
</hfoption>
48-
<hfoption id="image-to-image">
49-
50-
```py
5123
from diffusers import AutoPipelineForImage2Image
52-
from diffusers.utils import load_image
53-
import torch
54-
55-
pipe_img2img = AutoPipelineForImage2Image.from_pretrained(
56-
"dreamlike-art/dreamlike-photoreal-2.0", torch_dtype=torch.float16, use_safetensors=True
57-
).to("cuda")
58-
59-
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-text2img.png")
60-
61-
prompt = "cinematic photo of Godzilla eating burgers with a cat in a fast food restaurant, 35mm photograph, film, professional, 4k, highly detailed"
62-
generator = torch.Generator(device="cpu").manual_seed(53)
63-
image = pipe_img2img(prompt, image=init_image, generator=generator).images[0]
64-
image
65-
```
66-
67-
Notice how the [dreamlike-art/dreamlike-photoreal-2.0](https://hf.co/dreamlike-art/dreamlike-photoreal-2.0) checkpoint is used for both text-to-image and image-to-image tasks? To save memory and avoid loading the checkpoint twice, use the [`~DiffusionPipeline.from_pipe`] method.
6824

69-
```py
70-
pipe_img2img = AutoPipelineForImage2Image.from_pipe(pipe_txt2img).to("cuda")
71-
image = pipeline(prompt, image=init_image, generator=generator).images[0]
72-
image
25+
pipeline = AutoPipelineForImage2Image.from_pretrained(
26+
"RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.bfloat16, device_map="cuda",
27+
)
28+
print(pipeline)
29+
"StableDiffusionXLImg2ImgPipeline {
30+
"_class_name": "StableDiffusionXLImg2ImgPipeline",
31+
...
32+
"
7333
```
7434

75-
You can learn more about the [`~DiffusionPipeline.from_pipe`] method in the [Reuse a pipeline](../using-diffusers/loading#reuse-a-pipeline) guide.
76-
77-
<div class="flex justify-center">
78-
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-img2img.png"/>
79-
</div>
80-
81-
</hfoption>
82-
<hfoption id="inpainting">
35+
Loading the same model with [`DiffusionPipeline`] returns the [`StableDiffusionXLPipeline`] subclass. It can be used for text-to-image, image-to-image, or inpainting tasks depending on the inputs.
8336

8437
```py
85-
from diffusers import AutoPipelineForInpainting
86-
from diffusers.utils import load_image
8738
import torch
39+
from diffusers import DiffusionPipeline
8840

89-
pipeline = AutoPipelineForInpainting.from_pretrained(
90-
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True
91-
).to("cuda")
92-
93-
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-img2img.png")
94-
mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-mask.png")
95-
96-
prompt = "cinematic photo of a owl, 35mm photograph, film, professional, 4k, highly detailed"
97-
generator = torch.Generator(device="cpu").manual_seed(38)
98-
image = pipeline(prompt, image=init_image, mask_image=mask_image, generator=generator, strength=0.4).images[0]
99-
image
41+
pipeline = DiffusionPipeline.from_pretrained(
42+
"RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.bfloat16, device_map="cuda",
43+
)
44+
print(pipeline)
45+
"StableDiffusionXLPipeline {
46+
"_class_name": "StableDiffusionXLPipeline",
47+
...
48+
"
10049
```
10150

102-
<div class="flex justify-center">
103-
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-inpaint.png"/>
104-
</div>
51+
Check the [mappings](https://github.com/huggingface/diffusers/blob/130fd8df54f24ffb006d84787b598d8adc899f23/src/diffusers/pipelines/auto_pipeline.py#L114) to see whether a model is supported or not.
10552

106-
</hfoption>
107-
</hfoptions>
108-
109-
## Unsupported checkpoints
110-
111-
The [AutoPipeline](../api/pipelines/auto_pipeline) supports [Stable Diffusion](../api/pipelines/stable_diffusion/overview), [Stable Diffusion XL](../api/pipelines/stable_diffusion/stable_diffusion_xl), [ControlNet](../api/pipelines/controlnet), [Kandinsky 2.1](../api/pipelines/kandinsky.md), [Kandinsky 2.2](../api/pipelines/kandinsky_v22), and [DeepFloyd IF](../api/pipelines/deepfloyd_if) checkpoints.
112-
113-
If you try to load an unsupported checkpoint, you'll get an error.
53+
Trying to load an unsupported model returns an error.
11454

11555
```py
116-
from diffusers import AutoPipelineForImage2Image
11756
import torch
57+
from diffusers import AutoPipelineForImage2Image
11858

11959
pipeline = AutoPipelineForImage2Image.from_pretrained(
120-
"openai/shap-e-img2img", torch_dtype=torch.float16, use_safetensors=True
60+
"openai/shap-e-img2img", torch_dtype=torch.float16,
12161
)
12262
"ValueError: AutoPipeline can't find a pipeline linked to ShapEImg2ImgPipeline for None"
12363
```
64+
65+
There are three types of [AutoPipeline](../api/models/auto_model) classes, [`AutoPipelineForText2Image`], [`AutoPipelineForImage2Image`] and [`AutoPipelineForInpainting`]. Each of these classes have a predefined mapping, linking a pipeline to their task-specific subclass.
66+
67+
When [`~AutoPipelineForText2Image.from_pretrained`] is called, it extracts the class name from the `model_index.json` file and selects the appropriate pipeline subclass for the task based on the mapping.

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@
132132
"gguf>=0.10.0",
133133
"torchao>=0.7.0",
134134
"bitsandbytes>=0.43.3",
135+
"nvidia_modelopt[hf]>=0.33.1",
135136
"regex!=2019.12.17",
136137
"requests",
137138
"tensorboard",
@@ -244,6 +245,7 @@ def run(self):
244245
extras["gguf"] = deps_list("gguf", "accelerate")
245246
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
246247
extras["torchao"] = deps_list("torchao", "accelerate")
248+
extras["nvidia_modelopt"] = deps_list("nvidia_modelopt[hf]")
247249

248250
if os.name == "nt": # windows
249251
extras["flax"] = [] # jax is not supported on windows

src/diffusers/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
is_k_diffusion_available,
1414
is_librosa_available,
1515
is_note_seq_available,
16+
is_nvidia_modelopt_available,
1617
is_onnx_available,
1718
is_opencv_available,
1819
is_optimum_quanto_available,
@@ -111,6 +112,18 @@
111112
else:
112113
_import_structure["quantizers.quantization_config"].append("QuantoConfig")
113114

115+
try:
116+
if not is_torch_available() and not is_accelerate_available() and not is_nvidia_modelopt_available():
117+
raise OptionalDependencyNotAvailable()
118+
except OptionalDependencyNotAvailable:
119+
from .utils import dummy_nvidia_modelopt_objects
120+
121+
_import_structure["utils.dummy_nvidia_modelopt_objects"] = [
122+
name for name in dir(dummy_nvidia_modelopt_objects) if not name.startswith("_")
123+
]
124+
else:
125+
_import_structure["quantizers.quantization_config"].append("NVIDIAModelOptConfig")
126+
114127
try:
115128
if not is_onnx_available():
116129
raise OptionalDependencyNotAvailable()
@@ -795,6 +808,14 @@
795808
else:
796809
from .quantizers.quantization_config import QuantoConfig
797810

811+
try:
812+
if not is_nvidia_modelopt_available():
813+
raise OptionalDependencyNotAvailable()
814+
except OptionalDependencyNotAvailable:
815+
from .utils.dummy_nvidia_modelopt_objects import *
816+
else:
817+
from .quantizers.quantization_config import NVIDIAModelOptConfig
818+
798819
try:
799820
if not is_onnx_available():
800821
raise OptionalDependencyNotAvailable()

src/diffusers/dependency_versions_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"gguf": "gguf>=0.10.0",
4040
"torchao": "torchao>=0.7.0",
4141
"bitsandbytes": "bitsandbytes>=0.43.3",
42+
"nvidia_modelopt[hf]": "nvidia_modelopt[hf]>=0.33.1",
4243
"regex": "regex!=2019.12.17",
4344
"requests": "requests",
4445
"tensorboard": "tensorboard",

0 commit comments

Comments
 (0)