# JAX on GPU

In this notebook we will see how to run JAX code on GPU. We should see the difference in performance compared to CPU. This will be demonstrated on the already known Euclidean distance function.

An advantage of JAX is that running on different devices, like GPU or TPU, is quite seamless. JAX can automatically run the code on the device that is available.

## Caveats of GPU Computing with JAX

While running JAX code on a GPU can provide significant speed-ups for large-scale numerical computations, there are several important caveats to keep in mind:

- **Data Transfer Overhead**: Moving data between the CPU (host) and GPU (device) can be slow. For small datasets or frequent transfers, this overhead may outweigh the performance gains of GPU acceleration.

- **JAX Array Semantics**: JAX operations are performed on `DeviceArray` objects, which reside on a specific device. Mixing NumPy arrays (on CPU) and JAX arrays (on GPU) can lead to implicit, costly data transfers. Always ensure your data is on the correct device before computation.

- **GPU Memory Limitations**: GPUs typically have much less memory than CPUs. Large datasets may not fit in GPU memory, leading to out-of-memory errors.

- **Non-Universal Speedup**: Not all algorithms benefit equally from GPU acceleration. Simple or memory-bound operations may not see significant improvements.

- **Device Availability**: Code that runs on GPU in one environment (such as Colab) may fall back to CPU elsewhere. Always check which devices are available and handle the absence of a GPU gracefully.

- **Reproducibility**: Floating-point computations on GPU may yield slightly different results compared to CPU, due to differences in hardware and parallel execution order.


## Running in Colab

Make sure to select the GPU runtime in the top right corner of the notebook.

![How to change runtime to GPU in Colab](images/change-runtime-1.png)
![How to change runtime to GPU in Colab](images/change-runtime-2.png)


In [None]:
import numpy as np
import pandas as pd
import jax

import plotly.express as px
import plotly.graph_objects as go


In [None]:
print(f"Jax devices: {jax.devices()}")
if not(any(device.platform == "gpu" for device in jax.devices())):
    print("WARNING:No GPU found!")

You should see something like
```
Jax devices: [CudaDevice(id=0)]
```
when running properly with a GPU device. There will be a warning if there is no GPU device.


### Performance

Let's measure the performance (run time) of our algorithm. We will use the `%timeit` magic command to measure the time it takes to run the algorithm.

In [None]:
n_dataset_points: int = 10_000
n_query_points: int = 100
n_dim: int = 3
k: int = 5


def create_random_data(
    n_points: int, n_dim: int, *, seed: int = 42
) -> np.ndarray:
    np.random.seed(seed)
    return np.random.sample((n_points, n_dim)).astype(np.float32)

dataset = create_random_data(n_dataset_points, n_dim, seed=420)


## JAX on GPU



In [None]:
import jax
import jax.numpy as jnp

def selu(x: jnp.ndarray, alpha: float = 1.67, lmbda: float = 1.05) -> jnp.ndarray:
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)


x = jnp.arange(5.0)
print(selu(x))

In [None]:
x.device

In [None]:
key = jax.random.key(1701)
x = jax.random.normal(key, (1_000_000,))

%timeit selu(x).block_until_ready()

In [None]:
selu_jit = jax.jit(selu)

print(selu_jit(x)[:3])

Two important things happened above:

1. We instructed to Just-In-Time (JIT) compile the function when we call it.
2. The function *was* compiled in the `print` call. It was compiled for the *concrete input type and shape*.







In [None]:
%timeit selu_jit(x).block_until_ready()

The compiled function is significantly faster than the uncompiled one!

And also much faster than the CPU version (which was around 1 ms)!

### Exercise: JIT compiling Euclidean distance

1. Create a JIT-compiled version of the Euclidean distance function for GPU.
2. Compare the performance of the CPU and GPU versions.

Optionally:

3. Check that the GPU version yields the same result as the CPU version.
4. Compare the scaling of the performance of CPU and GPU versions with respect to the number of query points or the number of dimensions.