# Tutorial \#1: Linear Regression

In this tutorial, we'll demonstrate the most basic example of statistical learning, linear regression. We'll perform linear regression in three ways: (1) analytically (2) using `scikit-learn`, and (3) using `flax`. 

Linear regression assumes that the data is generated from the equation $$y = w x + b + \epsilon$$ where $w$ and $b$ are the parameters of the model and $\epsilon$ represents some noise with an expected value of zero. The goal is to find the params $w$ and $b$ which most accurately describe future data, by using observed data. 

A good explanatory resource on linear regression is the lecture notes from [COS324](https://www.cs.princeton.edu/courses/archive/spring19/cos324/) at Princeton.
* [Ordinary Least Squares Linear Regression](https://www.cs.princeton.edu/courses/archive/spring19/cos324/files/linear-regression.pdf)
* [Maximum Likelihood Linear Regression](https://www.cs.princeton.edu/courses/archive/spring19/cos324/files/mle-regression.pdf)
* [Least squares regression with non-linear features](https://www.cs.princeton.edu/courses/archive/spring19/cos324/files/basis-functions.pdf)

In [None]:
import jax
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt

We'll assume that our data is 1D and drawn from a distribution in which Gaussian noise is added to a non-linear ground truth distribution.

In [None]:
def ground_truth(x):
    return 3*x - 0.2*x**2 - 0.05 * x**3

def generate_data(key, N_data, L):
    key1, key2 = random.split(key)
    x = random.uniform(key1,(N_data,)) * L
    y = ground_truth(x) + random.normal(key2, (N_data,))
    return x, y

We draw a sample from the above distribution and plot the data below.

In [None]:
# plot ground truth and data

L = 5 # domain is from 0 to 5
N_data = 20
x_plot = jnp.linspace(0,L,100)

key = random.PRNGKey(0)
x_data, y_data = generate_data(key, N_data=N_data, L=L)

plt.plot(x_plot, ground_truth(x_plot), label='ground truth')
plt.scatter(x_data, y_data, color='red', marker='x', label='data')
plt.legend()
plt.show()

### Method 1: Analytically calculate weights and bias

Assuming the loss function $$L = ||\boldsymbol{X}\boldsymbol{w}-\boldsymbol{y}||^2$$ we can derive the optimal value of $\boldsymbol{w}$ given by $$\boldsymbol{w} = (\boldsymbol{X}^T\boldsymbol{X})^{-1}\boldsymbol{X}^T \boldsymbol{y}$$ We can rewrite the parameters $w$ and $b$ into a vector $\boldsymbol{w} = [w, b]$ and append a $1$ to the data, so that our data is described by a matrix $\boldsymbol{X} = [\boldsymbol{x}, \boldsymbol{1}]$.

In [None]:
X = jnp.concatenate([x_data[:,None], jnp.ones(N_data)[:,None]],axis=1)
print(X.shape)
print(y_data.shape)

We can now calculate the optimal (MLE) value of $\boldsymbol{w}$ using the above equation.

In [None]:
w_mle = jnp.linalg.inv(X.T @ X) @ X.T @ y_data
print(w_mle)
w, b = w_mle

Although we computed the inverse directly, note that generally it is advised to use `jax.scipy.linalg.solve()` over `jax.numpy.linalg.inv()`.

We can now plot our optimal values of $w$ and $b$ compared to the ground-truth data.

In [None]:
plt.plot(x_plot, ground_truth(x_plot), color='blue', label='ground truth')
plt.scatter(x_data, y_data, color='red', marker='x', label='data')
plt.plot(x_plot, w * x_plot + b, color='green', label='Learned linear model')
plt.legend()
plt.show()

### Method 2: Linear Regression with scikit-learn

`scikit-learn` makes it extremely easy to fit linear regression models. 

In [None]:
from sklearn.linear_model import LinearRegression

We'll first try fitting the data using 1D arrays.

In [None]:
try:
    reg = LinearRegression().fit(x_data, y_data)
except:
    print("ValueError: Expected 2D array, got 1D array instead")

`scikit-learn` expects the input data $X$ to come in a 2D array. Let's instead pass a 2D array.

In [None]:
X = x_data[:, None]
reg = LinearRegression().fit(X, y_data)

In [None]:
w_skl, b_skl = reg.coef_, reg.intercept_
print(w_skl, b_skl)

We can see that we get the same values of $w$ and $b$ as before.

In [None]:
plt.plot(x_plot, ground_truth(x_plot), color='blue', label='ground truth')
plt.scatter(x_data, y_data, color='red', marker='x', label='data')
plt.plot(x_plot, w_skl[0] * x_plot + b_skl, color='green', label='Learned linear model, scikit-learn')
plt.legend()
plt.show()

### Method 3: Using `flax` to minimize the MSE loss function using gradient descent

Since the mean squared error (MSE) loss function is convex and has an analytic solution, we don't need to perform gradient descent to find a minimum. However, if we wanted to, we could also perform gradient descent to minimize the loss function. This would give the same result as computing the loss analytically.

We'll use `flax` to implement and train a linear model. We will use this same code structure to train and implement more complicated neural network models in later tutorials. 

In [None]:
from flax import nnx
import optax

#### 3.1: Understanding `Rngs` in `flax`

PRNG works a little differently in `flax` than in JAX. We initialize a `nnx.Rngs` object as follows: `rngs = nnx.Rngs(seed)` where `seed` is an `int`.

Look at the following code below to see how `nnx.Rngs` works.

In [None]:
rngs = nnx.Rngs(0) # seed is 0 for all streams
print("The first time rngs is called, it gives a key")
print(rngs.params())
print("The key automatically changes each time rngs is called")
print(rngs.dropout())
print("I can call whatever stream I want")
print(rngs.random_stream_whatever_I_want())

# Different streams can have different keys
rngs = nnx.Rngs(0, params=1) # seed is 0 for all streams except for params stream
print("Params stream has seed of 1, has different key")
print(rngs.params()) # print params stream
print("Other streams have seeds of 2, have same keys as before")
print(rngs.dropout()) # print dropout stream
print(rngs.random_stream_whatever_I_want())

#### 3.2: Create linear model in `flax`

We'll now create a subclass of `nnx.Module` which represents our linear regression model. We need to implement `__init__` and `__call__` methods.  

In [None]:
class LinearRegression(nnx.Module):
    def __init__(self, din: int, rngs: nnx.Rngs):
        key = rngs.params()
        self.w = nnx.Param(random.normal(key, (din,)))
        self.b = nnx.Param(0.0)

    def __call__(self, x: jax.Array):
        return x @ self.w + self.b

In [None]:
model = LinearRegression(1, rngs = nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.sgd(1e-3))
y = model(jnp.asarray([1.0]))
print(y)
nnx.display(model)

#### 3.3: Write loss function and train model

Our training step takes advantage of the fact that `nnx.Module` classes are mutable, meaning that the params are stored and updated within the class. By calling `optimizer.update(grads)`, the parameters are updated automatically. 

In [None]:
@nnx.jit
def train_step(model, optimizer, x, y):
    def loss_fn(model: LinearRegression):
        y_pred = nnx.vmap(model)(x)
        return jnp.mean((y - y_pred)**2)

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)

    return loss

In [None]:
X = x_data.reshape(N_data, 1)
Y = y_data.reshape(N_data)

print(nnx.vmap(model)(X).shape)
print(Y.shape)
try:
    loss = train_step(model, optimizer, X, y_data)
    print(loss)
except:
    print("shape of x_data and y_data incorrect")

In [None]:
N_train = 10000
losses = []
for _ in range(N_train):
    loss = train_step(model, optimizer, X, Y)
    losses.append(loss)

In [None]:
plt.plot(losses)
plt.show()

In [None]:
print(model.w, model.b)

The values of $w$ and $b$ computed using gradient descent are almost identical to the analytically computed $w$ and $b$ from methods #1 and #2.

In [None]:
plt.plot(x_plot, ground_truth(x_plot), color='blue', label='ground truth')
plt.scatter(x_data, y_data, color='red', marker='x', label='data')
plt.plot(x_plot, model.w[0] * x_plot + model.b, color='green', label='Learned linear model, gradient descent')
plt.legend()
plt.show()