-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extremely slow GPU execution #7024
Comments
Thanks for raising this, and for working so hard to minimize it! The best tool here is to use profiling. If you can get a profile showing a realistic workload, we can really dig in to what improvements can be made (either to your code, to JAX itself, or to XLA:GPU). There's one effect that would explain one of your comments, though I don't think it would explain the code as written being slow. General while_loops can require returning control to the host on each iteration just to decide whether to dispatch another iteration of the loop body on the GPU, incurring expensive synchronization and transfer overheads (which would loom large when the loop body itself is cheap). But in XLA:GPU there's a "for loop optimization" which is meant to notice when the loop actually has a statically fixed trip count (as it does here, at least with the code as written!) so that control need not be returned to the host on each iteration. Could you share a profile of the execution so we can dig in? |
Thank you so much for the lightning fast reply! The first thing I tried was profiling, but I just couldn't make sense of all the information. I even tried using
Thanks for explaining that! I can definitely convert these loops to something else. I guess I should convert these to I wonder if your comment here is related? |
Well, I wouldn't suggest using Yes that comment is related. # Calling softplus costs 32%!
return jnp.dot(inputs, softplus(w)) For this part, because XLA:GPU is likely calling into pre-packaged CUDA kernels for the dot (unless it knows it can generate better code than the closed-source Nvidia-provided kernels, which is rare), adding the softplus may mean that you have to launch two kernels (one for the dot, one for the presumably fused xla-generated softplus computation) per call to the MLP (i.e. at least two per loop iteration), rather than just one (just the dot). (Tangentially, XLA:TPU has much more flexibility here: since it generates the dot routine too, it can fuse things like elementwise operations into the loads and stores of the dot operation, and indeed on TPU any By the way, if instead of splitting the RNG on every iteration, you just split it once into a big array (with leading axis size The overall theme here is to try to minimize the number of kernel launches per loop iteration. I haven't looked at your profile yet, but I'll try to get the chance soon! |
About the profile: would it be possible to share a screenshot of the TensorBoard visualization? A screenshot is easy to act on, and to show to others! |
Okay that explains why the speedup matches just making it static. One benefit to scan is that it's an assertion that the iteration limit is static. There's no
That's a great idea. I'll try that.
Sure, which tab do you want me to screenshot? |
I made your two suggested changes (below), and the runtime went from 0.69s down to 0.13s:
It's still much slower than running this on the CPU, and I don't understand why. In my real code, I can unroll the from functools import partial
from typing import Any
import haiku as hk
import jax.numpy as jnp
from contexttimer import Timer
from jax import jit
from jax.experimental import enable_x64
from jax.lax import scan
from jax.nn import sigmoid, softplus
from jax.random import PRNGKey, normal, split
from tjax.dataclasses import dataclass
class Linear(hk.Module):
def __init__(self, output_size: int):
super().__init__()
self.output_size = output_size
def __call__(self, inputs):
w = hk.get_parameter("w", [inputs.shape[-1], self.output_size],
inputs.dtype, # Passing dtype costs 23%!
init=jnp.zeros)
# Calling softplus costs 32%!
return jnp.dot(inputs, softplus(w))
class NoisyMLP(hk.Module):
def __init__(self, layer_sizes):
super().__init__()
self.layers = [Linear(output_size) for output_size in layer_sizes]
def __call__(self, inputs):
out = inputs
for layer in self.layers:
out = layer(out)
out = sigmoid(out) # Sigmoid costs 10%!
return out
@dataclass
class SamplerState:
code_momentum: Any
rng: Any
iterations: Any
shape = (1,)
def nat_to_exp(natural_explanation):
mlp = NoisyMLP((12, *shape))
return mlp(natural_explanation)
def haiku_weight_initializer() -> None:
nat_to_exp(jnp.zeros(shape))
def update_state(weights, state, diffusion):
nat_to_exp_f = hk.transform(nat_to_exp).apply
force = nat_to_exp_f(weights, None, state.code_momentum)
new_code_momentum = force + diffusion
return SamplerState(new_code_momentum, state.rng, state.iterations + 1)
def find_fixed_point(weights, initial_state, maximum_iterations):
def f(state, diffusion):
return update_state(weights, state, diffusion), None
leak_rng, new_rng = split(initial_state.rng)
diffusion = normal(leak_rng, (maximum_iterations,) + initial_state.code_momentum.shape)
retval = scan(f, initial_state, diffusion, length=maximum_iterations, unroll=2)[0]
retval = retval.replace(rng=new_rng)
return retval
@partial(jit, static_argnums=(2,))
def infer_encoding(weights, initial_rng, maximum_iterations):
initial_sampler_state = SamplerState(jnp.zeros(shape), initial_rng, 0)
return find_fixed_point(weights, initial_sampler_state, maximum_iterations)
with enable_x64(): # Enabling 64-bit costs 50%.
rng = PRNGKey(12)
weight_rng, inference_rng = split(rng)
weights = hk.transform(haiku_weight_initializer).init(weight_rng)
for _ in range(10):
with Timer() as timer:
infer_encoding(weights, inference_rng, 8000)
print(timer.elapsed) |
By the way, https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html (I should've mentioned that earlier!) |
That's right.
Basically the one that looks something like this: I think it's called the Trace Viewer, described here in the TensorBoard docs.
I think once we look at the right part of the Trace Viewer output, we should be able to figure out what the computer is spending its time doing. |
There's good documentation on how to use the trace viewer, but unfortunately it looks like it's all Google-internal, and hasn't been open-sourced yet... |
I know that It's also possible to create an unrolling while loop that doesn't change its behavior. |
That would definitely explain some of the slowness, but I don't see how that's better than just generating one kernel to do both things at once? Even if this were applied to an array with a million elements, it still seems like it would be faster spawn one kernel instead of two. Is this XLA generator code something I can look at? I guess it's not here? |
XLA:GPU is all open-source!
XLA:GPU's hands are tied here, because the way to generate the fastest GPU kernels for matmul and conv are proprietary. Only Nvidia can do it, and they release those kernels as binaries. That's what cuBLAS, cuDNN, etc are. There are many such kernels; XLA:GPU does autotuning and kernel selection to choose the best routines for your array shapes and specific GPU hardware. See here for example. But because these are proprietary pre-built kernels, it can't e.g. fuse operations into them. That's why loop bodies may have to be separated into multiple separate kernels. |
So is there anything I can do finally to make my program run faster? |
I hope you don't mind that I'm looking at this again. From what I undersatnd https://github.com/openai/triton tries to produce single GPU kernels. Is there any hope of JAX doing something like this in future from XLA? What a dream it would be to have fast GPU execution with the convenience of JAX's compiler. |
Hi Neil, Did you saw this jax/triton project https://github.com/jax-ml/jax-triton? Also, what are your shapes? In this example, you have: |
No, I had not! Thank you for sharing that. I was aware of Triton, so this is very exciting!
Yes, for now my shapes are less than 100 as I work on getting my ideas working. I still feel like in an ideal world, it would not require bouncing between the CPU and GPU no matter what the shapes are? |
With those very small size, it will be hard for the GPU to be efficient automatically in the short term. I think many optimizations will be needed in addition to fix the not bouncing between the CPU and GPU. For example, it will probably need more aggressive fusion then what XLA currently does. The quickest path would be some manual kernel via custom ops or maybe via Triton. |
You know better than I do 😄 If you say that's the quickest path, I believe you. I will look into Triton within the next couple weeks and get back to you? How do I learn about writing custom ops? |
This is the best documentation about JAX custom CUDA ops: https://github.com/dfm/extending-jax |
Thanks. I will look into that! Would I have access to Jax's automatic differentiation? Or would I need to do the differentiation myself and then implement that in CUDA? |
I only very briefly looked at the custom op myself. @mattjj, do you know if we can provide a forward graph where the gradient will be taken of for the custom op gradient? Neil, at worst, you can print your forward and backward graph. From this, you can find that is the gradient graph. Then you can create a jax graph that does it and ask JAX to use it for the gradient: |
You'd need to write differentiation rules for any new Primitives introduced. In the dfm tutorial there's code in src/kepler_jax/kepler_jax.py which shows how to define a JVP rule for the primitive introduced. These custom differentiation rules are not meant for custom kernels (though they kinda work there...) but rather for JAX-traceable code. When introducing a new Primitive, as in the dfm tutorial (and as I would recommend for a custom kernel), you just attach transformation rules to the primitive directly. Once you have a differentiation rule for your primitive, you can differentiate any function that applies it along with other JAX primitives. (Someday JAX may be able to generate derivatives for Triton code automatically. It's something we're looking at, but it's a long way off.) |
First of all, Jax Triton looks amazing! Yes, it should solve my problem with quite of bit of work on my side. So thank you for that. However, I have some thoughts that I'd like to get feedback on. My problem boils down to an internal scan that evaluates something like x[i+1] = x[i] + k * f_bwd(z - f(x[i])) Where If I also thought about how I would write this in Triton. I could just manually write every fused kernel I need. And at the end of it, I'd have a library of pieces of kernels that I could compose to do what I need. These would probably be extra methods on "modules" (from Haiku or Flax) that would do things like:
Then I would have some way of composing multiple modules into a single fused kernel. This entails two functions
Then I thought: why am I doing all this? Wouldn't it make much more sense to have a conversion from XLA to Triton? I understand that Triton is a very limited language. I understand that it may not be possible to convert everything that XLA can do to Triton. But I'm not doing anything that crazy. If the converter wants to bail out if I try to do something like take a hyperbolic sin, that's fine! I'm just doing ordinary multiplications, exponentiations, addition, etc. And I remember Matt explaining to me that Nvidia's kernels (e.g. matrix multiplication) are better optimized than anything the user can do. But I'm pretty sure that the last time I looked at this, my runtime is dominated by kernel spawning. Even if Triton is 50% as fast as Nvidia's hand-crafted kernels, the ability to fuse literally hundreds of kernels together would more than compensate. And the reason it's hundreds is because I have a scan (described above), and each iteration of the scan is a whole new set of kernel spawning. So, my question boils down to: Why have we decided on Jax-Triton as the solution? Why not convert XLA to Triton as best you can, and then we can keep programming in the Jax we love? |
You can use jax to compute the grad graph and print it. That way, you do not need to do it manually, it would be semi-manual, so a little better, but clearly not ideal. You "only" need to write some custom CUDA kernel. You do not need to do any memory allocation. You still use JAX. You write the custom kernel that XLA doesn't generate fast code for. All the rest will stay in JAX. You can use jax-triton to write the custom kernel. It doesn't need to do anything else to my understanding. JAX-Triton help make the bridge between the two tools. I think it is less work then what you describe above. Your idea of having a XLA/triton backend would be hard to implement as we currently can't guide what to do in triton vs XLA. Maybe a simpler thing could be a decorator on JAX expression that get converted to Triton. This is highly hypothetical as I didn't look enough at Triton. But I suppose this doesn't exist now. |
I understand. What I'm trying to say is that the custom kernel writing that I would be doing is tantamount to compilation, which is already done by the XLA compiler. I want to simply write in Jax using its primitives. Triton is a different language. As you point out in your comment, it is possible to convert Jax to Triton, so that's what I'm asking for when I say "I want to program in the Jax I love".
Just so we're clear, I'm suggesting that instead of writing Triton code, I would write a rudimentary module-to-Triton converter. You are right that writing a single Triton kernel is less work. However, I don't think this is a good approach in my case:
Can you elaborate on this? I'd like to generalize my suggestion. What I want is for XLA to produce fused kernels rather than the kernels it's producing now. Why can't it produce fused kernels? If I can write fused kernels in Triton, surely the XLA compiler can produce such kernels? The benefit of XLA producing such kernels is that I wouldn't have to worry about producing backwards passes, which involves:
Yes, I considered something like this. This would be a fantastic step in the right direction. If you consider my more general suggestion, then what I really want is a decorator to demand that the XLA compiler produces a fused kernel for a decorated function. PS Salut Frederic! We met ages ago when you were working on Theano. Very nice to see you here 😄 |
I agree that having XLA does the fusion for you would be great. I suppose you won't learn XLA and implement it yourself. A jaxpr->triton converter can be done more easily then a XLA->triton converted. Hopefully this is something you could take on. But I'm not able to estimate the time it would require. For the gradient, the simplest way that I see is to use JAX grad on the forward graph. Then reuse the jaxpr->triton a second time on that gradient graph to get a second operation in the gradient. |
Okay! That makes perfect sense. Thank you for explaining. But please do keep me posted if you learn about something like being in the works so that I can put my time into other things.
Interesting! I honestly thought jaxpr was XLA 😆 Isn't it true that jitted Jax code has a Jaxpr? If so, this would solve my problem, and if you think it's reasonable, then it's definitely worth examining. Are there any tools for working with Jaxpr? Is there any documentation?
Yes, that's what I was thinking too. That's the beauty of having such a converter. |
JAX build jaxpr expression. Then this is converted to XLA. |
There is some information on jaxprs here: https://jax.readthedocs.io/en/latest/jaxpr.html |
The following code is almost instantaneous (<1ms) on the CPU, but is extremely slow on the GPU (7s). I'm trying to track down the source of the problem. I have pared down my code from 5000 lines down to 80 lines, and I don't think I can remove any more. I have added comments in places that I found that have surprising (to me) effects on the GPU run time.
How can I make this code run faster on the GPU than it does on the CPU? What am I doing wrong?
The text was updated successfully, but these errors were encountered: