In [54]:
import jax
import jax.numpy as jnp
import jax.random as random
from jax import grad, vmap

In [45]:
batch_size = 32
n_train = 5000
n_valid = 1000

In [46]:
# First define a basic neural net
def relu_layer(params, x):
    pre_act = jnp.dot(params[0], x) + params[1]
    return jnp.maximum(0, pre_act)

def forward_pass(params, in_array):
    x = in_array
    
    for w, b in params[:-1]:
        x = relu_layer([w, b], x)
    
    final_w, final_b = params[-1]
    return jnp.dot(final_w, x) + final_b

batch_forward_pass = vmap(forward_pass, in_axes=(None, 0), out_axes=0)

In [81]:
# Make some synthetic data
def make_data(rng_key, n_rows):
    
    rng_key, subkey = random.split(rng_key)
    x_1 = random.normal(key=subkey, shape=(n_rows, 1))
    rng_key, subkey = random.split(rng_key)
    x_2 = random.uniform(key=subkey, shape=(n_rows, 1), minval=-0.5, maxval=+0.5)
    
    y_expected = 0.5*x_1 + jnp.cos(x_2)
    
    rng_key, subkey = random.split(rng_key)
    eps = 0.1*random.normal(key=subkey)
    y = y_expected + eps
    
    x_features = jnp.hstack([x_1, x_2])
    
    return (x_features, y)

In [84]:
key = jax.random.PRNGKey(123)
train_key, valid_key = random.split(key)
training_data = make_data(train_key, n_train)
validation_data = make_data(valid_key, n_valid)