-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Fix Flux Context Parallel Bug (Incoherent Image Generation) #12443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thanks for the PR. From the looks of it, it does seem like it is fully LLM-generated. Also, FWIW, we strive to keep our modeling implementations simple so, I am not sure yet if the changes align with that philosophy. @DN6 WDYT? |
Hi @mali-afridi the issue seems to be because an unsupported backend is being used with CP. This snippet should work import torch
from diffusers import FluxPipeline
from diffusers import ContextParallelConfig
try:
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
torch.cuda.set_device(device)
device = torch.device("cuda")
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.to(device)
pipe.transformer.set_attention_backend("_native_cudnn")
pipe.transformer.enable_parallelism(config=ContextParallelConfig(ring_degree=2))
prompt = "A picture of a cat holding a sign that says hello"
# Must specify generator so all ranks start with same latents (or pass your own)
generator = torch.Generator().manual_seed(42)
image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0, generator=generator).images[0]
if rank == 0:
image.save("output.png")
except Exception as e:
raise e
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group() I've opened a PR to raise an error when an incompatible backend is used: #12446 |
Interesting, yeah the |
For Qwen Reproducibility:
|
What does this PR do?
Fix Context Parallelism: Implement Ring Attention Pattern for Coherent Multi-GPU Generation
🐛 Problem
I did some testings of the https://huggingface.co/docs/diffusers/main/training/distributed_inference on main branch.
Context parallelism in diffusers was producing fragmented/split images when using multiple GPUs. Instead of generating a single coherent image, each GPU was independently generating its own portion, resulting in visible seams or completely different content in each image segment.
Example: Running with
torchrun --nproc-per-node=2
would produce an image that looked like two different images side-by-side rather than one unified image.🔍 Root Cause Analysis
The issue stems from how attention was computed in context parallel mode:
Before (Broken):
Each GPU was computing attention using only its local sequence chunk for Q, K, and V. This meant:
✅ Solution: Ring Attention Pattern
This PR implements the Ring Attention pattern where:
After (Fixed):
📝 Implementation Details
The fix is applied directly in the attention processors after rotary embeddings but before attention computation:
FluxAttnProcessor (
transformer_flux.py
):🧪 Testing
For testing, run the following with torchrun --nproc-per-node=2:
Before Fix:
Result: Two different images side-by-side in output
After Fix:
Result: Single coherent image matching single-GPU output
Summary: This PR fixes context parallelism by ensuring each GPU's attention queries can access the full key-value context from all GPUs, implementing the Ring Attention pattern for coherent multi-GPU image generation.
Note:I have observed that some tensors in QwenImage cannot be divided by world_size
(encoder_hidden_states, encoder_hidden_mask etc.)
. I am also willing to make a new PR for the QwenImage support for context parallel by padding the tensors to be divisible by world size, similar to chengzeyi/ParaAttention#53 if you guys want to.Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sayakpaul @a-r-r-o-w