-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make GPU loading faster by removing all extra CPU copies. #33
Conversation
a2e19c5
to
302b001
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this!
bindings/python/src/lib.rs
Outdated
let cuda_memcpy: libloading::Symbol< | ||
unsafe extern "C" fn( | ||
device_ptr: u64, | ||
src_ptr: *const std::ffi::c_void, | ||
src_len: usize, | ||
) -> u32, | ||
> = cudart.get(b"cudaMemcpy").map_err(|e| { | ||
exceptions::PyException::new_err(format!("Couldn't find cudaMemcpy {e:?}",)) | ||
})?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure of the cost of that but shouldn't it be done only once ? In __enter__
maybe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a reference, and references with a lifetime and those are disallowed in pyclass
.
There might be ways those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI I timed the actual libloading, it takes 70us
roughly.. (That's actually fetching the entire library, not just this method)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you store cudart
library in a write-once static like once_cell::Lazy
or a GILOnceCell
(or just Box::leak
it), then you could get &'static
reference to Library
. Could even put cuda_memcpy
as a Symbol<'static, ...>
in a second static. That would also allow you to only have to make this unsafe
call once per proces.
(with once_cell::Lazy
you should take care to release the GIL before interacting with it as IIRC you can deadlock because of the two locking mechanisms)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to save directly cuda_memcpy
within the GILOnceCell, but I wasn't able to do it. Well I was able to compile it, but I triggered the panic when using cuda_memcpy
afterwards. So something must have been very wrong.
It was something like:
fn find_memcpy() -> Symbol<'static, MemcpyFn>{
let symbol = Python::with_gil(|py| {
let module = get_module(py, &TORCH_MODULE);
let lib = find_cudart(module);
let lib = Box::leak(Box::new(lib));
let symbol = lib.get("cudaMemcpy");
symbol
})
symbol
}
....
// At init
CUDA_MEMCPY.get_or_init(py, || find_memcpy());
However somehow it fails to properly call it.
Box::leak
prevents the Drop
from being run right ? So the lib should not get unloaded ?
For instance:
The implementation of thread-local variables is extremely platform specific and uses of such variables that work on e.g. Linux may have unintended behaviour on other targets.
So maybe we can't really leak.
Library
itself doesn't seem to be thread safe, maybe we should load on every access to be safe ? (It's ok speedwise)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, that's unfortunate. It's not clear to me what might be going wrong, that looks like it should be fine to me.
I'm not aware of any thread-local variable here?
If torch is dynamically loading the cuda runtime, I wondered if it's possible that initialising too early here might interact badly with torch? I'm just speculating wildly at this point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're not worried about the loading call, you could always cache what you wanted from torch / filesystem (e.g. cuda location) and then the per-call cost is just load & lookup the symbol.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is what I'll end up doing I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also in torch==1.13
they splitted cuda to it's own lib, so ....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, figure it out, loading torch._C
directly actually works. libloading
uses dlopen
under the hood and is able to resolve cudaMemcpy
directly, so no need for ELF shenanigans and it's safe since it's already loaded in the torch python namespace !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem, happy to offer ideas (although I can't always guarantee a speedy reply!).
FFI is both inherently unsafe and also difficult to verify. For pure-rust unsafe
there's tools such as miri to check for UB, I think with FFI the best you can do is run tools like Valgrind but I'm unconvinced they'll give much value here.
One suggestion I do have - the way you search for the cuda runtime lib is a bit spooky, I think it might not match the one that torch has already loaded? I had a quick look but didn't see a torch API to get the loaded cuda runtime name (although it might still exist). I wonder if possible to enumerate already loaded libraries to find the cuda lib - you can then use libloading
to get a new handle and I think that'll avoid arbitrary code execution because the lib will already have been initialised?
bindings/python/src/lib.rs
Outdated
let cuda_memcpy: libloading::Symbol< | ||
unsafe extern "C" fn( | ||
device_ptr: u64, | ||
src_ptr: *const std::ffi::c_void, | ||
src_len: usize, | ||
) -> u32, | ||
> = cudart.get(b"cudaMemcpy").map_err(|e| { | ||
exceptions::PyException::new_err(format!("Couldn't find cudaMemcpy {e:?}",)) | ||
})?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you store cudart
library in a write-once static like once_cell::Lazy
or a GILOnceCell
(or just Box::leak
it), then you could get &'static
reference to Library
. Could even put cuda_memcpy
as a Symbol<'static, ...>
in a second static. That would also allow you to only have to make this unsafe
call once per proces.
(with once_cell::Lazy
you should take care to release the GIL before interacting with it as IIRC you can deadlock because of the two locking mechanisms)
bindings/python/src/lib.rs
Outdated
// but somehow the call failed. This is really worrying since Pytorch is | ||
// responsible for allocating the memory. | ||
if out != 0 { | ||
panic!( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be std::process::abort
? It is possible (although discouraged) to catch the PanicException
translated by PyO3. We don't allow users to name it directly, but it derives from BaseException
so some catch
blocks can halt them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the tip !
std::process::abort
I think.
bindings/python/src/lib.rs
Outdated
let out = cuda_memcpy( | ||
(data_ptr + offset) as u64, | ||
slice.as_ptr() as *const std::ffi::c_void, | ||
len, | ||
); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally I like to reduce the scope of the unsafe
, so I'd just put it around here and the cudart.get
call.
let out = cuda_memcpy( | |
(data_ptr + offset) as u64, | |
slice.as_ptr() as *const std::ffi::c_void, | |
len, | |
); | |
// Safety: pointer and len are guaranteed valid as they originate from the same slice | |
let out = unsafe { cuda_memcpy( | |
(data_ptr + offset) as u64, | |
slice.as_ptr() as *const std::ffi::c_void, | |
len, | |
) }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree in general. Here I tried to encapsulate both fetching the symbol and the various calls, as basically they all influence each other.
aa6cfeb
to
d9b3942
Compare
100% agree, I tinkered with "cleaner approaches", but I found it quite hard to code. Afaik, the "link" is resolved at runtime in the actual process and unknown to the program here. So there is still some manual resolution to be done manually here. Meaning reading environement LD_LIBRARY_PATH and so on. Which I don't feel very confident about. This current resolution is not great because if someone manages to highjack this location and send a bogus library named |
I mostly agree with this point; I guess that what this might enable is something like a privilege escalation if attacker used user X to write the malicious lib and then your process with user Y loads it. Is it perhaps asking upstream in torch if they have or are willing to expose a way to get the currently loaded lib? |
I'll try to do that. Old torch libs still won't work though. |
55f805f
to
1f8d58f
Compare
I am finding super weird issues (and slightly worrying). let data_ptr_fn = tensor.getattr("data_ptr")?;
let data_ptr: usize = data_ptr_fn.call0()?.extract()?;
let out = unsafe {
cuda_memcpy(
data_ptr.try_into()?, This works, but this doesn't: let data_ptr_fn = tensor.getattr("data_ptr")?;
let data_ptr: u64 = data_ptr_fn.call0()?.extract()?;
let out = unsafe {
cuda_memcpy(
data_ptr, It fails with some weird cuda error. Either invalid argument, or memcpy invalid direction (??). Is there anything you might be aware of ? Also referencing the |
Where did you get the definition of the Looking at the docs on Nvidia (which might be wrong, I'm not an expert here), it seems like the definition you want is: #[repr(C)]
enum cudaMemcpyKind {
cudaMemcpyHostToHost = 0,
cudaMemcpyHostToDevice = 1,
cudaMemcpyDeviceToHost = 2,
cudaMemcpyDefault = 3
}
type MemcpyFn =
unsafe extern "C" fn(dest: *mut c_void, src_ptr: *const c_void, src_len: libc::size_t, kind: cudaMemcpyKind) -> u32; In particular you should be using pointer for the first argument in case you are running on a 32-bit system, and additionally you are missing the Looks like with |
You might want to consider creating a |
Thanks ! Super helpful. For some reason I was using https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g4d32266788c440b0220b1a9ba5795169 and using its signature instead of That does explain the corrupted stack ! And the issues when I tried to leak the reference. The corruption was not because of the symbol but because the stack changed and the signature was wrong. |
Use cuda-sys directly. For slice too. Fun unsafe. Reduce unsafe. Removing CPU unsafe. Using shared `cuda-sys` (temporary, we need to use torch's cuda version). Tmp rework Cleaner device. Adding some statistics... Warmup steps. Removing unsafe GPU accesses. Removing dead code. Removing libloading. Revert "Removing unsafe GPU accesses." This reverts commit 5325ba2. Unsafe comments. Using GILOnceCell for module reuse. Finding the lib through the real python workload. Still requires transitive library parsing. Stable with global lib ref. Abort. Weird bug on torch 1.13. We **need** to get the data_ptr within the loop ?
Some very weird errors in calling the `cudaMemcpy` that fail depending on order. Even aliasing the `Symbol` seem to be unsafe.
15168a8
to
87767a2
Compare
Loading roughly 2x as fast (depends on hardware) models on GPU.
IT works by not allocating on CPU, and directly allocating + memsetting the memory on the GPU.
Afaik, there's no way to do that in Pytorch. All the storage methods have intermediary CPU alloc.
@davidhewitt Sorry to ping you out of the blue, but you have been a huge help on
tokenizers
.The present library aims to prevent arbitrary code execution when loading weights (in particular Pytorch). https://www.youtube.com/watch?v=2ethDz9KnLk
At HF, we haven't yet fully committed to actually switch format, but multiple other nice features could come out of it, namely this handy 2x speedup when loading tensors on GPU (because we can remove entirely the CPU allocations).
However the present solution uses a good dose of
unsafe
.Currently there's is no way (afaik) to access an equivalent of
cudaMemcpy
directly from torch. That indirection would help put the safety back in pytorch and not in this crate. However after a healthy dose of looking I couldn't find anything. Still since pytorch is in Python world, I'm guessing it's always going to require passing aPyBufferArray
which we have to reallocate (because of trailing'\0'
at least).The current PR does the following:
libcudart
being used by pytorch itselftorch.empty(shape, dtype, device)
.cudaMemcpyHostToDevice
.