From a2e19c5d1c1edff68a9b297f1c54738b78a3c68e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 26 Oct 2022 15:08:53 +0200 Subject: [PATCH] Revert "Removing unsafe GPU accesses." This reverts commit 5325ba2b73fffc16416130193da1690353e0a7db. --- bindings/python/Cargo.toml | 1 + bindings/python/src/lib.rs | 233 +++++++++++++++++--- bindings/python/tests/test_pt_comparison.py | 24 +- 3 files changed, 219 insertions(+), 39 deletions(-) diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 8c66059a..b62cb9a6 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -12,6 +12,7 @@ crate-type = ["cdylib"] pyo3 = { version = "0.17.1", features = ["extension-module"] } memmap = "0.7" serde_json = "1.0" +libloading = "0.7" [dependencies.safetensors] version = "*" diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 0884dc40..80a5ce5e 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -204,6 +204,133 @@ struct safe_open { framework: Framework, device: Device, mmap: Arc, + cudart: Option, +} + +fn create_empty_tensor_pt<'a>( + module: &'a PyModule, + shape: &[usize], + dtype: Dtype, + device: &Device, +) -> PyResult<&'a PyAny> { + let py = module.py(); + let shape = shape.to_vec(); + let empty = module.getattr("empty")?; + let dtype: PyObject = get_pydtype(module, dtype)?; + let shape: PyObject = shape.into_py(py); + let device: PyObject = device.clone().into_py(py); + let kwargs = [("dtype", dtype), ("device", device)].into_py_dict(py); + let tensor = empty.call((shape,), Some(kwargs))?; + Ok(tensor) +} + +fn find_cudart(module: &PyModule) -> PyResult { + let mut path: std::path::PathBuf = module + .getattr(intern!(module.py(), "__file__"))? + .extract()?; + path.pop(); + path.push("lib"); + for file in path.read_dir()?.flatten() { + let path = file.path(); + + let filename = path + .file_name() + .ok_or_else(|| exceptions::PyException::new_err("Couldn't read filename "))?; + let filename = filename + .to_str() + .ok_or_else(|| exceptions::PyException::new_err("Couldn't read filename "))?; + if filename.starts_with("libcudart") { + unsafe { + let cudart = file.path(); + let lib = libloading::Library::new(cudart).map_err(|e| { + exceptions::PyException::new_err(format!("Couldn't load cuda {e:?}",)) + })?; + return Ok(lib); + } + } + } + Err(exceptions::PyException::new_err("Couldn't find cuda")) +} + +fn create_cuda_unsafe_tensor( + module: &PyModule, + cudart: &libloading::Library, + info: &TensorInfo, + device: &Device, + data: &[u8], +) -> PyResult { + let tensor = create_empty_tensor_pt(module, &info.shape, info.dtype, device)?; + + let data_ptr_fn = tensor.getattr("data_ptr")?; + let data_ptr: usize = data_ptr_fn.call0()?.extract()?; + + let out = unsafe { + let cuda_memcpy: libloading::Symbol< + unsafe extern "C" fn( + device_ptr: u64, + src_ptr: *const std::ffi::c_void, + src_len: usize, + ) -> u32, + > = cudart.get(b"cudaMemcpy").map_err(|e| { + exceptions::PyException::new_err(format!("Couldn't find cudaMemcpy {e:?}",)) + })?; + cuda_memcpy( + data_ptr as u64, + data.as_ptr() as *const std::ffi::c_void, + data.len(), + ) + }; + if out != 0 { + panic!( + "We tried to set your tensor fast, but there was a cuda error, This could + have corrupted your GPU ram, aborting to prevent further errors" + ) + } + let tensor: PyObject = tensor.into_py(module.py()); + Ok(tensor) +} + +fn create_cuda_unsafe_tensor_from_slice( + module: &PyModule, + cudart: &libloading::Library, + shape: &[usize], + dtype: Dtype, + device: &Device, + iterator: safetensors::slice::SliceIterator, +) -> PyResult { + let tensor = create_empty_tensor_pt(module, shape, dtype, device)?; + + let data_ptr_fn = tensor.getattr("data_ptr")?; + let data_ptr: usize = data_ptr_fn.call0()?.extract()?; + let mut offset = 0; + unsafe { + let cuda_memcpy: libloading::Symbol< + unsafe extern "C" fn( + device_ptr: u64, + src_ptr: *const std::ffi::c_void, + src_len: usize, + ) -> u32, + > = cudart.get(b"cudaMemcpy").map_err(|e| { + exceptions::PyException::new_err(format!("Couldn't find cudaMemcpy {e:?}",)) + })?; + for slice in iterator { + let len = slice.len(); + let out = cuda_memcpy( + (data_ptr + offset) as u64, + slice.as_ptr() as *const std::ffi::c_void, + len, + ); + if out != 0 { + panic!( + "We tried to set your tensor fast, but there was a cuda error, This could + have corrupted your GPU ram, aborting to prevent further errors" + ) + } + offset += len; + } + } + let tensor: PyObject = tensor.into_py(module.py()); + Ok(tensor) } #[pymethods] @@ -229,12 +356,24 @@ impl safe_open { let offset = n + 8; + let cudart = match (&device, &framework) { + (Device::Cuda(_), Framework::Pytorch) => { + Python::with_gil(|py| -> PyResult> { + let module_name = intern!(py, "torch"); + let module = PyModule::import(py, module_name)?; + Ok(Some(find_cudart(module)?)) + })? + } + _ => None, + }; + Ok(Self { metadata, offset, framework, device, mmap: Arc::new(buffer), + cudart, }) } @@ -255,18 +394,39 @@ impl safe_open { let data = &self.mmap[info.data_offsets.0 + self.offset..info.data_offsets.1 + self.offset]; - let array: PyObject = Python::with_gil(|py| PyByteArray::new(py, data).into_py(py)); - - create_tensor( - &self.framework, - info.dtype, - &info.shape, - array, - &self.device, - ) + match (&self.device, &self.framework, &self.cudart) { + (Device::Cuda(_), Framework::Pytorch, Some(cudart)) => { + Python::with_gil(|py| -> PyResult { + let module = PyModule::import(py, intern!(py, "torch"))?; + create_cuda_unsafe_tensor(module, cudart, info, &self.device, data) + }) + } + _ => { + let array: PyObject = Python::with_gil(|py| PyByteArray::new(py, data).into_py(py)); + + create_tensor( + &self.framework, + info.dtype, + &info.shape, + array, + &self.device, + ) + } + } } pub fn get_slice(&self, name: &str) -> PyResult { + let cudart = match (&self.device, &self.framework) { + (Device::Cuda(_), Framework::Pytorch) => { + Python::with_gil(|py| -> PyResult> { + let module_name = intern!(py, "torch"); + let module = PyModule::import(py, module_name)?; + Ok(Some(find_cudart(module)?)) + })? + } + _ => None, + }; + if let Some(info) = self.metadata.tensors().get(name) { Ok(PySafeSlice { info: info.clone(), @@ -274,6 +434,7 @@ impl safe_open { offset: self.offset, device: self.device.clone(), mmap: self.mmap.clone(), + cudart, }) } else { Err(exceptions::PyException::new_err(format!( @@ -296,6 +457,7 @@ struct PySafeSlice { offset: usize, device: Device, mmap: Arc, + cudart: Option, } #[derive(FromPyObject)] @@ -341,24 +503,41 @@ impl PySafeSlice { let mut offset = 0; let length = iterator.remaining_byte_len(); - let array: PyObject = Python::with_gil(|py| -> PyResult { - Ok(PyByteArray::new_with(py, length, |bytes: &mut [u8]| { - for slice in iterator { - let len = slice.len(); - bytes[offset..offset + slice.len()].copy_from_slice(slice); - offset += len; - } - Ok(()) - })? - .into_py(py)) - })?; - create_tensor( - &self.framework, - self.info.dtype, - &newshape, - array, - &self.device, - ) + match (&self.device, &self.framework, &self.cudart) { + (Device::Cuda(_), Framework::Pytorch, Some(cudart)) => { + Python::with_gil(|py| -> PyResult { + let module = PyModule::import(py, intern!(py, "torch"))?; + create_cuda_unsafe_tensor_from_slice( + module, + cudart, + &newshape, + self.info.dtype, + &self.device, + iterator, + ) + }) + } + _ => { + let array: PyObject = Python::with_gil(|py| -> PyResult { + Ok(PyByteArray::new_with(py, length, |bytes: &mut [u8]| { + for slice in iterator { + let len = slice.len(); + bytes[offset..offset + slice.len()].copy_from_slice(slice); + offset += len; + } + Ok(()) + })? + .into_py(py)) + })?; + create_tensor( + &self.framework, + self.info.dtype, + &newshape, + array, + &self.device, + ) + } + } } } diff --git a/bindings/python/tests/test_pt_comparison.py b/bindings/python/tests/test_pt_comparison.py index 300cec82..409340ce 100644 --- a/bindings/python/tests/test_pt_comparison.py +++ b/bindings/python/tests/test_pt_comparison.py @@ -53,18 +53,8 @@ def setUp(self): save_file(data.copy(), self.local) def test_deserialization_safe(self): - W = 1 - N = 1 - - for i in range(W): - torch.load(self.filename) - pt_timing = datetime.timedelta(0) - for i in range(N): - start = datetime.datetime.now() - tweights = torch.load(self.filename) - pt_time = datetime.datetime.now() - start - pt_timing += pt_time - pt_time = pt_timing / N + W = 10 + N = 50 for i in range(W): load_file(self.local) @@ -76,6 +66,16 @@ def test_deserialization_safe(self): safe_timing += safe_time safe_time = safe_timing / N + for i in range(W): + torch.load(self.filename) + pt_timing = datetime.timedelta(0) + for i in range(N): + start = datetime.datetime.now() + tweights = torch.load(self.filename) + pt_time = datetime.datetime.now() - start + pt_timing += pt_time + pt_time = pt_timing / N + print() print(f"Deserialization (Safe) took {safe_time}") print(f"Deserialization (PT) took {pt_time} (Safe is {pt_time/safe_time} faster)")