In [1]:
import torch
import numpy as np

# Model definition

Suppose we want to train a model to predict the price of a house (`y`) based on two variables: the number of rooms in the house `x1`, and the number of bathrooms (`x2`). We can define the dataset as follows:

In [2]:
# Data format:
# [x1, x2] --> [y]
# [NUM_ROOMS, NUM_BATHS] --> [PRICE]

training_data = [
    [torch.tensor([6, 2], dtype=torch.float), torch.tensor([15], dtype=torch.float)],
    [torch.tensor([5, 2], dtype=torch.float), torch.tensor([12], dtype=torch.float)],
    [torch.tensor([5, 1], dtype=torch.float), torch.tensor([10], dtype=torch.float)],
    [torch.tensor([3, 1], dtype=torch.float), torch.tensor([7], dtype=torch.float)],
    [torch.tensor([2, 1], dtype=torch.float), torch.tensor([4.5], dtype=torch.float)],
    [torch.tensor([2, 0], dtype=torch.float), torch.tensor([4], dtype=torch.float)],
    [torch.tensor([1, 0], dtype=torch.float), torch.tensor([2], dtype=torch.float)],
]

Suppose we know that the relation between the two variables (`x1` and `x2`) and the target variable (`y`) is linear, i.e.,

$$y = w_1 \cdot x_1 + w_2 \cdot x_2 + b$$

where `w1` and `w2` are the weights of the model, and `b` is the bias term.

We want to train the model using gradient descent to find the optimal values of `w1`, `w2`, and `b` that minimize the mean squared error (MSE) between the predicted values and the actual values.

In [3]:
# Define the model parameters
class ModelParameters:
    
    def __init__(self):
        self.w1 = torch.tensor(0.773, dtype=torch.float, requires_grad=True)
        self.w2 = torch.tensor(0.321, dtype=torch.float, requires_grad=True)
        self.b = torch.tensor(0.067, dtype=torch.float, requires_grad=True)

# We will use two training loops: the first one without gradient accumulation, and the second one with gradient accumulation.
params_no_accumulate = ModelParameters()
params_accumulate = ModelParameters()

## Training loop (without gradient accumulation)

We run gradient descent using one data item at a time, we calculate the gradient of the loss function w.r.t the parameters, and update the parameters using the gradients at each iteration.

In [4]:
def train_no_accumulate(params: ModelParameters, num_epochs: int = 10, learning_rate: float = 1e-3):
    print(f'Initial parameters: w1: {params.w1.item():.3f}, w2: {params.w2.item():.3f}, b: {params.b.item():.3f}')
    for epoch in range(1, num_epochs+1):
        for (x1, x2), y_target in training_data:
            # Calculate the output of the model
            z1 = x1 * params.w1
            z1.retain_grad()
            z2 = x2 * params.w2
            z2.retain_grad()
            y_pred = z1 + z2 + params.b
            y_pred.retain_grad()
            loss = (y_pred - y_target) ** 2

            # Calculate the gradients of the loss w.r.t. the parameters
            loss.backward()

            # Update the parameters (at each iteration)
            with torch.no_grad():
                # Equivalent to calling optimizer.step()
                params.w1 -= learning_rate * params.w1.grad
                params.w2 -= learning_rate * params.w2.grad
                params.b -= learning_rate * params.b.grad

                # Reset the gradients to zero
                # Equivalent to calling optimizer.zero_grad()
                params.w1.grad.zero_()
                params.w2.grad.zero_()
                params.b.grad.zero_()
        print(f"Epoch {epoch:>3} - Loss: {np.round(loss.item(),4):>10}")
    print(f'Final parameters: w1: {params.w1.item():.3f}, w2: {params.w2.item():.3f}, b: {params.b.item():.3f}')
        
train_no_accumulate(params_no_accumulate)

Initial parameters: w1: 0.773, w2: 0.321, b: 0.067
Epoch   1 - Loss:     0.7001
Epoch   2 - Loss:     0.3374
Epoch   3 - Loss:     0.1454
Epoch   4 - Loss:      0.051
Epoch   5 - Loss:      0.011
Epoch   6 - Loss:     0.0001
Epoch   7 - Loss:     0.0039
Epoch   8 - Loss:     0.0142
Epoch   9 - Loss:     0.0265
Epoch  10 - Loss:     0.0387
Final parameters: w1: 1.897, w2: 0.692, b: 0.299


## Training loop (with gradient accumulation)

We run gradient descent using one data item at a time, but we accumulate the gradients over a fixed number of iterations (batch size) before updating the parameters.

In [5]:
def train_accumulate(params: ModelParameters, num_epochs: int = 10, learning_rate: float = 1e-3, batch_size: int = 2):
    print(f'Initial parameters: w1: {params.w1.item():.3f}, w2: {params.w2.item():.3f}, b: {params.b.item():.3f}')
    for epoch in range(1, num_epochs+1):
        for index, ((x1, x2), y_target) in enumerate(training_data):
            # Calculate the output of the model
            z1 = x1 * params.w1
            z1.retain_grad()
            z2 = x2 * params.w2
            z2.retain_grad()
            y_pred = z1 + z2 + params.b
            y_pred.retain_grad()
            loss = (y_pred - y_target) ** 2

            # We can also divide the loss by the batch size (equivalent to using nn.MSE loss with the paraemter reduction='mean')
            # If we don't divide by the batch size, then it is equivalent to using nn.MSE loss with the parameter reduction='sum'

            # Calculate the gradients of the loss w.r.t. the parameters
            # If we didn't call zero_() on the gradients on the previous iteration, then the gradients will accumulate (add up) over each iteration
            loss.backward()

            # Everytime we reach the batch size or the end of the dataset, update the parameters
            if (index + 1) % batch_size == 0 or index == len(training_data) - 1:
                with torch.no_grad():
                    # Equivalent to calling optimizer.step()
                    params.w1 -= learning_rate * params.w1.grad
                    params.w2 -= learning_rate * params.w2.grad
                    params.b -= learning_rate * params.b.grad

                    # Reset the gradients to zero
                    # Equivalent to calling optimizer.zero_grad()
                    params.w1.grad.zero_()
                    params.w2.grad.zero_()
                    params.b.grad.zero_()

        print(f"Epoch {epoch:>3} - Loss: {np.round(loss.item(),4):>10}")
    print(f'Final parameters: w1: {params.w1.item():.3f}, w2: {params.w2.item():.3f}, b: {params.b.item():.3f}')

train_accumulate(params_accumulate)

Initial parameters: w1: 0.773, w2: 0.321, b: 0.067
Epoch   1 - Loss:     0.6857
Epoch   2 - Loss:     0.3218
Epoch   3 - Loss:     0.1335
Epoch   4 - Loss:     0.0438
Epoch   5 - Loss:     0.0078
Epoch   6 - Loss:        0.0
Epoch   7 - Loss:     0.0059
Epoch   8 - Loss:     0.0174
Epoch   9 - Loss:     0.0303
Epoch  10 - Loss:     0.0427
Final parameters: w1: 1.905, w2: 0.698, b: 0.300
