# Linear Regression with Flax

In [1]:
# !pip install --upgrade -q pip jax jaxlib
# !pip install --upgrade -q git+https://github.com/google/flax.git
# !pip install pandas
# !pip install plotnine

Collecting pandas
  Using cached pandas-2.1.4-cp39-cp39-macosx_10_9_x86_64.whl.metadata (18 kB)
Collecting pytz>=2020.1 (from pandas)
  Using cached pytz-2023.3.post1-py2.py3-none-any.whl.metadata (22 kB)
Collecting tzdata>=2022.1 (from pandas)
  Using cached tzdata-2023.4-py2.py3-none-any.whl.metadata (1.4 kB)
Using cached pandas-2.1.4-cp39-cp39-macosx_10_9_x86_64.whl (11.8 MB)
Using cached pytz-2023.3.post1-py2.py3-none-any.whl (502 kB)
Using cached tzdata-2023.4-py2.py3-none-any.whl (346 kB)
Installing collected packages: pytz, tzdata, pandas
Successfully installed pandas-2.1.4 pytz-2023.3.post1 tzdata-2023.4
Collecting plotnine
  Using cached plotnine-0.12.4-py3-none-any.whl.metadata (8.9 kB)
Collecting matplotlib>=3.6.0 (from plotnine)
  Using cached matplotlib-3.8.2-cp39-cp39-macosx_10_12_x86_64.whl.metadata (5.8 kB)
Collecting mizani<0.10.0,>0.9.0 (from plotnine)
  Using cached mizani-0.9.3-py3-none-any.whl.metadata (4.6 kB)
Collecting patsy>=0.5.1 (from plotnine)
  Using cached

In [5]:
import numpy as np
import pandas as pd
from plotnine import *

In [6]:
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn

In [4]:
model = nn.Dense(features = 5)
## A linear transformation applied over the last dimension of the input
key1, key2 = random.split(random.key(0))
x = random.normal(key1, (10, ))
params = model.init(key2, x)
## Initializes a module method with variables and returns modified variables.
jax.tree_util.tree_map(lambda x: x.shape, params)
## Maps a multi-input function over pytree args to produce a new pytree
## params is the old pytree, so this is basically applying shape() to each key in the param dict

{'params': {'bias': (5,), 'kernel': (10, 5)}}

In [5]:
model.apply(params, x)
## Applies a module method to variables and returns output and modified variables

Array([-1.3721193 ,  0.61131495,  0.6442836 ,  2.2192965 , -1.1271116 ],      dtype=float32)

In [6]:
n_samples = 20
x_dim = 10
y_dim = 5

key = random.key(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim, ))
true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})

key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise, (n_samples, y_dim))
f'x shape: {x_samples.shape}; y shape: {y_samples.shape}'

'x shape: (20, 10); y shape: (20, 5)'

In [7]:
@jax.jit
def mse(params, x_batched, y_batched):
    # Define the squared loss for a single pair (x, y)
    def squared_error(x, y):
        pred = model.apply(params, x)
        return jnp.inner(y - pred, y - pred) / 2.0
    # Vectorize the previous to compute the average of the loss on all samples.
    return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis = 0)

In [8]:
learning_rate = 0.3
print(f'Loss for "true" W, b: {mse(true_params, x_samples, y_samples)}')
loss_grad_fn = jax.value_and_grad(mse)
## value_and_grad returns a function that can be called on data that will return the value and gradient

Loss for "true" W, b: 0.02363979071378708


In [9]:
@jax.jit
def update_params(params, learning_rate, grads):
    def move_params(params, grads):
        params - learning_rate * grads
        ## if g is positive then move away from current params
        ## if g is negative then move towards current params
        ## we want to minimize mse, so reduce the value = negative gradient
        return move_params
    
    new_params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, params, grads)
    return new_params

for i in range(101):
    # Perform one gradient update
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    params = update_params(params, learning_rate, grads)
    if i % 10 == 0:
        print(f'Loss step {i}: {loss_val}')

Loss step 0: 35.343875885009766
Loss step 10: 0.5143468976020813
Loss step 20: 0.11384159326553345
Loss step 30: 0.0393267385661602
Loss step 40: 0.01991621032357216
Loss step 50: 0.014209136366844177
Loss step 60: 0.012425652705132961
Loss step 70: 0.011850389651954174
Loss step 80: 0.011661785654723644
Loss step 90: 0.011599408462643623
Loss step 100: 0.011578695848584175


In [10]:
import optax
tx = optax.adam(learning_rate = learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

for i in range(101):
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 10 == 0:
        print(f'Loss step {i}: {loss_val}')

Loss step 0: 0.011577627621591091
Loss step 10: 0.2614315450191498
Loss step 20: 0.0767502710223198
Loss step 30: 0.03644055128097534
Loss step 40: 0.022012805566191673
Loss step 50: 0.016178598627448082
Loss step 60: 0.01300280075520277
Loss step 70: 0.01202614326030016
Loss step 80: 0.01176451425999403
Loss step 90: 0.011646044440567493
Loss step 100: 0.011585528962314129


In [11]:
params

{'params': {'bias': Array([-1.4555768 , -2.0277991 ,  2.0790975 ,  1.2186145 , -0.99809754],      dtype=float32),
  'kernel': Array([[ 1.0098814 ,  0.18934374,  0.04454996, -0.9280221 ,  0.3478402 ],
         [ 1.7298453 ,  0.9879368 ,  1.1640464 ,  1.1006076 , -0.10653935],
         [-1.2029463 ,  0.28635228,  1.4155979 ,  0.11870951, -1.3141483 ],
         [-1.1941489 , -0.18958491,  0.03413862,  1.3169426 ,  0.0806038 ],
         [ 0.1385241 ,  1.3713038 , -1.3187183 ,  0.53152674, -2.2404997 ],
         [ 0.56294024,  0.8122311 ,  0.3175201 ,  0.53455096,  0.9050039 ],
         [-0.37926027,  1.7410393 ,  1.0790287 , -0.5039833 ,  0.9283062 ],
         [ 0.9706492 , -1.3153403 ,  0.33681503,  0.8099344 , -1.2018458 ],
         [ 1.0194312 , -0.6202479 ,  1.0818833 , -1.838974  , -0.45805007],
         [-0.6436537 ,  0.45666698, -1.1329137 , -0.6853864 ,  0.16829035]],      dtype=float32)}}

So in this transformation we took a $20x10$ matrix, multiplied it by a $10x5$ matrix to get a $20x5$ result. Then we took the differnce between this prediction matrix and the actual values. Then we took an inner product of these errors to get a squared error, by row, but vectorized with vmap to work across our full X dimension. So this would return a simple (20, 1) array. Then we took a mean of each observation's errors (across dimension 0). to get the mean MSE. My question is why can't we use matrix multiplication. Our errors are by row, if we transpose the matrix we get them by column, a $5x20$ times a $20x5$. this gives us a 5x5? The diagonal elements are what we want, but the off-diagonal elements aren't. No, we need by observation, which is $20x5$ by $5x20$ to get the squared error, but this produces a ton of off diagonal elements we don't need. 

In [12]:
mse(params, x_samples, y_samples)

Array(0.01159014, dtype=float32)

In [13]:
jnp.mean(jnp.inner(y_samples - model.apply(params, x_samples), y_samples - model.apply(params, x_samples)) / 2.0)

Array(2.2043007e-06, dtype=float32)

In [14]:
model.apply(params, x_samples[0])

Array([-0.48579866, -4.074929  , -2.0214841 , -0.20061362, -0.59897006],      dtype=float32)

In [15]:
model.apply(params, x_samples)[0]

Array([-0.48579878, -4.0749288 , -2.0214846 , -0.20061398, -0.59897006],      dtype=float32)

In [16]:
errors = y_samples - model.apply(params, x_samples)
jnp.diag(jnp.dot(errors, errors.T) / 2)

Array([0.02080444, 0.00846512, 0.01537622, 0.0055256 , 0.0069695 ,
       0.00957341, 0.01018183, 0.009361  , 0.01557972, 0.01286128,
       0.00875765, 0.00492677, 0.00321356, 0.02069209, 0.02168619,
       0.00257889, 0.00533569, 0.0045466 , 0.01921797, 0.02614925],      dtype=float32)

In [17]:
mse(params, x_samples, y_samples)

Array(0.01159014, dtype=float32)

In [18]:
@jax.jit
def mse2(params, x_batched, y_batched):
    # Define the squared loss for a single pair (x, y)
    def squared_error(x, y):
        pred = model.apply(params, x)
        return jnp.inner(y - pred, y - pred) / 2.0
    # Vectorize the previous to compute the average of the loss on all samples.
    return jax.vmap(squared_error)(x_batched, y_batched)

mse2(params, x_samples, y_samples)

Array([0.02080444, 0.00846512, 0.01537622, 0.0055256 , 0.0069695 ,
       0.00957341, 0.01018183, 0.009361  , 0.01557972, 0.01286128,
       0.00875765, 0.00492677, 0.00321356, 0.02069209, 0.02168619,
       0.00257889, 0.00533569, 0.0045466 , 0.01921797, 0.02614925],      dtype=float32)

Ok so things to do to take this to our problem:

1. Get our X and Y dimensions
    [20, 60] => [20, 90]
2. Build up our internal layers
3. Compute our loss function 
  + squared loss: inner product on the errors
  + negative log likelihood: apply the function to each observation, take the mean

 

### Notes

`optax.GradientTransformation` aka `tx = optax.adam(...)`
* Defined as a pair of pure functions (`init` & `update`)
* Each time a gradient transformation is applied a new state is computed and returned
* init with an example of the model params whose gradients will be transformed and get a corresponding pytree containing the initial value for the optimizer state
* update with the new gradients and the old state

In [8]:
inits = nn.initializers.uniform(.1)
inits(jax.random.PRNGKey(42), (2, 3), jnp.float32)

Array([[0.07298189, 0.08691938, 0.08723002],
       [0.02081857, 0.01866242, 0.05502256]], dtype=float32)