# Linear regression

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import optax

from utils import (
    make_linear_data,
    MLP,
    AutoEncoder,
    mse
)

## Setup

In [None]:
# create initial key
key = jax.random.key(0)

In [None]:
# create problem setup
x_dim = 10
y_dim = 2

key, *subkeys = jax.random.split(key, num=3)

W = jax.random.normal(subkeys[0], (x_dim, y_dim))
b = jax.random.normal(subkeys[1], (y_dim,))

In [None]:
# create training data
num_train = 200
num_val = 30

key, *subkeys = jax.random.split(key, num=3)

x_train, y_train = make_linear_data(subkeys[0], num_train, W, b)
x_val, y_val = make_linear_data(subkeys[1], num_val, W, b)

print(f'X shape (train): {x_train.shape}')
print(f'y shape (train): {y_train.shape}')

print(f'\nX shape (val.): {x_val.shape}')
print(f'y shape (val.): {y_val.shape}')

In [None]:
# create nested dict
nested_dict = {
    'params': {
        'kernel': W,
        'bias': b
    }
}

# make immutable pytree
true_params = flax.core.freeze(nested_dict)

print(jax.tree_util.tree_map(lambda x: x.shape, true_params))

## Model

In [None]:
# create linear model
model = nn.Dense(features=y_dim)

In [None]:
# create parameters
key, *subkeys = jax.random.split(key, num=3)

params = model.init(
    subkeys[0],
    jax.random.normal(subkeys[1], (1, x_dim)) # example inputs
)

print(jax.tree_util.tree_map(lambda x: x.shape, params))

## Training

In [None]:
@jax.jit
def mse_loss(params, x, y):
    y_pred = model.apply(params, x)
    loss = mse(y_pred, y)
    return loss

# create function to compute loss and its gradients
loss_and_grad = jax.value_and_grad(
    mse_loss,
    argnums=0 # differentiate w.r.t. first argument
)

In [None]:
num_epochs = 200
learning_rate = 0.03

# compute initial loss
val_loss = mse_loss(params, x_val, y_val)
print('Before training, val. loss: {:.2e}'.format(val_loss))

# initialize optimizer
optimizer = optax.adam(learning_rate=learning_rate)
opt_state = optimizer.init(params)

# perform training epochs
for idx in range(num_epochs):

    # compute loss and gradients
    loss, grads = loss_and_grad(params, x_train, y_train)

    # update parameters and optimizer
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    # print losses
    if (idx + 1) % 10 == 0 or (idx + 1) == num_epochs:
        val_loss = mse_loss(params, x_val, y_val)
        print('Epoch {}, batch loss: {:.2e}, val. loss: {:.2e}'.format(idx + 1, loss, val_loss))