# Linear regression
From [this](https://flax.readthedocs.io/en/v0.6.10/guides/jax_for_the_impatient.html#full-example-linear-regression).

In [None]:
import jax
from jax import random
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

jax.config.update("jax_enable_x64", True)

In [None]:
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = x_samples @ W + b + 0.1 * random.normal(key_noise, (n_samples, y_dim))
print("x shape:", x_samples.shape, "; y shape:", y_samples.shape)

## Without pytrees

In [None]:
# Loss function: Mean squared error.
@jax.jit
def mse(W, b, x_batched, y_batched):
    y_hat = x_batched @ W + b  # returns NxP
    err = 0.5 * jnp.sum((y_hat - y_batched) ** 2, axis=1)  # NxP, NxP -> (N, )
    return jnp.mean(err)  # (N, ) -> ()


# Ensure we jit the largest-possible jittable block.
@jax.jit
def update_params(W, b, x, y, lr):
    # we can get the loss for free with autodiff
    loss, grad = jax.value_and_grad(mse, argnums=(0, 1))(W, b, x, y)
    W, b = W - lr * grad[0], b - lr * grad[1]
    return W, b, loss

In [None]:
# Initialize estimated W and b with zeros.
W_hat = jnp.zeros_like(W)
b_hat = jnp.zeros_like(b)

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(W, b, x_samples, y_samples))
for i in range(101):
    W_hat, b_hat, loss = update_params(
        W_hat, b_hat, x_samples, y_samples, learning_rate
    )
    if i % 5 == 0:
        # printing the loss introduces overhead of the device syncing with the host (Python process)
        print(f"Loss step {i}: ", loss)
print(f"\nW:{W}\nW_hat:{W_hat}\n\nb:{b}\nb_hat:{b_hat}")
print("Diff:", jnp.linalg.norm(W - W_hat))

In [None]:
## SOLUTION
"""
# Linear feed-forward.
def predict(W, b, x):
  return jnp.dot(x, W) + b

# Loss function: Mean squared error.
def mse(W, b, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    y_pred = predict(W, b, x)
    return jnp.inner(y-y_pred, y-y_pred) / 2.0
  # We vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = predict(W, b, x_samples) + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

# Initialize estimated W and b with zeros.
W_hat = jnp.zeros_like(W)
b_hat = jnp.zeros_like(b)

# Ensure we jit the largest-possible jittable block.
@jax.jit
def update_params(W, b, x, y, lr):
  W, b = W - lr * jax.grad(mse, 0)(W, b, x, y), b - lr * jax.grad(mse, 1)(W, b, x, y)
  return W, b

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(W, b, x_samples, y_samples))
for i in range(101):
  # Perform one gradient update.
  W_hat, b_hat = update_params(W_hat, b_hat, x_samples, y_samples, learning_rate)
  if (i % 5 == 0):
    print(f"Loss step {i}: ", mse(W_hat, b_hat, x_samples, y_samples))
"""
print("Orginal solution")

## With pytrees

In [None]:
# Loss function: Mean squared error.
@jax.jit
def mse(params, x_batched, y_batched):
    y_hat = x_batched @ params["W"] + params["b"]  # returns NxP
    err = 0.5 * jnp.sum((y_hat - y_batched) ** 2, axis=1)  # NxP, NxP -> (N, )
    return jnp.mean(err)  # (N, ) -> ()


# Ensure we jit the largest-possible jittable block.
@jax.jit
def update_params(params, x, y, lr):
    # we can get the loss for free with autodiff
    loss, grad = jax.value_and_grad(mse)(params, x, y)
    params = jax.tree_util.tree_map(lambda p, g: p - lr * g, params, grad)
    return params, loss

In [None]:
params = {"W": jnp.zeros_like(W), "b": jnp.zeros_like(b)}

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(params, x_samples, y_samples))
for i in range(101):
    # need to pipe the params in and out of the update because JAX works
    # with stateless computations https://docs.jax.dev/en/latest/stateful-computations.html
    params, loss = update_params(params, x_samples, y_samples, learning_rate)
    if i % 5 == 0:
        # printing the loss introduces overhead of the device syncing with the host (Python process)
        print(f"Loss step {i}: ", loss)
print(f"\nW:{W}\nW_hat:{W_hat}\n\nb:{b}\nb_hat:{b_hat}")
print("Diff:", jnp.linalg.norm(W - W_hat))

## Least squares

In [None]:
x_samples_ones_stacked = jnp.c_[x_samples, jnp.ones(x_samples.shape[0])]
assert x_samples_ones_stacked.shape == (x_samples.shape[0], x_samples.shape[1] + 1)
W_hat_b_hat = jnp.linalg.lstsq(x_samples_ones_stacked, y_samples)[0]
W_hat = W_hat_b_hat[:-1, :]
print("Diff:", jnp.linalg.norm(W - W_hat))