# JAX implementation of XOR

In [None]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit, value_and_grad, vmap

In [None]:
layer_sizes = [2,2,1] #2 nodes in input, 2 nodes in hidden, 1 in output

def init_mlp_params(layer_sizes, key):
    params=[]
    keys = random.split(key, len(layer_sizes) - 1) # 1 key for each layer
    for nin, nout, layer_key in zip(layer_sizes[:-1], layer_sizes[1:], keys):
        w_key, b_key = random.split(layer_key)
        layer_params = {
            'w': random.normal(w_key, (nout, nin)),
            'b': jnp.zeros((nout,))
        }
        params.append(layer_params)
    return params

def mlp_apply(params, inputs): #forward pass
    x = inputs
    for layer_params in params[:-1]:
        z = layer_params['w'] @ x + layer_params['b']
        x = jax.nn.tanh(z)
    final_layer_params = params[-1]
    output = final_layer_params['w'] @ x + final_layer_params['b']
    return output

def batched_loss_fn(params, inputs, targets):
    predictions = vmap(mlp_apply, in_axes=(None,0))(params, inputs)
    return jnp.mean((jnp.squeeze(predictions)-targets) ** 2)

@jit  
def batched_train_step(params, inputs, targets, learning_rate):
    loss, gradients = value_and_grad(batched_loss_fn)(params, inputs, targets)
    # this is a naive way to update, starting next time we will use an optimizer
    new_params = jax.tree_util.tree_map(
        lambda p, g: p - learning_rate * g, params, gradients
    )
    return new_params, loss

In [3]:
def sgd_loss_fn(params, inputs, target):
    prediction = mlp_apply(params, inputs)
    return (jnp.squeeze(prediction) - target) ** 2

@jit
def sgd_train_step(params, inputs, target, learning_rate):
    loss, gradient = value_and_grad(sgd_loss_fn)(params, inputs, target)
    new_params = jax.tree_util.tree_map(
        lambda p,g: p - learning_rate * g, params, gradient
    )
    return new_params, loss

In [6]:
key = random.PRNGKey(42)
batched_params = init_mlp_params(layer_sizes, key)
sgd_params = init_mlp_params(layer_sizes, key)
batch_learning_rate = 0.1
sgd_learning_rate = 0.02
epochs = 1000
X_train = jnp.array([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
Y_train = jnp.array([0,1,1,0])

print("____Batched Gradient Descent____")
for epoch in range(epochs):
    batched_params, loss = batched_train_step(batched_params, X_train, Y_train, batch_learning_rate)
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}")
        
print("\n____SGD Gradient Descent____")
for epoch in range(epochs):
    key, subkey = random.split(key)
    shuffled_indices = random.permutation(subkey, 4)
    X_shuffled = X_train[shuffled_indices]
    Y_shuffled = Y_train[shuffled_indices]
    for x_input, y_target in zip(X_shuffled, Y_shuffled):
        sgd_params, loss = sgd_train_step(sgd_params, x_input, y_target, sgd_learning_rate)
    if epoch % 100 == 0:
            # Use vmap to efficiently apply the model to the entire batch of inputs
            predictions = vmap(mlp_apply, in_axes=(None, 0))(sgd_params, X_train)
            # Calculate the mean squared error over the whole dataset
            epoch_loss = jnp.mean((jnp.squeeze(predictions) - Y_train) ** 2)
            print(f"Epoch {epoch}, Average Loss: {epoch_loss:.4f}")

____Batched Gradient Descent____
Epoch 0, Loss: 2.8005
Epoch 100, Loss: 0.2331
Epoch 200, Loss: 0.1599
Epoch 300, Loss: 0.0455
Epoch 400, Loss: 0.0029
Epoch 500, Loss: 0.0001
Epoch 600, Loss: 0.0000
Epoch 700, Loss: 0.0000
Epoch 800, Loss: 0.0000
Epoch 900, Loss: 0.0000

____SGD Gradient Descent____
Epoch 0, Average Loss: 1.2757
Epoch 100, Average Loss: 0.2412
Epoch 200, Average Loss: 0.2096
Epoch 300, Average Loss: 0.1463
Epoch 400, Average Loss: 0.0538
Epoch 500, Average Loss: 0.0048
Epoch 600, Average Loss: 0.0002
Epoch 700, Average Loss: 0.0000
Epoch 800, Average Loss: 0.0000
Epoch 900, Average Loss: 0.0000


In [7]:
final_predictions = jnp.squeeze(vmap(mlp_apply, in_axes=(None,0))(batched_params, X_train))
binary_predictions = jnp.round(final_predictions)
for i, (inputs, pred, actual) in enumerate(zip(X_train, binary_predictions, Y_train)):
    print(f"Input: {inputs}, Prediction: {pred}, Actual: {actual}")

# You can also see the raw output values
print("\nBatched raw model outputs (logits):")
print(f"{final_predictions} \n")

final_predictions = jnp.squeeze(vmap(mlp_apply, in_axes=(None,0))(sgd_params, X_train))
binary_predictions = jnp.round(final_predictions)
for i, (inputs, pred, actual) in enumerate(zip(X_train, binary_predictions, Y_train)):
    print(f"Input: {inputs}, Prediction: {pred}, Actual: {actual}")

# You can also see the raw output values
print("\nSGD raw model outputs (logits):")
print(final_predictions)

Input: [0. 0.], Prediction: 0.0, Actual: 0
Input: [0. 1.], Prediction: 1.0, Actual: 1
Input: [1. 0.], Prediction: 1.0, Actual: 1
Input: [1. 1.], Prediction: 0.0, Actual: 0

Batched raw model outputs (logits):
[8.3446503e-07 9.9999881e-01 9.9999851e-01 5.9604645e-07] 

Input: [0. 0.], Prediction: 0.0, Actual: 0
Input: [0. 1.], Prediction: 1.0, Actual: 1
Input: [1. 0.], Prediction: 1.0, Actual: 1
Input: [1. 1.], Prediction: 0.0, Actual: 0

SGD raw model outputs (logits):
[7.3909760e-06 9.9999356e-01 9.9999011e-01 6.0796738e-06]
