Skip to content

Sage attention fails with ring-attention. #13506

@songh11

Description

@songh11

Describe the bug

When using Sage attention (SAGE or SAGE_HUB) combined with ring attention (context parallel),
the forward pass fails because sageattn CUDA kernels internally call tensor.data_ptr(),
which assumes storage_offset == 0.
In ring attention, funcol.all_gather_tensor(...).chunk(world_size) produces tensor views
with non-zero storage_offset. When these views are passed to sageattn, the kernel
reads from the wrong memory address.

Reproduction

import torch
import torch.distributed as dist
from diffusers import FluxPipeline
from diffusers.models._modeling_parallel import ContextParallelConfig
import argparse

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--ring",    type=int, default=1)
    parser.add_argument("--ulysses", type=int, default=2)
    parser.add_argument("--sage", action="store_true")
    args = parser.parse_args()

    # 1. init dist — torchrun injects RANK / LOCAL_RANK / WORLD_SIZE automatically
    dist.init_process_group(backend="nccl")
    local_rank = dist.get_rank() % torch.cuda.device_count()
    torch.cuda.set_device(local_rank)

    # 2. load model
    pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev/", torch_dtype=torch.bfloat16)
    pipe.to(f"cuda:{local_rank}") #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power

    # 3. enable sage
    if args.sage:
        pipe.transformer.set_attention_backend("sage")

    # 4. build parallel config via optkit and enable CP on the transformer
    parallelism_config = ContextParallelConfig(
        ring_degree=args.ring,
        ulysses_degree=args.ulysses,
    )
    pipe.transformer.enable_parallelism(config=parallelism_config)

    prompt = "A cat holding a sign that says hello world"
    image = pipe(
        prompt,
        height=1024,
        width=1024,
        guidance_scale=3.5,
        num_inference_steps=50,
        max_sequence_length=512,
        generator=torch.Generator("cpu").manual_seed(0)
    ).images[0]
    image.save("flux-dev.png")

    dist.destroy_process_group()

if __name__ == "__main__":
    main()
torchrun --nproc_per_node=2 ./demo.py --ring 2 --ulysses 1 --sage

Logs

[rank1]: Traceback (most recent call last):
[rank1]:   File "/root/workspace/optkit/examples/flux/./flux_context_parallel.py", line 37, in <module>
[rank1]:     image = pipe(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/flux/pipeline_flux.py", line 949, in __call__
[rank1]:     noise_pred = self.transformer(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/diffusers/hooks/hooks.py", line 189, in new_forward
[rank1]:     output = function_reference.forward(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/diffusers/utils/peft_utils.py", line 315, in wrapper
[rank1]:     result = forward_fn(self, *args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/diffusers/models/transformers/transformer_flux.py", line 726, in forward
[rank1]:     encoder_hidden_states, hidden_states = block(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/diffusers/models/transformers/transformer_flux.py", line 453, in forward
[rank1]:     attention_outputs = self.attn(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/diffusers/models/transformers/transformer_flux.py", line 352, in forward
[rank1]:     return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/diffusers/models/transformers/transformer_flux.py", line 118, in __call__
[rank1]:     hidden_states = dispatch_attention_fn(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/diffusers/models/attention_dispatch.py", line 432, in dispatch_attention_fn
[rank1]:     return backend_fn(**kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/diffusers/models/attention_dispatch.py", line 3262, in _sage_attention
[rank1]:     out = _templated_context_parallel_attention(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/diffusers/models/attention_dispatch.py", line 2269, in _templated_context_parallel_attention
[rank1]:     return TemplatedRingAttention.apply(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 576, in apply
[rank1]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/diffusers/models/attention_dispatch.py", line 1935, in forward
[rank1]:     lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
[rank1]: torch.AcceleratorError: CUDA error: an illegal memory access was encountered

System Info

  • 🤗 Diffusers version: 0.37.1
  • Platform: Linux-5.10.134-15.al8.x86_64-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.12
  • PyTorch version (GPU?): 2.8.0+cu128 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 1.11.0
  • Transformers version: 5.5.4
  • Accelerate version: 1.13.0
  • PEFT version: 0.19.1
  • Bitsandbytes version: not installed
  • Safetensors version: 0.7.0
  • xFormers version: not installed
  • Accelerator: NVIDIA GeForce RTX 4090, 49140 MiB
    NVIDIA GeForce RTX 4090, 49140 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions