Create a dataset:

In [1]:
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

In [2]:
X, y = make_regression(n_features=3)
X, X_test, y, y_test = train_test_split(X, y)
X.shape, y.shape

((75, 3), (75,))

Define the parameters of the linear regression i.e. weights and biases:

In [3]:
import jax
import jax.numpy as jnp

params = {
    'w': jnp.zeros(X.shape[1:]),
    'b': 0.0
}

params



{'w': DeviceArray([0., 0., 0.], dtype=float32), 'b': 0.0}

Define the forward pass:

In [4]:
def forward(params, X):
    W = params['w']
    b = params['b']
    return jnp.dot(X, W) + b

assert forward(params, X).shape == (75,)

Define the loss function (mse):

In [5]:
# decorator to speed things up
@jax.jit
def loss_fn(params, X, y_true):
    y_pred = forward(params, X)
    err = y_pred - y_true
    return jnp.mean(jnp.square(err))

Define a function to compute the gradient of the loss function:

In [6]:
# by default, jax.grad takes the gradient of loss_fn wrt
# loss_fn's first arg i.e. params
grad_fn = jax.grad(loss_fn)

In [7]:
grad_fn(params, X_test, y_test)

{'b': DeviceArray(-27.83489, dtype=float32),
 'w': DeviceArray([ -74.10313, -186.97112, -158.06946], dtype=float32)}

Define the training loop:

In [8]:
lr = 10/100

for i in range(50):
    # show the performance on the test set
    loss = loss_fn(params, X_test, y_test)
    print(f"Loss at iteration {i}: {loss:.4f}")
    
    # update the gradients wrt loss function using the 
    # training set
    grads = grad_fn(params, X, y)
    params['w'] = params['w'] - lr * grads['w']
    params['b'] = params['b'] - lr * grads['b']
    
    # alternative way to update params
#     params = jax.tree_multimap(
#         lambda p, g: p - lr * g,
#         params,
#         grads
#     )

Loss at iteration 0: 13366.9443
Loss at iteration 1: 7959.3882
Loss at iteration 2: 4775.9360
Loss at iteration 3: 2886.5881
Loss at iteration 4: 1756.5554
Loss at iteration 5: 1075.6941
Loss at iteration 6: 662.6248
Loss at iteration 7: 410.4019
Loss at iteration 8: 255.4695
Loss at iteration 9: 159.7708
Loss at iteration 10: 100.3560
Loss at iteration 11: 63.2931
Loss at iteration 12: 40.0712
Loss at iteration 13: 25.4617
Loss at iteration 14: 16.2350
Loss at iteration 15: 10.3866
Loss at iteration 16: 6.6668
Loss at iteration 17: 4.2928
Loss at iteration 18: 2.7730
Loss at iteration 19: 1.7969
Loss at iteration 20: 1.1680
Loss at iteration 21: 0.7617
Loss at iteration 22: 0.4982
Loss at iteration 23: 0.3270
Loss at iteration 24: 0.2153
Loss at iteration 25: 0.1422
Loss at iteration 26: 0.0942
Loss at iteration 27: 0.0626
Loss at iteration 28: 0.0418
Loss at iteration 29: 0.0280
Loss at iteration 30: 0.0188
Loss at iteration 31: 0.0126
Loss at iteration 32: 0.0085
Loss at iteration 3