In [None]:
!pip install -qU diffusers transformers accelerate bitsandbytes gguf torchao

# Quantization

**Quantization** techniques focus on representing data with less information while also trying not to lose too much accuracy. This often means converting a data type to represent the same information with fewer bits.

# bitsandbytes

**bitsandbytes** is the easiest option for quantizing a model to 8-bit or 4-bit.

8-bit quantization multiplies outliers in fp16 with non-outlier in int8, converts the non-outlier values back to fp16, and then adds them together to return the weights in fp16. This reduces the degradative effect outlier values have on a model's performance.

4-bit quantization compresses a model even further, and it is commonly used with **QLoRA** to finetune quantized LLMs.

##### 8-bit

Quantizing a model in 8-bit halves the memory-usage.

bitsandbytes is supported in both `transformers` and `diffusers`, so we can quantize both the `FluxTransformer2DModel` and `T5EncoderModel`.

For Ada and higher-series GPUs, we can change `torch_dtype` to `torch.bfloat16`.

The `CLIPTextModel` and `AutoencoderKL` are not quantized because they are already small in size and because `AutoencoderKL` only has a few `torch.nn.Linear` layers.

In [None]:
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig

from diffusers import FluxTransformer2DModel
from transformers import T5EncoderModel

quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True)

text_encoder_2_8bit = T5EncoderModel.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    subfolder='text_encoder_2',
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)

transformer_8bit = FluxTransformer2DModel.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    subfolder='transformer',
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. We can change the data type of these modules with the `torch_dtype` parameter:

In [None]:
transformer_8bit = FluxTransformer2DModel.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    subfolder='transformer',
    quantization_config=quant_config,
    torch_dtype=torch.float32, # new here
)

Now we can generate an image using our quantized models. Setting `device_map="auto"` automatically fills all available space on the GPUs first, then the CPU, and finally, the hard drive if there is still not enough memory.

In [None]:
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    transformer=transformer_8bit,
    text_encoder_2=text_encoder_2_8bit,
    torch_dtype=torch.float16,
    device_map="auto",
)

In [None]:
pipe_kwargs = {
    'prompt': 'a cat holding a sign that says hello world',
    'height': 1024,
    'width': 1024,
    'guidance_scale': 3.5,
    'num_inference_steps': 50,
    'max_sequence_length': 512,
}

image = pipe(
    **pipe_kwargs,
    generator=torch.manual_sed(111),
).images[0]
image

When there is enough memory, we can also directly move the pipeline to the GPU with `.to('cuda')` and apply `enable_model_cpu_offload()` to optimize GPU memory usage.

Once a model is quantized, we can push the model to the Hub with the `push_to_hub()` method. The quantization `config.json` file is pushed first, followed by the quantized model weights. We can also save the serialized 8-bit models locally with `save_pretrained()`.

We can check our memory footprint with the `get_memory_footprint` method:

In [None]:
print(model.get_memory_footprint())

Quantized models can be loaded from the `from_pretrained()` method without needing to specify the `quantization_config` parameters:

In [None]:
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig

quant_config = BitsAndBytesConfig(load_in_8bit=True)

model_8bit = FluxTransformer2DModel.from_pretrained(
    'hf-internal-testing/flux.1-dev-int8-pkg',
    subfolder='transformer'
)

##### 4-bit

Quantizing a model in 4-bit reduces our memory-usage by 4 times:

In [None]:
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig

from diffusers import FluxTransformer2DModel
from transformers import T5EncoderModel

quant_config = TransformersBitsAndBytesConfig(load_in_4bit=True)

text_encoder_2_4bit = T5EncoderModel.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    subfolder='text_encoder_2',
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

quant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True)

transformer_4bit = FluxTransformer2DModel.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    subfolder='transformer',
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

Same as 8-bit, we can change the data of all other modules such as `torch.nn.LayerNorm` with the `torch_dtype` parameter:

In [None]:
transformer_4bit = FluxTransformer2DModel.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    subfolder='transformer',
    quantization_config=quant_config,
    torch_dtype=torch.float32, # new here
)

Generate an image using our 4-bit quantized models:

In [None]:
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    transformer=transformer_4bit,
    text_encoder_2=text_encoder_2_4bit,
    torch_dtype=torch.float16,
    device_map="auto",
)

In [None]:
pipe_kwargs = {
    "prompt": "A cat holding a sign that says hello world",
    "height": 1024,
    "width": 1024,
    "guidance_scale": 3.5,
    "num_inference_steps": 50,
    "max_sequence_length": 512,
}

image = pipe(
    **pipe_kwargs,
    generator=torch.manual_seed(111),
).images[0]
image

Same as 8-bit, we can load quantized models from the `from_pretrained()` without specifying the `quantization_config` parameters:

In [None]:
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig

quant_config = BitsAndBytesConfig(load_in_4bit=True)

model_4bit = FluxTransformer2DModel.from_pretrained(
    'hf-internal-testing/flux.1-dev-nf4-pkg',
    subfolder='transformer'
)

## 8-bit (LLM.int8() algorithm)

### Outlier threshold

An **"outlier"** is a hidden state value greater than a certain threshold, and these values are computed in fp16. While the values are usually normally distributed (`[-3.5, 3.5]`), this distribution can be very different for large models (`[-60, 6]` or `[6, 60]`). 8-bit quantization works well for value around 5, but beyond that, there is a significant performance penalty. A good default threshold value is 6, but a lower threshold may be needed for more unstable models (small models or finetuning).

To find the best threshold for a model, we can experiment with the `llm_int8_threshold` parameter in `BitsAndBytesConfig`:

In [None]:
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig

from diffusers import FluxTransformer2DModel
from transformers import T5EncoderModel


quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True)

text_encoder_2_8bit = T5EncoderModel.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    subfolder='text_encoder_2',
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

quantization_config = DiffusersBitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=10, # change here
)

model_8bit = FluxTransformer2DModel.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    subfolder='transformer',
    quantization_config=quantization_config,
)

In [None]:
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    transformer=transformer_8bit,
    text_encoder_2=text_encoder_2_8bit,
    torch_dtype=torch.float16,
    device_map="auto",
)

In [None]:
pipe_kwargs = {
    "prompt": "A cat holding a sign that says hello world",
    "height": 1024,
    "width": 1024,
    "guidance_scale": 3.5,
    "num_inference_steps": 50,
    "max_sequence_length": 512,
}

image = pipe(
    **pipe_kwargs,
    generator=torch.manual_seed(111),
).images[0]
image

### Skip module conversion

For some models, we do not need to quantize every module to 8-bit which can actually cause instability. For example, for diffusion models like **Stable Diffusion 3**, the `proj_out` module can be skipped using the `llm_int8_skip_modules` parameter in `BitsAndBytesConfig`:

In [None]:
from diffusers import SD3Transformer2DModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_skip_modules=['proj_out']
)

model_8bit = SD3Transformer2DModel.from_pretrained(
    'stabilityai/stable-diffusion-3-medium-diffusers',
    subfolder='transformer',
    quantization_config=quantization_config,
)

## 4-bit (QLoRA algorithm)

### Compute data type

To speed up computation, we can change the data type from float32 (the default value) to bf16 using the `bnb_4bit_compute_dtype` parameter in `BitsAndBytesConfig`:

In [None]:
import torch
from diffusers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

### Normal Float 4 (NF4)

NF4 is a 4-bit data type from the [**QLoRA**](https://hf.co/papers/2305.14314), adapted for weights initialized from a normal distribution. We should use NF4 for training 4-bit base models. This can be configured with the `bnb_4bit_quant_type` parameter:

In [None]:
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig

from diffusers import FluxTransformer2DModel
from transformers import T5EncoderModel

quant_config = TransformersBitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
)

text_encoder_2_4bit = T5EncoderModel.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    subfolder='text_encoder_2',
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

quant_config = DiffusersBitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4'
)

transformer_4bit = FluxTransformer2DModel.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    subfolder='transformer',
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

For inference, the `bnb_4bit_quant_type` does not have a huge impact on performance. However, to remain consistent with the model weights, we should use the `bnb_4bit_compute_dtype` and `torch_dtype` values.

In [None]:
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    transformer=transformer_4bit,
    text_encoder_2=text_encoder_2_4bit,
    torch_dtype=torch.float16,
    device_map="auto",
)

In [None]:
pipe_kwargs = {
    "prompt": "A cat holding a sign that says hello world",
    "height": 1024,
    "width": 1024,
    "guidance_scale": 3.5,
    "num_inference_steps": 50,
    "max_sequence_length": 512,
}

image = pipe(
    **pipe_kwargs,
    generator=torch.manual_seed(111),
).images[0]
image

### Nested quantization

Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an additional 0.4 bits/parameter:

In [None]:
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig

from diffusers import FluxTransformer2DModel
from transformers import T5EncoderModel

quant_config = TransformersBitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True, # new here
)

text_encoder_2_4bit = T5EncoderModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="text_encoder_2",
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

quant_config = DiffusersBitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True, # new here
)

transformer_4bit = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

In [None]:
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    transformer=transformer_4bit,
    text_encoder_2=text_encoder_2_4bit,
    torch_dtype=torch.float16,
    device_map="auto",
)

In [None]:
pipe_kwargs = {
    "prompt": "A cat holding a sign that says hello world",
    "height": 1024,
    "width": 1024,
    "guidance_scale": 3.5,
    "num_inference_steps": 50,
    "max_sequence_length": 512,
}

image = pipe(
    **pipe_kwargs,
    generator=torch.manual_seed(111),
).images[0]
image

## Dequantizing bitsandbytes models

Once quantized, we can dequantize a model to its original precision, but this may result in a small loss of quality.

In [None]:
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig

from diffusers import FluxTransformer2DModel
from transformers import T5EncoderModel

quant_config = TransformersBitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
)

text_encoder_2_4bit = T5EncoderModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="text_encoder_2",
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

quant_config = DiffusersBitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
)

transformer_4bit = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)


# Dequantize model
text_encoder_2_4bit.dequantize()
transformer_4bit.dequantize()

In [None]:
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    transformer=transformer_4bit,
    text_encoder_2=text_encoder_2_4bit,
    torch_dtype=torch.float16,
    device_map="auto",
)

In [None]:
pipe_kwargs = {
    "prompt": "A cat holding a sign that says hello world",
    "height": 1024,
    "width": 1024,
    "guidance_scale": 3.5,
    "num_inference_steps": 50,
    "max_sequence_length": 512,
}

image = pipe(
    **pipe_kwargs,
    generator=torch.manual_seed(111),
).images[0]
image

# GGUF

The **GGUF** file format is typically used to store models for inference with [GGML](https://github.com/ggml-org/ggml) and supports a variety of block wise quantization options.

Diffusers supports loading checkpoints prequantized and saved in the GGUF format via `from_single_file` loading with Model classes.

Since GGUF is a single file format, we will use `from_single_file` to load the model and pass in the `GGUFQuantizationConfig`.

When using GGUF checkpoints, the quantized weights remain in a low memory `dtype` (typically `torch.uint8`) and are dynamically dequantized and cast to the configured `compute_dtype` during each module's forward pass through the model. The `GGUFQuantizationConfig` allows us to set the `compute_dtype`.

In [None]:
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig

ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"

quant_config = GGUFQuantizationConfig(compute_dtype=torch.bfloat16)

transformer = FluxTransformer2DModel.from_single_file(
    ckpt_path,
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
)

pipe = FluxPipeline.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    transformer=transformer,
    torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()

In [None]:
prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt,
    generator=torch.manual_seed(111)
).images[0]
image

# torchao

[**TorchAO**](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, `FullyShardedDataParallel` (FSDP), and more.

Quantize a model by passing `TorchAoConfig` to `from_pretrained()`. This works for any model in anyu modality, as long as it supports loading with HuggingFace Accelerate library and contains `torch.nn.Linear` layers.

In [None]:
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig

model_id = 'black-forest-labs/FLUX.1-dev'
dtype = torch.bfloat16

quant_config = TorchAoConfig('int8wo')

transformer = FluxTransformer2DModel.from_pretrained(
    model_id,
    subfolder='transformer',
    quantization_config=quant_config,
    torch_dtype=dtype,
)

pipe = FluxPipeline.from_pretrained(
    model_id,
    transformer=transformer,
    torch_dtype=dtype,
)
pipe.to('cuda')

# Without quantization: ~31.447 GB
# With quantization: ~20.40 GB
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")

In [None]:
prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt,
    num_inference_steps=50,
    guidance_scale=4.5,
    max_sequence_length=512
).images[0]
image

TorchAO is fully compatible with `torch.compile`, setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code:

In [None]:
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig

model_id = 'black-forest-labs/FLUX.1-dev'
dtype = torch.bfloat16

quant_config = TorchAoConfig('int8wo')

transformer = FluxTransformer2DModel.from_pretrained(
    model_id,
    subfolder='transformer',
    quantization_config=quant_config,
    torch_dtype=dtype,
)
# apply torch.compile
transformer = torch.compile(
    transformer,
    mode='max-autotune',
    fullgraph=True,
)

pipe = FluxPipeline.from_pretrained(
    model_id,
    transformer=transformer,
    torch_dtype=dtype,
)
pipe.to('cuda')

In [None]:
prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt,
    num_inference_steps=50,
    guidance_scale=4.5,
    max_sequence_length=512
).images[0]
image

TorchAO also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes.

The `TorchAoConfig` class accepts:
* `quant_type`: a string value mentioning one of the quantization types below.
* `modules_to_not_convert`: a list of module full/partial module names for which quantization should not be performed.
* `kwargs`: a dictionary of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.

## Supported quantization types

TorchAO supports
* **weight-only quantization**, which stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation.
* **dynamic-activation quantization**, which stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended totest different models thoroughly.

The supported quantization methods:

| **Category** | **Full Function Names** | **Shorthands** |
|--------------|-------------------------|----------------|
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row` |
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |

Some quantization methods are aliases. For example, `int8wo` we used above is commonly used shorthand for `int8_weight_only`.

## Serializing and deserializing quantized models

To serialize a quantized model in a given dtype, we first load the model with the desired quantization dtype and then save it using the `save_pretrained()` method:

In [None]:
import torch
from diffusers import FluxTransformer2DModel, TorchAoConfig

quant_config = TorchAoConfig('int8wo')

transformer = FluxTransformer2DModel.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    subfolder='transformer',
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
)
transformer.save_pretrained(
    '/path/to/flux_int8wo',
    safe_serialization=False
)

To load a serialized quantized model,

In [None]:
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel

transformer = FluxTransformer2DModel.from_pretrained(
    '/path/to/flux_int8wo',
    torch_dtype=torch.bfloat16,
    use_safetensors=False,
)

pipe = FluxPipeline.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    transformer=transformer,
    torch_dtype=torch.bfloat16,
)
pipe.to('cuda')

In [None]:
prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt, \
    num_inference_steps=30,
    guidance_scale=7.0
).images[0]
image

Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them.

In order to work around this, we can load the state dict manually into the model. However, this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.

In [None]:
import torch
from accelerate import init_empty_weights
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig

quant_config = TorchAoConfig('uint4wo')

# Serialize the model
transformer = FluxTransformer2DModel.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    subfolder='transformer',
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
)
transformer.save_pretrained(
    '/path/to/flux_uint4wo',
    safe_serialization=False,
    max_shard_size="50GB"
)

In [None]:
# Load the model
state_dict = torch.load(
    'path/to/flux_uint4wo/diffusion_pytorch_model.bin',
    weights_only=False,
    map_location='cpu'
)

with init_empty_weights():
    transformer = FluxTransformer2DModel.from_config(
        '/path/to/flux_uint4wo/config.json'
    )
transformer.load_state_dict(
    state_dict,
    strict=True,
    assign=True,
)

pipe = FluxPipeline.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    transformer=transformer,
    torch_dtype=torch.bfloat16,
)
pipe.to('cuda')

In [None]:
prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt, \
    num_inference_steps=30,
    guidance_scale=7.0
).images[0]
image