<a href="https://colab.research.google.com/github/kbrezinski/JAX-Practice/blob/main/introduction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax.numpy as jnp
import numpy as np

# special transformations
from jax import grad, jit, vmap, pmap
# JAX low level APIs
from jax import lax, make_jaxpr, random, device_put

In [None]:
# show comptability with numpy arrays
x = jnp.linspace(0, 10, 1000)
y = 2 * np.sin(x) * np.cos(x)

In [None]:
# how to set values on immutable JAX arrays
x = jnp.arange(10)
y = x.at[4].set(-1) 
y

DeviceArray([ 0,  1,  2,  3, -1,  5,  6,  7,  8,  9], dtype=int32)

In [3]:
# in Jax random numbers are stateful, need to be passed in as args
k = random.PRNGKey(2021)
x = random.normal(k, (10,))
x

DeviceArray([-1.9064336 ,  0.9475057 , -0.0449216 , -0.5956921 ,
              0.92937005, -0.33263102, -2.6711109 ,  0.6385125 ,
              0.55228376,  1.7470944 ], dtype=float32)

In [None]:
x = random.normal(k, (4, 4), dtype=jnp.float32) # automatically cast to acccelerated by default
%timeit jnp.dot(x, x.T).block_until_ready() # runs faster than numpy because its being accelerated on GPU and no overhead
x = device_put(np.random.normal(size=(4, 4))).astype(np.float32) # push numpy to GPU

# block_until_ready() waits until the completion is done using asynchronous dispatch

The slowest run took 8.31 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 5: 322 µs per loop


In [None]:
# jit caches the intermediate results and make computation faster
def selu(x, a=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, a * jnp.exp(x) - a)

selu_jit = jit(selu)
x = random.normal(k, (1_000_000,))
%timeit selu_jit(x).block_until_ready()

The slowest run took 673.03 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 5: 113 µs per loop


In [None]:
def sum_logistic(x):
  #return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
  return jnp.sum(x**2) # 2*x1 + 2*x2 + 2*x3

# [1, 2, 3]
x = jnp.arange(3.)
loss = sum_logistic # rename

# wrap loss fn around grad
grad_loss = grad(loss)
# determine loss
grad_loss(x) # why does it output as an array?

DeviceArray([0., 2., 4.], dtype=float32)

In [None]:
## Manual gradients
x = 1.
y = 1.

# function is a product of x and y
f = lambda x, y: x**2 + x + 4 + y**2

# passing in argnums will tell grad to differentiate wrt y instead
dfdx = grad(f, argnums=(1)) # 2*x + 1
d2fdx = grad(dfdx, argnums=(1)) # 2
d3fdx = grad(d2fdx, argnums=(1)) # 0

print(f(x, y), dfdx(x, y), d2fdx(x, y), d3fdx(x, y))

7.0 2.0 2.0 0.0


In [None]:
from jax import jacfwd, jacrev # find jacobians for deritives

f = lambda x, y: x**2 + y**2
# Jacobian = [df/dx, df/dy]
# Hessian = [[d2f/dx, d2f/dxdy], [d2f/dydx, d2f/dy]]

def hessian(f):
  return jit(jacfwd(jacrev(f, argnums=(0,1)), argnums=(0,1)))

jacrev(f, argnums=(0,1))(1.,1.) # Jacobian = [2, 2]
hessian(f)(1.,1.) # Hessian [[2,0],[0,2]]

((DeviceArray(2., dtype=float32), DeviceArray(0., dtype=float32)),
 (DeviceArray(0., dtype=float32), DeviceArray(2., dtype=float32)))

In [None]:
f = lambda x: abs(x)

# grad is smart enough to take the derivitive wrt x of the abs(0.) and return 1 even though its not differentiable
dfdx = grad(f)
print(dfdx(0.), dfdx(-1.))

1.0 -1.0


In [5]:
W = random.normal(k, (150, 100)) # weights that are of size (150, 100)
batched_x = random.normal(k, (10, 100)) # 10 samples of 100 features

# slow implementation, will not work because of sizes
def apply_matrix(W, x):
  return jnp.dot(W, x)

# very slow using loops, need to vectorize
def naively_batched_apply_matrix(batched_x):
  return jnp.stack([apply_matrix(x) for x in batched_x])

# really fast jit; does not account for single or batched sizes of 1 or 100
@jit
def batched_apply_matrix(batched_x):
  return jnp.dot(batched_x, W.T)

# apply vmap to single instance fn, then use batched on it.
@jit
def vmap_batched_apply_matrix(W, batched_x):
  return vmap(apply_matrix, in_axes=(None, 0))(W, batched_x)  # None because W has no batched dimension, 0 because the zeroth dim is the batch

%timeit vmap_batched_apply_matrix(W, batched_x).block_until_ready()

The slowest run took 80.35 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 5: 93.1 µs per loop


In [7]:
## Lax is stricter

jnp.add(1, 1.)
#lax.add(1, 1.)  # error diff data types


TypeError: ignored