In [4]:
import jax.numpy as jnp
from jax import vmap, jit, grad
from jax import random

# This is the activation function that 
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

# Define MLP model
def MLP(params, x):
    for W, b in params:
        x = jnp.dot(x, W) + b
        x = sigmoid(x)
    return x

# Define loss function (mean squared error)
def loss(params, x, y):
    y_pred = MLP(params, x)
    return jnp.mean((y - y_pred)**2)

# Initialize random parameters
def init_params(layer_sizes, key):
    keys = random.split(key, len(layer_sizes))
    return [(random.normal(k, (n_in, n_out)),  # weight matrix
             random.normal(k, (n_out,)))      # bias vector
            for k, (n_in, n_out) in zip(keys, zip(layer_sizes[:-1], layer_sizes[1:]))]

# Define update function for gradient descent
@jit  
def update(params, x, y, lr):
    grads = grad(loss)(params, x, y)
    return [(W - lr*dW, b - lr*db) for (W, b), (dW, db) in zip(params, grads)]

# Generate fake data for training
key = random.PRNGKey(0)
x = random.normal(key, (100, 10))
y = random.normal(key, (100, 1))

# Define network architecture of the network. 
layer_sizes = [10, 10, 1]

# Initialize parameters to a default setting
params = init_params(layer_sizes, key)
print("Using Autodiff")
print()
# Train network using gradient descent and we will update the parameters after each iteration
for i in range(1000):
    params = update(params, x, y, 0.1)
    if i % 100 == 0:
        print(loss(params, x, y))
print()


Using Autodiff

1.16791
0.89074874
0.87983733
0.87559766
0.8721167
0.8686884
0.8651581
0.86143035
0.85739547
0.8529395

