# Linear Regression

In [1]:
%pip install \
    git+https://github.com/deepmind/dm-haiku@v0.0.4 \
    git+https://github.com/deepmind/optax@v0.0.9

Collecting git+https://github.com/deepmind/dm-haiku@v0.0.4
  Cloning https://github.com/deepmind/dm-haiku (to revision v0.0.4) to /tmp/pip-req-build-x4u3u_dq
  Running command git clone -q https://github.com/deepmind/dm-haiku /tmp/pip-req-build-x4u3u_dq
  Running command git checkout -q 005187ca7825c25aedcfd73d828214aee6eebab2
Collecting git+https://github.com/deepmind/optax@v0.0.9
  Cloning https://github.com/deepmind/optax (to revision v0.0.9) to /tmp/pip-req-build-ir4i313t
  Running command git clone -q https://github.com/deepmind/optax /tmp/pip-req-build-ir4i313t
  Running command git checkout -q 989b755ca7cf0b42f30612edfb6e90ba53cef7e1
Collecting chex>=0.0.4
  Downloading chex-0.0.8-py3-none-any.whl (57 kB)
[K     |████████████████████████████████| 57 kB 2.5 MB/s 
Building wheels for collected packages: dm-haiku, optax
  Building wheel for dm-haiku (setup.py) ... [?25l[?25hdone
  Created wheel for dm-haiku: filename=dm_haiku-0.0.4-py3-none-any.whl size=545777 sha256=247b534580e

In [2]:
import haiku as hk

from jax import jit, partial, vmap, grad, value_and_grad
import jax.lax as lax
import jax.numpy as np
from jax import random

import optax

In [3]:
@hk.without_apply_rng
@hk.transform
def model(x):
    linear = hk.Linear(output_size=1)
    return linear(x)

In [4]:
rng = random.PRNGKey(42)

x_shape = (2,)
rng, r = random.split(rng)
generating_model_state = model.init(r, np.zeros(x_shape))
rng, r = random.split(rng)
x = random.normal(r, (1024,) + x_shape)
y = vmap(partial(model.apply, generating_model_state))(x)

@jit
def loss_fn(model_state):
    model_predictions = vmap(partial(model.apply, model_state))(x)
    loss = np.mean(optax.l2_loss(model_predictions, y))
    return loss



In [5]:
steps = 10
start_learning_rate = 1e-1
optimizer = optax.adam(start_learning_rate)
@jit
def train(model_state, optimizer_state):
    def train_step(i, train_state):
        model_state, optimizer_state = train_state
        loss_grads = grad(loss_fn)(model_state)
        model_updates, optimizer_state = optimizer.update(loss_grads, optimizer_state)
        model_state = optax.apply_updates(model_state, model_updates)
        return model_state, optimizer_state

    initial_train_state = model_state, optimizer_state
    return lax.fori_loop(0, steps, train_step, initial_train_state)

In [6]:
rng, r = random.split(rng)
inferred_model_state = model.init(r, np.zeros(x.shape[1:]))
optimizer_state = optimizer.init(inferred_model_state)

In [7]:
for i in range(10):
    inferred_model_state, optimizer_state = train(inferred_model_state, optimizer_state)
    print(loss_fn(inferred_model_state))

0.23128243
0.0014280316
0.039447226
0.009026997
0.0010691732
0.0021785079
1.8527819e-05
0.00027094266
1.5056953e-05
3.0225212e-05
