Draft
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
- spawn_materialize schedules all owned mappings up front for disk I/O parallelism, but each rank only schedules its own mappings - source rank materializes to CPU (mmap-backed), moves to GPU per- mapping right before the broadcast — avoids piling up many full expert tensors on-device at once - free the full tensor immediately after shard on source + receivers (del full_tensor, drop realized_value reference) - get_packed_weights slices in native dtype when input is a torch.Tensor and only upcasts the already-small shard (no full-tensor FP8->FP16 upcast) - _redistribute_realized_value decorated with @torch.no_grad() - benchmark_v2/benchmark_scripts/tp_loading.py: --model-id flag to drive the path on real checkpoints via torchrun
- clone the sharded view in _redistribute_realized_value when it aliases the full broadcast buffer — otherwise del full_tensor only decrements a refcount and the 6GB expert tensor stays pinned by the view, silently blowing up peak memory - don't re-cast the sharded shard to top-level dtype inside the TP path; the per-entry _dtype was already honored by spawn_materialize and forcing a dtype here upcasts FP8 shards to BF16, doubling memory - skip caching_allocator_warmup when TP is active: the streaming redistribute+del flow doesn't benefit from a big upfront reserve, and the reservation fights the per-mapping receive-buffer allocs
Each rank owns ~num_mappings/world_size mappings via partition_mappings_ across_ranks; for its owned mappings it now pre-slices the materialized full tensor into one shard per rank and calls dist.scatter, so every other rank receives only its own shard. - cluster bandwidth drops from (N-1)*sizeof(full) to (N-1)/N*sizeof(full) - no more view-keeps-full-tensor-alive bug: each scatter recv buffer is a fresh allocation sized to the local shard, no clone needed on rx - replicated params (tp_layer is None) still broadcast — they're small - world_size==1 path stays fully local, no comms - ragged shards (non-divisible sharded dim) raise rather than silently miscompute; every model we ship divides cleanly under tp_plan=auto
Replace the per-rank Python loop (N × tp_layer.rank mutation + shard_tensor dispatch) with a single _batch_shard_for_scatter() that uses torch.chunk / torch.split for GPU-native batched slicing. Falls back to the per-rank loop for unusual TP classes (EmbeddingParallel, MoeIdentityExpertParallel etc.) No measured perf change on NVLink — the bottleneck is the per-mapping sequential scatter calls, not the pre-slicing.
…broadcast - _redistribute_async returns (work_handles, local_params) with async_op=True for all scatter/broadcast calls - main loop pipelines: while scatter N is in-flight, source converts mapping N+1 and kicks off its async scatter - skip signalling uses a single-int broadcast instead of the old heavyweight broadcast_object_list of pickled metadata per mapping - non-source ranks derive target shapes from the shared meta model state dict (no per-mapping metadata round-trip)
The thread pool now lands tensors on the TP device in one shot (disk → GPU via safetensors mmap). Removes the CPU staging + per-mapping .to(device) copy that was the dominant cost on every model.
- convert runs in a background thread (ThreadPoolExecutor(1)): while scatter N is in-flight on NCCL and finalize N-1 writes params, the CPU is already resolving futures + stacking for mapping N+1 - removed per-mapping skip-flag broadcast (was a sync barrier on every single mapping); SkipParameters is deterministic across ranks - materialize directly to GPU (previous commit) + pipeline overlap means disk→GPU transfer is hidden behind the previous scatter Results (8×B200, tp_plan=auto): Qwen2.5-7B (ws=4): 11.17s (main: 10.42s) Qwen2.5-14B (ws=8): 24.73s (main: 12.54s) Qwen3-30B-A3B MoE (ws=4): 13.83s (main: 13.56s)
Wrap each batch of 64 mappings' redistribute calls in torch.distributed.distributed_c10d._coalescing_manager with device=local_device, which lowers per-call launch overhead via ncclGroupStart/ncclGroupEnd. Falls back to the plain loop on gloo (CPU synthetic test) since gloo has no coalescing primitive. Also threads mapping.convert() across a small thread pool (up to 4 workers) so the batch's converts run concurrently with each other while the previous batch's scatter is still in-flight. Added VIZTRACER_OUTPUT env var to the benchmark script for rank-0 profiling of the loading path. Results (8×B200, tp_plan=auto): Qwen2.5-7B (ws=4): 13.21s (main: 10.42s) Qwen2.5-14B (ws=8): 22.94s (main: 12.54s) Qwen3-30B-A3B MoE (ws=4): 14.97s (main: 13.56s) viztracer shows the per-collective launch overhead is gone; the remaining gap vs main is the coalesced NCCL wait itself (~8.5s on 14B). NVLink is not saturated — further gains likely need grouping multiple tensors into a single scatter payload.
- Each batch of mappings is grouped by source rank, and each source packs all of its owned shards for the batch into one uint8 buffer per destination, then does a single dist.scatter(async_op=True). Receivers allocate a matching uint8 recv buffer and slice it back into typed/shaped views after wait(). Shrinks K×world_size tiny scatters per batch to world_size big scatters. - Removed the _coalescing_manager context. Profiling showed the time just shifted from per-call launch to _end_coalescing wait, and it is NCCL-only (gloo had to fall back to nullcontext anyway). Plain async scatter + wait-per-handle works on every backend. - Each batch's recv_bufs are stashed in model._tp_recv_buffers so the param views (which alias them) stay valid for the life of the model. - profile (14B): dist.scatter calls 480→74, scatter time 7.3s→3.3s; coalescing wait removed.
- Thread pool reads safetensors → CPU pin_memory() per batch (not all upfront). Next batch's disk reads overlap with current batch's scatter. - Sync DMA (pin_memory + .to(device)) — async DMA on dedicated stream caused silent data corruption; reverted to synchronous for correctness. The pipeline overlap still works at the CPU level (thread pool disk reads run concurrently with GPU scatter). - Removed dead _redistribute_async function and unused dma_stream. - Correctness: synthetic OK, Qwen2.5-14B generates correctly.
Reverted CUDA IPC path — multi-process IPC handle mapping overhead (cudaIpcOpenMemHandle ~38ms × 65 = 2.5s + all_gather_object 4.8s) made it 37s vs NCCL's 22s. Single-process cudaMemcpyPeer achieves 654 GB/s on NVLink, but multi-process IPC doesn't get the same bandwidth. Also tried: - NCCL_P2P_USE_CUDA_MEMCPY=1 → hangs (known NCCL issue #1774) - NCCL_ALLOC_P2P_NET_LL_BUFFERS + NCCL_P2P_NET_CHUNKSIZE → no change - batch_isend_irecv → same as scatter (uses _coalescing_manager internally) - all_to_all_single → same overhead, higher memory NCCL packed-scatter at 22s (1.77× main's 12.5s) is the multi-process floor. The architecture is designed for future FSDP integration where the one-rank-reads + scatter pattern becomes essential for cross-node loading.
BATCH_SIZE=all mappings so each source does ONE scatter per model load instead of one per batch. Reduces scatter calls from 96 → 8 for 70B. Extensively benchmarked scatter vs all_to_all_single vs batch_isend_irecv vs CUDA IPC on 70B Llama 3.1 with 8×B200. All produce the same ~11s of NCCL transfer — the floor is NVLink bisection bandwidth with 8 ranks each sending 17.5 GB of cross-traffic, not per-op overhead. NCCL flags tested (no significant effect): NCCL_MAX_NCHANNELS=32, NCCL_PROTO=Simple, NCCL_P2P_NET_CHUNKSIZE=4M, NCCL_NCHANNELS_PER_NET_PEER=8 Results (8×B200, tp_plan=auto): Llama-3.1-70B: main 16.49s refactor 25.96s (1.57×) Qwen2.5-72B: main 17.62s refactor 26.92s (1.53×) Generate 2.5× faster on refactor (both 70B models)
- batch_isend_irecv (one ncclGroup per batch, no individual deadlocks) - Pipeline: while batch N's P2P ops run on NCCL, finalize(N-1) unpacks params and next batch reads start from disk - BATCH_SIZE=max(len/4, 64) for ~4 pipeline stages - Individual isend/irecv deadlocks (NCCL requires ncclGroup for concurrent P2P — confirmed by testing + NCCL docs) Results (8×B200, tp_plan=auto): Llama-3.1-70B: 26.31s load + 2.48s gen (main: 16.49s + 6.59s) Generate 2.5× faster. Load 1.6× slower — 10s is NCCL transfer floor.
Contributor
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45453&sha=63d748 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Ai init