Describe the bug
When using Sage attention (SAGE or SAGE_HUB) combined with ring attention (context parallel),
the forward pass fails because sageattn CUDA kernels internally call tensor.data_ptr(),
which assumes storage_offset == 0.
In ring attention, funcol.all_gather_tensor(...).chunk(world_size) produces tensor views
with non-zero storage_offset. When these views are passed to sageattn, the kernel
reads from the wrong memory address.
Reproduction
import torch
import torch.distributed as dist
from diffusers import FluxPipeline
from diffusers.models._modeling_parallel import ContextParallelConfig
import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--ring", type=int, default=1)
parser.add_argument("--ulysses", type=int, default=2)
parser.add_argument("--sage", action="store_true")
args = parser.parse_args()
# 1. init dist — torchrun injects RANK / LOCAL_RANK / WORLD_SIZE automatically
dist.init_process_group(backend="nccl")
local_rank = dist.get_rank() % torch.cuda.device_count()
torch.cuda.set_device(local_rank)
# 2. load model
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev/", torch_dtype=torch.bfloat16)
pipe.to(f"cuda:{local_rank}") #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
# 3. enable sage
if args.sage:
pipe.transformer.set_attention_backend("sage")
# 4. build parallel config via optkit and enable CP on the transformer
parallelism_config = ContextParallelConfig(
ring_degree=args.ring,
ulysses_degree=args.ulysses,
)
pipe.transformer.enable_parallelism(config=parallelism_config)
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=50,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("flux-dev.png")
dist.destroy_process_group()
if __name__ == "__main__":
main()
torchrun --nproc_per_node=2 ./demo.py --ring 2 --ulysses 1 --sage
Logs
[rank1]: Traceback (most recent call last):
[rank1]: File "/root/workspace/optkit/examples/flux/./flux_context_parallel.py", line 37, in <module>
[rank1]: image = pipe(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/flux/pipeline_flux.py", line 949, in __call__
[rank1]: noise_pred = self.transformer(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/diffusers/hooks/hooks.py", line 189, in new_forward
[rank1]: output = function_reference.forward(*args, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/diffusers/utils/peft_utils.py", line 315, in wrapper
[rank1]: result = forward_fn(self, *args, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/diffusers/models/transformers/transformer_flux.py", line 726, in forward
[rank1]: encoder_hidden_states, hidden_states = block(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/diffusers/models/transformers/transformer_flux.py", line 453, in forward
[rank1]: attention_outputs = self.attn(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/diffusers/models/transformers/transformer_flux.py", line 352, in forward
[rank1]: return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/diffusers/models/transformers/transformer_flux.py", line 118, in __call__
[rank1]: hidden_states = dispatch_attention_fn(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/diffusers/models/attention_dispatch.py", line 432, in dispatch_attention_fn
[rank1]: return backend_fn(**kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/diffusers/models/attention_dispatch.py", line 3262, in _sage_attention
[rank1]: out = _templated_context_parallel_attention(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/diffusers/models/attention_dispatch.py", line 2269, in _templated_context_parallel_attention
[rank1]: return TemplatedRingAttention.apply(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 576, in apply
[rank1]: return super().apply(*args, **kwargs) # type: ignore[misc]
[rank1]: File "/usr/local/lib/python3.10/dist-packages/diffusers/models/attention_dispatch.py", line 1935, in forward
[rank1]: lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
[rank1]: torch.AcceleratorError: CUDA error: an illegal memory access was encountered
System Info
- 🤗 Diffusers version: 0.37.1
- Platform: Linux-5.10.134-15.al8.x86_64-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.10.12
- PyTorch version (GPU?): 2.8.0+cu128 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 1.11.0
- Transformers version: 5.5.4
- Accelerate version: 1.13.0
- PEFT version: 0.19.1
- Bitsandbytes version: not installed
- Safetensors version: 0.7.0
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 4090, 49140 MiB
NVIDIA GeForce RTX 4090, 49140 MiB
- Using GPU in script?:
- Using distributed or parallel set-up in script?:
Who can help?
No response
Describe the bug
When using Sage attention (
SAGEorSAGE_HUB) combined with ring attention (context parallel),the forward pass fails because
sageattnCUDA kernels internally calltensor.data_ptr(),which assumes
storage_offset == 0.In ring attention,
funcol.all_gather_tensor(...).chunk(world_size)produces tensor viewswith non-zero
storage_offset. When these views are passed tosageattn, the kernelreads from the wrong memory address.
Reproduction
Logs
System Info
NVIDIA GeForce RTX 4090, 49140 MiB
Who can help?
No response