# Odd vs Even MNIST classifier
TASK: Classify digit images into two classes - even (0, 2, 4, 6, 8) / odd (1, 3, 5, 7, 9)
- This file is on ILIAS - download it and open in jupyter notebook
- Make sure you understand the code - **data loading**, **preprocessing**, **model** and **loss function** definitions are ready
- **Implement functions _train()_ and _test()_**
- Train the classifier using these functions

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [9]:
# Label transformer function - transforms 0, ... 9 to 0, 1 (even / odd)
# Label 0 - even, 1 - odd
def label2parity(label):
    return 0 if label % 2 == 0 else 1

In [3]:
transforms = Compose([ToTensor(),
                      Normalize(mean=(0.5,), std=(0.5,))]) # Converts Images to tensors and normalizes to [-1, 1]

# Load train dataset
train_dataset = MNIST(root='data', train=True, download=True,
                      transform=transforms,
                      target_transform=label2parity # transforms labels 0,..,9 to 0 - even, 1 - odd
                     )

# Load test dataset
test_dataset = MNIST(root='data', train=False, download=True,
                      transform=transforms,
                      target_transform=label2parity
                     )

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [4]:
class ParityModel(nn.Module):
    def __init__(self):
        super(ParityModel, self).__init__()
        self.main = nn.Sequential(nn.Linear(28*28, 128),
                                 nn.ReLU(),
                                 nn.Linear(128, 2))
    
    def forward(self, x):
        out = x.view(x.size(0), 28*28)
        out = self.main(out)
        return out

In [5]:
model = ParityModel()
model = model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.005)

### Implement train() function
Implement *train()* function, that trains the model for 1 epoch (processes the entire dataset once)
- iterate over train dataset
    - get predictions with **model**
    - compute loss with **loss_fn**, predictions and labels
    - update parameters with **optimizer**
    
You can get inspired by previous tutorials and PyTorch documentation
- https://pytorch.org/docs/stable/optim.html#taking-an-optimization-step
- https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#train-the-network

In [6]:
def train():
    model.train()
    for iteration, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        output = model(images)
        optimizer.zero_grad()
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()
        # The code below will print loss every 100 iterations. Make sure the loss computer by loss_fn is in variable "loss"
        if iteration % 100 == 0:
            print('Training iteration {}: loss {:.4f}'.format(iteration, loss.item()))

### Implement test() function
- iterate over test dataset
    - get predictions with **model**
    - compute loss with **loss_fn**, predictions and labels
    - compute accuracy for the test dataset  
    - assign average test set loss and accuracy to *average_loss* and *accuracy* variables
    
You can get inspired by previous tutorials and PyTorch documentation
https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#test-the-network-on-the-test-data

In [7]:
def test():
    model.eval()
    test_loss = 0
    n_correct = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            output = model(images)
            loss = loss_fn(output, labels)
            test_loss += loss.item()
            n_correct += torch.sum(output.argmax(1) == labels).item()

    average_loss = test_loss / len(test_loader)
    accuracy = 100.0 * n_correct / len(test_loader.dataset)
    print('Test average loss: {:.4f}, accuracy: {:.3f}'.format(average_loss, accuracy))

### Once you've completed train() and test() functions, run the cell below to train and test for 10 epochs
You should be able to get test accuracy > 95 %

In [8]:
n_epochs = 10
for epoch in range(n_epochs):
    print('Epoch {}'.format(epoch+1))
    train()
    test()

Epoch 1
Training iteration 0: loss 0.6824
Training iteration 100: loss 0.4411
Training iteration 200: loss 0.4094
Training iteration 300: loss 0.3652
Training iteration 400: loss 0.3947
Training iteration 500: loss 0.2979
Training iteration 600: loss 0.2927
Training iteration 700: loss 0.3907
Training iteration 800: loss 0.3239
Training iteration 900: loss 0.3174
Test average loss: 0.2864, accuracy: 88.580
Epoch 2
Training iteration 0: loss 0.2550
Training iteration 100: loss 0.2862
Training iteration 200: loss 0.2440
Training iteration 300: loss 0.1732
Training iteration 400: loss 0.2514
Training iteration 500: loss 0.2322
Training iteration 600: loss 0.3538
Training iteration 700: loss 0.1511
Training iteration 800: loss 0.3168
Training iteration 900: loss 0.2264
Test average loss: 0.2369, accuracy: 90.470
Epoch 3
Training iteration 0: loss 0.2435
Training iteration 100: loss 0.2675
Training iteration 200: loss 0.1746
Training iteration 300: loss 0.3510
Training iteration 400: loss 0

#### If you finished early, try to improve the model to increase test accuracy
Things you could try:
- train for more epochs
- play with learning rate value
- increase model size