Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fatal Python error: Aborted when running mnist.py example #37

Closed
iobtl opened this issue May 11, 2020 · 3 comments
Closed

Fatal Python error: Aborted when running mnist.py example #37

iobtl opened this issue May 11, 2020 · 3 comments

Comments

@iobtl
Copy link

iobtl commented May 11, 2020

Hi there! I've been trying to get familiar with the library by running some examples in the examples/ folder. My environment was set up according to the instructions on https://github.com/google/jax#installation and https://github.com/deepmind/dm-haiku#installation.

On running the mnist.py example with TensorFlow 2.1.0, a Fatal Python error: Aborted occurs. The full error message is as below:

2020-05-11 20:08:28.772679: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer.so.6'; dlerror: libnvinfer.so.6: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64
2020-05-11 20:08:28.772772: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer_plugin.so.6'; dlerror: libnvinfer_plugin.so.6: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64
2020-05-11 20:08:28.772790: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:30] Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
I0511 20:08:29.945827 140124110780224 dataset_info.py:361] Load dataset info from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:29.947907 140124110780224 dataset_info.py:401] Field info.citation from disk and from code do not match. Keeping the one from code.
I0511 20:08:29.948163 140124110780224 dataset_builder.py:283] Reusing dataset mnist (/home/tedmund/tensorflow_datasets/mnist/3.0.1)
I0511 20:08:29.948284 140124110780224 dataset_builder.py:479] Constructing tf.data.Dataset for split train, from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:30.869276 140124110780224 dataset_info.py:361] Load dataset info from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:30.870402 140124110780224 dataset_info.py:401] Field info.citation from disk and from code do not match. Keeping the one from code.
I0511 20:08:30.870607 140124110780224 dataset_builder.py:283] Reusing dataset mnist (/home/tedmund/tensorflow_datasets/mnist/3.0.1)
I0511 20:08:30.870710 140124110780224 dataset_builder.py:479] Constructing tf.data.Dataset for split train, from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:30.915239 140124110780224 dataset_info.py:361] Load dataset info from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:30.916428 140124110780224 dataset_info.py:401] Field info.citation from disk and from code do not match. Keeping the one from code.
I0511 20:08:30.916637 140124110780224 dataset_builder.py:283] Reusing dataset mnist (/home/tedmund/tensorflow_datasets/mnist/3.0.1)
I0511 20:08:30.916740 140124110780224 dataset_builder.py:479] Constructing tf.data.Dataset for split test, from /home/tedmund/tensorflow_datasets/mnist/3.0.1
2020-05-11 20:08:32.182307: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_blas.cc:236] failed to create cublas handle: CUBLAS_STATUS_NOT_INITIALIZED
2020-05-11 20:08:32.182349: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc:113] Check failed: stream->parent()->GetBlasGemmAlgorithms(&algorithms) 
Fatal Python error: Aborted

Current thread 0x00007f712fd8f740 (most recent call first):
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jaxlib/xla_client.py", line 156 in compile
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jaxlib/xla_client.py", line 576 in Compile
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/interpreters/xla.py", line 197 in xla_primitive_callable
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/interpreters/xla.py", line 166 in apply_primitive
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/core.py", line 199 in bind
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/lax/lax.py", line 626 in dot_general
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/lax/lax.py", line 564 in dot
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/numpy/lax_numpy.py", line 2484 in dot
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/basic.py", line 161 in __call__
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/module.py", line 155 in wrapped
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/basic.py", line 120 in __call__
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/module.py", line 155 in wrapped
  File "mnist.py", line 41 in net_fn
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/transform.py", line 271 in init_fn
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/transform.py", line 106 in init_fn
  File "mnist.py", line 112 in main
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/absl/app.py", line 250 in _run_main
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/absl/app.py", line 299 in run
  File "mnist.py", line 131 in <module>
Aborted (core dumped)

One solution I've found to this is a more commonplace solution when using TensorFlow, by inserting the code:

from tensorflow.compat.v1 import ConfigProto, InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.7
sess = InteractiveSession(config=config)

However, this kind of defeats the purpose if one is simply trying to use JAX/NumPy instead of TensorFlow. Not sure what else I can provide to help, please do let me know!

@trevorcai
Copy link
Contributor

2020-05-11 20:08:32.182307: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_blas.cc:236] failed to create cublas handle: CUBLAS_STATUS_NOT_INITIALIZED

This error is indicating that CUDA didn't initialize properly.
As you point out, one cause of this might be that many processes are concurrently trying to reserve GPU memory. Do you get this error when there are no other processes using your GPU? (Can be seen with nvidia-smi.)

The JAX equivalent of the above TF configuration can be found here:
https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

One other thing that might be worth trying is encasing all the TFDS logic in a with tf.device('cpu'): to make sure the TF executor & XLA aren't fighting in your setup.

I'm fairly confident that this is a configuration issue; the folks in google/jax may be more able to pinpoint the problem than me.

@iobtl
Copy link
Author

iobtl commented May 14, 2020

Following the configuration instructions in https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html worked.
Thank you! Will keep note of such a requirement in future. I don't have concurrent processes running so this may just be an internal system setup on my side.

@iobtl iobtl closed this as completed May 14, 2020
@shakes76
Copy link

This is the solution for the case when using JAX with tensorflow datasets that use tensorflow for the preprocessing as XLA and TF both fight for the memory. Using the documentation highlighted in this issue above points out using tf.config.experimental.set_visible_devices([], "GPU")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants