# Neural Networks: Learn a Linear Function

This notebook teaches a small neural network to learn the linear function f(a, b) = 3a + 4b.

We'll go step by step:
- Setup (import libraries)
- Define the target function
- Build a simple model
- Train the model
- Test with a few examples


In [None]:
import torch
from torch import nn
import random
import matplotlib.pyplot as plt
print(torch.__version__)


## Define the target function
Our goal is to learn `f(a, b) = 3a + 4b`.


In [None]:
def mystery(a, b):
    return torch.tensor(3*a + 4*b)


## Build a simple model
A single linear layer can represent any linear function.


In [None]:
model = nn.Sequential(nn.Linear(2, 1))
model


## Train the model
We will sample random `(a, b)` pairs and train the network to predict `mystery(a,b)`.


In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
loss_history = []

for i in range(20000):
    a = random.random()
    b = random.random()
    desired = mystery(a, b)

    output = model(torch.tensor([a, b]))
    loss = criterion(output.squeeze(), desired)
    loss_history.append(loss.item())

    if i % 1000 == 0:
        print(f"Loss: {loss.item():.6f}")

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


In [None]:
plt.figure(figsize=(6,4))
plt.plot(loss_history)
plt.title('Training Loss')
plt.xlabel('Step')
plt.ylabel('MSE Loss')
plt.grid(True)
plt.show()


## Test the model


In [None]:
a, b = 1.0, -1.0
pred = model(torch.tensor([a, b]))
print('Model prediction:', pred.item())
print('Expected:', mystery(a, b).item())
