# Optimizing Model Parameters
Now that we have a model and data, let's use this notebook to to *train*, *validate* and *test* our model by optimizing its parameters on our data. 

Training is an iterative process; In each iteration, the model makes a guess about the output, calculates the error based on its guess (loss), and optimizes the parameters using gradient descent based on the loss.

## Load Prerequisite Code

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()

100.0%
100.0%
100.0%
100.0%


## Hyperparameters
*Hyperparameters* are adjustable parameters that let you control the model optimization process. They are different from the model parameters, which are learned during training.

In this example, we will define the following hyperparameters:
- `Number of Epochs`: The number of times to iterate over the training dataset.
- `Batch Size`: The number of data samples propagated through the network before the parameters are updated.
- `Learning Rate`: How much to update the model parameters at each batch based on the loss gradient. A small learning rate means the model learns slowly, while a large learning rate means it learns quickly but may overshoot the optimal parameters.

In [2]:
learning_rate = 1e-3
batch_size = 64
epochs = 5

## Optimization Loop
Once our hyperparameters are set, we can train our model with an optimization loop. Each iteration of the loop is called an *epoch*. Each epochs consists of 2 main parts:
1. **Training**: The model is trained on the training dataset.
2. **Validation / Testing**: Iterate over the validation/test dataset to check if model performance is improving.

In [3]:
def train_loop(dataloader, model, loss_fn, optimizer):
    '''
    Parameters
    ----------
    dataloader : DataLoader
        The data loader for training data.
    model : nn.Module
        The neural network model to train.
    loss_fn : nn.Module
        The loss function to use.
    optimizer : torch.optim.Optimizer
        The optimizer to use for training.
    '''
    size = len(dataloader.dataset)          # total number of training samples

    # Set the model to training mode (enables dropout, batch norm, etc.)
    # unncessary for this simple model but good practice
    model.train()

    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()   # clear previous gradients
        loss.backward()         # compute gradients
        optimizer.step()        # update weights

        if batch % 100 == 0:
            # Print loss every 100 batches
            loss, current = loss.item(), batch * batch_size + len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test_loop(dataloader, model, loss_fn):
    '''
    Parameters
    ----------
    dataloader : DataLoader
        The data loader for test data.
    model : nn.Module
        The neural network model to test.
    loss_fn : nn.Module
        The loss function to use.
    '''
    size = len(dataloader.dataset)          # total number of test samples
    num_batches = len(dataloader)          # total number of batches

    # Set the model to evaluation mode (disables dropout, batch norm, etc.)
    # unncessary for this simple model but good practice
    model.eval()
    test_loss, correct = 0, 0

    '''
    Evaluating the model with torch.no_grad() ensures that gradients are not computed to reduce unnecessary memory usage and speed up computations.
    '''
    with torch.no_grad():                   # no need to track gradients during evaluation
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()   # accumulate loss
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()  # count correct predictions

    test_loss /= num_batches
    correct /= size

    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


### Initialize the Loss Function and Optimizer

In [4]:
loss_fn = nn.CrossEntropyLoss()  # loss function for multi-class classification
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)  # SGD optimizer

epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.304772  [   64/60000]
loss: 2.287111  [ 6464/60000]
loss: 2.266696  [12864/60000]
loss: 2.268206  [19264/60000]
loss: 2.252316  [25664/60000]
loss: 2.220776  [32064/60000]
loss: 2.235157  [38464/60000]
loss: 2.203969  [44864/60000]
loss: 2.193308  [51264/60000]
loss: 2.159923  [57664/60000]
Test Error: 
 Accuracy: 38.1%, Avg loss: 2.157460 

Epoch 2
-------------------------------
loss: 2.163246  [   64/60000]
loss: 2.151571  [ 6464/60000]
loss: 2.092443  [12864/60000]
loss: 2.116298  [19264/60000]
loss: 2.065917  [25664/60000]
loss: 1.997988  [32064/60000]
loss: 2.029382  [38464/60000]
loss: 1.955623  [44864/60000]
loss: 1.952384  [51264/60000]
loss: 1.879130  [57664/60000]
Test Error: 
 Accuracy: 57.8%, Avg loss: 1.884183 

Epoch 3
-------------------------------
loss: 1.909723  [   64/60000]
loss: 1.883273  [ 6464/60000]
loss: 1.765719  [12864/60000]
loss: 1.814987  [19264/60000]
loss: 1.701680  [25664/60000]
loss: 1.646752  [32064/600