# Tutorial #4: Multilayer Perceptrons

Multilayer perceptrons (MLPs), also known as feedforward neural networks, apply a series of linear transformations followed by non-linear activation functions. At each layer in an MLP, the hidden units (neurons) $\boldsymbol{h}$ are given by $$\boldsymbol{h} = g(\boldsymbol{W} \boldsymbol{x} + \boldsymbol{b})$$ where $\boldsymbol{x}$ is the input from the previous hidden layer and $g$ is a non-linear function. MLPs can be used either in regression, where the output layer is linear layer, or in classification, where the output layer is a softmax. 

A good resource on MLPs is Chapter 6 of the [Deep Learning Book](https://www.deeplearningbook.org/).

We'll start by replacing the linear regression model on the synthetic dataset from tutorial #1 with an MLP. We'll then train an MLP to perform classification on the MNIST dataset of handwritten digits. 

In [1]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax
import jax.random as random

### 3.1: MLP for 1D regression on a synthetic dataset

We'll begin by recreating the data used in tutorial #1.

In [None]:
def ground_truth(x):
    return 3*x - 0.2*x**2 - 0.05 * x**3

def generate_data(key, N_data, L):
    key1, key2 = random.split(key)
    x = random.uniform(key1,(N_data,)) * L
    y = ground_truth(x) + random.normal(key2, (N_data,))
    return x, y

In [None]:
# plot ground truth and data

L = 5 # domain is from 0 to 5
N_data = 20
N_plot = 100
x_plot = jnp.linspace(0,L,N_plot)

key = random.PRNGKey(0)
x_data, y_data = generate_data(key, N_data=N_data, L=L)

plt.plot(x_plot, ground_truth(x_plot), label='ground truth')
plt.scatter(x_data, y_data, color='red', marker='x', label='data')
plt.legend()
plt.show()

We'll create an MLP with one input variable $x$, three hidden layers with five hidden units each, ReLU activation functions, and one output variable $y$.

In [None]:
from flax import nnx
import flax
import optax

In [None]:
class scalarMLP(nnx.Module):
    def __init__(self, dhiddens: list[int], rngs: nnx.Rngs):
        self.linear_in = nnx.Linear(1, dhiddens[0], rngs=rngs)
        self.layers = []
        for j in range(len(dhiddens)-1):
            self.layers.append(nnx.Linear(dhiddens[j], dhiddens[j+1], rngs=rngs))
        self.linear_out = nnx.Linear(dhiddens[-1], 1, rngs=rngs)

    def __call__(self, x):
        x = nnx.relu(self.linear_in(x))
        for layer in self.layers:
            x = nnx.relu(layer(x))
        return self.linear_out(x)[0]

In [None]:
rngs = nnx.Rngs(0)
model = scalarMLP([5, 5, 5], rngs)
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
nnx.display(model)

In [None]:
@nnx.jit
def train_step(model, optimizer, x, y):
    def loss_fn(model):
        y_pred = nnx.vmap(model)(x)
        return jnp.mean((y - y_pred)**2)

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)

    return loss

In [None]:
X = x_data.reshape(N_data, 1)
Y = y_data.reshape(N_data)

try:
    loss = train_step(model, optimizer, X, y_data)
    print(loss)
except:
    print("shape of x_data and y_data incorrect")

In [None]:
N_train = 20000
losses = []
for _ in range(N_train):
    loss = train_step(model, optimizer, X, Y)
    losses.append(loss)

In [None]:
plt.plot(losses)
plt.show()

In [None]:
plt.plot(x_plot, ground_truth(x_plot), color='blue', label='ground truth')
plt.scatter(x_data, y_data, color='red', marker='x', label='data')
plt.plot(x_plot, nnx.vmap(model)(x_plot.reshape(N_plot,1)), color='green', label='Trained MLP')
plt.legend()
plt.show()

#### 3.1.1: Varying the complexity of MLPs

We saw that our MLP with three hidden layers and five hidden units gives a piecewise linear function. What happens as we change the number of hidden units and number of layers?

In [None]:
N_train = 20000
rngs = nnx.Rngs(0)
hdims_list = [[3,3],[5,5,5],[7,7,7,7],[9,9,9,9,9]]
models = []
optimizers = []

for hdims in hdims_list:
    model = scalarMLP(hdims, rngs)
    models.append(model)
    optimizers.append(nnx.Optimizer(model, optax.adam(1e-3)))

for j, model in enumerate(models):
    optimizer = optimizers[j]
    print(j)
    for _ in range(N_train):
        loss = train_step(model, optimizer, X, Y)

In [None]:
labels = []
for hdims in hdims_list:
    labels.append("MLP: {} units, {} layers".format(hdims[0], len(hdims)))

plt.plot(x_plot, ground_truth(x_plot), color='blue', label='ground truth')
plt.scatter(x_data, y_data, color='red', marker='x', label='data')
for j, model in enumerate(models):
    plt.plot(x_plot, nnx.vmap(model)(x_plot.reshape(N_plot,1)), label=labels[j])
plt.legend()
plt.show()

As we can see, the deeper networks with more hidden units result in more complex functions than the smaller networks. The MLP with 3 hidden units and 2 layers simply results in a linear model.

### 3.2: MLP for MNIST

First we need to load the MNIST dataset. We'll use `tfds` and prepare the dataset for `num_epochs` training epochs with a batch size of 32.