Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 67 additions & 24 deletions docs/source/en/training/distributed_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ By selectively loading and unloading the models you need at a given stage and sh

Use [`~ModelMixin.set_attention_backend`] to switch to a more optimized attention backend. Refer to this [table](../optimization/attention_backends#available-backends) for a complete list of available backends.

Most attention backends are compatible with context parallelism. Open an [issue](https://github.com/huggingface/diffusers/issues/new) if a backend is not compatible.

### Ring Attention

Key (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.
Expand All @@ -245,40 +247,60 @@ Pass a [`ContextParallelConfig`] to the `parallel_config` argument of the transf

```py
import torch
from diffusers import AutoModel, QwenImagePipeline, ContextParallelConfig

try:
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig

def setup_distributed():
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)

transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2))
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda")
pipeline.transformer.set_attention_backend("flash")
return device

def main():
device = setup_distributed()
world_size = dist.get_world_size()

pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
)
pipeline.transformer.set_attention_backend("_native_cudnn")

cp_config = ContextParallelConfig(ring_degree=world_size)
pipeline.transformer.enable_parallelism(config=cp_config)

prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""

# Must specify generator so all ranks start with same latents (or pass your own)
generator = torch.Generator().manual_seed(42)
image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]

if rank == 0:
image.save("output.png")

except Exception as e:
print(f"An error occurred: {e}")
torch.distributed.breakpoint()
raise

finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
image = pipeline(
prompt,
guidance_scale=3.5,
num_inference_steps=50,
generator=generator,
).images[0]

if dist.get_rank() == 0:
image.save(f"output.png")

if dist.is_initialized():
dist.destroy_process_group()


if __name__ == "__main__":
main()
```

The script above needs to be run with a distributed launcher, such as [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html), that is compatible with PyTorch. `--nproc-per-node` is set to the number of GPUs available.

/```shell
`torchrun --nproc-per-node 2 above_script.py`.
/```

### Ulysses Attention

[Ulysses Attention](https://huggingface.co/papers/2309.14509) splits a sequence across GPUs and performs an *all-to-all* communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.
Expand All @@ -288,5 +310,26 @@ finally:
Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].

```py
# Depending on the number of GPUs available.
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
```

### parallel_config

Pass `parallel_config` during model initialization to enable context parallelism.

```py
CKPT_ID = "black-forest-labs/FLUX.1-dev"

cp_config = ContextParallelConfig(ring_degree=2)
transformer = AutoModel.from_pretrained(
CKPT_ID,
subfolder="transformer",
torch_dtype=torch.bfloat16,
parallel_config=cp_config
)

pipeline = DiffusionPipeline.from_pretrained(
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
).to(device)
```