Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,18 @@ def _onload_from_memory(self):
else:
self._process_tensors_from_modules(None)

# Gate the default stream on the transfer stream completing before the forward pass runs.
# On CUDA, implicit stream ordering often masks this race; on AMD ROCm (gfx1xxx) the
# first matmul can race ahead of the async CPU→GPU copies and raise a device-mismatch
# error ("mat2 is on cpu") inside the first matmul of the loaded module.
# `wait_stream` is a no-op when both handles refer to the same stream.
if self.stream is not None:
current_default = self._torch_accelerator_module.current_stream()
if hasattr(current_default, "wait_stream"):
current_default.wait_stream(self.stream)
else:
self.stream.synchronize()

def _offload_to_disk(self):
self._check_disk_offload_torchao()

Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1916,7 +1916,10 @@ def forward(

# Refer to:
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if is_torch_version("<", "2.9.0"):
# Use ndim check instead of torch version: on AMD ROCm, torch>=2.9 still returns
# LSE as [B,H,S] (3D) rather than [B,H,S,1] (4D), so the version gate is incorrect.
# Checking ndim is both backend-agnostic and torch-version-agnostic.
if lse.ndim < out.ndim:
lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
Expand Down Expand Up @@ -2206,7 +2209,9 @@ def _templated_unified_attention(
# lse is of shape (B, S, H_LOCAL, 1)
# Refer to:
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if is_torch_version("<", "2.9.0"):
# Use ndim check instead of torch version: on AMD ROCm, torch>=2.9 still returns
# LSE as [B,H,S] (3D), so SeqAllToAllDim must receive 4D regardless of torch version.
if lse.ndim == 3:
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
lse = lse.squeeze(-1)
Expand Down
Loading