In [0]:
%tensorflow_version 2.x

TensorFlow 2.x selected.


In [0]:
!pip install --upgrade jax

Collecting jax
[?25l  Downloading https://files.pythonhosted.org/packages/50/f4/d90107c22334c267ccb64e0ea8039018a4740b5dfad1576dd868aac45254/jax-0.1.59.tar.gz (270kB)
[K     |█▏                              | 10kB 26.8MB/s eta 0:00:01[K     |██▍                             | 20kB 6.3MB/s eta 0:00:01[K     |███▋                            | 30kB 8.8MB/s eta 0:00:01[K     |████▉                           | 40kB 5.8MB/s eta 0:00:01[K     |██████                          | 51kB 4.8MB/s eta 0:00:01[K     |███████▎                        | 61kB 5.7MB/s eta 0:00:01[K     |████████▌                       | 71kB 6.2MB/s eta 0:00:01[K     |█████████▊                      | 81kB 7.0MB/s eta 0:00:01[K     |███████████                     | 92kB 7.8MB/s eta 0:00:01[K     |████████████▏                   | 102kB 7.7MB/s eta 0:00:01[K     |█████████████▎                  | 112kB 7.7MB/s eta 0:00:01[K     |██████████████▌                 | 122kB 7.7MB/s eta 0:00:01[K     |██

# JAX 1. Numpy Wrapper

In [0]:
import numpy as np

x = np.ones((5000, 5000))
y = np.arange(5000)

%timeit z = np.sin(x) + np.cos(y)

1 loop, best of 3: 401 ms per loop


In [0]:
import jax.numpy as jnp
x = jnp.ones((5000, 5000))
y = jnp.arange(5000)

%timeit z = jnp.sin(x) + jnp.cos(y)

100 loops, best of 3: 2.15 ms per loop


# JAX 2. JIT Compiler

In [0]:
from jax import jit
import tensorflow as tf

@jit
def fn(x, y):
  z = jnp.sin(x)
  w = jnp.cos(y)
  return z + w

def fn2(x, y):
  z = np.sin(x)
  w = np.cos(y)
  return z + w

@tf.function
def fn3(x, y):
  z = tf.sin(x)
  w = tf.cos(y)
  return z + w

def fn4(x, y):
  z = tf.sin(x)
  w = tf.cos(y)
  return z + w

In [0]:
jx = jnp.ones((5000, 5000))
jy = jnp.ones((5000, 5000))
%timeit fn(jx, jy)

100 loops, best of 3: 2.12 ms per loop


In [0]:
x = np.ones((5000, 5000))
y = np.ones((5000, 5000))
%timeit fn2(x, y)

1 loop, best of 3: 780 ms per loop


In [0]:
tx = tf.ones((5000, 5000))
ty = tf.ones((5000, 5000))
%timeit fn3(tx, ty)

The slowest run took 4.55 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 3: 3.36 ms per loop


In [0]:
tx = tf.ones((5000, 5000))
ty = tf.ones((5000, 5000))
%timeit fn4(tx, ty)

1000 loops, best of 3: 3.36 ms per loop


# JAX 3. grad

In [0]:
from jax import grad

@jit
def simple_fun(x):
  return jnp.sin(x) / x

In [0]:
grad_simple_fun = grad(simple_fun)

In [0]:
%timeit grad_simple_fun(1.0)

1000 loops, best of 3: 1.22 ms per loop


In [0]:
x_range = jnp.arange(10, dtype=jnp.float32)
[grad_simple_fun(xi) for xi in x_range]

[DeviceArray(nan, dtype=float32),
 DeviceArray(-0.30116874, dtype=float32),
 DeviceArray(-0.43539774, dtype=float32),
 DeviceArray(-0.3456775, dtype=float32),
 DeviceArray(-0.11611074, dtype=float32),
 DeviceArray(0.09508941, dtype=float32),
 DeviceArray(0.16778992, dtype=float32),
 DeviceArray(0.09429243, dtype=float32),
 DeviceArray(-0.03364623, dtype=float32),
 DeviceArray(-0.10632458, dtype=float32)]

In [0]:
grad_grad_simple_fun = grad(grad(simple_fun))

In [0]:
%timeit grad_grad_simple_fun(1.0)

The slowest run took 93.35 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 3.19 ms per loop


In [0]:
grad_grad_simple_fun(1.0)

DeviceArray(-0.23913354, dtype=float32)

In [0]:
x_range = jnp.arange(10, dtype=jnp.float32)
[grad_grad_simple_fun(xi) for xi in x_range]

[DeviceArray(nan, dtype=float32),
 DeviceArray(-0.23913354, dtype=float32),
 DeviceArray(-0.01925094, dtype=float32),
 DeviceArray(0.18341166, dtype=float32),
 DeviceArray(0.247256, dtype=float32),
 DeviceArray(0.1537491, dtype=float32),
 DeviceArray(-0.00936072, dtype=float32),
 DeviceArray(-0.12079593, dtype=float32),
 DeviceArray(-0.11525822, dtype=float32),
 DeviceArray(-0.02216326, dtype=float32)]