# JAX on GPU

In this notebook we will see how to run JAX code on GPU. We should see the difference in performance compared to CPU. This will be demonstrated on the already known Euclidean distance function.

An advantage of JAX is that running on different devices, like GPU or TPU, is quite seamless. JAX can automatically run the code on the device that is available.

## Caveats of GPU Computing with JAX

While running JAX code on a GPU can provide significant speed-ups for large-scale numerical computations, there are several important caveats to keep in mind:

- **Data Transfer Overhead**: Moving data between the CPU (host) and GPU (device) can be slow. For small datasets or frequent transfers, this overhead may outweigh the performance gains of GPU acceleration.

- **JAX Array Semantics**: JAX operations are performed on `DeviceArray` objects, which reside on a specific device. Mixing NumPy arrays (on CPU) and JAX arrays (on GPU) can lead to implicit, costly data transfers. Always ensure your data is on the correct device before computation.

- **GPU Memory Limitations**: GPUs typically have much less memory than CPUs. Large datasets may not fit in GPU memory, leading to out-of-memory errors.

- **Non-Universal Speedup**: Not all algorithms benefit equally from GPU acceleration. Simple or memory-bound operations may not see significant improvements.

- **Device Availability**: Code that runs on GPU in one environment (such as Colab) may fall back to CPU elsewhere. Always check which devices are available and handle the absence of a GPU gracefully.

- **Reproducibility**: Floating-point computations on GPU may yield slightly different results compared to CPU, due to differences in hardware and parallel execution order.


## Running in Colab

Make sure to select the GPU runtime in the top right corner of the notebook.

![How to change runtime to GPU in Colab](images/change-runtime-1.png)
![How to change runtime to GPU in Colab](images/change-runtime-2.png)


In [1]:
import numpy as np
import pandas as pd
import jax

import plotly.express as px
import plotly.graph_objects as go


In [3]:
print(f"Jax devices: {jax.devices()}")
if not(any(device.platform == "gpu" for device in jax.devices())):
    print("WARNING:No GPU found!")

Jax devices: [CpuDevice(id=0)]


You should see something like
```
Jax devices: [CudaDevice(id=0)]
```
when running properly with a GPU device. There will be a warning if there is no GPU device.


### Performance

Let's measure the performance (run time) of our algorithm. We will use the `%timeit` magic command to measure the time it takes to run the algorithm.

In [7]:
n_dataset_points: int = 10_000
n_query_points: int = 100
n_dim: int = 3
k: int = 5


def create_random_data(
    n_points: int, n_dim: int, *, seed: int = 42
) -> np.ndarray:
    np.random.seed(seed)
    return np.random.sample((n_points, n_dim)).astype(np.float32)

dataset = create_random_data(n_dataset_points, n_dim, seed=420)


## JAX on GPU



In [8]:
import jax
import jax.numpy as jnp

def selu(x: jnp.ndarray, alpha: float = 1.67, lmbda: float = 1.05) -> jnp.ndarray:
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)


x = jnp.arange(5.0)
print(selu(x))

[0.        1.05      2.1       3.1499999 4.2      ]


In [9]:
x.device

CudaDevice(id=0)

In [10]:
key = jax.random.key(1701)
x = jax.random.normal(key, (1_000_000,))

%timeit selu(x).block_until_ready()

1.11 ms ± 95.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
selu_jit = jax.jit(selu)

print(selu_jit(x)[:3])

[-0.83556366  0.33142313 -0.9244633 ]


Two important things happened above:

1. We instructed to Just-In-Time (JIT) compile the function when we call it.
2. The function *was* compiled in the `print` call. It was compiled for the *concrete input type and shape*.







In [12]:
%timeit selu_jit(x).block_until_ready()

175 µs ± 29.9 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


The compiled function is significantly faster than the uncompiled one!

And also much faster than the CPU version (which was around 1 ms)!

### Exercise: JIT compiling Euclidean distance

1. Create a JIT-compiled version of the Euclidean distance function for GPU.
2. Compare the performance of the CPU and GPU versions.

Optionally:

3. Check that the GPU version yields the same result as the CPU version.
4. Compare the scaling of the performance of CPU and GPU versions with respect to the number of query points or the number of dimensions.

In [17]:
import jax
import jax.numpy as jnp


def euclidean_distances_jax(
    query_points: jnp.ndarray, dataset: jnp.ndarray
) -> jnp.ndarray:
    """
    Calculates the Euclidean distance between a set of query points and a dataset of points.

    Args:
        query_points (jnp.ndarray): Array of shape (n_queries, n_features).
        dataset (jnp.ndarray): Array of shape (n_samples, n_features).

    Returns:
        jnp.ndarray: The Euclidean distance between the query points and the dataset.
    """
    # Broadcasting (dataset - query_point) subtracts query_point from each row of dataset
    return jnp.sqrt(jnp.sum((dataset[:, jnp.newaxis, :] - query_points) ** 2, axis=-1))


euclidean_distances_jax_jit = jax.jit(euclidean_distances_jax)

In [19]:
n_dataset_points: int = 10_000
n_query_points: int = 100
n_dim: int = 3
k: int = 5


dataset = create_random_data(n_dataset_points, n_dim, seed=420)
query_points = create_random_data(n_query_points, n_dim, seed=421)

dataset_jax = jnp.array(dataset)
query_points_jax = jnp.array(query_points)

np.testing.assert_allclose(euclidean_distances_numpy(query_points, dataset), euclidean_distances_jax(query_points_jax, dataset_jax), rtol=1e-6)
np.testing.assert_allclose(euclidean_distances_numpy(query_points, dataset), euclidean_distances_jax_jit(query_points_jax, dataset_jax), rtol=1e-6)

In [20]:
%timeit euclidean_distances_numpy(query_points, dataset)
%timeit euclidean_distances_jax(query_points_jax, dataset_jax).block_until_ready()
%timeit euclidean_distances_jax_jit(query_points_jax, dataset_jax).block_until_ready()

28.8 ms ± 823 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
1.14 ms ± 67.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
169 µs ± 27.9 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [21]:
euclidean_execution_times = []
euclidean_distances_jax_jit.__name__ = "euclidean_distances_jax_jit"

for function in [euclidean_distances_numpy, euclidean_distances_jax, euclidean_distances_jax_jit]:
    for n_dim in [2, 4]:
        dataset = create_random_data(n_dataset_points, n_dim)
        for n_query_points in [1, 100, 1_000, 10_000]:
            query_points = create_random_data(n_query_points, n_dim)
            if "numpy" in function.__name__:
                execution_time = %timeit -o function(query_points, dataset)
            else:
                function(query_points, dataset).block_until_ready()
                execution_time = %timeit -o function(query_points, dataset).block_until_ready()
            euclidean_execution_times.append(
                {
                    "n_query_points": n_query_points,
                    "n_dataset_points": n_dataset_points,
                    "n_dim": n_dim,
                    "execution_time": execution_time.average,
                    "function": function.__name__,
                }
            )


249 µs ± 6.35 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
28.4 ms ± 1.21 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
308 ms ± 4.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.52 s ± 294 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
364 µs ± 131 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
36.3 ms ± 5.03 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
373 ms ± 37 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.89 s ± 258 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
466 µs ± 16.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
14.7 ms ± 2.62 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
157 ms ± 21.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
1.63 s ± 224 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
668 µs ± 199 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
19.6 ms ± 2.04 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
206 ms ± 5.92 ms per lo

In [None]:
px.line(
    euclidean_execution_times,
    x="n_query_points",
    y="execution_time",
    title="Execution Time: Euclidean Distance",
    labels={"n_query_points": "Number of Query Points", "execution_time": "Execution Time (s)"},
    log_x=True,
    log_y=True,
    markers=True,
    color="function",
    symbol="n_dim",
)