Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unimplemented: DNN library is not found. #4920

Closed
milmor opened this issue Nov 17, 2020 · 27 comments
Closed

Unimplemented: DNN library is not found. #4920

milmor opened this issue Nov 17, 2020 · 27 comments
Assignees
Labels
needs info More information is required to diagnose & prioritize the issue. NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)

Comments

@milmor
Copy link

milmor commented Nov 17, 2020

Working on local GPU RTX 2060 super, Cuda 11.1, and got this error.

jax has been installed successfully with the following

pip install --upgrade jax jaxlib==0.1.57+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html

and symlink

sudo ln -s /path/to/cuda /usr/local/cuda-11.1

jax outputs the gpu with

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

and do math stuff like

rng_key = random.PRNGKey(0)

however still can't train the model

evaluate(model, test_ds)

FilteredStackTrace: RuntimeError: Unimplemented: DNN library is not found.

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
<ipython-input-8-0f8618edbb7d> in <module>()
     13   return compute_metrics(logits, eval_ds['label'])
     14 
---> 15 evaluate(model, test_ds)

/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    131   def reraise_with_filtered_traceback(*args, **kwargs):
    132     try:
--> 133       return fun(*args, **kwargs)
    134     except Exception as e:
    135       if not is_under_reraiser(e):

/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    221         backend=backend,
    222         name=flat_fun.__name__,
--> 223         donated_invars=donated_invars)
    224     return tree_unflatten(out_tree(), out)
    225 

/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1175 
   1176   def bind(self, fun, *args, **params):
-> 1177     return call_bind(self, fun, *args, **params)
   1178 
   1179   def process(self, trace, fun, tracers, params):

/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1166   tracers = map(top_trace.full_raise, args)
   1167   with maybe_new_sublevel(top_trace):
-> 1168     outs = primitive.process(top_trace, fun, tracers, params)
   1169   return map(full_lower, apply_todos(env_trace_todo(), outs))
   1170 

/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1178 
   1179   def process(self, trace, fun, tracers, params):
-> 1180     return trace.process_call(self, fun, tracers, params)
   1181 
   1182   def post_process(self, trace, out_tracers, params):

/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    577 
    578   def process_call(self, primitive, f, tracers, params):
--> 579     return primitive.impl(f, *tracers, **params)
    580   process_map = process_call
    581 

/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    557                                *unsafe_map(arg_spec, args))
    558   try:
--> 559     return compiled_fun(*args)
    560   except FloatingPointError:
    561     assert FLAGS.jax_debug_nans  # compiled_fun can only raise in this case

/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/interpreters/xla.py in _execute_compiled(compiled, avals, handlers, *args)
    805   device, = compiled.local_devices()
    806   input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
--> 807   out_bufs = compiled.execute(input_bufs)
    808   if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
    809   return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]

RuntimeError: Unimplemented: DNN library is not found.
@hawkinsp
Copy link
Member

That means that CuDNN is not in your library path. Can you try adding your CUDA lib path to LD_LIBRARY_PATH?

@tomhennigan
Copy link
Member

This issue seems related google-deepmind/dm-haiku#83, perhaps something recently has changed?

@hawkinsp
Copy link
Member

About 5 months ago (a141cc6) we switched how we link GPU libraries to be the same as TensorFlow, namely, we use dlopen() to find libraries like CuDNN rather than linking against them directly. dlopen() looks for libraries using LD_LIBRARY_PATH, so that's ultimately the cause of this error: we can't find the libraries.

I suspect you would see the exact same behavior with tensorflow with GPU support: as far as I am aware, it uses the same code to find the GPU libraries. It might be interesting to verify that hypothesis: install a GPU version of TF and try running a convolution. You should see the same error as JAX (if you haven't set LD_LIBRARY_PATH).

I also suspect if you set TF_CPP_MIN_LOG_LEVEL=0 then you may see some better logging that more clearly indicates what the real problem is.

I agree the error message isn't very helpful; we should probably fix that.

@mil-ad
Copy link

mil-ad commented Nov 17, 2020

That does get rid of that error (although some other issues still remain in google-deepmind/dm-haiku#83).

It'd be great if the cuDNN dependency is documented in more details in the installation guide. The cuda bit is clear with the symbolic link and env variable but I didn't know about cuDNN.

@milmor
Copy link
Author

milmor commented Nov 17, 2020

That means that CuDNN is not in your library path. Can you try adding your CUDA lib path to LD_LIBRARY_PATH?

After

$ export PATH=/usr/local/cuda-11.1/bin${PATH:+:${PATH}}
$ export LD_LIBRARY_PATH=/usr/local/cuda-11.1/lib64\
                         ${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}

Still outputs the same error

@milmor
Copy link
Author

milmor commented Nov 17, 2020

About 5 months ago (a141cc6) we switched how we link GPU libraries to be the same as TensorFlow, namely, we use dlopen() to find libraries like CuDNN rather than linking against them directly. dlopen() looks for libraries using LD_LIBRARY_PATH, so that's ultimately the cause of this error: we can't find the libraries.

I suspect you would see the exact same behavior with tensorflow with GPU support: as far as I am aware, it uses the same code to find the GPU libraries. It might be interesting to verify that hypothesis: install a GPU version of TF and try running a convolution. You should see the same error as JAX (if you haven't set LD_LIBRARY_PATH).

I also suspect if you set TF_CPP_MIN_LOG_LEVEL=0 then you may see some better logging that more clearly indicates what the real problem is.

I agree the error message isn't very helpful; we should probably fix that.

Tensorflow 2.3 works perfect, no error.

@mil-ad
Copy link

mil-ad commented Nov 17, 2020

Tensorflow 2.3 works perfect, no error.

Do you mean upgrading to 2.3 or downgrading to 2.3?

@milmor
Copy link
Author

milmor commented Nov 17, 2020

Tensorflow 2.3 works perfect, no error.

Do you mean upgrading to 2.3 or downgrading to 2.3?

After creating a new conda enviroment, and installing tensorflow-gpu==2.3 using pip, there's no error with cuda or tensorflow and can train succesfully. However jax still fails.

@tomhennigan
Copy link
Member

https://groups.google.com/a/tensorflow.org/g/discuss/c/TiWgve-KERo/m/NgUohfTiAgAJ

^ I think this thread in TF is relevant too.

@mil-ad
Copy link

mil-ad commented Nov 23, 2020

Looks like this may have been fixed in 0.2.6. @milmor can you confirm?

@milmor
Copy link
Author

milmor commented Nov 24, 2020

Looks like this may have been fixed in 0.2.6. @milmor can you confirm?

The issue has not been fixed in 0.2.6.

I found that although jaxlib .whl is for cuda111, jax version is 0.2.6 and it has been installed with:

pip install --upgrade jax jaxlib==0.1.57+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html

it seems that is looking for a different cuda version as shown in the following:

2020-11-23 22:06:33.769111: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/cuda-11.1/lib64:
2020-11-23 22:06:34.835004: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/cuda-11.1/lib64:
2020-11-23 22:06:34.835077: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcublas.so.10'; dlerror: libcublas.so.10: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/cuda-11.1/lib64:
2020-11-23 22:06:34.835874: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcusolver.so.10'; dlerror: libcusolver.so.10: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/cuda-11.1/lib64:
2020-11-23 22:06:34.835936: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcusparse.so.10'; dlerror: libcusparse.so.10: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/cuda-11.1/lib64:
2020-11-23 22:06:34.835992: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudnn.so.7'; dlerror: libcudnn.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/cuda-11.1/lib64:
2020-11-23 22:06:34.835999: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1753] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
2020-11-23 22:06:44.328599: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/cuda-11.1/lib64:
2020-11-23 22:06:44.328623: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:349] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2020-11-23 22:06:44.328778: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:349] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2020-11-23 22:06:44.328809: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:349] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2020-11-23 22:06:44.328851: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:772] Failed to determine best cudnn convolution algorithm: Internal: All algorithms tried for convolution %custom-call.2 = (f32[10000,28,28,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[10000,28,28,1]{2,1,3,0} %multiply, f32[3,3,1,32]{1,0,2,3} %copy.5, f32[32]{0} %parameter.2), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", metadata={op_type="conv_general_dilated" op_name="jit(evaluate)/conv_general_dilated[ batch_group_count=1\n                                    dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n                                    feature_group_count=1\n                                    lhs_dilation=(1, 1)\n                                    lhs_shape=(10000, 28, 28, 1)\n                                    padding=((1, 1), (1, 1))\n                                    precision=None\n                                    rhs_dilation=(1, 1)\n                                    rhs_shape=(3, 3, 1, 32)\n                                    window_strides=(1, 1) ]" source_file="/home/emam/anaconda3/envs/flax/lib/python3.7/site-packages/flax/nn/linear.py" source_line=247}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":0}" failed. Falling back to default algorithm. 

Convolution performance may be suboptimal.
2020-11-23 22:06:44.400440: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:349] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2020-11-23 22:06:44.400469: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:349] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2020-11-23 22:06:44.400488: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:349] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2020-11-23 22:06:44.400516: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:772] Failed to determine best cudnn convolution algorithm: Internal: All algorithms tried for convolution %custom-call.3 = (f32[10000,14,14,64]{2,1,3,0}, u8[0]{0}) custom-call(f32[10000,14,14,32]{2,1,3,0} %multiply.3, f32[3,3,32,64]{1,0,2,3} %copy.6, f32[64]{0} %parameter.4), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", metadata={op_type="conv_general_dilated" op_name="jit(evaluate)/conv_general_dilated[ batch_group_count=1\n                                    dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n                                    feature_group_count=1\n                                    lhs_dilation=(1, 1)\n                                    lhs_shape=(10000, 14, 14, 32)\n                                    padding=((1, 1), (1, 1))\n                                    precision=None\n                                    rhs_dilation=(1, 1)\n                                    rhs_shape=(3, 3, 32, 64)\n                                    window_strides=(1, 1) ]" source_file="/home/emam/anaconda3/envs/flax/lib/python3.7/site-packages/flax/nn/linear.py" source_line=247}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":0}" failed. Falling back to default algorithm. 

Convolution performance may be suboptimal.
2020-11-23 22:06:45.077602: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:349] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2020-11-23 22:06:45.077637: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_client.cc:1809] Execution of replica 0 failed: Unimplemented: DNN library is not found.

@Robert-Lu
Copy link

Thank you for your comments on this thread. I met the similar problem and I solved it after manually install CuDNN. You could refer to this official guide for installing it. I agree that the dependency of CuDNN library could be introduced in the new version's install instruction.

@MrinankSharma
Copy link

Hi everyone,

I'm having a similar issue. I also get the RuntimeError: Unimplemented: DNN library is not found..

However, I see a different error suggesting that CuDNN was loaded:

2021-01-13 17:12:23.262374: I external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.7
2021-01-13 17:12:24.491302: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:336] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR

Any pointers?

@falesiani
Copy link

falesiani commented Feb 5, 2021

export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/cuda/lib64

added to my .bashrc

based on:
https://forums.developer.nvidia.com/t/path-ld-library-path/48080

but error stays when using CNN module from stax or from haiku

RuntimeError: Unimplemented: DNN library is not found.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

@AntreasAntoniou
Copy link

I am also having the same issue. I can confirm that my LD_LIBRARY is correctly configured.

@Ir1d
Copy link

Ir1d commented Dec 20, 2021

I am also having the same issue. I can confirm that my LD_LIBRARY is correctly configured. I pointed LD_LIBRARY_PATH to CUDA path, and there is libcudnn.so.7 under it . But I get 2021-12-20 00:51:16.501936: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dyn amic library 'libcudnn.so.7'; dlerror: libcudnn.so.7: cannot open shared object file: No such file or directory;

@gabehope
Copy link

gabehope commented Feb 15, 2022

I was also getting this error. I don't know the details of what was happening, but the issue for me seemed to stem from JAX and Tensorflow not sharing the GPU nicely. When I added this code snippet to the top of my code it seems to run (taken from the Flax MNIST example):

# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.experimental.set_visible_devices([], 'GPU')

The comment suggests this is a known issue, but a quick google only brings up an old closed issue #120. I don't get the same issue running the same code on Colab (without the above snippet), so it may be particular to my machine's configuration.

Edit: Ah I now see there is a whole page on this in the JAX documentation. Would be very useful if JAX could detect this issue and give a helpful error message. Based on the "DNN library not found" error I went down the rabbit hole of thinking I had the wrong version of cuda/cudnn.

@Ir1d
Copy link

Ir1d commented Feb 15, 2022

I also believe this is due to wrong version of cuda/cudnn. I was able to overcome this issue by recreating the conda environment .

@Waterkin
Copy link

Waterkin commented May 14, 2022

①under cuda 11.2 install cudnn>8.2
https://stackoverflow.com/questions/55256671/how-to-install-latest-cudnn-to-conda
② add path to .bashrc
still not work

@xidulu
Copy link

xidulu commented Jun 21, 2022

I was able to solve this problem by adding these 4 lines of code at the head of the file:

import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] ='false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

@gigadeplex
Copy link

What worked for me:
conda install -c anaconda cudnn=8.2.1 cudatoolkit=11.3

Check if LD_LIBRARY_PATH is empty: echo $LD_LIBRARY_PATH
If empty export LD_LIBRARY_PATH=$CONDA_PREFIX/lib/
else export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/

Finally pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

@sudhakarsingh27 sudhakarsingh27 added NVIDIA GPU Issues specific to NVIDIA GPUs P0 (urgent) An issue of the highest priority. We are addressing this urgently. (Assignee required) labels Aug 10, 2022
@sudhakarsingh27
Copy link
Collaborator

@milmor is this resolved now?

@sudhakarsingh27 sudhakarsingh27 added P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) needs info More information is required to diagnose & prioritize the issue. and removed P0 (urgent) An issue of the highest priority. We are addressing this urgently. (Assignee required) labels Aug 12, 2022
@bijanx
Copy link

bijanx commented Oct 6, 2022

fixed for me

# install cudnn first
pip uninstall jax  
pip install jax[cuda]

@HoldOffHunger
Copy link

Still broken.

Traceback (most recent call last):
  File "C:\Users\Makhno\faceswap\lib\cli\launcher.py", line 222, in execute_script
    process.process()
  File "C:\Users\Makhno\faceswap\scripts\extract.py", line 165, in process
    extract.process()
  File "C:\Users\Makhno\faceswap\scripts\extract.py", line 689, in process
    self._run_extraction()
  File "C:\Users\Makhno\faceswap\scripts\extract.py", line 709, in _run_extraction
    self._extractor.launch()
  File "C:\Users\Makhno\faceswap\plugins\extract\pipeline.py", line 271, in launch
    self._launch_plugin(phase)
  File "C:\Users\Makhno\faceswap\plugins\extract\pipeline.py", line 700, in _launch_plugin
    plugin.initialize(**kwargs)
  File "C:\Users\Makhno\faceswap\plugins\extract\align\_base\aligner.py", line 199, in initialize
    super().initialize(*args, **kwargs)
  File "C:\Users\Makhno\faceswap\plugins\extract\_base.py", line 482, in initialize
    self.init_model()
  File "C:\Users\Makhno\faceswap\plugins\extract\align\fan.py", line 50, in init_model
    self.model.predict(placeholder)
  File "C:\Users\Makhno\faceswap\lib\model\session.py", line 105, in predict
    return self._model.predict(feed, verbose=0, batch_size=batch_size)
  File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\utils\traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\tensorflow\python\eager\execute.py", line 54, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.UnimplementedError: Graph execution error:

Detected at node 'model_1/conv1/Conv2D' defined at (most recent call last):
    File "C:\Users\Makhno\faceswap\faceswap.py", line 56, in <module>
      _main()
    File "C:\Users\Makhno\faceswap\faceswap.py", line 52, in _main
      arguments.func(arguments)
    File "C:\Users\Makhno\faceswap\lib\cli\launcher.py", line 222, in execute_script
      process.process()
    File "C:\Users\Makhno\faceswap\scripts\extract.py", line 165, in process
      extract.process()
    File "C:\Users\Makhno\faceswap\scripts\extract.py", line 689, in process
      self._run_extraction()
    File "C:\Users\Makhno\faceswap\scripts\extract.py", line 709, in _run_extraction
      self._extractor.launch()
    File "C:\Users\Makhno\faceswap\plugins\extract\pipeline.py", line 271, in launch
      self._launch_plugin(phase)
    File "C:\Users\Makhno\faceswap\plugins\extract\pipeline.py", line 700, in _launch_plugin
      plugin.initialize(**kwargs)
    File "C:\Users\Makhno\faceswap\plugins\extract\align\_base\aligner.py", line 199, in initialize
      super().initialize(*args, **kwargs)
    File "C:\Users\Makhno\faceswap\plugins\extract\_base.py", line 482, in initialize
      self.init_model()
    File "C:\Users\Makhno\faceswap\plugins\extract\align\fan.py", line 50, in init_model
      self.model.predict(placeholder)
    File "C:\Users\Makhno\faceswap\lib\model\session.py", line 105, in predict
      return self._model.predict(feed, verbose=0, batch_size=batch_size)
    File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\engine\training.py", line 2253, in predict
      tmp_batch_outputs = self.predict_function(iterator)
    File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\engine\training.py", line 2041, in predict_function
      return step_function(self, iterator)
    File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\engine\training.py", line 2027, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\engine\training.py", line 2015, in run_step
      outputs = model.predict_step(data)
    File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\engine\training.py", line 1983, in predict_step
      return self(x, training=False)
    File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\engine\training.py", line 557, in __call__
      return super().__call__(*args, **kwargs)
    File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\engine\functional.py", line 510, in call
      return self._run_internal_graph(inputs, training=training, mask=mask)
    File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\engine\functional.py", line 667, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\layers\convolutional\base_conv.py", line 283, in call
      outputs = self.convolution_op(inputs, self.kernel)
    File "C:\Users\Makhno\MiniConda3\envs\faceswap\lib\site-packages\keras\layers\convolutional\base_conv.py", line 255, in convolution_op
      return tf.nn.convolution(
Node: 'model_1/conv1/Conv2D'
DNN library is not found.
	 [[{{node model_1/conv1/Conv2D}}]] [Op:__inference_predict_function_18866]
01/07/2023 18:48:43 CRITICAL An unexpected crash has occurred. Crash report written to 'C:\Users\Makhno\faceswap\crash_report.2023.01.07.184843827561.log'. You MUST provide this file if seeking assistance. Please verify you are running the latest version of faceswap before reporting

@thekevinscott
Copy link

thekevinscott commented Jan 29, 2023

Arrived here after googling, running in to the same error with the DNN library.

The comment from @gabehope helped me resolve my problem. Specifically, I was running both Tensorflow and JAX in the same script and, presumably, they were both fighting for GPU memory.

For reference, here's the (quite helpful!) page on memory allocation with JAX.

It would be helpful if there were some way for the error to better indicate that it's a memory issue, though it sounds like for others it may be a different problem than what Gabe and I were running into.

@ZhenhuiL1n
Copy link

export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/

I sol

I was also getting this error. I don't know the details of what was happening, but the issue for me seemed to stem from JAX and Tensorflow not sharing the GPU nicely. When I added this code snippet to the top of my code it seems to run (taken from the Flax MNIST example):

# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.experimental.set_visible_devices([], 'GPU')

The comment suggests this is a known issue, but a quick google only brings up an old closed issue #120. I don't get the same issue running the same code on Colab (without the above snippet), so it may be particular to my machine's configuration.

Edit: Ah I now see there is a whole page on this in the JAX documentation. Would be very useful if JAX could detect this issue and give a helpful error message. Based on the "DNN library not found" error I went down the rabbit hole of thinking I had the wrong version of cuda/cudnn.

I solved by doing this, thanks a lot !

@Micky774
Copy link
Collaborator

We've added an FAQ section addressing various CUDA library loading issues and solutions/workarounds, and have (hopefully) made it easier to find by including it in some error messages that often correlate to these memory starvation issues.

I'm going to go ahead and close this specific issue, since the FAQ documentation should provide proper workarounds.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs info More information is required to diagnose & prioritize the issue. NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)
Projects
None yet
Development

No branches or pull requests