In [18]:
from typing import List

class LinearRegression:
    def __init__(self, w: float = 1, b: float = 0, learning_rate: float = 0.01) -> None:
        self.w = w
        self.b = b
        self.learning_rate = learning_rate

    def fit(self, x_train: List[float], y_train: List[float], epochs: int = 1000) -> None:
        m = len(x_train)
        for epoch in range(epochs):
            y_pred = self.predict(x_train)
            mse = LinearRegression.compute_mse(y_pred, y_train)

            # compute the gradients (partial derivatives of the loss function J - MSE)
            errors = [y_pred[i] - y_train[i] for i in range(m)]
            dw = (1 / m) * sum(errors[i] * x_train[i] for i in range(m))
            db = (1 / m) * sum(errors)

            self.w -= self.learning_rate * dw 
            self.b -= self.learning_rate * db

            if epoch % 100 == 0:
                print(f"Epoch: {epoch}, MSE: {mse}")

    def predict(self, x_predict: List[float]) -> List[float]:
        return [x * self.w + self.b for x in x_predict]

    @staticmethod
    def compute_mse(y_train: List[float], y_prediction: List[float]) -> float:
        m = len(y_train)
        return (1 / (2 * m)) * sum((y_train[i] - y_prediction[i]) ** 2 for i in range(m))

In [19]:
x_train = [i for i in range(5)]
y_train = [i*5 for i in range(5)]

x_test = [i for i in range(101)]

model = LinearRegression()
model.fit(x_train, y_train)

y_pred = model.predict(x_test)
print(y_pred)

Epoch: 0, MSE: 48.0
Epoch: 100, MSE: 0.14392846140158663
Epoch: 200, MSE: 0.07914131916165541
Epoch: 300, MSE: 0.043530720782905014
Epoch: 400, MSE: 0.023943543941019975
Epoch: 500, MSE: 0.013169855360646542
Epoch: 600, MSE: 0.007243918888849419
Epoch: 700, MSE: 0.003984429550003297
Epoch: 800, MSE: 0.002191587051502882
Epoch: 900, MSE: 0.0012054558234844766
[0.06290148900525033, 5.04083683805678, 10.01877218710831, 14.996707536159839, 19.974642885211367, 24.952578234262894, 29.930513583314426, 34.90844893236596, 39.886384281417484, 44.86431963046901, 49.84225497952054, 54.820190328572075, 59.7981256776236, 64.77606102667514, 69.75399637572667, 74.73193172477819, 79.70986707382973, 84.68780242288126, 89.66573777193278, 94.64367312098432, 99.62160847003584, 104.59954381908737, 109.57747916813891, 114.55541451719043, 119.53334986624196, 124.5112852152935, 129.489220564345, 134.46715591339654, 139.44509126244807, 144.4230266114996, 149.40096196055111, 154.37889730960265, 159.3568326586541