In [22]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [23]:
import jax 
import jax.numpy as jnp
from jax import random
import numpy as np
import scipy as sp

In [29]:
print('LU Decomp, SciPy ====')
A = np.random.rand(1000, 1000)
%timeit sp.linalg.lu(A)

print('LU Decomp, Jax (TPU) ====')
key = random.PRNGKey(42)
A = random.normal(key, (1000,1000))
%timeit jax.scipy.linalg.lu(A)

LU Decomp, Numpy ====
10 loops, best of 5: 76.5 ms per loop
LU Decomp, Jax (TPU) ====
100 loops, best of 5: 5.76 ms per loop


In [25]:
print('QR Decomp, Numpy ====')
A = np.random.rand(1000, 1000)
%timeit np.linalg.qr(A)

print('QR Decomp, Jax (TPU) ====')
A = random.normal(key, (1000,1000))
%timeit jnp.linalg.qr(A)

QR Decomp, Numpy ====
1 loop, best of 5: 338 ms per loop
QR Decomp, Jax (TPU) ====
100 loops, best of 5: 6.01 ms per loop


In [26]:
print('Matrix Multiplication, Numpy ====')
A = np.random.rand(3000, 3000)
B = np.random.rand(3000, 3000)
%timeit np.dot(A, B)

print('Matrix Multiplication, Jax (TPU) ====')
A = random.normal(key, (3000,3000))
B = random.normal(key, (3000,3000))
A_jax = jax.device_put(A)
B_jax = jax.device_put(B)
jnp.dot(A_jax, B_jax).block_until_ready()
%timeit jnp.dot(A_jax, B_jax).block_until_ready()

Matrix Multiplication, Numpy ====
1 loop, best of 5: 2.73 s per loop
Matrix Multiplication, Jax (TPU) ====
100 loops, best of 5: 4.47 ms per loop


In [28]:
def element_wise_ops(A, B):
  return (A * B) + B

print('Element-wise Operations, Numpy ====')
A = np.random.rand(1000, 1000)
B = np.random.rand(1000, 1000)
%timeit element_wise_ops(A, B)

print('Element-wise Operations, Jax (TPU) ====')
A = random.normal(key, (1000, 1000))
B = random.normal(key, (1000, 1000))
A_jax = jax.device_put(A)
B_jax = jax.device_put(B)
element_wise_jit = jax.jit(element_wise_ops)
%timeit element_wise_jit(A_jax, B_jax)

Element-wise Operations, Numpy ====
100 loops, best of 5: 3.45 ms per loop
Element-wise Operations, Jax (TPU) ====
The slowest run took 33.16 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 5: 1.82 ms per loop
