-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
I am not sure whether the backward pass of Ulysses SP is formally supported, but I found that backward ops like _native_attention_backward_op is implemented in the codebase. When I try to run QwenImageTransformer with the backward pass, I encounter errors related to shape mismatches.
Reproduction
We can reproduce the results with the following code snippets (relies on the PR #13278 to fix the forward pass first)
import argparse
import torch
import torch.distributed as dist
import torch.nn.functional as F
from diffusers.models import QwenImageTransformer2DModel
from diffusers.models._modeling_parallel import ContextParallelConfig
def init_model(device, enable_sp: bool):
model = QwenImageTransformer2DModel(
num_layers=2,
num_attention_heads=4,
attention_head_dim=32,
joint_attention_dim=3584,
axes_dims_rope=(8, 12, 12),
).to(device, dtype=torch.bfloat16)
if enable_sp:
model.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
return model
def make_batch(device):
dtype = torch.bfloat16
torch.manual_seed(0)
hidden_states = torch.randn(2, 256, 64, device=device, dtype=dtype)
encoder_hidden_states = torch.randn(2, 32, 3584, device=device, dtype=dtype)
encoder_hidden_states_mask = torch.ones(2, 32, device=device, dtype=torch.bool)
timestep = torch.rand(2, device=device, dtype=dtype)
img_shapes = [[(1, 16, 16)]] * 2
target = torch.randn(2, 256, 64, device=device, dtype=dtype)
return (
hidden_states,
encoder_hidden_states,
encoder_hidden_states_mask,
timestep,
img_shapes,
target,
)
def train(enable_sp: bool):
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
device = torch.device("cuda", rank)
torch.cuda.set_device(device)
model = init_model(device, enable_sp)
model.train()
(
hidden_states,
encoder_hidden_states,
encoder_hidden_states_mask,
timestep,
img_shapes,
target,
) = make_batch(device)
pred = model(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
timestep=timestep,
img_shapes=img_shapes,
return_dict=False,
)[0]
loss = F.mse_loss(pred.float(), target.float())
loss.backward()
if rank == 0:
print(f"loss={loss.item():.6f}")
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--enable-sp", action="store_true")
args = parser.parse_args()
train(enable_sp=args.enable_sp)Logs
- Without Ulysses SP enabled,
torchrun --nproc-per-node 2 toy_train.pythe script runs and produces the expected output:
loss=1.351188
- With Ulysses SP enabled
torchrun --nproc-per-node 2 toy_train.py --enable-spit runs with error
RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([2, 2, 288, 32]) and output[0] has a shape of torch.Size([2, 288, 2, 32]).
System Info
- 🤗 Diffusers version: 0.38.0.dev0
- Platform: Linux-5.15.0-1053-nvidia-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.12.12
- PyTorch version (GPU?): 2.9.1+cu128 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.36.2
- Transformers version: 4.57.6
- Accelerate version: 1.12.0
- PEFT version: 0.18.1
- Bitsandbytes version: not installed
- Safetensors version: 0.7.0
- xFormers version: not installed
- Accelerator: NVIDIA H800, 81559 MiB
NVIDIA H800, 81559 MiB
NVIDIA H800, 81559 MiB
NVIDIA H800, 81559 MiB - Using GPU in script?: yes
- Using distributed or parallel set-up in script?: yes
Who can help?
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working