Skip to content

Commit

Permalink
Revert "Removing unsafe GPU accesses."
Browse files Browse the repository at this point in the history
This reverts commit 5325ba2.
  • Loading branch information
Narsil committed Oct 26, 2022
1 parent b1664eb commit a2e19c5
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 39 deletions.
1 change: 1 addition & 0 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ crate-type = ["cdylib"]
pyo3 = { version = "0.17.1", features = ["extension-module"] }
memmap = "0.7"
serde_json = "1.0"
libloading = "0.7"

[dependencies.safetensors]
version = "*"
Expand Down
233 changes: 206 additions & 27 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,133 @@ struct safe_open {
framework: Framework,
device: Device,
mmap: Arc<Mmap>,
cudart: Option<libloading::Library>,
}

fn create_empty_tensor_pt<'a>(
module: &'a PyModule,
shape: &[usize],
dtype: Dtype,
device: &Device,
) -> PyResult<&'a PyAny> {
let py = module.py();
let shape = shape.to_vec();
let empty = module.getattr("empty")?;
let dtype: PyObject = get_pydtype(module, dtype)?;
let shape: PyObject = shape.into_py(py);
let device: PyObject = device.clone().into_py(py);
let kwargs = [("dtype", dtype), ("device", device)].into_py_dict(py);
let tensor = empty.call((shape,), Some(kwargs))?;
Ok(tensor)
}

fn find_cudart(module: &PyModule) -> PyResult<libloading::Library> {
let mut path: std::path::PathBuf = module
.getattr(intern!(module.py(), "__file__"))?
.extract()?;
path.pop();
path.push("lib");
for file in path.read_dir()?.flatten() {
let path = file.path();

let filename = path
.file_name()
.ok_or_else(|| exceptions::PyException::new_err("Couldn't read filename "))?;
let filename = filename
.to_str()
.ok_or_else(|| exceptions::PyException::new_err("Couldn't read filename "))?;
if filename.starts_with("libcudart") {
unsafe {
let cudart = file.path();
let lib = libloading::Library::new(cudart).map_err(|e| {
exceptions::PyException::new_err(format!("Couldn't load cuda {e:?}",))
})?;
return Ok(lib);
}
}
}
Err(exceptions::PyException::new_err("Couldn't find cuda"))
}

fn create_cuda_unsafe_tensor(
module: &PyModule,
cudart: &libloading::Library,
info: &TensorInfo,
device: &Device,
data: &[u8],
) -> PyResult<PyObject> {
let tensor = create_empty_tensor_pt(module, &info.shape, info.dtype, device)?;

let data_ptr_fn = tensor.getattr("data_ptr")?;
let data_ptr: usize = data_ptr_fn.call0()?.extract()?;

let out = unsafe {
let cuda_memcpy: libloading::Symbol<
unsafe extern "C" fn(
device_ptr: u64,
src_ptr: *const std::ffi::c_void,
src_len: usize,
) -> u32,
> = cudart.get(b"cudaMemcpy").map_err(|e| {
exceptions::PyException::new_err(format!("Couldn't find cudaMemcpy {e:?}",))
})?;
cuda_memcpy(
data_ptr as u64,
data.as_ptr() as *const std::ffi::c_void,
data.len(),
)
};
if out != 0 {
panic!(
"We tried to set your tensor fast, but there was a cuda error, This could
have corrupted your GPU ram, aborting to prevent further errors"
)
}
let tensor: PyObject = tensor.into_py(module.py());
Ok(tensor)
}

fn create_cuda_unsafe_tensor_from_slice(
module: &PyModule,
cudart: &libloading::Library,
shape: &[usize],
dtype: Dtype,
device: &Device,
iterator: safetensors::slice::SliceIterator,
) -> PyResult<PyObject> {
let tensor = create_empty_tensor_pt(module, shape, dtype, device)?;

let data_ptr_fn = tensor.getattr("data_ptr")?;
let data_ptr: usize = data_ptr_fn.call0()?.extract()?;
let mut offset = 0;
unsafe {
let cuda_memcpy: libloading::Symbol<
unsafe extern "C" fn(
device_ptr: u64,
src_ptr: *const std::ffi::c_void,
src_len: usize,
) -> u32,
> = cudart.get(b"cudaMemcpy").map_err(|e| {
exceptions::PyException::new_err(format!("Couldn't find cudaMemcpy {e:?}",))
})?;
for slice in iterator {
let len = slice.len();
let out = cuda_memcpy(
(data_ptr + offset) as u64,
slice.as_ptr() as *const std::ffi::c_void,
len,
);
if out != 0 {
panic!(
"We tried to set your tensor fast, but there was a cuda error, This could
have corrupted your GPU ram, aborting to prevent further errors"
)
}
offset += len;
}
}
let tensor: PyObject = tensor.into_py(module.py());
Ok(tensor)
}

#[pymethods]
Expand All @@ -229,12 +356,24 @@ impl safe_open {

let offset = n + 8;

let cudart = match (&device, &framework) {
(Device::Cuda(_), Framework::Pytorch) => {
Python::with_gil(|py| -> PyResult<Option<libloading::Library>> {
let module_name = intern!(py, "torch");
let module = PyModule::import(py, module_name)?;
Ok(Some(find_cudart(module)?))
})?
}
_ => None,
};

Ok(Self {
metadata,
offset,
framework,
device,
mmap: Arc::new(buffer),
cudart,
})
}

Expand All @@ -255,25 +394,47 @@ impl safe_open {

let data = &self.mmap[info.data_offsets.0 + self.offset..info.data_offsets.1 + self.offset];

let array: PyObject = Python::with_gil(|py| PyByteArray::new(py, data).into_py(py));

create_tensor(
&self.framework,
info.dtype,
&info.shape,
array,
&self.device,
)
match (&self.device, &self.framework, &self.cudart) {
(Device::Cuda(_), Framework::Pytorch, Some(cudart)) => {
Python::with_gil(|py| -> PyResult<PyObject> {
let module = PyModule::import(py, intern!(py, "torch"))?;
create_cuda_unsafe_tensor(module, cudart, info, &self.device, data)
})
}
_ => {
let array: PyObject = Python::with_gil(|py| PyByteArray::new(py, data).into_py(py));

create_tensor(
&self.framework,
info.dtype,
&info.shape,
array,
&self.device,
)
}
}
}

pub fn get_slice(&self, name: &str) -> PyResult<PySafeSlice> {
let cudart = match (&self.device, &self.framework) {
(Device::Cuda(_), Framework::Pytorch) => {
Python::with_gil(|py| -> PyResult<Option<libloading::Library>> {
let module_name = intern!(py, "torch");
let module = PyModule::import(py, module_name)?;
Ok(Some(find_cudart(module)?))
})?
}
_ => None,
};

if let Some(info) = self.metadata.tensors().get(name) {
Ok(PySafeSlice {
info: info.clone(),
framework: self.framework.clone(),
offset: self.offset,
device: self.device.clone(),
mmap: self.mmap.clone(),
cudart,
})
} else {
Err(exceptions::PyException::new_err(format!(
Expand All @@ -296,6 +457,7 @@ struct PySafeSlice {
offset: usize,
device: Device,
mmap: Arc<Mmap>,
cudart: Option<libloading::Library>,
}

#[derive(FromPyObject)]
Expand Down Expand Up @@ -341,24 +503,41 @@ impl PySafeSlice {
let mut offset = 0;
let length = iterator.remaining_byte_len();

let array: PyObject = Python::with_gil(|py| -> PyResult<PyObject> {
Ok(PyByteArray::new_with(py, length, |bytes: &mut [u8]| {
for slice in iterator {
let len = slice.len();
bytes[offset..offset + slice.len()].copy_from_slice(slice);
offset += len;
}
Ok(())
})?
.into_py(py))
})?;
create_tensor(
&self.framework,
self.info.dtype,
&newshape,
array,
&self.device,
)
match (&self.device, &self.framework, &self.cudart) {
(Device::Cuda(_), Framework::Pytorch, Some(cudart)) => {
Python::with_gil(|py| -> PyResult<PyObject> {
let module = PyModule::import(py, intern!(py, "torch"))?;
create_cuda_unsafe_tensor_from_slice(
module,
cudart,
&newshape,
self.info.dtype,
&self.device,
iterator,
)
})
}
_ => {
let array: PyObject = Python::with_gil(|py| -> PyResult<PyObject> {
Ok(PyByteArray::new_with(py, length, |bytes: &mut [u8]| {
for slice in iterator {
let len = slice.len();
bytes[offset..offset + slice.len()].copy_from_slice(slice);
offset += len;
}
Ok(())
})?
.into_py(py))
})?;
create_tensor(
&self.framework,
self.info.dtype,
&newshape,
array,
&self.device,
)
}
}
}
}

Expand Down
24 changes: 12 additions & 12 deletions bindings/python/tests/test_pt_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,8 @@ def setUp(self):
save_file(data.copy(), self.local)

def test_deserialization_safe(self):
W = 1
N = 1

for i in range(W):
torch.load(self.filename)
pt_timing = datetime.timedelta(0)
for i in range(N):
start = datetime.datetime.now()
tweights = torch.load(self.filename)
pt_time = datetime.datetime.now() - start
pt_timing += pt_time
pt_time = pt_timing / N
W = 10
N = 50

for i in range(W):
load_file(self.local)
Expand All @@ -76,6 +66,16 @@ def test_deserialization_safe(self):
safe_timing += safe_time
safe_time = safe_timing / N

for i in range(W):
torch.load(self.filename)
pt_timing = datetime.timedelta(0)
for i in range(N):
start = datetime.datetime.now()
tweights = torch.load(self.filename)
pt_time = datetime.datetime.now() - start
pt_timing += pt_time
pt_time = pt_timing / N

print()
print(f"Deserialization (Safe) took {safe_time}")
print(f"Deserialization (PT) took {pt_time} (Safe is {pt_time/safe_time} faster)")
Expand Down

0 comments on commit a2e19c5

Please sign in to comment.