Describe the bug
The mask is not correctly sharded in QwenImageTransformer when Ulysses SP is enabled.
Since the QwenImageTransformer's mask is non-contiguous, it should be expanded in an interleaved manner before being fed into each SP rank.
The current implementation shows no clear visual difference with or without correct mask sharding. However, for RL training, we find that it causes abnormal losses and metrics.
This issue can also be validated by the following SP accuracy tests.
Reproduction
The test should be failed since the forward output for SP and non-SP is not same.
"""
Run with 2 GPUs:
torchrun --nproc_per_node=2 --local-ranks-filter=0 test.py
"""
import os
from datetime import timedelta
import torch
import torch.distributed
from torch.distributed.device_mesh import init_device_mesh
from diffusers import ContextParallelConfig, QwenImageTransformer2DModel
_SP_SIZE = 2
_MODEL_CONFIG = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 4,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
def main():
torch.distributed.init_process_group("nccl", timeout=timedelta(seconds=36000))
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
rank = torch.distributed.get_rank()
mesh = init_device_mesh("cuda", mesh_shape=(1, 1, _SP_SIZE), mesh_dim_names=("dp", "ring", "ulysses"))
module_sp = QwenImageTransformer2DModel(**_MODEL_CONFIG)
module_sp.enable_parallelism(config=ContextParallelConfig(ulysses_degree=_SP_SIZE, mesh=mesh))
module_sp = module_sp.to("cuda", dtype=torch.bfloat16)
for p in module_sp.parameters():
torch.distributed.broadcast(p.data, src=0)
module_no_sp = QwenImageTransformer2DModel(**_MODEL_CONFIG)
module_no_sp = module_no_sp.to("cuda", dtype=torch.bfloat16)
module_no_sp.load_state_dict({k: v.clone() for k, v in module_sp.state_dict().items()}, strict=False)
batch_size, latent_h, text_seq_len = 2, 4, 8
latent_dim, text_dim = _MODEL_CONFIG["in_channels"], _MODEL_CONFIG["joint_attention_dim"]
hidden_states = torch.zeros(batch_size, latent_h * latent_h, latent_dim, dtype=torch.bfloat16, device="cuda")
encoder_hidden_states = torch.zeros(batch_size, text_seq_len, text_dim, dtype=torch.bfloat16, device="cuda")
if rank == 0:
torch.manual_seed(42)
hidden_states.normal_()
encoder_hidden_states.normal_()
torch.distributed.broadcast(hidden_states, src=0)
torch.distributed.broadcast(encoder_hidden_states, src=0)
encoder_hidden_states_mask = torch.zeros(batch_size, text_seq_len, dtype=torch.bool, device="cuda")
encoder_hidden_states_mask[0, :2] = True
encoder_hidden_states_mask[1, :6] = True
model_inputs = dict(
hidden_states=hidden_states,
timestep=torch.full([batch_size], 0.5, dtype=torch.float32, device="cuda"),
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
img_shapes=[[(1, latent_h, latent_h)]] * batch_size,
return_dict=False,
)
module_sp.eval()
module_no_sp.eval()
with torch.no_grad():
output_sp = module_sp(**model_inputs)[0]
output_no_sp = module_no_sp(**model_inputs)[0]
torch.testing.assert_close(output_sp.float(), output_no_sp.float(), rtol=1e-2, atol=1e-2)
if rank == 0:
print(f"mean(SP)={output_sp.float().mean():.6f} mean(no-SP)={output_no_sp.float().mean():.6f} ✓")
if __name__ == "__main__":
main()
Logs
[rank0]: Traceback (most recent call last):
[rank0]: File "/scratch/fq9hpsac/mikecheung/gitlocal/verl-omni/test.py", line 96, in <module>
[rank0]: main()
[rank0]: File "/scratch/fq9hpsac/mikecheung/gitlocal/verl-omni/test.py", line 89, in main
[rank0]: torch.testing.assert_close(output_sp.float(), output_no_sp.float(), rtol=1e-2, atol=1e-2)
[rank0]: File "/scratch/fq9hpsac/mikecheung/miniforge3/envs/verl-omni/lib/python3.12/site-packages/torch/testing/_comparison.py", line 1600, in assert_close
[rank0]: raise error_metas[0].to_error(msg)
[rank0]: AssertionError: Tensor-likes are not close!
[rank0]: Mismatched elements: 66 / 512 (12.9%)
[rank0]: Greatest absolute difference: 0.03934478759765625 at index (0, 12, 11) (up to 0.01 allowed)
[rank0]: Greatest relative difference: 72.2631607055664 at index (0, 12, 9) (up to 0.01 allowed)
System Info
diffusers==0.38.0
Who can help?
@sayakpaul
Describe the bug
The mask is not correctly sharded in
QwenImageTransformerwhen Ulysses SP is enabled.Since the
QwenImageTransformer's mask is non-contiguous, it should be expanded in an interleaved manner before being fed into each SP rank.The current implementation shows no clear visual difference with or without correct mask sharding. However, for RL training, we find that it causes abnormal losses and metrics.
This issue can also be validated by the following SP accuracy tests.
Reproduction
The test should be failed since the forward output for SP and non-SP is not same.
Logs
System Info
diffusers==0.38.0
Who can help?
@sayakpaul