In [1]:
from torch import nn
import torch

In [8]:
class LinearRegression(nn.Module):
  def __init__(self):
    super().__init__()
    self.weights = nn.Parameter(torch.randn(1, requires_grad=True, dtype=torch.float))
    self.bias = nn.Parameter(torch.randn(1, requires_grad=True, dtype=torch.float))
  def forward(self, x : torch.Tensor) -> torch.Tensor:
    return self.weights * x  + self.bias

In [10]:
model_0 = LinearRegression()
loss_fn = nn.L1Loss()
optimizer = torch.optim.SGD(params=model_0.parameters(), lr=0.01)

In [11]:
start = 0
end = 1
step = 0.02
weight = 0.7
bias = 0.3
X = torch.arange(start, end, step).unsqueeze(dim=1)
y = weight * X + bias
train_split = int(0.8 * len(X))
X_train, y_train = X[:train_split], y[:train_split]
X_test, y_test = X[train_split:], y[train_split:]

In [12]:
torch.manual_seed(42)

epoch_count = []
train_loss_values = []
test_loss_values = []
epochs = 100
for epoch in range(epochs) :
  model_0.train()
  y_pred = model_0(X_train)
  loss = loss_fn(y_pred, y_train)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  model_0.eval()

  with torch.inference_mode():
    test_pred = model_0(X_test)
    test_loss = loss_fn(test_pred, y_test)
    if epoch % 10 == 0:
      epoch_count.append(epoch)
      train_loss_values.append(loss.detach().numpy())
      test_loss_values.append(test_loss.detach().numpy())
      print(f"Epoch : {epoch} | Loss : {loss} | Test Loss : {test_loss}")

Epoch : 0 | Loss : 0.31288138031959534 | Test Loss : 0.48106518387794495
Epoch : 10 | Loss : 0.1976713240146637 | Test Loss : 0.3463551998138428
Epoch : 20 | Loss : 0.08908725529909134 | Test Loss : 0.21729660034179688
Epoch : 30 | Loss : 0.053148526698350906 | Test Loss : 0.14464017748832703
Epoch : 40 | Loss : 0.04543796554207802 | Test Loss : 0.11360953003168106
Epoch : 50 | Loss : 0.04167863354086876 | Test Loss : 0.09919948130846024
Epoch : 60 | Loss : 0.03818932920694351 | Test Loss : 0.08886633068323135
Epoch : 70 | Loss : 0.03476089984178543 | Test Loss : 0.0805937647819519
Epoch : 80 | Loss : 0.03132382780313492 | Test Loss : 0.07232122868299484
Epoch : 90 | Loss : 0.02788739837706089 | Test Loss : 0.06473556160926819
