In [3]:
import jax
devices = jax.devices()

for item in devices:
  print(item)

TPU_0(process=0,(0,0,0,0))
TPU_1(process=0,(0,0,0,1))
TPU_2(process=0,(1,0,0,0))
TPU_3(process=0,(1,0,0,1))
TPU_4(process=0,(0,1,0,0))
TPU_5(process=0,(0,1,0,1))
TPU_6(process=0,(1,1,0,0))
TPU_7(process=0,(1,1,0,1))


In [4]:
import jax.numpy as jnp
from jax import random
from jax import grad, jit
import numpy as np

key = random.PRNGKey(0)

In [5]:
# runs on CPU - numpy
size = 5000
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)

138 ms ± 584 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [6]:
# runs on CPU - JAX
size = 5000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

13.4 ms ± 65.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:

# runs on TPU
size = 5000
x = random.normal(key, (size, size), dtype=jnp.float32)
%time x_jax = jax.device_put(x)  # 1. measure JAX device transfer time
%time jnp.dot(x_jax, x_jax.T).block_until_ready()  # 2. measure JAX compilation time
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready() # 3. measure JAX running time

CPU times: user 75 µs, sys: 110 µs, total: 185 µs
Wall time: 231 µs
CPU times: user 1.92 ms, sys: 2.02 ms, total: 3.94 ms
Wall time: 14.8 ms
13.4 ms ± 23.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [19]:
from time import time

a = time()
x = jnp.dot(x_jax, x_jax.T)
b = time()

m = b-a
print(b-a)

0.001363992691040039


In [20]:
a = time()
y = np.dot(x, x.T)
b = time()

n = b-a
print(b-a)

0.309955358505249


In [22]:
print(n/m)

227.24121657052962


In [35]:
def selu_np(x, alpha=1.67, lmbda=1.05):
  return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

@jax.jit
def selu_jax(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [36]:
# runs on the CPU - numpy


x = np.random.normal(size=(1000000,)).astype(np.float32)

a = time()
selu_np(x)
b = time()

print(b-a)

0.007717132568359375


In [38]:
# runs on the CPU - JAX
x = random.normal(key, (1000000,))
x = jax.device_put(x)
selu_jax(x).block_until_ready() # 1. measure JAX compilation time

a = time()
selu_jax(x).block_until_ready() # 2. measure JAX runtime
b = time()

print(b-a)

0.0007071495056152344
