# Straight Lines & JAX

First, lets make some _messy_ straight line data, and see if we can recover the original line definition.

We'll use JAX. JAX is a toolkit from DeepMind & Google used for NN research. Unlike `pytorch` and others it isn't a framework, more like a library and an ecosystem. This means we can see "inside" it. At a very basic level, you could think of it as `numpy`, but differentiable.

In [None]:
import jax # Access to the library
import jax.numpy as jnp # Easy access to numpy like functions in jax
import matplotlib.pyplot as plt

Lets create some data that is a straight line, but some random jitter in it.

In [None]:
# random number tracking in JAX
rng = jax.random.PRNGKey(0)
rng, new_key = jax.random.split(rng)

slope = 3
intercept = 2

# Straight line with jitter
n_items = 100
x = jax.random.normal(rng, (n_items,))
jitter = jax.random.normal(new_key, (n_items,))
y = slope * x + intercept + 0.5 * jitter

We have a slope of **3** and an intercept of **2**.

To get ourselves comfortable with this, lets look at the data.

In [None]:
[f"({f_x:0.2f}, {f_y:0.2f})" for f_x, f_y in list(zip(x,y))[0:5]]

Of course - when we have this many points, our brain is not built to understand a sequence of numbers. Our eyes, however, are excellent big-data sensors!

In [None]:
plt.scatter(x, y)
# plt.plot(x, 3 * x + 2, color="red")
plt.xlabel("x")
plt.ylabel("y")
plt.show()

## "Exact" solution

We can code up the derivation:

In [None]:
n = len(x)
beta_1 = (n * jnp.sum(x * y) - jnp.sum(x) * jnp.sum(y)) / (
    n * jnp.sum(x**2) - jnp.sum(x) ** 2
)
beta_0 = (jnp.sum(y) - beta_1 * jnp.sum(x)) / n

And the values of the fit:

In [None]:
print(f"beta_0: {beta_0:.2f}")
print(f"beta_1: {beta_1:.2f}")

Again - our eyes are a lot better here!

In [None]:
plt.scatter(x, y, label="Data", color="black")
plt.plot(x, 3 * x + 2, color="green", label="Real Line")
plt.plot(x, beta_1 * x + beta_0, color="red", label="Fitted Line")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
# Add beta_0 and beta_1 values as text on the plot
plt.text(0.05, 0.95, f"beta_0 = {beta_0:.2f} (exact: {intercept})\nbeta_1 = {beta_1:.2f} (exact: {slope})", transform=plt.gca().transAxes, fontsize=10,
         verticalalignment='top', bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7))
plt.show()

## The Gradient

Lets try gradient descent to solve this straight. First, lets look at how the gradient can be calculated in JAX.

In [None]:
def x2(beta):
    x, y = beta
    return x**2 + 2*y**2

In [None]:
print(f"x2(1, 1)= {x2((1,1))}")
print(f"x2(2, 2)= {x2((2,2))}")
print(f"x2(3, 3)= {x2((3,3))}")

The `jax.grad` function will take the gradient of any function. That is all! A lot of magic happens:

* A *tape* of the function is recorded.
* JAX replays the tape and takes the derivative of every operation
* Applies the product rule. A lot.

In [None]:
grad_of_x2 = jax.grad(x2)

print(f"grad_of_x2(1, 0)= {grad_of_x2((1.0,0.0))}")
print(f"grad_of_x2(2, 0)= {grad_of_x2((2.0,0.0))}")
print(f"grad_of_x2(3, 0)= {grad_of_x2((3.0,0.0))}")
print(f"grad_of_x2(0, 1)= {grad_of_x2((0.0,1.0))}")

* The expected values for the derivative of $x^2 \rightarrow 2x$
* Note that it takes the derivative for each of the arguments - so we get it w.r.t both $x$ and $y$!

### Straight Line Fit

Lets now use this to solve the straight line fit from earlier.

First, lets define our $f(x)$. We'll call it `network` - because it will be a NN in the end.

In [None]:
def network(beta, x):
    b0, b1 = beta
    return b0 + b1 * x

And our loss. We'll use $r$ from earlier - but we'll call it `loss`.

In [None]:
def loss(params, x, y):
    '''Calculate the mean squared error loss.

    Args:
        params (tuple): A tuple containing the model parameters (b0, b1).
        x (array): The input features.
        y (array): The true labels (ground truth).

    Returns:
        float: The mean squared error loss.
    '''
    y_pred = network(params, x)
    return jnp.mean((y_pred - y) ** 2)

Here we define our *network* and *loss* function

* Note the loss function is just the least squares function from before.
* Note how nicely we can look at this and see what we are doing - nothing like the analytical function above!

Now that we know what JAX is going to do, this is a bit anti-climatic.

But - this is a much more sophisticated function that the previous simple `x2`!
* So we should still be impressed!

In [None]:
grad_of_loss = jax.grad(loss)

### Update the parameters

Each iteration we calculate the gradient, and adjust the parameters.

In [None]:
def one_epoch(beta, x, y, i_epoch):
    g = grad_of_loss(beta, x, y)
    beta -= 0.1 * g
    print(f"Step {i_epoch}, loss {loss(beta, x, y)}")
    return beta

Note the `0.1`:

* This is the _learning rate_.
* Adjust it to help converge more or less quickly
* Too large can mean you miss the minimum
* There are sophisticated algorithms that calculate different learning rates on the fly.

### Training

Lets loop 10 times:

In [None]:
beta = jnp.array([1.0, 1.0])

for i_epoch in range(10):
    beta = one_epoch(beta, x, y, i_epoch)

In [None]:
print(f"Final parameters: {beta}")

Lets do a longer training and track the parameters.

In [None]:
beta = jnp.array([1.0, 1.0])
beta_history = [beta]
for i in range(50):
    beta = one_epoch(beta, x, y, i)
    beta_history.append(beta.copy())

### Results

In [None]:
nn_beta_0, nn_beta_1 = beta
print(f"beta_0 (b): {nn_beta_0:.2f} - least squares: {beta_0:.2f}")
print(f"beta_1 (m): {nn_beta_1:.2f} - least squares: {beta_1:.2f}")

In [None]:
# Extract the first coordinate from param_history
first_coordinate = [b[0] for b in beta_history]

# Plot the first coordinate as a function of the epoch number
plt.plot(range(len(first_coordinate)), first_coordinate)
plt.xlabel("Epoch Number")
plt.ylabel(r"$\beta_0$ - Offset Coordinate")
plt.title("Offset Coordinate vs. Epoch Number")
plt.show()

# Extract the first coordinate from param_history
first_coordinate = [b[1] for b in beta_history]

# Plot the first coordinate as a function of the epoch number
plt.plot(range(len(first_coordinate)), first_coordinate)
plt.xlabel("Epoch Number")
plt.ylabel(r"$\beta_1$ - Slope Coordinate")
plt.title("Slope Coordinate vs. Epoch Number")
plt.show()

## With a NN

Lets use a very simple fully connected NN and fit to the same data using JAX.

In [None]:
%pip install flax optax

In [None]:
import flax.linen as nn
import optax
from flax.training import train_state

Lets define a module that will do our fully connected layers

In [None]:
class MLP(nn.Module):
    features: list
    '''Generate a fully connected multi-layer perceptron

       Pass in `features` as a list of integers, where each integer
       specifies the number of neurons in that layer.
    '''

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.Dense(feat)(x)
            x = nn.relu(x)
        x = nn.Dense(self.features[-1])(x)
        return x

mlp = MLP(features=[16, 16, 1])

Next initialize the network/function parameters with random numbers. Also, the input arrays need to be in a slightly funny form.

In [None]:
# Prepare data for Flax (needs shape [N,1])
x_train = x.reshape(-1, 1)
y_train = y.reshape(-1, 1)

# Init the parameters
mlp_key = jax.random.PRNGKey(0)
params = mlp.init(mlp_key, x_train)

Before we had a learning rate of 0.1. Now we will do an optimizer which modifies the learning rate as we go to help us get the best possible solution.

In [None]:
optimizer = optax.adam(learning_rate=0.01)

# Track the state (for optimizers that want to know things like variable momentum).
state = train_state.TrainState.create(apply_fn=mlp.apply, params=params, tx=optimizer)

Finally, the loss function! Which looks just like it did before.

* JIT - "Just In Time" compiling.

In [None]:
@jax.jit
def loss_fn(params, x, y):
    preds = mlp.apply(params, x)
    return jnp.mean((preds - y) ** 2)

And the training loop...

In [None]:
@jax.jit
def train_step(state, x, y):
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params, x, y)
    state = state.apply_gradients(grads=grads)
    return state, loss

Which we can now run...

In [None]:
num_epochs = 100
for epoch in range(num_epochs):
    state, loss = train_step(state, x_train, y_train)
    if epoch % 10 == 0 or epoch == num_epochs - 1:
        print(f"[Flax] Epoch {epoch}, Loss: {loss:.4f}")

And the results...

In [None]:
# Plot the fit
x_plot = jnp.linspace(jnp.min(x), jnp.max(x), 100).reshape(-1, 1)
y_pred = mlp.apply(state.params, x_plot).flatten()
plt.scatter(x, y, label="Data", color="black", alpha=0.5)
plt.plot(x_plot.flatten(), y_pred, color="purple", label="Flax NN Fit")
plt.plot(x_plot.flatten(), 3 * x_plot.flatten() + 2, color="green", label="Real Line")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.title("Flax NN Fit vs. True Line")
plt.show()

NOTE: We lost the interpretation of what was going on - we don't have a slop!!!!!!!!!

In [None]:
# Plot the fit
x_plot = jnp.linspace(jnp.min(x), jnp.max(x), 100).reshape(-1, 1)
y_pred = mlp.apply(state.params, x_plot).flatten()
plt.plot(x_plot.flatten(), y_pred, color="purple", label="Flax NN Fit")
plt.plot(x_plot.flatten(), 3 * x_plot.flatten() + 2, color="green", label="Real Line")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.title("Flax NN Fit vs. True Line")
plt.show()

And it isn't flat! Where is the PHYSICS!!?!??

## Over Fitting

Lets create a second dataset that is _independent_ of the first, and see how the loss of that varies over time.

In [None]:
new_key2, _ = jax.random.split(new_key)
new_key3, _ = jax.random.split(new_key2)

x2 = jax.random.normal(new_key3, (1000,))
jitter2 = jax.random.normal(new_key2, (1000,))
y2 = slope * x2 + intercept + 0.5 * jitter2

x_test = x2.reshape(-1, 1)
y_test = y2.reshape(-1, 1)

In [None]:
# Reset the state to be random
params2 = mlp.init(mlp_key, x_train)
state2 = train_state.TrainState.create(apply_fn=mlp.apply, params=params2, tx=optimizer)

# Run the training, tracking the loss for plotting.
num_epochs = 500
training_loss_by_epoch = []
testing_loss_by_epoch = []
for epoch in range(num_epochs):
    state2, loss = train_step(state2, x_train, y_train)
    training_loss_by_epoch.append(loss)

    y_test_pred = mlp.apply(state2.params, x_plot)
    test_loss = loss_fn(state2.params, x_test, y_test)
    testing_loss_by_epoch.append(test_loss)

    if epoch % 10 == 0 or epoch == num_epochs - 1:
        print(f"[Flax] Epoch {epoch}, Loss: {loss:.4f}, test loss: {test_loss:.4f}")

Lets plot the test and training loss

In [None]:
# Plot training and testing loss over epochs
plt.plot(range(num_epochs), training_loss_by_epoch, label="Training Loss")
plt.plot(range(num_epochs), testing_loss_by_epoch, label="Testing Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.yscale('log')
plt.title("Training and Testing Loss vs. Epoch")
plt.legend()
plt.show()

In [None]:
# Plot the new fit...
x_plot = jnp.linspace(jnp.min(x), jnp.max(x), 100).reshape(-1, 1)
y_pred_train = mlp.apply(state.params, x_plot).flatten()
y_pred_extra_fit = mlp.apply(state2.params, x_plot).flatten()
plt.plot(x_plot.flatten(), y_pred_train, color="purple", label="Flax NN Fit - Short Training")
plt.plot(x_plot.flatten(), y_pred_extra_fit, color="blue", label="Flax NN Fit - Long Training")
plt.plot(x_plot.flatten(), 3 * x_plot.flatten() + 2, color="green", label="Real Line")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.title("Flax NN Fit vs. True Line")
plt.show()