In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim import Adam

In [2]:
class MLP(nn.Module):
    def __init__(self, input_size):
        super(MLP, self).__init__()
        self.hidden_layer = nn.Linear(input_size, 64)
        self.output_layer = nn.Linear(64, 1)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.activation(self.hidden_layer(x))
        return self.output_layer(x)

In [3]:
class NumberSumDataset(Dataset):
    def __init__(self, data_range=(0, 10)):
        self.numbers = list(range(data_range[0], data_range[1]))
    
    def __len__(self):
        return len(self.numbers) ** 2
    
    def __getitem__(self, index):
        nr1 = float(self.numbers[index // len(self.numbers)])
        # print(f"nr1: {nr1}")
        nr2 = float(self.numbers[index % len(self.numbers)])
        # print(f"nr2: {nr2}")
        # print(f"len(self.numbers): {len(self.numbers)}")
        # print(f"index: {index}")
        return torch.tensor([nr1, nr2]), torch.tensor([nr1 + nr2])

In [4]:
dataset = NumberSumDataset(data_range=(1, 100))
for i in range(5):
    print(dataset[i])

(tensor([1., 1.]), tensor([2.]))
(tensor([1., 2.]), tensor([3.]))
(tensor([1., 3.]), tensor([4.]))
(tensor([1., 4.]), tensor([5.]))
(tensor([1., 5.]), tensor([6.]))


In [5]:
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)
model = MLP(input_size = 2)
loss_fn = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=0.001)

In [6]:
# Train the model for 10 epochs
for epoch in range(10):
    loss = 0.0
    for number_pairs, sums in dataloader: # iterate over the batches
        predictions = model(number_pairs) # Compute model output
        loss = loss_fn(predictions, sums) # Compute loss
        loss.backward() # Compute gradients
        optimizer.step() # Update model parameters
        optimizer.zero_grad() # Reset gradients
        
        loss += loss.item() # Add the loss to all batches
    print("Epoch {}: Sum of Batch Losses = {:.5f}".format(epoch, loss))

Epoch 0: Sum of Batch Losses = 2700.63599
Epoch 1: Sum of Batch Losses = 295.21207
Epoch 2: Sum of Batch Losses = 30.89520
Epoch 3: Sum of Batch Losses = 4.36074
Epoch 4: Sum of Batch Losses = 0.05190
Epoch 5: Sum of Batch Losses = 5.36319
Epoch 6: Sum of Batch Losses = 0.74097
Epoch 7: Sum of Batch Losses = 0.40672
Epoch 8: Sum of Batch Losses = 0.09687
Epoch 9: Sum of Batch Losses = 3.95036
