From 83a56c0ccce71b133fd44816997372cd8b54e25c Mon Sep 17 00:00:00 2001 From: Dennis van der Staay Date: Thu, 20 Nov 2025 19:55:18 -0800 Subject: [PATCH 1/2] support concurrency (#1943) Summary: TL;DR: BEFORE: controlled flow by requiring python caller to obtain a QP ownership and hold for duration of call (.read_from/.write_into) AFTER: now we can cheaply clone QPs, and just use atomics to generate wr_id, and rely on ibverbs internal locks (ibv_post_send is thread-safe). Complexity introduced by Work completion events which may be returned out of order and only delivered once, so need to store any WC in seperate cache. ### Atomic Counters in rdmaxcel_qp_t for Lock-Free Operations The rdmaxcel_qp_t wrapper uses atomic counters to enable concurrent, lock-free work request posting: ``` typedef struct rdmaxcel_qp { struct ibv_qp* ibv_qp; struct ibv_cq* send_cq; struct ibv_cq* recv_cq; // Atomic counters for lock-free concurrent access _Atomic uint64_t send_wqe_idx; // Next send WQE slot _Atomic uint64_t send_db_idx; // Last doorbell rung _Atomic uint64_t recv_wqe_idx; // Next recv WQE slot _Atomic uint64_t recv_db_idx; // Last recv doorbell _Atomic uint64_t rts_timestamp; // Ready-to-send timestamp // Completion caches for efficient polling completion_cache_t* send_completion_cache; completion_cache_t* recv_completion_cache; } rdmaxcel_qp_t; ``` Key Benefits: Multiple threads can post work requests concurrently using fetch_add on atomic indices No locks needed for the hot path (posting operations) Each thread gets a unique WQE slot atomically Completion polling uses cached results to avoid redundant CQ polls ### Mutex-Protected Queue Pair Creation While operations are lock-free, QP creation is serialized using Rust Arc>: ``` pub struct RdmaManagerActor { // Track QPs currently being created to prevent duplicate creation pending_qp_creation: Arc>>, // ... } ``` Creation Flow: Thread checks if QP exists (lock-free read from HashMap) If not, acquires mutex and checks pending_qp_creation set If another thread is creating it, waits without holding lock Otherwise, inserts key into set, releases lock, and creates QP After creation, removes key from set This prevents race conditions where multiple threads try to create the same QP simultaneously while keeping the common path (using existing QPs) lock-free. ### Resource Lifecycle Management Simplified cleanup via rdmaxcel_qp_destroy: Previously: Rust manually destroyed ibv_qp and CQs separately (error-prone with concurrent access) Now: Single C function destroys all resources atomically Changed register_segments(pd, rdmaxcel_qp_t*) to work with wrapper instead of raw ibv_qp Reviewed By: casteryh Differential Revision: D87021168 --- monarch_rdma/Cargo.toml | 1 + monarch_rdma/src/rdma_components.rs | 511 +++++++++++-------- monarch_rdma/src/rdma_manager_actor.rs | 451 ++++++++-------- monarch_rdma/src/rdma_manager_actor_tests.rs | 212 ++++++-- monarch_rdma/src/test_utils.rs | 191 ++++--- rdmaxcel-sys/build.rs | 9 + rdmaxcel-sys/src/rdmaxcel.c | 367 ++++++++++++- rdmaxcel-sys/src/rdmaxcel.cpp | 10 +- rdmaxcel-sys/src/rdmaxcel.h | 131 ++++- 9 files changed, 1339 insertions(+), 544 deletions(-) diff --git a/monarch_rdma/Cargo.toml b/monarch_rdma/Cargo.toml index d4a70a5c9..d9176026d 100644 --- a/monarch_rdma/Cargo.toml +++ b/monarch_rdma/Cargo.toml @@ -14,6 +14,7 @@ edition = "2024" anyhow = "1.0.98" async-trait = "0.1.86" cuda-sys = { path = "../cuda-sys" } +futures = { version = "0.3.31", features = ["async-await", "compat"] } hyperactor = { version = "0.0.0", path = "../hyperactor" } rand = { version = "0.8", features = ["small_rng"] } rdmaxcel-sys = { path = "../rdmaxcel-sys" } diff --git a/monarch_rdma/src/rdma_components.rs b/monarch_rdma/src/rdma_components.rs index eaac40b86..7c600fe9e 100644 --- a/monarch_rdma/src/rdma_components.rs +++ b/monarch_rdma/src/rdma_components.rs @@ -144,9 +144,9 @@ impl RdmaBuffer { ) .await?; - qp.put(self.clone(), remote)?; + let wr_id = qp.put(self.clone(), remote)?; let result = self - .wait_for_completion(&mut qp, PollTarget::Send, timeout) + .wait_for_completion(&mut qp, PollTarget::Send, &wr_id, timeout) .await; // Release the queue pair back to the actor @@ -197,9 +197,9 @@ impl RdmaBuffer { remote_device.clone(), ) .await?; - qp.get(self.clone(), remote)?; + let wr_id = qp.get(self.clone(), remote)?; let result = self - .wait_for_completion(&mut qp, PollTarget::Send, timeout) + .wait_for_completion(&mut qp, PollTarget::Send, &wr_id, timeout) .await; // Release the queue pair back to the actor @@ -209,34 +209,47 @@ impl RdmaBuffer { result } - /// Waits for the completion of an RDMA operation. + /// Waits for the completion of RDMA operations. /// - /// This method polls the completion queue until the specified work request completes + /// This method polls the completion queue until all specified work requests complete /// or until the timeout is reached. /// /// # Arguments /// * `qp` - The RDMA Queue Pair to poll for completion + /// * `poll_target` - Which CQ to poll (Send or Recv) + /// * `expected_wr_ids` - The work request IDs to wait for /// * `timeout` - Timeout in seconds for the RDMA operation to complete. /// /// # Returns - /// `Ok(true)` if the operation completes successfully within the timeout, + /// `Ok(true)` if all operations complete successfully within the timeout, /// or an error if the timeout is reached async fn wait_for_completion( &self, qp: &mut RdmaQueuePair, poll_target: PollTarget, + expected_wr_ids: &[u64], timeout: u64, ) -> Result { let timeout = Duration::from_secs(timeout); let start_time = std::time::Instant::now(); + let mut remaining: std::collections::HashSet = + expected_wr_ids.iter().copied().collect(); + while start_time.elapsed() < timeout { - match qp.poll_completion_target(poll_target) { - Ok(Some(_wc)) => { - tracing::debug!("work completed"); - return Ok(true); - } - Ok(None) => { + if remaining.is_empty() { + return Ok(true); + } + + let wr_ids_to_poll: Vec = remaining.iter().copied().collect(); + match qp.poll_completion(poll_target, &wr_ids_to_poll) { + Ok(completions) => { + for (wr_id, _wc) in completions { + remaining.remove(&wr_id); + } + if remaining.is_empty() { + return Ok(true); + } RealClock.sleep(Duration::from_millis(1)).await; } Err(e) => { @@ -251,10 +264,14 @@ impl RdmaBuffer { } } } - tracing::error!("timed out while waiting on request completion"); + tracing::error!( + "timed out while waiting on request completion for wr_ids={:?}", + remaining + ); Err(anyhow::anyhow!( - "[buffer({:?})] rdma operation did not complete in time", - self + "[buffer({:?})] rdma operation did not complete in time (expected wr_ids={:?})", + self, + expected_wr_ids )) } @@ -287,6 +304,7 @@ impl RdmaBuffer { /// /// * `context`: A pointer to the RDMA device context, representing the connection to the RDMA device. /// * `pd`: A pointer to the protection domain, which provides isolation between different connections. +#[derive(Clone)] pub struct RdmaDomain { pub context: *mut rdmaxcel_sys::ibv_context, pub pd: *mut rdmaxcel_sys::ibv_pd, @@ -446,24 +464,23 @@ pub enum PollTarget { /// 3. Exchange connection info with remote peer (application must handle this) /// 4. Connect to remote endpoint with `connect()` /// 5. Perform RDMA operations with `put()` or `get()` -/// 6. Poll for completions with `poll_send_completion()` or `poll_recv_completion()` +/// 6. Poll for completions with `poll_send_completion(wr_id)` or `poll_recv_completion(wr_id)` +/// +/// # Notes +/// - The `qp` field stores a pointer to `rdmaxcel_qp_t` (not `ibv_qp`) +/// - `rdmaxcel_qp_t` contains atomic counters and completion caches internally +/// - This makes RdmaQueuePair trivially Clone and Serialize +/// - Multiple clones share the same underlying rdmaxcel_qp_t via the pointer #[derive(Debug, Serialize, Deserialize, Named, Clone)] pub struct RdmaQueuePair { pub send_cq: usize, // *mut rdmaxcel_sys::ibv_cq, pub recv_cq: usize, // *mut rdmaxcel_sys::ibv_cq, - pub qp: usize, // *mut rdmaxcel_sys::ibv_qp, + pub qp: usize, // *mut rdmaxcel_sys::rdmaxcel_qp_t pub dv_qp: usize, // *mut rdmaxcel_sys::mlx5dv_qp, pub dv_send_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq, pub dv_recv_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq, context: usize, // *mut rdmaxcel_sys::ibv_context, config: IbverbsConfig, - pub send_wqe_idx: u64, - pub send_db_idx: u64, - pub send_cq_idx: u64, - pub recv_wqe_idx: u64, - pub recv_db_idx: u64, - pub recv_cq_idx: u64, - rts_timestamp: u64, } impl RdmaQueuePair { @@ -472,22 +489,26 @@ impl RdmaQueuePair { /// This ensures the hardware has sufficient time to settle after reaching /// Ready-to-Send state before the first actual operation. fn apply_first_op_delay(&self, wr_id: u64) { - if wr_id == 0 { - assert!( - self.rts_timestamp != u64::MAX, - "First operation attempted before queue pair reached RTS state! Call connect() first." - ); - let current_nanos = RealClock - .system_time_now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_nanos() as u64; - let elapsed_nanos = current_nanos - self.rts_timestamp; - let elapsed = Duration::from_nanos(elapsed_nanos); - let init_delay = Duration::from_millis(self.config.hw_init_delay_ms); - if elapsed < init_delay { - let remaining_delay = init_delay - elapsed; - sleep(remaining_delay); + unsafe { + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; + if wr_id == 0 { + let rts_timestamp = rdmaxcel_sys::rdmaxcel_qp_load_rts_timestamp(qp); + assert!( + rts_timestamp != u64::MAX, + "First operation attempted before queue pair reached RTS state! Call connect() first." + ); + let current_nanos = RealClock + .system_time_now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos() as u64; + let elapsed_nanos = current_nanos - rts_timestamp; + let elapsed = Duration::from_nanos(elapsed_nanos); + let init_delay = Duration::from_millis(self.config.hw_init_delay_ms); + if elapsed < init_delay { + let remaining_delay = init_delay - elapsed; + sleep(remaining_delay); + } } } } @@ -521,8 +542,7 @@ impl RdmaQueuePair { unsafe { // Resolve Auto to a concrete QP type based on device capabilities let resolved_qp_type = resolve_qp_type(config.qp_type); - - let qp = rdmaxcel_sys::create_qp( + let qp = rdmaxcel_sys::rdmaxcel_qp_create( context, pd, config.cq_entries, @@ -541,18 +561,18 @@ impl RdmaQueuePair { )); } - let send_cq = (*qp).send_cq; - let recv_cq = (*qp).recv_cq; + let send_cq = (*(*qp).ibv_qp).send_cq; + let recv_cq = (*(*qp).ibv_qp).recv_cq; // mlx5dv provider APIs - let dv_qp = rdmaxcel_sys::create_mlx5dv_qp(qp); - let dv_send_cq = rdmaxcel_sys::create_mlx5dv_send_cq(qp); - let dv_recv_cq = rdmaxcel_sys::create_mlx5dv_recv_cq(qp); + let dv_qp = rdmaxcel_sys::create_mlx5dv_qp((*qp).ibv_qp); + let dv_send_cq = rdmaxcel_sys::create_mlx5dv_send_cq((*qp).ibv_qp); + let dv_recv_cq = rdmaxcel_sys::create_mlx5dv_recv_cq((*qp).ibv_qp); if dv_qp.is_null() || dv_send_cq.is_null() || dv_recv_cq.is_null() { - rdmaxcel_sys::ibv_destroy_cq((*qp).recv_cq); - rdmaxcel_sys::ibv_destroy_cq((*qp).send_cq); - rdmaxcel_sys::ibv_destroy_qp(qp); + rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq); + rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq); + rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp); return Err(anyhow::anyhow!( "failed to init mlx5dv_qp or completion queues" )); @@ -562,9 +582,9 @@ impl RdmaQueuePair { if config.use_gpu_direct { let ret = rdmaxcel_sys::register_cuda_memory(dv_qp, dv_recv_cq, dv_send_cq); if ret != 0 { - rdmaxcel_sys::ibv_destroy_cq((*qp).recv_cq); - rdmaxcel_sys::ibv_destroy_cq((*qp).send_cq); - rdmaxcel_sys::ibv_destroy_qp(qp); + rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq); + rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq); + rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp); return Err(anyhow::anyhow!( "failed to register GPU Direct RDMA memory: {:?}", ret @@ -575,18 +595,11 @@ impl RdmaQueuePair { send_cq: send_cq as usize, recv_cq: recv_cq as usize, qp: qp as usize, - dv_qp: dv_qp as usize, + dv_qp: qp as usize, dv_send_cq: dv_send_cq as usize, dv_recv_cq: dv_recv_cq as usize, context: context as usize, config, - recv_db_idx: 0, - recv_wqe_idx: 0, - recv_cq_idx: 0, - send_db_idx: 0, - send_wqe_idx: 0, - send_cq_idx: 0, - rts_timestamp: u64::MAX, }) } } @@ -615,7 +628,7 @@ impl RdmaQueuePair { // - The memory address provided is only stored, not dereferenced in this function unsafe { let context = self.context as *mut rdmaxcel_sys::ibv_context; - let qp = self.qp as *mut rdmaxcel_sys::ibv_qp; + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; let mut port_attr = rdmaxcel_sys::ibv_port_attr::default(); let errno = rdmaxcel_sys::ibv_query_port( context, @@ -642,7 +655,7 @@ impl RdmaQueuePair { } Ok(RdmaQpInfo { - qp_num: (*qp).qp_num, + qp_num: (*(*qp).ibv_qp).qp_num, lid: port_attr.lid, gid: Some(gid), psn: self.config.psn, @@ -653,7 +666,7 @@ impl RdmaQueuePair { pub fn state(&mut self) -> Result { // SAFETY: This block interacts with the RDMA device through rdmaxcel_sys calls. unsafe { - let qp = self.qp as *mut rdmaxcel_sys::ibv_qp; + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; let mut qp_attr = rdmaxcel_sys::ibv_qp_attr { ..Default::default() }; @@ -661,8 +674,12 @@ impl RdmaQueuePair { ..Default::default() }; let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE; - let errno = - rdmaxcel_sys::ibv_query_qp(qp, &mut qp_attr, mask.0 as i32, &mut qp_init_attr); + let errno = rdmaxcel_sys::ibv_query_qp( + (*qp).ibv_qp, + &mut qp_attr, + mask.0 as i32, + &mut qp_init_attr, + ); if errno != 0 { let os_error = Error::last_os_error(); return Err(anyhow::anyhow!("failed to query QP state: {}", os_error)); @@ -687,7 +704,7 @@ impl RdmaQueuePair { // 4. Memory access is properly bounded by the registered memory regions unsafe { // Transition to INIT - let qp = self.qp as *mut rdmaxcel_sys::ibv_qp; + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; let qp_access_flags = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE @@ -707,7 +724,7 @@ impl RdmaQueuePair { | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PORT | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS; - let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32); + let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32); if errno != 0 { let os_error = Error::last_os_error(); return Err(anyhow::anyhow!( @@ -755,7 +772,7 @@ impl RdmaQueuePair { | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER; - let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32); + let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32); if errno != 0 { let os_error = Error::last_os_error(); return Err(anyhow::anyhow!( @@ -782,7 +799,7 @@ impl RdmaQueuePair { | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RNR_RETRY | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC; - let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32); + let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32); if errno != 0 { let os_error = Error::last_os_error(); return Err(anyhow::anyhow!( @@ -796,56 +813,66 @@ impl RdmaQueuePair { ); // Record RTS time now that the queue pair is ready to send - self.rts_timestamp = RealClock + let rts_timestamp_nanos = RealClock .system_time_now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_nanos() as u64; + rdmaxcel_sys::rdmaxcel_qp_store_rts_timestamp(qp, rts_timestamp_nanos); Ok(()) } } - pub fn recv(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> { - let idx = self.recv_wqe_idx; - self.recv_wqe_idx += 1; - self.send_wqe( - 0, - lhandle.lkey, - 0, - idx, - true, - RdmaOperation::Recv, - 0, - rhandle.rkey, - ) - .unwrap(); - Ok(()) + pub fn recv(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result { + unsafe { + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; + let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_wqe_idx(qp); + self.post_op( + 0, + lhandle.lkey, + 0, + idx, + true, + RdmaOperation::Recv, + 0, + rhandle.rkey, + ) + .unwrap(); + rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_db_idx(qp); + Ok(idx) + } } pub fn put_with_recv( &mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer, - ) -> Result<(), anyhow::Error> { - let idx = self.send_wqe_idx; - self.send_wqe_idx += 1; - self.post_op( - lhandle.addr, - lhandle.lkey, - lhandle.size, - idx, - true, - RdmaOperation::WriteWithImm, - rhandle.addr, - rhandle.rkey, - ) - .unwrap(); - self.send_db_idx += 1; - Ok(()) + ) -> Result, anyhow::Error> { + unsafe { + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; + let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(qp); + self.post_op( + lhandle.addr, + lhandle.lkey, + lhandle.size, + idx, + true, + RdmaOperation::WriteWithImm, + rhandle.addr, + rhandle.rkey, + ) + .unwrap(); + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(qp); + Ok(vec![idx]) + } } - pub fn put(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> { + pub fn put( + &mut self, + lhandle: RdmaBuffer, + rhandle: RdmaBuffer, + ) -> Result, anyhow::Error> { let total_size = lhandle.size; if rhandle.size < total_size { return Err(anyhow::anyhow!( @@ -857,10 +884,15 @@ impl RdmaQueuePair { let mut remaining = total_size; let mut offset = 0; + let mut wr_ids = Vec::new(); while remaining > 0 { let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE); - let idx = self.send_wqe_idx; - self.send_wqe_idx += 1; + let idx = unsafe { + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx( + self.qp as *mut rdmaxcel_sys::rdmaxcel_qp, + ) + }; + wr_ids.push(idx); self.post_op( lhandle.addr + offset, lhandle.lkey, @@ -871,13 +903,17 @@ impl RdmaQueuePair { rhandle.addr + offset, rhandle.rkey, )?; - self.send_db_idx += 1; + unsafe { + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx( + self.qp as *mut rdmaxcel_sys::rdmaxcel_qp, + ); + } remaining -= chunk_size; offset += chunk_size; } - Ok(()) + Ok(wr_ids) } /// Get a doorbell for the queue pair. @@ -890,19 +926,23 @@ impl RdmaQueuePair { /// * `Result` - A doorbell for the queue pair pub fn ring_doorbell(&mut self) -> Result<(), anyhow::Error> { unsafe { + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp; let base_ptr = (*dv_qp).sq.buf as *mut u8; let wqe_cnt = (*dv_qp).sq.wqe_cnt; let stride = (*dv_qp).sq.stride; - if (wqe_cnt as u64) < (self.send_wqe_idx - self.send_db_idx) { + let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(qp); + let mut send_db_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_db_idx(qp); + if (wqe_cnt as u64) < (send_wqe_idx - send_db_idx) { return Err(anyhow::anyhow!("Overflow of WQE, possible data loss")); } - self.apply_first_op_delay(self.send_db_idx); - while self.send_db_idx < self.send_wqe_idx { - let offset = (self.send_db_idx % wqe_cnt as u64) * stride as u64; + self.apply_first_op_delay(send_db_idx); + while send_db_idx < send_wqe_idx { + let offset = (send_db_idx % wqe_cnt as u64) * stride as u64; let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize); rdmaxcel_sys::db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void); - self.send_db_idx += 1; + send_db_idx += 1; + rdmaxcel_sys::rdmaxcel_qp_store_send_db_idx(qp, send_db_idx); } Ok(()) } @@ -920,14 +960,18 @@ impl RdmaQueuePair { /// /// # Returns /// - /// * `Result<(), anyhow::Error>` - Success or error + /// * `Result, anyhow::Error>` - The work request IDs or error pub fn enqueue_put( &mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer, - ) -> Result<(), anyhow::Error> { - let idx = self.send_wqe_idx; - self.send_wqe_idx += 1; + ) -> Result, anyhow::Error> { + let idx = unsafe { + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx( + self.qp as *mut rdmaxcel_sys::rdmaxcel_qp, + ) + }; + self.send_wqe( lhandle.addr, lhandle.lkey, @@ -938,7 +982,7 @@ impl RdmaQueuePair { rhandle.addr, rhandle.rkey, )?; - Ok(()) + Ok(vec![idx]) } /// Enqueues a put with receive operation without ringing the doorbell. @@ -953,14 +997,18 @@ impl RdmaQueuePair { /// /// # Returns /// - /// * `Result<(), anyhow::Error>` - Success or error + /// * `Result, anyhow::Error>` - The work request IDs or error pub fn enqueue_put_with_recv( &mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer, - ) -> Result<(), anyhow::Error> { - let idx = self.send_wqe_idx; - self.send_wqe_idx += 1; + ) -> Result, anyhow::Error> { + let idx = unsafe { + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx( + self.qp as *mut rdmaxcel_sys::rdmaxcel_qp, + ) + }; + self.send_wqe( lhandle.addr, lhandle.lkey, @@ -971,7 +1019,7 @@ impl RdmaQueuePair { rhandle.addr, rhandle.rkey, )?; - Ok(()) + Ok(vec![idx]) } /// Enqueues a get operation without ringing the doorbell. @@ -986,14 +1034,18 @@ impl RdmaQueuePair { /// /// # Returns /// - /// * `Result<(), anyhow::Error>` - Success or error + /// * `Result, anyhow::Error>` - The work request IDs or error pub fn enqueue_get( &mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer, - ) -> Result<(), anyhow::Error> { - let idx = self.send_wqe_idx; - self.send_wqe_idx += 1; + ) -> Result, anyhow::Error> { + let idx = unsafe { + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx( + self.qp as *mut rdmaxcel_sys::rdmaxcel_qp, + ) + }; + self.send_wqe( lhandle.addr, lhandle.lkey, @@ -1004,10 +1056,14 @@ impl RdmaQueuePair { rhandle.addr, rhandle.rkey, )?; - Ok(()) + Ok(vec![idx]) } - pub fn get(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> { + pub fn get( + &mut self, + lhandle: RdmaBuffer, + rhandle: RdmaBuffer, + ) -> Result, anyhow::Error> { let total_size = lhandle.size; if rhandle.size < total_size { return Err(anyhow::anyhow!( @@ -1019,11 +1075,17 @@ impl RdmaQueuePair { let mut remaining = total_size; let mut offset = 0; + let mut wr_ids = Vec::new(); while remaining > 0 { let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE); - let idx = self.send_wqe_idx; - self.send_wqe_idx += 1; + let idx = unsafe { + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx( + self.qp as *mut rdmaxcel_sys::rdmaxcel_qp, + ) + }; + wr_ids.push(idx); + self.post_op( lhandle.addr + offset, lhandle.lkey, @@ -1034,13 +1096,17 @@ impl RdmaQueuePair { rhandle.addr + offset, rhandle.rkey, )?; - self.send_db_idx += 1; + unsafe { + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx( + self.qp as *mut rdmaxcel_sys::rdmaxcel_qp, + ); + } remaining -= chunk_size; offset += chunk_size; } - Ok(()) + Ok(wr_ids) } /// Posts a request to the queue pair. @@ -1073,7 +1139,7 @@ impl RdmaQueuePair { // - The ibverbs post_send operation follows the documented API contract // - Error codes from the device are properly checked and propagated unsafe { - let qp = self.qp as *mut rdmaxcel_sys::ibv_qp; + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; let context = self.context as *mut rdmaxcel_sys::ibv_context; let ops = &mut (*context).ops; let errno; @@ -1090,7 +1156,8 @@ impl RdmaQueuePair { ..Default::default() }; let mut bad_wr: *mut rdmaxcel_sys::ibv_recv_wr = std::ptr::null_mut(); - errno = ops.post_recv.as_mut().unwrap()(qp, &mut wr as *mut _, &mut bad_wr); + errno = + ops.post_recv.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr); } else if op_type == RdmaOperation::Write || op_type == RdmaOperation::Read || op_type == RdmaOperation::WriteWithImm @@ -1124,7 +1191,8 @@ impl RdmaQueuePair { wr.wr.rdma.rkey = rkey; let mut bad_wr: *mut rdmaxcel_sys::ibv_send_wr = std::ptr::null_mut(); - errno = ops.post_send.as_mut().unwrap()(qp, &mut wr as *mut _, &mut bad_wr); + errno = + ops.post_send.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr); } else { panic!("Not Implemented"); } @@ -1166,7 +1234,7 @@ impl RdmaQueuePair { RdmaOperation::Recv => 0, }; - let qp = self.qp as *mut rdmaxcel_sys::ibv_qp; + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp; let _dv_cq = if op_type == RdmaOperation::Recv { self.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq @@ -1191,7 +1259,7 @@ impl RdmaQueuePair { op_type: op_type_val, raddr, rkey, - qp_num: (*qp).qp_num, + qp_num: (*(*qp).ibv_qp).qp_num, buf, dbrec: (*dv_qp).dbrec, wqe_cnt: (*dv_qp).sq.wqe_cnt, @@ -1214,114 +1282,123 @@ impl RdmaQueuePair { } } - /// Poll for completions on the specified completion queue(s) + /// Poll for work completions by wr_ids. /// /// # Arguments /// - /// * `target` - Which completion queue(s) to poll (Send, Receive) + /// * `target` - Which completion queue to poll (Send, Receive) + /// * `expected_wr_ids` - Slice of work request IDs to wait for /// /// # Returns /// - /// * `Ok(Some(wc))` - A completion was found - /// * `Ok(None)` - No completion was found + /// * `Ok(Vec<(u64, IbvWc)>)` - Vector of (wr_id, completion) pairs that were found /// * `Err(e)` - An error occurred - pub fn poll_completion_target( + pub fn poll_completion( &mut self, target: PollTarget, - ) -> Result, anyhow::Error> { + expected_wr_ids: &[u64], + ) -> Result, anyhow::Error> { + if expected_wr_ids.is_empty() { + return Ok(Vec::new()); + } + unsafe { - let context = self.context as *mut rdmaxcel_sys::ibv_context; - let _outstanding_wqe = - self.send_db_idx + self.recv_db_idx - self.send_cq_idx - self.recv_cq_idx; + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; + let qp_num = (*(*qp).ibv_qp).qp_num; + + let (cq, cache, cq_type) = match target { + PollTarget::Send => ( + self.send_cq as *mut rdmaxcel_sys::ibv_cq, + rdmaxcel_sys::rdmaxcel_qp_get_send_cache(qp), + "send", + ), + PollTarget::Recv => ( + self.recv_cq as *mut rdmaxcel_sys::ibv_cq, + rdmaxcel_sys::rdmaxcel_qp_get_recv_cache(qp), + "recv", + ), + }; - // Check for send completions if requested - if (target == PollTarget::Send) && self.send_db_idx > self.send_cq_idx { - let send_cq = self.send_cq as *mut rdmaxcel_sys::ibv_cq; - let ops = &mut (*context).ops; - let mut wc = std::mem::MaybeUninit::::zeroed().assume_init(); - let ret = ops.poll_cq.as_mut().unwrap()(send_cq, 1, &mut wc); + let mut results = Vec::new(); - if ret < 0 { - return Err(anyhow::anyhow!( - "Failed to poll send CQ: {}", - Error::last_os_error() - )); - } + // Single-shot poll: check each wr_id once and return what we find + for &expected_wr_id in expected_wr_ids { + let mut poll_ctx = rdmaxcel_sys::poll_context_t { + expected_wr_id, + expected_qp_num: qp_num, + cache, + cq, + }; - if ret > 0 { - if !wc.is_valid() { - if let Some((status, vendor_err)) = wc.error() { - return Err(anyhow::anyhow!( - "Send work completion failed with status: {:?}, vendor error: {}, wr_id: {}, send_cq_idx: {}", - status, - vendor_err, - wc.wr_id(), - self.send_cq_idx, - )); + let mut wc = std::mem::MaybeUninit::::zeroed().assume_init(); + let ret = rdmaxcel_sys::poll_cq_with_cache(&mut poll_ctx, &mut wc); + + match ret { + 1 => { + // Found completion + if !wc.is_valid() { + if let Some((status, vendor_err)) = wc.error() { + return Err(anyhow::anyhow!( + "{} completion failed for wr_id={}: status={:?}, vendor_err={}", + cq_type, + expected_wr_id, + status, + vendor_err, + )); + } } + results.push((expected_wr_id, IbvWc::from(wc))); } - - // This should be a send completion - verify it's the one we're waiting for - if wc.wr_id() == self.send_cq_idx { - self.send_cq_idx += 1; - } - // finished polling, return the last completion - if self.send_cq_idx == self.send_db_idx { - return Ok(Some(IbvWc::from(wc))); + 0 => { + // Not found yet - this is fine for single-shot poll } - } - } - - // Check for receive completions if requested - if (target == PollTarget::Recv) && self.recv_db_idx > self.recv_cq_idx { - let recv_cq = self.recv_cq as *mut rdmaxcel_sys::ibv_cq; - let ops = &mut (*context).ops; - let mut wc = std::mem::MaybeUninit::::zeroed().assume_init(); - let ret = ops.poll_cq.as_mut().unwrap()(recv_cq, 1, &mut wc); - - if ret < 0 { - return Err(anyhow::anyhow!( - "Failed to poll receive CQ: {}", - Error::last_os_error() - )); - } - - if ret > 0 { - if !wc.is_valid() { + -17 => { + // RDMAXCEL_COMPLETION_FAILED: Completion found but failed - wc contains the error details + let error_msg = + std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret)) + .to_str() + .unwrap_or("Unknown error"); if let Some((status, vendor_err)) = wc.error() { return Err(anyhow::anyhow!( - "Recv work completion failed with status: {:?}, vendor error: {}, wr_id: {}, send_cq_idx: {}", + "Failed to poll {} CQ for wr_id={}: {} [status={:?}, vendor_err={}, qp_num={}, byte_len={}]", + cq_type, + expected_wr_id, + error_msg, status, vendor_err, - wc.wr_id(), - self.recv_cq_idx, + wc.qp_num, + wc.len(), + )); + } else { + return Err(anyhow::anyhow!( + "Failed to poll {} CQ for wr_id={}: {} [qp_num={}, byte_len={}]", + cq_type, + expected_wr_id, + error_msg, + wc.qp_num, + wc.len(), )); } } - - // This should be a send completion - verify it's the one we're waiting for - if wc.wr_id() == self.recv_cq_idx { - self.recv_cq_idx += 1; - } - // finished polling, return the last completion - if self.recv_cq_idx == self.recv_db_idx { - return Ok(Some(IbvWc::from(wc))); + _ => { + // Other errors + let error_msg = + std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret)) + .to_str() + .unwrap_or("Unknown error"); + return Err(anyhow::anyhow!( + "Failed to poll {} CQ for wr_id={}: {}", + cq_type, + expected_wr_id, + error_msg + )); } } } - // No completion found - Ok(None) + Ok(results) } } - - pub fn poll_send_completion(&mut self) -> Result, anyhow::Error> { - self.poll_completion_target(PollTarget::Send) - } - - pub fn poll_recv_completion(&mut self) -> Result, anyhow::Error> { - self.poll_completion_target(PollTarget::Recv) - } } /// Utility to validate execution context. diff --git a/monarch_rdma/src/rdma_manager_actor.rs b/monarch_rdma/src/rdma_manager_actor.rs index 2083d3fa3..936e2620f 100644 --- a/monarch_rdma/src/rdma_manager_actor.rs +++ b/monarch_rdma/src/rdma_manager_actor.rs @@ -28,8 +28,13 @@ //! //! See test examples: `test_rdma_write_loopback` and `test_rdma_read_loopback`. use std::collections::HashMap; +use std::collections::HashSet; +use std::sync::Arc; +use std::time::Duration; +use std::time::Instant; use async_trait::async_trait; +use futures::lock::Mutex; use hyperactor::Actor; use hyperactor::ActorId; use hyperactor::ActorRef; @@ -40,6 +45,7 @@ use hyperactor::Instance; use hyperactor::Named; use hyperactor::OncePortRef; use hyperactor::RefClient; +use hyperactor::clock::Clock; use hyperactor::supervision::ActorSupervisionEvent; use serde::Deserialize; use serde::Serialize; @@ -48,6 +54,7 @@ use crate::ibverbs_primitives::IbverbsConfig; use crate::ibverbs_primitives::RdmaMemoryRegionView; use crate::ibverbs_primitives::RdmaQpInfo; use crate::ibverbs_primitives::ibverbs_supported; +use crate::ibverbs_primitives::mlx5dv_supported; use crate::ibverbs_primitives::resolve_qp_type; use crate::rdma_components::RdmaBuffer; use crate::rdma_components::RdmaDomain; @@ -55,13 +62,6 @@ use crate::rdma_components::RdmaQueuePair; use crate::rdma_components::get_registered_cuda_segments; use crate::validate_execution_context; -/// Represents the state of a queue pair in the manager, either available or checked out. -#[derive(Debug, Clone)] -pub enum QueuePairState { - Available(RdmaQueuePair), - CheckedOut, -} - /// Helper function to get detailed error messages from RDMAXCEL error codes pub fn get_rdmaxcel_error_message(error_code: i32) -> String { unsafe { @@ -128,6 +128,14 @@ pub enum RdmaManagerMessage { /// `qp` - The queue pair to return (ownership transferred back) qp: RdmaQueuePair, }, + GetQpState { + other: ActorRef, + self_device: String, + other_device: String, + #[reply] + /// `reply` - Reply channel to return the QP state + reply: OncePortRef, + }, } #[derive(Debug)] @@ -138,12 +146,16 @@ pub enum RdmaManagerMessage { ], )] pub struct RdmaManagerActor { - // Nested map: local_device -> (ActorId, remote_device) -> QueuePairState - device_qps: HashMap>, + // Nested map: local_device -> (ActorId, remote_device) -> RdmaQueuePair + device_qps: HashMap>, + + // Track QPs currently being created to prevent duplicate creation + // Wrapped in Arc to allow safe concurrent access + pending_qp_creation: Arc>>, // Map of RDMA device names to their domains and loopback QPs // Created lazily when memory is registered for a specific device - device_domains: HashMap, + device_domains: HashMap)>, config: IbverbsConfig, @@ -171,14 +183,8 @@ impl Drop for RdmaManagerActor { fn destroy_queue_pair(qp: &RdmaQueuePair, context: &str) { unsafe { if qp.qp != 0 { - let result = rdmaxcel_sys::ibv_destroy_qp(qp.qp as *mut rdmaxcel_sys::ibv_qp); - if result != 0 { - tracing::debug!( - "ibv_destroy_qp returned {} for {} (may be busy during shutdown)", - result, - context - ); - } + let rdmaxcel_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp; + rdmaxcel_sys::rdmaxcel_qp_destroy(rdmaxcel_qp); } if qp.send_cq != 0 { let result = @@ -206,30 +212,17 @@ impl Drop for RdmaManagerActor { } // 1. Clean up all queue pairs (both regular and loopback) - for (device_name, device_map) in self.device_qps.drain() { - for ((actor_id, remote_device), qp_state) in device_map { - match qp_state { - QueuePairState::Available(qp) => { - destroy_queue_pair(&qp, &format!("actor {:?}", actor_id)); - } - QueuePairState::CheckedOut => { - tracing::warn!( - "QP for actor {:?} (device {} -> {}) was checked out during cleanup", - actor_id, - device_name, - remote_device - ); - } - } + for (_device_name, device_map) in self.device_qps.drain() { + for ((actor_id, _remote_device), qp) in device_map { + destroy_queue_pair(&qp, &format!("actor {:?}", actor_id)); } } // 2. Clean up device domains (which contain PDs and loopback QPs) - for (device_name, (domain, loopback_qp)) in self.device_domains.drain() { - destroy_queue_pair( - &loopback_qp, - &format!("loopback QP on device {}", device_name), - ); + for (device_name, (domain, qp)) in self.device_domains.drain() { + if let Some(qp) = qp { + destroy_queue_pair(&qp, &format!("loopback QP on device {}", device_name)); + } drop(domain); } @@ -278,10 +271,10 @@ impl RdmaManagerActor { &mut self, device_name: &str, rdma_device: &crate::ibverbs_primitives::RdmaDevice, - ) -> Result<(*mut rdmaxcel_sys::ibv_pd, *mut rdmaxcel_sys::ibv_qp), anyhow::Error> { + ) -> Result<(RdmaDomain, Option), anyhow::Error> { // Check if we already have a domain for this device if let Some((domain, qp)) = self.device_domains.get(device_name) { - return Ok((domain.pd, qp.qp as *mut rdmaxcel_sys::ibv_qp)); + return Ok((domain.clone(), qp.clone())); } // Create new domain for this device @@ -292,43 +285,38 @@ impl RdmaManagerActor { // Print device info if MONARCH_DEBUG_RDMA=1 is set (before initial QP creation) crate::print_device_info_if_debug_enabled(domain.context); - // Create loopback QP for this domain - let mut loopback_qp = RdmaQueuePair::new(domain.context, domain.pd, self.config.clone()) - .map_err(|e| { + // Create loopback QP for this domain if mlx5dv is supported + let qp = if mlx5dv_supported() { + let mut qp = RdmaQueuePair::new(domain.context, domain.pd, self.config.clone()) + .map_err(|e| { + anyhow::anyhow!( + "could not create loopback QP for device {}: {}", + device_name, + e + ) + })?; + + // Get connection info and connect to itself + let endpoint = qp.get_qp_info().map_err(|e| { + anyhow::anyhow!("could not get QP info for device {}: {}", device_name, e) + })?; + + qp.connect(&endpoint).map_err(|e| { anyhow::anyhow!( - "could not create loopback QP for device {}: {}", + "could not connect loopback QP for device {}: {}", device_name, e ) })?; - // Get connection info and connect to itself - let endpoint = loopback_qp.get_qp_info().map_err(|e| { - anyhow::anyhow!("could not get QP info for device {}: {}", device_name, e) - })?; - - loopback_qp.connect(&endpoint).map_err(|e| { - anyhow::anyhow!( - "could not connect loopback QP for device {}: {}", - device_name, - e - ) - })?; - - tracing::debug!( - "Created domain and loopback QP for RDMA device: {}", - device_name - ); - - // Store PD and QP pointers before inserting - let pd = domain.pd; - let qp = loopback_qp.qp as *mut rdmaxcel_sys::ibv_qp; + Some(qp) + } else { + None + }; - // Store the domain and QP self.device_domains - .insert(device_name.to_string(), (domain, loopback_qp)); - - Ok((pd, qp)) + .insert(device_name.to_string(), (domain.clone(), qp.clone())); + Ok((domain, qp)) } fn find_cuda_segment_for_address( @@ -417,8 +405,7 @@ impl RdmaManagerActor { ); // Get or create domain and loopback QP for this device - let (domain_pd, loopback_qp_ptr) = - self.get_or_create_device_domain(&device_name, &rdma_device)?; + let (domain, qp) = self.get_or_create_device_domain(&device_name, &rdma_device)?; let access = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE @@ -433,7 +420,10 @@ impl RdmaManagerActor { let mut maybe_mrv = self.find_cuda_segment_for_address(addr, size); // not found, lets re-sync with caching allocator and retry if maybe_mrv.is_none() { - let err = rdmaxcel_sys::register_segments(domain_pd, loopback_qp_ptr); + let err = rdmaxcel_sys::register_segments( + domain.pd, + qp.unwrap().qp as *mut rdmaxcel_sys::rdmaxcel_qp_t, + ); if err != 0 { let error_msg = get_rdmaxcel_error_message(err); return Err(anyhow::anyhow!( @@ -464,7 +454,7 @@ impl RdmaManagerActor { rdmaxcel_sys::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0, ); - mr = rdmaxcel_sys::ibv_reg_dmabuf_mr(domain_pd, 0, size, 0, fd, access.0 as i32); + mr = rdmaxcel_sys::ibv_reg_dmabuf_mr(domain.pd, 0, size, 0, fd, access.0 as i32); if mr.is_null() { return Err(anyhow::anyhow!("Failed to register dmabuf MR")); } @@ -480,7 +470,7 @@ impl RdmaManagerActor { } else { // CPU memory path mr = rdmaxcel_sys::ibv_reg_mr( - domain_pd, + domain.pd, addr as *mut std::ffi::c_void, size, access.0 as i32, @@ -561,6 +551,7 @@ impl Actor for RdmaManagerActor { Ok(Self { device_qps: HashMap::new(), + pending_qp_creation: Arc::new(Mutex::new(HashSet::new())), device_domains: HashMap::new(), config, pt_cuda_alloc, @@ -664,121 +655,170 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { &mut self, cx: &Context, other: ActorRef, - self_device: String, other_device: String, ) -> Result { let other_id = other.actor_id().clone(); - // Use the nested map structure: local_device -> (actor_id, remote_device) -> QueuePairState + // Use the nested map structure: local_device -> (actor_id, remote_device) -> RdmaQueuePair let inner_key = (other_id.clone(), other_device.clone()); // Check if queue pair exists in map if let Some(device_map) = self.device_qps.get(&self_device) { - if let Some(qp_state) = device_map.get(&inner_key).cloned() { - match qp_state { - QueuePairState::Available(qp) => { - // Queue pair exists and is available - return it - self.device_qps - .get_mut(&self_device) - .unwrap() - .insert(inner_key, QueuePairState::CheckedOut); - return Ok(qp); - } - QueuePairState::CheckedOut => { - return Err(anyhow::anyhow!( - "queue pair for actor {} on device {} is already checked out", - other_id, - other_device - )); + if let Some(qp) = device_map.get(&inner_key) { + return Ok(qp.clone()); + } + } + + // Try to acquire lock and mark as pending (hold lock only once!) + let pending_key = (self_device.clone(), other_id.clone(), other_device.clone()); + let mut pending = self.pending_qp_creation.lock().await; + + if pending.contains(&pending_key) { + // Another task is creating this QP, release lock and wait + drop(pending); + + // Loop checking device_qps until QP is created (no more locks needed) + // Timeout after 1 second + let start = Instant::now(); + let timeout = Duration::from_secs(1); + + loop { + cx.clock().sleep(Duration::from_micros(200)).await; + + // Check if QP was created while we waited + if let Some(device_map) = self.device_qps.get(&self_device) { + if let Some(qp) = device_map.get(&inner_key) { + return Ok(qp.clone()); } } + + // Check for timeout + if start.elapsed() >= timeout { + return Err(anyhow::anyhow!( + "Timeout waiting for QP creation (device {} -> actor {} device {}). \ + Another task is creating it but hasn't completed in 1 second", + self_device, + other_id, + other_device + )); + } } + } else { + // Not pending, add to set and proceed with creation + pending.insert(pending_key.clone()); + drop(pending); + // Fall through to create QP } // Queue pair doesn't exist - need to create connection - let is_loopback = other_id == cx.bind::().actor_id().clone() - && self_device == other_device; - - if is_loopback { - // Loopback connection setup - self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone()) - .await?; - let endpoint = self - .connection_info(cx, other.clone(), other_device.clone(), self_device.clone()) - .await?; - self.connect( - cx, - other.clone(), - self_device.clone(), - other_device.clone(), - endpoint, - ) - .await?; - } else { - // Remote connection setup - self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone()) - .await?; - other - .initialize_qp( + let result = async { + let is_loopback = other_id == cx.bind::().actor_id().clone() + && self_device == other_device; + + if is_loopback { + // Loopback connection setup + self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone()) + .await?; + let endpoint = self + .connection_info(cx, other.clone(), other_device.clone(), self_device.clone()) + .await?; + self.connect( cx, - cx.bind().clone(), - other_device.clone(), + other.clone(), self_device.clone(), - ) - .await?; - let other_endpoint: RdmaQpInfo = other - .connection_info( - cx, - cx.bind().clone(), other_device.clone(), - self_device.clone(), + endpoint, ) .await?; - self.connect( - cx, - other.clone(), - self_device.clone(), - other_device.clone(), - other_endpoint, - ) - .await?; - let local_endpoint = self - .connection_info(cx, other.clone(), self_device.clone(), other_device.clone()) - .await?; - other - .connect( + } else { + // Remote connection setup + self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone()) + .await?; + other + .initialize_qp( + cx, + cx.bind().clone(), + other_device.clone(), + self_device.clone(), + ) + .await?; + let other_endpoint: RdmaQpInfo = other + .connection_info( + cx, + cx.bind().clone(), + other_device.clone(), + self_device.clone(), + ) + .await?; + self.connect( cx, - cx.bind().clone(), - other_device.clone(), + other.clone(), self_device.clone(), - local_endpoint, + other_device.clone(), + other_endpoint, ) .await?; - } + let local_endpoint = self + .connection_info(cx, other.clone(), self_device.clone(), other_device.clone()) + .await?; + other + .connect( + cx, + cx.bind().clone(), + other_device.clone(), + self_device.clone(), + local_endpoint, + ) + .await?; + + // BARRIER: Ensure remote side has completed its connection and is ready + let remote_state = other + .get_qp_state( + cx, + cx.bind().clone(), + other_device.clone(), + self_device.clone(), + ) + .await?; + + if remote_state != rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS { + return Err(anyhow::anyhow!( + "Remote QP not in RTS state after connection setup. \ + Local is ready but remote is in state {}. \ + This indicates a synchronization issue in connection setup.", + remote_state + )); + } + } - // Now that connection is established, get the queue pair - if let Some(device_map) = self.device_qps.get(&self_device) { - if let Some(QueuePairState::Available(qp)) = device_map.get(&inner_key).cloned() { - self.device_qps - .get_mut(&self_device) - .unwrap() - .insert(inner_key, QueuePairState::CheckedOut); - Ok(qp) + // Now that connection is established, get and clone the queue pair + if let Some(device_map) = self.device_qps.get(&self_device) { + if let Some(qp) = device_map.get(&inner_key) { + Ok(qp.clone()) + } else { + Err(anyhow::anyhow!( + "Failed to create connection for actor {} on device {}", + other_id, + other_device + )) + } } else { Err(anyhow::anyhow!( - "Failed to create connection for actor {} on device {}", + "Failed to create connection for actor {} on device {} - no device map", other_id, other_device )) } - } else { - Err(anyhow::anyhow!( - "Failed to create connection for actor {} on device {} - no device map", - other_id, - other_device - )) } + .await; + + // Always remove from pending set when done (success or failure) + let mut pending = self.pending_qp_creation.lock().await; + pending.remove(&pending_key); + drop(pending); + + result } async fn initialize_qp( @@ -813,13 +853,7 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { // Get or create domain and extract pointers to avoid borrowing issues let (domain_context, domain_pd) = { // Check if we already have a domain for the device - if !self.device_domains.contains_key(&self_device) { - // Create domain first if it doesn't exist - self.get_or_create_device_domain(&self_device, &rdma_device)?; - } - - // Now get the domain context and PD safely - let (domain, _qp) = self.device_domains.get(&self_device).unwrap(); + let (domain, _) = self.get_or_create_device_domain(&self_device, &rdma_device)?; (domain.context, domain.pd) }; @@ -830,7 +864,7 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { self.device_qps .entry(self_device.clone()) .or_insert_with(HashMap::new) - .insert(inner_key, QueuePairState::Available(qp)); + .insert(inner_key, qp); tracing::debug!( "successfully created a connection with {:?} for local device {} -> remote device {}", @@ -858,21 +892,16 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { tracing::debug!("connecting with {:?}", other); let other_id = other.actor_id().clone(); - // For backward compatibility, use default device let inner_key = (other_id.clone(), other_device.clone()); if let Some(device_map) = self.device_qps.get_mut(&self_device) { match device_map.get_mut(&inner_key) { - Some(QueuePairState::Available(qp)) => { + Some(qp) => { qp.connect(&endpoint).map_err(|e| { anyhow::anyhow!("could not connect to RDMA endpoint: {}", e) })?; Ok(()) } - Some(QueuePairState::CheckedOut) => Err(anyhow::anyhow!( - "Cannot connect: queue pair for actor {} is checked out", - other_id - )), None => Err(anyhow::anyhow!( "No connection found for actor {}", other_id @@ -907,14 +936,10 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { if let Some(device_map) = self.device_qps.get_mut(&self_device) { match device_map.get_mut(&inner_key) { - Some(QueuePairState::Available(qp)) => { + Some(qp) => { let connection_info = qp.get_qp_info()?; Ok(connection_info) } - Some(QueuePairState::CheckedOut) => Err(anyhow::anyhow!( - "Cannot get connection info: queue pair for actor {} is checked out", - other_id - )), None => Err(anyhow::anyhow!( "No connection found for actor {}", other_id @@ -930,47 +955,59 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { /// Releases a queue pair back to the HashMap /// - /// This method returns a queue pair to the HashMap after the caller has finished - /// using it. This completes the request/release cycle, similar to RdmaBuffer. + /// This method is now a no-op since RdmaQueuePair is Clone and can be safely shared. + /// The queue pair is not actually checked out, so there's nothing to release. + /// This method is kept for API compatibility. /// /// # Arguments /// * `remote` - The ActorRef of the remote actor to return the queue pair for - /// * `qp` - The queue pair to release + /// * `qp` - The queue pair to release (ignored) async fn release_queue_pair( + &mut self, + _cx: &Context, + _other: ActorRef, + _self_device: String, + _other_device: String, + _qp: RdmaQueuePair, + ) -> Result<(), anyhow::Error> { + // No-op: Queue pairs are now cloned and shared via atomic counters + // Nothing needs to be released + Ok(()) + } + + /// Gets the state of a queue pair + /// + /// # Arguments + /// * `other` - The ActorRef to get the QP state for + /// * `self_device` - Local device name + /// * `other_device` - Remote device name + /// + /// # Returns + /// * `u32` - The QP state (e.g., IBV_QPS_RTS = Ready To Send) + async fn get_qp_state( &mut self, _cx: &Context, other: ActorRef, self_device: String, other_device: String, - qp: RdmaQueuePair, - ) -> Result<(), anyhow::Error> { - let inner_key = (other.actor_id().clone(), other_device.clone()); - - match self - .device_qps - .get_mut(&self_device) - .unwrap() - .get_mut(&inner_key) - { - Some(QueuePairState::CheckedOut) => { - self.device_qps - .get_mut(&self_device) - .unwrap() - .insert(inner_key, QueuePairState::Available(qp)); - Ok(()) + ) -> Result { + let other_id = other.actor_id().clone(); + let inner_key = (other_id.clone(), other_device.clone()); + + if let Some(device_map) = self.device_qps.get_mut(&self_device) { + match device_map.get_mut(&inner_key) { + Some(qp) => qp.state(), + None => Err(anyhow::anyhow!( + "No connection found for actor {} on device {}", + other_id, + other_device + )), } - Some(QueuePairState::Available(_)) => Err(anyhow::anyhow!( - "Cannot release queue pair: queue pair for actor {} is already available between devices {} and {}", - other.actor_id(), - self_device, - other_device, - )), - None => Err(anyhow::anyhow!( - "No queue pair found for actor {}, between devices {} and {}", - other.actor_id(), - self_device, - other_device, - )), + } else { + Err(anyhow::anyhow!( + "No device map found for self device {}", + self_device + )) } } } diff --git a/monarch_rdma/src/rdma_manager_actor_tests.rs b/monarch_rdma/src/rdma_manager_actor_tests.rs index 4e6ba4648..01938d83c 100644 --- a/monarch_rdma/src/rdma_manager_actor_tests.rs +++ b/monarch_rdma/src/rdma_manager_actor_tests.rs @@ -41,10 +41,10 @@ mod tests { env.rdma_handle_2.device_name.clone(), ) .await?; - qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; + let wr_id = qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; // Poll for completion - wait_for_completion(&mut qp_1, PollTarget::Send, 2).await?; + wait_for_completion(&mut qp_1, PollTarget::Send, &wr_id, 2).await?; env.actor_1 .release_queue_pair( @@ -56,7 +56,7 @@ mod tests { ) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; Ok(()) } @@ -79,9 +79,9 @@ mod tests { env.rdma_handle_2.device_name.clone(), ) .await?; - qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; + let wr_id = qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; - wait_for_completion(&mut qp_1, PollTarget::Send, 2).await?; + wait_for_completion(&mut qp_1, PollTarget::Send, &wr_id, 2).await?; env.actor_1 .release_queue_pair( @@ -93,7 +93,7 @@ mod tests { ) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; Ok(()) } @@ -118,12 +118,12 @@ mod tests { env.rdma_handle_2.device_name.clone(), ) .await?; - qp_1.get(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; + let wr_id = qp_1.get(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; // Poll for completion - wait_for_completion(&mut qp_1, PollTarget::Send, 2).await?; + wait_for_completion(&mut qp_1, PollTarget::Send, &wr_id, 2).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; Ok(()) } @@ -148,10 +148,10 @@ mod tests { env.rdma_handle_2.device_name.clone(), ) .await?; - qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; - wait_for_completion(&mut qp_1, PollTarget::Send, 2).await?; + let wr_id = qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; + wait_for_completion(&mut qp_1, PollTarget::Send, &wr_id, 2).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -185,10 +185,10 @@ mod tests { env.rdma_handle_1.device_name.clone(), ) .await?; - qp_2.put_with_recv(env.rdma_handle_2.clone(), env.rdma_handle_1.clone())?; qp_1.recv(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; - wait_for_completion(&mut qp_2, PollTarget::Send, 5).await?; - env.verify_buffers(BSIZE).await?; + let wr_id = qp_2.put_with_recv(env.rdma_handle_2.clone(), env.rdma_handle_1.clone())?; + wait_for_completion(&mut qp_2, PollTarget::Send, &wr_id, 5).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -214,12 +214,12 @@ mod tests { env.rdma_handle_2.device_name.clone(), ) .await?; - qp_1.enqueue_put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; + let wr_id = qp_1.enqueue_put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; qp_1.ring_doorbell()?; // Poll for completion - wait_for_completion(&mut qp_1, PollTarget::Send, 5).await?; + wait_for_completion(&mut qp_1, PollTarget::Send, &wr_id, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; Ok(()) } @@ -244,12 +244,12 @@ mod tests { env.rdma_handle_1.device_name.clone(), ) .await?; - qp_2.enqueue_get(env.rdma_handle_2.clone(), env.rdma_handle_1.clone())?; + let wr_id = qp_2.enqueue_get(env.rdma_handle_2.clone(), env.rdma_handle_1.clone())?; qp_2.ring_doorbell()?; // Poll for completion - wait_for_completion(&mut qp_2, PollTarget::Send, 5).await?; + wait_for_completion(&mut qp_2, PollTarget::Send, &wr_id, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; Ok(()) } @@ -270,7 +270,7 @@ mod tests { .read_into(env.client_1, env.rdma_handle_2.clone(), 2) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -289,10 +289,10 @@ mod tests { let env = RdmaManagerTestEnv::setup(BSIZE, "cpu:0", "cpu:1").await?; let /*mut*/ rdma_handle_1 = env.rdma_handle_1.clone(); rdma_handle_1 - .write_from(env.client_1, env.rdma_handle_2.clone(), 2) + .write_from(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; Ok(()) } @@ -342,7 +342,7 @@ mod tests { // Poll for completion wait_for_completion_gpu(&mut qp_1, PollTarget::Send, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; Ok(()) } @@ -377,11 +377,11 @@ mod tests { ) .await?; qp_1.enqueue_get(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; - ring_db_gpu(&mut qp_1).await?; + ring_db_gpu(&qp_1).await?; // Poll for completion wait_for_completion_gpu(&mut qp_1, PollTarget::Send, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; Ok(()) } @@ -438,9 +438,9 @@ mod tests { ) .await?; ring_db_gpu(&mut qp_2).await?; - wait_for_completion_gpu(&mut qp_1, PollTarget::Recv, 10).await?; + wait_for_completion_gpu(&mut qp_1, PollTarget::Send, 10).await?; wait_for_completion_gpu(&mut qp_2, PollTarget::Send, 10).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -471,11 +471,11 @@ mod tests { env.rdma_handle_2.device_name.clone(), ) .await?; - qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; + let wr_id = qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; - wait_for_completion(&mut qp_1, PollTarget::Send, 5).await?; + wait_for_completion(&mut qp_1, PollTarget::Send, &wr_id, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -506,11 +506,11 @@ mod tests { env.rdma_handle_2.device_name.clone(), ) .await?; - qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; + let wr_id = qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; - wait_for_completion(&mut qp_1, PollTarget::Send, 5).await?; + wait_for_completion(&mut qp_1, PollTarget::Send, &wr_id, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -538,7 +538,7 @@ mod tests { .read_into(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -564,7 +564,7 @@ mod tests { .read_into(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -590,7 +590,7 @@ mod tests { .read_into(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -616,7 +616,53 @@ mod tests { .write_from(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; + env.cleanup().await?; + Ok(()) + } + + #[timed_test::async_timed_test(timeout_secs = 30)] + async fn test_concurrent_send_to_same_target() -> Result<(), anyhow::Error> { + const BSIZE: usize = 2 * 1024 * 1024; + let devices = get_all_devices(); + if devices.is_empty() { + println!("Skipping test: RDMA devices not available"); + return Ok(()); + } + + let env = RdmaManagerTestEnv::setup(BSIZE, "cuda:0", "cuda:1").await?; + + let rdma_handle_1 = env.rdma_handle_1.clone(); + let rdma_handle_2 = env.rdma_handle_2.clone(); + let rdma_handle_3 = env.rdma_handle_1.clone(); + let rdma_handle_4 = env.rdma_handle_2.clone(); + let rdma_handle_5 = env.rdma_handle_1.clone(); + let rdma_handle_6 = env.rdma_handle_2.clone(); + let client = env.client_1; + + let task1 = async { + let result = rdma_handle_1 + .write_from(client, rdma_handle_2.clone(), 2) + .await; + result + }; + + let task2 = async { + let result = rdma_handle_3 + .write_from(client, rdma_handle_4.clone(), 2) + .await; + result + }; + + let task3 = async { + let result = rdma_handle_5 + .write_from(client, rdma_handle_6.clone(), 2) + .await; + result + }; + let (result1, result2, result3) = tokio::join!(task1, task2, task3); + + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -644,7 +690,7 @@ mod tests { .read_into(env.client_1, env.rdma_handle_2.clone(), 2) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -672,7 +718,7 @@ mod tests { .write_from(env.client_1, env.rdma_handle_2.clone(), 2) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -704,7 +750,7 @@ mod tests { .read_into(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -736,8 +782,88 @@ mod tests { .write_from(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; + env.cleanup().await?; + Ok(()) + } + + #[timed_test::async_timed_test(timeout_secs = 120)] + async fn test_rdma_read_chunk() -> Result<(), anyhow::Error> { + if is_cpu_only_mode() { + println!("Skipping CUDA test in CPU-only mode"); + return Ok(()); + } + // 2GB tensor size + const BSIZE: usize = 2 * 1024 * 1024 * 1024; + const CHUNK_SIZE: usize = 2 * 1024 * 1024; // 2MB + let devices = get_all_devices(); + if devices.len() < 5 { + println!( + "Skipping test: requires H100 nodes with backend network (found {} devices)", + devices.len() + ); + return Ok(()); + } + + println!("Setting up 2GB CUDA tensors for RDMA read test..."); + let env = RdmaManagerTestEnv::setup(BSIZE, "cuda:0", "cuda:1").await?; + + println!("Performing RDMA read operation on 2GB tensor..."); + let rdma_handle_1 = env.rdma_handle_1.clone(); + rdma_handle_1 + .read_into(env.client_1, env.rdma_handle_2.clone(), 30) + .await?; + + println!("Verifying first 2MB..."); + env.verify_buffers(CHUNK_SIZE, 0).await?; + + println!("Verifying last 2MB..."); + env.verify_buffers(CHUNK_SIZE, BSIZE - CHUNK_SIZE).await?; + + println!("Cleaning up..."); + env.cleanup().await?; + + println!("2GB RDMA read test completed successfully"); + Ok(()) + } + + #[timed_test::async_timed_test(timeout_secs = 60)] + async fn test_rdma_write_chunk() -> Result<(), anyhow::Error> { + if is_cpu_only_mode() { + println!("Skipping CUDA test in CPU-only mode"); + return Ok(()); + } + // 2GB tensor size + const BSIZE: usize = 2 * 1024 * 1024 * 1024; + const CHUNK_SIZE: usize = 2 * 1024 * 1024; // 2MB + let devices = get_all_devices(); + if devices.len() < 5 { + println!( + "Skipping test: requires H100 nodes with backend network (found {} devices)", + devices.len() + ); + return Ok(()); + } + + println!("Setting up 2GB CUDA tensors for RDMA write test..."); + let env = RdmaManagerTestEnv::setup(BSIZE, "cuda:0", "cuda:1").await?; + + println!("Performing RDMA write operation on 2GB tensor..."); + let rdma_handle_1 = env.rdma_handle_1.clone(); + rdma_handle_1 + .write_from(env.client_1, env.rdma_handle_2.clone(), 30) + .await?; + + println!("Verifying first 2MB..."); + env.verify_buffers(CHUNK_SIZE, 0).await?; + + println!("Verifying last 2MB..."); + env.verify_buffers(CHUNK_SIZE, BSIZE - CHUNK_SIZE).await?; + + println!("Cleaning up..."); env.cleanup().await?; + + println!("2GB RDMA write test completed successfully"); Ok(()) } } diff --git a/monarch_rdma/src/test_utils.rs b/monarch_rdma/src/test_utils.rs index 96c3cce30..58d2a88ea 100644 --- a/monarch_rdma/src/test_utils.rs +++ b/monarch_rdma/src/test_utils.rs @@ -307,26 +307,47 @@ pub mod test_utils { } } - // Waits for the completion of an RDMA operation. - - // This function polls for the completion of an RDMA operation by repeatedly - // sending a `PollCompletion` message to the specified actor mesh and checking - // the returned work completion status. It continues polling until the operation - // completes or the specified timeout is reached. - + /// Waits for the completion of RDMA operations. + /// + /// This function polls for the completion of RDMA operations by repeatedly + /// checking the completion queue until all expected work requests complete + /// or the specified timeout is reached. + /// + /// # Arguments + /// * `qp` - The RDMA Queue Pair to poll for completion + /// * `poll_target` - Which CQ to poll (Send or Recv) + /// * `expected_wr_ids` - Slice of work request IDs to wait for + /// * `timeout_secs` - Timeout in seconds + /// + /// # Returns + /// `Ok(true)` if all operations complete successfully within the timeout, + /// or an error if the timeout is reached pub async fn wait_for_completion( qp: &mut RdmaQueuePair, poll_target: PollTarget, + expected_wr_ids: &[u64], timeout_secs: u64, ) -> Result { let timeout = Duration::from_secs(timeout_secs); let start_time = Instant::now(); + + let mut remaining: std::collections::HashSet = + expected_wr_ids.iter().copied().collect(); + while start_time.elapsed() < timeout { - match qp.poll_completion_target(poll_target) { - Ok(Some(_wc)) => { - return Ok(true); - } - Ok(None) => { + if remaining.is_empty() { + return Ok(true); + } + + let wr_ids_to_poll: Vec = remaining.iter().copied().collect(); + match qp.poll_completion(poll_target, &wr_ids_to_poll) { + Ok(completions) => { + for (wr_id, _wc) in completions { + remaining.remove(&wr_id); + } + if remaining.is_empty() { + return Ok(true); + } RealClock.sleep(Duration::from_millis(1)).await; } Err(e) => { @@ -334,7 +355,10 @@ pub mod test_utils { } } } - Err(anyhow::Error::msg("Timeout while waiting for completion")) + Err(anyhow::Error::msg(format!( + "Timeout while waiting for completion of wr_ids: {:?}", + remaining + ))) } /// Posts a work request to the send queue of the given RDMA queue pair. @@ -345,25 +369,26 @@ pub mod test_utils { op_type: u32, ) -> Result<(), anyhow::Error> { unsafe { - let ibv_qp = qp.qp as *mut rdmaxcel_sys::ibv_qp; + let ibv_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp; let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp; + let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(ibv_qp); let params = rdmaxcel_sys::wqe_params_t { laddr: lhandle.addr, length: lhandle.size, lkey: lhandle.lkey, - wr_id: qp.send_wqe_idx, + wr_id: send_wqe_idx, signaled: true, op_type, raddr: rhandle.addr, rkey: rhandle.rkey, - qp_num: (*ibv_qp).qp_num, + qp_num: (*(*ibv_qp).ibv_qp).qp_num, buf: (*dv_qp).sq.buf as *mut u8, wqe_cnt: (*dv_qp).sq.wqe_cnt, dbrec: (*dv_qp).dbrec, ..Default::default() }; rdmaxcel_sys::launch_send_wqe(params); - qp.send_wqe_idx += 1; + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(ibv_qp); } Ok(()) } @@ -377,43 +402,53 @@ pub mod test_utils { ) -> Result<(), anyhow::Error> { // Populate params using lhandle and rhandle unsafe { - let ibv_qp = qp.qp as *mut rdmaxcel_sys::ibv_qp; + let rdmaxcel_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp; let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp; + let recv_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_recv_wqe_idx(rdmaxcel_qp); let params = rdmaxcel_sys::wqe_params_t { laddr: lhandle.addr, length: lhandle.size, lkey: lhandle.lkey, - wr_id: qp.recv_wqe_idx, + wr_id: recv_wqe_idx, op_type, signaled: true, - qp_num: (*ibv_qp).qp_num, + qp_num: (*(*rdmaxcel_qp).ibv_qp).qp_num, buf: (*dv_qp).rq.buf as *mut u8, wqe_cnt: (*dv_qp).rq.wqe_cnt, dbrec: (*dv_qp).dbrec, ..Default::default() }; rdmaxcel_sys::launch_recv_wqe(params); - qp.recv_wqe_idx += 1; - qp.recv_db_idx += 1; + rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_wqe_idx(rdmaxcel_qp); + rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_db_idx(rdmaxcel_qp); } Ok(()) } - pub async fn ring_db_gpu(qp: &mut RdmaQueuePair) -> Result<(), anyhow::Error> { + pub async fn ring_db_gpu(qp: &RdmaQueuePair) -> Result<(), anyhow::Error> { RealClock.sleep(Duration::from_millis(2)).await; unsafe { let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp; let base_ptr = (*dv_qp).sq.buf as *mut u8; let wqe_cnt = (*dv_qp).sq.wqe_cnt; let stride = (*dv_qp).sq.stride; - if (wqe_cnt as u64) < (qp.send_wqe_idx - qp.send_db_idx) { + let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx( + qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp, + ); + let mut send_db_idx = + rdmaxcel_sys::rdmaxcel_qp_load_send_db_idx(qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp); + if (wqe_cnt as u64) < (send_wqe_idx - send_db_idx) { return Err(anyhow::anyhow!("Overflow of WQE, possible data loss")); } - while qp.send_db_idx < qp.send_wqe_idx { - let offset = (qp.send_db_idx % wqe_cnt as u64) * stride as u64; + while send_db_idx < send_wqe_idx { + let offset = (send_db_idx % wqe_cnt as u64) * stride as u64; let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize); rdmaxcel_sys::launch_db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void); - qp.send_db_idx += 1; + send_db_idx += 1; + rdmaxcel_sys::rdmaxcel_qp_store_send_db_idx( + qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp, + send_db_idx, + ); } } Ok(()) @@ -425,48 +460,54 @@ pub mod test_utils { poll_target: PollTarget, timeout_secs: u64, ) -> Result { - let timeout = Duration::from_secs(timeout_secs); - let start_time = Instant::now(); - - while start_time.elapsed() < timeout { - // Get the appropriate completion queue and index based on the poll target - let (cq, idx, cq_type_str) = match poll_target { - PollTarget::Send => ( - qp.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq, - qp.send_cq_idx, - "send", - ), - PollTarget::Recv => ( - qp.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq, - qp.recv_cq_idx, - "receive", - ), - }; - - // Poll the completion queue - let result = - unsafe { rdmaxcel_sys::launch_cqe_poll(cq as *mut std::ffi::c_void, idx as i32) }; - - match result { - rdmaxcel_sys::CQE_POLL_TRUE => { - // Update the appropriate index based on the poll target - match poll_target { - PollTarget::Send => qp.send_cq_idx += 1, - PollTarget::Recv => qp.recv_cq_idx += 1, + unsafe { + let start_time = Instant::now(); + let timeout = Duration::from_secs(timeout_secs); + let ibv_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp; + + while start_time.elapsed() < timeout { + // Get the appropriate completion queue and index based on the poll target + let (cq, idx, cq_type_str) = match poll_target { + PollTarget::Send => ( + qp.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq, + rdmaxcel_sys::rdmaxcel_qp_load_send_cq_idx(ibv_qp), + "send", + ), + PollTarget::Recv => ( + qp.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq, + rdmaxcel_sys::rdmaxcel_qp_load_recv_cq_idx(ibv_qp), + "receive", + ), + }; + + // Poll the completion queue + let result = rdmaxcel_sys::launch_cqe_poll(cq as *mut std::ffi::c_void, idx as i32); + + match result { + rdmaxcel_sys::CQE_POLL_TRUE => { + // Update the appropriate index based on the poll target + match poll_target { + PollTarget::Send => { + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_cq_idx(ibv_qp); + } + PollTarget::Recv => { + rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_cq_idx(ibv_qp); + } + } + return Ok(true); + } + rdmaxcel_sys::CQE_POLL_ERROR => { + return Err(anyhow::anyhow!("Error polling {} completion", cq_type_str)); + } + _ => { + // No completion yet, sleep and try again + RealClock.sleep(Duration::from_millis(1)).await; } - return Ok(true); - } - rdmaxcel_sys::CQE_POLL_ERROR => { - return Err(anyhow::anyhow!("Error polling {} completion", cq_type_str)); - } - _ => { - // No completion yet, sleep and try again - RealClock.sleep(Duration::from_millis(1)).await; } } - } - Err(anyhow::Error::msg("Timeout while waiting for completion")) + Err(anyhow::Error::msg("Timeout while waiting for completion")) + } } pub struct RdmaManagerTestEnv<'a> { @@ -498,6 +539,7 @@ pub mod test_utils { if backend == "cuda" { config.use_gpu_direct = validate_execution_context().await.is_ok(); + eprintln!("Using GPU Direct: {}", config.use_gpu_direct); } (backend.to_string(), parsed_idx) @@ -716,7 +758,11 @@ pub mod test_utils { .await } - pub async fn verify_buffers(&self, size: usize) -> Result<(), anyhow::Error> { + pub async fn verify_buffers( + &self, + size: usize, + offset: usize, + ) -> Result<(), anyhow::Error> { let mut temp_buffer_1 = vec![0u8; size]; let mut temp_buffer_2 = vec![0u8; size]; @@ -726,14 +772,14 @@ pub mod test_utils { .verify_buffer( self.client_1, temp_buffer_1.as_mut_ptr() as usize, - self.device_ptr_1.unwrap(), + self.device_ptr_1.unwrap() + offset, size, ) .await?; } else { unsafe { std::ptr::copy_nonoverlapping( - self.buffer_1.ptr as *const u8, + (self.buffer_1.ptr + offset as u64) as *const u8, temp_buffer_1.as_mut_ptr(), size, ); @@ -746,14 +792,14 @@ pub mod test_utils { .verify_buffer( self.client_2, temp_buffer_2.as_mut_ptr() as usize, - self.device_ptr_2.unwrap(), + self.device_ptr_2.unwrap() + offset, size, ) .await?; } else { unsafe { std::ptr::copy_nonoverlapping( - self.buffer_2.ptr as *const u8, + (self.buffer_2.ptr + offset as u64) as *const u8, temp_buffer_2.as_mut_ptr(), size, ); @@ -763,7 +809,10 @@ pub mod test_utils { // Compare buffers for i in 0..size { if temp_buffer_1[i] != temp_buffer_2[i] { - return Err(anyhow::anyhow!("Buffers are not equal at index {}", i)); + return Err(anyhow::anyhow!( + "Buffers are not equal at index {}", + offset + i + )); } } Ok(()) diff --git a/rdmaxcel-sys/build.rs b/rdmaxcel-sys/build.rs index 036122a66..0a78d67de 100644 --- a/rdmaxcel-sys/build.rs +++ b/rdmaxcel-sys/build.rs @@ -97,6 +97,9 @@ fn main() { .allowlist_function("get_cuda_pci_address_from_ptr") .allowlist_function("rdmaxcel_print_device_info") .allowlist_function("rdmaxcel_error_string") + .allowlist_function("rdmaxcel_qp_.*") + .allowlist_function("poll_cq_with_cache") + .allowlist_function("completion_cache_.*") .allowlist_type("ibv_.*") .allowlist_type("mlx5dv_.*") .allowlist_type("mlx5_wqe_.*") @@ -104,6 +107,12 @@ fn main() { .allowlist_type("wqe_params_t") .allowlist_type("cqe_poll_params_t") .allowlist_type("rdma_segment_info_t") + .allowlist_type("rdmaxcel_qp_t") + .allowlist_type("rdmaxcel_qp") + .allowlist_type("completion_cache_t") + .allowlist_type("completion_cache") + .allowlist_type("poll_context_t") + .allowlist_type("poll_context") .allowlist_var("MLX5_.*") .allowlist_var("IBV_.*") // Block specific types that are manually defined in lib.rs diff --git a/rdmaxcel-sys/src/rdmaxcel.c b/rdmaxcel-sys/src/rdmaxcel.c index c6fdf57c8..7a4021fcc 100644 --- a/rdmaxcel-sys/src/rdmaxcel.c +++ b/rdmaxcel-sys/src/rdmaxcel.c @@ -7,10 +7,375 @@ */ #include "rdmaxcel.h" - +#include +#include #include #include #include +#include +#include + +// ============================================================================ +// RDMAXCEL QP Wrapper Implementation +// ============================================================================ + +rdmaxcel_qp_t* rdmaxcel_qp_create( + struct ibv_context* context, + struct ibv_pd* pd, + int cq_entries, + int max_send_wr, + int max_recv_wr, + int max_send_sge, + int max_recv_sge, + rdma_qp_type_t qp_type) { + // Allocate wrapper structure + rdmaxcel_qp_t* qp = (rdmaxcel_qp_t*)calloc(1, sizeof(rdmaxcel_qp_t)); + if (!qp) { + fprintf(stderr, "ERROR: Failed to allocate rdmaxcel_qp_t\n"); + return NULL; + } + + // Create underlying ibverbs QP + qp->ibv_qp = create_qp( + context, + pd, + cq_entries, + max_send_wr, + max_recv_wr, + max_send_sge, + max_recv_sge, + qp_type); + if (!qp->ibv_qp) { + free(qp); + return NULL; + } + + // Store CQ pointers + qp->send_cq = qp->ibv_qp->send_cq; + qp->recv_cq = qp->ibv_qp->recv_cq; + + // Initialize atomic counters + atomic_init(&qp->send_wqe_idx, 0); + atomic_init(&qp->send_db_idx, 0); + atomic_init(&qp->send_cq_idx, 0); + atomic_init(&qp->recv_wqe_idx, 0); + atomic_init(&qp->recv_db_idx, 0); + atomic_init(&qp->recv_cq_idx, 0); + atomic_init(&qp->rts_timestamp, UINT64_MAX); + + // Initialize completion caches + qp->send_completion_cache = + (completion_cache_t*)calloc(1, sizeof(completion_cache_t)); + qp->recv_completion_cache = + (completion_cache_t*)calloc(1, sizeof(completion_cache_t)); + + if (!qp->send_completion_cache || !qp->recv_completion_cache) { + if (qp->send_completion_cache) + free(qp->send_completion_cache); + if (qp->recv_completion_cache) + free(qp->recv_completion_cache); + ibv_destroy_qp(qp->ibv_qp); + free(qp); + fprintf(stderr, "ERROR: Failed to allocate completion caches\n"); + return NULL; + } + + completion_cache_init(qp->send_completion_cache); + completion_cache_init(qp->recv_completion_cache); + + return qp; +} + +void rdmaxcel_qp_destroy(rdmaxcel_qp_t* qp) { + if (!qp) { + return; + } + + // Clean up completion caches + if (qp->send_completion_cache) { + completion_cache_destroy(qp->send_completion_cache); + free(qp->send_completion_cache); + } + if (qp->recv_completion_cache) { + completion_cache_destroy(qp->recv_completion_cache); + free(qp->recv_completion_cache); + } + + // Destroy the underlying ibv_qp and its CQs + if (qp->ibv_qp) { + ibv_destroy_qp(qp->ibv_qp); + } + if (qp->send_cq) { + ibv_destroy_cq(qp->send_cq); + } + if (qp->recv_cq) { + ibv_destroy_cq(qp->recv_cq); + } + + free(qp); +} + +struct ibv_qp* rdmaxcel_qp_get_ibv_qp(rdmaxcel_qp_t* qp) { + return qp ? qp->ibv_qp : NULL; +} + +// Atomic fetch_add operations +uint64_t rdmaxcel_qp_fetch_add_send_wqe_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_fetch_add(&qp->send_wqe_idx, 1) : 0; +} + +uint64_t rdmaxcel_qp_fetch_add_send_db_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_fetch_add(&qp->send_db_idx, 1) : 0; +} + +uint64_t rdmaxcel_qp_fetch_add_send_cq_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_fetch_add(&qp->send_cq_idx, 1) : 0; +} + +uint64_t rdmaxcel_qp_fetch_add_recv_wqe_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_fetch_add(&qp->recv_wqe_idx, 1) : 0; +} + +uint64_t rdmaxcel_qp_fetch_add_recv_db_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_fetch_add(&qp->recv_db_idx, 1) : 0; +} + +uint64_t rdmaxcel_qp_fetch_add_recv_cq_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_fetch_add(&qp->recv_cq_idx, 1) : 0; +} + +// Atomic load operations +uint64_t rdmaxcel_qp_load_send_wqe_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_load(&qp->send_wqe_idx) : 0; +} + +uint64_t rdmaxcel_qp_load_send_db_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_load(&qp->send_db_idx) : 0; +} + +uint64_t rdmaxcel_qp_load_send_cq_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_load(&qp->send_cq_idx) : 0; +} + +uint64_t rdmaxcel_qp_load_recv_wqe_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_load(&qp->recv_wqe_idx) : 0; +} + +uint64_t rdmaxcel_qp_load_recv_cq_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_load(&qp->recv_cq_idx) : 0; +} + +uint64_t rdmaxcel_qp_load_rts_timestamp(rdmaxcel_qp_t* qp) { + return qp ? atomic_load(&qp->rts_timestamp) : UINT64_MAX; +} + +// Atomic store operations +void rdmaxcel_qp_store_send_db_idx(rdmaxcel_qp_t* qp, uint64_t value) { + if (qp) { + atomic_store(&qp->send_db_idx, value); + } +} + +void rdmaxcel_qp_store_rts_timestamp(rdmaxcel_qp_t* qp, uint64_t value) { + if (qp) { + atomic_store(&qp->rts_timestamp, value); + } +} + +// Get completion caches +completion_cache_t* rdmaxcel_qp_get_send_cache(rdmaxcel_qp_t* qp) { + return qp ? qp->send_completion_cache : NULL; +} + +completion_cache_t* rdmaxcel_qp_get_recv_cache(rdmaxcel_qp_t* qp) { + return qp ? qp->recv_completion_cache : NULL; +} + +// ============================================================================ +// Completion Cache Implementation +// ============================================================================ + +void completion_cache_init(completion_cache_t* cache) { + if (!cache) { + return; + } + cache->head = -1; + cache->tail = -1; + cache->count = 0; + + // Initialize free list + cache->free_head = 0; + for (int i = 0; i < MAX_CACHED_COMPLETIONS - 1; i++) { + cache->nodes[i].next = i + 1; + } + cache->nodes[MAX_CACHED_COMPLETIONS - 1].next = -1; + + pthread_mutex_init(&cache->lock, NULL); +} + +void completion_cache_destroy(completion_cache_t* cache) { + if (!cache) { + return; + } + + // Warn if cache still has entries when being destroyed + if (cache->count > 0) { + fprintf( + stderr, + "WARNING: Destroying completion cache with %zu unretrieved entries! " + "Possible missing poll operations or leaked work requests.\n", + cache->count); + int curr = cache->head; + fprintf(stderr, " Cached wr_ids:"); + while (curr != -1 && cache->count > 0) { + fprintf(stderr, " %lu", cache->nodes[curr].wc.wr_id); + curr = cache->nodes[curr].next; + } + fprintf(stderr, "\n"); + } + + pthread_mutex_destroy(&cache->lock); + cache->count = 0; +} + +int completion_cache_add(completion_cache_t* cache, struct ibv_wc* wc) { + if (!cache || !wc) { + return 0; + } + + pthread_mutex_lock(&cache->lock); + + if (cache->free_head == -1) { + pthread_mutex_unlock(&cache->lock); + fprintf( + stderr, + "WARNING: Completion cache full (%zu entries)! Dropping completion " + "for wr_id=%lu, qp=%u\n", + cache->count, + wc->wr_id, + wc->qp_num); + return 0; + } + + // Pop from free list + int idx = cache->free_head; + cache->free_head = cache->nodes[idx].next; + + // Store completion + cache->nodes[idx].wc = *wc; + cache->nodes[idx].next = -1; + + // Append to tail of used list + if (cache->head == -1) { + cache->head = idx; + cache->tail = idx; + } else { + cache->nodes[cache->tail].next = idx; + cache->tail = idx; + } + + cache->count++; + pthread_mutex_unlock(&cache->lock); + return 1; +} + +int completion_cache_find( + completion_cache_t* cache, + uint64_t wr_id, + uint32_t qp_num, + struct ibv_wc* out_wc) { + if (!cache || !out_wc) { + return 0; + } + + pthread_mutex_lock(&cache->lock); + + int prev = -1; + int curr = cache->head; + + while (curr != -1) { + if (cache->nodes[curr].wc.wr_id == wr_id && + cache->nodes[curr].wc.qp_num == qp_num) { + // Found it! Copy out + *out_wc = cache->nodes[curr].wc; + + // Remove from used list + if (prev == -1) { + // Removing head (O(1) for typical case!) + cache->head = cache->nodes[curr].next; + if (cache->head == -1) { + cache->tail = -1; + } + } else { + // Removing from middle/tail + cache->nodes[prev].next = cache->nodes[curr].next; + if (cache->nodes[curr].next == -1) { + cache->tail = prev; + } + } + + // Add to free list + cache->nodes[curr].next = cache->free_head; + cache->free_head = curr; + + cache->count--; + pthread_mutex_unlock(&cache->lock); + return 1; + } + + prev = curr; + curr = cache->nodes[curr].next; + } + + pthread_mutex_unlock(&cache->lock); + return 0; +} + +int poll_cq_with_cache(poll_context_t* ctx, struct ibv_wc* out_wc) { + if (!ctx || !out_wc) { + return RDMAXCEL_INVALID_PARAMS; + } + + if (completion_cache_find( + ctx->cache, ctx->expected_wr_id, ctx->expected_qp_num, out_wc)) { + if (out_wc->status != IBV_WC_SUCCESS) { + return RDMAXCEL_COMPLETION_FAILED; + } + return 1; + } + + struct ibv_wc wc; + int ret = ibv_poll_cq(ctx->cq, 1, &wc); + + if (ret < 0) { + return RDMAXCEL_CQ_POLL_FAILED; + } + + if (ret == 0) { + return 0; + } + + if (wc.status != IBV_WC_SUCCESS) { + if (wc.wr_id == ctx->expected_wr_id && wc.qp_num == ctx->expected_qp_num) { + *out_wc = wc; + return RDMAXCEL_COMPLETION_FAILED; + } + completion_cache_add(ctx->cache, &wc); + return 0; + } + + if (wc.wr_id == ctx->expected_wr_id && wc.qp_num == ctx->expected_qp_num) { + *out_wc = wc; + return 1; + } + + completion_cache_add(ctx->cache, &wc); + return 0; +} + +// ============================================================================ +// End of Completion Cache Implementation +// ============================================================================ cudaError_t register_mmio_to_cuda(void* bf, size_t size) { cudaError_t result = cudaHostRegister( diff --git a/rdmaxcel-sys/src/rdmaxcel.cpp b/rdmaxcel-sys/src/rdmaxcel.cpp index 9e359895b..6f55824b3 100644 --- a/rdmaxcel-sys/src/rdmaxcel.cpp +++ b/rdmaxcel-sys/src/rdmaxcel.cpp @@ -260,8 +260,8 @@ int compact_mrs(struct ibv_pd* pd, SegmentInfo& seg, int access_flags) { } // Register memory region for a specific segment address, assume cuda -int register_segments(struct ibv_pd* pd, struct ibv_qp* qp) { - if (!pd) { +int register_segments(struct ibv_pd* pd, rdmaxcel_qp_t* qp) { + if (!pd || !qp) { return RDMAXCEL_INVALID_PARAMS; // Invalid parameter } scan_existing_segments(); @@ -334,7 +334,7 @@ int register_segments(struct ibv_pd* pd, struct ibv_qp* qp) { seg.mr_size = seg.phys_size; // Create vector of GPU addresses for bind_mrs - auto err = bind_mrs(pd, qp, access_flags, seg); + auto err = bind_mrs(pd, qp->ibv_qp, access_flags, seg); if (err != 0) { return err; // Bind MR's failed } @@ -535,6 +535,10 @@ const char* rdmaxcel_error_string(int error_code) { return "[RdmaXcel] Output buffer too small"; case RDMAXCEL_QUERY_DEVICE_FAILED: return "[RdmaXcel] Failed to query device attributes"; + case RDMAXCEL_CQ_POLL_FAILED: + return "[RdmaXcel] CQ polling failed"; + case RDMAXCEL_COMPLETION_FAILED: + return "[RdmaXcel] Completion status not successful"; default: return "[RdmaXcel] Unknown error code"; } diff --git a/rdmaxcel-sys/src/rdmaxcel.h b/rdmaxcel-sys/src/rdmaxcel.h index 01b01d575..4a3cf316d 100644 --- a/rdmaxcel-sys/src/rdmaxcel.h +++ b/rdmaxcel-sys/src/rdmaxcel.h @@ -15,8 +15,13 @@ #include #include "driver_api.h" +// Handle atomics for both C and C++ #ifdef __cplusplus +#include +#define _Atomic(T) std::atomic extern "C" { +#else +#include #endif typedef enum { @@ -136,7 +141,9 @@ typedef enum { -12, // Failed to get CUDA device attribute RDMAXCEL_CUDA_GET_DEVICE_FAILED = -13, // Failed to get CUDA device handle RDMAXCEL_BUFFER_TOO_SMALL = -14, // Output buffer too small - RDMAXCEL_QUERY_DEVICE_FAILED = -15 // Failed to query device attributes + RDMAXCEL_QUERY_DEVICE_FAILED = -15, // Failed to query device attributes + RDMAXCEL_CQ_POLL_FAILED = -16, // CQ polling failed + RDMAXCEL_COMPLETION_FAILED = -17 // Completion status not successful } rdmaxcel_error_code_t; // Error/Debugging functions @@ -147,7 +154,6 @@ const char* rdmaxcel_error_string(int error_code); int rdma_get_active_segment_count(); int rdma_get_all_segment_info(rdma_segment_info_t* info_array, int max_count); bool pt_cuda_allocator_compatibility(); -int register_segments(struct ibv_pd* pd, struct ibv_qp* qp); int deregister_segments(); // CUDA utility functions @@ -156,6 +162,127 @@ int get_cuda_pci_address_from_ptr( char* pci_addr_out, size_t pci_addr_size); +cudaError_t register_host_mem(void** buf, size_t size); + +// Forward declarations +typedef struct completion_cache completion_cache_t; + +// RDMA Queue Pair wrapper with atomic counters and completion caches +typedef struct rdmaxcel_qp { + struct ibv_qp* ibv_qp; // Underlying ibverbs QP + struct ibv_cq* send_cq; // Send completion queue + struct ibv_cq* recv_cq; // Receive completion queue + + // Atomic counters + _Atomic(uint64_t) send_wqe_idx; + _Atomic(uint64_t) send_db_idx; + _Atomic(uint64_t) send_cq_idx; + _Atomic(uint64_t) recv_wqe_idx; + _Atomic(uint64_t) recv_db_idx; + _Atomic(uint64_t) recv_cq_idx; + _Atomic(uint64_t) rts_timestamp; + + // Completion caches + completion_cache_t* send_completion_cache; + completion_cache_t* recv_completion_cache; +} rdmaxcel_qp_t; + +// Create and initialize an rdmaxcel QP (wraps create_qp + initializes +// counters/caches) +rdmaxcel_qp_t* rdmaxcel_qp_create( + struct ibv_context* context, + struct ibv_pd* pd, + int cq_entries, + int max_send_wr, + int max_recv_wr, + int max_send_sge, + int max_recv_sge, + rdma_qp_type_t qp_type); + +// Destroy rdmaxcel QP and clean up resources +void rdmaxcel_qp_destroy(rdmaxcel_qp_t* qp); + +// Get underlying ibv_qp pointer (for compatibility with existing ibverbs calls) +struct ibv_qp* rdmaxcel_qp_get_ibv_qp(rdmaxcel_qp_t* qp); + +// Atomic fetch_add operations +uint64_t rdmaxcel_qp_fetch_add_send_wqe_idx(rdmaxcel_qp_t* qp); +uint64_t rdmaxcel_qp_fetch_add_send_db_idx(rdmaxcel_qp_t* qp); +uint64_t rdmaxcel_qp_fetch_add_send_cq_idx(rdmaxcel_qp_t* qp); +uint64_t rdmaxcel_qp_fetch_add_recv_wqe_idx(rdmaxcel_qp_t* qp); +uint64_t rdmaxcel_qp_fetch_add_recv_db_idx(rdmaxcel_qp_t* qp); +uint64_t rdmaxcel_qp_fetch_add_recv_cq_idx(rdmaxcel_qp_t* qp); + +// Atomic load operations (minimal API surface) +// Send side: needed for doorbell ring iteration [db_idx, wqe_idx) +uint64_t rdmaxcel_qp_load_send_wqe_idx(rdmaxcel_qp_t* qp); +uint64_t rdmaxcel_qp_load_send_db_idx(rdmaxcel_qp_t* qp); +// Receive side: needed for receive operations +uint64_t rdmaxcel_qp_load_recv_wqe_idx(rdmaxcel_qp_t* qp); +// Completion queue indices: needed for polling without modifying +uint64_t rdmaxcel_qp_load_send_cq_idx(rdmaxcel_qp_t* qp); +uint64_t rdmaxcel_qp_load_recv_cq_idx(rdmaxcel_qp_t* qp); +// Connection state validation +uint64_t rdmaxcel_qp_load_rts_timestamp(rdmaxcel_qp_t* qp); + +// Atomic store operations +void rdmaxcel_qp_store_send_db_idx(rdmaxcel_qp_t* qp, uint64_t value); +void rdmaxcel_qp_store_rts_timestamp(rdmaxcel_qp_t* qp, uint64_t value); + +// Get completion caches +completion_cache_t* rdmaxcel_qp_get_send_cache(rdmaxcel_qp_t* qp); +completion_cache_t* rdmaxcel_qp_get_recv_cache(rdmaxcel_qp_t* qp); + +// Segment registration (uses rdmaxcel_qp_t, so must come after type definition) +int register_segments(struct ibv_pd* pd, rdmaxcel_qp_t* qp); + +// Completion Cache Structures and Functions +#define MAX_CACHED_COMPLETIONS 128 + +// Linked list node for cached completions +typedef struct completion_node { + struct ibv_wc wc; + int next; // Index of next node, or -1 for end of list +} completion_node_t; + +// Cache for "unmatched" completions using embedded linked list +typedef struct completion_cache { + completion_node_t nodes[MAX_CACHED_COMPLETIONS]; + int head; // Index of first used node, or -1 if empty + int tail; // Index of last used node + int free_head; // Index of first free node, or -1 if full + size_t count; + pthread_mutex_t lock; +} completion_cache_t; + +// Context for polling with cache +typedef struct poll_context { + uint64_t expected_wr_id; // What wr_id am I looking for? + uint32_t expected_qp_num; // What QP am I expecting? + completion_cache_t* cache; // Shared completion cache + struct ibv_cq* cq; // The CQ to poll +} poll_context_t; + +// Initialize completion cache +void completion_cache_init(completion_cache_t* cache); + +// Destroy completion cache +void completion_cache_destroy(completion_cache_t* cache); + +// Add completion to cache +int completion_cache_add(completion_cache_t* cache, struct ibv_wc* wc); + +// Find and remove completion from cache +int completion_cache_find( + completion_cache_t* cache, + uint64_t wr_id, + uint32_t qp_num, + struct ibv_wc* out_wc); + +// Poll with cache support +// Returns: 1 = found, 0 = not found, -1 = error +int poll_cq_with_cache(poll_context_t* ctx, struct ibv_wc* out_wc); + #ifdef __cplusplus } #endif From 2a1f86e5cdddd5f0395efde30e99264653e04f9a Mon Sep 17 00:00:00 2001 From: Dennis van der Staay Date: Thu, 20 Nov 2025 19:55:18 -0800 Subject: [PATCH 2/2] Update load test to support concurrency (#1944) Summary: Update script to support concurrency, with relevant benchmarks: buck run @//mode/dev-nosan //monarch/python/tests:rdma_load_test -- --device cuda:0 cuda:1 --operation write --iterations 5 --size 500 --expandable-segments true --concurrency 4 sample output ``` ================================================================== CONCURRENT BATCH TIMING (wall-clock for all concurrent ops): Average batch time: 48.681 ms Minimum batch time: 25.463 ms Maximum batch time: 230.379 ms Standard deviation: 20.382 ms Average data per batch: 1982.5 MB AGGREGATE BANDWIDTH (concurrency=4): Average aggregate bandwidth: 341.62 Gbps Maximum aggregate bandwidth: 653.13 Gbps Minimum aggregate bandwidth: 72.19 Gbps TOTAL SUSTAINED THROUGHPUT: Total wall-clock time: 5.094 s Total data transferred: 198250.0 MB Sustained throughput: 326.47 Gbps (Accounts for 4x concurrent overlapping operations) ============================================================ RDMA WRITE LOAD TEST RESULTS (CUDA:1) ============================================================ INDIVIDUAL OPERATION TIMING: Average time per operation: 29.031 ms Minimum time per operation: 6.103 ms Maximum time per operation: 191.391 ms Standard deviation: 19.436 ms Total iterations completed: 400 Average data per operation: 495.6 MB Total data transferred: 198250.0 MB INDIVIDUAL OPERATION BANDWIDTH: Average bandwidth: 143.21 Gbps Maximum bandwidth: 681.26 Gbps Minimum bandwidth: 21.72 Gbps ``` Reviewed By: casteryh Differential Revision: D87475053 --- python/tests/rdma_load_test.py | 126 +++++++++++++++++++++++++++------ 1 file changed, 106 insertions(+), 20 deletions(-) diff --git a/python/tests/rdma_load_test.py b/python/tests/rdma_load_test.py index 3f6df8364..ad20efce2 100644 --- a/python/tests/rdma_load_test.py +++ b/python/tests/rdma_load_test.py @@ -56,6 +56,12 @@ default=10, help="Number of warmup iterations (default: 5)", ) + parser.add_argument( + "--concurrency", + type=int, + default=1, + help="Number of concurrent RDMA operations (default: 1)", + ) args = parser.parse_args() @@ -85,13 +91,38 @@ def __init__( # Timing data storage self.timing_data = [] self.size_data = [] + self.batch_timing_data = [] + self.batch_size_data = [] @endpoint async def set_other_actor(self, other_actor): self.other_actor = other_actor @endpoint - async def send(self, is_warmup=False) -> None: + async def send(self, is_warmup=False, concurrency: int = 1) -> None: + # Track wall-clock time for the entire concurrent batch + batch_start = time.time() + + tasks = [] + for _ in range(concurrency): + tasks.append(self._send_single(is_warmup)) + await asyncio.gather(*tasks) + + batch_end = time.time() + batch_elapsed = batch_end - batch_start + + if not is_warmup: + batch_size = ( + sum(self.size_data[-concurrency:]) + if len(self.size_data) >= concurrency + else 0 + ) + self.batch_timing_data.append(batch_elapsed) + self.batch_size_data.append(batch_size) + + self.i += 1 + + async def _send_single(self, is_warmup=False) -> None: shape = int( 1024 * 1024 * self.size_mb / 4 * (0.5 * random.randint(1, 3)) ) # Random size with +/- 50% variation based on user size @@ -104,7 +135,7 @@ async def send(self, is_warmup=False) -> None: # Critical validation - this should catch the null pointer issue assert ( tensor_addr != 0 - ), f"CRITICAL: Tensor has null pointer! Device: {device}, Shape: {shape}" + ), f"CRITICAL: Tensor has null pointer! Device: {self.device}, Shape: {shape}" assert size_elem > 0, f"CRITICAL: Tensor has zero size! Size: {size_elem}" byte_view = tensor.view(torch.uint8).flatten() @@ -136,8 +167,6 @@ async def send(self, is_warmup=False) -> None: # cleanup await buffer.drop() - self.i += 1 - @endpoint async def recv(self, rdma_buffer, shape, dtype, is_warmup): # Create receiving tensor on the same device @@ -167,7 +196,7 @@ async def recv(self, rdma_buffer, shape, dtype, is_warmup): @endpoint async def print_statistics(self, calc_bwd: bool = False): - """Calculate and print timing statistics""" + """Calculate and print timing statistics for individual operations""" if not self.timing_data: print("No timing data collected!") return @@ -175,7 +204,7 @@ async def print_statistics(self, calc_bwd: bool = False): timings = self.timing_data sizes = self.size_data - # Calculate statistics + # Calculate statistics for individual operations avg_time = statistics.mean(timings) min_time = min(timings) max_time = max(timings) @@ -184,7 +213,12 @@ async def print_statistics(self, calc_bwd: bool = False): avg_size = statistics.mean(sizes) total_data = sum(sizes) - print("TIMING RESULTS:") + device_type = self.device.upper() if self.device != "cpu" else "CPU" + print("\n" + "=" * 60) + print(f"RDMA {self.operation.upper()} LOAD TEST RESULTS ({device_type})") + print("=" * 60) + + print("INDIVIDUAL OPERATION TIMING:") print(f" Average time per operation: {avg_time * 1000:.3f} ms") print(f" Minimum time per operation: {min_time * 1000:.3f} ms") print(f" Maximum time per operation: {max_time * 1000:.3f} ms") @@ -202,24 +236,70 @@ def calc_bandwidth_gbps(size_bytes: int, time_seconds: float) -> float: max_bandwidth = calc_bandwidth_gbps(avg_size, min_time) min_bandwidth = calc_bandwidth_gbps(avg_size, max_time) - device_type = self.device.upper() if self.device != "cpu" else "CPU" - # Print results - print("\n" + "=" * 60) - print(f"RDMA {self.operation.upper()} LOAD TEST RESULTS ({device_type})") - print("=" * 60) print(f"Total iterations completed: {len(timings)}") print(f"Average data per operation: {avg_size / (1024*1024):.1f} MB") print(f"Total data transferred: {total_data / (1024*1024):.1f} MB") print() - print() - print("BANDWIDTH RESULTS:") + print("INDIVIDUAL OPERATION BANDWIDTH:") print(f" Average bandwidth: {avg_bandwidth:.2f} Gbps") print(f" Maximum bandwidth: {max_bandwidth:.2f} Gbps") print(f" Minimum bandwidth: {min_bandwidth:.2f} Gbps") print("=" * 60) + @endpoint + async def print_batch_statistics( + self, concurrency: int = 1, total_elapsed_time: float = 0.0 + ): + """Calculate and print batch-level statistics for concurrent operations""" + if not self.batch_timing_data: + print("No batch timing data collected!") + return + + batch_timings = self.batch_timing_data + batch_sizes = self.batch_size_data + total_data = sum(self.size_data) + + avg_batch_time = statistics.mean(batch_timings) + min_batch_time = min(batch_timings) + max_batch_time = max(batch_timings) + std_batch_time = ( + statistics.stdev(batch_timings) if len(batch_timings) > 1 else 0.0 + ) + avg_batch_size = statistics.mean(batch_sizes) + + print("\nCONCURRENT BATCH TIMING (wall-clock for all concurrent ops):") + print(f" Average batch time: {avg_batch_time * 1000:.3f} ms") + print(f" Minimum batch time: {min_batch_time * 1000:.3f} ms") + print(f" Maximum batch time: {max_batch_time * 1000:.3f} ms") + print(f" Standard deviation: {std_batch_time * 1000:.3f} ms") + print(f" Average data per batch: {avg_batch_size / (1024*1024):.1f} MB") + + # Calculate bandwidth (Gbps) + def calc_bandwidth_gbps(size_bytes: int, time_seconds: float) -> float: + if time_seconds == 0: + return 0.0 + bits_transferred = size_bytes * 8 + return bits_transferred / (time_seconds * 1e9) + + avg_aggregate_bw = calc_bandwidth_gbps(avg_batch_size, avg_batch_time) + max_aggregate_bw = calc_bandwidth_gbps(avg_batch_size, min_batch_time) + min_aggregate_bw = calc_bandwidth_gbps(avg_batch_size, max_batch_time) + + print(f"\nAGGREGATE BANDWIDTH (concurrency={concurrency}):") + print(f" Average aggregate bandwidth: {avg_aggregate_bw:.2f} Gbps") + print(f" Maximum aggregate bandwidth: {max_aggregate_bw:.2f} Gbps") + print(f" Minimum aggregate bandwidth: {min_aggregate_bw:.2f} Gbps") + + total_throughput = calc_bandwidth_gbps(total_data, total_elapsed_time) + print("\nTOTAL SUSTAINED THROUGHPUT:") + print(f" Total wall-clock time: {total_elapsed_time:.3f} s") + print(f" Total data transferred: {total_data / (1024*1024):.1f} MB") + print(f" Sustained throughput: {total_throughput:.2f} Gbps") + if concurrency > 1: + print(f" (Accounts for {concurrency}x concurrent overlapping operations)") + async def main( devices: list[str], @@ -227,6 +307,7 @@ async def main( operation: str = "write", size_mb: int = 64, warmup_iterations: int = 10, + concurrency: int = 1, ): # Adjust GPU allocation based on the device types device_0, device_1 = devices[0], devices[1] @@ -245,16 +326,20 @@ async def main( await actor_0.set_other_actor.call(actor_1) for i in range(warmup_iterations): - await actor_0.send.call(is_warmup=True) + await actor_0.send.call(is_warmup=True, concurrency=concurrency) + total_start_time = time.time() for i in range(iterations): - await actor_0.send.call() + await actor_0.send.call(concurrency=concurrency) + total_end_time = time.time() + total_elapsed_time = total_end_time - total_start_time - # Have both actors print their statistics - print("\n=== ACTOR 0 (Create Buffer) STATISTICS ===") - await actor_0.print_statistics.call() + # Actor 0: Print batch statistics (concurrency orchestration) + await actor_0.print_batch_statistics.call( + concurrency=concurrency, total_elapsed_time=total_elapsed_time + ) - print("\n=== ACTOR 1 (Create Buffer+Transmit) STATISTICS ===") + # Actor 1: Print individual RDMA transfer statistics await actor_1.print_statistics.call(calc_bwd=True) await mesh_0.stop() @@ -313,5 +398,6 @@ async def main( args.operation, args.size, args.warmup_iterations, + args.concurrency, ) )