In [2]:
!pip install jax jaxlib



In [3]:
import jax.numpy as jnp
from jax import grad, jit
from jax import random

In [4]:
# Data
key = random.PRNGKey(0)
X = random.normal(key, (100, 1))
true_w = 2.0
true_b = 1.0
y = true_w * X + true_b + random.normal(key, (100, 1)) * 0.1

# Initializations
w = random.normal(key, (1,))
b = random.normal(key, ())

In [5]:
# Linear model
def predict(X, w, b):
    return jnp.dot(X, w) + b

# Mean squared error loss
def loss_fn(w, b, X, y):
    preds = predict(X, w, b)
    return jnp.mean((preds - y) ** 2)

In [6]:
# Gradients
grad_fn = grad(loss_fn, argnums=(0, 1))

# Training
learning_rate = 0.1
for i in range(1000):
    grads = grad_fn(w, b, X, y)
    w -= learning_rate * grads[0]
    b -= learning_rate * grads[1]

    if i % 100 == 0:
        current_loss = loss_fn(w, b, X, y)
        print(f"Iteration {i}: loss {current_loss}")

print(f"Trained parameters: w = {w}, b = {b}")

Iteration 0: loss 5.156553745269775
Iteration 100: loss 3.9025702476501465
Iteration 200: loss 3.9025702476501465
Iteration 300: loss 3.9025702476501465
Iteration 400: loss 3.9025702476501465
Iteration 500: loss 3.9025702476501465
Iteration 600: loss 3.9025702476501465
Iteration 700: loss 3.9025702476501465
Iteration 800: loss 3.9025702476501465
Iteration 900: loss 3.9025702476501465
Trained parameters: w = [3.394465e-08], b = 1.172472596168518
