In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [15]:
class AddNet(nn.Module):
    def __init__(self):
        super(AddNet, self).__init__()
        # two inputs, one output
        self.fc = nn.Linear(2, 1)
    
    def forward(self, x):
        return self.fc(x)

In [3]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.zeros_(m.bias)

In [4]:
torch.rand(5, 2)


tensor([[0.4103, 0.2921],
        [0.0812, 0.7048],
        [0.6621, 0.2070],
        [0.3180, 0.5248],
        [0.9884, 0.7344]])

In [10]:
def generate_normalized_data(num_samples):
    x = torch.rand((10, 2))
    x_mean = x.mean(dim=0, keepdim=True)
    x_std = x.std(dim=0, keepdim=True)
    x_normalized = (x - x_mean) / x_std

    y = x.sum(dim=1, keepdim=True)
    y_normalized = (y - y.mean()) / y.std()
    return x_normalized, y_normalized, x_mean, x_std, y.mean(), y.std()


In [11]:
def eval(model, x_train_mean, x_train_std, y_train_mean, y_train_std, test_samples=10):
    x_test = torch.rand((test_samples, 2)) * 100
    y_actual = x_test.sum(dim=1, keepdim=True)
    x_test_normalized = (x_test - x_train_mean) / x_train_std

    with torch.no_grad():
        y_pred_normalized = model(x_test_normalized)
        y_pred = y_pred_normalized * y_train_std + y_train_mean
    
    print(f"{'Input':<25}{'Predicted':<15}{'Actual':<15}{'Error':<10}")
    print("-" * 65)

    for i in range(test_samples):
        input_vals = f"{x_test[i,0].item():.2f}, {x_test[i,1].item():.2f}"
        predicted = y_pred[i].item()
        actual = y_actual[i].item()
        error = abs(predicted - actual)
        print(f"{input_vals:<25}{predicted:<15.2f}{actual:<15.2f}{error:<10.2f}")

In [17]:
model = AddNet()
model.apply(init_weights)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

In [18]:
x_train, y_train, x_mean, x_std, y_mean, y_std = generate_normalized_data(30000)
epochs = 5000

In [22]:
for epoch in range(epochs):
    # forward pass
    predictions = model(x_train)
    loss = criterion(predictions, y_train)

    optimizer.zero_grad()
    loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")


Epoch 0, Loss: 1.5166
Epoch 100, Loss: 1.3292
Epoch 200, Loss: 1.1566
Epoch 300, Loss: 0.9968
Epoch 400, Loss: 0.8491
Epoch 500, Loss: 0.7134
Epoch 600, Loss: 0.5896
Epoch 700, Loss: 0.4775
Epoch 800, Loss: 0.3788
Epoch 900, Loss: 0.2987
Epoch 1000, Loss: 0.2341
Epoch 1100, Loss: 0.1820
Epoch 1200, Loss: 0.1401
Epoch 1300, Loss: 0.1067
Epoch 1400, Loss: 0.0802
Epoch 1500, Loss: 0.0594
Epoch 1600, Loss: 0.0433
Epoch 1700, Loss: 0.0311
Epoch 1800, Loss: 0.0219
Epoch 1900, Loss: 0.0151
Epoch 2000, Loss: 0.0103
Epoch 2100, Loss: 0.0068
Epoch 2200, Loss: 0.0044
Epoch 2300, Loss: 0.0028
Epoch 2400, Loss: 0.0017
Epoch 2500, Loss: 0.0010
Epoch 2600, Loss: 0.0006
Epoch 2700, Loss: 0.0003
Epoch 2800, Loss: 0.0002
Epoch 2900, Loss: 0.0001
Epoch 3000, Loss: 0.0000
Epoch 3100, Loss: 0.0000
Epoch 3200, Loss: 0.0000
Epoch 3300, Loss: 0.0000
Epoch 3400, Loss: 0.0000
Epoch 3500, Loss: 0.0000
Epoch 3600, Loss: 0.0000
Epoch 3700, Loss: 0.0000
Epoch 3800, Loss: 0.0000
Epoch 3900, Loss: 0.0000
Epoch 4000, 

In [24]:
eval(model, x_mean, x_std, y_mean, y_std, test_samples=100)

Input                    Predicted      Actual         Error     
-----------------------------------------------------------------
17.23, 48.11             65.33          65.34          0.01      
15.28, 36.39             51.65          51.66          0.01      
61.28, 66.13             127.39         127.41         0.02      
68.16, 59.95             128.09         128.11         0.02      
94.89, 27.54             122.42         122.44         0.02      
72.41, 78.74             151.13         151.15         0.02      
14.60, 93.18             107.77         107.78         0.02      
30.31, 97.24             127.53         127.55         0.02      
39.69, 73.55             113.22         113.23         0.02      
65.87, 52.52             118.38         118.39         0.02      
57.10, 91.77             148.86         148.88         0.02      
46.64, 30.56             77.19          77.20          0.01      
32.05, 53.16             85.20          85.21          0.01      
19.88, 17.