In [None]:
import sys
import os

import torch
from torch.utils.tensorboard import SummaryWriter
%load_ext tensorboard

%load_ext autoreload
%autoreload 2

module_path = os.path.abspath(os.path.join('.'))
if module_path not in sys.path:
    sys.path.append(module_path)


## Download MNIST

In [None]:
import torchvision

transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(0.5, 1.0)])

train_data = torchvision.datasets.MNIST(
    root='data',
    train=True,
    transform=transform,
    download=True,
)

test_data = torchvision.datasets.MNIST(
    root='data',
    train=False,
    transform=transform,
    download=True
)

image_size=28
num_classes=10


## View Data

In [None]:
import matplotlib.pyplot as plt

figure = plt.figure(figsize=(10, 8))
cols, rows = 5, 5
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(train_data), size=(1,)).item()
    img, label = train_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img.numpy().squeeze(), cmap="gray")
plt.show()


## Data loader

In [None]:
train_valid_ratio = 0.9
train_data_nr = int(len(train_data) * train_valid_ratio)
valid_data_nr = len(train_data) - train_data_nr

train_data, valid_data = torch.utils.data.random_split(
    train_data, [train_data_nr, valid_data_nr], generator=torch.Generator().manual_seed(0))

batch_size = 500

loaders = {
    'train': torch.utils.data.DataLoader(train_data,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=2),
    'valid': torch.utils.data.DataLoader(valid_data,
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=2),
    'test': torch.utils.data.DataLoader(test_data,
                                        batch_size=batch_size,
                                        shuffle=False,
                                        num_workers=1)
}


## Training

In [None]:
from torch.autograd import Variable


def train_model(loaders, num_epochs: int, model: torch.nn.Module, loss_func, optimizer: torch.optim.Optimizer, writer=None, lr_scheduler=None):
    last_valid_loss = float("Inf")

    for epoch in range(num_epochs):

        model.train()
        sum_loss_train = 0
        for images, labels in loaders['train']:
            # gives batch data, normalize x when iterate train_loader
            if torch.cuda.is_available():
                images, labels = images.cuda(
                    non_blocking=True), labels.cuda(non_blocking=True)

            b_x = Variable(images)
            b_y = Variable(labels)

            output = model(b_x)
            loss = loss_func(output, b_y)

            # clear gradients for this training step
            optimizer.zero_grad()

            # backpropagation, compute gradients
            loss.backward()
            # apply gradients
            optimizer.step()

            sum_loss_train += loss.item()

        if writer:
            writer.add_scalar("AvgTrainLoss/Epoch",
                              sum_loss_train / len(loaders['train']), epoch)
            writer.flush()

        model.eval()
        sum_loss_valid = 0
        for i, (images, labels) in enumerate(loaders['valid']):
            if torch.cuda.is_available():
                images, labels = images.cuda(
                    non_blocking=True), labels.cuda(non_blocking=True)
            b_x = Variable(images)
            b_y = Variable(labels)

            output = model(b_x)
            loss = loss_func(output, b_y)
            sum_loss_valid += loss.item()

        last_valid_loss = sum_loss_valid / len(loaders['valid'])

        print(f"Epoch [{epoch + 1:2d}/{num_epochs}], TrainLoss: {sum_loss_train / len(loaders['train']):.4f}, ValidLoss: {last_valid_loss:.4f}, LearningRate: {optimizer.param_groups[0]['lr']:.5f}")

        if writer:
            writer.add_scalar("AvgValidLoss/Epoch",
                              last_valid_loss, epoch)
            writer.flush()

        if lr_scheduler:
            lr_scheduler.step()

    return last_valid_loss


In [None]:
from transformer import GrayscaleVisionTransformer, ConvStemConfig


#transformer hyper prams
patch_size=7
num_layers=6
num_heads=8
hidden_dim=num_heads * 8
mlp_dim=128
dropout=0.5
attention_dropout=0.5

model = GrayscaleVisionTransformer(image_size=image_size, patch_size=patch_size, num_layers=num_layers, num_heads=num_heads,
                                   hidden_dim=hidden_dim, mlp_dim=mlp_dim, dropout=dropout, attention_dropout=attention_dropout, num_classes=num_classes)
# conv_stem_configs=[ConvStemConfig(out_channels=32, kernel_size=5, padding=2, stride=1), ConvStemConfig(out_channels=32, kernel_size=patch_size, stride=patch_size)]

number_params = sum(p.numel() for p in model.parameters())

number_params

In [None]:
# or CNN

from CNN import CNN

hidden_dim = [32, 32, 64]
dropout = 0.5

model = CNN(image_size=image_size, hidden_dim=hidden_dim,
            dropout=dropout, num_classes=num_classes)

number_params = sum(p.numel() for p in model.parameters())

number_params


In [None]:
# train Transformer
learning_rate = 0.0025
epochs = 40
weight_decay = 0
scheduler_gamma = 0.96

loss_func = torch.nn.CrossEntropyLoss()
if torch.cuda.is_available():
    model = torch.nn.DataParallel(model).cuda()
    loss_func = loss_func.cuda()
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=weight_decay)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
    optimizer=optimizer, gamma=scheduler_gamma)
writer = SummaryWriter()
images, labels = next(iter(loaders['train']))
if torch.cuda.is_available():
    images = images.cuda()
# writer.add_graph(model, images)
valid_loss = train_model(loaders, epochs, model,
                         loss_func, optimizer, writer, lr_scheduler)
writer.add_hparams({'learning_rate': learning_rate, 'batch_size': batch_size, 'patch_size': patch_size, 'num_layers': num_layers, 'num_heads': num_heads,
                   'hidden_dim': hidden_dim, 'mlp_dim': mlp_dim, 'dropout': dropout, 'attention_dropout': attention_dropout, 'weight_decay': weight_decay, 'scheduler_gamma': scheduler_gamma}, {'number_params': number_params, 'valid_loss': valid_loss})
writer.close()
torch.save(model.state_dict(), 'saved_model_transformer.pth')
torch.cuda.empty_cache()
del images, labels


In [None]:
# or train cnn

learning_rate = 0.002
epochs = 40
weight_decay = 1e-3
scheduler_gamma = 0.96

loss_func = torch.nn.CrossEntropyLoss()
if torch.cuda.is_available():
    model = torch.nn.DataParallel(model).cuda()
    loss_func = loss_func.cuda()
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=weight_decay)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
    optimizer=optimizer, gamma=scheduler_gamma)
writer = SummaryWriter()
images, labels = next(iter(loaders['train']))
if torch.cuda.is_available():
    images = images.cuda()
# writer.add_graph(model, images)
valid_loss = train_model(loaders, epochs, model,
                         loss_func, optimizer, writer, lr_scheduler)
writer.add_hparams({'learning_rate': learning_rate, 'batch_size': batch_size, 'dropout': dropout, 'weight_decay': weight_decay,
                   'scheduler_gamma': scheduler_gamma}, {'number_params': number_params, 'valid_loss': valid_loss})
writer.close()
torch.save(model.state_dict(), 'saved_model_cnn.pth')
torch.cuda.empty_cache()
del images, labels


# Loading

In [None]:
# or load transformer
from transformer import GrayscaleVisionTransformer
model = model = GrayscaleVisionTransformer(image_size=image_size, patch_size=patch_size, num_layers=num_layers, num_heads=num_heads,
                                           hidden_dim=hidden_dim, mlp_dim=mlp_dim, dropout=dropout, attention_dropout=attention_dropout, num_classes=num_classes)
if torch.cuda.is_available():
    model = torch.nn.DataParallel(model).cuda()
model.load_state_dict(torch.load('saved_model_transformer.pth'))


In [None]:
# or load cnn
from CNN import CNN
model = CNN()
if torch.cuda.is_available():
    model = torch.nn.DataParallel(model).cuda()
model.load_state_dict(torch.load('saved_model_cnn.pth'))

## Evaluation

In [None]:
from typing import List
import numpy as np


def images_to_probs(model, images):
    '''
    Generates predictions and corresponding probabilities from a trained
    network and a list of images
    '''
    output = model(images)

    # convert output probabilities to predicted class
    _, preds_tensor = torch.max(output, 1)
    preds = np.squeeze(preds_tensor.cpu().numpy())
    return preds, [torch.nn.functional.softmax(el, dim=0)[i].item() for i, el in zip(preds, output)]


def plot_classes_preds(net, images, labels):
    '''
    Generates matplotlib Figure using a trained network, along with images
    and labels from a batch, that shows the network's top prediction along
    with its probability, alongside the actual label, coloring this
    information based on whether the prediction was correct or not.
    Uses the "images_to_probs" function.
    '''
    preds, probs = images_to_probs(net, images)
    # plot the images in the batch, along with predicted and true labels
    fig = plt.figure(figsize=(15, 15))
    for idx in np.arange(6*6):
        ax = fig.add_subplot(6, 6, idx+1, xticks=[], yticks=[])
        img = images[idx]
        img = img.mean(dim=0)
        img = img / 2 + 0.5
        plt.imshow(img.cpu().numpy().squeeze(), cmap="gray")
        ax.set_title(f"pred: {preds[idx]}, prob: { probs[idx] * 100.0:.1f}%, gt: {labels[idx]}",
                     color=("green" if preds[idx] == labels[idx].item() else "red"))
    plt.tight_layout()
    return fig


def accuracy(output: torch.Tensor, target: torch.Tensor, topk=(1,)) -> List[torch.FloatTensor]:
    """
    Computes the accuracy over the k top predictions for the specified values of k
    In top-5 accuracy you give yourself credit for having the right answer
    if the right answer appears in your top five guesses.

    ref:
    - https://pytorch.org/docs/stable/generated/torch.topk.html
    - https://discuss.pytorch.org/t/imagenet-example-accuracy-calculation/7840
    - https://gist.github.com/weiaicunzai/2a5ae6eac6712c70bde0630f3e76b77b
    - https://discuss.pytorch.org/t/top-k-error-calculation/48815/2
    - https://stackoverflow.com/questions/59474987/how-to-get-top-k-accuracy-in-semantic-segmentation-using-pytorch

    :param output: output is the prediction of the model e.g. scores, logits, raw y_pred before normalization or getting classes
    :param target: target is the truth
    :param topk: tuple of topk's to compute e.g. (1, 2, 5) computes top 1, top 2 and top 5.
    e.g. in top 2 it means you get a +1 if your models's top 2 predictions are in the right label.
    So if your model predicts cat, dog (0, 1) and the true label was bird (3) you get zero
    but if it were either cat or dog you'd accumulate +1 for that example.
    :return: list of topk accuracy [top1st, top2nd, ...] depending on your topk input
    """
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        _, y_pred = output.topk(k=maxk, dim=1)
        y_pred = y_pred.t()

        target_reshaped = target.view(1, -1).expand_as(y_pred)
        correct = (y_pred == target_reshaped)
        list_topk_accs = []
        for k in topk:
            ind_which_topk_matched_truth = correct[:k]
            flattened_indicator_which_topk_matched_truth = ind_which_topk_matched_truth.reshape(
                -1).float()
            tot_correct_topk = flattened_indicator_which_topk_matched_truth.float().sum(dim=0,
                                                                                        keepdim=True)
            topk_acc = tot_correct_topk / batch_size
            list_topk_accs.append(topk_acc)
        return list_topk_accs


In [None]:
# calc top1 and top5 accuracy
predictions = []
labels = []
model.eval()
with torch.no_grad():
    for images, labels_batch in loaders['test']:
        if torch.cuda.is_available():
            images = images.cuda(non_blocking=True)
        pred = model(images)
        predictions.append(pred.cpu())
        labels.append(labels_batch)
    predictions = torch.cat(predictions)
    labels = torch.cat(labels)
acc = accuracy(predictions, labels, (1, 5))
torch.cuda.empty_cache()
acc


In [None]:
# plot some images + predictions
writer = SummaryWriter()
images, labels = next(iter(loaders['test']))
if torch.cuda.is_available():
    images, labels = images.cuda(), labels.cuda()
writer.add_figure('predictions vs. actuals',
                plot_classes_preds(model, images, labels))
writer.flush()
writer.close()


In [None]:
#plot wrong predictions
writer = SummaryWriter()
images_wrong = []
predictions_wrong = []
labels_wrong = []
for images, labels in iter(loaders['test']):
    if torch.cuda.is_available():
        images, labels = images.cuda(), labels.cuda()
    predictions, probs = images_to_probs(model, images)

    for pred, label, image in zip(predictions, labels.cpu(), images.cpu()):
        if pred != label.item():
            images_wrong.append(image)
            predictions_wrong.append(pred)
            labels_wrong.append(label)

fig = plt.figure(figsize=(15, 15))
for idx in np.arange(6*6):
    ax = fig.add_subplot(6, 6, idx+1, xticks=[], yticks=[])
    img = images_wrong[idx]
    img = img.mean(dim=0)
    img = img / 2 + 0.5
    plt.imshow(img.cpu().numpy().squeeze(), cmap="gray")
    ax.set_title(f"pred: {predictions_wrong[idx]}, gt: {labels_wrong[idx]}",
                    color="red")
plt.tight_layout()

writer.add_figure('wrong predictions', fig)
writer.flush()
writer.close()