diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index f517d4e6a..50ad3e291 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -622,6 +622,58 @@ mod tests { assert!(done_rx.recv().await.unwrap()); } + #[tokio::test] + async fn test_pingpong_full_mesh() { + use hyperactor::test_utils::pingpong::PingPongActor; + use hyperactor::test_utils::pingpong::PingPongActorParams; + use hyperactor::test_utils::pingpong::PingPongMessage; + + use futures::future::join_all; + + const X: usize = 3; + const Y: usize = 3; + const Z: usize = 3; + let alloc = $allocator + .allocate(AllocSpec { + shape: shape! { x = X, y = Y, z = Z }, + constraints: Default::default(), + }) + .await + .unwrap(); + + let proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); + let (undeliverable_tx, _undeliverable_rx) = proc_mesh.client().open_port(); + let params = PingPongActorParams::new(undeliverable_tx.bind(), None); + let actor_mesh = proc_mesh.spawn::("pingpong", ¶ms).await.unwrap(); + let slice = actor_mesh.shape().slice(); + + let mut futures = Vec::new(); + for rank in slice.iter() { + let actor = actor_mesh.get(rank).unwrap(); + let coords = (&slice.coordinates(rank).unwrap()[..]).try_into().unwrap(); + let sizes = (&slice.sizes())[..].try_into().unwrap(); + let neighbors = ndslice::utils::stencil::moore_neighbors::<3>(); + for neighbor_coords in ndslice::utils::apply_stencil(&coords, sizes, &neighbors) { + if let Ok(neighbor_rank) = slice.location(&neighbor_coords) { + let neighbor = actor_mesh.get(neighbor_rank).unwrap(); + let (done_tx, done_rx) = proc_mesh.client().open_once_port(); + actor + .send( + proc_mesh.client(), + PingPongMessage(4, neighbor.clone(), done_tx.bind()), + ) + .unwrap(); + futures.push(done_rx.recv()); + } + } + } + let results = join_all(futures).await; + assert_eq!(results.len(), 316); // 5180 messages + for result in results { + assert_eq!(result.unwrap(), true); + } + } + #[tokio::test] async fn test_cast() { let alloc = $allocator