Skip to content

[BUG] Ulysses DistributedAttention silently produces incorrect output when #GPUs does not divide global sequence length #7384

@selflein

Description

@selflein

Describe the bug
When using deepspeed.sequence.DistributedAttention on sequences with length N where N % G != 0 and G is the number of GPUs the output will silently be corrupted without throwing an error.

To Reproduce
Steps to reproduce the behavior:

  1. Save this as ds_distributed_att_varlen.py
import os
import deepspeed
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh

from deepspeed.sequence.layer import DistributedAttention

def _gather_variable_tensor(t: torch.Tensor, dim: int = 0, group=None):
    """Gather tensors with different size in ``dim`` using padding + all_gather.

    Returns the concatenated tensor on rank-0 and ``None`` on the others.
    """
    if group is None:
        group = dist.group.WORLD

    world_size_local = dist.get_world_size(group)
    rank = dist.get_rank(group)

    # 1. Share the size of ``dim`` across ranks.
    local_len = torch.tensor([t.shape[dim]], device=t.device, dtype=torch.long)
    len_list = [torch.zeros_like(local_len) for _ in range(world_size_local)]
    dist.all_gather(len_list, local_len, group=group)
    lens = [int(l.item()) for l in len_list]
    max_len = max(lens)

    # 2. Pad to the maximum length so that shapes match.
    pad_shape = list(t.shape)
    pad_shape[dim] = max_len
    t_padded = torch.zeros(pad_shape, dtype=t.dtype, device=t.device)
    slc = [slice(None)] * t.ndim
    slc[dim] = slice(0, t.shape[dim])
    t_padded[tuple(slc)] = t

    # 3. All-gather the padded tensors.
    gather_list = [torch.empty_like(t_padded) for _ in range(world_size_local)]
    dist.all_gather(gather_list, t_padded, group=group)

    # 4. On rank-0, unpad and concatenate. Other ranks drop the data.
    if rank == 0:
        parts = []
        for i, g in enumerate(gather_list):
            slc_i = [slice(None)] * t.ndim
            slc_i[dim] = slice(0, lens[i])
            # Keep the data on the original device (no .cpu()).
            parts.append(g[tuple(slc_i)])
        return torch.cat(parts, dim=dim)
    else:
        return None


def context_parallel_sdpa_example(world_size: int, rank: int):
    print(f"Rank {rank}, world_size {world_size}")
    assert torch.cuda.is_available()
    assert dist.is_nccl_available()
    torch.cuda.manual_seed(0)

    deepspeed.init_distributed("nccl")

    device_mesh = init_device_mesh(
        device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("cp",)
    )
    group = device_mesh.get_group("cp")

    device = torch.device("cuda", rank)

    batch = 8
    nheads = 8
    qkv_len = world_size * 2 + 1  # Make not a multiple of #GPUs.
    dim = 64
    # backend = SDPBackend.CUDNN_ATTENTION
    torch.backends.cuda.enable_math_sdp(False)

    dtype = torch.float32

    qkv = [
        torch.rand(
            (batch, nheads, qkv_len, dim),
            dtype=dtype,
            requires_grad=True,
            device=device,
        )
        for _ in range(3)
    ]
    out = F.scaled_dot_product_attention(*qkv, is_causal=False)
    # make a clean copy of QKV for output comparison
    cp_qkv = [t.detach().clone() for t in qkv]

    # Broadcast to all ranks
    for t in cp_qkv:
        dist.broadcast(t, src=0, group=group)
    print(f"[RANK {rank}] Broadcasted QKV")

    cp_qkv = [torch.tensor_split(t, world_size, dim=2)[rank] for t in cp_qkv]

    for t in cp_qkv:
        print(f"[RANK {rank}] Chunked QKV shape: {t.shape}")

    # Run distributed attention
    dist_attn = DistributedAttention(F.scaled_dot_product_attention, group, gather_idx=2, scatter_idx=1)
    cp_out = dist_attn(*cp_qkv, is_causal=False, batch_dim_idx=0)
    print(f"[RANK {rank}] DistributedAttention ran.")

    cp_out_full = _gather_variable_tensor(cp_out, dim=2, group=group)
    print(f"[RANK {rank}] Gathered output across ranks.") 

    # Validate equivalence with the baseline implementation.
    if rank == 0:
        print(cp_out_full.device, out.device)
        print(
            f"Rank {rank}: cp_out_full.shape {cp_out_full.shape}, baseline out.shape {out.shape}"
        )
        print(cp_out_full - out)
        assert torch.allclose(
            cp_out_full,
            out,
            atol=(1e-08 if dtype == torch.float32 else 1e-08 * world_size),
        )

        # If we reached here the test has passed on this rank.
        print(f"Rank {rank}: DistributedAttention test passed ✅")


if __name__ == "__main__":
    rank = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))

    context_parallel_sdpa_example(world_size, rank)

    dist.destroy_process_group()
  1. Run the above script with torchrun --nproc_per_device 2 ds_distributed_att_varlen.py.

Expected behavior
I would have expected for an error to be thrown if this is outside the specification of the implementation or the output should actually be correct and coincide with the single GPU version.

ds_report output
Please run ds_report to give us details about your setup.

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
dc ..................... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
 [WARNING]  NVIDIA Inference is only supported on Ampere and newer architectures
 [WARNING]  FP Quantizer is using an untested triton version (3.3.1), only 2.3.(0, 1) and 3.0.0 are known to be compatible with these kernels
fp_quantizer ........... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
compiler_compat/ld: 
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.7
 [WARNING]  using untested triton version (3.3.1), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch version .................... 2.7.1+cu126
deepspeed info ................... 0.17.1, unknown, unknown
torch cuda version ............... 12.6
torch hip version ................ None
nvcc version ..................... 12.1
deepspeed wheel compiled w. ...... torch 2.3, cuda 12.1

Screenshots
If applicable, add screenshots to help explain your problem.

System info (please complete the following information):

  • OS: Linux version 5.4.0-80-generic (buildd@lcy01-amd64-030) (gcc version 9.3.0 (Ubuntu 9.3.0-17ubuntu1~20.04)) ZeRO with non-zero loss scale crashes #90-Ubuntu SMP Fri Jul 9 22:49:44 UTC 2021
  • Tested with 2x and 8x V100
  • Interconnects: NVLink
  • Python version: 3.11

Launcher context
Are you launching your experiment with the deepspeed launcher, MPI, or something else?
torchrun

Docker context
Are you using a specific docker image that you can share?
Running on barebones node.

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtraining

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions