<a href="https://colab.research.google.com/github/google/jax/blob/main/tests/notebooks/colab_tpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# JAX Colab TPU Test

This notebook is meant to be run in a [Colab](http://colab.research.google.com) TPU runtime as a basic check for JAX updates.

In [1]:
import jax
import jaxlib

!cat /var/colab/hostname
print(jax.__version__)
print(jaxlib.__version__)

tpu-s-2dna7uebo6z96
0.1.64
0.1.45


## TPU Setup

In [2]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

## Confirm Device

In [5]:
from jaxlib import tpu_client_extension
import jax
key = jax.random.PRNGKey(1701)
arr = jax.random.normal(key, (1000,))
device = arr.device_buffer.device()
print(f"JAX device type: {device}")
assert isinstance(device, tpu_client_extension.TpuDevice), "unexpected JAX device type"

JAX device type: TPU_0(host=0,(0,0,0,0))


## Matrix Multiplication

In [6]:
import jax
import numpy as np

# matrix multiplication on GPU
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (3000, 3000))
result = jax.numpy.dot(x, x.T).mean()
print(result)

1.021576


## XLA Compilation

In [8]:
@jax.jit
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
x = jax.random.normal(key, (5000,))
result = selu(x).block_until_ready()
print(result)

[ 0.34676817 -0.7532211   1.7060809  ...  2.120809   -0.42622015
  0.13093244]
