In [1]:
from jax import grad
from jax import numpy as jnp

In [2]:
x = jnp.array([1.0, 2.0, 3.0])
y = 2.0*x + 1.0

# split train and test data
N = 2
x_train, y_train, x_test, y_test = x[:N], y[:N], x[N:], y[N:]
print(x_train, y_train, x_test, y_test)

[1. 2.] [3. 5.] [3.] [7.]


In [3]:
def model(w, b, x):
    return w*x + b

def loss(w, b, x, y):
    y_hat = model(w, b, x)
    return jnp.mean((y_hat - y)**2)

dloss_dw = grad(loss, argnums=0)
dloss_db = grad(loss, argnums=1)

In [4]:
w, b = 0.0, 0.0
N_epoch = 100
for epoch in range(N_epoch):
    for x_i, y_i in zip(x_train, y_train):
        dl_dw = dloss_dw(w, b, x_i, y_i)
        dl_db = dloss_db(w, b, x_i, y_i)
        w -= 0.01 * dl_dw
        b -= 0.01 * dl_db
    l = loss(w, b, x_train, y_train)
    err = loss(w, b, x_test, y_test)
    if epoch % 10 == 0:
        print(f'epoch {epoch}: test loss {err:.3e}, train loss {l:.3e}, w {w:.2f}, b {b:.2f}')

epoch 0: test loss 3.703e+01, train loss 1.277e+01, w 0.25, b 0.16
epoch 10: test loss 2.433e+00, train loss 7.280e-01, w 1.50, b 0.93
epoch 20: test loss 2.264e-01, train loss 4.338e-02, w 1.81, b 1.11
epoch 30: test loss 4.571e-02, train loss 4.734e-03, w 1.88, b 1.15
epoch 40: test loss 2.197e-02, train loss 2.519e-03, w 1.90, b 1.15
epoch 50: test loss 1.682e-02, train loss 2.290e-03, w 1.91, b 1.15
epoch 60: test loss 1.500e-02, train loss 2.158e-03, w 1.91, b 1.15
epoch 70: test loss 1.394e-02, train loss 2.034e-03, w 1.91, b 1.14
epoch 80: test loss 1.308e-02, train loss 1.916e-03, w 1.92, b 1.14
epoch 90: test loss 1.231e-02, train loss 1.804e-03, w 1.92, b 1.13
