Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numba bridge #1870

Open
shoyer opened this issue Dec 16, 2019 · 18 comments
Open

Numba bridge #1870

shoyer opened this issue Dec 16, 2019 · 18 comments
Assignees
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)

Comments

@shoyer
Copy link
Collaborator

shoyer commented Dec 16, 2019

It would be great to support calls into Numba via an XLA CustomCall (which works inside jax.jit). This would let you use Numba as an alternative for writing low-level kernels in JAX.

xref #81, #1100

@seibert
Copy link

seibert commented Dec 16, 2019

We are also interested how this would work. (cc: @sklam, @esc, @stuartarchibald)

@seibert
Copy link

seibert commented Dec 16, 2019

Looks like this could go via Numba's @cfunc machinery.

@shoyer
Copy link
Collaborator Author

shoyer commented Dec 16, 2019

For examples of how JAX wraps XLA's CustomCall interface, take a look at lapack.pyx and cusolver.cc/cusolver.py in jaxlib:
https://github.com/google/jax/tree/master/jaxlib

@saulshanabrook
Copy link

Ah very cool! Does this work on TPUs as well?

@seibert
Copy link

seibert commented Dec 16, 2019

We've been a little curious how one generates custom functions for the TPU as well. We haven't seen an LLVM backend for the TPU anywhere, so I assume the toolchain is not public?

@SamPruden
Copy link

I've been having a little play to see how doable this is. I've got the basics of Numba for CPU working.

import jax
import jax.numpy as jnp
from jax import lax
from jax.lib import xla_bridge as xb
from jax.lib import xla_client as xc
import functools
import numba

prim = jax.core.Primitive("Foo")
prim.def_impl(functools.partial(jax.interpreters.xla.apply_primitive, prim))

def shape_rule(aval): return aval.shape
def dtype_rule(aval): return aval.dtype

prim.def_abstract_eval(functools.partial(jax.lax.standard_abstract_eval, prim, shape_rule, dtype_rule))

def encapsulate(address):
  import ctypes
  PyCapsule_Destructor = ctypes.CFUNCTYPE(None, ctypes.py_object)
  PyCapsule_New = ctypes.pythonapi.PyCapsule_New
  PyCapsule_New.restype = ctypes.py_object
  PyCapsule_New.argtypes = (ctypes.c_void_p, ctypes.c_char_p, PyCapsule_Destructor)
  capsule = PyCapsule_New(address, b"xla._CUSTOM_CALL_TARGET", PyCapsule_Destructor(0))
  return capsule

# Is it called a kernel on CPU? 
@numba.cfunc(numba.types.void(numba.types.voidptr, numba.types.CPointer(numba.types.voidptr)))
def kernel_cpu(output, inbuffers):
  size = numba.carray(inbuffers[1], 1, numba.types.int32)[0] # TODO: Find the proper way to do this
  input = numba.carray(inbuffers[0], size, numba.types.float32)
  output = numba.carray(output, size, dtype = numba.types.float32)
  for i in range(size): output[i] = input[i] + 1

xc.register_custom_call_target("test", encapsulate(kernel_cpu.address), platform = 'cpu')

def translation_cpu(builder: jax.lib.xla_client.XlaBuilder, op):
  shape = builder.get_shape(op)
  assert shape.rank() == 1
  size_const = xb.constant(builder, shape.dimensions()[0])
  return  xc.ops.CustomCall(builder, b'test', operands = (op, size_const), shape = shape)

jax.interpreters.xla.backend_specific_translations['cpu'][prim] = translation_cpu

# TEST
prim.bind(7 * jnp.ones(20, dtype=jnp.float32))

It took a bit of fiddling about, but it's pretty easy in the end!


GPU, however, is proving a little more tricky.

Using Numba, I can make a kernel and get a handle for it.

@numba.cuda.jit("void(float32[:], float32[:])")
def kernel_gpu(input, output):
  i = numba.cuda.grid(1)
  output[i] = input[i] + 1

handle = kernel_gpu[1,10]._func.get().handle

I can also apply this kernel directly to JAX arrays through the standard Numba API. However, as JAX arrays are not writable and kernels can't return anything, I'm unable to get any output using this method. I've tested using a Numpy array as outputs and a JAX array as inputs, and that seems to work fine.

The step required to make this all work together is invoking the GPU kernel from CustomCall. This XLA page gives a good overview of the basic structure, and this cusolver source file in the JAX source is an example implementation.

I'm not sure what the best next step is here. I presume that the kernel invocation is going to need to be implemented in C++ proper rather than Numba, although I'm curious as to whether I can make Cython work. The only reason for my reluctance to do this properly is that I've done this all in Colab so far, so getting multiple files and compilers involved is a (probably necessary) step up in project complexity.

The cusolver.cc reference makes it all look quite complicated. There's manual device memory copying and stuff going on and I was hoping that things wouldn't need to get that complex, although maybe they do. I haven't grokked that file in detail yet. I was hoping that there may be a way to hijack the fact that Numba already does most of what we want under the hood - it already knows how to call the kernel with JAX inputs.

In theory all we need to implement is the equivalent of this sample from the XLA page that I linked:

void do_custom_call(CUstream stream, void** buffers,
                    const char* opaque, size_t opaque_len) {
  const float* in0 = reinterpret_cast<const float*>(buffers[0]);
  const float* in1 = reinterpret_cast<const float*>(buffers[1]);
  float* out = reinterpret_cast<float*>(buffers[2]);

  const int64 block_dim = 64;
  const int64 grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim,
                       /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}

This makes it seem like it may actually be quite easy! However, that sample is running through the NVCC compiler for the invocation and it's all getting a bit too much for my current approach where I'm working out of a Jupyter notebook! I think we're probably at the point where this needs to be a JAX source modification to be practical, although I'm not sure.

I'd appreciate feedback from knowledgeable people about what the best direction to take this in next is! Is it feasible to write a small C++ utility to handle the kernel invocation? Does this machinery already exist in JAX somewhere? Do I need to get NVCC involved? Is this something that it makes sense to attempt externally as I'm doing, or would it be much smoother to patch this into the JAX source?

This may be as far as I can justify taking this distraction for now, so I thought I would write up what I've done to at least inspire somebody else to take it further. I might do more if it feels approachable, although I'm quite far off track from the model I'm supposed to be training!

@josipd
Copy link

josipd commented Nov 30, 2020

Hello all,

I have been coincidentally also working on this, just pushed the WIP to my fork

https://github.com/josipd/jax/blob/master/jax/experimental/jambax.py

Would it make sense for you to collaborate to avoid duplicate work? There are quite a few additional things that can be done (better), e.g. better automatic batching, understanding if the interface can be improved, and quite importantly CUDA support, which I have not touched at all, but you have made some progress.

@SamPruden
Copy link

That's interesting @josipd! It looks like we've done roughly the same thing, although you've done a bit more towards an API - mine was only a technical exploration.

I have some vague ideas for what a nice - slightly higher level - API might look like, but I've not put them together or tested them yet. I have this idea in the back of my mind that for certain types of kernels - particularly pointwise and pointwise like - it may be possible to share an implementation between CPU and GPU. I'd like to play with that but I haven't got that far in my explorations yet.

For me this is only actually useful if it runs on GPU. My CPU implementation was only an exploratory stepping stone to that. My CUDA progress so far is minimal - I've basically just read some documentation.

I'm not familiar with how NVCC works and that whole side of setting things up, or indeed if that's actually necessary. That became a significant hurdle for me building an independent implementation, but if we went down the route of forking it may turn out to be very simple to just add extra CUDA code into the existing project infrastructure.

I'm on a Windows machine and that's not supported for building, which was another reason I stopped where I did.

I'm interested in collaborating if we're both working on this, but I haven't yet decided to commit more time to it.

@mblondel
Copy link

mblondel commented Dec 1, 2020

On our side, we are quite interested in CPU support. One of our motivation for a Numba / JAX bridge was that Numba can generate very efficient code for nested loops (typically, code like this or this). Also Numba supports native Python for loops, while in JAX one would need to use jax.lax or jax.experimental.loops. Is your main use case to write low-level CUDA kernels in Numba?

@SamPruden
Copy link

SamPruden commented Dec 1, 2020

That's interesting @mblondel. I was almost wondering whether CPU support is even worth doing, but clearly from your use case it is. It sounds like it's definitely worth tidying my and @josipd's experiments up into a nice API.

For me, my main interest in custom kernels is performance. There are situations where JAX's JIT is significantly slower than what I can hand code and I'd like to be able to experiment with throwing a custom solution together. These would usually be quite small kernels that represent custom layers. I'd be defining them inline in a notebook as part of a model definition so brevity and ease are important. That's what makes Numba ideal.

I think that I may have actually made some progress on GPU support! I've got some separate parts that I think should work in theory, but I haven't put them all together and tested it yet, so we'll see. It's probably too much of a hack to be used as an official solution; I'm patching into Numba's internal kernel invocation machinery and I don't think it's considered a public or stable API. It may be worth publishing as a plugin until somebody can come along and put a more official solution together in JAX itself. I'll update this thread with details when I've taken it a little further.

@SamPruden
Copy link

I'm currently blocked on this Numba issue and a few other fiddly bits trying to wrangle Numba into doing what I want. This hack may prove more trouble than it's worth. I think it's clear that the proper way to do this is to integrate the kernel launching code into JAX properly but I don't think I'm the right person for that task. It's probably quite easy, so if somebody from the team with the right experience to pull it off would like to add that to jaxlib that would be great! Although...

Even the process of obtaining the kernel handle from Numba is proving problematic. The handle = kernel_gpu[1,10]._func.get().handle method I posted above breaks when upgrading Numba from 0.48.0 to 0.52.0. Clearly that's not a stable API. I think we would need Numba to expose something for us.

If we're relying on getting changes made to Numba anyway, there may be a smoother way forward. About a year ago, @seibert mentioned here that they would like to enable Numba cfuncs to launch Numba CUDA kernels. I believe that would cover everything that we need to do this. The CustomCall could go into a standard Numba cfunc in the same way as I've done for the CPU, and that cfunc would call the Numba CUDA kernel. It's neat and allows this to be implemented in JAX with only Python changes.

@seibert, if you'd still like to get that put in then I think this becomes very easy. I don't think there's a feature request open for that on the Numba github yet so I'll probably add one soon and link back to here.

@SamPruden
Copy link

If that Numba change happened, I would probably be able to take on putting in a JAX PR for a nice wrapper API if we want that, as well as documentation. That appears to be what @josipd is also working on, so maybe we could collaborate on that.

@josipd
Copy link

josipd commented Jan 12, 2021

Hey all,

Should we agree on the API and have at least the CPU version checked in soon? Looking at the XLA GPU convention, I think what I suggested could extend there as well once we have the necessary visibility changes in numba.

@shoyer
Copy link
Collaborator Author

shoyer commented Jan 31, 2021

Another a way to do this (other than CustomCall) would be to call back from JAX into Python using jax.experimental.host_callback.call: https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html (cc @gnecula)

It would be nice if host_callback.call() supported passing arrays on GPUs, but right now everything seems to passed onto the CPU as NumPy arrays.

@SamPruden
Copy link

Would the host_callback approach be competitive for speed when going through Python? Speed is one of the main reasons to want this functionality, so I think implementing it in the fastest possible way is important.

The GPU thing is similar - CPU is a good start for hacking something together, but GPU is probably required to make this actually useful for most people.

I've never used TPUs but I suppose they would ideally be supported too... Numba's roadmap suggests that TPU support may happen at some point, but that's listed as "2020 and beyond" and hasn't been updated since 2018, so I'm not sure if that's actually happening. I can't find any relevant github issues over there. Is it even possible for Numba to support TPUs given Google secrecy about them?

@shoyer
Copy link
Collaborator Author

shoyer commented Feb 1, 2021

Would the host_callback approach be competitive for speed when going through Python? Speed is one of the main reasons to want this functionality, so I think implementing it in the fastest possible way is important.

It looks like the current version of host_callback adds about 1 ms of overhead. This isn't something you would want to put in your inner loop, but could be fine in some contexts.

@gnecula
Copy link
Collaborator

gnecula commented Feb 1, 2021

I am working on an alternative implementation of host_callback for cpu/gpu using customcall.

@PhilipVinc
Copy link
Contributor

For what is worth, I've been recently experimenting with @josipd jambax and found that it has extremely low overhead on the CPU.

However, I have some complicated code that must run on the cpu inside of a larger jitted GPU block. I tried using host_callaback but the overhead is very large.
As GPU-customcalls are actually encoded as CPU functions that are supposed to encode kernels to the stream, my idea was to memcpy the data from gpu to cpu, call the numba-jitted function, then call memcpy back to gpu.

This is not exactly your standard use-case, but it might be an interesting one to support, since jax does not support mixed cpu-gpu jitted functions?

I was experimenting with adding support for this to jambax by doing something like

    def xla_gpu_custom_call_target(stream, inout_gpu_ptrs, opaque, opaque_len): 
        # allocate temporary output cpu buffer
        if n_out == 1:
            args_out = (
                np.empty(output_shapes[0], dtype=output_dtypes[0]),
            )
        ...

        # allocate temporary input cpu buffer
        if n_in == 1:
            args_in = (
                np.empty(input_dimensions[0], dtype=input_dtypes[0]),
            )
          cudaMemcpyAsync(args_in[0].ctypes.data, inout_gpu_ptrs[0], input_byte_size[0], nb_types.int32(cupy_backends.cuda.api.runtime.memcpyDeviceToHost), stream)
        ...
     
        cudaStreamSynchronize(stream)
        numba_fn(args_out + args_in)

        if n_out == 1:
            cudaMemcpy(inout_gpu_ptrs[n_in+0], args_out[0].ctypes.data, output_byte_size[0], nb_types.int32(cupy_backends.cuda.api.runtime.memcpyHostToDevice))

However I'm seeing GIL crashes such as this.

@sudhakarsingh27 sudhakarsingh27 added NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) labels Aug 10, 2022
@sudhakarsingh27 sudhakarsingh27 added the contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. label Oct 5, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)
Projects
None yet
Development

No branches or pull requests

10 participants