Add: L3 broadcast and all-to-all distributed collectives#888
Add: L3 broadcast and all-to-all distributed collectives#888georgebisbas wants to merge 3 commits into
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughAdded two new L3 worker examples demonstrating distributed communication patterns. All-to-all implements symmetric 3-phase exchange with per-rank staging, cross-rank synchronization, and remote scratch reads. Broadcast implements 3-phase communication where root stages data, all ranks synchronize, then read from root. Both include AICORE kernels, C++ orchestration shims, Python drivers with CLI interfaces, and parametrized integration tests. ChangesAll-to-All Distributed Exchange
Broadcast Distributed Communication
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces new distributed communication examples for Level 3 workers, specifically implementing all_to_all_distributed and broadcast_distributed. Each example includes a symmetric 3-phase C++ kernel utilizing HCCL-window scratch patterns, an orchestration shim, a Python main script for end-to-end execution, and corresponding pytest suites. The main README has also been updated to document these additions. There are no review comments, so no feedback is provided.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In
`@examples/workers/l3/all_to_all_distributed/kernels/aiv/all_to_all_kernel.cpp`:
- Around line 108-118: The barrier can pass early because signal slots may be
non-zero; fix by making waits monotonic: read the current local counter from
pto::comm::Signal(signal_base + my_rank) to compute a per-phase target = current
+ 1, then perform the remote increments with CommRemotePtr/pto::comm::TNOTIFY as
before and change the waits to TWAIT(..., target, ...) against each peer's
signal_base slot (use the same target for all peers) instead of waiting for >=
1; reference pto::comm::Signal, CommRemotePtr, TNOTIFY and TWAIT to locate and
update the logic.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: ad889782-193a-43a0-8a91-0183409f4553
📒 Files selected for processing (11)
examples/workers/l3/README.mdexamples/workers/l3/all_to_all_distributed/__init__.pyexamples/workers/l3/all_to_all_distributed/kernels/aiv/all_to_all_kernel.cppexamples/workers/l3/all_to_all_distributed/kernels/orchestration/all_to_all_orch.cppexamples/workers/l3/all_to_all_distributed/main.pyexamples/workers/l3/all_to_all_distributed/test_all_to_all.pyexamples/workers/l3/broadcast_distributed/__init__.pyexamples/workers/l3/broadcast_distributed/kernels/aiv/broadcast_kernel.cppexamples/workers/l3/broadcast_distributed/kernels/orchestration/broadcast_orch.cppexamples/workers/l3/broadcast_distributed/main.pyexamples/workers/l3/broadcast_distributed/test_broadcast.py
| for (int peer = 0; peer < nranks; ++peer) { | ||
| if (peer == my_rank) continue; | ||
| __gm__ int32_t *remote_signal = CommRemotePtr(commCtx, signal_base + my_rank, peer); | ||
| pto::comm::Signal sig(remote_signal); | ||
| pto::comm::TNOTIFY(sig, (int32_t)1, pto::comm::NotifyOp::AtomicAdd); | ||
| } | ||
| for (int peer = 0; peer < nranks; ++peer) { | ||
| if (peer == my_rank) continue; | ||
| pto::comm::Signal sig(signal_base + peer); | ||
| pto::comm::TWAIT(sig, (int32_t)1, pto::comm::WaitCmp::GE); | ||
| } |
There was a problem hiding this comment.
Phase-2 barrier can pass early due to stale signal counters.
At Line [117], waiting for >= 1 assumes each signal_base[...] slot starts at zero. These slots are never initialized here, so reused/non-zero tail memory can satisfy waits immediately and break cross-rank synchronization.
💡 Proposed fix (monotonic wait targets)
+ int32_t wait_target[kMaxSupportedRanks];
+ for (int peer = 0; peer < nranks; ++peer) {
+ wait_target[peer] = signal_base[peer];
+ if (peer != my_rank) {
+ wait_target[peer] += 1;
+ }
+ }
+
for (int peer = 0; peer < nranks; ++peer) {
if (peer == my_rank) continue;
__gm__ int32_t *remote_signal = CommRemotePtr(commCtx, signal_base + my_rank, peer);
pto::comm::Signal sig(remote_signal);
pto::comm::TNOTIFY(sig, (int32_t)1, pto::comm::NotifyOp::AtomicAdd);
}
for (int peer = 0; peer < nranks; ++peer) {
if (peer == my_rank) continue;
pto::comm::Signal sig(signal_base + peer);
- pto::comm::TWAIT(sig, (int32_t)1, pto::comm::WaitCmp::GE);
+ pto::comm::TWAIT(sig, wait_target[peer], pto::comm::WaitCmp::GE);
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for (int peer = 0; peer < nranks; ++peer) { | |
| if (peer == my_rank) continue; | |
| __gm__ int32_t *remote_signal = CommRemotePtr(commCtx, signal_base + my_rank, peer); | |
| pto::comm::Signal sig(remote_signal); | |
| pto::comm::TNOTIFY(sig, (int32_t)1, pto::comm::NotifyOp::AtomicAdd); | |
| } | |
| for (int peer = 0; peer < nranks; ++peer) { | |
| if (peer == my_rank) continue; | |
| pto::comm::Signal sig(signal_base + peer); | |
| pto::comm::TWAIT(sig, (int32_t)1, pto::comm::WaitCmp::GE); | |
| } | |
| int32_t wait_target[kMaxSupportedRanks]; | |
| for (int peer = 0; peer < nranks; ++peer) { | |
| wait_target[peer] = signal_base[peer]; | |
| if (peer != my_rank) { | |
| wait_target[peer] += 1; | |
| } | |
| } | |
| for (int peer = 0; peer < nranks; ++peer) { | |
| if (peer == my_rank) continue; | |
| __gm__ int32_t *remote_signal = CommRemotePtr(commCtx, signal_base + my_rank, peer); | |
| pto::comm::Signal sig(remote_signal); | |
| pto::comm::TNOTIFY(sig, (int32_t)1, pto::comm::NotifyOp::AtomicAdd); | |
| } | |
| for (int peer = 0; peer < nranks; ++peer) { | |
| if (peer == my_rank) continue; | |
| pto::comm::Signal sig(signal_base + peer); | |
| pto::comm::TWAIT(sig, wait_target[peer], pto::comm::WaitCmp::GE); | |
| } |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/workers/l3/all_to_all_distributed/kernels/aiv/all_to_all_kernel.cpp`
around lines 108 - 118, The barrier can pass early because signal slots may be
non-zero; fix by making waits monotonic: read the current local counter from
pto::comm::Signal(signal_base + my_rank) to compute a per-phase target = current
+ 1, then perform the remote increments with CommRemotePtr/pto::comm::TNOTIFY as
before and change the waits to TWAIT(..., target, ...) against each peer's
signal_base slot (use the same target for all peers) instead of waiting for >=
1; reference pto::comm::Signal, CommRemotePtr, TNOTIFY and TWAIT to locate and
update the logic.
Complete the canonical collective set with two new examples that follow the existing scratch-window + TNOTIFY/TWAIT pattern used by allgather and reduce-scatter. - broadcast_distributed: root stages, barrier, all ranks read root scratch - all_to_all_distributed: dest-indexed scratch staging and peer gather - pytest wrappers parametrize 2 and 4 devices on a2a3sim/a2a3/a5sim - README: index allgather, reduce_scatter, broadcast, and all_to_all rows
910d8ff to
e7c8e25
Compare
L3 subprocesses fork chip children and load torch/libomp; running several in parallel on macos-latest has caused sporadic SIGABRT flakes in unrelated collectives. Linux sim jobs keep --max-parallel auto.
Pin Linux st-sim jobs below auto to reduce L3 resource-phase native flakes while keeping macOS at --max-parallel 1. Document both caps in docs/ci.md.
Complete the canonical collective set with two new examples that follow the existing scratch-window + TNOTIFY/TWAIT pattern used by allgather and reduce-scatter.