JAX as Accelerated Numpy

https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html

In [None]:
import jax
import jax.numpy as jnp

In [None]:
x = jnp.arange(10)
x

In [None]:
big_vector = jnp.arange(int(1e7))

%timeit jnp.dot(big_vector, big_vector)

In [None]:
def mse(y, t):
    return jnp.mean((y-t)**2)

key_seed = jax.random.PRNGKey(420)
y = jax.random.randint(key_seed,(1,4),-2,2).astype(jnp.float32)
t = y + 0.01*jax.random.normal(key_seed, (1,4))
print(f'y = {y} \nt = {t}')

print(f'mse {mse(y, t)}')

dmse_dy = jax.grad(mse, argnums=0)
dmse_dt = jax.grad(mse, argnums=1)
dmse_dydt = jax.grad(mse, argnums=(0, 1))

print(f'dmse_dy {dmse_dy(y,t)}')
print(f'dmse_dt {dmse_dt(y,t)}')
print(f'dmse_dydt {dmse_dydt(y,t)}')


In [None]:
jax.value_and_grad(mse)(y, t)

In [None]:
def mse_with_aux_loss(y, t):
    return mse(y, t), jnp.abs(t - y)

# Auxiliary loss or any additional output of function
jax.value_and_grad(mse_with_aux_loss, has_aux=True)(y, t)

In [None]:
def illegal_inplace(x):
    x[0] = 1
    return None

def legal_inplace(x):
    return x.at[0].set(1)

x = jnp.asarray([0, 0, 0])

# illegal_inplace(x)
legal_inplace(x)

In [None]:
import numpy as np
import matplotlib.pyplot as plt 

In [None]:
mu=0.5
std=0.1
key_seed = jax.random.PRNGKey(69)
x = mu + std*jax.random.normal(key_seed, (100,))
noise = 0.01*jax.random.normal(key_seed, (100,))

# Simple linear function, lets see if we can learn it
y = x * 3 - noise
plt.scatter(x, y)

In [None]:
def model(params, input):
    """ y = wx + b """
    return params[0]*input + params[1]

def loss(params, input, target):
    """ Mean squared error. """
    pred = model(params, input)
    return jnp.mean((pred - target)**2)

def update(params, input, target, lr=0.1):
    """ A single gradient step. """
    return params - lr * jax.grad(loss)(params, input, target)

# initialize parameters of model
params = jnp.asarray([0, 0]).astype(jnp.float32)

plt.scatter(x, y)
for i in range(1000):
    if i % 50 == 0:
        print(f'step {i} params {params}')
        plt.plot(x, model(params, x))
    params = update(params, x, y)
