Gradient descent implementation for linear regression models
- yhat = wx + b
- MSE = sum(y-yhat)**2 / N
- Minimize the loss using gradient descent

Requierements:
- Loss function
- Gradient with respect to params
- Training procedure
- Testing procedure

In [1]:
import numpy as np

In [12]:
# Initialize Random Parameters
w = np.random.randn()
b = np.random.randn()

print(b)


# Simulate Observation
X = np.random.randn(100, 1)
y = 2*X + b

1.273655985445401


In [9]:
def forward(X):
    return w*X + b

def gradient_descent(X, y, w, b, alpha):
    dldw = 0.0
    dldb = 0.0
    N = X.shape[0]

    for xi, yi in zip(X,y):
        dldw += -2*xi*(yi-(w*xi+b))
        dldb += -2*(yi-(w*xi+b))
    
    w = w - alpha*(1/N)*dldw
    b = b - alpha*(1/N)*dldb

    return w,b

def fit(epochs, X, y, w, b, alpha):
    for epoch in range(epochs):
        w,b = gradient_descent(X,y,w,b,alpha)
        yhat = w*X + b
        loss = np.mean((y-yhat)**2)
        print(f"epoch: {epoch} Results: {loss}, w: {w} b: {b}")


In [14]:
fit(5000, X, y, w, b, 0.001)

epoch: 0 Results: 0.8076896348767291, w: [1.13587401] b: [1.27385305]
epoch: 1 Results: 0.8041602781473342, w: [1.13774344] b: [1.2740493]
epoch: 2 Results: 0.8006464958678384, w: [1.13960877] b: [1.27424473]
epoch: 3 Results: 0.7971482187163161, w: [1.14147003] b: [1.27443934]
epoch: 4 Results: 0.7936653776816911, w: [1.14332721] b: [1.27463314]
epoch: 5 Results: 0.7901979040623368, w: [1.14518033] b: [1.27482613]
epoch: 6 Results: 0.7867457294646782, w: [1.1470294] b: [1.27501831]
epoch: 7 Results: 0.783308785801802, w: [1.14887443] b: [1.27520969]
epoch: 8 Results: 0.7798870052920741, w: [1.15071542] b: [1.27540026]
epoch: 9 Results: 0.7764803204577615, w: [1.15255238] b: [1.27559004]
epoch: 10 Results: 0.7730886641236601, w: [1.15438533] b: [1.27577902]
epoch: 11 Results: 0.7697119694157303, w: [1.15621427] b: [1.2759672]
epoch: 12 Results: 0.7663501697597402, w: [1.1580392] b: [1.27615459]
epoch: 13 Results: 0.7630031988799094, w: [1.15986015] b: [1.27634119]
epoch: 14 Results: 0.