# GPU acceleration and GradVAR training

A GPU (Graphics Processing Unit) is a specialized processor originally designed for rendering graphics, but now widely used in scientific computing and machine learning due to its ability to perform parallel computations at high speed. Unlike CPUs, which are optimized for sequential task execution, GPUs consist of thousands of smaller cores designed for handling multiple operations simultaneously, making them ideal for vectorized and matrix-heavy computations.

When run on compatible hardware, GradVAR can transparently offload computations to the GPU, dramatically improving performance, especially for large-scale linear algebra, deep learning, and scientific simulation tasks.

To use GradVAR with GPU, install GradVAR using:

```sh
pip install gradvar[gpu]
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

Alternatively, you can ensure that the JAX installation is properly configured with a GPU before a regular `pip install gradvar`. Follow instructions at [Installation](https://docs.jax.dev/en/latest/installation.html#pip-installation-gpu-cuda) for more details.

## Ensuring that a GPU is available

The following two commands will return "CudaDevice" and "gpu" if properly configured:

In [4]:
import jax

# List available devices (should include GPU)
jax.devices()

[CudaDevice(id=0)]

In [5]:
from jax.lib import xla_bridge

xla_bridge.get_backend().platform

'gpu'

GradVAR will automatically switch to GPU use for the gradient calculations, when available.