-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Description
Describe the bug
When using deepspeed.sequence.DistributedAttention on sequences with length N where N % G != 0 and G is the number of GPUs the output will silently be corrupted without throwing an error.
To Reproduce
Steps to reproduce the behavior:
- Save this as
ds_distributed_att_varlen.py
import os
import deepspeed
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh
from deepspeed.sequence.layer import DistributedAttention
def _gather_variable_tensor(t: torch.Tensor, dim: int = 0, group=None):
"""Gather tensors with different size in ``dim`` using padding + all_gather.
Returns the concatenated tensor on rank-0 and ``None`` on the others.
"""
if group is None:
group = dist.group.WORLD
world_size_local = dist.get_world_size(group)
rank = dist.get_rank(group)
# 1. Share the size of ``dim`` across ranks.
local_len = torch.tensor([t.shape[dim]], device=t.device, dtype=torch.long)
len_list = [torch.zeros_like(local_len) for _ in range(world_size_local)]
dist.all_gather(len_list, local_len, group=group)
lens = [int(l.item()) for l in len_list]
max_len = max(lens)
# 2. Pad to the maximum length so that shapes match.
pad_shape = list(t.shape)
pad_shape[dim] = max_len
t_padded = torch.zeros(pad_shape, dtype=t.dtype, device=t.device)
slc = [slice(None)] * t.ndim
slc[dim] = slice(0, t.shape[dim])
t_padded[tuple(slc)] = t
# 3. All-gather the padded tensors.
gather_list = [torch.empty_like(t_padded) for _ in range(world_size_local)]
dist.all_gather(gather_list, t_padded, group=group)
# 4. On rank-0, unpad and concatenate. Other ranks drop the data.
if rank == 0:
parts = []
for i, g in enumerate(gather_list):
slc_i = [slice(None)] * t.ndim
slc_i[dim] = slice(0, lens[i])
# Keep the data on the original device (no .cpu()).
parts.append(g[tuple(slc_i)])
return torch.cat(parts, dim=dim)
else:
return None
def context_parallel_sdpa_example(world_size: int, rank: int):
print(f"Rank {rank}, world_size {world_size}")
assert torch.cuda.is_available()
assert dist.is_nccl_available()
torch.cuda.manual_seed(0)
deepspeed.init_distributed("nccl")
device_mesh = init_device_mesh(
device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("cp",)
)
group = device_mesh.get_group("cp")
device = torch.device("cuda", rank)
batch = 8
nheads = 8
qkv_len = world_size * 2 + 1 # Make not a multiple of #GPUs.
dim = 64
# backend = SDPBackend.CUDNN_ATTENTION
torch.backends.cuda.enable_math_sdp(False)
dtype = torch.float32
qkv = [
torch.rand(
(batch, nheads, qkv_len, dim),
dtype=dtype,
requires_grad=True,
device=device,
)
for _ in range(3)
]
out = F.scaled_dot_product_attention(*qkv, is_causal=False)
# make a clean copy of QKV for output comparison
cp_qkv = [t.detach().clone() for t in qkv]
# Broadcast to all ranks
for t in cp_qkv:
dist.broadcast(t, src=0, group=group)
print(f"[RANK {rank}] Broadcasted QKV")
cp_qkv = [torch.tensor_split(t, world_size, dim=2)[rank] for t in cp_qkv]
for t in cp_qkv:
print(f"[RANK {rank}] Chunked QKV shape: {t.shape}")
# Run distributed attention
dist_attn = DistributedAttention(F.scaled_dot_product_attention, group, gather_idx=2, scatter_idx=1)
cp_out = dist_attn(*cp_qkv, is_causal=False, batch_dim_idx=0)
print(f"[RANK {rank}] DistributedAttention ran.")
cp_out_full = _gather_variable_tensor(cp_out, dim=2, group=group)
print(f"[RANK {rank}] Gathered output across ranks.")
# Validate equivalence with the baseline implementation.
if rank == 0:
print(cp_out_full.device, out.device)
print(
f"Rank {rank}: cp_out_full.shape {cp_out_full.shape}, baseline out.shape {out.shape}"
)
print(cp_out_full - out)
assert torch.allclose(
cp_out_full,
out,
atol=(1e-08 if dtype == torch.float32 else 1e-08 * world_size),
)
# If we reached here the test has passed on this rank.
print(f"Rank {rank}: DistributedAttention test passed ✅")
if __name__ == "__main__":
rank = int(os.environ.get("RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
context_parallel_sdpa_example(world_size, rank)
dist.destroy_process_group()- Run the above script with
torchrun --nproc_per_device 2 ds_distributed_att_varlen.py.
Expected behavior
I would have expected for an error to be thrown if this is outside the specification of the implementation or the output should actually be correct and coincide with the single GPU version.
ds_report output
Please run ds_report to give us details about your setup.
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
dc ..................... [NO] ....... [OKAY]
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
[WARNING] NVIDIA Inference is only supported on Ampere and newer architectures
[WARNING] FP Quantizer is using an untested triton version (3.3.1), only 2.3.(0, 1) and 3.0.0 are known to be compatible with these kernels
fp_quantizer ........... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
compiler_compat/ld:
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.7
[WARNING] using untested triton version (3.3.1), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch version .................... 2.7.1+cu126
deepspeed info ................... 0.17.1, unknown, unknown
torch cuda version ............... 12.6
torch hip version ................ None
nvcc version ..................... 12.1
deepspeed wheel compiled w. ...... torch 2.3, cuda 12.1
Screenshots
If applicable, add screenshots to help explain your problem.
System info (please complete the following information):
- OS: Linux version 5.4.0-80-generic (buildd@lcy01-amd64-030) (gcc version 9.3.0 (Ubuntu 9.3.0-17ubuntu1~20.04)) ZeRO with non-zero loss scale crashes #90-Ubuntu SMP Fri Jul 9 22:49:44 UTC 2021
- Tested with 2x and 8x V100
- Interconnects: NVLink
- Python version: 3.11
Launcher context
Are you launching your experiment with the deepspeed launcher, MPI, or something else?
torchrun
Docker context
Are you using a specific docker image that you can share?
Running on barebones node.