Skip to content

Commit f771be1

Browse files
Flux fp16 inference fix (#9097)
* clipping for fp16 * fix typo * added fp16 inference to docs * fix docs typo * include link for fp16 investigation --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent b6fac9d commit f771be1

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

docs/source/en/api/pipelines/flux.md

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Both checkpoints have slightly difference usage which we detail below.
3737

3838
```python
3939
import torch
40-
from diffusers import FluxPipeline
40+
from diffusers import FluxPipeline
4141

4242
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
4343
pipe.enable_model_cpu_offload()
@@ -61,7 +61,7 @@ out.save("image.png")
6161

6262
```python
6363
import torch
64-
from diffusers import FluxPipeline
64+
from diffusers import FluxPipeline
6565

6666
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
6767
pipe.enable_model_cpu_offload()
@@ -77,6 +77,34 @@ out = pipe(
7777
out.save("image.png")
7878
```
7979

80+
## Running FP16 inference
81+
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
82+
83+
FP16 inference code:
84+
```python
85+
import torch
86+
from diffusers import FluxPipeline
87+
88+
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) # can replace schnell with dev
89+
# to run on low vram GPUs (i.e. between 4 and 32 GB VRAM)
90+
pipe.enable_sequential_cpu_offload()
91+
pipe.vae.enable_slicing()
92+
pipe.vae.enable_tiling()
93+
94+
pipe.to(torch.float16) # casting here instead of in the pipeline constructor because doing so in the constructor loads all models into CPU memory at once
95+
96+
prompt = "A cat holding a sign that says hello world"
97+
out = pipe(
98+
prompt=prompt,
99+
guidance_scale=0.,
100+
height=768,
101+
width=1360,
102+
num_inference_steps=4,
103+
max_sequence_length=256,
104+
).images[0]
105+
out.save("image.png")
106+
```
107+
80108
## Single File Loading for the `FluxTransformer2DModel`
81109

82110
The `FluxTransformer2DModel` supports loading checkpoints in the original format shipped by Black Forest Labs. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
@@ -134,4 +162,4 @@ image.save("flux-fp8-dev.png")
134162

135163
[[autodoc]] FluxPipeline
136164
- all
137-
- __call__
165+
- __call__

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ def forward(
125125
gate = gate.unsqueeze(1)
126126
hidden_states = gate * self.proj_out(hidden_states)
127127
hidden_states = residual + hidden_states
128+
if hidden_states.dtype == torch.float16:
129+
hidden_states = hidden_states.clip(-65504, 65504)
128130

129131
return hidden_states
130132

@@ -223,6 +225,8 @@ def forward(
223225

224226
context_ff_output = self.ff_context(norm_encoder_hidden_states)
225227
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
228+
if encoder_hidden_states.dtype == torch.float16:
229+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
226230

227231
return encoder_hidden_states, hidden_states
228232

0 commit comments

Comments
 (0)