Skip to content

[bug] The mask is not correctly sharded for QwenImageTransformer + Ulysses SP #13696

@zhtmike

Description

@zhtmike

Describe the bug

The mask is not correctly sharded in QwenImageTransformer when Ulysses SP is enabled.
Since the QwenImageTransformer's mask is non-contiguous, it should be expanded in an interleaved manner before being fed into each SP rank.

The current implementation shows no clear visual difference with or without correct mask sharding. However, for RL training, we find that it causes abnormal losses and metrics.

This issue can also be validated by the following SP accuracy tests.

Reproduction

The test should be failed since the forward output for SP and non-SP is not same.

"""
Run with 2 GPUs:
    torchrun --nproc_per_node=2 --local-ranks-filter=0 test.py
"""

import os
from datetime import timedelta

import torch
import torch.distributed
from torch.distributed.device_mesh import init_device_mesh
from diffusers import ContextParallelConfig, QwenImageTransformer2DModel

_SP_SIZE = 2
_MODEL_CONFIG = {
    "patch_size": 2,
    "in_channels": 16,
    "out_channels": 4,
    "num_layers": 2,
    "attention_head_dim": 16,
    "num_attention_heads": 4,
    "joint_attention_dim": 16,
    "guidance_embeds": False,
    "axes_dims_rope": (8, 4, 4),
}


def main():
    torch.distributed.init_process_group("nccl", timeout=timedelta(seconds=36000))
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

    rank = torch.distributed.get_rank()
    mesh = init_device_mesh("cuda", mesh_shape=(1, 1, _SP_SIZE), mesh_dim_names=("dp", "ring", "ulysses"))

    module_sp = QwenImageTransformer2DModel(**_MODEL_CONFIG)
    module_sp.enable_parallelism(config=ContextParallelConfig(ulysses_degree=_SP_SIZE, mesh=mesh))
    module_sp = module_sp.to("cuda", dtype=torch.bfloat16)
    for p in module_sp.parameters():
        torch.distributed.broadcast(p.data, src=0)

    module_no_sp = QwenImageTransformer2DModel(**_MODEL_CONFIG)
    module_no_sp = module_no_sp.to("cuda", dtype=torch.bfloat16)
    module_no_sp.load_state_dict({k: v.clone() for k, v in module_sp.state_dict().items()}, strict=False)

    batch_size, latent_h, text_seq_len = 2, 4, 8
    latent_dim, text_dim = _MODEL_CONFIG["in_channels"], _MODEL_CONFIG["joint_attention_dim"]

    hidden_states = torch.zeros(batch_size, latent_h * latent_h, latent_dim, dtype=torch.bfloat16, device="cuda")
    encoder_hidden_states = torch.zeros(batch_size, text_seq_len, text_dim, dtype=torch.bfloat16, device="cuda")
    if rank == 0:
        torch.manual_seed(42)
        hidden_states.normal_()
        encoder_hidden_states.normal_()
    torch.distributed.broadcast(hidden_states, src=0)
    torch.distributed.broadcast(encoder_hidden_states, src=0)

    encoder_hidden_states_mask = torch.zeros(batch_size, text_seq_len, dtype=torch.bool, device="cuda")
    encoder_hidden_states_mask[0, :2] = True
    encoder_hidden_states_mask[1, :6] = True

    model_inputs = dict(
        hidden_states=hidden_states,
        timestep=torch.full([batch_size], 0.5, dtype=torch.float32, device="cuda"),
        encoder_hidden_states=encoder_hidden_states,
        encoder_hidden_states_mask=encoder_hidden_states_mask,
        img_shapes=[[(1, latent_h, latent_h)]] * batch_size,
        return_dict=False,
    )

    module_sp.eval()
    module_no_sp.eval()
    with torch.no_grad():
        output_sp = module_sp(**model_inputs)[0]
        output_no_sp = module_no_sp(**model_inputs)[0]

    torch.testing.assert_close(output_sp.float(), output_no_sp.float(), rtol=1e-2, atol=1e-2)

    if rank == 0:
        print(f"mean(SP)={output_sp.float().mean():.6f}  mean(no-SP)={output_no_sp.float().mean():.6f}  ✓")


if __name__ == "__main__":
    main()

Logs

[rank0]: Traceback (most recent call last):
[rank0]:   File "/scratch/fq9hpsac/mikecheung/gitlocal/verl-omni/test.py", line 96, in <module>
[rank0]:     main()
[rank0]:   File "/scratch/fq9hpsac/mikecheung/gitlocal/verl-omni/test.py", line 89, in main
[rank0]:     torch.testing.assert_close(output_sp.float(), output_no_sp.float(), rtol=1e-2, atol=1e-2)
[rank0]:   File "/scratch/fq9hpsac/mikecheung/miniforge3/envs/verl-omni/lib/python3.12/site-packages/torch/testing/_comparison.py", line 1600, in assert_close
[rank0]:     raise error_metas[0].to_error(msg)
[rank0]: AssertionError: Tensor-likes are not close!

[rank0]: Mismatched elements: 66 / 512 (12.9%)
[rank0]: Greatest absolute difference: 0.03934478759765625 at index (0, 12, 11) (up to 0.01 allowed)
[rank0]: Greatest relative difference: 72.2631607055664 at index (0, 12, 9) (up to 0.01 allowed)

System Info

diffusers==0.38.0

Who can help?

@sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions