In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random


In [2]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)


[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [3]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU


9.75 ms ± 1.03 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661194 0.10499357]


In [5]:
x_small

DeviceArray([0., 1., 2.], dtype=float32)

In [6]:
derivative_fn

<function __main__.sum_logistic(x)>

In [7]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

In [9]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
3.07 ms ± 103 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
56.1 µs ± 3.02 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [12]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
58 µs ± 3.61 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [19]:
! #export XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=fals

In [18]:
x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)

2023-01-10 01:49:49.773081: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:389] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2023-01-10 01:49:49.776399: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:389] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2023-01-10 01:49:49.778660: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:389] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR


XlaRuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv = (f32[1,1,12]{2,1,0}, u8[0]{0}) custom-call(f32[1,1,10]{2,1,0} %bitcast, f32[1,1,3]{2,1,0} %bitcast.1), window={size=3 pad=2_2}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(convolve)/jit(main)/jit(_conv)/conv_general_dilated[window_strides=(1,) padding=((2, 2),) lhs_dilation=(1,) rhs_dilation=(1,) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2), rhs_spec=(0, 1, 2), out_spec=(0, 1, 2)) feature_group_count=1 batch_group_count=1 lhs_shape=(1, 1, 10) rhs_shape=(1, 1, 3) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_522/2973065201.py" source_line=3}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"

Original error: UNIMPLEMENTED: DNN library is not found.

To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.

In [20]:
from jax import lax
result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # note: explicit promotion
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # equivalent of padding='full' in NumPy
result[0, 0]


2023-01-10 01:50:22.382976: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:389] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2023-01-10 01:50:22.387506: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:389] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2023-01-10 01:50:22.389924: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:389] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR


XlaRuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv = (f32[1,1,12]{2,1,0}, u8[0]{0}) custom-call(f32[1,1,3]{2,1,0} %Arg_0.1, f32[1,1,10]{2,1,0} %Arg_1.2), window={size=10 pad=9_9}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(conv_general_dilated)/jit(main)/conv_general_dilated[window_strides=(1,) padding=((9, 9),) lhs_dilation=(1,) rhs_dilation=(1,) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2), rhs_spec=(0, 1, 2), out_spec=(0, 1, 2)) feature_group_count=1 batch_group_count=1 lhs_shape=(1, 1, 3) rhs_shape=(1, 1, 10) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_522/4009118831.py" source_line=2}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"

Original error: UNIMPLEMENTED: DNN library is not found.

To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.