New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cuda failed to allocate errors #788
Comments
That seems like it might be a bug, but it's hard to say without a repro. |
Does your workload use convolutions? I think I see a bug where JAX isn't giving the correct GPU allocator to the XLA convolution autotuning code. That would explain why the errors are non-fatal; XLA will fall back to convolution algorithms that need less scratch space. In particular the @skye Mind taking a look? |
Nice find! Yes, this is using the model from https://arxiv.org/abs/1802.01561 (which includes convolutions). I haven't been able to make a small repo yet that doesn't depend on a bunch of internal libraries but I can give that a try if it is helpful. |
Thanks for the trick |
@fehiepsi if you've seen this in simpler contexts, would you be able to contribute a small repro? We're eager to debug this, but need a way to make progress. @christopherhesse yes a small repro would be helpful :) |
A repro would be helpful, but in the meantime, I can also look into the unset options that @hawkinsp identified above. If the missing options are obviously a bug, I can go ahead and set them and we can see if that helps. |
Yup, I will make a repro example soon (currently my scripts depend on numpyro). |
@mattjj @skye Here you can find a small repro script, which triggers
at iteration import numpy as onp
from sklearn.datasets import fetch_covtype
import jax.numpy as np
from jax import jit, random
from jax.config import config; config.update("jax_platform_name", "gpu")
data = fetch_covtype()
features = data.data
def get_f(features):
@jit
def f(x):
return np.dot(features, x).sum()
return f
for i in range(10):
f = get_f(features)
print(f(np.ones(54))) In addition, without Sometimes, it will trigger the following error
|
I made a repro script too, though it's not as small: https://gist.github.com/christopherhesse/1808bbe01824c7a23d9af59dc6376961 |
Same here. In my script with convolution operation, I run into the same "out of memory" error (though it's not fatal). |
@fehiepsi @christopherhesse @gd-zhang can you try updating to the latest jaxlib (0.1.18) and see if you still see the errors? I suspect tensorflow/tensorflow@805b7cc will alleviate this problem, although it doesn't actually fix it (i.e. if the BFCAllocator still ends up using most of your GPU memory, you could still see the errors when something uses a different allocator. The BFCAllocator is now less likely to be using all the memory though). |
@skye I upgraded jaxlib (but not jax) and still see the errors:
|
@skye I have not seen the error any more. Thanks, that's a great news! |
@fehiepsi glad to hear this issue is resolved for you! Please let me know if it pops up again though. @christopherhesse unfortunately neither repro works for me, even before tensorflow/tensorflow@805b7cc (I'm guessing it depends on what GPU you have). I'll keep trying to trigger the error, but in the meantime I can also make the change described above, and we can see if that helps. I'll let you know when there's a new jaxlib to try. |
@skye thanks for investigating! The script I posted before works reliably for me with jaxlib 0.1.16 on a GCE instance using V100 GPUs, so nothing too exotic there. After upgrading using It could be sensitive to the exact versions of things though, so let me know if there are any jaxlib dependencies that you want information on. Another script I run still occasionally prints the errors though, I've posted an updated version of my original script: https://gist.github.com/christopherhesse/3fa507c7b1d50dceede20b60653d307f The output looks like this:
Oddly it doesn't seem to happen if I just do batch size Here's the full output: https://gist.github.com/christopherhesse/8ec37c3cda18851bc6eb8621bec76a23 It's possible that these errors are expected, but that's true it seems like they should be warning messages or silent. |
@christopherhesse I'm able to repro with your updated script, thanks! Agreed that these "errors" aren't necessary, they're way too noisy and not actionable (since the script still runs, at least for a while). Now I can find out exactly where they're coming from and hopefully put a stop to them :) |
@christopherhesse if you update to the latest jaxlib (0.1.20, currently only available on Linux for now, let me know if need the Mac build), you should see fewer OOM messages. (tensorflow/tensorflow@701f7e5 reduces the amount of GPU memory needed in your script, and tensorflow/tensorflow@84e3ae1 suppresses some spurious OOM log messages.) Give it a try? There's another issue that I haven't addressed yet, which is that tensorflow/tensorflow@805b7cc reduces GPU memory utilization (with the upshot that jax no longer allocate all your GPU memory up-front). I noticed that this makes your script OOM sooner than it does prior to that change. This is harder to fix; I might just add a toggle to reenable the old behavior for now. I'll file a separate issue for this once I can better quantify how much worse the utilization is. |
@skye the errors are gone, thanks for fixing this! |
Actually I immediately am running into what I suspect is the OOM issue:
(and then the program exits) |
So I'll have to downgrade to 0.1.18 for now :/ |
When you say immediately, you mean it makes less progress than it did when you first reported this issue? |
I mean that it prints this error but it's actually fatal this time (on my training script). Before it would print errors (under some conditions) but this script did not fail due to OOM. |
Ok, I think you're hitting tensorflow/tensorflow@805b7cc then. I'll create a toggle to revert to the old behavior as a workaround for now. |
Hey, forgot to update this issue, oops! As of jaxlib 0.1.21, I've reverted the default behavior back to allocating 90% of your GPU memory up-front, which avoids the fragmentation issue. @christopherhesse your script should be able to run successfully now, give it a shot? FYI you can set the env var XLA_PYTHON_CLIENT_PREALLOCATE=false to start with a small footprint again, or set XLA_PYTHON_CLIENT_MEM_FRACTION=.5 to limit JAX to using 50% of available GPU memory (or the fraction of your choice). |
I don't see any errors with this version, and it doesn't crash. Thanks! Should I close this issue? |
Awesome, thanks for your patience with this! I'll go ahead and close the issue. |
I did see one more of these error messages, though I believe this one was non-fatal:
All the errors except this one seem to be gone. |
Looks like an internal "error" log message that should be downgraded to "info". Safe to ignore, but I'll leave this open until we get rid of the spurious error message. |
I also got this error message; is this a real memory issue or a bug? If it's a bug, does anyone know how to suppress it until fixed? |
I believe this is a bug in that it's not a real memory issue (XLA is using too much memory trying to pick the best cuDNN algorithm, which may result in non-optimal performance but otherwise isn't a big deal). I got caught up with other things, but will downgrade this log level to INFO. |
@skye in my case, the error message appears when trying to allocate memory for an internal variable in a jitted function, i.e. on the GPU. This out-of-memory error is related to the batch size. I don't know if this will be useful:
|
Some jax users are hitting this case (google/jax#788), and are confused as to whether it's an actual error. Given that this doesn't effect correctness and is somewhat internal to the compiler, I would argue it's not an error from the user's perspective. PiperOrigin-RevId: 279991083 Change-Id: I3c893179f805c37f6a66cae0b9674337b1693314
I ended up making it a WARNING, since it can have a significant performance impact. The change is to committed to XLA in tensorflow/tensorflow@1423eab, and will be included in the next jaxlib. @mgbukov the error is referring to GPU memory and GPU convolution algorithms, so you won't see it on CPU. You might also try the techniques for reducing GPU memory usage as described in https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html. |
When running a a training script using the new memory allocation backend (#417), I see a bunch of non-fatal errors like this:
Is this a known issue? The errors go away when using
XLA_PYTHON_CLIENT_ALLOCATOR=platform
.The text was updated successfully, but these errors were encountered: