Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda failed to allocate errors #788

Closed
christopherhesse opened this issue May 30, 2019 · 32 comments
Closed

cuda failed to allocate errors #788

christopherhesse opened this issue May 30, 2019 · 32 comments
Assignees
Labels
bug Something isn't working

Comments

@christopherhesse
Copy link

When running a a training script using the new memory allocation backend (#417), I see a bunch of non-fatal errors like this:

[1] 2019-05-29 23:55:55.555823: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 528.00M (553648128 bytes) from 
device: CUDA_ERROR_OUT_OF_MEMORY: out of memory
[1] 2019-05-29 23:55:55.581962: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc:525] Resource exhausted: Failed to 
allocate request for 528.00MiB (553648128B) on device ordinal 0
[7] 2019-05-29 23:55:55.594693: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 528.00M (553648128 bytes) from 
device: CUDA_ERROR_OUT_OF_MEMORY: out of memory
[7] 2019-05-29 23:55:55.606314: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc:525] Resource exhausted: Failed to 
allocate request for 528.00MiB (553648128B) on device ordinal 0
[1] 2019-05-29 23:55:55.633261: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 1.14G (1224736768 bytes) from d
evice: CUDA_ERROR_OUT_OF_MEMORY: out of memory
[1] 2019-05-29 23:55:55.635169: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 1.05G (1132822528 bytes) from d
evice: CUDA_ERROR_OUT_OF_MEMORY: out of memory
[1] 2019-05-29 23:55:55.646031: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 561.11M (588365824 bytes) from 
device: CUDA_ERROR_OUT_OF_MEMORY: out of memory
[1] 2019-05-29 23:55:55.647926: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 592.04M (620793856 bytes) from 
device: CUDA_ERROR_OUT_OF_MEMORY: out of memory
[7] 2019-05-29 23:55:55.655470: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 1.14G (1224736768 bytes) from d
evice: CUDA_ERROR_OUT_OF_MEMORY: out of memory

Is this a known issue? The errors go away when using XLA_PYTHON_CLIENT_ALLOCATOR=platform.

@hawkinsp
Copy link
Member

That seems like it might be a bug, but it's hard to say without a repro.

@hawkinsp
Copy link
Member

Does your workload use convolutions? I think I see a bug where JAX isn't giving the correct GPU allocator to the XLA convolution autotuning code. That would explain why the errors are non-fatal; XLA will fall back to convolution algorithms that need less scratch space.

In particular the options structure at the line linked below accepts an optional device ordinal and device allocator for autotuning at compile time. We should be setting both but currently we are setting neither.
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/local_client.cc#L880

@skye Mind taking a look?

@hawkinsp hawkinsp added the bug Something isn't working label May 30, 2019
@christopherhesse
Copy link
Author

christopherhesse commented May 30, 2019

Nice find! Yes, this is using the model from https://arxiv.org/abs/1802.01561 (which includes convolutions). I haven't been able to make a small repo yet that doesn't depend on a bunch of internal libraries but I can give that a try if it is helpful.

@fehiepsi
Copy link
Member

Thanks for the trick XLA_PYTHON_CLIENT_ALLOCATOR=platform @christopherhesse ! I am facing this issue frequently though there is no convolution in my script.

@mattjj
Copy link
Member

mattjj commented May 30, 2019

@fehiepsi if you've seen this in simpler contexts, would you be able to contribute a small repro? We're eager to debug this, but need a way to make progress.

@christopherhesse yes a small repro would be helpful :)

@skye
Copy link
Collaborator

skye commented May 30, 2019

A repro would be helpful, but in the meantime, I can also look into the unset options that @hawkinsp identified above. If the missing options are obviously a bug, I can go ahead and set them and we can see if that helps.

@fehiepsi
Copy link
Member

Yup, I will make a repro example soon (currently my scripts depend on numpyro).

@fehiepsi
Copy link
Member

fehiepsi commented May 30, 2019

@mattjj @skye Here you can find a small repro script, which triggers

2019-05-30 17:00:02.646858: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:226] failed to load CUBIN: Internal: failed to load in-memory CUBIN: CUDA_ERROR_OUT_OF_MEMORY: out of memory
2019-05-30 17:00:02.646902: F external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:888] Check failed: module != nullptr 

at iteration i=5 in my system (RTX 2070 8GB)

import numpy as onp
from sklearn.datasets import fetch_covtype

import jax.numpy as np
from jax import jit, random
from jax.config import config; config.update("jax_platform_name", "gpu")

data = fetch_covtype()
features = data.data

def get_f(features):
    @jit
    def f(x):
        return np.dot(features, x).sum()

    return f

for i in range(10):
    f = get_f(features)
    print(f(np.ones(54)))

In addition, without XLA_PYTHON_CLIENT_ALLOCATOR=platform, this script will allocate all memory in my GPU.

Sometimes, it will trigger the following error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-2-8bd343fefd9f> in <module>
     18 for i in range(10):
     19     f = get_f(features)
---> 20     print(f(np.ones(54)))

~/jax/jax/api.py in f_jitted(*args, **kwargs)
    119     _check_args(args_flat)
    120     flat_fun, out_tree = flatten_fun_leafout(f, in_tree)
--> 121     out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)
    122     return out if out_tree() is leaf else tree_unflatten(out_tree(), out)
    123 

~/jax/jax/core.py in call_bind(primitive, f, *args, **params)
    653   if top_trace is None:
    654     with new_sublevel():
--> 655       ans = primitive.impl(f, *args, **params)
    656   else:
    657     tracers = map(top_trace.full_raise, args)

~/jax/jax/interpreters/xla.py in xla_call_impl(fun, *args, **params)
    611   compiled_fun = xla_callable(fun, device_values, *map(abstractify, args))
    612   try:
--> 613     return compiled_fun(*args)
    614   except FloatingPointError:
    615     print("Invalid value encountered in the output of a jit function. "

~/jax/jax/interpreters/xla.py in execute_compiled(compiled, pval, handle_result, *args)
    634 def execute_compiled(compiled, pval, handle_result, *args):
    635   input_bufs = [device_put(x) for x in args]
--> 636   out_buf = compiled.Execute(input_bufs)
    637   check_nans("jit-compiled computation", out_buf)
    638   return pe.merge_pvals(handle_result(out_buf), pval)

RuntimeError: Internal: Unable to load kernel 'fusion'

@christopherhesse
Copy link
Author

I made a repro script too, though it's not as small: https://gist.github.com/christopherhesse/1808bbe01824c7a23d9af59dc6376961

@gd-zhang
Copy link

gd-zhang commented Jun 2, 2019

Same here. In my script with convolution operation, I run into the same "out of memory" error (though it's not fatal).

@skye
Copy link
Collaborator

skye commented Jun 7, 2019

@fehiepsi @christopherhesse @gd-zhang can you try updating to the latest jaxlib (0.1.18) and see if you still see the errors? I suspect tensorflow/tensorflow@805b7cc will alleviate this problem, although it doesn't actually fix it (i.e. if the BFCAllocator still ends up using most of your GPU memory, you could still see the errors when something uses a different allocator. The BFCAllocator is now less likely to be using all the memory though).

@christopherhesse
Copy link
Author

@skye I upgraded jaxlib (but not jax) and still see the errors:

pip show jaxlib
Name: jaxlib
Version: 0.1.18
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /opt/conda/lib/python3.7/site-packages
Requires: absl-py, numpy, protobuf, scipy, six
Required-by: 
2019-06-08 03:44:36.084180: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 3.39G (3640655872 bytes) from devic
e: CUDA_ERROR_OUT_OF_MEMORY: out of memory
2019-06-08 03:44:36.087182: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc:555] Resource exhausted: Failed to allo
cate request for 3.39GiB (3640655872B) on device ordinal 0
2019-06-08 03:44:36.095983: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 880.00M (922746880 bytes) from devi
ce: CUDA_ERROR_OUT_OF_MEMORY: out of memory
2019-06-08 03:44:36.096050: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc:555] Resource exhausted: Failed to allo
cate request for 880.00MiB (922746880B) on device ordinal 0
2019-06-08 03:44:36.097871: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 880.00M (922746880 bytes) from devi
ce: CUDA_ERROR_OUT_OF_MEMORY: out of memory
2019-06-08 03:44:36.097929: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc:555] Resource exhausted: Failed to allo
cate request for 880.00MiB (922746880B) on device ordinal 0
2019-06-08 03:44:36.099827: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 880.00M (922746880 bytes) from devi
ce: CUDA_ERROR_OUT_OF_MEMORY: out of memory

@fehiepsi
Copy link
Member

@skye I have not seen the error any more. Thanks, that's a great news!

@skye
Copy link
Collaborator

skye commented Jun 10, 2019

@fehiepsi glad to hear this issue is resolved for you! Please let me know if it pops up again though.

@christopherhesse unfortunately neither repro works for me, even before tensorflow/tensorflow@805b7cc (I'm guessing it depends on what GPU you have). I'll keep trying to trigger the error, but in the meantime I can also make the change described above, and we can see if that helps. I'll let you know when there's a new jaxlib to try.

@christopherhesse
Copy link
Author

@skye thanks for investigating!

The script I posted before works reliably for me with jaxlib 0.1.16 on a GCE instance using V100 GPUs, so nothing too exotic there.

After upgrading using pip install https://storage.googleapis.com/jax-wheels/cuda100/jaxlib-0.1.18-cp37-none-linux_x86_64.whl, that original script stops printing errors (so this definitely fixes at least one issue)

It could be sensitive to the exact versions of things though, so let me know if there are any jaxlib dependencies that you want information on.

Another script I run still occasionally prints the errors though, I've posted an updated version of my original script: https://gist.github.com/christopherhesse/3fa507c7b1d50dceede20b60653d307f

The output looks like this:

batch_size 1
batch_size 2
batch_size 4
batch_size 8
batch_size 16
batch_size 32
batch_size 64
batch_size 128
batch_size 256
batch_size 512
2019-06-10 22:17:12.532436: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 7.34G (7886798848 bytes) from device: CUDA_ERROR_OUT_OF_MEMORY: out of memory

Oddly it doesn't seem to happen if I just do batch size 512 only, unclear if memory is not being freed from the previous iterations of the loop.

Here's the full output: https://gist.github.com/christopherhesse/8ec37c3cda18851bc6eb8621bec76a23

It's possible that these errors are expected, but that's true it seems like they should be warning messages or silent.

@skye
Copy link
Collaborator

skye commented Jun 11, 2019

@christopherhesse I'm able to repro with your updated script, thanks! Agreed that these "errors" aren't necessary, they're way too noisy and not actionable (since the script still runs, at least for a while). Now I can find out exactly where they're coming from and hopefully put a stop to them :)

@skye
Copy link
Collaborator

skye commented Jun 13, 2019

@christopherhesse if you update to the latest jaxlib (0.1.20, currently only available on Linux for now, let me know if need the Mac build), you should see fewer OOM messages. (tensorflow/tensorflow@701f7e5 reduces the amount of GPU memory needed in your script, and tensorflow/tensorflow@84e3ae1 suppresses some spurious OOM log messages.) Give it a try?

There's another issue that I haven't addressed yet, which is that tensorflow/tensorflow@805b7cc reduces GPU memory utilization (with the upshot that jax no longer allocate all your GPU memory up-front). I noticed that this makes your script OOM sooner than it does prior to that change. This is harder to fix; I might just add a toggle to reenable the old behavior for now. I'll file a separate issue for this once I can better quantify how much worse the utilization is.

@christopherhesse
Copy link
Author

@skye the errors are gone, thanks for fixing this!

@christopherhesse
Copy link
Author

christopherhesse commented Jun 13, 2019

Actually I immediately am running into what I suspect is the OOM issue:

[1] 2019-06-13 01:17:08.886805: E external/org_tensorflow/tensorflow/compiler/xla/python/local_client.cc:672] Execution of replica 0 failed: Resource exhausted: Out 
of memory while trying to allocate 5520758552 bytes.

(and then the program exits)

@christopherhesse
Copy link
Author

So I'll have to downgrade to 0.1.18 for now :/

@skye skye reopened this Jun 13, 2019
@skye
Copy link
Collaborator

skye commented Jun 13, 2019

When you say immediately, you mean it makes less progress than it did when you first reported this issue?

@christopherhesse
Copy link
Author

I mean that it prints this error but it's actually fatal this time (on my training script). Before it would print errors (under some conditions) but this script did not fail due to OOM.

@skye
Copy link
Collaborator

skye commented Jun 13, 2019

Ok, I think you're hitting tensorflow/tensorflow@805b7cc then. I'll create a toggle to revert to the old behavior as a workaround for now.

@skye
Copy link
Collaborator

skye commented Jun 21, 2019

Hey, forgot to update this issue, oops! As of jaxlib 0.1.21, I've reverted the default behavior back to allocating 90% of your GPU memory up-front, which avoids the fragmentation issue. @christopherhesse your script should be able to run successfully now, give it a shot?

FYI you can set the env var XLA_PYTHON_CLIENT_PREALLOCATE=false to start with a small footprint again, or set XLA_PYTHON_CLIENT_MEM_FRACTION=.5 to limit JAX to using 50% of available GPU memory (or the fraction of your choice).

@christopherhesse
Copy link
Author

I don't see any errors with this version, and it doesn't crash. Thanks! Should I close this issue?

@skye
Copy link
Collaborator

skye commented Jun 21, 2019

Awesome, thanks for your patience with this! I'll go ahead and close the issue.

@skye skye closed this as completed Jun 21, 2019
@christopherhesse
Copy link
Author

christopherhesse commented Jun 21, 2019

I did see one more of these error messages, though I believe this one was non-fatal:

2019-06-21 23:26:20.942813: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc:569] Resource exhausted: Allocating 7247757312 bytes exceeds the memory limit of 4294967296 bytes.

All the errors except this one seem to be gone.

@skye skye reopened this Jun 21, 2019
@skye
Copy link
Collaborator

skye commented Jun 21, 2019

Looks like an internal "error" log message that should be downgraded to "info". Safe to ignore, but I'll leave this open until we get rid of the spurious error message.

@mgbukov
Copy link

mgbukov commented Nov 6, 2019

I did see one more of these error messages, though I believe this one was non-fatal:

2019-06-21 23:26:20.942813: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc:569] Resource exhausted: Allocating 7247757312 bytes exceeds the memory limit of 4294967296 bytes.

All the errors except this one seem to be gone.

I also got this error message; is this a real memory issue or a bug? If it's a bug, does anyone know how to suppress it until fixed?

@skye
Copy link
Collaborator

skye commented Nov 11, 2019

I believe this is a bug in that it's not a real memory issue (XLA is using too much memory trying to pick the best cuDNN algorithm, which may result in non-optimal performance but otherwise isn't a big deal). I got caught up with other things, but will downgrade this log level to INFO.

@mgbukov
Copy link

mgbukov commented Nov 11, 2019

@skye in my case, the error message appears when trying to allocate memory for an internal variable in a jitted function, i.e. on the GPU. This out-of-memory error is related to the batch size. I don't know if this will be useful:

  • when I ran the same code on the CPU it was OK.
  • when I started two MPI processes with half the batch size each and executed them on the same GPU device, they ran (but I wasn't able to tell if this is because the two half-batches ran in series).

tensorflow-copybara pushed a commit to tensorflow/tensorflow that referenced this issue Nov 12, 2019
Some jax users are hitting this case (google/jax#788), and are confused as to whether it's an actual error. Given that this doesn't effect correctness and is somewhat internal to the compiler, I would argue it's not an error from the user's perspective.

PiperOrigin-RevId: 279991083
Change-Id: I3c893179f805c37f6a66cae0b9674337b1693314
@skye
Copy link
Collaborator

skye commented Nov 12, 2019

I ended up making it a WARNING, since it can have a significant performance impact. The change is to committed to XLA in tensorflow/tensorflow@1423eab, and will be included in the next jaxlib.

@mgbukov the error is referring to GPU memory and GPU convolution algorithms, so you won't see it on CPU. You might also try the techniques for reducing GPU memory usage as described in https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

7 participants