Create a dataset:

In [1]:
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

In [2]:
X, y = make_regression(n_features=3)
X, X_test, y, y_test = train_test_split(X, y)
X.shape, y.shape

((75, 3), (75,))

Define the forward pass of a linear regression model using Haiku:

In [3]:
import jax
import jax.numpy as jnp
import haiku as hk

def forward(X):
    lin = hk.Linear(1)
    return lin(X).ravel()

forward = hk.transform(forward)

Initialze the model parameters:

In [4]:
rng = jax.random.PRNGKey(seed=13)
params = forward.init(rng, X)
params

  lax._check_user_dtype_supported(dtype, "zeros")


FlatMapping({
  'linear': FlatMapping({
              'w': DeviceArray([[0.3705009 ],
                                [0.2911511 ],
                                [0.56166327]], dtype=float32),
              'b': DeviceArray([0.], dtype=float32),
            }),
})

Define function to perform forward pass given input:

In [5]:
f = forward.apply
f(params, rng, X).shape

(75,)

Define the training loop:

In [6]:
def loss_fn(params, X, y_true):
    y_pred = f(params, rng, X)
    err = y_pred - y_true
    return jnp.mean(jnp.square(err))

grad_fn = jax.grad(loss_fn)

In [7]:
lr = 10/100

for i in range(50):
    # show the performance on the test set
    loss = loss_fn(params, X_test, y_test)
    print(f"Loss at iteration {i}: {loss:.4f}")
    
    # update the gradients wrt loss function using the 
    # training set
    grads = grad_fn(params, X, y)
    params = jax.tree_multimap(
        lambda p, g: p - lr * g,
        params,
        grads
    )

Loss at iteration 0: 15677.8184
Loss at iteration 1: 9501.0879
Loss at iteration 2: 5833.2993
Loss at iteration 3: 3624.5974
Loss at iteration 4: 2276.9775
Loss at iteration 5: 1444.7137
Loss at iteration 6: 924.9911
Loss at iteration 7: 597.1456
Loss at iteration 8: 388.4287
Loss at iteration 9: 254.4359
Loss at iteration 10: 167.7528
Loss at iteration 11: 111.2772
Loss at iteration 12: 74.2397
Loss at iteration 13: 49.8000
Loss at iteration 14: 33.5794
Loss at iteration 15: 22.7540
Loss at iteration 16: 15.4914
Loss at iteration 17: 10.5944
Loss at iteration 18: 7.2765
Loss at iteration 19: 5.0180
Loss at iteration 20: 3.4739
Loss at iteration 21: 2.4138
Loss at iteration 22: 1.6829
Loss at iteration 23: 1.1771
Loss at iteration 24: 0.8258
Loss at iteration 25: 0.5810
Loss at iteration 26: 0.4098
Loss at iteration 27: 0.2898
Loss at iteration 28: 0.2054
Loss at iteration 29: 0.1458
Loss at iteration 30: 0.1038
Loss at iteration 31: 0.0739
Loss at iteration 32: 0.0528
Loss at iteratio