Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm #8506

Closed
pseudo-rnd-thoughts opened this issue Nov 10, 2021 · 16 comments
Assignees
Labels
bug Something isn't working

Comments

@pseudo-rnd-thoughts
Copy link

pseudo-rnd-thoughts commented Nov 10, 2021

Running convolutional layers seems to cause an error that Jax does not know what cudnn optimisation algorithm to use
This error appears to be Jax only as I have replicated the code with TensorFlow and no error occurs

My jax version is 0.2.24 and jaxlib version is 0.1.74+cuda11.cudnn82 with a Nvidia 3080

The example is taken from the flax readme (https://github.com/google/flax)
The bug appears to be only for convolutions as the error does not occur for the MLP example

I haven't been able to replicate this error as I don't have another GPU to use
I found this similar issue from someone who uses a 3080 like me (#7953)

import jax.numpy as jnp
import flax.linen as nn

class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        x = nn.log_softmax(x)
        return x

model = CNN()
batch = jnp.ones((32, 64, 64, 10))  # (N, H, W, C) format
variables = model.init(jax.random.PRNGKey(0), batch)
# output = model.apply(variables, batch)
Traceback (most recent call last):
  File "/home/mark/Documents/programming/test-jax/flax_main.py", line 25, in <module>
    variables = model.init(jax.random.PRNGKey(0), batch)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 998, in init
    _, v_out = self.init_with_output(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 968, in init_with_output
    return self.apply(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 936, in apply
    return apply(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/core/scope.py", line 687, in wrapper
    y = fn(root, *args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 1178, in scope_fn
    return fn(module.clone(parent=scope), *args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 275, in wrapped_module_method
    y = fun(self, *args, **kwargs)
  File "/home/mark/Documents/programming/test-jax/flax_main.py", line 9, in __call__
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 275, in wrapped_module_method
    y = fun(self, *args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py", line 270, in __call__
    y = lax.conv_general_dilated(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 653, in conv_general_dilated
    return conv_general_dilated_p.bind(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/core.py", line 272, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/core.py", line 624, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 311, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/util.py", line 187, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/util.py", line 180, in cached
    return f(*args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 334, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 653, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars,
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 769, in compile
    self._executable = XlaCompiledComputation.from_xla_computation(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 798, in from_xla_computation
    compiled = compile_or_get_cached(backend, xla_computation, options)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 87, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 369, in backend_compile
    return backend.compile(built_c, compile_options=options)
RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm: INTERNAL: All algorithms tried for %cudnn-conv = (f32[32,64,64,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[32,64,64,10]{2,1,3,0} %copy.3, f32[3,3,10,32]{1,0,2,3} %copy.4), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="jit(conv_general_dilated)/conv_general_dilated[\n  batch_group_count=1\n  dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n  feature_group_count=1\n  lhs_dilation=(1, 1)\n  lhs_shape=(32, 64, 64, 10)\n  padding=((1, 1), (1, 1))\n  precision=None\n  preferred_element_type=None\n  rhs_dilation=(1, 1)\n  rhs_shape=(3, 3, 10, 32)\n  window_strides=(1, 1)\n]" source_file="/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py" source_line=270}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm. 

Convolution performance may be suboptimal.  To ignore this failure and try to use a fallback algorithm, use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.

Process finished with exit code 1
@pseudo-rnd-thoughts pseudo-rnd-thoughts added the bug Something isn't working label Nov 10, 2021
@zhangqiaorjc
Copy link
Member

zhangqiaorjc commented Nov 17, 2021

I couldn't repro this on my environment, it's likely RTX 3080 specific. I'm asking our GPU experts to take a look

zhangqiaorjc@skyewm-gpu-vm2:~$ cat issue_8506.py
import jax
import jax.numpy as jnp
import flax.linen as nn
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        x = nn.log_softmax(x)
        return x
model = CNN()
batch = jnp.ones((32, 64, 64, 10))  # (N, H, W, C) format
variables = model.init(jax.random.PRNGKey(0), batch)
# output = model.apply(variables, batch)
zhangqiaorjc@skyewm-gpu-vm2:~$ python3 issue_8506.py
zhangqiaorjc@skyewm-gpu-vm2:~$ python3
Python 3.8.10 (default, Sep 28 2021, 16:10:42) 
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
imp>>> import jaxlib
>>> jax.__version__
'0.2.25'
>>> jaxlib.__version__
'0.1.73'
>>> 
zhangqiaorjc@skyewm-gpu-vm2:~$ nvidia-smi
Wed Nov 17 19:31:07 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.91.03    Driver Version: 460.91.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P0    45W / 300W |    109MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  Off  | 00000000:00:05.0 Off |                    0 |
| N/A   37C    P0    45W / 300W |      4MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-SXM2...  Off  | 00000000:00:06.0 Off |                    0 |
| N/A   39C    P0    46W / 300W |      4MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   3  Tesla V100-SXM2...  Off  | 00000000:00:07.0 Off |                    0 |
| N/A   37C    P0    44W / 300W |      4MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   4  Tesla V100-SXM2...  Off  | 00000000:00:08.0 Off |                    0 |
| N/A   37C    P0    44W / 300W |      4MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   5  Tesla V100-SXM2...  Off  | 00000000:00:09.0 Off |                    0 |
| N/A   37C    P0    43W / 300W |      4MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   6  Tesla V100-SXM2...  Off  | 00000000:00:0A.0 Off |                    0 |
| N/A   39C    P0    43W / 300W |      4MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   7  Tesla V100-SXM2...  Off  | 00000000:00:0B.0 Off |                    0 |
| N/A   40C    P0    43W / 300W |      4MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+```

@atondwal
Copy link
Contributor

atondwal commented Nov 17, 2021

Yep, this is ampere-specific, and I was able to repro on an A6000 using the previous release. Yesterday's release of jaxlib 0.1.74 fixes it on my machine: can you try that?

@pseudo-rnd-thoughts
Copy link
Author

pseudo-rnd-thoughts commented Nov 18, 2021

The latest release doesn't fix it

Traceback (most recent call last):
  File "/home/mark/Documents/programming/test-jax/flax_main.py", line 25, in <module>
    variables = model.init(jax.random.PRNGKey(0), batch)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 884, in init
    _, v_out = self.init_with_output(rngs, *args, method=method, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 862, in init_with_output
    return self.apply(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 841, in apply
    return apply(fn, mutable=mutable)(variables, rngs=rngs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/core/scope.py", line 608, in wrapper
    y = fn(root, *args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 834, in <lambda>
    fn = lambda scope: method(self.clone(parent=scope), *args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 277, in wrapped_module_method
    y = fun(self, *args, **kwargs)
  File "/home/mark/Documents/programming/test-jax/flax_main.py", line 9, in __call__
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 277, in wrapped_module_method
    y = fun(self, *args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py", line 269, in __call__
    y = lax.conv_general_dilated(
  File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/_src/lax/lax.py", line 695, in conv_general_dilated
    return conv_general_dilated_p.bind(
  File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/core.py", line 274, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/core.py", line 626, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/interpreters/xla.py", line 419, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/_src/util.py", line 201, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/_src/util.py", line 194, in cached
    return f(*args, **kwargs)
  File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/interpreters/xla.py", line 442, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
  File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/interpreters/xla.py", line 768, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars,
  File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/interpreters/xla.py", line 903, in compile
    self._executable = XlaCompiledComputation.from_xla_computation(
  File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/interpreters/xla.py", line 932, in from_xla_computation
    compiled = compile_or_get_cached(backend, xla_computation, options)
  File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/interpreters/xla.py", line 871, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options)
  File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/interpreters/xla.py", line 478, in backend_compile
    return backend.compile(built_c, compile_options=options)
RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv = (f32[32,64,64,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[32,64,64,10]{2,1,3,0} %copy.3, f32[3,3,10,32]{1,0,2,3} %copy.4), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="jit(conv_general_dilated)/conv_general_dilated[\n  batch_group_count=1\n  dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n  feature_group_count=1\n  lhs_dilation=(1, 1)\n  lhs_shape=(32, 64, 64, 10)\n  padding=((1, 1), (1, 1))\n  precision=None\n  preferred_element_type=None\n  rhs_dilation=(1, 1)\n  rhs_shape=(3, 3, 10, 32)\n  window_strides=(1, 1)\n]" source_file="/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py" source_line=269}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"

Original error: INTERNAL: All algorithms tried for %cudnn-conv = (f32[32,64,64,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[32,64,64,10]{2,1,3,0} %copy.3, f32[3,3,10,32]{1,0,2,3} %copy.4), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="jit(conv_general_dilated)/conv_general_dilated[\n  batch_group_count=1\n  dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n  feature_group_count=1\n  lhs_dilation=(1, 1)\n  lhs_shape=(32, 64, 64, 10)\n  padding=((1, 1), (1, 1))\n  precision=None\n  preferred_element_type=None\n  rhs_dilation=(1, 1)\n  rhs_shape=(3, 3, 10, 32)\n  window_strides=(1, 1)\n]" source_file="/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py" source_line=269}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm.  Per-algorithm errors:
  Profiling failure on cuDNN engine 1#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 1: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 0#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 0: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 2#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 2: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 4#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 4: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 6#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 6: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 5#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 5: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 7#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 7: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'

To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.
 jax.__version__ = '0.2.26'
 jaxlib.__version__ = '0.1.74'

$ nvidia-smi
Thu Nov 18 22:12:41 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.82.00 Driver Version: 470.82.00 CUDA Version: 11.4 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:07:00.0 On | N/A |
| 0% 45C P8 32W / 340W | 344MiB / 10014MiB | 2% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 1276 G /usr/lib/xorg/Xorg 35MiB |
| 0 N/A N/A 1824 G /usr/lib/xorg/Xorg 129MiB |
| 0 N/A N/A 1959 G /usr/bin/gnome-shell 53MiB |
| 0 N/A N/A 81757 G ...AAAAAAAAA= --shared-files 105MiB |
| 0 N/A N/A 84116 G ..._82451.log --shared-files 3MiB |
+-----------------------------------------------------------------------------+

@hawkinsp
Copy link
Member

What version of CuDNN do you have installed?

@pseudo-rnd-thoughts
Copy link
Author

Sorry I hadnt seen your reply

I have the latest cuda and cudnn version, 11.5 and 8.3.1
Am happy to test any other versions that you suggest

I have tried building the jax project for my computer however that hasnt worked either

Any other suggestions to try?

@mattiasmar
Copy link

mattiasmar commented Dec 9, 2021

@pseudo-rnd-thoughts:
Issue #8302 solved this problem for me when running the Flax ImageNet example (add environment variable TF_FORCE_GPU_ALLOW_GROWTH before calling tf datasets)

@pseudo-rnd-thoughts
Copy link
Author

pseudo-rnd-thoughts commented Dec 13, 2021

Sadly that haven't fixed it either, same error about cudnn convolutions

This is really strange as a couple of people have had this bug but have all got their code working. Im not sure what is strange on my system.
Trying the equivalent tensorflow code, it doesn't throw an error
Any idea on what to try?

== Version information
Jax version: 0.2.25
Jaxlib version: 0.1.73
Cuda: 11.5
Cudnn: 8.3.1
TF_FORCE_GPU_ALLOW_GROWTH - true
Ubuntu 20.04.3 LTS

== Error

Traceback (most recent call last):
  File "/home/mark/Documents/programming/test-jax/flax_main.py", line 25, in <module>
    variables = model.init(jax.random.PRNGKey(0), batch)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 1122, in init
    _, v_out = self.init_with_output(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 1091, in init_with_output
    return self.apply(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 1058, in apply
    return apply(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/core/scope.py", line 706, in wrapper
    y = fn(root, *args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 1313, in scope_fn
    return fn(module.clone(parent=scope), *args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/transforms.py", line 883, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 318, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 603, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/home/mark/Documents/programming/test-jax/flax_main.py", line 9, in __call__
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/transforms.py", line 883, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 318, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 603, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py", line 282, in __call__
    y = lax.conv_general_dilated(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 653, in conv_general_dilated
    return conv_general_dilated_p.bind(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/core.py", line 272, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/core.py", line 624, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 416, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/util.py", line 187, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/util.py", line 180, in cached
    return f(*args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 439, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 759, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars,
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 892, in compile
    self._executable = XlaCompiledComputation.from_xla_computation(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 921, in from_xla_computation
    compiled = compile_or_get_cached(backend, xla_computation, options)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 863, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 474, in backend_compile
    return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm: INTERNAL: All algorithms tried for %cudnn-conv = (f32[32,64,64,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[32,64,64,10]{2,1,3,0} %copy.3, f32[3,3,10,32]{1,0,2,3} %copy.4), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="jit(conv_general_dilated)/conv_general_dilated[\n  batch_group_count=1\n  dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n  feature_group_count=1\n  lhs_dilation=(1, 1)\n  lhs_shape=(32, 64, 64, 10)\n  padding=((1, 1), (1, 1))\n  precision=None\n  preferred_element_type=None\n  rhs_dilation=(1, 1)\n  rhs_shape=(3, 3, 10, 32)\n  window_strides=(1, 1)\n]" source_file="/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py" source_line=282}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm. 

Convolution performance may be suboptimal.  To ignore this failure and try to use a fallback algorithm, use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/mark/Documents/programming/test-jax/flax_main.py", line 25, in <module>
    variables = model.init(jax.random.PRNGKey(0), batch)
  File "/home/mark/Documents/programming/test-jax/flax_main.py", line 9, in __call__
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py", line 282, in __call__
    y = lax.conv_general_dilated(
RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm: INTERNAL: All algorithms tried for %cudnn-conv = (f32[32,64,64,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[32,64,64,10]{2,1,3,0} %copy.3, f32[3,3,10,32]{1,0,2,3} %copy.4), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="jit(conv_general_dilated)/conv_general_dilated[\n  batch_group_count=1\n  dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n  feature_group_count=1\n  lhs_dilation=(1, 1)\n  lhs_shape=(32, 64, 64, 10)\n  padding=((1, 1), (1, 1))\n  precision=None\n  preferred_element_type=None\n  rhs_dilation=(1, 1)\n  rhs_shape=(3, 3, 10, 32)\n  window_strides=(1, 1)\n]" source_file="/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py" source_line=282}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm. 

Convolution performance may be suboptimal.  To ignore this failure and try to use a fallback algorithm, use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.

Process finished with exit code 1

== Jax / Flax Code
This can be found on the github.com/google/flax page

import jax
import jax.numpy as jnp
import flax.linen as nn


class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        x = nn.log_softmax(x)
        return x


model = CNN()
batch = jnp.ones((32, 64, 64, 10))  # (N, H, W, C) format
variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch)

== Tensorflow code

import numpy as np
import tensorflow as tf


model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3)),
    tf.keras.layers.ReLU(),
    tf.keras.layers.AvgPool2D(pool_size=(2, 2), strides=(2, 2)),
    tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3)),
    tf.keras.layers.ReLU(),
    tf.keras.layers.AvgPool2D(pool_size=(2, 2), strides=(2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dense(10, activation='softplus')
])

batch = np.ones((32, 64, 64, 10))
output = model(batch)

Thanks

@mattiasmar
Copy link

@pseudo-rnd-thoughts Go to the folder /usr/local and check what cuda installation you have installed (in my case it is cuda-11.3 as I work with the docker imagenvidia/cuda:11.3.0-cudnn8-devel-ubuntu20.04)
Then under the user that you working with (e.g. su pseudo-rnd-thoughts) type:

export PATH=/usr/local/cuda-11.3/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda-11.3/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}

Replace 11.3 above with whatever cuda version that you have installed.

@rems75
Copy link

rems75 commented Jan 6, 2022

Seeing the same issue on a Quadro T2000, tried the various fixes above and none worked.

== Version information
Jax version: 0.2.26
Jaxlib version: 0.1.75
Cuda: 11.5
Cudnn: 8.3.0

@pseudo-rnd-thoughts
Copy link
Author

Fix it, it is a memory allocation issue like suggested below however different
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.7

I found this previous discussion that had a very similar problem to mine
#6332

The discussion noted the way that Jax allocates memory, which by default is 90% on the first JAX operation which for us was the convolution operation.
As the GPU is my display then I think there isn't enough memory available for JAX to allocate 90% of the memory

@rems75 does this fix the issue for you?
If so, I think we can close the issue

@hawkinsp
Copy link
Member

@pseudo-rnd-thoughts I think the fix is that we need to have a minimum absolute amount of GPU RAM that we reserve for CuDNN. How much GPU RAM do you have? Is 0.7 the largest value that works? e.g., does, say, 0.8 work?

@pseudo-rnd-thoughts
Copy link
Author

@hawkinsp I have Nvidia 3080 with 10Gb RAM

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.29.05    Driver Version: 495.29.05    CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:07:00.0  On |                  N/A |
|  0%   52C    P5    43W / 340W |    287MiB / 10016MiB |     31%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      1195      G   /usr/lib/xorg/Xorg                 35MiB |
|    0   N/A  N/A      1785      G   /usr/lib/xorg/Xorg                153MiB |
|    0   N/A  N/A      1913      G   /usr/bin/gnome-shell               42MiB |
|    0   N/A  N/A     27970      G   ...AAAAAAAAA= --shared-files       40MiB |

I did a bit of testing: 80% and 85% are good while 90% causes the crash.
So I dont think the issue is minimum amount of GPU RAM because requiring 9GB (90%) seems too much to me
However if the nvidia-smi output is correct, when testing, my system was only using ~300Mb of RAM, i.e. 3% of the available so I don't understand why 90% use is giving a problem

@hawkinsp do you have any other questions?
Im happy to be a guinea pig so see if there is a larger underlying issue

@hawkinsp
Copy link
Member

@pseudo-rnd-thoughts No, that seems roughly in line with what I expect. You have 10016MiB, of which JAX claims 90% (9014MiB). Your system processes claim another 300MiB, so (9314MiB), and there's only ~700MiB left for CuDNN. This is apparently not enough. I think the way to fix this is for JAX to ensure that at least say, 1GiB is left free after its allocation for CuDNN to work. I don't know what the right value is for "1GiB", but clearly ~700MiB is too low.

@pseudo-rnd-thoughts
Copy link
Author

@hawkinsp Thanks, I was imagining that the cudnn memory usage would be within the JAX preallocated amount. That makes a lot of sense now.

@rems75
Copy link

rems75 commented Jan 13, 2022

Worked for me as well. Cool.

Numbers are different in my case. This is the GPU memory with XLA_PYTHON_CLIENT_MEM_FRACTION=0.8:
before 609MiB / 4096MiB
during 4027MiB / 4096MiB
So it seems like my CuDNN only uses ~130MB?

Note that this is through WSL2 on a Laptop running Windows 11.

@sg879
Copy link

sg879 commented Feb 15, 2022

Hi everyone,

I had the exact same issue described above. I am running on WSL2 on Windows 10. I installed CUDA and CuDNN and then installed jax[gpu] via pip. After setting XLA_PYTHON_CLIENT_MEM_FRACTION=0.87, my program works perfectly on a 3080, but with 0.9 the same RuntimeError is thrown up.

Pre-updated memory fraction nvidia-smi:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.52       Driver Version: 511.79       CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:01:00.0 Off |                  N/A |
| N/A   42C    P8    12W /  N/A |   7530MiB /  8192MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      7754      C   /python3.8                      N/A      |
+-----------------------------------------------------------------------------+`

Post-updated memory fraction nvidia-smi:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.52       Driver Version: 511.79       CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:01:00.0 Off |                  N/A |
| N/A   53C    P0    57W /  N/A |   7476MiB /  8192MiB |     48%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      7958      C   /python3.8                      N/A      |
+-----------------------------------------------------------------------------+

I am very new to any sort of collaboration on repositories, so apologies if my etiquette is somewhat off, but I was wondering whether this had any updates? Any "best practice" ways to correct this? I am currently setting XLA_PYTHON_CLIENT_MEM_FRACTION=0.87 in my bash ~/.profile directory and then just running Jax as is.

Also, slightly unrelated, but what would be the best way to keep up with updates to this repository? I will be using Jax pretty religiously to build SVGP models as I love its flexibility, and so would like to keep up-to-date.

Thanks for any help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

8 participants