Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make GPU loading faster by removing all extra CPU copies. #33

Merged
merged 7 commits into from Nov 2, 2022

Conversation

Narsil
Copy link
Collaborator

@Narsil Narsil commented Oct 24, 2022

Loading roughly 2x as fast (depends on hardware) models on GPU.

IT works by not allocating on CPU, and directly allocating + memsetting the memory on the GPU.
Afaik, there's no way to do that in Pytorch. All the storage methods have intermediary CPU alloc.

@davidhewitt Sorry to ping you out of the blue, but you have been a huge help on tokenizers.

The present library aims to prevent arbitrary code execution when loading weights (in particular Pytorch). https://www.youtube.com/watch?v=2ethDz9KnLk

At HF, we haven't yet fully committed to actually switch format, but multiple other nice features could come out of it, namely this handy 2x speedup when loading tensors on GPU (because we can remove entirely the CPU allocations).
However the present solution uses a good dose of unsafe.

  • Do you have ideas on removing this unsafe ?
  • Or validating it ?

Currently there's is no way (afaik) to access an equivalent of cudaMemcpy directly from torch. That indirection would help put the safety back in pytorch and not in this crate. However after a healthy dose of looking I couldn't find anything. Still since pytorch is in Python world, I'm guessing it's always going to require passing a PyBufferArray which we have to reallocate (because of trailing '\0' at least).

The current PR does the following:

  • It figures out the libcudart being used by pytorch itself
  • Loads it
  • When creating tensors on GPU, it will allocate an empty buffer (managed by Pytorch) through torch.empty(shape, dtype, device).
  • Then lookup cudaMemcpyHostToDevice.
  • Call it directly to set the GPU RAM with the actual tensors on disk.

Copy link
Contributor

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

Comment on lines 322 to 363
let cuda_memcpy: libloading::Symbol<
unsafe extern "C" fn(
device_ptr: u64,
src_ptr: *const std::ffi::c_void,
src_len: usize,
) -> u32,
> = cudart.get(b"cudaMemcpy").map_err(|e| {
exceptions::PyException::new_err(format!("Couldn't find cudaMemcpy {e:?}",))
})?;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure of the cost of that but shouldn't it be done only once ? In __enter__ maybe

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a reference, and references with a lifetime and those are disallowed in pyclass.
There might be ways those.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I timed the actual libloading, it takes 70us roughly.. (That's actually fetching the entire library, not just this method)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you store cudart library in a write-once static like once_cell::Lazy or a GILOnceCell (or just Box::leak it), then you could get &'static reference to Library. Could even put cuda_memcpy as a Symbol<'static, ...> in a second static. That would also allow you to only have to make this unsafe call once per proces.

(with once_cell::Lazy you should take care to release the GIL before interacting with it as IIRC you can deadlock because of the two locking mechanisms)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to save directly cuda_memcpy within the GILOnceCell, but I wasn't able to do it. Well I was able to compile it, but I triggered the panic when using cuda_memcpy afterwards. So something must have been very wrong.

It was something like:

fn find_memcpy() -> Symbol<'static, MemcpyFn>{
   let symbol = Python::with_gil(|py| {
         let module = get_module(py, &TORCH_MODULE);
         let lib = find_cudart(module);
         let lib = Box::leak(Box::new(lib));
         let symbol = lib.get("cudaMemcpy");
         symbol
     })
     symbol
}
....
// At init
CUDA_MEMCPY.get_or_init(py, || find_memcpy());

However somehow it fails to properly call it.
Box::leak prevents the Drop from being run right ? So the lib should not get unloaded ?

For instance:

The implementation of thread-local variables is extremely platform specific and uses of such variables that work on e.g. Linux may have unintended behaviour on other targets.

So maybe we can't really leak.

Library itself doesn't seem to be thread safe, maybe we should load on every access to be safe ? (It's ok speedwise)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, that's unfortunate. It's not clear to me what might be going wrong, that looks like it should be fine to me.

I'm not aware of any thread-local variable here?

If torch is dynamically loading the cuda runtime, I wondered if it's possible that initialising too early here might interact badly with torch? I'm just speculating wildly at this point.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're not worried about the loading call, you could always cache what you wanted from torch / filesystem (e.g. cuda location) and then the per-call cost is just load & lookup the symbol.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is what I'll end up doing I think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also in torch==1.13 they splitted cuda to it's own lib, so ....

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, figure it out, loading torch._C directly actually works. libloading uses dlopen under the hood and is able to resolve cudaMemcpy directly, so no need for ELF shenanigans and it's safe since it's already loaded in the torch python namespace !

bindings/python/src/lib.rs Outdated Show resolved Hide resolved
Copy link

@davidhewitt davidhewitt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem, happy to offer ideas (although I can't always guarantee a speedy reply!).

FFI is both inherently unsafe and also difficult to verify. For pure-rust unsafe there's tools such as miri to check for UB, I think with FFI the best you can do is run tools like Valgrind but I'm unconvinced they'll give much value here.

One suggestion I do have - the way you search for the cuda runtime lib is a bit spooky, I think it might not match the one that torch has already loaded? I had a quick look but didn't see a torch API to get the loaded cuda runtime name (although it might still exist). I wonder if possible to enumerate already loaded libraries to find the cuda lib - you can then use libloading to get a new handle and I think that'll avoid arbitrary code execution because the lib will already have been initialised?

bindings/python/src/lib.rs Outdated Show resolved Hide resolved
Comment on lines 322 to 363
let cuda_memcpy: libloading::Symbol<
unsafe extern "C" fn(
device_ptr: u64,
src_ptr: *const std::ffi::c_void,
src_len: usize,
) -> u32,
> = cudart.get(b"cudaMemcpy").map_err(|e| {
exceptions::PyException::new_err(format!("Couldn't find cudaMemcpy {e:?}",))
})?;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you store cudart library in a write-once static like once_cell::Lazy or a GILOnceCell (or just Box::leak it), then you could get &'static reference to Library. Could even put cuda_memcpy as a Symbol<'static, ...> in a second static. That would also allow you to only have to make this unsafe call once per proces.

(with once_cell::Lazy you should take care to release the GIL before interacting with it as IIRC you can deadlock because of the two locking mechanisms)

// but somehow the call failed. This is really worrying since Pytorch is
// responsible for allocating the memory.
if out != 0 {
panic!(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be std::process::abort? It is possible (although discouraged) to catch the PanicException translated by PyO3. We don't allow users to name it directly, but it derives from BaseException so some catch blocks can halt them.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the tip !
std::process::abort I think.

Comment on lines 333 to 384
let out = cuda_memcpy(
(data_ptr + offset) as u64,
slice.as_ptr() as *const std::ffi::c_void,
len,
);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I like to reduce the scope of the unsafe, so I'd just put it around here and the cudart.get call.

Suggested change
let out = cuda_memcpy(
(data_ptr + offset) as u64,
slice.as_ptr() as *const std::ffi::c_void,
len,
);
// Safety: pointer and len are guaranteed valid as they originate from the same slice
let out = unsafe { cuda_memcpy(
(data_ptr + offset) as u64,
slice.as_ptr() as *const std::ffi::c_void,
len,
) };

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree in general. Here I tried to encapsulate both fetching the symbol and the various calls, as basically they all influence each other.

@Narsil Narsil changed the title [WIP] Make GPU loading faster by removing all extra CPU copies. Make GPU loading faster by removing all extra CPU copies. Oct 27, 2022
@Narsil
Copy link
Collaborator Author

Narsil commented Oct 27, 2022

the way you search for the cuda runtime lib is a bit spooky

100% agree, I tinkered with "cleaner approaches", but I found it quite hard to code.
torch._C is their c library.
It is linked to libtorch_python.so, which is linked to libtorch_cuda_cpp.so which is linked to libcudart-xxxx.so.xxx.

Afaik, the "link" is resolved at runtime in the actual process and unknown to the program here. So there is still some manual resolution to be done manually here. Meaning reading environement LD_LIBRARY_PATH and so on. Which I don't feel very confident about.

This current resolution is not great because if someone manages to highjack this location and send a bogus library named libcudart then it will indeed enable escape.
But I feel like if an attacker managed to do this, well security is gone already.

@davidhewitt
Copy link

But I feel like if an attacker managed to do this, well security is gone already.

I mostly agree with this point; I guess that what this might enable is something like a privilege escalation if attacker used user X to write the malicious lib and then your process with user Y loads it.

Is it perhaps asking upstream in torch if they have or are willing to expose a way to get the currently loaded lib?

@Narsil
Copy link
Collaborator Author

Narsil commented Oct 31, 2022

Is it perhaps asking upstream in torch if they have or are willing to expose a way to get the currently loaded lib?

I'll try to do that. Old torch libs still won't work though.
Good point about privilege escalation, didn't think about that.

@Narsil Narsil force-pushed the faster_gpu_load branch 3 times, most recently from 55f805f to 1f8d58f Compare November 1, 2022 08:21
@Narsil
Copy link
Collaborator Author

Narsil commented Nov 1, 2022

@davidhewitt

I am finding super weird issues (and slightly worrying).

  let data_ptr_fn = tensor.getattr("data_ptr")?;
    let data_ptr: usize = data_ptr_fn.call0()?.extract()?;

    let out = unsafe {
        cuda_memcpy(
            data_ptr.try_into()?,

This works, but this doesn't:

  let data_ptr_fn = tensor.getattr("data_ptr")?;
    let data_ptr: u64 = data_ptr_fn.call0()?.extract()?;

    let out = unsafe {
        cuda_memcpy(
            data_ptr,

It fails with some weird cuda error. Either invalid argument, or memcpy invalid direction (??).
This doesn't make a lot of sense to me.
The values look exactly the same when I compare them. Everything really seems to be equivalent.

Is there anything you might be aware of ?

Also referencing the Symbol<_> seems to be unsafe (just moving the calls from cuda_memcpy to ´&cuda_memcpy in the loop). That I guess might be linked to how the compiler can optimize references and libloading is not upholding some invariants I'm guessing.

@davidhewitt
Copy link

Where did you get the definition of the cudaMemcpy symbol?

Looking at the docs on Nvidia (which might be wrong, I'm not an expert here), it seems like the definition you want is:

#[repr(C)]
enum cudaMemcpyKind {
    cudaMemcpyHostToHost = 0,
    cudaMemcpyHostToDevice = 1,
    cudaMemcpyDeviceToHost = 2,
    cudaMemcpyDefault = 3
}

type MemcpyFn =
     unsafe extern "C" fn(dest: *mut c_void, src_ptr: *const c_void, src_len: libc::size_t, kind: cudaMemcpyKind) -> u32;

In particular you should be using pointer for the first argument in case you are running on a 32-bit system, and additionally you are missing the kind argument. Without this you've got UB as the FFI function call will be setting up an incorrect stack frame for the C function arguments. The kind parameter is unset and will likely generate exactly the kind of errors you are reporting.

Looks like with usize input you're just getting luck that the stack is set up correctly, with u64 not so lucky.

@davidhewitt
Copy link

You might want to consider creating a #[repr(C)] enum for the return value too, as I'm unsure if the size is guaranteed to be u32 on all platforms. https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3f51e3575c2178246db0a94a430e0038

@Narsil
Copy link
Collaborator Author

Narsil commented Nov 1, 2022

Thanks ! Super helpful.

For some reason I was using https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g4d32266788c440b0220b1a9ba5795169 and using its signature instead of cudaMemcpy. (My inital testing setup was using cuda.h + bindgen which generated cuMemcpy and not cudaMemcpy)

That does explain the corrupted stack ! And the issues when I tried to leak the reference. The corruption was not because of the symbol but because the stack changed and the signature was wrong.

Use cuda-sys directly.

For slice too.

Fun unsafe.

Reduce unsafe.

Removing CPU unsafe.

Using shared `cuda-sys` (temporary, we need to use torch's cuda version).

Tmp rework

Cleaner device.

Adding some statistics...

Warmup steps.

Removing unsafe GPU accesses.

Removing dead code.

Removing libloading.

Revert "Removing unsafe GPU accesses."

This reverts commit 5325ba2.

Unsafe comments.

Using GILOnceCell for module reuse.

Finding the lib through the real python workload.
Still requires transitive library parsing.

Stable with global lib ref.

Abort.

Weird bug on torch 1.13.

We **need** to get the data_ptr within the loop ?
Some very weird errors in calling the `cudaMemcpy` that fail depending
on order. Even aliasing the `Symbol` seem to be unsafe.
@Narsil Narsil merged commit be0683f into main Nov 2, 2022
@Narsil Narsil deleted the faster_gpu_load branch November 2, 2022 15:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants