diff --git a/monarch_rdma/Cargo.toml b/monarch_rdma/Cargo.toml index d4a70a5c9..d9176026d 100644 --- a/monarch_rdma/Cargo.toml +++ b/monarch_rdma/Cargo.toml @@ -14,6 +14,7 @@ edition = "2024" anyhow = "1.0.98" async-trait = "0.1.86" cuda-sys = { path = "../cuda-sys" } +futures = { version = "0.3.31", features = ["async-await", "compat"] } hyperactor = { version = "0.0.0", path = "../hyperactor" } rand = { version = "0.8", features = ["small_rng"] } rdmaxcel-sys = { path = "../rdmaxcel-sys" } diff --git a/monarch_rdma/src/rdma_components.rs b/monarch_rdma/src/rdma_components.rs index eaac40b86..7c600fe9e 100644 --- a/monarch_rdma/src/rdma_components.rs +++ b/monarch_rdma/src/rdma_components.rs @@ -144,9 +144,9 @@ impl RdmaBuffer { ) .await?; - qp.put(self.clone(), remote)?; + let wr_id = qp.put(self.clone(), remote)?; let result = self - .wait_for_completion(&mut qp, PollTarget::Send, timeout) + .wait_for_completion(&mut qp, PollTarget::Send, &wr_id, timeout) .await; // Release the queue pair back to the actor @@ -197,9 +197,9 @@ impl RdmaBuffer { remote_device.clone(), ) .await?; - qp.get(self.clone(), remote)?; + let wr_id = qp.get(self.clone(), remote)?; let result = self - .wait_for_completion(&mut qp, PollTarget::Send, timeout) + .wait_for_completion(&mut qp, PollTarget::Send, &wr_id, timeout) .await; // Release the queue pair back to the actor @@ -209,34 +209,47 @@ impl RdmaBuffer { result } - /// Waits for the completion of an RDMA operation. + /// Waits for the completion of RDMA operations. /// - /// This method polls the completion queue until the specified work request completes + /// This method polls the completion queue until all specified work requests complete /// or until the timeout is reached. /// /// # Arguments /// * `qp` - The RDMA Queue Pair to poll for completion + /// * `poll_target` - Which CQ to poll (Send or Recv) + /// * `expected_wr_ids` - The work request IDs to wait for /// * `timeout` - Timeout in seconds for the RDMA operation to complete. /// /// # Returns - /// `Ok(true)` if the operation completes successfully within the timeout, + /// `Ok(true)` if all operations complete successfully within the timeout, /// or an error if the timeout is reached async fn wait_for_completion( &self, qp: &mut RdmaQueuePair, poll_target: PollTarget, + expected_wr_ids: &[u64], timeout: u64, ) -> Result { let timeout = Duration::from_secs(timeout); let start_time = std::time::Instant::now(); + let mut remaining: std::collections::HashSet = + expected_wr_ids.iter().copied().collect(); + while start_time.elapsed() < timeout { - match qp.poll_completion_target(poll_target) { - Ok(Some(_wc)) => { - tracing::debug!("work completed"); - return Ok(true); - } - Ok(None) => { + if remaining.is_empty() { + return Ok(true); + } + + let wr_ids_to_poll: Vec = remaining.iter().copied().collect(); + match qp.poll_completion(poll_target, &wr_ids_to_poll) { + Ok(completions) => { + for (wr_id, _wc) in completions { + remaining.remove(&wr_id); + } + if remaining.is_empty() { + return Ok(true); + } RealClock.sleep(Duration::from_millis(1)).await; } Err(e) => { @@ -251,10 +264,14 @@ impl RdmaBuffer { } } } - tracing::error!("timed out while waiting on request completion"); + tracing::error!( + "timed out while waiting on request completion for wr_ids={:?}", + remaining + ); Err(anyhow::anyhow!( - "[buffer({:?})] rdma operation did not complete in time", - self + "[buffer({:?})] rdma operation did not complete in time (expected wr_ids={:?})", + self, + expected_wr_ids )) } @@ -287,6 +304,7 @@ impl RdmaBuffer { /// /// * `context`: A pointer to the RDMA device context, representing the connection to the RDMA device. /// * `pd`: A pointer to the protection domain, which provides isolation between different connections. +#[derive(Clone)] pub struct RdmaDomain { pub context: *mut rdmaxcel_sys::ibv_context, pub pd: *mut rdmaxcel_sys::ibv_pd, @@ -446,24 +464,23 @@ pub enum PollTarget { /// 3. Exchange connection info with remote peer (application must handle this) /// 4. Connect to remote endpoint with `connect()` /// 5. Perform RDMA operations with `put()` or `get()` -/// 6. Poll for completions with `poll_send_completion()` or `poll_recv_completion()` +/// 6. Poll for completions with `poll_send_completion(wr_id)` or `poll_recv_completion(wr_id)` +/// +/// # Notes +/// - The `qp` field stores a pointer to `rdmaxcel_qp_t` (not `ibv_qp`) +/// - `rdmaxcel_qp_t` contains atomic counters and completion caches internally +/// - This makes RdmaQueuePair trivially Clone and Serialize +/// - Multiple clones share the same underlying rdmaxcel_qp_t via the pointer #[derive(Debug, Serialize, Deserialize, Named, Clone)] pub struct RdmaQueuePair { pub send_cq: usize, // *mut rdmaxcel_sys::ibv_cq, pub recv_cq: usize, // *mut rdmaxcel_sys::ibv_cq, - pub qp: usize, // *mut rdmaxcel_sys::ibv_qp, + pub qp: usize, // *mut rdmaxcel_sys::rdmaxcel_qp_t pub dv_qp: usize, // *mut rdmaxcel_sys::mlx5dv_qp, pub dv_send_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq, pub dv_recv_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq, context: usize, // *mut rdmaxcel_sys::ibv_context, config: IbverbsConfig, - pub send_wqe_idx: u64, - pub send_db_idx: u64, - pub send_cq_idx: u64, - pub recv_wqe_idx: u64, - pub recv_db_idx: u64, - pub recv_cq_idx: u64, - rts_timestamp: u64, } impl RdmaQueuePair { @@ -472,22 +489,26 @@ impl RdmaQueuePair { /// This ensures the hardware has sufficient time to settle after reaching /// Ready-to-Send state before the first actual operation. fn apply_first_op_delay(&self, wr_id: u64) { - if wr_id == 0 { - assert!( - self.rts_timestamp != u64::MAX, - "First operation attempted before queue pair reached RTS state! Call connect() first." - ); - let current_nanos = RealClock - .system_time_now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_nanos() as u64; - let elapsed_nanos = current_nanos - self.rts_timestamp; - let elapsed = Duration::from_nanos(elapsed_nanos); - let init_delay = Duration::from_millis(self.config.hw_init_delay_ms); - if elapsed < init_delay { - let remaining_delay = init_delay - elapsed; - sleep(remaining_delay); + unsafe { + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; + if wr_id == 0 { + let rts_timestamp = rdmaxcel_sys::rdmaxcel_qp_load_rts_timestamp(qp); + assert!( + rts_timestamp != u64::MAX, + "First operation attempted before queue pair reached RTS state! Call connect() first." + ); + let current_nanos = RealClock + .system_time_now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos() as u64; + let elapsed_nanos = current_nanos - rts_timestamp; + let elapsed = Duration::from_nanos(elapsed_nanos); + let init_delay = Duration::from_millis(self.config.hw_init_delay_ms); + if elapsed < init_delay { + let remaining_delay = init_delay - elapsed; + sleep(remaining_delay); + } } } } @@ -521,8 +542,7 @@ impl RdmaQueuePair { unsafe { // Resolve Auto to a concrete QP type based on device capabilities let resolved_qp_type = resolve_qp_type(config.qp_type); - - let qp = rdmaxcel_sys::create_qp( + let qp = rdmaxcel_sys::rdmaxcel_qp_create( context, pd, config.cq_entries, @@ -541,18 +561,18 @@ impl RdmaQueuePair { )); } - let send_cq = (*qp).send_cq; - let recv_cq = (*qp).recv_cq; + let send_cq = (*(*qp).ibv_qp).send_cq; + let recv_cq = (*(*qp).ibv_qp).recv_cq; // mlx5dv provider APIs - let dv_qp = rdmaxcel_sys::create_mlx5dv_qp(qp); - let dv_send_cq = rdmaxcel_sys::create_mlx5dv_send_cq(qp); - let dv_recv_cq = rdmaxcel_sys::create_mlx5dv_recv_cq(qp); + let dv_qp = rdmaxcel_sys::create_mlx5dv_qp((*qp).ibv_qp); + let dv_send_cq = rdmaxcel_sys::create_mlx5dv_send_cq((*qp).ibv_qp); + let dv_recv_cq = rdmaxcel_sys::create_mlx5dv_recv_cq((*qp).ibv_qp); if dv_qp.is_null() || dv_send_cq.is_null() || dv_recv_cq.is_null() { - rdmaxcel_sys::ibv_destroy_cq((*qp).recv_cq); - rdmaxcel_sys::ibv_destroy_cq((*qp).send_cq); - rdmaxcel_sys::ibv_destroy_qp(qp); + rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq); + rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq); + rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp); return Err(anyhow::anyhow!( "failed to init mlx5dv_qp or completion queues" )); @@ -562,9 +582,9 @@ impl RdmaQueuePair { if config.use_gpu_direct { let ret = rdmaxcel_sys::register_cuda_memory(dv_qp, dv_recv_cq, dv_send_cq); if ret != 0 { - rdmaxcel_sys::ibv_destroy_cq((*qp).recv_cq); - rdmaxcel_sys::ibv_destroy_cq((*qp).send_cq); - rdmaxcel_sys::ibv_destroy_qp(qp); + rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq); + rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq); + rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp); return Err(anyhow::anyhow!( "failed to register GPU Direct RDMA memory: {:?}", ret @@ -575,18 +595,11 @@ impl RdmaQueuePair { send_cq: send_cq as usize, recv_cq: recv_cq as usize, qp: qp as usize, - dv_qp: dv_qp as usize, + dv_qp: qp as usize, dv_send_cq: dv_send_cq as usize, dv_recv_cq: dv_recv_cq as usize, context: context as usize, config, - recv_db_idx: 0, - recv_wqe_idx: 0, - recv_cq_idx: 0, - send_db_idx: 0, - send_wqe_idx: 0, - send_cq_idx: 0, - rts_timestamp: u64::MAX, }) } } @@ -615,7 +628,7 @@ impl RdmaQueuePair { // - The memory address provided is only stored, not dereferenced in this function unsafe { let context = self.context as *mut rdmaxcel_sys::ibv_context; - let qp = self.qp as *mut rdmaxcel_sys::ibv_qp; + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; let mut port_attr = rdmaxcel_sys::ibv_port_attr::default(); let errno = rdmaxcel_sys::ibv_query_port( context, @@ -642,7 +655,7 @@ impl RdmaQueuePair { } Ok(RdmaQpInfo { - qp_num: (*qp).qp_num, + qp_num: (*(*qp).ibv_qp).qp_num, lid: port_attr.lid, gid: Some(gid), psn: self.config.psn, @@ -653,7 +666,7 @@ impl RdmaQueuePair { pub fn state(&mut self) -> Result { // SAFETY: This block interacts with the RDMA device through rdmaxcel_sys calls. unsafe { - let qp = self.qp as *mut rdmaxcel_sys::ibv_qp; + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; let mut qp_attr = rdmaxcel_sys::ibv_qp_attr { ..Default::default() }; @@ -661,8 +674,12 @@ impl RdmaQueuePair { ..Default::default() }; let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE; - let errno = - rdmaxcel_sys::ibv_query_qp(qp, &mut qp_attr, mask.0 as i32, &mut qp_init_attr); + let errno = rdmaxcel_sys::ibv_query_qp( + (*qp).ibv_qp, + &mut qp_attr, + mask.0 as i32, + &mut qp_init_attr, + ); if errno != 0 { let os_error = Error::last_os_error(); return Err(anyhow::anyhow!("failed to query QP state: {}", os_error)); @@ -687,7 +704,7 @@ impl RdmaQueuePair { // 4. Memory access is properly bounded by the registered memory regions unsafe { // Transition to INIT - let qp = self.qp as *mut rdmaxcel_sys::ibv_qp; + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; let qp_access_flags = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE @@ -707,7 +724,7 @@ impl RdmaQueuePair { | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PORT | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS; - let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32); + let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32); if errno != 0 { let os_error = Error::last_os_error(); return Err(anyhow::anyhow!( @@ -755,7 +772,7 @@ impl RdmaQueuePair { | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER; - let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32); + let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32); if errno != 0 { let os_error = Error::last_os_error(); return Err(anyhow::anyhow!( @@ -782,7 +799,7 @@ impl RdmaQueuePair { | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RNR_RETRY | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC; - let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32); + let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32); if errno != 0 { let os_error = Error::last_os_error(); return Err(anyhow::anyhow!( @@ -796,56 +813,66 @@ impl RdmaQueuePair { ); // Record RTS time now that the queue pair is ready to send - self.rts_timestamp = RealClock + let rts_timestamp_nanos = RealClock .system_time_now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_nanos() as u64; + rdmaxcel_sys::rdmaxcel_qp_store_rts_timestamp(qp, rts_timestamp_nanos); Ok(()) } } - pub fn recv(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> { - let idx = self.recv_wqe_idx; - self.recv_wqe_idx += 1; - self.send_wqe( - 0, - lhandle.lkey, - 0, - idx, - true, - RdmaOperation::Recv, - 0, - rhandle.rkey, - ) - .unwrap(); - Ok(()) + pub fn recv(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result { + unsafe { + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; + let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_wqe_idx(qp); + self.post_op( + 0, + lhandle.lkey, + 0, + idx, + true, + RdmaOperation::Recv, + 0, + rhandle.rkey, + ) + .unwrap(); + rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_db_idx(qp); + Ok(idx) + } } pub fn put_with_recv( &mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer, - ) -> Result<(), anyhow::Error> { - let idx = self.send_wqe_idx; - self.send_wqe_idx += 1; - self.post_op( - lhandle.addr, - lhandle.lkey, - lhandle.size, - idx, - true, - RdmaOperation::WriteWithImm, - rhandle.addr, - rhandle.rkey, - ) - .unwrap(); - self.send_db_idx += 1; - Ok(()) + ) -> Result, anyhow::Error> { + unsafe { + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; + let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(qp); + self.post_op( + lhandle.addr, + lhandle.lkey, + lhandle.size, + idx, + true, + RdmaOperation::WriteWithImm, + rhandle.addr, + rhandle.rkey, + ) + .unwrap(); + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(qp); + Ok(vec![idx]) + } } - pub fn put(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> { + pub fn put( + &mut self, + lhandle: RdmaBuffer, + rhandle: RdmaBuffer, + ) -> Result, anyhow::Error> { let total_size = lhandle.size; if rhandle.size < total_size { return Err(anyhow::anyhow!( @@ -857,10 +884,15 @@ impl RdmaQueuePair { let mut remaining = total_size; let mut offset = 0; + let mut wr_ids = Vec::new(); while remaining > 0 { let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE); - let idx = self.send_wqe_idx; - self.send_wqe_idx += 1; + let idx = unsafe { + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx( + self.qp as *mut rdmaxcel_sys::rdmaxcel_qp, + ) + }; + wr_ids.push(idx); self.post_op( lhandle.addr + offset, lhandle.lkey, @@ -871,13 +903,17 @@ impl RdmaQueuePair { rhandle.addr + offset, rhandle.rkey, )?; - self.send_db_idx += 1; + unsafe { + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx( + self.qp as *mut rdmaxcel_sys::rdmaxcel_qp, + ); + } remaining -= chunk_size; offset += chunk_size; } - Ok(()) + Ok(wr_ids) } /// Get a doorbell for the queue pair. @@ -890,19 +926,23 @@ impl RdmaQueuePair { /// * `Result` - A doorbell for the queue pair pub fn ring_doorbell(&mut self) -> Result<(), anyhow::Error> { unsafe { + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp; let base_ptr = (*dv_qp).sq.buf as *mut u8; let wqe_cnt = (*dv_qp).sq.wqe_cnt; let stride = (*dv_qp).sq.stride; - if (wqe_cnt as u64) < (self.send_wqe_idx - self.send_db_idx) { + let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(qp); + let mut send_db_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_db_idx(qp); + if (wqe_cnt as u64) < (send_wqe_idx - send_db_idx) { return Err(anyhow::anyhow!("Overflow of WQE, possible data loss")); } - self.apply_first_op_delay(self.send_db_idx); - while self.send_db_idx < self.send_wqe_idx { - let offset = (self.send_db_idx % wqe_cnt as u64) * stride as u64; + self.apply_first_op_delay(send_db_idx); + while send_db_idx < send_wqe_idx { + let offset = (send_db_idx % wqe_cnt as u64) * stride as u64; let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize); rdmaxcel_sys::db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void); - self.send_db_idx += 1; + send_db_idx += 1; + rdmaxcel_sys::rdmaxcel_qp_store_send_db_idx(qp, send_db_idx); } Ok(()) } @@ -920,14 +960,18 @@ impl RdmaQueuePair { /// /// # Returns /// - /// * `Result<(), anyhow::Error>` - Success or error + /// * `Result, anyhow::Error>` - The work request IDs or error pub fn enqueue_put( &mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer, - ) -> Result<(), anyhow::Error> { - let idx = self.send_wqe_idx; - self.send_wqe_idx += 1; + ) -> Result, anyhow::Error> { + let idx = unsafe { + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx( + self.qp as *mut rdmaxcel_sys::rdmaxcel_qp, + ) + }; + self.send_wqe( lhandle.addr, lhandle.lkey, @@ -938,7 +982,7 @@ impl RdmaQueuePair { rhandle.addr, rhandle.rkey, )?; - Ok(()) + Ok(vec![idx]) } /// Enqueues a put with receive operation without ringing the doorbell. @@ -953,14 +997,18 @@ impl RdmaQueuePair { /// /// # Returns /// - /// * `Result<(), anyhow::Error>` - Success or error + /// * `Result, anyhow::Error>` - The work request IDs or error pub fn enqueue_put_with_recv( &mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer, - ) -> Result<(), anyhow::Error> { - let idx = self.send_wqe_idx; - self.send_wqe_idx += 1; + ) -> Result, anyhow::Error> { + let idx = unsafe { + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx( + self.qp as *mut rdmaxcel_sys::rdmaxcel_qp, + ) + }; + self.send_wqe( lhandle.addr, lhandle.lkey, @@ -971,7 +1019,7 @@ impl RdmaQueuePair { rhandle.addr, rhandle.rkey, )?; - Ok(()) + Ok(vec![idx]) } /// Enqueues a get operation without ringing the doorbell. @@ -986,14 +1034,18 @@ impl RdmaQueuePair { /// /// # Returns /// - /// * `Result<(), anyhow::Error>` - Success or error + /// * `Result, anyhow::Error>` - The work request IDs or error pub fn enqueue_get( &mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer, - ) -> Result<(), anyhow::Error> { - let idx = self.send_wqe_idx; - self.send_wqe_idx += 1; + ) -> Result, anyhow::Error> { + let idx = unsafe { + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx( + self.qp as *mut rdmaxcel_sys::rdmaxcel_qp, + ) + }; + self.send_wqe( lhandle.addr, lhandle.lkey, @@ -1004,10 +1056,14 @@ impl RdmaQueuePair { rhandle.addr, rhandle.rkey, )?; - Ok(()) + Ok(vec![idx]) } - pub fn get(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> { + pub fn get( + &mut self, + lhandle: RdmaBuffer, + rhandle: RdmaBuffer, + ) -> Result, anyhow::Error> { let total_size = lhandle.size; if rhandle.size < total_size { return Err(anyhow::anyhow!( @@ -1019,11 +1075,17 @@ impl RdmaQueuePair { let mut remaining = total_size; let mut offset = 0; + let mut wr_ids = Vec::new(); while remaining > 0 { let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE); - let idx = self.send_wqe_idx; - self.send_wqe_idx += 1; + let idx = unsafe { + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx( + self.qp as *mut rdmaxcel_sys::rdmaxcel_qp, + ) + }; + wr_ids.push(idx); + self.post_op( lhandle.addr + offset, lhandle.lkey, @@ -1034,13 +1096,17 @@ impl RdmaQueuePair { rhandle.addr + offset, rhandle.rkey, )?; - self.send_db_idx += 1; + unsafe { + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx( + self.qp as *mut rdmaxcel_sys::rdmaxcel_qp, + ); + } remaining -= chunk_size; offset += chunk_size; } - Ok(()) + Ok(wr_ids) } /// Posts a request to the queue pair. @@ -1073,7 +1139,7 @@ impl RdmaQueuePair { // - The ibverbs post_send operation follows the documented API contract // - Error codes from the device are properly checked and propagated unsafe { - let qp = self.qp as *mut rdmaxcel_sys::ibv_qp; + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; let context = self.context as *mut rdmaxcel_sys::ibv_context; let ops = &mut (*context).ops; let errno; @@ -1090,7 +1156,8 @@ impl RdmaQueuePair { ..Default::default() }; let mut bad_wr: *mut rdmaxcel_sys::ibv_recv_wr = std::ptr::null_mut(); - errno = ops.post_recv.as_mut().unwrap()(qp, &mut wr as *mut _, &mut bad_wr); + errno = + ops.post_recv.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr); } else if op_type == RdmaOperation::Write || op_type == RdmaOperation::Read || op_type == RdmaOperation::WriteWithImm @@ -1124,7 +1191,8 @@ impl RdmaQueuePair { wr.wr.rdma.rkey = rkey; let mut bad_wr: *mut rdmaxcel_sys::ibv_send_wr = std::ptr::null_mut(); - errno = ops.post_send.as_mut().unwrap()(qp, &mut wr as *mut _, &mut bad_wr); + errno = + ops.post_send.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr); } else { panic!("Not Implemented"); } @@ -1166,7 +1234,7 @@ impl RdmaQueuePair { RdmaOperation::Recv => 0, }; - let qp = self.qp as *mut rdmaxcel_sys::ibv_qp; + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp; let _dv_cq = if op_type == RdmaOperation::Recv { self.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq @@ -1191,7 +1259,7 @@ impl RdmaQueuePair { op_type: op_type_val, raddr, rkey, - qp_num: (*qp).qp_num, + qp_num: (*(*qp).ibv_qp).qp_num, buf, dbrec: (*dv_qp).dbrec, wqe_cnt: (*dv_qp).sq.wqe_cnt, @@ -1214,114 +1282,123 @@ impl RdmaQueuePair { } } - /// Poll for completions on the specified completion queue(s) + /// Poll for work completions by wr_ids. /// /// # Arguments /// - /// * `target` - Which completion queue(s) to poll (Send, Receive) + /// * `target` - Which completion queue to poll (Send, Receive) + /// * `expected_wr_ids` - Slice of work request IDs to wait for /// /// # Returns /// - /// * `Ok(Some(wc))` - A completion was found - /// * `Ok(None)` - No completion was found + /// * `Ok(Vec<(u64, IbvWc)>)` - Vector of (wr_id, completion) pairs that were found /// * `Err(e)` - An error occurred - pub fn poll_completion_target( + pub fn poll_completion( &mut self, target: PollTarget, - ) -> Result, anyhow::Error> { + expected_wr_ids: &[u64], + ) -> Result, anyhow::Error> { + if expected_wr_ids.is_empty() { + return Ok(Vec::new()); + } + unsafe { - let context = self.context as *mut rdmaxcel_sys::ibv_context; - let _outstanding_wqe = - self.send_db_idx + self.recv_db_idx - self.send_cq_idx - self.recv_cq_idx; + let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp; + let qp_num = (*(*qp).ibv_qp).qp_num; + + let (cq, cache, cq_type) = match target { + PollTarget::Send => ( + self.send_cq as *mut rdmaxcel_sys::ibv_cq, + rdmaxcel_sys::rdmaxcel_qp_get_send_cache(qp), + "send", + ), + PollTarget::Recv => ( + self.recv_cq as *mut rdmaxcel_sys::ibv_cq, + rdmaxcel_sys::rdmaxcel_qp_get_recv_cache(qp), + "recv", + ), + }; - // Check for send completions if requested - if (target == PollTarget::Send) && self.send_db_idx > self.send_cq_idx { - let send_cq = self.send_cq as *mut rdmaxcel_sys::ibv_cq; - let ops = &mut (*context).ops; - let mut wc = std::mem::MaybeUninit::::zeroed().assume_init(); - let ret = ops.poll_cq.as_mut().unwrap()(send_cq, 1, &mut wc); + let mut results = Vec::new(); - if ret < 0 { - return Err(anyhow::anyhow!( - "Failed to poll send CQ: {}", - Error::last_os_error() - )); - } + // Single-shot poll: check each wr_id once and return what we find + for &expected_wr_id in expected_wr_ids { + let mut poll_ctx = rdmaxcel_sys::poll_context_t { + expected_wr_id, + expected_qp_num: qp_num, + cache, + cq, + }; - if ret > 0 { - if !wc.is_valid() { - if let Some((status, vendor_err)) = wc.error() { - return Err(anyhow::anyhow!( - "Send work completion failed with status: {:?}, vendor error: {}, wr_id: {}, send_cq_idx: {}", - status, - vendor_err, - wc.wr_id(), - self.send_cq_idx, - )); + let mut wc = std::mem::MaybeUninit::::zeroed().assume_init(); + let ret = rdmaxcel_sys::poll_cq_with_cache(&mut poll_ctx, &mut wc); + + match ret { + 1 => { + // Found completion + if !wc.is_valid() { + if let Some((status, vendor_err)) = wc.error() { + return Err(anyhow::anyhow!( + "{} completion failed for wr_id={}: status={:?}, vendor_err={}", + cq_type, + expected_wr_id, + status, + vendor_err, + )); + } } + results.push((expected_wr_id, IbvWc::from(wc))); } - - // This should be a send completion - verify it's the one we're waiting for - if wc.wr_id() == self.send_cq_idx { - self.send_cq_idx += 1; - } - // finished polling, return the last completion - if self.send_cq_idx == self.send_db_idx { - return Ok(Some(IbvWc::from(wc))); + 0 => { + // Not found yet - this is fine for single-shot poll } - } - } - - // Check for receive completions if requested - if (target == PollTarget::Recv) && self.recv_db_idx > self.recv_cq_idx { - let recv_cq = self.recv_cq as *mut rdmaxcel_sys::ibv_cq; - let ops = &mut (*context).ops; - let mut wc = std::mem::MaybeUninit::::zeroed().assume_init(); - let ret = ops.poll_cq.as_mut().unwrap()(recv_cq, 1, &mut wc); - - if ret < 0 { - return Err(anyhow::anyhow!( - "Failed to poll receive CQ: {}", - Error::last_os_error() - )); - } - - if ret > 0 { - if !wc.is_valid() { + -17 => { + // RDMAXCEL_COMPLETION_FAILED: Completion found but failed - wc contains the error details + let error_msg = + std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret)) + .to_str() + .unwrap_or("Unknown error"); if let Some((status, vendor_err)) = wc.error() { return Err(anyhow::anyhow!( - "Recv work completion failed with status: {:?}, vendor error: {}, wr_id: {}, send_cq_idx: {}", + "Failed to poll {} CQ for wr_id={}: {} [status={:?}, vendor_err={}, qp_num={}, byte_len={}]", + cq_type, + expected_wr_id, + error_msg, status, vendor_err, - wc.wr_id(), - self.recv_cq_idx, + wc.qp_num, + wc.len(), + )); + } else { + return Err(anyhow::anyhow!( + "Failed to poll {} CQ for wr_id={}: {} [qp_num={}, byte_len={}]", + cq_type, + expected_wr_id, + error_msg, + wc.qp_num, + wc.len(), )); } } - - // This should be a send completion - verify it's the one we're waiting for - if wc.wr_id() == self.recv_cq_idx { - self.recv_cq_idx += 1; - } - // finished polling, return the last completion - if self.recv_cq_idx == self.recv_db_idx { - return Ok(Some(IbvWc::from(wc))); + _ => { + // Other errors + let error_msg = + std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret)) + .to_str() + .unwrap_or("Unknown error"); + return Err(anyhow::anyhow!( + "Failed to poll {} CQ for wr_id={}: {}", + cq_type, + expected_wr_id, + error_msg + )); } } } - // No completion found - Ok(None) + Ok(results) } } - - pub fn poll_send_completion(&mut self) -> Result, anyhow::Error> { - self.poll_completion_target(PollTarget::Send) - } - - pub fn poll_recv_completion(&mut self) -> Result, anyhow::Error> { - self.poll_completion_target(PollTarget::Recv) - } } /// Utility to validate execution context. diff --git a/monarch_rdma/src/rdma_manager_actor.rs b/monarch_rdma/src/rdma_manager_actor.rs index 2083d3fa3..936e2620f 100644 --- a/monarch_rdma/src/rdma_manager_actor.rs +++ b/monarch_rdma/src/rdma_manager_actor.rs @@ -28,8 +28,13 @@ //! //! See test examples: `test_rdma_write_loopback` and `test_rdma_read_loopback`. use std::collections::HashMap; +use std::collections::HashSet; +use std::sync::Arc; +use std::time::Duration; +use std::time::Instant; use async_trait::async_trait; +use futures::lock::Mutex; use hyperactor::Actor; use hyperactor::ActorId; use hyperactor::ActorRef; @@ -40,6 +45,7 @@ use hyperactor::Instance; use hyperactor::Named; use hyperactor::OncePortRef; use hyperactor::RefClient; +use hyperactor::clock::Clock; use hyperactor::supervision::ActorSupervisionEvent; use serde::Deserialize; use serde::Serialize; @@ -48,6 +54,7 @@ use crate::ibverbs_primitives::IbverbsConfig; use crate::ibverbs_primitives::RdmaMemoryRegionView; use crate::ibverbs_primitives::RdmaQpInfo; use crate::ibverbs_primitives::ibverbs_supported; +use crate::ibverbs_primitives::mlx5dv_supported; use crate::ibverbs_primitives::resolve_qp_type; use crate::rdma_components::RdmaBuffer; use crate::rdma_components::RdmaDomain; @@ -55,13 +62,6 @@ use crate::rdma_components::RdmaQueuePair; use crate::rdma_components::get_registered_cuda_segments; use crate::validate_execution_context; -/// Represents the state of a queue pair in the manager, either available or checked out. -#[derive(Debug, Clone)] -pub enum QueuePairState { - Available(RdmaQueuePair), - CheckedOut, -} - /// Helper function to get detailed error messages from RDMAXCEL error codes pub fn get_rdmaxcel_error_message(error_code: i32) -> String { unsafe { @@ -128,6 +128,14 @@ pub enum RdmaManagerMessage { /// `qp` - The queue pair to return (ownership transferred back) qp: RdmaQueuePair, }, + GetQpState { + other: ActorRef, + self_device: String, + other_device: String, + #[reply] + /// `reply` - Reply channel to return the QP state + reply: OncePortRef, + }, } #[derive(Debug)] @@ -138,12 +146,16 @@ pub enum RdmaManagerMessage { ], )] pub struct RdmaManagerActor { - // Nested map: local_device -> (ActorId, remote_device) -> QueuePairState - device_qps: HashMap>, + // Nested map: local_device -> (ActorId, remote_device) -> RdmaQueuePair + device_qps: HashMap>, + + // Track QPs currently being created to prevent duplicate creation + // Wrapped in Arc to allow safe concurrent access + pending_qp_creation: Arc>>, // Map of RDMA device names to their domains and loopback QPs // Created lazily when memory is registered for a specific device - device_domains: HashMap, + device_domains: HashMap)>, config: IbverbsConfig, @@ -171,14 +183,8 @@ impl Drop for RdmaManagerActor { fn destroy_queue_pair(qp: &RdmaQueuePair, context: &str) { unsafe { if qp.qp != 0 { - let result = rdmaxcel_sys::ibv_destroy_qp(qp.qp as *mut rdmaxcel_sys::ibv_qp); - if result != 0 { - tracing::debug!( - "ibv_destroy_qp returned {} for {} (may be busy during shutdown)", - result, - context - ); - } + let rdmaxcel_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp; + rdmaxcel_sys::rdmaxcel_qp_destroy(rdmaxcel_qp); } if qp.send_cq != 0 { let result = @@ -206,30 +212,17 @@ impl Drop for RdmaManagerActor { } // 1. Clean up all queue pairs (both regular and loopback) - for (device_name, device_map) in self.device_qps.drain() { - for ((actor_id, remote_device), qp_state) in device_map { - match qp_state { - QueuePairState::Available(qp) => { - destroy_queue_pair(&qp, &format!("actor {:?}", actor_id)); - } - QueuePairState::CheckedOut => { - tracing::warn!( - "QP for actor {:?} (device {} -> {}) was checked out during cleanup", - actor_id, - device_name, - remote_device - ); - } - } + for (_device_name, device_map) in self.device_qps.drain() { + for ((actor_id, _remote_device), qp) in device_map { + destroy_queue_pair(&qp, &format!("actor {:?}", actor_id)); } } // 2. Clean up device domains (which contain PDs and loopback QPs) - for (device_name, (domain, loopback_qp)) in self.device_domains.drain() { - destroy_queue_pair( - &loopback_qp, - &format!("loopback QP on device {}", device_name), - ); + for (device_name, (domain, qp)) in self.device_domains.drain() { + if let Some(qp) = qp { + destroy_queue_pair(&qp, &format!("loopback QP on device {}", device_name)); + } drop(domain); } @@ -278,10 +271,10 @@ impl RdmaManagerActor { &mut self, device_name: &str, rdma_device: &crate::ibverbs_primitives::RdmaDevice, - ) -> Result<(*mut rdmaxcel_sys::ibv_pd, *mut rdmaxcel_sys::ibv_qp), anyhow::Error> { + ) -> Result<(RdmaDomain, Option), anyhow::Error> { // Check if we already have a domain for this device if let Some((domain, qp)) = self.device_domains.get(device_name) { - return Ok((domain.pd, qp.qp as *mut rdmaxcel_sys::ibv_qp)); + return Ok((domain.clone(), qp.clone())); } // Create new domain for this device @@ -292,43 +285,38 @@ impl RdmaManagerActor { // Print device info if MONARCH_DEBUG_RDMA=1 is set (before initial QP creation) crate::print_device_info_if_debug_enabled(domain.context); - // Create loopback QP for this domain - let mut loopback_qp = RdmaQueuePair::new(domain.context, domain.pd, self.config.clone()) - .map_err(|e| { + // Create loopback QP for this domain if mlx5dv is supported + let qp = if mlx5dv_supported() { + let mut qp = RdmaQueuePair::new(domain.context, domain.pd, self.config.clone()) + .map_err(|e| { + anyhow::anyhow!( + "could not create loopback QP for device {}: {}", + device_name, + e + ) + })?; + + // Get connection info and connect to itself + let endpoint = qp.get_qp_info().map_err(|e| { + anyhow::anyhow!("could not get QP info for device {}: {}", device_name, e) + })?; + + qp.connect(&endpoint).map_err(|e| { anyhow::anyhow!( - "could not create loopback QP for device {}: {}", + "could not connect loopback QP for device {}: {}", device_name, e ) })?; - // Get connection info and connect to itself - let endpoint = loopback_qp.get_qp_info().map_err(|e| { - anyhow::anyhow!("could not get QP info for device {}: {}", device_name, e) - })?; - - loopback_qp.connect(&endpoint).map_err(|e| { - anyhow::anyhow!( - "could not connect loopback QP for device {}: {}", - device_name, - e - ) - })?; - - tracing::debug!( - "Created domain and loopback QP for RDMA device: {}", - device_name - ); - - // Store PD and QP pointers before inserting - let pd = domain.pd; - let qp = loopback_qp.qp as *mut rdmaxcel_sys::ibv_qp; + Some(qp) + } else { + None + }; - // Store the domain and QP self.device_domains - .insert(device_name.to_string(), (domain, loopback_qp)); - - Ok((pd, qp)) + .insert(device_name.to_string(), (domain.clone(), qp.clone())); + Ok((domain, qp)) } fn find_cuda_segment_for_address( @@ -417,8 +405,7 @@ impl RdmaManagerActor { ); // Get or create domain and loopback QP for this device - let (domain_pd, loopback_qp_ptr) = - self.get_or_create_device_domain(&device_name, &rdma_device)?; + let (domain, qp) = self.get_or_create_device_domain(&device_name, &rdma_device)?; let access = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE @@ -433,7 +420,10 @@ impl RdmaManagerActor { let mut maybe_mrv = self.find_cuda_segment_for_address(addr, size); // not found, lets re-sync with caching allocator and retry if maybe_mrv.is_none() { - let err = rdmaxcel_sys::register_segments(domain_pd, loopback_qp_ptr); + let err = rdmaxcel_sys::register_segments( + domain.pd, + qp.unwrap().qp as *mut rdmaxcel_sys::rdmaxcel_qp_t, + ); if err != 0 { let error_msg = get_rdmaxcel_error_message(err); return Err(anyhow::anyhow!( @@ -464,7 +454,7 @@ impl RdmaManagerActor { rdmaxcel_sys::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0, ); - mr = rdmaxcel_sys::ibv_reg_dmabuf_mr(domain_pd, 0, size, 0, fd, access.0 as i32); + mr = rdmaxcel_sys::ibv_reg_dmabuf_mr(domain.pd, 0, size, 0, fd, access.0 as i32); if mr.is_null() { return Err(anyhow::anyhow!("Failed to register dmabuf MR")); } @@ -480,7 +470,7 @@ impl RdmaManagerActor { } else { // CPU memory path mr = rdmaxcel_sys::ibv_reg_mr( - domain_pd, + domain.pd, addr as *mut std::ffi::c_void, size, access.0 as i32, @@ -561,6 +551,7 @@ impl Actor for RdmaManagerActor { Ok(Self { device_qps: HashMap::new(), + pending_qp_creation: Arc::new(Mutex::new(HashSet::new())), device_domains: HashMap::new(), config, pt_cuda_alloc, @@ -664,121 +655,170 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { &mut self, cx: &Context, other: ActorRef, - self_device: String, other_device: String, ) -> Result { let other_id = other.actor_id().clone(); - // Use the nested map structure: local_device -> (actor_id, remote_device) -> QueuePairState + // Use the nested map structure: local_device -> (actor_id, remote_device) -> RdmaQueuePair let inner_key = (other_id.clone(), other_device.clone()); // Check if queue pair exists in map if let Some(device_map) = self.device_qps.get(&self_device) { - if let Some(qp_state) = device_map.get(&inner_key).cloned() { - match qp_state { - QueuePairState::Available(qp) => { - // Queue pair exists and is available - return it - self.device_qps - .get_mut(&self_device) - .unwrap() - .insert(inner_key, QueuePairState::CheckedOut); - return Ok(qp); - } - QueuePairState::CheckedOut => { - return Err(anyhow::anyhow!( - "queue pair for actor {} on device {} is already checked out", - other_id, - other_device - )); + if let Some(qp) = device_map.get(&inner_key) { + return Ok(qp.clone()); + } + } + + // Try to acquire lock and mark as pending (hold lock only once!) + let pending_key = (self_device.clone(), other_id.clone(), other_device.clone()); + let mut pending = self.pending_qp_creation.lock().await; + + if pending.contains(&pending_key) { + // Another task is creating this QP, release lock and wait + drop(pending); + + // Loop checking device_qps until QP is created (no more locks needed) + // Timeout after 1 second + let start = Instant::now(); + let timeout = Duration::from_secs(1); + + loop { + cx.clock().sleep(Duration::from_micros(200)).await; + + // Check if QP was created while we waited + if let Some(device_map) = self.device_qps.get(&self_device) { + if let Some(qp) = device_map.get(&inner_key) { + return Ok(qp.clone()); } } + + // Check for timeout + if start.elapsed() >= timeout { + return Err(anyhow::anyhow!( + "Timeout waiting for QP creation (device {} -> actor {} device {}). \ + Another task is creating it but hasn't completed in 1 second", + self_device, + other_id, + other_device + )); + } } + } else { + // Not pending, add to set and proceed with creation + pending.insert(pending_key.clone()); + drop(pending); + // Fall through to create QP } // Queue pair doesn't exist - need to create connection - let is_loopback = other_id == cx.bind::().actor_id().clone() - && self_device == other_device; - - if is_loopback { - // Loopback connection setup - self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone()) - .await?; - let endpoint = self - .connection_info(cx, other.clone(), other_device.clone(), self_device.clone()) - .await?; - self.connect( - cx, - other.clone(), - self_device.clone(), - other_device.clone(), - endpoint, - ) - .await?; - } else { - // Remote connection setup - self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone()) - .await?; - other - .initialize_qp( + let result = async { + let is_loopback = other_id == cx.bind::().actor_id().clone() + && self_device == other_device; + + if is_loopback { + // Loopback connection setup + self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone()) + .await?; + let endpoint = self + .connection_info(cx, other.clone(), other_device.clone(), self_device.clone()) + .await?; + self.connect( cx, - cx.bind().clone(), - other_device.clone(), + other.clone(), self_device.clone(), - ) - .await?; - let other_endpoint: RdmaQpInfo = other - .connection_info( - cx, - cx.bind().clone(), other_device.clone(), - self_device.clone(), + endpoint, ) .await?; - self.connect( - cx, - other.clone(), - self_device.clone(), - other_device.clone(), - other_endpoint, - ) - .await?; - let local_endpoint = self - .connection_info(cx, other.clone(), self_device.clone(), other_device.clone()) - .await?; - other - .connect( + } else { + // Remote connection setup + self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone()) + .await?; + other + .initialize_qp( + cx, + cx.bind().clone(), + other_device.clone(), + self_device.clone(), + ) + .await?; + let other_endpoint: RdmaQpInfo = other + .connection_info( + cx, + cx.bind().clone(), + other_device.clone(), + self_device.clone(), + ) + .await?; + self.connect( cx, - cx.bind().clone(), - other_device.clone(), + other.clone(), self_device.clone(), - local_endpoint, + other_device.clone(), + other_endpoint, ) .await?; - } + let local_endpoint = self + .connection_info(cx, other.clone(), self_device.clone(), other_device.clone()) + .await?; + other + .connect( + cx, + cx.bind().clone(), + other_device.clone(), + self_device.clone(), + local_endpoint, + ) + .await?; + + // BARRIER: Ensure remote side has completed its connection and is ready + let remote_state = other + .get_qp_state( + cx, + cx.bind().clone(), + other_device.clone(), + self_device.clone(), + ) + .await?; + + if remote_state != rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS { + return Err(anyhow::anyhow!( + "Remote QP not in RTS state after connection setup. \ + Local is ready but remote is in state {}. \ + This indicates a synchronization issue in connection setup.", + remote_state + )); + } + } - // Now that connection is established, get the queue pair - if let Some(device_map) = self.device_qps.get(&self_device) { - if let Some(QueuePairState::Available(qp)) = device_map.get(&inner_key).cloned() { - self.device_qps - .get_mut(&self_device) - .unwrap() - .insert(inner_key, QueuePairState::CheckedOut); - Ok(qp) + // Now that connection is established, get and clone the queue pair + if let Some(device_map) = self.device_qps.get(&self_device) { + if let Some(qp) = device_map.get(&inner_key) { + Ok(qp.clone()) + } else { + Err(anyhow::anyhow!( + "Failed to create connection for actor {} on device {}", + other_id, + other_device + )) + } } else { Err(anyhow::anyhow!( - "Failed to create connection for actor {} on device {}", + "Failed to create connection for actor {} on device {} - no device map", other_id, other_device )) } - } else { - Err(anyhow::anyhow!( - "Failed to create connection for actor {} on device {} - no device map", - other_id, - other_device - )) } + .await; + + // Always remove from pending set when done (success or failure) + let mut pending = self.pending_qp_creation.lock().await; + pending.remove(&pending_key); + drop(pending); + + result } async fn initialize_qp( @@ -813,13 +853,7 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { // Get or create domain and extract pointers to avoid borrowing issues let (domain_context, domain_pd) = { // Check if we already have a domain for the device - if !self.device_domains.contains_key(&self_device) { - // Create domain first if it doesn't exist - self.get_or_create_device_domain(&self_device, &rdma_device)?; - } - - // Now get the domain context and PD safely - let (domain, _qp) = self.device_domains.get(&self_device).unwrap(); + let (domain, _) = self.get_or_create_device_domain(&self_device, &rdma_device)?; (domain.context, domain.pd) }; @@ -830,7 +864,7 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { self.device_qps .entry(self_device.clone()) .or_insert_with(HashMap::new) - .insert(inner_key, QueuePairState::Available(qp)); + .insert(inner_key, qp); tracing::debug!( "successfully created a connection with {:?} for local device {} -> remote device {}", @@ -858,21 +892,16 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { tracing::debug!("connecting with {:?}", other); let other_id = other.actor_id().clone(); - // For backward compatibility, use default device let inner_key = (other_id.clone(), other_device.clone()); if let Some(device_map) = self.device_qps.get_mut(&self_device) { match device_map.get_mut(&inner_key) { - Some(QueuePairState::Available(qp)) => { + Some(qp) => { qp.connect(&endpoint).map_err(|e| { anyhow::anyhow!("could not connect to RDMA endpoint: {}", e) })?; Ok(()) } - Some(QueuePairState::CheckedOut) => Err(anyhow::anyhow!( - "Cannot connect: queue pair for actor {} is checked out", - other_id - )), None => Err(anyhow::anyhow!( "No connection found for actor {}", other_id @@ -907,14 +936,10 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { if let Some(device_map) = self.device_qps.get_mut(&self_device) { match device_map.get_mut(&inner_key) { - Some(QueuePairState::Available(qp)) => { + Some(qp) => { let connection_info = qp.get_qp_info()?; Ok(connection_info) } - Some(QueuePairState::CheckedOut) => Err(anyhow::anyhow!( - "Cannot get connection info: queue pair for actor {} is checked out", - other_id - )), None => Err(anyhow::anyhow!( "No connection found for actor {}", other_id @@ -930,47 +955,59 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { /// Releases a queue pair back to the HashMap /// - /// This method returns a queue pair to the HashMap after the caller has finished - /// using it. This completes the request/release cycle, similar to RdmaBuffer. + /// This method is now a no-op since RdmaQueuePair is Clone and can be safely shared. + /// The queue pair is not actually checked out, so there's nothing to release. + /// This method is kept for API compatibility. /// /// # Arguments /// * `remote` - The ActorRef of the remote actor to return the queue pair for - /// * `qp` - The queue pair to release + /// * `qp` - The queue pair to release (ignored) async fn release_queue_pair( + &mut self, + _cx: &Context, + _other: ActorRef, + _self_device: String, + _other_device: String, + _qp: RdmaQueuePair, + ) -> Result<(), anyhow::Error> { + // No-op: Queue pairs are now cloned and shared via atomic counters + // Nothing needs to be released + Ok(()) + } + + /// Gets the state of a queue pair + /// + /// # Arguments + /// * `other` - The ActorRef to get the QP state for + /// * `self_device` - Local device name + /// * `other_device` - Remote device name + /// + /// # Returns + /// * `u32` - The QP state (e.g., IBV_QPS_RTS = Ready To Send) + async fn get_qp_state( &mut self, _cx: &Context, other: ActorRef, self_device: String, other_device: String, - qp: RdmaQueuePair, - ) -> Result<(), anyhow::Error> { - let inner_key = (other.actor_id().clone(), other_device.clone()); - - match self - .device_qps - .get_mut(&self_device) - .unwrap() - .get_mut(&inner_key) - { - Some(QueuePairState::CheckedOut) => { - self.device_qps - .get_mut(&self_device) - .unwrap() - .insert(inner_key, QueuePairState::Available(qp)); - Ok(()) + ) -> Result { + let other_id = other.actor_id().clone(); + let inner_key = (other_id.clone(), other_device.clone()); + + if let Some(device_map) = self.device_qps.get_mut(&self_device) { + match device_map.get_mut(&inner_key) { + Some(qp) => qp.state(), + None => Err(anyhow::anyhow!( + "No connection found for actor {} on device {}", + other_id, + other_device + )), } - Some(QueuePairState::Available(_)) => Err(anyhow::anyhow!( - "Cannot release queue pair: queue pair for actor {} is already available between devices {} and {}", - other.actor_id(), - self_device, - other_device, - )), - None => Err(anyhow::anyhow!( - "No queue pair found for actor {}, between devices {} and {}", - other.actor_id(), - self_device, - other_device, - )), + } else { + Err(anyhow::anyhow!( + "No device map found for self device {}", + self_device + )) } } } diff --git a/monarch_rdma/src/rdma_manager_actor_tests.rs b/monarch_rdma/src/rdma_manager_actor_tests.rs index 4e6ba4648..01938d83c 100644 --- a/monarch_rdma/src/rdma_manager_actor_tests.rs +++ b/monarch_rdma/src/rdma_manager_actor_tests.rs @@ -41,10 +41,10 @@ mod tests { env.rdma_handle_2.device_name.clone(), ) .await?; - qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; + let wr_id = qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; // Poll for completion - wait_for_completion(&mut qp_1, PollTarget::Send, 2).await?; + wait_for_completion(&mut qp_1, PollTarget::Send, &wr_id, 2).await?; env.actor_1 .release_queue_pair( @@ -56,7 +56,7 @@ mod tests { ) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; Ok(()) } @@ -79,9 +79,9 @@ mod tests { env.rdma_handle_2.device_name.clone(), ) .await?; - qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; + let wr_id = qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; - wait_for_completion(&mut qp_1, PollTarget::Send, 2).await?; + wait_for_completion(&mut qp_1, PollTarget::Send, &wr_id, 2).await?; env.actor_1 .release_queue_pair( @@ -93,7 +93,7 @@ mod tests { ) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; Ok(()) } @@ -118,12 +118,12 @@ mod tests { env.rdma_handle_2.device_name.clone(), ) .await?; - qp_1.get(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; + let wr_id = qp_1.get(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; // Poll for completion - wait_for_completion(&mut qp_1, PollTarget::Send, 2).await?; + wait_for_completion(&mut qp_1, PollTarget::Send, &wr_id, 2).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; Ok(()) } @@ -148,10 +148,10 @@ mod tests { env.rdma_handle_2.device_name.clone(), ) .await?; - qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; - wait_for_completion(&mut qp_1, PollTarget::Send, 2).await?; + let wr_id = qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; + wait_for_completion(&mut qp_1, PollTarget::Send, &wr_id, 2).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -185,10 +185,10 @@ mod tests { env.rdma_handle_1.device_name.clone(), ) .await?; - qp_2.put_with_recv(env.rdma_handle_2.clone(), env.rdma_handle_1.clone())?; qp_1.recv(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; - wait_for_completion(&mut qp_2, PollTarget::Send, 5).await?; - env.verify_buffers(BSIZE).await?; + let wr_id = qp_2.put_with_recv(env.rdma_handle_2.clone(), env.rdma_handle_1.clone())?; + wait_for_completion(&mut qp_2, PollTarget::Send, &wr_id, 5).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -214,12 +214,12 @@ mod tests { env.rdma_handle_2.device_name.clone(), ) .await?; - qp_1.enqueue_put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; + let wr_id = qp_1.enqueue_put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; qp_1.ring_doorbell()?; // Poll for completion - wait_for_completion(&mut qp_1, PollTarget::Send, 5).await?; + wait_for_completion(&mut qp_1, PollTarget::Send, &wr_id, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; Ok(()) } @@ -244,12 +244,12 @@ mod tests { env.rdma_handle_1.device_name.clone(), ) .await?; - qp_2.enqueue_get(env.rdma_handle_2.clone(), env.rdma_handle_1.clone())?; + let wr_id = qp_2.enqueue_get(env.rdma_handle_2.clone(), env.rdma_handle_1.clone())?; qp_2.ring_doorbell()?; // Poll for completion - wait_for_completion(&mut qp_2, PollTarget::Send, 5).await?; + wait_for_completion(&mut qp_2, PollTarget::Send, &wr_id, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; Ok(()) } @@ -270,7 +270,7 @@ mod tests { .read_into(env.client_1, env.rdma_handle_2.clone(), 2) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -289,10 +289,10 @@ mod tests { let env = RdmaManagerTestEnv::setup(BSIZE, "cpu:0", "cpu:1").await?; let /*mut*/ rdma_handle_1 = env.rdma_handle_1.clone(); rdma_handle_1 - .write_from(env.client_1, env.rdma_handle_2.clone(), 2) + .write_from(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; Ok(()) } @@ -342,7 +342,7 @@ mod tests { // Poll for completion wait_for_completion_gpu(&mut qp_1, PollTarget::Send, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; Ok(()) } @@ -377,11 +377,11 @@ mod tests { ) .await?; qp_1.enqueue_get(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; - ring_db_gpu(&mut qp_1).await?; + ring_db_gpu(&qp_1).await?; // Poll for completion wait_for_completion_gpu(&mut qp_1, PollTarget::Send, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; Ok(()) } @@ -438,9 +438,9 @@ mod tests { ) .await?; ring_db_gpu(&mut qp_2).await?; - wait_for_completion_gpu(&mut qp_1, PollTarget::Recv, 10).await?; + wait_for_completion_gpu(&mut qp_1, PollTarget::Send, 10).await?; wait_for_completion_gpu(&mut qp_2, PollTarget::Send, 10).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -471,11 +471,11 @@ mod tests { env.rdma_handle_2.device_name.clone(), ) .await?; - qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; + let wr_id = qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; - wait_for_completion(&mut qp_1, PollTarget::Send, 5).await?; + wait_for_completion(&mut qp_1, PollTarget::Send, &wr_id, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -506,11 +506,11 @@ mod tests { env.rdma_handle_2.device_name.clone(), ) .await?; - qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; + let wr_id = qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; - wait_for_completion(&mut qp_1, PollTarget::Send, 5).await?; + wait_for_completion(&mut qp_1, PollTarget::Send, &wr_id, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -538,7 +538,7 @@ mod tests { .read_into(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -564,7 +564,7 @@ mod tests { .read_into(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -590,7 +590,7 @@ mod tests { .read_into(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -616,7 +616,53 @@ mod tests { .write_from(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; + env.cleanup().await?; + Ok(()) + } + + #[timed_test::async_timed_test(timeout_secs = 30)] + async fn test_concurrent_send_to_same_target() -> Result<(), anyhow::Error> { + const BSIZE: usize = 2 * 1024 * 1024; + let devices = get_all_devices(); + if devices.is_empty() { + println!("Skipping test: RDMA devices not available"); + return Ok(()); + } + + let env = RdmaManagerTestEnv::setup(BSIZE, "cuda:0", "cuda:1").await?; + + let rdma_handle_1 = env.rdma_handle_1.clone(); + let rdma_handle_2 = env.rdma_handle_2.clone(); + let rdma_handle_3 = env.rdma_handle_1.clone(); + let rdma_handle_4 = env.rdma_handle_2.clone(); + let rdma_handle_5 = env.rdma_handle_1.clone(); + let rdma_handle_6 = env.rdma_handle_2.clone(); + let client = env.client_1; + + let task1 = async { + let result = rdma_handle_1 + .write_from(client, rdma_handle_2.clone(), 2) + .await; + result + }; + + let task2 = async { + let result = rdma_handle_3 + .write_from(client, rdma_handle_4.clone(), 2) + .await; + result + }; + + let task3 = async { + let result = rdma_handle_5 + .write_from(client, rdma_handle_6.clone(), 2) + .await; + result + }; + let (result1, result2, result3) = tokio::join!(task1, task2, task3); + + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -644,7 +690,7 @@ mod tests { .read_into(env.client_1, env.rdma_handle_2.clone(), 2) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -672,7 +718,7 @@ mod tests { .write_from(env.client_1, env.rdma_handle_2.clone(), 2) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -704,7 +750,7 @@ mod tests { .read_into(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; env.cleanup().await?; Ok(()) } @@ -736,8 +782,88 @@ mod tests { .write_from(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers(BSIZE, 0).await?; + env.cleanup().await?; + Ok(()) + } + + #[timed_test::async_timed_test(timeout_secs = 120)] + async fn test_rdma_read_chunk() -> Result<(), anyhow::Error> { + if is_cpu_only_mode() { + println!("Skipping CUDA test in CPU-only mode"); + return Ok(()); + } + // 2GB tensor size + const BSIZE: usize = 2 * 1024 * 1024 * 1024; + const CHUNK_SIZE: usize = 2 * 1024 * 1024; // 2MB + let devices = get_all_devices(); + if devices.len() < 5 { + println!( + "Skipping test: requires H100 nodes with backend network (found {} devices)", + devices.len() + ); + return Ok(()); + } + + println!("Setting up 2GB CUDA tensors for RDMA read test..."); + let env = RdmaManagerTestEnv::setup(BSIZE, "cuda:0", "cuda:1").await?; + + println!("Performing RDMA read operation on 2GB tensor..."); + let rdma_handle_1 = env.rdma_handle_1.clone(); + rdma_handle_1 + .read_into(env.client_1, env.rdma_handle_2.clone(), 30) + .await?; + + println!("Verifying first 2MB..."); + env.verify_buffers(CHUNK_SIZE, 0).await?; + + println!("Verifying last 2MB..."); + env.verify_buffers(CHUNK_SIZE, BSIZE - CHUNK_SIZE).await?; + + println!("Cleaning up..."); + env.cleanup().await?; + + println!("2GB RDMA read test completed successfully"); + Ok(()) + } + + #[timed_test::async_timed_test(timeout_secs = 60)] + async fn test_rdma_write_chunk() -> Result<(), anyhow::Error> { + if is_cpu_only_mode() { + println!("Skipping CUDA test in CPU-only mode"); + return Ok(()); + } + // 2GB tensor size + const BSIZE: usize = 2 * 1024 * 1024 * 1024; + const CHUNK_SIZE: usize = 2 * 1024 * 1024; // 2MB + let devices = get_all_devices(); + if devices.len() < 5 { + println!( + "Skipping test: requires H100 nodes with backend network (found {} devices)", + devices.len() + ); + return Ok(()); + } + + println!("Setting up 2GB CUDA tensors for RDMA write test..."); + let env = RdmaManagerTestEnv::setup(BSIZE, "cuda:0", "cuda:1").await?; + + println!("Performing RDMA write operation on 2GB tensor..."); + let rdma_handle_1 = env.rdma_handle_1.clone(); + rdma_handle_1 + .write_from(env.client_1, env.rdma_handle_2.clone(), 30) + .await?; + + println!("Verifying first 2MB..."); + env.verify_buffers(CHUNK_SIZE, 0).await?; + + println!("Verifying last 2MB..."); + env.verify_buffers(CHUNK_SIZE, BSIZE - CHUNK_SIZE).await?; + + println!("Cleaning up..."); env.cleanup().await?; + + println!("2GB RDMA write test completed successfully"); Ok(()) } } diff --git a/monarch_rdma/src/test_utils.rs b/monarch_rdma/src/test_utils.rs index 96c3cce30..58d2a88ea 100644 --- a/monarch_rdma/src/test_utils.rs +++ b/monarch_rdma/src/test_utils.rs @@ -307,26 +307,47 @@ pub mod test_utils { } } - // Waits for the completion of an RDMA operation. - - // This function polls for the completion of an RDMA operation by repeatedly - // sending a `PollCompletion` message to the specified actor mesh and checking - // the returned work completion status. It continues polling until the operation - // completes or the specified timeout is reached. - + /// Waits for the completion of RDMA operations. + /// + /// This function polls for the completion of RDMA operations by repeatedly + /// checking the completion queue until all expected work requests complete + /// or the specified timeout is reached. + /// + /// # Arguments + /// * `qp` - The RDMA Queue Pair to poll for completion + /// * `poll_target` - Which CQ to poll (Send or Recv) + /// * `expected_wr_ids` - Slice of work request IDs to wait for + /// * `timeout_secs` - Timeout in seconds + /// + /// # Returns + /// `Ok(true)` if all operations complete successfully within the timeout, + /// or an error if the timeout is reached pub async fn wait_for_completion( qp: &mut RdmaQueuePair, poll_target: PollTarget, + expected_wr_ids: &[u64], timeout_secs: u64, ) -> Result { let timeout = Duration::from_secs(timeout_secs); let start_time = Instant::now(); + + let mut remaining: std::collections::HashSet = + expected_wr_ids.iter().copied().collect(); + while start_time.elapsed() < timeout { - match qp.poll_completion_target(poll_target) { - Ok(Some(_wc)) => { - return Ok(true); - } - Ok(None) => { + if remaining.is_empty() { + return Ok(true); + } + + let wr_ids_to_poll: Vec = remaining.iter().copied().collect(); + match qp.poll_completion(poll_target, &wr_ids_to_poll) { + Ok(completions) => { + for (wr_id, _wc) in completions { + remaining.remove(&wr_id); + } + if remaining.is_empty() { + return Ok(true); + } RealClock.sleep(Duration::from_millis(1)).await; } Err(e) => { @@ -334,7 +355,10 @@ pub mod test_utils { } } } - Err(anyhow::Error::msg("Timeout while waiting for completion")) + Err(anyhow::Error::msg(format!( + "Timeout while waiting for completion of wr_ids: {:?}", + remaining + ))) } /// Posts a work request to the send queue of the given RDMA queue pair. @@ -345,25 +369,26 @@ pub mod test_utils { op_type: u32, ) -> Result<(), anyhow::Error> { unsafe { - let ibv_qp = qp.qp as *mut rdmaxcel_sys::ibv_qp; + let ibv_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp; let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp; + let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(ibv_qp); let params = rdmaxcel_sys::wqe_params_t { laddr: lhandle.addr, length: lhandle.size, lkey: lhandle.lkey, - wr_id: qp.send_wqe_idx, + wr_id: send_wqe_idx, signaled: true, op_type, raddr: rhandle.addr, rkey: rhandle.rkey, - qp_num: (*ibv_qp).qp_num, + qp_num: (*(*ibv_qp).ibv_qp).qp_num, buf: (*dv_qp).sq.buf as *mut u8, wqe_cnt: (*dv_qp).sq.wqe_cnt, dbrec: (*dv_qp).dbrec, ..Default::default() }; rdmaxcel_sys::launch_send_wqe(params); - qp.send_wqe_idx += 1; + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(ibv_qp); } Ok(()) } @@ -377,43 +402,53 @@ pub mod test_utils { ) -> Result<(), anyhow::Error> { // Populate params using lhandle and rhandle unsafe { - let ibv_qp = qp.qp as *mut rdmaxcel_sys::ibv_qp; + let rdmaxcel_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp; let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp; + let recv_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_recv_wqe_idx(rdmaxcel_qp); let params = rdmaxcel_sys::wqe_params_t { laddr: lhandle.addr, length: lhandle.size, lkey: lhandle.lkey, - wr_id: qp.recv_wqe_idx, + wr_id: recv_wqe_idx, op_type, signaled: true, - qp_num: (*ibv_qp).qp_num, + qp_num: (*(*rdmaxcel_qp).ibv_qp).qp_num, buf: (*dv_qp).rq.buf as *mut u8, wqe_cnt: (*dv_qp).rq.wqe_cnt, dbrec: (*dv_qp).dbrec, ..Default::default() }; rdmaxcel_sys::launch_recv_wqe(params); - qp.recv_wqe_idx += 1; - qp.recv_db_idx += 1; + rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_wqe_idx(rdmaxcel_qp); + rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_db_idx(rdmaxcel_qp); } Ok(()) } - pub async fn ring_db_gpu(qp: &mut RdmaQueuePair) -> Result<(), anyhow::Error> { + pub async fn ring_db_gpu(qp: &RdmaQueuePair) -> Result<(), anyhow::Error> { RealClock.sleep(Duration::from_millis(2)).await; unsafe { let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp; let base_ptr = (*dv_qp).sq.buf as *mut u8; let wqe_cnt = (*dv_qp).sq.wqe_cnt; let stride = (*dv_qp).sq.stride; - if (wqe_cnt as u64) < (qp.send_wqe_idx - qp.send_db_idx) { + let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx( + qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp, + ); + let mut send_db_idx = + rdmaxcel_sys::rdmaxcel_qp_load_send_db_idx(qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp); + if (wqe_cnt as u64) < (send_wqe_idx - send_db_idx) { return Err(anyhow::anyhow!("Overflow of WQE, possible data loss")); } - while qp.send_db_idx < qp.send_wqe_idx { - let offset = (qp.send_db_idx % wqe_cnt as u64) * stride as u64; + while send_db_idx < send_wqe_idx { + let offset = (send_db_idx % wqe_cnt as u64) * stride as u64; let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize); rdmaxcel_sys::launch_db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void); - qp.send_db_idx += 1; + send_db_idx += 1; + rdmaxcel_sys::rdmaxcel_qp_store_send_db_idx( + qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp, + send_db_idx, + ); } } Ok(()) @@ -425,48 +460,54 @@ pub mod test_utils { poll_target: PollTarget, timeout_secs: u64, ) -> Result { - let timeout = Duration::from_secs(timeout_secs); - let start_time = Instant::now(); - - while start_time.elapsed() < timeout { - // Get the appropriate completion queue and index based on the poll target - let (cq, idx, cq_type_str) = match poll_target { - PollTarget::Send => ( - qp.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq, - qp.send_cq_idx, - "send", - ), - PollTarget::Recv => ( - qp.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq, - qp.recv_cq_idx, - "receive", - ), - }; - - // Poll the completion queue - let result = - unsafe { rdmaxcel_sys::launch_cqe_poll(cq as *mut std::ffi::c_void, idx as i32) }; - - match result { - rdmaxcel_sys::CQE_POLL_TRUE => { - // Update the appropriate index based on the poll target - match poll_target { - PollTarget::Send => qp.send_cq_idx += 1, - PollTarget::Recv => qp.recv_cq_idx += 1, + unsafe { + let start_time = Instant::now(); + let timeout = Duration::from_secs(timeout_secs); + let ibv_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp; + + while start_time.elapsed() < timeout { + // Get the appropriate completion queue and index based on the poll target + let (cq, idx, cq_type_str) = match poll_target { + PollTarget::Send => ( + qp.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq, + rdmaxcel_sys::rdmaxcel_qp_load_send_cq_idx(ibv_qp), + "send", + ), + PollTarget::Recv => ( + qp.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq, + rdmaxcel_sys::rdmaxcel_qp_load_recv_cq_idx(ibv_qp), + "receive", + ), + }; + + // Poll the completion queue + let result = rdmaxcel_sys::launch_cqe_poll(cq as *mut std::ffi::c_void, idx as i32); + + match result { + rdmaxcel_sys::CQE_POLL_TRUE => { + // Update the appropriate index based on the poll target + match poll_target { + PollTarget::Send => { + rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_cq_idx(ibv_qp); + } + PollTarget::Recv => { + rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_cq_idx(ibv_qp); + } + } + return Ok(true); + } + rdmaxcel_sys::CQE_POLL_ERROR => { + return Err(anyhow::anyhow!("Error polling {} completion", cq_type_str)); + } + _ => { + // No completion yet, sleep and try again + RealClock.sleep(Duration::from_millis(1)).await; } - return Ok(true); - } - rdmaxcel_sys::CQE_POLL_ERROR => { - return Err(anyhow::anyhow!("Error polling {} completion", cq_type_str)); - } - _ => { - // No completion yet, sleep and try again - RealClock.sleep(Duration::from_millis(1)).await; } } - } - Err(anyhow::Error::msg("Timeout while waiting for completion")) + Err(anyhow::Error::msg("Timeout while waiting for completion")) + } } pub struct RdmaManagerTestEnv<'a> { @@ -498,6 +539,7 @@ pub mod test_utils { if backend == "cuda" { config.use_gpu_direct = validate_execution_context().await.is_ok(); + eprintln!("Using GPU Direct: {}", config.use_gpu_direct); } (backend.to_string(), parsed_idx) @@ -716,7 +758,11 @@ pub mod test_utils { .await } - pub async fn verify_buffers(&self, size: usize) -> Result<(), anyhow::Error> { + pub async fn verify_buffers( + &self, + size: usize, + offset: usize, + ) -> Result<(), anyhow::Error> { let mut temp_buffer_1 = vec![0u8; size]; let mut temp_buffer_2 = vec![0u8; size]; @@ -726,14 +772,14 @@ pub mod test_utils { .verify_buffer( self.client_1, temp_buffer_1.as_mut_ptr() as usize, - self.device_ptr_1.unwrap(), + self.device_ptr_1.unwrap() + offset, size, ) .await?; } else { unsafe { std::ptr::copy_nonoverlapping( - self.buffer_1.ptr as *const u8, + (self.buffer_1.ptr + offset as u64) as *const u8, temp_buffer_1.as_mut_ptr(), size, ); @@ -746,14 +792,14 @@ pub mod test_utils { .verify_buffer( self.client_2, temp_buffer_2.as_mut_ptr() as usize, - self.device_ptr_2.unwrap(), + self.device_ptr_2.unwrap() + offset, size, ) .await?; } else { unsafe { std::ptr::copy_nonoverlapping( - self.buffer_2.ptr as *const u8, + (self.buffer_2.ptr + offset as u64) as *const u8, temp_buffer_2.as_mut_ptr(), size, ); @@ -763,7 +809,10 @@ pub mod test_utils { // Compare buffers for i in 0..size { if temp_buffer_1[i] != temp_buffer_2[i] { - return Err(anyhow::anyhow!("Buffers are not equal at index {}", i)); + return Err(anyhow::anyhow!( + "Buffers are not equal at index {}", + offset + i + )); } } Ok(()) diff --git a/python/tests/rdma_load_test.py b/python/tests/rdma_load_test.py index 3f6df8364..ad20efce2 100644 --- a/python/tests/rdma_load_test.py +++ b/python/tests/rdma_load_test.py @@ -56,6 +56,12 @@ default=10, help="Number of warmup iterations (default: 5)", ) + parser.add_argument( + "--concurrency", + type=int, + default=1, + help="Number of concurrent RDMA operations (default: 1)", + ) args = parser.parse_args() @@ -85,13 +91,38 @@ def __init__( # Timing data storage self.timing_data = [] self.size_data = [] + self.batch_timing_data = [] + self.batch_size_data = [] @endpoint async def set_other_actor(self, other_actor): self.other_actor = other_actor @endpoint - async def send(self, is_warmup=False) -> None: + async def send(self, is_warmup=False, concurrency: int = 1) -> None: + # Track wall-clock time for the entire concurrent batch + batch_start = time.time() + + tasks = [] + for _ in range(concurrency): + tasks.append(self._send_single(is_warmup)) + await asyncio.gather(*tasks) + + batch_end = time.time() + batch_elapsed = batch_end - batch_start + + if not is_warmup: + batch_size = ( + sum(self.size_data[-concurrency:]) + if len(self.size_data) >= concurrency + else 0 + ) + self.batch_timing_data.append(batch_elapsed) + self.batch_size_data.append(batch_size) + + self.i += 1 + + async def _send_single(self, is_warmup=False) -> None: shape = int( 1024 * 1024 * self.size_mb / 4 * (0.5 * random.randint(1, 3)) ) # Random size with +/- 50% variation based on user size @@ -104,7 +135,7 @@ async def send(self, is_warmup=False) -> None: # Critical validation - this should catch the null pointer issue assert ( tensor_addr != 0 - ), f"CRITICAL: Tensor has null pointer! Device: {device}, Shape: {shape}" + ), f"CRITICAL: Tensor has null pointer! Device: {self.device}, Shape: {shape}" assert size_elem > 0, f"CRITICAL: Tensor has zero size! Size: {size_elem}" byte_view = tensor.view(torch.uint8).flatten() @@ -136,8 +167,6 @@ async def send(self, is_warmup=False) -> None: # cleanup await buffer.drop() - self.i += 1 - @endpoint async def recv(self, rdma_buffer, shape, dtype, is_warmup): # Create receiving tensor on the same device @@ -167,7 +196,7 @@ async def recv(self, rdma_buffer, shape, dtype, is_warmup): @endpoint async def print_statistics(self, calc_bwd: bool = False): - """Calculate and print timing statistics""" + """Calculate and print timing statistics for individual operations""" if not self.timing_data: print("No timing data collected!") return @@ -175,7 +204,7 @@ async def print_statistics(self, calc_bwd: bool = False): timings = self.timing_data sizes = self.size_data - # Calculate statistics + # Calculate statistics for individual operations avg_time = statistics.mean(timings) min_time = min(timings) max_time = max(timings) @@ -184,7 +213,12 @@ async def print_statistics(self, calc_bwd: bool = False): avg_size = statistics.mean(sizes) total_data = sum(sizes) - print("TIMING RESULTS:") + device_type = self.device.upper() if self.device != "cpu" else "CPU" + print("\n" + "=" * 60) + print(f"RDMA {self.operation.upper()} LOAD TEST RESULTS ({device_type})") + print("=" * 60) + + print("INDIVIDUAL OPERATION TIMING:") print(f" Average time per operation: {avg_time * 1000:.3f} ms") print(f" Minimum time per operation: {min_time * 1000:.3f} ms") print(f" Maximum time per operation: {max_time * 1000:.3f} ms") @@ -202,24 +236,70 @@ def calc_bandwidth_gbps(size_bytes: int, time_seconds: float) -> float: max_bandwidth = calc_bandwidth_gbps(avg_size, min_time) min_bandwidth = calc_bandwidth_gbps(avg_size, max_time) - device_type = self.device.upper() if self.device != "cpu" else "CPU" - # Print results - print("\n" + "=" * 60) - print(f"RDMA {self.operation.upper()} LOAD TEST RESULTS ({device_type})") - print("=" * 60) print(f"Total iterations completed: {len(timings)}") print(f"Average data per operation: {avg_size / (1024*1024):.1f} MB") print(f"Total data transferred: {total_data / (1024*1024):.1f} MB") print() - print() - print("BANDWIDTH RESULTS:") + print("INDIVIDUAL OPERATION BANDWIDTH:") print(f" Average bandwidth: {avg_bandwidth:.2f} Gbps") print(f" Maximum bandwidth: {max_bandwidth:.2f} Gbps") print(f" Minimum bandwidth: {min_bandwidth:.2f} Gbps") print("=" * 60) + @endpoint + async def print_batch_statistics( + self, concurrency: int = 1, total_elapsed_time: float = 0.0 + ): + """Calculate and print batch-level statistics for concurrent operations""" + if not self.batch_timing_data: + print("No batch timing data collected!") + return + + batch_timings = self.batch_timing_data + batch_sizes = self.batch_size_data + total_data = sum(self.size_data) + + avg_batch_time = statistics.mean(batch_timings) + min_batch_time = min(batch_timings) + max_batch_time = max(batch_timings) + std_batch_time = ( + statistics.stdev(batch_timings) if len(batch_timings) > 1 else 0.0 + ) + avg_batch_size = statistics.mean(batch_sizes) + + print("\nCONCURRENT BATCH TIMING (wall-clock for all concurrent ops):") + print(f" Average batch time: {avg_batch_time * 1000:.3f} ms") + print(f" Minimum batch time: {min_batch_time * 1000:.3f} ms") + print(f" Maximum batch time: {max_batch_time * 1000:.3f} ms") + print(f" Standard deviation: {std_batch_time * 1000:.3f} ms") + print(f" Average data per batch: {avg_batch_size / (1024*1024):.1f} MB") + + # Calculate bandwidth (Gbps) + def calc_bandwidth_gbps(size_bytes: int, time_seconds: float) -> float: + if time_seconds == 0: + return 0.0 + bits_transferred = size_bytes * 8 + return bits_transferred / (time_seconds * 1e9) + + avg_aggregate_bw = calc_bandwidth_gbps(avg_batch_size, avg_batch_time) + max_aggregate_bw = calc_bandwidth_gbps(avg_batch_size, min_batch_time) + min_aggregate_bw = calc_bandwidth_gbps(avg_batch_size, max_batch_time) + + print(f"\nAGGREGATE BANDWIDTH (concurrency={concurrency}):") + print(f" Average aggregate bandwidth: {avg_aggregate_bw:.2f} Gbps") + print(f" Maximum aggregate bandwidth: {max_aggregate_bw:.2f} Gbps") + print(f" Minimum aggregate bandwidth: {min_aggregate_bw:.2f} Gbps") + + total_throughput = calc_bandwidth_gbps(total_data, total_elapsed_time) + print("\nTOTAL SUSTAINED THROUGHPUT:") + print(f" Total wall-clock time: {total_elapsed_time:.3f} s") + print(f" Total data transferred: {total_data / (1024*1024):.1f} MB") + print(f" Sustained throughput: {total_throughput:.2f} Gbps") + if concurrency > 1: + print(f" (Accounts for {concurrency}x concurrent overlapping operations)") + async def main( devices: list[str], @@ -227,6 +307,7 @@ async def main( operation: str = "write", size_mb: int = 64, warmup_iterations: int = 10, + concurrency: int = 1, ): # Adjust GPU allocation based on the device types device_0, device_1 = devices[0], devices[1] @@ -245,16 +326,20 @@ async def main( await actor_0.set_other_actor.call(actor_1) for i in range(warmup_iterations): - await actor_0.send.call(is_warmup=True) + await actor_0.send.call(is_warmup=True, concurrency=concurrency) + total_start_time = time.time() for i in range(iterations): - await actor_0.send.call() + await actor_0.send.call(concurrency=concurrency) + total_end_time = time.time() + total_elapsed_time = total_end_time - total_start_time - # Have both actors print their statistics - print("\n=== ACTOR 0 (Create Buffer) STATISTICS ===") - await actor_0.print_statistics.call() + # Actor 0: Print batch statistics (concurrency orchestration) + await actor_0.print_batch_statistics.call( + concurrency=concurrency, total_elapsed_time=total_elapsed_time + ) - print("\n=== ACTOR 1 (Create Buffer+Transmit) STATISTICS ===") + # Actor 1: Print individual RDMA transfer statistics await actor_1.print_statistics.call(calc_bwd=True) await mesh_0.stop() @@ -313,5 +398,6 @@ async def main( args.operation, args.size, args.warmup_iterations, + args.concurrency, ) ) diff --git a/rdmaxcel-sys/build.rs b/rdmaxcel-sys/build.rs index 036122a66..0a78d67de 100644 --- a/rdmaxcel-sys/build.rs +++ b/rdmaxcel-sys/build.rs @@ -97,6 +97,9 @@ fn main() { .allowlist_function("get_cuda_pci_address_from_ptr") .allowlist_function("rdmaxcel_print_device_info") .allowlist_function("rdmaxcel_error_string") + .allowlist_function("rdmaxcel_qp_.*") + .allowlist_function("poll_cq_with_cache") + .allowlist_function("completion_cache_.*") .allowlist_type("ibv_.*") .allowlist_type("mlx5dv_.*") .allowlist_type("mlx5_wqe_.*") @@ -104,6 +107,12 @@ fn main() { .allowlist_type("wqe_params_t") .allowlist_type("cqe_poll_params_t") .allowlist_type("rdma_segment_info_t") + .allowlist_type("rdmaxcel_qp_t") + .allowlist_type("rdmaxcel_qp") + .allowlist_type("completion_cache_t") + .allowlist_type("completion_cache") + .allowlist_type("poll_context_t") + .allowlist_type("poll_context") .allowlist_var("MLX5_.*") .allowlist_var("IBV_.*") // Block specific types that are manually defined in lib.rs diff --git a/rdmaxcel-sys/src/rdmaxcel.c b/rdmaxcel-sys/src/rdmaxcel.c index c6fdf57c8..7a4021fcc 100644 --- a/rdmaxcel-sys/src/rdmaxcel.c +++ b/rdmaxcel-sys/src/rdmaxcel.c @@ -7,10 +7,375 @@ */ #include "rdmaxcel.h" - +#include +#include #include #include #include +#include +#include + +// ============================================================================ +// RDMAXCEL QP Wrapper Implementation +// ============================================================================ + +rdmaxcel_qp_t* rdmaxcel_qp_create( + struct ibv_context* context, + struct ibv_pd* pd, + int cq_entries, + int max_send_wr, + int max_recv_wr, + int max_send_sge, + int max_recv_sge, + rdma_qp_type_t qp_type) { + // Allocate wrapper structure + rdmaxcel_qp_t* qp = (rdmaxcel_qp_t*)calloc(1, sizeof(rdmaxcel_qp_t)); + if (!qp) { + fprintf(stderr, "ERROR: Failed to allocate rdmaxcel_qp_t\n"); + return NULL; + } + + // Create underlying ibverbs QP + qp->ibv_qp = create_qp( + context, + pd, + cq_entries, + max_send_wr, + max_recv_wr, + max_send_sge, + max_recv_sge, + qp_type); + if (!qp->ibv_qp) { + free(qp); + return NULL; + } + + // Store CQ pointers + qp->send_cq = qp->ibv_qp->send_cq; + qp->recv_cq = qp->ibv_qp->recv_cq; + + // Initialize atomic counters + atomic_init(&qp->send_wqe_idx, 0); + atomic_init(&qp->send_db_idx, 0); + atomic_init(&qp->send_cq_idx, 0); + atomic_init(&qp->recv_wqe_idx, 0); + atomic_init(&qp->recv_db_idx, 0); + atomic_init(&qp->recv_cq_idx, 0); + atomic_init(&qp->rts_timestamp, UINT64_MAX); + + // Initialize completion caches + qp->send_completion_cache = + (completion_cache_t*)calloc(1, sizeof(completion_cache_t)); + qp->recv_completion_cache = + (completion_cache_t*)calloc(1, sizeof(completion_cache_t)); + + if (!qp->send_completion_cache || !qp->recv_completion_cache) { + if (qp->send_completion_cache) + free(qp->send_completion_cache); + if (qp->recv_completion_cache) + free(qp->recv_completion_cache); + ibv_destroy_qp(qp->ibv_qp); + free(qp); + fprintf(stderr, "ERROR: Failed to allocate completion caches\n"); + return NULL; + } + + completion_cache_init(qp->send_completion_cache); + completion_cache_init(qp->recv_completion_cache); + + return qp; +} + +void rdmaxcel_qp_destroy(rdmaxcel_qp_t* qp) { + if (!qp) { + return; + } + + // Clean up completion caches + if (qp->send_completion_cache) { + completion_cache_destroy(qp->send_completion_cache); + free(qp->send_completion_cache); + } + if (qp->recv_completion_cache) { + completion_cache_destroy(qp->recv_completion_cache); + free(qp->recv_completion_cache); + } + + // Destroy the underlying ibv_qp and its CQs + if (qp->ibv_qp) { + ibv_destroy_qp(qp->ibv_qp); + } + if (qp->send_cq) { + ibv_destroy_cq(qp->send_cq); + } + if (qp->recv_cq) { + ibv_destroy_cq(qp->recv_cq); + } + + free(qp); +} + +struct ibv_qp* rdmaxcel_qp_get_ibv_qp(rdmaxcel_qp_t* qp) { + return qp ? qp->ibv_qp : NULL; +} + +// Atomic fetch_add operations +uint64_t rdmaxcel_qp_fetch_add_send_wqe_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_fetch_add(&qp->send_wqe_idx, 1) : 0; +} + +uint64_t rdmaxcel_qp_fetch_add_send_db_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_fetch_add(&qp->send_db_idx, 1) : 0; +} + +uint64_t rdmaxcel_qp_fetch_add_send_cq_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_fetch_add(&qp->send_cq_idx, 1) : 0; +} + +uint64_t rdmaxcel_qp_fetch_add_recv_wqe_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_fetch_add(&qp->recv_wqe_idx, 1) : 0; +} + +uint64_t rdmaxcel_qp_fetch_add_recv_db_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_fetch_add(&qp->recv_db_idx, 1) : 0; +} + +uint64_t rdmaxcel_qp_fetch_add_recv_cq_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_fetch_add(&qp->recv_cq_idx, 1) : 0; +} + +// Atomic load operations +uint64_t rdmaxcel_qp_load_send_wqe_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_load(&qp->send_wqe_idx) : 0; +} + +uint64_t rdmaxcel_qp_load_send_db_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_load(&qp->send_db_idx) : 0; +} + +uint64_t rdmaxcel_qp_load_send_cq_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_load(&qp->send_cq_idx) : 0; +} + +uint64_t rdmaxcel_qp_load_recv_wqe_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_load(&qp->recv_wqe_idx) : 0; +} + +uint64_t rdmaxcel_qp_load_recv_cq_idx(rdmaxcel_qp_t* qp) { + return qp ? atomic_load(&qp->recv_cq_idx) : 0; +} + +uint64_t rdmaxcel_qp_load_rts_timestamp(rdmaxcel_qp_t* qp) { + return qp ? atomic_load(&qp->rts_timestamp) : UINT64_MAX; +} + +// Atomic store operations +void rdmaxcel_qp_store_send_db_idx(rdmaxcel_qp_t* qp, uint64_t value) { + if (qp) { + atomic_store(&qp->send_db_idx, value); + } +} + +void rdmaxcel_qp_store_rts_timestamp(rdmaxcel_qp_t* qp, uint64_t value) { + if (qp) { + atomic_store(&qp->rts_timestamp, value); + } +} + +// Get completion caches +completion_cache_t* rdmaxcel_qp_get_send_cache(rdmaxcel_qp_t* qp) { + return qp ? qp->send_completion_cache : NULL; +} + +completion_cache_t* rdmaxcel_qp_get_recv_cache(rdmaxcel_qp_t* qp) { + return qp ? qp->recv_completion_cache : NULL; +} + +// ============================================================================ +// Completion Cache Implementation +// ============================================================================ + +void completion_cache_init(completion_cache_t* cache) { + if (!cache) { + return; + } + cache->head = -1; + cache->tail = -1; + cache->count = 0; + + // Initialize free list + cache->free_head = 0; + for (int i = 0; i < MAX_CACHED_COMPLETIONS - 1; i++) { + cache->nodes[i].next = i + 1; + } + cache->nodes[MAX_CACHED_COMPLETIONS - 1].next = -1; + + pthread_mutex_init(&cache->lock, NULL); +} + +void completion_cache_destroy(completion_cache_t* cache) { + if (!cache) { + return; + } + + // Warn if cache still has entries when being destroyed + if (cache->count > 0) { + fprintf( + stderr, + "WARNING: Destroying completion cache with %zu unretrieved entries! " + "Possible missing poll operations or leaked work requests.\n", + cache->count); + int curr = cache->head; + fprintf(stderr, " Cached wr_ids:"); + while (curr != -1 && cache->count > 0) { + fprintf(stderr, " %lu", cache->nodes[curr].wc.wr_id); + curr = cache->nodes[curr].next; + } + fprintf(stderr, "\n"); + } + + pthread_mutex_destroy(&cache->lock); + cache->count = 0; +} + +int completion_cache_add(completion_cache_t* cache, struct ibv_wc* wc) { + if (!cache || !wc) { + return 0; + } + + pthread_mutex_lock(&cache->lock); + + if (cache->free_head == -1) { + pthread_mutex_unlock(&cache->lock); + fprintf( + stderr, + "WARNING: Completion cache full (%zu entries)! Dropping completion " + "for wr_id=%lu, qp=%u\n", + cache->count, + wc->wr_id, + wc->qp_num); + return 0; + } + + // Pop from free list + int idx = cache->free_head; + cache->free_head = cache->nodes[idx].next; + + // Store completion + cache->nodes[idx].wc = *wc; + cache->nodes[idx].next = -1; + + // Append to tail of used list + if (cache->head == -1) { + cache->head = idx; + cache->tail = idx; + } else { + cache->nodes[cache->tail].next = idx; + cache->tail = idx; + } + + cache->count++; + pthread_mutex_unlock(&cache->lock); + return 1; +} + +int completion_cache_find( + completion_cache_t* cache, + uint64_t wr_id, + uint32_t qp_num, + struct ibv_wc* out_wc) { + if (!cache || !out_wc) { + return 0; + } + + pthread_mutex_lock(&cache->lock); + + int prev = -1; + int curr = cache->head; + + while (curr != -1) { + if (cache->nodes[curr].wc.wr_id == wr_id && + cache->nodes[curr].wc.qp_num == qp_num) { + // Found it! Copy out + *out_wc = cache->nodes[curr].wc; + + // Remove from used list + if (prev == -1) { + // Removing head (O(1) for typical case!) + cache->head = cache->nodes[curr].next; + if (cache->head == -1) { + cache->tail = -1; + } + } else { + // Removing from middle/tail + cache->nodes[prev].next = cache->nodes[curr].next; + if (cache->nodes[curr].next == -1) { + cache->tail = prev; + } + } + + // Add to free list + cache->nodes[curr].next = cache->free_head; + cache->free_head = curr; + + cache->count--; + pthread_mutex_unlock(&cache->lock); + return 1; + } + + prev = curr; + curr = cache->nodes[curr].next; + } + + pthread_mutex_unlock(&cache->lock); + return 0; +} + +int poll_cq_with_cache(poll_context_t* ctx, struct ibv_wc* out_wc) { + if (!ctx || !out_wc) { + return RDMAXCEL_INVALID_PARAMS; + } + + if (completion_cache_find( + ctx->cache, ctx->expected_wr_id, ctx->expected_qp_num, out_wc)) { + if (out_wc->status != IBV_WC_SUCCESS) { + return RDMAXCEL_COMPLETION_FAILED; + } + return 1; + } + + struct ibv_wc wc; + int ret = ibv_poll_cq(ctx->cq, 1, &wc); + + if (ret < 0) { + return RDMAXCEL_CQ_POLL_FAILED; + } + + if (ret == 0) { + return 0; + } + + if (wc.status != IBV_WC_SUCCESS) { + if (wc.wr_id == ctx->expected_wr_id && wc.qp_num == ctx->expected_qp_num) { + *out_wc = wc; + return RDMAXCEL_COMPLETION_FAILED; + } + completion_cache_add(ctx->cache, &wc); + return 0; + } + + if (wc.wr_id == ctx->expected_wr_id && wc.qp_num == ctx->expected_qp_num) { + *out_wc = wc; + return 1; + } + + completion_cache_add(ctx->cache, &wc); + return 0; +} + +// ============================================================================ +// End of Completion Cache Implementation +// ============================================================================ cudaError_t register_mmio_to_cuda(void* bf, size_t size) { cudaError_t result = cudaHostRegister( diff --git a/rdmaxcel-sys/src/rdmaxcel.cpp b/rdmaxcel-sys/src/rdmaxcel.cpp index 9e359895b..6f55824b3 100644 --- a/rdmaxcel-sys/src/rdmaxcel.cpp +++ b/rdmaxcel-sys/src/rdmaxcel.cpp @@ -260,8 +260,8 @@ int compact_mrs(struct ibv_pd* pd, SegmentInfo& seg, int access_flags) { } // Register memory region for a specific segment address, assume cuda -int register_segments(struct ibv_pd* pd, struct ibv_qp* qp) { - if (!pd) { +int register_segments(struct ibv_pd* pd, rdmaxcel_qp_t* qp) { + if (!pd || !qp) { return RDMAXCEL_INVALID_PARAMS; // Invalid parameter } scan_existing_segments(); @@ -334,7 +334,7 @@ int register_segments(struct ibv_pd* pd, struct ibv_qp* qp) { seg.mr_size = seg.phys_size; // Create vector of GPU addresses for bind_mrs - auto err = bind_mrs(pd, qp, access_flags, seg); + auto err = bind_mrs(pd, qp->ibv_qp, access_flags, seg); if (err != 0) { return err; // Bind MR's failed } @@ -535,6 +535,10 @@ const char* rdmaxcel_error_string(int error_code) { return "[RdmaXcel] Output buffer too small"; case RDMAXCEL_QUERY_DEVICE_FAILED: return "[RdmaXcel] Failed to query device attributes"; + case RDMAXCEL_CQ_POLL_FAILED: + return "[RdmaXcel] CQ polling failed"; + case RDMAXCEL_COMPLETION_FAILED: + return "[RdmaXcel] Completion status not successful"; default: return "[RdmaXcel] Unknown error code"; } diff --git a/rdmaxcel-sys/src/rdmaxcel.h b/rdmaxcel-sys/src/rdmaxcel.h index 01b01d575..4a3cf316d 100644 --- a/rdmaxcel-sys/src/rdmaxcel.h +++ b/rdmaxcel-sys/src/rdmaxcel.h @@ -15,8 +15,13 @@ #include #include "driver_api.h" +// Handle atomics for both C and C++ #ifdef __cplusplus +#include +#define _Atomic(T) std::atomic extern "C" { +#else +#include #endif typedef enum { @@ -136,7 +141,9 @@ typedef enum { -12, // Failed to get CUDA device attribute RDMAXCEL_CUDA_GET_DEVICE_FAILED = -13, // Failed to get CUDA device handle RDMAXCEL_BUFFER_TOO_SMALL = -14, // Output buffer too small - RDMAXCEL_QUERY_DEVICE_FAILED = -15 // Failed to query device attributes + RDMAXCEL_QUERY_DEVICE_FAILED = -15, // Failed to query device attributes + RDMAXCEL_CQ_POLL_FAILED = -16, // CQ polling failed + RDMAXCEL_COMPLETION_FAILED = -17 // Completion status not successful } rdmaxcel_error_code_t; // Error/Debugging functions @@ -147,7 +154,6 @@ const char* rdmaxcel_error_string(int error_code); int rdma_get_active_segment_count(); int rdma_get_all_segment_info(rdma_segment_info_t* info_array, int max_count); bool pt_cuda_allocator_compatibility(); -int register_segments(struct ibv_pd* pd, struct ibv_qp* qp); int deregister_segments(); // CUDA utility functions @@ -156,6 +162,127 @@ int get_cuda_pci_address_from_ptr( char* pci_addr_out, size_t pci_addr_size); +cudaError_t register_host_mem(void** buf, size_t size); + +// Forward declarations +typedef struct completion_cache completion_cache_t; + +// RDMA Queue Pair wrapper with atomic counters and completion caches +typedef struct rdmaxcel_qp { + struct ibv_qp* ibv_qp; // Underlying ibverbs QP + struct ibv_cq* send_cq; // Send completion queue + struct ibv_cq* recv_cq; // Receive completion queue + + // Atomic counters + _Atomic(uint64_t) send_wqe_idx; + _Atomic(uint64_t) send_db_idx; + _Atomic(uint64_t) send_cq_idx; + _Atomic(uint64_t) recv_wqe_idx; + _Atomic(uint64_t) recv_db_idx; + _Atomic(uint64_t) recv_cq_idx; + _Atomic(uint64_t) rts_timestamp; + + // Completion caches + completion_cache_t* send_completion_cache; + completion_cache_t* recv_completion_cache; +} rdmaxcel_qp_t; + +// Create and initialize an rdmaxcel QP (wraps create_qp + initializes +// counters/caches) +rdmaxcel_qp_t* rdmaxcel_qp_create( + struct ibv_context* context, + struct ibv_pd* pd, + int cq_entries, + int max_send_wr, + int max_recv_wr, + int max_send_sge, + int max_recv_sge, + rdma_qp_type_t qp_type); + +// Destroy rdmaxcel QP and clean up resources +void rdmaxcel_qp_destroy(rdmaxcel_qp_t* qp); + +// Get underlying ibv_qp pointer (for compatibility with existing ibverbs calls) +struct ibv_qp* rdmaxcel_qp_get_ibv_qp(rdmaxcel_qp_t* qp); + +// Atomic fetch_add operations +uint64_t rdmaxcel_qp_fetch_add_send_wqe_idx(rdmaxcel_qp_t* qp); +uint64_t rdmaxcel_qp_fetch_add_send_db_idx(rdmaxcel_qp_t* qp); +uint64_t rdmaxcel_qp_fetch_add_send_cq_idx(rdmaxcel_qp_t* qp); +uint64_t rdmaxcel_qp_fetch_add_recv_wqe_idx(rdmaxcel_qp_t* qp); +uint64_t rdmaxcel_qp_fetch_add_recv_db_idx(rdmaxcel_qp_t* qp); +uint64_t rdmaxcel_qp_fetch_add_recv_cq_idx(rdmaxcel_qp_t* qp); + +// Atomic load operations (minimal API surface) +// Send side: needed for doorbell ring iteration [db_idx, wqe_idx) +uint64_t rdmaxcel_qp_load_send_wqe_idx(rdmaxcel_qp_t* qp); +uint64_t rdmaxcel_qp_load_send_db_idx(rdmaxcel_qp_t* qp); +// Receive side: needed for receive operations +uint64_t rdmaxcel_qp_load_recv_wqe_idx(rdmaxcel_qp_t* qp); +// Completion queue indices: needed for polling without modifying +uint64_t rdmaxcel_qp_load_send_cq_idx(rdmaxcel_qp_t* qp); +uint64_t rdmaxcel_qp_load_recv_cq_idx(rdmaxcel_qp_t* qp); +// Connection state validation +uint64_t rdmaxcel_qp_load_rts_timestamp(rdmaxcel_qp_t* qp); + +// Atomic store operations +void rdmaxcel_qp_store_send_db_idx(rdmaxcel_qp_t* qp, uint64_t value); +void rdmaxcel_qp_store_rts_timestamp(rdmaxcel_qp_t* qp, uint64_t value); + +// Get completion caches +completion_cache_t* rdmaxcel_qp_get_send_cache(rdmaxcel_qp_t* qp); +completion_cache_t* rdmaxcel_qp_get_recv_cache(rdmaxcel_qp_t* qp); + +// Segment registration (uses rdmaxcel_qp_t, so must come after type definition) +int register_segments(struct ibv_pd* pd, rdmaxcel_qp_t* qp); + +// Completion Cache Structures and Functions +#define MAX_CACHED_COMPLETIONS 128 + +// Linked list node for cached completions +typedef struct completion_node { + struct ibv_wc wc; + int next; // Index of next node, or -1 for end of list +} completion_node_t; + +// Cache for "unmatched" completions using embedded linked list +typedef struct completion_cache { + completion_node_t nodes[MAX_CACHED_COMPLETIONS]; + int head; // Index of first used node, or -1 if empty + int tail; // Index of last used node + int free_head; // Index of first free node, or -1 if full + size_t count; + pthread_mutex_t lock; +} completion_cache_t; + +// Context for polling with cache +typedef struct poll_context { + uint64_t expected_wr_id; // What wr_id am I looking for? + uint32_t expected_qp_num; // What QP am I expecting? + completion_cache_t* cache; // Shared completion cache + struct ibv_cq* cq; // The CQ to poll +} poll_context_t; + +// Initialize completion cache +void completion_cache_init(completion_cache_t* cache); + +// Destroy completion cache +void completion_cache_destroy(completion_cache_t* cache); + +// Add completion to cache +int completion_cache_add(completion_cache_t* cache, struct ibv_wc* wc); + +// Find and remove completion from cache +int completion_cache_find( + completion_cache_t* cache, + uint64_t wr_id, + uint32_t qp_num, + struct ibv_wc* out_wc); + +// Poll with cache support +// Returns: 1 = found, 0 = not found, -1 = error +int poll_cq_with_cache(poll_context_t* ctx, struct ibv_wc* out_wc); + #ifdef __cplusplus } #endif