# Simple Gradient Descent Implementation

Libraries: NumPy | PyTorch | JAX  
Purpose: Demonstrate linear regression using gradient descent with fully calculated gradients.  
Scope: Single/multi-feature input, scalar output, MSE loss.

In [26]:
import numpy as np
import torch
import jax
import jax.numpy as jnp

In [32]:
# Simple gradient descent example in NumPy


# Input features
x = np.array([2, 2])

# Parameters (weights and bias)
w = np.array([0.1, 0.01])
b = np.array([0.0])

# Target output
y = 0.1

# Linear model: y_pred = w*x + b
simple_node = lambda x, w, b: np.dot(x, w) + b

# Learning rate
lr = 0.1

# Number of training iterations
epochs = 100

for r in range(epochs):
    # Forward pass: compute predicted output
    y_pred = simple_node(x, w, b)
    print(f"{y_pred} is the prediction for epoch {r}")

    # Compute mean squared error (MSE) loss
    error = np.mean((y - y_pred) ** 2)

    # Compute gradients using chain rule:
    # Let u = (y - y_pred)
    # dL/du = 2 * u
    # du/dw = -x, du/db = -1
    # Therefore:
    # dL/dw = -2 * (y - y_pred) * x
    # dL/db = -2 * (y - y_pred)
    dw = -2 * (y - y_pred) * x
    db = -2 * (y - y_pred)

    # Update parameters using gradient descent
    w = w - lr * dw
    b = b - lr * db

    # Print current loss
    print(f"error is {error}")


[0.22] is the prediction for epoch 0
error is 0.0144
[0.004] is the prediction for epoch 1
error is 0.009216
[0.1768] is the prediction for epoch 2
error is 0.005898240000000005
[0.03856] is the prediction for epoch 3
error is 0.0037748736000000043
[0.149152] is the prediction for epoch 4
error is 0.002415919104000003
[0.0606784] is the prediction for epoch 5
error is 0.0015461882265600032
[0.13145728] is the prediction for epoch 6
error is 0.000989560464998402
[0.07483418] is the prediction for epoch 7
error is 0.0006333186975989775
[0.12013266] is the prediction for epoch 8
error is 0.0004053239664633469
[0.08389387] is the prediction for epoch 9
error is 0.0002594073385365428
[0.1128849] is the prediction for epoch 10
error is 0.0001660206966633877
[0.08969208] is the prediction for epoch 11
error is 0.00010625324586456811
[0.10824634] is the prediction for epoch 12
error is 6.800207735332346e-05
[0.09340293] is the prediction for epoch 13
error is 4.352132950612701e-05
[0.10527766]

In [33]:
y_pred

array([0.1])

# PyTorch

In [35]:
# Gradient Descent in Pytorch
# Input features
x = torch.tensor([2, 2], dtype=torch.float64)

# Parameters (weights and bias), set requires_grad=True for autograd
w = torch.tensor([0.1, 0.01], dtype=torch.float64, requires_grad=True)
b = torch.tensor([0.0], dtype=torch.float64, requires_grad=True)

# Target output
y = torch.tensor([0.1])

# Learning rate
lr = 0.1

# Number of features (for averaging loss if needed)
m = x.shape[0]

# Training loop
for i in range(100):
    # Define linear model: y_pred = w*x + b
    node = lambda w, b: torch.dot(x, w) + b
    y_pred = node(w, b)

    # Compute mean squared error (MSE) loss
    loss = ((y - y_pred) ** 2) / m

    # Backward pass: compute gradients automatically
    loss.backward()

    # Retrieve gradients
    dw = w.grad
    db = b.grad

    # Gradient descent parameter update
    w.data -= lr * dw
    b.data -= lr * db

    # Clear gradients for next iteration
    w.grad.zero_()
    b.grad.zero_()

    # Print current loss
    print(f"Epoch {i}: loss = {loss.item()}")

# Final prediction
print(f"Final prediction: {y_pred.item()}")


Epoch 0: loss = 0.0071999998211860665
Epoch 1: loss = 7.199999821186084e-05
Epoch 2: loss = 7.19999982118595e-07
Epoch 3: loss = 7.199999821185784e-09
Epoch 4: loss = 7.19999982119078e-11
Epoch 5: loss = 7.199999821224086e-13
Epoch 6: loss = 7.19999982089102e-15
Epoch 7: loss = 7.199999842540368e-17
Epoch 8: loss = 7.199999859193714e-19
Epoch 9: loss = 7.199997860792427e-21
Epoch 10: loss = 7.200014514144952e-23
Epoch 11: loss = 7.200147741658488e-25
Epoch 12: loss = 7.200147741658488e-27
Epoch 13: loss = 7.18849499882647e-29
Epoch 14: loss = 7.288681874533494e-31
Epoch 15: loss = 7.800016274768305e-33
Epoch 16: loss = 0.0
Epoch 17: loss = 0.0
Epoch 18: loss = 0.0
Epoch 19: loss = 0.0
Epoch 20: loss = 0.0
Epoch 21: loss = 0.0
Epoch 22: loss = 0.0
Epoch 23: loss = 0.0
Epoch 24: loss = 0.0
Epoch 25: loss = 0.0
Epoch 26: loss = 0.0
Epoch 27: loss = 0.0
Epoch 28: loss = 0.0
Epoch 29: loss = 0.0
Epoch 30: loss = 0.0
Epoch 31: loss = 0.0
Epoch 32: loss = 0.0
Epoch 33: loss = 0.0
Epoch 34: lo

JAX: Simple Gradient Descent

In [31]:
#Gradient Descent in Jax

# Input features
x = jnp.array([2.0, 2.0])

# Parameters (weights and bias)
w = jnp.array([0.1, 0.01])
b = jnp.array([0.0])

# Target output
y = 0.1

# Learning rate
lr = 0.1

# Number of features (for averaging loss)
m = x.shape[0]

# Linear model
node = lambda w, b: jnp.dot(x, w) + b

# Loss function (mean squared error)
error = lambda w, b: jnp.mean((y - node(w, b)) ** 2)

# Training loop
epochs = 100
for r in range(epochs):
    # Compute gradients w.r.t weights and bias NB: grad was used here unlike vmap seen in previous codes. b is a scalar, w vector.
    dw, db = jax.grad(error, argnums=(0, 1))(w, b)

    # Gradient descent update
    w = w - lr * dw
    b = b - lr * db

    # Compute loss and prediction after update
    loss = error(w, b)
    y_pred = node(w, b)

    # Print current loss and prediction
    print(f"Epoch {r}: loss = {loss}, y_pred = {y_pred}")
# to understand the loss function better in jax, see Partial derivative Chain rule in Computational Multivariate Calculus Notebook.

loss is 0.009216000325977802 and y_pred is [0.004]
loss is 0.005898241885006428 and y_pred is [0.17680001]
loss is 0.003774875309318304 and y_pred is [0.03855999]
loss is 0.002415921539068222 and y_pred is [0.14915203]
loss is 0.001546190120279789 and y_pred is [0.06067838]
loss is 0.000989562482573092 and y_pred is [0.13145731]
loss is 0.0006333203054964542 and y_pred is [0.07483415]
loss is 0.0004053249431308359 and y_pred is [0.12013268]
loss is 0.0002594076213426888 and y_pred is [0.08389387]
loss is 0.00016602102550677955 and y_pred is [0.11288492]
loss is 0.00010625358117977157 and y_pred is [0.08969206]
loss is 6.800224218750373e-05 and y_pred is [0.10824635]
loss is 4.352135874796659e-05 and y_pred is [0.09340293]
loss is 2.7853731808136217e-05 and y_pred is [0.10527766]
loss is 1.7826338080340065e-05 and y_pred is [0.09577788]
loss is 1.1408896170905791e-05 and y_pred is [0.10337771]
loss is 7.301718142116442e-06 and y_pred is [0.09729783]
loss is 4.673157491197344e-06 and y_p