diff --git a/.github/workflows/test-gpu-rust.yml b/.github/workflows/test-gpu-rust.yml index 878e9167f..b30c59149 100644 --- a/.github/workflows/test-gpu-rust.yml +++ b/.github/workflows/test-gpu-rust.yml @@ -66,8 +66,9 @@ jobs: timeout 12m cargo nextest run --workspace --profile ci \ --exclude monarch_messages \ --exclude monarch_tensor_worker \ - --exclude torch-sys \ - --exclude torch-sys-cuda + --exclude torch-sys-cuda \ + --exclude monarch_rdma \ + --exclude torch-sys # Copy the test results to the expected location # TODO: error in pytest-results-action, TypeError: results.testsuites.testsuite.testcase is not iterable # Don't try to parse these results for now. diff --git a/Cargo.toml b/Cargo.toml index 90b2dc135..675e482c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,6 @@ resolver = "2" members = [ "build_utils", - "cuda-sys", "erased_lifetime", "hyper", "hyperactor", diff --git a/cuda-sys/Cargo.toml b/cuda-sys/Cargo.toml deleted file mode 100644 index 23c0149b8..000000000 --- a/cuda-sys/Cargo.toml +++ /dev/null @@ -1,16 +0,0 @@ -[package] -name = "cuda-sys" -version = "0.0.0" -authors = ["Facebook"] -edition = "2021" -license = "MIT" -links = "cuda" -description = "Rust FFI bindings for CUDA libraries" - -[dependencies] -cxx = "1.0.119" -serde = { version = "1.0.185", features = ["derive", "rc"] } - -[build-dependencies] -bindgen = "0.70.1" -build_utils = { path = "../build_utils" } diff --git a/cuda-sys/build.rs b/cuda-sys/build.rs deleted file mode 100644 index 22b8b4675..000000000 --- a/cuda-sys/build.rs +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -use std::env; -use std::path::PathBuf; - -#[cfg(target_os = "macos")] -fn main() {} - -#[cfg(not(target_os = "macos"))] -fn main() { - // Discover CUDA configuration including include and lib directories - let cuda_config = match build_utils::discover_cuda_config() { - Ok(config) => config, - Err(_) => { - build_utils::print_cuda_error_help(); - std::process::exit(1); - } - }; - - // Start building the bindgen configuration - let mut builder = bindgen::Builder::default() - // The input header we would like to generate bindings for - .header("src/wrapper.h") - .clang_arg("-x") - .clang_arg("c++") - .clang_arg("-std=gnu++20") - .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) - // Allow the specified functions and types (CUDA Runtime API only) - .allowlist_function("cuda.*") - .allowlist_function("CUDA.*") - .allowlist_type("cuda.*") - .allowlist_type("CUDA.*") - // Use newtype enum style - .default_enum_style(bindgen::EnumVariation::NewType { - is_bitfield: false, - is_global: false, - }); - - // Add CUDA include paths from the discovered configuration - for include_dir in &cuda_config.include_dirs { - builder = builder.clang_arg(format!("-I{}", include_dir.display())); - } - - // Include headers and libs from the active environment. - let python_config = match build_utils::python_env_dirs_with_interpreter("python3") { - Ok(config) => config, - Err(_) => { - eprintln!("Warning: Failed to get Python environment directories"); - build_utils::PythonConfig { - include_dir: None, - lib_dir: None, - } - } - }; - - if let Some(include_dir) = &python_config.include_dir { - builder = builder.clang_arg(format!("-I{}", include_dir)); - } - if let Some(lib_dir) = &python_config.lib_dir { - println!("cargo::rustc-link-search=native={}", lib_dir); - // Set cargo metadata to inform dependent binaries about how to set their - // RPATH (see controller/build.rs for an example). - println!("cargo::metadata=LIB_PATH={}", lib_dir); - } - - // Get CUDA library directory and emit link directives - let cuda_lib_dir = match build_utils::get_cuda_lib_dir() { - Ok(dir) => dir, - Err(_) => { - build_utils::print_cuda_lib_error_help(); - std::process::exit(1); - } - }; - println!("cargo:rustc-link-search=native={}", cuda_lib_dir); - println!("cargo:rustc-link-lib=cudart"); - - // Generate bindings - fail fast if this doesn't work - let bindings = builder.generate().expect("Unable to generate bindings"); - - // Write the bindings to the $OUT_DIR/bindings.rs file - match env::var("OUT_DIR") { - Ok(out_dir) => { - let out_path = PathBuf::from(out_dir); - bindings - .write_to_file(out_path.join("bindings.rs")) - .expect("Couldn't write bindings"); - - println!("cargo::rustc-cfg=cargo"); - println!("cargo::rustc-check-cfg=cfg(cargo)"); - } - Err(_) => { - println!("Note: OUT_DIR not set, skipping bindings file generation"); - } - } -} diff --git a/cuda-sys/src/lib.rs b/cuda-sys/src/lib.rs deleted file mode 100644 index a99325278..000000000 --- a/cuda-sys/src/lib.rs +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -use cxx::ExternType; -use cxx::type_id; - -/// SAFETY: bindings -unsafe impl ExternType for CUstream_st { - type Id = type_id!("CUstream_st"); - type Kind = cxx::kind::Opaque; -} - -// When building with cargo, this is actually the lib.rs file for a crate. -// Include the generated bindings.rs and suppress lints. -#[allow(non_camel_case_types)] -#[allow(non_upper_case_globals)] -#[allow(non_snake_case)] -mod inner { - #[cfg(cargo)] - include!(concat!(env!("OUT_DIR"), "/bindings.rs")); -} - -pub use inner::*; diff --git a/cuda-sys/src/wrapper.h b/cuda-sys/src/wrapper.h deleted file mode 100644 index 02b4028d6..000000000 --- a/cuda-sys/src/wrapper.h +++ /dev/null @@ -1,11 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include diff --git a/monarch_rdma/Cargo.toml b/monarch_rdma/Cargo.toml index a4c1bf6ca..655923956 100644 --- a/monarch_rdma/Cargo.toml +++ b/monarch_rdma/Cargo.toml @@ -13,7 +13,6 @@ edition = "2024" [dependencies] anyhow = "1.0.98" async-trait = "0.1.86" -cuda-sys = { path = "../cuda-sys" } futures = { version = "0.3.31", features = ["async-await", "compat"] } hyperactor = { version = "0.0.0", path = "../hyperactor" } rand = { version = "0.8", features = ["small_rng"] } @@ -28,9 +27,6 @@ ndslice = { version = "0.0.0", path = "../ndslice" } timed_test = { version = "0.0.0", path = "../timed_test" } tokio = { version = "1.47.1", features = ["full", "test-util", "tracing"] } -[build-dependencies] -build_utils = { path = "../build_utils" } - [features] cuda = [] default = ["cuda"] diff --git a/monarch_rdma/build.rs b/monarch_rdma/build.rs deleted file mode 100644 index d76dab156..000000000 --- a/monarch_rdma/build.rs +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#[cfg(target_os = "macos")] -fn main() {} - -#[cfg(not(target_os = "macos"))] -fn main() { - // Validate CUDA installation and get CUDA home path - let _cuda_home = match build_utils::validate_cuda_installation() { - Ok(home) => home, - Err(_) => { - build_utils::print_cuda_error_help(); - std::process::exit(1); - } - }; - - // Include headers and libs from the active environment. - let python_config = match build_utils::python_env_dirs_with_interpreter("python3") { - Ok(config) => config, - Err(_) => { - eprintln!("Warning: Failed to get Python environment directories"); - build_utils::PythonConfig { - include_dir: None, - lib_dir: None, - } - } - }; - - if let Some(lib_dir) = &python_config.lib_dir { - println!("cargo:rustc-link-search=native={}", lib_dir); - // Set cargo metadata to inform dependent binaries about how to set their - // RPATH (see controller/build.rs for an example). - println!("cargo:metadata=LIB_PATH={}", lib_dir); - } - - // Get CUDA library directory and emit link directives - let cuda_lib_dir = match build_utils::get_cuda_lib_dir() { - Ok(dir) => dir, - Err(_) => { - build_utils::print_cuda_lib_error_help(); - std::process::exit(1); - } - }; - println!("cargo:rustc-link-search=native={}", cuda_lib_dir); - println!("cargo:rustc-link-lib=cuda"); - println!("cargo:rustc-link-lib=cudart"); - - // Link against the ibverbs and mlx5 libraries (used by rdmaxcel-sys) - println!("cargo:rustc-link-lib=ibverbs"); - println!("cargo:rustc-link-lib=mlx5"); - - // Link PyTorch libraries needed for C10 symbols used by rdmaxcel-sys - let use_pytorch_apis = build_utils::get_env_var_with_rerun("TORCH_SYS_USE_PYTORCH_APIS") - .unwrap_or_else(|_| "1".to_owned()); - if use_pytorch_apis == "1" { - // Get PyTorch library directory using build_utils - let python_interpreter = std::path::PathBuf::from("python"); - if let Ok(output) = std::process::Command::new(&python_interpreter) - .arg("-c") - .arg(build_utils::PYTHON_PRINT_PYTORCH_DETAILS) - .output() - { - if output.status.success() { - for line in String::from_utf8_lossy(&output.stdout).lines() { - if let Some(path) = line.strip_prefix("LIBTORCH_LIB: ") { - // Add library search path - println!("cargo:rustc-link-search=native={}", path); - // Set rpath so runtime linker can find the libraries - println!("cargo::rustc-link-arg=-Wl,-rpath,{}", path); - } - } - } - } - - // Link core PyTorch libraries needed for C10 symbols - println!("cargo:rustc-link-lib=torch_cpu"); - println!("cargo:rustc-link-lib=torch"); - println!("cargo:rustc-link-lib=c10"); - println!("cargo:rustc-link-lib=c10_cuda"); - } else { - // Fallback to torch-sys links metadata if available - if let Ok(torch_lib_path) = std::env::var("DEP_TORCH_LIB_PATH") { - println!("cargo::rustc-link-arg=-Wl,-rpath,{}", torch_lib_path); - } - } - - // Set rpath for NCCL libraries if available - if let Ok(nccl_lib_path) = std::env::var("DEP_NCCL_LIB_PATH") { - println!("cargo::rustc-link-arg=-Wl,-rpath,{}", nccl_lib_path); - } - - // Disable new dtags, as conda envs generally use `RPATH` over `RUNPATH` - println!("cargo::rustc-link-arg=-Wl,--disable-new-dtags"); - - // Link the static libraries from rdmaxcel-sys - // Try the Cargo dependency mechanism first, then fall back to fixed paths - if let Ok(rdmaxcel_out_dir) = std::env::var("DEP_RDMAXCEL_SYS_OUT_DIR") { - println!("cargo:rustc-link-search=native={}", rdmaxcel_out_dir); - println!("cargo:rustc-link-lib=static=rdmaxcel"); - println!("cargo:rustc-link-lib=static=rdmaxcel_cpp"); - println!("cargo:rustc-link-lib=static=rdmaxcel_cuda"); - } else { - eprintln!("Warning: DEP_RDMAXCEL_SYS_OUT_DIR not found. Using fallback paths."); - - // Use relative paths to the known locations - let cuda_build_dir = "../rdmaxcel-sys/target/cuda_build"; - println!("cargo:rustc-link-search=native={}", cuda_build_dir); - println!("cargo:rustc-link-lib=static=rdmaxcel_cuda"); - - // Find the most recent rdmaxcel-sys build directory for C/C++ libraries - let monarch_target_dir = "../target/debug/build"; - if let Ok(entries) = std::fs::read_dir(monarch_target_dir) { - let mut rdmaxcel_dirs: Vec<_> = entries - .filter_map(|entry| entry.ok()) - .filter(|entry| { - entry - .file_name() - .to_string_lossy() - .starts_with("rdmaxcel-sys-") - }) - .collect(); - - // Sort by modification time and use the most recent - rdmaxcel_dirs - .sort_by_key(|entry| entry.metadata().ok().and_then(|m| m.modified().ok())); - - if let Some(most_recent) = rdmaxcel_dirs.last() { - let out_dir = most_recent.path().join("out"); - if out_dir.exists() { - println!("cargo:rustc-link-search=native={}", out_dir.display()); - println!("cargo:rustc-link-lib=static=rdmaxcel"); - println!("cargo:rustc-link-lib=static=rdmaxcel_cpp"); - } - } else { - eprintln!("Warning: No rdmaxcel-sys build directories found"); - } - } - } - - // Set build configuration flags - println!("cargo::rustc-cfg=cargo"); - println!("cargo::rustc-check-cfg=cfg(cargo)"); -} diff --git a/monarch_rdma/examples/cuda_ping_pong/src/cuda_ping_pong.rs b/monarch_rdma/examples/cuda_ping_pong/src/cuda_ping_pong.rs index b65993b4b..70233bcf7 100644 --- a/monarch_rdma/examples/cuda_ping_pong/src/cuda_ping_pong.rs +++ b/monarch_rdma/examples/cuda_ping_pong/src/cuda_ping_pong.rs @@ -407,7 +407,7 @@ impl Handler for CudaRdmaActor { self.device_id as i32 )); cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(context)); - cuda_sys::cudaDeviceSynchronize(); + rdmaxcel_sys::rdmaxcel_cuCtxSynchronize(); cu_check!(rdmaxcel_sys::rdmaxcel_cuMemcpyHtoD_v2( self.cu_ptr as u64, self.cpu_buffer.as_ptr() as *const std::ffi::c_void, @@ -554,7 +554,7 @@ impl Handler for CudaRdmaActor { self.device_id as i32 )); cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(context)); - cuda_sys::cudaDeviceSynchronize(); + rdmaxcel_sys::rdmaxcel_cuCtxSynchronize(); cu_check!(rdmaxcel_sys::rdmaxcel_cuMemcpyDtoH_v2( self.cpu_buffer.as_mut_ptr() as *mut std::ffi::c_void, self.cu_ptr as rdmaxcel_sys::CUdeviceptr, diff --git a/monarch_rdma/extension/lib.rs b/monarch_rdma/extension/lib.rs index a350b0b62..db3428037 100644 --- a/monarch_rdma/extension/lib.rs +++ b/monarch_rdma/extension/lib.rs @@ -24,15 +24,99 @@ use monarch_rdma::RdmaBuffer; use monarch_rdma::RdmaManagerActor; use monarch_rdma::RdmaManagerMessageClient; use monarch_rdma::rdma_supported; +use monarch_rdma::register_segment_scanner; use pyo3::IntoPyObjectExt; use pyo3::exceptions::PyException; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use pyo3::types::PyAny; use pyo3::types::PyTuple; use pyo3::types::PyType; use serde::Deserialize; use serde::Serialize; +/// Segment scanner callback that uses PyTorch's memory snapshot API. +/// +/// This function calls torch.cuda.memory._snapshot() to get CUDA memory segments +/// and fills the provided buffer with segment information. +/// +/// # Safety +/// This function is called from C code as a callback. +unsafe extern "C" fn pytorch_segment_scanner( + segments_out: *mut monarch_rdma::rdmaxcel_sys::rdmaxcel_scanned_segment_t, + max_segments: usize, +) -> usize { + // Acquire the GIL to call Python code + let result = Python::with_gil(|py| -> PyResult { + // Check if torch is already imported - don't import it ourselves + let sys = py.import("sys")?; + let modules = sys.getattr("modules")?; + + // Try to get torch from sys.modules + let torch = match modules.get_item("torch") { + Ok(torch_module) => torch_module, + Err(_) => { + // torch not imported yet, return 0 segments + return Ok(0); + } + }; + + // Check if CUDA is available + let cuda_available: bool = torch + .getattr("cuda")? + .getattr("is_available")? + .call0()? + .extract()?; + + if !cuda_available { + return Ok(0); + } + + // Call torch.cuda.memory._snapshot() + let snapshot = torch + .getattr("cuda")? + .getattr("memory")? + .getattr("_snapshot")? + .call0()?; + + // Get the segments list from the snapshot dict + let segments = snapshot.get_item("segments")?; + let segments_list: Vec> = segments.extract()?; + + let num_segments = segments_list.len(); + + // Fill the output buffer with as many segments as will fit + let segments_to_write = num_segments.min(max_segments); + + for (i, segment) in segments_list.iter().take(segments_to_write).enumerate() { + // Extract fields from the segment dict + let address: u64 = segment.get_item("address")?.extract()?; + let total_size: usize = segment.get_item("total_size")?.extract()?; + let device: i32 = segment.get_item("device")?.extract()?; + let is_expandable: bool = segment.get_item("is_expandable")?.extract()?; + + // Write to the output buffer - only the fields the scanner needs to provide + let seg_info = &mut *segments_out.add(i); + seg_info.address = address as usize; + seg_info.size = total_size; + seg_info.device = device; + seg_info.is_expandable = if is_expandable { 1 } else { 0 }; + } + + // Return total number of segments found (may be > max_segments) + Ok(num_segments) + }); + + match result { + Ok(count) => count, + Err(e) => { + // Log the specific error for debugging + eprintln!("[monarch_rdma] pytorch_segment_scanner failed: {}", e); + 0 + } + } +} + fn setup_rdma_context( rdma_buffer: &PyRdmaBuffer, local_proc_id: String, @@ -115,11 +199,6 @@ impl PyRdmaBuffer { rdma_supported() } - #[classmethod] - fn pt_cuda_allocator_compatibility<'py>(_cls: &Bound<'_, PyType>, _py: Python<'py>) -> bool { - monarch_rdma::pt_cuda_allocator_compatibility() - } - #[pyo3(name = "__repr__")] fn repr(&self) -> String { format!("", self.buffer) @@ -314,6 +393,10 @@ impl PyRdmaManager { } pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> { + // Register the PyTorch segment scanner callback. + // This calls torch.cuda.memory._snapshot() to get CUDA memory segments. + register_segment_scanner(Some(pytorch_segment_scanner)); + module.add_class::()?; module.add_class::()?; Ok(()) diff --git a/monarch_rdma/src/lib.rs b/monarch_rdma/src/lib.rs index a16fbf407..05da045ff 100644 --- a/monarch_rdma/src/lib.rs +++ b/monarch_rdma/src/lib.rs @@ -18,8 +18,13 @@ mod rdma_manager_actor; mod macros; pub use ibverbs_primitives::*; +pub use rdma_components::SegmentScannerFn; +// Re-export segment scanner types for extension crate +pub use rdma_components::register_segment_scanner; pub use rdma_components::*; pub use rdma_manager_actor::*; +// Re-export rdmaxcel_sys for extension crate to access types +pub use rdmaxcel_sys; pub use test_utils::is_cuda_available; /// Print comprehensive RDMA device information for debugging. diff --git a/monarch_rdma/src/rdma_components.rs b/monarch_rdma/src/rdma_components.rs index 7c600fe9e..ed689d112 100644 --- a/monarch_rdma/src/rdma_components.rs +++ b/monarch_rdma/src/rdma_components.rs @@ -1471,19 +1471,27 @@ pub fn get_registered_cuda_segments() -> Vec } } -/// Check if PyTorch CUDA caching allocator has expandable segments enabled. +/// Segment scanner callback type alias for convenience. +pub type SegmentScannerFn = rdmaxcel_sys::RdmaxcelSegmentScannerFn; + +/// Register a segment scanner callback. /// -/// This function calls the C++ implementation that directly accesses the -/// PyTorch C10 CUDA allocator configuration to check if expandable segments -/// are enabled, which is required for RDMA operations with CUDA tensors. +/// The scanner callback is called during RDMA segment registration to discover +/// CUDA memory segments. The callback should fill the provided buffer with +/// segment information and return the total count of segments found. /// -/// # Returns +/// If the returned count exceeds the buffer size, the caller will allocate +/// a larger buffer and retry. +/// +/// Pass `None` to unregister the scanner. +/// +/// # Safety /// -/// `true` if both CUDA caching allocator is enabled AND expandable segments are enabled, -/// `false` otherwise. -pub fn pt_cuda_allocator_compatibility() -> bool { - // SAFETY: We are calling a C++ function from rdmaxcel that accesses PyTorch C10 APIs. - unsafe { rdmaxcel_sys::pt_cuda_allocator_compatibility() } +/// The provided callback function must be safe to call from C code and must +/// properly handle the segment buffer. +pub fn register_segment_scanner(scanner: SegmentScannerFn) { + // SAFETY: We are registering a callback function pointer with rdmaxcel. + unsafe { rdmaxcel_sys::rdmaxcel_register_segment_scanner(scanner) } } #[cfg(test)] diff --git a/monarch_rdma/src/rdma_manager_actor.rs b/monarch_rdma/src/rdma_manager_actor.rs index 290771232..50f9e3b1e 100644 --- a/monarch_rdma/src/rdma_manager_actor.rs +++ b/monarch_rdma/src/rdma_manager_actor.rs @@ -160,10 +160,6 @@ pub struct RdmaManagerActor { config: IbverbsConfig, - // Flag indicating PyTorch CUDA allocator compatibility - // True if both C10 CUDA allocator is enabled AND expandable segments are enabled - pt_cuda_alloc: bool, - mlx5dv_enabled: bool, // Map of unique RdmaMemoryRegionView to ibv_mr*. In case of cuda w/ pytorch its -1 @@ -244,8 +240,9 @@ impl Drop for RdmaManagerActor { } } - // 4. Deregister all CUDA segments (if using PyTorch CUDA allocator) - if self.cuda_pt_alloc_enabled() { + // 4. Deregister all CUDA segments (if using mlx5dv) + // The segment scanner in Python handles compatibility checks + if self.mlx5dv_enabled { unsafe { let result = rdmaxcel_sys::deregister_segments(); if result != 0 { @@ -262,11 +259,6 @@ impl Drop for RdmaManagerActor { } impl RdmaManagerActor { - /// Whether to register all memory regions allocated by the PyTorch CUDA allocator - /// True if both `pt_cuda_alloc` and `mlx5dv_enabled` are true - fn cuda_pt_alloc_enabled(&self) -> bool { - self.pt_cuda_alloc && self.mlx5dv_enabled - } /// Get or create a domain and loopback QP for the specified RDMA device fn get_or_create_device_domain( &mut self, @@ -416,58 +408,55 @@ impl RdmaManagerActor { let mut mr: *mut rdmaxcel_sys::ibv_mr = std::ptr::null_mut(); let mrv; - if is_cuda && self.cuda_pt_alloc_enabled() { - // Get registered segments and check if our memory range is covered - let mut maybe_mrv = self.find_cuda_segment_for_address(addr, size); - // not found, lets re-sync with caching allocator and retry - if maybe_mrv.is_none() { - let err = rdmaxcel_sys::register_segments( - domain.pd, - qp.unwrap().qp as *mut rdmaxcel_sys::rdmaxcel_qp_t, - ); - if err != 0 { - let error_msg = get_rdmaxcel_error_message(err); - return Err(anyhow::anyhow!( - "RdmaXcel register_segments failed (addr: 0x{:x}, size: {}): {}", - addr, - size, - error_msg - )); + if is_cuda { + // First, try to use segment scanning if mlx5dv is enabled + let mut segment_mrv = None; + if self.mlx5dv_enabled { + // Try to find in already registered segments + segment_mrv = self.find_cuda_segment_for_address(addr, size); + + // If not found, trigger a re-sync with the allocator and retry + if segment_mrv.is_none() { + let err = rdmaxcel_sys::register_segments( + domain.pd, + qp.unwrap().qp as *mut rdmaxcel_sys::rdmaxcel_qp_t, + ); + // Only retry if register_segments succeeded + // If it fails (e.g., scanner returns 0 segments), we'll fall back to dmabuf + if err == 0 { + segment_mrv = self.find_cuda_segment_for_address(addr, size); + } } - - maybe_mrv = self.find_cuda_segment_for_address(addr, size); } - // if still not found, throw exception - if maybe_mrv.is_none() { - return Err(anyhow::anyhow!( - "MR registration failed for cuda (addr: 0x{:x}, size: {}), unable to find segment in CudaCachingAllocator", - addr, - size - )); - } - mrv = maybe_mrv.unwrap(); - } else if is_cuda { - let mut fd: i32 = -1; - rdmaxcel_sys::rdmaxcel_cuMemGetHandleForAddressRange( - &mut fd, - addr as rdmaxcel_sys::CUdeviceptr, - size, - rdmaxcel_sys::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, - 0, - ); - mr = rdmaxcel_sys::ibv_reg_dmabuf_mr(domain.pd, 0, size, 0, fd, access.0 as i32); - if mr.is_null() { - return Err(anyhow::anyhow!("Failed to register dmabuf MR")); + + // Use segment if found, otherwise fall back to direct dmabuf registration + if let Some(mrv_from_segment) = segment_mrv { + mrv = mrv_from_segment; + } else { + // Dmabuf path: used when mlx5dv is disabled OR scanner returns no segments + let mut fd: i32 = -1; + rdmaxcel_sys::rdmaxcel_cuMemGetHandleForAddressRange( + &mut fd, + addr as rdmaxcel_sys::CUdeviceptr, + size, + rdmaxcel_sys::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, + 0, + ); + mr = + rdmaxcel_sys::ibv_reg_dmabuf_mr(domain.pd, 0, size, 0, fd, access.0 as i32); + if mr.is_null() { + return Err(anyhow::anyhow!("Failed to register dmabuf MR")); + } + mrv = RdmaMemoryRegionView { + id: self.mrv_id, + virtual_addr: addr, + rdma_addr: (*mr).addr as usize, + size, + lkey: (*mr).lkey, + rkey: (*mr).rkey, + }; + self.mrv_id += 1; } - mrv = RdmaMemoryRegionView { - id: self.mrv_id, - virtual_addr: addr, - rdma_addr: (*mr).addr as usize, - size, - lkey: (*mr).lkey, - rkey: (*mr).rkey, - }; - self.mrv_id += 1; } else { // CPU memory path mr = rdmaxcel_sys::ibv_reg_mr( @@ -523,8 +512,6 @@ impl RemoteSpawn for RdmaManagerActor { let mut config = params.unwrap_or_default(); tracing::debug!("rdma is enabled, config device hint: {}", config.device); - let pt_cuda_alloc = crate::rdma_components::pt_cuda_allocator_compatibility(); - let mlx5dv_enabled = resolve_qp_type(config.qp_type) == rdmaxcel_sys::RDMA_QP_TYPE_MLX5DV; // check config and hardware support align @@ -555,7 +542,6 @@ impl RemoteSpawn for RdmaManagerActor { pending_qp_creation: Arc::new(Mutex::new(HashSet::new())), device_domains: HashMap::new(), config, - pt_cuda_alloc, mlx5dv_enabled, mr_map: HashMap::new(), mrv_id: 0, diff --git a/python/monarch/_rust_bindings/rdma.pyi b/python/monarch/_rust_bindings/rdma.pyi index 0f1f4b57f..3fb09b033 100644 --- a/python/monarch/_rust_bindings/rdma.pyi +++ b/python/monarch/_rust_bindings/rdma.pyi @@ -61,5 +61,3 @@ class _RdmaBuffer: def new_from_json(json: str) -> _RdmaBuffer: ... @classmethod def rdma_supported(cls) -> bool: ... - @classmethod - def pt_cuda_allocator_compatibility(cls) -> bool: ... diff --git a/python/monarch/_src/rdma/rdma.py b/python/monarch/_src/rdma/rdma.py index e3f831c3a..7a61a08b2 100644 --- a/python/monarch/_src/rdma/rdma.py +++ b/python/monarch/_src/rdma/rdma.py @@ -145,20 +145,38 @@ async def create_manager() -> _RdmaManager: await self._manager_futures[proc_mesh] +def pt_cuda_allocator_compatibility() -> bool: + """ + Check if PyTorch CUDA caching allocator is compatible with RDMA. + + This checks if both the CUDA caching allocator is enabled AND expandable + segments are enabled, which is required for RDMA operations with CUDA tensors. + + Returns: + bool: True if both conditions are met, False otherwise + """ + if not torch.cuda.is_available(): + return False + + # Get allocator snapshot which contains settings + snapshot = torch.cuda.memory._snapshot() + allocator_settings = snapshot.get("allocator_settings", {}) + + # Check if expandable_segments is enabled + return allocator_settings.get("expandable_segments", False) + + @functools.cache def _check_cuda_expandable_segments_enabled() -> bool: """ Check if PyTorch CUDA caching allocator is using expandable segments. - Uses the Rust extension which calls the C++ implementation from rdmaxcel-sys - that directly accesses the PyTorch C10 CUDA allocator configuration. - Returns: bool: True if expandable segments are enabled, False otherwise """ try: - # Use the new Rust utility function that calls the C++ pt_cuda_allocator_compatibility() - pt_cuda_compat = _RdmaBuffer.pt_cuda_allocator_compatibility() + # Call the Python implementation of pt_cuda_allocator_compatibility + pt_cuda_compat = pt_cuda_allocator_compatibility() if not pt_cuda_compat: warnings.warn( diff --git a/rdmaxcel-sys/build.rs b/rdmaxcel-sys/build.rs index 0a78d67de..d9d465428 100644 --- a/rdmaxcel-sys/build.rs +++ b/rdmaxcel-sys/build.rs @@ -71,7 +71,7 @@ fn main() { .header(&header_path) .clang_arg("-x") .clang_arg("c++") - .clang_arg("-std=gnu++20") + .clang_arg("-std=gnu++14") .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) // Allow the specified functions, types, and variables .allowlist_function("ibv_.*") @@ -90,7 +90,6 @@ fn main() { .allowlist_function("launch_recv_wqe") .allowlist_function("rdma_get_active_segment_count") .allowlist_function("rdma_get_all_segment_info") - .allowlist_function("pt_cuda_allocator_compatibility") .allowlist_function("register_segments") .allowlist_function("deregister_segments") .allowlist_function("rdmaxcel_cu.*") @@ -98,6 +97,7 @@ fn main() { .allowlist_function("rdmaxcel_print_device_info") .allowlist_function("rdmaxcel_error_string") .allowlist_function("rdmaxcel_qp_.*") + .allowlist_function("rdmaxcel_register_segment_scanner") .allowlist_function("poll_cq_with_cache") .allowlist_function("completion_cache_.*") .allowlist_type("ibv_.*") @@ -107,12 +107,14 @@ fn main() { .allowlist_type("wqe_params_t") .allowlist_type("cqe_poll_params_t") .allowlist_type("rdma_segment_info_t") + .allowlist_type("rdmaxcel_scanned_segment_t") .allowlist_type("rdmaxcel_qp_t") .allowlist_type("rdmaxcel_qp") .allowlist_type("completion_cache_t") .allowlist_type("completion_cache") .allowlist_type("poll_context_t") .allowlist_type("poll_context") + .allowlist_type("rdmaxcel_segment_scanner_fn") .allowlist_var("MLX5_.*") .allowlist_var("IBV_.*") // Block specific types that are manually defined in lib.rs @@ -171,29 +173,8 @@ fn main() { // Only link cudart (CUDA Runtime API) println!("cargo:rustc-link-lib=cudart"); - // Link PyTorch C++ libraries for c10 symbols - let use_pytorch_apis = build_utils::get_env_var_with_rerun("TORCH_SYS_USE_PYTORCH_APIS") - .unwrap_or_else(|_| "1".to_owned()); - if use_pytorch_apis == "1" { - // Try to get PyTorch library directory - let python_interpreter = std::path::PathBuf::from("python"); - if let Ok(output) = std::process::Command::new(&python_interpreter) - .arg("-c") - .arg(build_utils::PYTHON_PRINT_PYTORCH_DETAILS) - .output() - { - for line in String::from_utf8_lossy(&output.stdout).lines() { - if let Some(path) = line.strip_prefix("LIBTORCH_LIB: ") { - println!("cargo:rustc-link-search=native={}", path); - break; - } - } - } - // Link core PyTorch libraries needed for C10 symbols - println!("cargo:rustc-link-lib=torch_cpu"); - println!("cargo:rustc-link-lib=torch"); - println!("cargo:rustc-link-lib=c10"); - } + // Note: We no longer link against libtorch/c10 since segment scanning + // is now done via a callback registered from the extension crate. // Generate bindings let bindings = builder.generate().expect("Unable to generate bindings"); @@ -230,42 +211,10 @@ fn main() { panic!("C source file not found at {}", c_source_path); } - // Compile the C++ source file for CUDA allocator compatibility + // Compile the C++ source file let cpp_source_path = format!("{}/src/rdmaxcel.cpp", manifest_dir); let driver_api_cpp_path = format!("{}/src/driver_api.cpp", manifest_dir); if Path::new(&cpp_source_path).exists() && Path::new(&driver_api_cpp_path).exists() { - let mut libtorch_include_dirs: Vec = vec![]; - - // Use the same approach as torch-sys: Python discovery first, env vars as fallback - let use_pytorch_apis = - build_utils::get_env_var_with_rerun("TORCH_SYS_USE_PYTORCH_APIS") - .unwrap_or_else(|_| "1".to_owned()); - - if use_pytorch_apis == "1" { - // Use Python to get PyTorch include paths (same as torch-sys) - let python_interpreter = PathBuf::from("python"); - let output = std::process::Command::new(&python_interpreter) - .arg("-c") - .arg(build_utils::PYTHON_PRINT_PYTORCH_DETAILS) - .output() - .unwrap_or_else(|_| panic!("error running {python_interpreter:?}")); - - for line in String::from_utf8_lossy(&output.stdout).lines() { - if let Some(path) = line.strip_prefix("LIBTORCH_INCLUDE: ") { - libtorch_include_dirs.push(PathBuf::from(path)); - } - } - } else { - // Use environment variables (fallback approach) - libtorch_include_dirs.extend( - build_utils::get_env_var_with_rerun("LIBTORCH_INCLUDE") - .unwrap_or_default() - .split(':') - .filter(|s| !s.is_empty()) - .map(PathBuf::from), - ); - } - let mut cpp_build = cc::Build::new(); cpp_build .file(&cpp_source_path) @@ -273,17 +222,11 @@ fn main() { .include(format!("{}/src", manifest_dir)) .flag("-fPIC") .cpp(true) - .flag("-std=gnu++20") - .define("PYTORCH_C10_DRIVER_API_SUPPORTED", "1"); + .flag("-std=gnu++14"); // Add CUDA include paths cpp_build.include(&cuda_include_path); - // Add PyTorch/C10 include paths - for include_dir in &libtorch_include_dirs { - cpp_build.include(include_dir); - } - // Add Python include path if available if let Some(include_dir) = &python_config.include_dir { cpp_build.include(include_dir); @@ -324,7 +267,7 @@ fn main() { &cuda_obj_path, "--compiler-options", "-fPIC", - "-std=c++20", + "-std=c++14", "--expt-extended-lambda", "-Xcompiler", "-fPIC", diff --git a/rdmaxcel-sys/src/driver_api.cpp b/rdmaxcel-sys/src/driver_api.cpp index 63d1aff7d..d65e515c7 100644 --- a/rdmaxcel-sys/src/driver_api.cpp +++ b/rdmaxcel-sys/src/driver_api.cpp @@ -32,6 +32,7 @@ _(cuDeviceGetAttribute) \ _(cuCtxCreate_v2) \ _(cuCtxSetCurrent) \ + _(cuCtxSynchronize) \ _(cuGetErrorString) namespace rdmaxcel { @@ -214,6 +215,10 @@ CUresult rdmaxcel_cuCtxSetCurrent(CUcontext ctx) { return rdmaxcel::DriverAPI::get()->cuCtxSetCurrent_(ctx); } +CUresult rdmaxcel_cuCtxSynchronize(void) { + return rdmaxcel::DriverAPI::get()->cuCtxSynchronize_(); +} + // Error handling CUresult rdmaxcel_cuGetErrorString(CUresult error, const char** pStr) { return rdmaxcel::DriverAPI::get()->cuGetErrorString_(error, pStr); diff --git a/rdmaxcel-sys/src/driver_api.h b/rdmaxcel-sys/src/driver_api.h index f0976459a..78f3709c1 100644 --- a/rdmaxcel-sys/src/driver_api.h +++ b/rdmaxcel-sys/src/driver_api.h @@ -96,6 +96,8 @@ rdmaxcel_cuCtxCreate_v2(CUcontext* pctx, unsigned int flags, CUdevice dev); CUresult rdmaxcel_cuCtxSetCurrent(CUcontext ctx); +CUresult rdmaxcel_cuCtxSynchronize(void); + // Error handling CUresult rdmaxcel_cuGetErrorString(CUresult error, const char** pStr); diff --git a/rdmaxcel-sys/src/lib.rs b/rdmaxcel-sys/src/lib.rs index 546fd84ad..f493273ba 100644 --- a/rdmaxcel-sys/src/lib.rs +++ b/rdmaxcel-sys/src/lib.rs @@ -238,7 +238,11 @@ mod inner { pub use inner::*; -// RDMA error string function and CUDA utility functions +// Segment scanner callback type - type alias for the bindgen-generated type +pub type RdmaxcelSegmentScannerFn = rdmaxcel_segment_scanner_fn; + +// Additional extern "C" declarations for functions that are also auto-generated by bindgen. +// These provide a place for doc comments and explicit signatures. unsafe extern "C" { pub fn rdmaxcel_error_string(error_code: std::os::raw::c_int) -> *const std::os::raw::c_char; pub fn get_cuda_pci_address_from_ptr( diff --git a/rdmaxcel-sys/src/rdmaxcel.cpp b/rdmaxcel-sys/src/rdmaxcel.cpp index 6f55824b3..98312ae87 100644 --- a/rdmaxcel-sys/src/rdmaxcel.cpp +++ b/rdmaxcel-sys/src/rdmaxcel.cpp @@ -7,13 +7,12 @@ */ #include "rdmaxcel.h" -#include -#include #include #include #include #include #include +#include #include "driver_api.h" // MR size must be a multiple of 2MB @@ -62,20 +61,52 @@ struct SegmentInfo { static std::unordered_map activeSegments; static std::mutex segmentsMutex; +// Segment scanner callback - set via rdmaxcel_register_segment_scanner() +static rdmaxcel_segment_scanner_fn g_segment_scanner = nullptr; + +// Initial buffer size for segment scanning (will grow if needed) +static size_t g_segment_buffer_size = 64; + // Helper function to scan existing segments from allocator snapshot void scan_existing_segments() { + // If no scanner is registered, nothing to do + if (!g_segment_scanner) { + return; + } + std::lock_guard lock(segmentsMutex); - // Get current snapshot from the allocator - auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); + // Allocate a buffer for the scanner to fill + std::vector scanned_segments( + g_segment_buffer_size); + + // Call the scanner + size_t segment_count = + g_segment_scanner(scanned_segments.data(), g_segment_buffer_size); + + // If we need more space, grow the buffer and retry + if (segment_count > g_segment_buffer_size) { + // Round up to next power of 2 for efficiency + size_t new_size = g_segment_buffer_size; + while (new_size < segment_count) { + new_size *= 2; + } + g_segment_buffer_size = new_size; + scanned_segments.resize(g_segment_buffer_size); + + // Retry with larger buffer + segment_count = + g_segment_scanner(scanned_segments.data(), g_segment_buffer_size); + } - // Create a set to track snapshot segments + // Create a set to track scanned segments std::set> snapshotSegments; - // Process snapshot segments - for (const auto& segment : snapshot.segments) { - size_t segment_address = reinterpret_cast(segment.address); - int32_t device = segment.device; + // Process scanned segments + for (size_t i = 0; i < segment_count; i++) { + const auto& scanned = scanned_segments[i]; + size_t segment_address = scanned.address; + int32_t device = scanned.device; snapshotSegments.insert({segment_address, device}); @@ -83,13 +114,13 @@ void scan_existing_segments() { auto it = activeSegments.find(segment_address); if (it != activeSegments.end() && it->second.device == device) { // Existing segment found - update total_size if needed - if (it->second.phys_size != segment.total_size) { - it->second.phys_size = segment.total_size; + if (it->second.phys_size != scanned.size) { + it->second.phys_size = scanned.size; } } else { // New segment - add it SegmentInfo segInfo( - segment_address, segment.total_size, device, segment.is_expandable); + segment_address, scanned.size, device, scanned.is_expandable != 0); activeSegments[segment_address] = segInfo; } @@ -114,12 +145,9 @@ void scan_existing_segments() { extern "C" { -// Simple check for PyTorch CUDA allocator compatibility -bool pt_cuda_allocator_compatibility() { - return ( - c10::cuda::CUDACachingAllocator::isEnabled() && - c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: - expandable_segments()); +// Register a segment scanner callback +void rdmaxcel_register_segment_scanner(rdmaxcel_segment_scanner_fn scanner) { + g_segment_scanner = scanner; } // Get count of active segments diff --git a/rdmaxcel-sys/src/rdmaxcel.h b/rdmaxcel-sys/src/rdmaxcel.h index 4a3cf316d..fe8dd422b 100644 --- a/rdmaxcel-sys/src/rdmaxcel.h +++ b/rdmaxcel-sys/src/rdmaxcel.h @@ -153,9 +153,38 @@ const char* rdmaxcel_error_string(int error_code); // Active segment tracking functions (implemented in C++) int rdma_get_active_segment_count(); int rdma_get_all_segment_info(rdma_segment_info_t* info_array, int max_count); -bool pt_cuda_allocator_compatibility(); int deregister_segments(); +// Scanned segment information - minimal fields needed from external scanner +// This is what the scanner callback fills in, separate from internal +// rdma_segment_info_t +typedef struct { + size_t address; // Physical memory address of the segment + size_t size; // Size of the segment in bytes + int32_t device; // CUDA device ID + int is_expandable; // Boolean: 1 if expandable, 0 if not +} rdmaxcel_scanned_segment_t; + +// Segment scanner callback type +// This callback is used to scan for CUDA memory segments from an external +// source (e.g., Python/PyTorch allocator API). +// +// Parameters: +// segments_out: Buffer to write segment info into +// max_segments: Maximum number of segments that can fit in the buffer +// +// Returns: +// Number of segments found. If the return value is greater than max_segments, +// the caller should reallocate a larger buffer and call again. +typedef size_t (*rdmaxcel_segment_scanner_fn)( + rdmaxcel_scanned_segment_t* segments_out, + size_t max_segments); + +// Register a segment scanner callback. +// The scanner will be called by scan_existing_segments() to discover CUDA +// segments. Pass NULL to unregister the scanner. +void rdmaxcel_register_segment_scanner(rdmaxcel_segment_scanner_fn scanner); + // CUDA utility functions int get_cuda_pci_address_from_ptr( CUdeviceptr cuda_ptr,