# System Info

Get GPU,Cuda and CuDNN info:

In [1]:
!nvidia-smi

Fri Feb 26 09:20:38 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.80.02    Driver Version: 450.80.02    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            On   | 00000000:00:1E.0 Off |                    0 |
| N/A   31C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Oct_12_20:09:46_PDT_2020
Cuda compilation tools, release 11.1, V11.1.105
Build cuda_11.1.TC455_06.29190527_0


In [3]:
# from https://stackoverflow.com/questions/31326015/how-to-verify-cudnn-installation/36978616
!cat /usr/include/cudnn_version.h | grep '#define CUDNN_MAJOR' -A 2

#define CUDNN_MAJOR 8
#define CUDNN_MINOR 0
#define CUDNN_PATCHLEVEL 5


Check if jax can detect GPU:

In [4]:
# from https://github.com/google/jax/issues/971
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


# Demo

In [5]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np

In [6]:
key = random.PRNGKey(0)
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)

Run on GPU:

In [7]:
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

13 ms ± 3.71 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Run on CPU:

In [8]:
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)

122 ms ± 1.09 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
