## Training a Neural Network

In this exercise, you will have to implement all aspects of a neural network training script. This includes creating data loading and model training, testing logic in your code. Here are the steps you need to do to complete this exercise:

1. Finish the `train()` and `test()` functions. These should contain loops to train and test your model respectively.
2. Finish the `create_model()` function. You can use the same model from a previous exercise.
3. Create Data Transforms and Data Loaders to download, read and preprocess your data.
4. Add a cost function and optimizer. You can use the same cost functions and optimizer from the previous exercise.

**Note**: It may take 5 - 10 minutes to download the data sets. 

In case you get stuck, you can look at the solution notebook.

In [1]:
import numpy as np
import torch
from torchvision import datasets, transforms
from torch import nn, optim

def train(model, train_loader, cost, optimizer, epoch):
    model.train()
    for e in range(epoch):
        running_loss=0
        correct=0
        for data, target in train_loader:
            data = data.view(data.shape[0], -1)
            optimizer.zero_grad()
            pred = model(data)
            loss = cost(pred, target)
            running_loss+=loss
            loss.backward()
            optimizer.step()
            pred=pred.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
        print(f"Epoch {e}: Loss {running_loss/len(train_loader.dataset)}, Accuracy {100*(correct/len(train_loader.dataset))}%")

def test(model, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data = data.view(data.shape[0], -1)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    print(f'Test set: Accuracy: {correct}/{len(test_loader.dataset)} = {100*(correct/len(test_loader.dataset))}%)')

def create_model():
    # Build a feed-forward network
    input_size = 784
    output_size = 10
    model = nn.Sequential(nn.Linear(input_size, 128),
                          nn.ReLU(),
                          nn.Linear(128, 64),
                          nn.ReLU(),
                          nn.Linear(64, output_size),
                          nn.LogSoftmax(dim=1))

    return model

training_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])

testing_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])

# Set Hyperparameters
batch_size = 64
epoch = 10

# Download and load the training data
trainset = datasets.MNIST('data/', download=True, train=True, transform=training_transform)
testset = datasets.MNIST('data/', download=True, train=False, transform=testing_transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)


model = create_model()

cost = nn.NLLLoss()

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


train(model, train_loader, cost, optimizer, epoch)
test(model, test_loader)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 100029861.07it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 71677925.34it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 28084039.83it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 11455519.40it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Epoch 0: Loss 0.016076793894171715, Accuracy 70.215%
Epoch 1: Loss 0.00779104745015502, Accuracy 84.38333333333333%
Epoch 2: Loss 0.0063056848011910915, Accuracy 87.64999999999999%
Epoch 3: Loss 0.0051855361089110374, Accuracy 90.075%
Epoch 4: Loss 0.004431516397744417, Accuracy 91.72166666666666%
Epoch 5: Loss 0.0039262305945158005, Accuracy 92.58833333333332%
Epoch 6: Loss 0.0035564308054745197, Accuracy 93.24166666666667%
Epoch 7: Loss 0.0032825872767716646, Accuracy 93.71166666666667%
Epoch 8: Loss 0.003045401070266962, Accuracy 94.17166666666667%
Epoch 9: Loss 0.0028317668475210667, Accuracy 94.64833333333334%
Test set: Accuracy: 9446/10000 = 94.46%)
