# Linear Regression with Pytorch model
1. Design model (input, output size, forward pass)
2. Construct loss and optimizer
3. Training loop
    - forward propagation: compute prdiction
    - backward propagation: gradients
    - update weights

In [13]:
import torch 
import torch.nn as nn

# f = w*x
X = torch.tensor([[1], [2], [3], [4]], dtype = torch.float32)
Y = torch.tensor([[2], [4], [6], [8]], dtype = torch.float32)
X_test = torch.tensor([[5]], dtype=torch.float32)

n_samples, n_features = X.shape

input_size = n_features
output_size = n_features

#model = nn.Linear(input_size, output_size)

class LinearRegression(nn.Module):
    
    def __init__(self, input_dim, output_dim):
        super(LinearRegression, self).__init__()
        # define layers
        self.lin = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        return self.lin(x)

model = LinearRegression(input_size, output_size)

print('Prediction before training: f(5) = {:.3f}'.format(model(X_test).item()))

# training
learning_rate = 0.01
n_iters = 1000

loss = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

for epoch in range(n_iters):
    # prediction
    y_pred = model(X)

    # loss
    l = loss(Y, y_pred)
    # back propagation
    
    l.backward()
    
    # update weights
    optimizer.step()

    # zero gradient
    optimizer.zero_grad()

    if epoch % 10 == 0:
        [w, b] = model.parameters()
        print(f'epoch {epoch+1}: w = {w[0][0].item():.3f}, b = {b[0].item():.3f}, loss = {l:.8f}')

print(f'Prediction after training: f(5) = {model(X_test).item():.3f}')



Prediction before training: f(5) = -2.277
epoch 1: w = -0.143, b = 0.376, loss = 43.97863770
epoch 11: w = 1.383, b = 0.862, loss = 1.28878391
epoch 21: w = 1.636, b = 0.917, loss = 0.17550947
epoch 31: w = 1.685, b = 0.903, loss = 0.13843153
epoch 41: w = 1.700, b = 0.878, loss = 0.12967908
epoch 51: w = 1.710, b = 0.852, loss = 0.12211309
epoch 61: w = 1.719, b = 0.827, loss = 0.11500504
epoch 71: w = 1.727, b = 0.803, loss = 0.10831112
epoch 81: w = 1.735, b = 0.779, loss = 0.10200687
epoch 91: w = 1.743, b = 0.756, loss = 0.09606957
epoch 101: w = 1.750, b = 0.734, loss = 0.09047784
epoch 111: w = 1.758, b = 0.712, loss = 0.08521153
epoch 121: w = 1.765, b = 0.691, loss = 0.08025177
epoch 131: w = 1.772, b = 0.671, loss = 0.07558069
epoch 141: w = 1.779, b = 0.651, loss = 0.07118146
epoch 151: w = 1.785, b = 0.632, loss = 0.06703838
epoch 161: w = 1.792, b = 0.613, loss = 0.06313640
epoch 171: w = 1.798, b = 0.595, loss = 0.05946150
epoch 181: w = 1.804, b = 0.577, loss = 0.0560004