In [1]:
# !nvcc --version

In [2]:
#!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

In [3]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Thu Aug 21 22:59:45 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.64.03              Driver Version: 575.64.03      CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce GTX 1080 Ti     Off |   00000000:05:00.0  On |                  N/A |
| 39%   63C    P8             22W /  250W |     368MiB /  11264MiB |     30%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce GTX 1080 Ti     Off |   00

In [4]:
# from jax.lib import xla_bridge
# print(xla_bridge.get_backend().platform)
from jax.extend import backend
print(backend.get_backend().platform)


gpu


In [5]:
import jax

In [6]:
jax.devices()

[CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)]

In [7]:
jax.devices("cpu")

[CpuDevice(id=0)]

In [8]:
jax.devices("gpu")

[CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)]

In [9]:
import numpy as np
import jax.numpy as jnp

In [None]:
# A function with some amount of calculations
def f(x):
  y1 = x + x*x + 3
  y2 = x*x + x*x.T
  return y1*y2

# Generate some random data
x = np.random.randn(3000, 3000).astype('float32')
jax_x_gpu = jax.device_put(jnp.array(x), jax.devices('gpu')[0])
jax_x_cpu = jax.device_put(jnp.array(x), jax.devices('cpu')[0])

# Compile function to CPU and GPU backends with JAX
jax_f_cpu = jax.jit(f, backend='cpu')
jax_f_gpu = jax.jit(f, backend='gpu')

# Warm-up
jax_f_cpu(jax_x_cpu)
jax_f_gpu(jax_x_gpu);

In [11]:
%timeit -n100 f(x)

89.6 ms ± 602 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
%timeit -n100 f(jax_x_cpu).block_until_ready()

85.9 ms ± 2.99 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
%timeit -n100 jax_f_cpu(jax_x_cpu).block_until_ready()

17.6 ms ± 525 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
%timeit -n100 f(jax_x_gpu).block_until_ready()

2.51 ms ± 655 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
%timeit -n100 jax_f_gpu(jax_x_gpu).block_until_ready()

608 μs ± 12.2 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [16]:
jax_x_cpu.device

CpuDevice(id=0)

In [17]:
jax_x_gpu.device

CudaDevice(id=0)