The following additional libraries are needed to run this
notebook. Note that running on Colab is experimental, please report a Github
issue if you have any problem.

# GPUs and TPUs
:label:`sec_use_gpu`

In :numref:`tab_intro_decade`, we illustrated the rapid growth
of computation over the past two decades.
In a nutshell, GPU performance has increased
by a factor of 1000 every decade since 2000.
This offers great opportunities but it also suggests
that there was significant demand for such performance.


In this section, we begin to discuss how to harness
this computational performance for your research.
First by using a single GPU and at a later point,
how to use multiple GPUs and multiple servers (with multiple GPUs).

Specifically, we will discuss how
to use a single NVIDIA GPU for calculations.
First, make sure you have at least one NVIDIA GPU installed.
Then, download the [NVIDIA driver and CUDA](https://developer.nvidia.com/cuda-downloads)
and follow the prompts to set the appropriate path.
Once these preparations are complete,
the `nvidia-smi` command can be used
to (**view the graphics card information**).


To run the programs in this section,
you need at least two GPUs.
Note that this might be extravagant for most desktop computers
but it is easily available in the cloud, e.g.,
by using the AWS EC2 multi-GPU instances.
Almost all other sections do *not* require multiple GPUs, but here we simply wish to illustrate data flow between different devices.


In [1]:
import jax
import jax.numpy as jnp

## [**Computing Devices**]

We can specify devices, such as CPUs and GPUs,
for storage and calculation.
By default, tensors are created in the main memory
and then the CPU is used for calculations.


In [2]:
def cpu():
    """Get the CPU device."""
    return jax.devices('cpu')[0]

def gpu(i=0):
    """Get a GPU device."""
    return jax.devices('gpu')[i]

def tpu(i=0):
    """Get a TPU device."""
    return jax.devices('tpu')[i]

# cpu(), gpu(), gpu(1)

We can (**query the number of available GPUs.**)


In [3]:
def num_gpus():
    """Get the number of available GPUs."""
    try:
        return jax.device_count('gpu')
    except:
        return 0  # No GPU backend found

num_gpus()

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


0

In [4]:
def num_tpu():
    """Get the number of available GPUs."""
    try:
        return jax.device_count('tpu')
    except:
        return 0  # No GPU backend found
num_tpu()

0

Now we [**define two convenient functions that allow us
to run code even if the requested GPUs do not exist.**]


In [5]:
def try_tpu_gpu(i=0):
    """Return gpu(i) if exists, otherwise return cpu()."""
    if num_tpu() >= i + 1:
        return tpu(i)
    elif num_gpus() >= i + 1:
        return gpu(i)
    return cpu()

def try_all_gpus():
    """Return all available GPUs, or [cpu(),] if no GPU exists."""
    return [gpu(i) for i in range(num_gpus())]

try_tpu_gpu(), try_tpu_gpu(10), try_all_gpus()

(CpuDevice(id=0), CpuDevice(id=0), [])

## Tensors and GPUs


By default, tensors are created on the GPU/TPU if they are available,
else CPU is used if not available.
We can [**query the device where the tensor is located.**]


In [6]:
x = jnp.array([1, 2, 3])
x.devices()

{CpuDevice(id=0)}

It is important to note that whenever we want
to operate on multiple terms,
they need to be on the same device.
For instance, if we sum two tensors,
we need to make sure that both arguments
live on the same device---otherwise the framework
would not know where to store the result
or even how to decide where to perform the computation.

### Storage on the GPU

There are several ways to [**store a tensor on the GPU.**]
For example, we can specify a storage device when creating a tensor.
Next, we create the tensor variable `X` on the first `gpu`.
The tensor created on a GPU only consumes the memory of this GPU.
We can use the `nvidia-smi` command to view GPU memory usage.
In general, we need to make sure that we do not create data that exceeds the GPU memory limit.

The syntax is
```
x_ondevice = jax.device_put(x, 'device')
```


In [7]:
# By default JAX puts arrays to GPUs or TPUs if available
X = jax.device_put(jnp.ones((2, 3)), try_tpu_gpu())
X, X.devices()

(Array([[1., 1., 1.],
        [1., 1., 1.]], dtype=float32),
 {CpuDevice(id=0)})

Assuming that you have at least two GPUs, the following code will (**create a random tensor, `Y`, on the second GPU.**)


In [8]:
Y = jax.device_put(jax.random.uniform(jax.random.PRNGKey(0), (2, 3)),
                   try_tpu_gpu(1))
Y, Y.devices()

(Array([[0.57450044, 0.09968603, 0.7419659 ],
        [0.8941783 , 0.59656656, 0.45325184]], dtype=float32),
 {CpuDevice(id=0)})

### Copying

[**If we want to compute `X + Y`,
we need to decide where to perform this operation.**]
For instance, as shown in :numref:`fig_copyto`,
we can transfer `X` to the second GPU
and perform the operation there.
*Do not* simply add `X` and `Y`,
since this will result in an exception.
The runtime engine would not know what to do:
it cannot find data on the same device and it fails.
Since `Y` lives on the second GPU,
we need to move `X` there before we can add the two.

![Copy data to perform an operation on the same device.](http://d2l.ai/_images/copyto.svg)
:label:`fig_copyto`


In [9]:
X + Y

Array([[1.5745004, 1.099686 , 1.7419659],
       [1.8941783, 1.5965666, 1.4532518]], dtype=float32)

To put it on a device, we will need to move it, which is again nothing stateful but functional: we get a moved tensor.

In [10]:
Z = jax.device_put(X, try_tpu_gpu(1))
print(X, X.devices())
print(Z, Z.devices())

[[1. 1. 1.]
 [1. 1. 1.]] {CpuDevice(id=0)}
[[1. 1. 1.]
 [1. 1. 1.]] {CpuDevice(id=0)}


Now that [**the data (both `Z` and `Y`) are on the same GPU), we can add them up.**]


In [11]:
Y + Z

Array([[1.5745004, 1.099686 , 1.7419659],
       [1.8941783, 1.5965666, 1.4532518]], dtype=float32)

Imagine that your variable `Z` already lives on your second GPU.
What happens if we still call `Z2 = Z` under the same device scope?
It will return `Z` instead of making a copy and allocating new memory.


In [12]:
Z2 = jax.device_put(Z, try_tpu_gpu(1))
Z2 is Z

False

### The bottleneck

Transferring variables between devices is slow: much slower than computation.
So we want you to be 100% certain
that you want to do something slow before we let you do it.
If the deep learning framework just did the copy automatically
without crashing then you might not realize
that you had written some slow code.

Transferring data is not only slow, it also makes parallelization a lot more difficult,
since we have to wait for data to be sent (or rather to be received)
before we can proceed with more operations.
This is why copy operations should be taken with great care.
As a rule of thumb, many small operations
are much worse than one big operation.
Moreover, several operations at a time
are much better than many single operations interspersed in the code
unless you know what you are doing.
This is the case since such operations can block if one device
has to wait for the other before it can do something else.
It is a bit like ordering your coffee in a queue
rather than pre-ordering it by phone
and finding out that it is ready when you are.

Last, when we print tensors or convert tensors to the NumPy format,
if the data is not in the main memory,
the framework will copy it to the main memory first,
resulting in additional transmission overhead.
Even worse, it is now subject to the dreaded global interpreter lock
that makes everything wait for Python to complete.


## Lazy execution

There is another trick that libraries may use to speed up the comutation: while the actual computation is performed, the Python code continues _until the value of the computation is needed_ (barrier).

At this point, the Python code stopps until the needed computation is ready.
Usually, we don't notice this (great! That's how it should be!), except if we're profiling code: when running something like `y = jax.something(x)`, then y is _technically_ not yet computed but a _lazy object_, a placeholder that says "I am holding a value, I promise".
If we use Python that _needs_ this value (i.e. by calling `np.array(y)`, it will wait until the computation is finished.

Therefore, to profile code, we're not interested in the time it takes to return a placeholder (which can be instant) but instead, how long it takes until the computation is performed.

`block_until_ready()` forces the Python code to wait until the computation is executed (only usecase is benchmarking, should not be needed otherwise).

_note that we cannot enforce the computation not to be executed, but that jax has the freedom to delay it_

In [13]:
import random

In [14]:
%%timeit -n1 -r3
with jax.default_device(cpu()):
    rand = random.randint(0, 1e4)
    x = jnp.linspace(0, 100, num=10_000_000 + rand)
    y = x ** 0.5 - jnp.log(jnp.abs(x));
    # y.block_until_ready()

186 ms ± 33.1 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)


In [15]:
n = 100_000_000
# n = 100

In [16]:
%%timeit -n5
x = jax.random.uniform(jax.random.PRNGKey(random.randint(0, 1e12)), (n,))  # on GPU/TPU
xcpu = jax.device_put(x, cpu())  # move to CPU
ycpu = xcpu ** 2 - 0.5  # calculate on CPU
yacc = jax.device_put(ycpu, try_tpu_gpu())  # move to GPU/TPU
zacc = yacc ** 3 - 0.2  # calculate on GPU/TPU
zacc.block_until_ready()

1.09 s ± 84.1 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [17]:
%%timeit -n5
with jax.default_device(cpu()):
    xcpu = jax.random.uniform(jax.random.PRNGKey(random.randint(0, 1e12)), (n,))  # on CPU
    ycpu = xcpu ** 2 - 0.5
    zcpu = ycpu ** 3 - 0.2
    zcpu.block_until_ready()

1.2 s ± 20.1 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [18]:
%%timeit -n5
with jax.default_device(try_tpu_gpu()):
    xacc = jax.random.uniform(jax.random.PRNGKey(random.randint(0, 1e12)), (n,))  # on CPU
    yacc = xacc ** 2 - 0.5
    zacc = yacc ** 3 - 0.2
    zacc.block_until_ready()

1.09 s ± 35.6 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [19]:
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
import torch

In [20]:
tf.config.get_visible_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]

In [21]:

with tf.device('cpu'):
    x = tnp.linspace(0, 10, num=100)
x.device

'/job:localhost/replica:0/task:0/device:CPU:0'

In [22]:
torch.cuda.device_count()

1

In [23]:
x = torch.linspace(0., 10, 100).to('cpu')  # try to use the `num=100` -> fails. torch is not very pythonic sometimes...
x.device

device(type='cpu')