Skip to content

fix(ROCm): stream sync race in group_offloading + LSE shape mismatch in ring/Ulysses attention#13502

Open
Dev-next-gen wants to merge 2 commits intohuggingface:mainfrom
Dev-next-gen:fix/rocm-lse-shape-and-stream-sync
Open

fix(ROCm): stream sync race in group_offloading + LSE shape mismatch in ring/Ulysses attention#13502
Dev-next-gen wants to merge 2 commits intohuggingface:mainfrom
Dev-next-gen:fix/rocm-lse-shape-and-stream-sync

Conversation

@Dev-next-gen
Copy link
Copy Markdown

Summary

Two independent bugs that together prevent use_stream=True group offloading and context-parallel (ring/Ulysses) attention from working on AMD ROCm (tested on gfx1101 / RX 7800 XT, ROCm 7.1, PyTorch 2.7). Both fixes are backward-compatible with NVIDIA CUDA.


Fix 1 — group_offloading.py: gate default stream on transfer stream

Root cause: ModuleGroup._onload_from_memory launches async CPU→GPU tensor copies on a dedicated transfer stream but returns without making the default stream (on which the module forward pass runs) wait for those copies to complete.

On CUDA, implicit stream ordering and driver-level synchronization generally mask this race. On ROCm, the first matmul executes before the async copies finish, raising:

RuntimeError: Expected all tensors to be on the same device, but found at
least two devices, cuda:0 and cpu!  (when checking argument for argument
mat2 in method wrapper_CUDA_mm)

This affects any pipeline using enable_group_offload(use_stream=True), including FLUX.1-dev with int8 block-level offloading on ROCm.

Fix: After the transfer context block, call default_stream.wait_stream(self.stream) so the forward pass is gated on completed transfers. A stream.synchronize() fallback is included for backends that do not expose wait_stream. On CUDA this is a no-op when streams are already synchronized.


Fix 2 — attention_dispatch.py: replace is_torch_version guard with lse.ndim check

Root cause: Two call sites in TemplatedRingAttention.forward and _ulysses_context_parallel_attention condition the LSE unsqueeze on is_torch_version("<", "2.9.0"), assuming torch≥2.9 always returns LSE as [B,H,S,1] (4D) from aten._scaled_dot_product_flash_attention.

That assumption holds on CUDA but not on ROCm: on ROCm 7.x + torch≥2.9 (AOTriton / hipBLASLt backend), LSE is still returned as [B,H,S] (3D). The ring merge then broadcasts a 3D tensor against a 4D out, raising:

RuntimeError: The size of tensor a (24) must match the size of tensor b
(128) at non-singleton dimension 3

This blocks ring attention and Ulysses context-parallel attention on AMD hardware with any torch≥2.9 build.

Fix: Replace the is_torch_version guard with lse.ndim < out.ndim (resp. lse.ndim == 3). This is backend-agnostic: on CUDA where LSE is already 4D the condition is False and behaviour is unchanged; on ROCm where LSE is 3D the unsqueeze happens regardless of torch version. Applied to both affected call sites.


Hardware / software

| GPU | 5× AMD RX 7800 XT (gfx1101) |
| ROCm | 7.1 |
| PyTorch | 2.7 |
| Model | FLUX.1-dev, int8 quantization via torchao |
| Config | 4-GPU tensor parallel + block-level group offload + use_stream=True |

CUDA regression

None. Both fixes are either no-ops or logically equivalent on CUDA:

  • wait_stream when streams are synchronized = no-op
  • lse.ndim < out.ndim when LSE is already 4D = False, same branch as before

Related

…oad_from_memory

## Problem

`ModuleGroup._onload_from_memory` schedules async CPU→GPU tensor copies on a
dedicated transfer stream, but returns without making the default stream (on
which the module's forward pass runs) wait for those copies to finish.

On NVIDIA CUDA, implicit stream ordering and driver-level synchronization
generally prevent this race from manifesting. On **AMD ROCm** (tested on
gfx1101 / RX 7800 XT with ROCm 7.x), the race is reliable: the first matmul
in the freshly onloaded module executes before the async copies complete,
raising:

    RuntimeError: Expected all tensors to be on the same device, but found at
    least two devices, cuda:0 and cpu!  (when checking argument for argument
    mat2 in method wrapper_CUDA_mm)

This affects any pipeline that uses `enable_group_offload(use_stream=True)`,
including FLUX.1-dev with int8 group offloading on ROCm.

## Fix

After the `with context:` block, call `default_stream.wait_stream(self.stream)`
so the forward pass is gated on the completed transfers. A `stream.synchronize()`
fallback is included for backends that do not expose `wait_stream`.

On CUDA this call is a no-op when both streams are already synchronized,
so existing behaviour is preserved.

## Reproduction (ROCm)

```python
from diffusers import FluxPipeline
from diffusers.hooks import apply_group_offloading

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.to("cuda")
apply_group_offloading(pipe.transformer, offload_type="block_level",
                       offload_device=torch.device("cpu"),
                       onload_device=torch.device("cuda"),
                       use_stream=True, num_blocks_per_group=1)
pipe("test prompt", num_inference_steps=4)
# → RuntimeError: Expected all tensors to be on the same device … cpu vs cuda
# Fixed with this patch.
```

Tested on: 5× AMD RX 7800 XT (gfx1101), ROCm 7.1, PyTorch 2.7, diffusers main.
CUDA regression: none (wait_stream is a no-op when streams are synchronized).
…for ROCm compat

## Problem

Two call sites in `TemplatedRingAttention.forward` and
`_ulysses_context_parallel_attention` condition the LSE unsqueeze on
`is_torch_version("<", "2.9.0")`, following the assumption introduced in
huggingface#12693 that torch>=2.9 always returns LSE with shape [B,H,S,1] (4D) from
`_scaled_dot_product_flash_attention`.

That assumption holds on **NVIDIA CUDA** but not on **AMD ROCm**: on ROCm 7.x
with torch>=2.9, `aten._scaled_dot_product_flash_attention` (backed by
AOTriton / hipBLASLt) still returns LSE as [B,H,S] (3D). The downstream ring
merge then broadcasts a 3D tensor against a 4D `out` tensor, raising:

    RuntimeError: The size of tensor a (24) must match the size of tensor b
    (128) at non-singleton dimension 3

This blocks ring / context-parallel attention entirely on AMD hardware with
any torch>=2.9 build.

## Fix

Replace the `is_torch_version` guard with `lse.ndim < out.ndim` (resp.
`lse.ndim == 3`). This is backend-agnostic: on CUDA torch>=2.9 where LSE is
already 4D the condition is False and behaviour is unchanged; on ROCm where
LSE is 3D the unsqueeze happens regardless of torch version.

The same logical fix is applied to both affected call sites:
- `TemplatedRingAttention.forward` (ring merge loop)
- `_ulysses_context_parallel_attention` (Ulysses all-to-all path)

## Tested on

- 5× AMD RX 7800 XT (gfx1101), ROCm 7.1, PyTorch 2.7, diffusers main
- Ring attention + context parallel with FLUX.1-dev, 4-GPU tensor parallel
- CUDA regression: none (ndim guard is equivalent to version guard on CUDA)
@github-actions github-actions Bot added models hooks size/S PR with diff < 50 LOC labels Apr 19, 2026
Dev-next-gen added a commit to Dev-next-gen/ao that referenced this pull request Apr 19, 2026
…_get_to_kwargs

## Problem

`_get_to_kwargs` explicitly discarded the `non_blocking` argument parsed from
`torch._C._nn._parse_to`, with a comment saying it is "not very useful for
most tensor subclasses". As a result, any call to `tensor.to(device,
non_blocking=True)` on a `TorchAOBaseTensor` subclass silently became a
blocking transfer at the inner-tensor level.

This matters in practice for async CPU→GPU offloading workflows such as
`diffusers` `enable_group_offload(use_stream=True)`: the diffusers hook
schedules copies with `non_blocking=True` so that the transfer stream and
the compute stream can overlap. Because the flag was dropped, all copies
became blocking, negating the overlap benefit.

On AMD ROCm (gfx1xxx) the missing non_blocking also interacts with a
separate stream-ordering race (fixed in huggingface/diffusers#13502): the
default stream can race ahead of "blocking" copies that the OS scheduler
hasn't committed yet, producing device-mismatch errors in the first matmul.

## Fix

1. `_get_to_kwargs`: include `non_blocking` in the returned kwargs dict.
2. `TorchAOBaseTensor._to_copy.default`: pop `non_blocking` from kwargs and
   forward it to every inner `.to()` call for both `tensor_data_names` and
   `optional_tensor_data_names`.

The change is backward-compatible: when `non_blocking=False` (the default),
behaviour is identical to before.

## Tested on

- 5× AMD RX 7800 XT (gfx1101), ROCm 7.1, PyTorch 2.7
- FLUX.1-dev int8 (`Int8WeightOnlyConfig`) with `enable_group_offload(use_stream=True)`
- Companion fix in diffusers: huggingface/diffusers#13502
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

hooks models size/S PR with diff < 50 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant