# MNIST with CNN

In [None]:
import pickle

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from mlcourse.config import Config
from mlcourse.utils.data import show_dataset
from pytorch_model_summary import summary
from sklearn.metrics import (
    ConfusionMatrixDisplay,
    classification_report,
    confusion_matrix,
)
from skorch import NeuralNetClassifier
from skorch.callbacks import Checkpoint, EarlyStopping, LRScheduler
from skorch.helper import predefined_split
from torch.utils.data import Subset
from torchvision.transforms.functional import InterpolationMode

In [None]:
config = Config()

In [None]:
input_size = 28 * 28
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.005
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
mnist_transforms = transforms.Compose(
    [transforms.Resize((28, 28)), transforms.ToTensor()]
)

In [None]:
train_dataset = torchvision.datasets.MNIST(
    root="./data", train=True, transform=mnist_transforms, download=True
)
test_dataset = torchvision.datasets.MNIST(
    root="./data", train=False, transform=mnist_transforms, download=True
)

In [None]:
partial_model = nn.Sequential(
    nn.Conv2d(1, 10, kernel_size=5),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(10, 20, kernel_size=5),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    # nn.Linear(320, 60),
    # nn.ReLU(),
    # nn.Linear(60, 10),
)
print(summary(partial_model, torch.zeros((1, 1, 28, 28)), show_input=True))
print(summary(partial_model, torch.zeros((1, 1, 28, 28))))

In [None]:
conv_model = nn.Sequential(
    nn.Conv2d(1, 10, kernel_size=5),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(10, 20, kernel_size=5),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(320, 60),
    nn.ReLU(),
    nn.Linear(60, 10),
    # nn.Softmax(dim=1),
)

In [None]:
print(summary(conv_model, torch.zeros((1, 1, 28, 28))))
print(summary(conv_model, torch.zeros((1, 1, 28, 28)), show_input=True))

In [None]:
cnn_classifier = NeuralNetClassifier(
    conv_model,
    criterion=nn.CrossEntropyLoss,
    batch_size=100,
    max_epochs=2,
    lr=0.1,
    iterator_train__shuffle=True,
    train_split=predefined_split(test_dataset),
    device=device,
)

In [None]:
cnn_classifier.fit(train_dataset, None)

In [None]:
cnn_classifier.partial_fit(train_dataset, None)

In [None]:
y_pred_cnn = cnn_classifier.predict(test_dataset)

In [None]:
y_test = np.array([y for _, y in test_dataset])

In [None]:
print(classification_report(y_test, y_pred_cnn))

In [None]:
print(confusion_matrix(y_test, y_pred_cnn))

In [None]:
plt.figure(figsize=(10, 8))
ax = plt.axes()
ConfusionMatrixDisplay.from_predictions(y_test, y_pred_cnn, ax=ax)


## Finding Misclassified Images

In [None]:
def find_misclassified_images(y_pred=y_pred_cnn):
    return np.where(y_test != y_pred)[0]

In [None]:
find_misclassified_images(y_pred_cnn)

In [None]:
misclassified_ds = Subset(test_dataset, find_misclassified_images())

In [None]:
show_dataset(misclassified_ds)


## Data Augmentation (V2)

In [None]:
augmented_transforms = transforms.Compose(
    [
        transforms.RandomApply(
            [
                transforms.Resize((56, 56)),
                transforms.RandomResizedCrop(
                    28, (0.8, 1.0), interpolation=InterpolationMode.BICUBIC
                ),
                transforms.RandomApply(
                    [
                        transforms.RandomAffine(
                            degrees=15.0,
                            translate=(0.08, 0.8),
                            interpolation=InterpolationMode.BICUBIC,
                        )
                    ],
                    0.5,
                ),
            ]
        ),
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
    ]
)

In [None]:
augmented_train_dataset = torchvision.datasets.MNIST(
    root="./data", train=True, transform=augmented_transforms, download=True
)

In [None]:
cnn_classifier = NeuralNetClassifier(
    conv_model,
    criterion=nn.CrossEntropyLoss,
    batch_size=100,
    max_epochs=2,
    optimizer=torch.optim.Adam,
    lr=1e-3,
    iterator_train__shuffle=True,
    train_split=predefined_split(test_dataset),
    device=device,
)

In [None]:
cnn_classifier.fit(augmented_train_dataset, None)


## Callbacks

In [None]:
step_lr_scheduler = LRScheduler(policy="StepLR", step_size=5, gamma=0.1)

In [None]:
checkpoint = Checkpoint(
    f_pickle="mnist_cnn.pkl",
    dirname=config.model_dir_path.as_posix(),
    monitor="valid_acc_best",
)

In [None]:
early_stopping = EarlyStopping(monitor="valid_acc", patience=5, lower_is_better=False)

In [None]:
cnn_classifier = NeuralNetClassifier(
    conv_model,
    criterion=nn.CrossEntropyLoss,
    batch_size=100,
    max_epochs=200,
    optimizer=torch.optim.Adam,
    lr=1e-3,
    iterator_train__shuffle=True,
    train_split=predefined_split(test_dataset),
    callbacks=[step_lr_scheduler, checkpoint, early_stopping],
    device=device,
)

In [None]:
cnn_classifier.fit(augmented_train_dataset, None)

In [None]:
with open(config.model_dir_path / "mnist_cnn.pkl", "rb") as file:
    loaded_classifier = pickle.load(file)

In [None]:
y_pred_loaded = loaded_classifier.predict(test_dataset)

In [None]:
print(classification_report(y_test, y_pred_loaded))

In [None]:
print(confusion_matrix(y_test, y_pred_loaded))

## Workshop Fashion MNIST mit CNN

Trainieren Sie ein Konvolutionsnetz, das Bilder aus dem Fashion MNIST Datenset
klassifizieren kann.

(Zur Erinnerung: Das Torch `Dataset` für Fashion MNIST kann mit der Klasse
`torchvision.datasets.FashionMNIST` erzeugt werden.)