# GPUs and TPUs

We will discuss how
to use a single NVIDIA GPU and a Google TPU for calculations.

In [None]:
import jax
import jax.numpy as jnp

## Computing Devices

We can specify devices, such as CPUs and GPUs,
for storage and calculation.
By default, tensors are created in the main memory
and then the CPU is used for calculations.


In [None]:
def cpu():
    """Get the CPU device."""
    return jax.devices('cpu')[0]

def gpu(i=0):
    """Get a GPU device."""
    return jax.devices('gpu')[i]

def tpu(i=0):
    """Get a TPU device."""
    return jax.devices('tpu')[i]

# cpu(), gpu(), gpu(1)

We can query the number of available GPUs and TPUs

In [None]:
def num_gpus():
    """Get the number of available GPUs."""
    try:
        return jax.device_count('gpu')
    except:
        return 0  # No GPU backend found

num_gpus()

In [None]:
def num_tpu():
    """Get the number of available GPUs."""
    try:
        return jax.device_count('tpu')
    except:
        return 0  # No GPU backend found
num_tpu()

Now we define two convenient functions that allow us
to run code even if the requested GPUs or TPUs do not exist.


In [None]:
def try_tpu_gpu(i=0):
    """Return gpu(i) if exists, otherwise return cpu()."""
    if num_tpu() >= i + 1:
        return tpu(i)
    elif num_gpus() >= i + 1:
        return gpu(i)
    return cpu()

def try_all_gpus():
    """Return all available GPUs, or [cpu(),] if no GPU exists."""
    return [gpu(i) for i in range(num_gpus())]

try_tpu_gpu(), try_tpu_gpu(10), try_all_gpus()

## Tensors and GPUs/TPUs


By default, tensors are created on the GPU/TPU if they are available,
else CPU is used if not available.
We can query the device where the tensor is located.


In [None]:
x = jnp.array([1, 2, 3])
x.devices()

It is important to note that whenever we want
to operate on multiple terms,
they need to be on the same device.
For instance, if we sum two tensors,
we need to make sure that both arguments
live on the same device---otherwise the framework
would not know where to store the result
or even how to decide where to perform the computation.

### Storage on the GPU

There are several ways to store a tensor on the GPU.
For example, we can specify a storage device when creating a tensor.
Next, we create the tensor variable `X` on the first `gpu`.
The tensor created on a GPU only consumes the memory of this GPU.
We can use the `nvidia-smi` command to view GPU memory usage.
In general, we need to make sure that we do not create data that exceeds the GPU memory limit.

The syntax is
```
x_ondevice = jax.device_put(x, 'device')
```


In [None]:
# By default JAX puts arrays to GPUs or TPUs if available
X = jax.device_put(jnp.ones((2, 3)), try_tpu_gpu())
X, X.devices()

Assuming that you have at least two TPUs, the following code will create a random tensor, `Y`, on the second GPU.


In [None]:
Y = jax.device_put(jax.random.uniform(jax.random.PRNGKey(0), (2, 3)),
                   try_tpu_gpu(1))
Y, Y.devices()

### Copying

If we want to compute `X + Y`,
we need to decide where to perform this operation.
For instance, as shown below,
we can transfer `X` to the second accelerator
and perform the operation there.
Simply adding `X` and `Y` will error.
The runtime engine would not know what to do:
it cannot find data on the same device and it fails.
Since `Y` lives on the second GPU,
we need to move `X` there before we can add the two.

![Copy data to perform an operation on the same device.](http://d2l.ai/_images/copyto.svg)



In [None]:
X + Y

To put it on a device, we will need to move it, which is again nothing stateful but functional: we get a moved tensor.

In [None]:
Z = jax.device_put(X, try_tpu_gpu(1))
print(X, X.devices())
print(Z, Z.devices())

Now that the data (both `Z` and `Y`) are on the same GPU), we can add them up.


In [None]:
Y + Z

Imagine that your variable `Z` already lives on your second GPU.
What happens if we still call `Z2 = Z` under the same device scope?
It will return `Z` instead of making a copy and allocating new memory.


In [None]:
Z2 = jax.device_put(Z, try_tpu_gpu(1))
Z2 is Z

### The bottleneck

Transferring variables between devices is slow: much slower than computation.
So we want you to be 100% certain
that you want to do something slow before we let you do it.
If the deep learning framework just did the copy automatically
without crashing then you might not realize
that you had written some slow code.

Transferring data is not only slow, it also makes parallelization a lot more difficult,
since we have to wait for data to be sent (or rather to be received)
before we can proceed with more operations.
This is why copy operations should be taken with great care.
As a rule of thumb, many small operations
are much worse than one big operation.

Last, when we print tensors or convert tensors to the NumPy format,
if the data is not in the main memory,
the framework will copy it to the main memory first,
resulting in additional transmission overhead.
Even worse, it is now subject to the dreaded global interpreter lock
that makes everything wait for Python to complete.


## Lazy execution

There is another trick that libraries may use to speed up the comutation: while the actual computation is performed, the Python code continues _until the value of the computation is needed_ (barrier).

At this point, the Python code stopps until the needed computation is ready.
Usually, we don't notice this (great! That's how it should be!), except if we're profiling code: when running something like `y = jax.something(x)`, then y is _technically_ not yet computed but a _lazy object_, a placeholder that says "I am holding a value, I promise".
If we use Python that _needs_ this value (i.e. by calling `np.array(y)`, it will wait until the computation is finished.

Therefore, to profile code, we're not interested in the time it takes to return a placeholder (which can be instant) but instead, how long it takes until the computation is performed.

`block_until_ready()` forces the Python code to wait until the computation is executed (only usecase is benchmarking, should not be needed otherwise).

_note that we cannot enforce the computation not to be executed, but that jax has the freedom to delay it_

In [None]:
import random

In [None]:
%%timeit -n1 -r3
with jax.default_device(cpu()):
    rand = random.randint(0, 1e4)
    x = jnp.linspace(0, 100, num=10_000_000 + rand)
    y = x ** 0.5 - jnp.log(jnp.abs(x));
    # y.block_until_ready()

In [None]:
n = 100_000_000
# n = 100

## Pseudo random numbers

How hard is it to get random numbers right?

[Hard. JAX has an extensive, in-depth explanation for their choice of randomness](https://jax.readthedocs.io/en/latest/random-numbers.html)

In short, we need to feed a seed/generator to every call (_before you complain, read the above. You may still complain, but not as loud_)

In [None]:
%%timeit -n5
x = jax.random.uniform(jax.random.PRNGKey(random.randint(0, 1e12)), (n,))  # on GPU/TPU
xcpu = jax.device_put(x, cpu())  # move to CPU
ycpu = xcpu ** 2 - 0.5  # calculate on CPU
yacc = jax.device_put(ycpu, try_tpu_gpu())  # move to GPU/TPU
zacc = yacc ** 3 - 0.2  # calculate on GPU/TPU
zacc.block_until_ready()

In [None]:
%%timeit -n5
with jax.default_device(cpu()):
    xcpu = jax.random.uniform(jax.random.PRNGKey(random.randint(0, 1e12)), (n,))  # on CPU
    ycpu = xcpu ** 2 - 0.5
    zcpu = ycpu ** 3 - 0.2
    zcpu.block_until_ready()

In [None]:
%%timeit -n5
with jax.default_device(try_tpu_gpu()):
    xacc = jax.random.uniform(jax.random.PRNGKey(random.randint(0, 1e12)), (n,))  # on CPU
    yacc = xacc ** 2 - 0.5
    zacc = yacc ** 3 - 0.2
    zacc.block_until_ready()

As mentioned in the talk, TF and torch are similar. The concept is about the same, they all just have slightly different ways and syntax.

In [None]:
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
import torch

In [None]:
tf.config.get_visible_devices()

In [None]:

with tf.device('cpu'):
    x = tnp.linspace(0, 10, num=100)
x.device

In [None]:
torch.cuda.device_count()

In [None]:
x = torch.linspace(0., 10, 100).to('cpu')  # why I dislike torch: try to use the `num=100` -> fails.
x.device                                   # torch is not very pythonic sometimes...