In [1]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Mar_28_02:18:24_PDT_2024
Cuda compilation tools, release 12.4, V12.4.131
Build cuda_12.4.r12.4/compiler.34097967_0


In [2]:
#!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

In [2]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Sun Aug  4 12:16:42 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.27                 Driver Version: 560.70         CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4070 ...    On  |   00000000:01:00.0  On |                  N/A |
|  0%   42C    P8              9W /  285W |    2029MiB /  16376MiB |     48%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [4]:
import jax

In [5]:
jax.devices()

[cuda(id=0)]

In [6]:
jax.devices("cpu")

[CpuDevice(id=0)]

In [7]:
jax.devices("gpu")

[cuda(id=0)]

In [8]:
import numpy as np
import jax.numpy as jnp

In [9]:
# a function with some amount of calculations
def f(x):
  y1 = x + x*x + 3
  y2 = x*x + x*x.T
  return y1*y2

# generate some random data
x = np.random.randn(3000, 3000).astype('float32')
jax_x_gpu = jax.device_put(jnp.array(x), jax.devices('gpu')[0])
jax_x_cpu = jax.device_put(jnp.array(x), jax.devices('cpu')[0])

# compile function to CPU and GPU backends with JAX
jax_f_cpu = jax.jit(f, backend='cpu')
jax_f_gpu = jax.jit(f, backend='gpu')

# warm-up
jax_f_cpu(jax_x_cpu)
jax_f_gpu(jax_x_gpu)

Array([[ 1.71871048e+02,  8.19977474e+00,  1.06485176e+01, ...,
         1.96896732e+00,  5.80587924e-01,  4.19012547e+00],
       [-2.37301445e+00,  1.43541393e+01,  7.59780049e-01, ...,
         3.30399424e-01,  3.34734589e-01, -5.36943972e-02],
       [ 2.06122375e+00,  1.69826183e+01,  8.65360451e+00, ...,
         4.29970074e+00, -1.19282775e-01,  8.49998474e+00],
       ...,
       [ 2.40623260e+00,  2.07550123e-01,  6.37984180e+00, ...,
         2.15799923e+01,  8.03496659e-01,  7.38873065e-01],
       [ 2.94676375e+00, -1.61128402e-01,  2.05881566e-01, ...,
         6.57972383e+00,  3.36179519e+00,  2.01247334e+00],
       [ 3.71754098e+00,  8.14146709e+00, -4.31723166e+00, ...,
         3.63160610e-01, -1.76449144e+00,  6.14816017e+01]],      dtype=float32)

In [10]:
%timeit -n100 f(x)

48 ms ± 320 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [11]:
%timeit -n100 f(jax_x_cpu).block_until_ready()

41.6 ms ± 415 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
%timeit -n100 jax_f_cpu(jax_x_cpu).block_until_ready()

7.52 ms ± 64.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
%timeit -n100 f(jax_x_gpu).block_until_ready()

1.94 ms ± 558 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
%timeit -n100 jax_f_gpu(jax_x_gpu).block_until_ready()

1.41 ms ± 379 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [16]:
jax_x_cpu.devices()

{CpuDevice(id=0)}

In [17]:
jax_x_gpu.devices()

{cuda(id=0)}