# Example of an Image Classifier in PyTorch

In [None]:
from time import time
from pathlib import Path
import itertools

from context import uncertify
from uncertify.common import DATA_DIR_PATH
from uncertify.utils.date_utils import get_date_time_tag

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import numpy as np
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from torch.utils.tensorboard import SummaryWriter

from typing import Tuple

In [None]:
BATCH_SIZE = 64
NUM_WORKERS = 4
DEVICE_MAP = {'gpu': torch.device('cuda:0'), 'cpu': torch.device('cpu')}
CIFAR10_CLASSES = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
CIFAR_TRANSFORM = transforms.Compose([transforms.ToTensor(),
                                 transforms.Normalize(mean=(0, 0, 0),
                                                      std=(1.0, 1.0, 1.0))])
MNIST_TRANSFORM = transforms.Compose([transforms.ToTensor()])

In [None]:
def get_cifar10_data_loaders(transform: transforms.Compose,
                             data_path: Path,
                             batch_size: int,
                             num_workers: int) -> Tuple[DataLoader, DataLoader]:
    train_set = torchvision.datasets.CIFAR10(root=data_path,
                                             train=True,
                                             download=True,
                                             transform=transform)
    train_loader = DataLoader(train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=NUM_WORKERS)

    test_set = torchvision.datasets.CIFAR10(root=data_path,
                                            train=False,
                                            download=True,
                                            transform=transform)
    test_loader = DataLoader(test_set,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=num_workers)
    return train_loader, test_loader

def get_mnist_data_loaders(transform: transforms.Compose,
                             data_path: Path,
                             batch_size: int,
                             num_workers: int) -> Tuple[DataLoader, DataLoader]:
    train_set = torchvision.datasets.MNIST(root=data_path,
                                             train=True,
                                             download=True,
                                             transform=transform)
    train_loader = DataLoader(train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=NUM_WORKERS)

    test_set = torchvision.datasets.MNIST(root=data_path,
                                            train=False,
                                            download=True,
                                            transform=transform)
    test_loader = DataLoader(test_set,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=num_workers)
    return train_loader, test_loader


train_loader, test_loader = get_cifar10_data_loaders(transform=CIFAR_TRANSFORM, data_path=DATA_DIR_PATH / 'cifar10_data',
                                                     batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
train_loader, test_loader = get_mnist_data_loaders(transform=MNIST_TRANSFORM, data_path=DATA_DIR_PATH / 'mnist_data',
                                                     batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

In [None]:
def matplotlib_imshow(img, one_channel=False, plt_show=False):
    if one_channel:
        img = img.mean(dim=0)
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
    if plt_show:
        plt.show()

def plot_n_batches(data_loader: DataLoader, n_first_batches: int, classes: Tuple[str]) -> None:
    # get some random training images
    for images, labels in itertools.islice(data_loader, n_first_batches):
        print(images.shape)
        grid = torchvision.utils.make_grid(images)
        matplotlib_imshow(grid, plt_show=True)
        print(' - '.join([classes[idx] for idx in labels]))

plot_n_batches(train_loader, n_first_batches=1, classes=CIFAR10_CLASSES)

In [None]:
from uncertify.models.gradient import GradientNet

for batch, _ in itertools.islice(train_loader, 1):
    grad = GradientNet()(batch)  # onyl works with 1-channel (greyscale image)
    grid = torchvision.utils.make_grid(grad, normalize=True)
    matplotlib_imshow(grid)

In [None]:
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv_1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv_2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.fc_1 = nn.Linear(in_features=16 * 5 * 5, out_features=120)
        self.fc_2 = nn.Linear(in_features=120, out_features=84)
        self.fc_3 = nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv_1(x)))  # passing input through convolutional layer, then relu then pooling
        x = self.pool(F.relu(self.conv_2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc_1(x))
        x = F.relu(self.fc_2(x))
        x = self.fc_3(x)
        return x

In [None]:
LEARNING_RATE = 0.001

net = Net()
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=LEARNING_RATE, momentum=0.9)

In [None]:
def images_to_probs(network, images):
    """Generates predictions and corresponding probabilities from a trained network and a list of images."""
    output = network(images)
    # convert output probabilities to predicted class
    _, preds_tensor = torch.max(output, 1)
    preds = np.squeeze(preds_tensor.cpu().numpy())
    return preds, [F.softmax(el, dim=0)[i].item() for i, el in zip(preds, output)]


def plot_classes_preds(net, images, labels, classes):
    """Generates matplotlib Figure using a trained network, along with images and labels from a batch,
    that shows the network's top prediction along
    with its probability, alongside the actual label, coloring this
    information based on whether the prediction was correct or not.
    Uses the "images_to_probs" function.
    """
    preds, probs = images_to_probs(net, images)
    # plot the images in the batch, along with predicted and true labels
    fig = plt.figure(figsize=(12, 48))
    for idx in np.arange(4):
        ax = fig.add_subplot(1, 4, idx+1, xticks=[], yticks=[])
        matplotlib_imshow(images[idx].cpu(), one_channel=True)
        ax.set_title("{0}, {1:.1f}%\n(label: {2})".format(
            classes[preds[idx]],
            probs[idx] * 100.0,
            classes[labels[idx]]),
                    color=("green" if preds[idx]==labels[idx].item() else "red"))
    return fig

In [None]:
N_EPOCHS = 2
PRINT_STEPS = 1000
GPU_ON = True

writer = SummaryWriter(DATA_DIR_PATH / "cifar10_runs" / get_date_time_tag())
device = DEVICE_MAP['gpu' if GPU_ON else 'cpu']
net = Net()
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=LEARNING_RATE, momentum=0.9)
net = net.to(device)
start_time = time()
for epoch_idx in range(N_EPOCHS):
    running_loss = 0.0
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        if GPU_ON:
            inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()  # zero all parameter gradients
        outputs = net.forward(inputs)  # feed data through network
        loss = loss_func(outputs, labels)  # compute loss
        loss.backward()  # perform back propagation
        optimizer.step()  # perform an update step on the parameters
        running_loss += loss.item()
        if (batch_idx + 1) % PRINT_STEPS == 0:
            print(f'epoch {epoch_idx + 1:<2} | batch {batch_idx + 1:5}  >>>  loss: {running_loss / PRINT_STEPS:.3f}')
            writer.add_scalar('training loss',
                running_loss / 1000,
                epoch_idx * len(train_loader) + batch_idx)
            writer.add_figure('predictions vs. actuals',
                            plot_classes_preds(net, inputs, labels, CIFAR10_CLASSES),
                            global_step=epoch_idx * len(train_loader) + batch_idx)
            running_loss = 0.0
print(f'Training on {"gpu" if GPU_ON else "cpu"} took {time()-start_time:.2f} seconds.')

In [None]:
MODEL_PATH = DATA_DIR_PATH / 'cifar10_data' / 'cifar_net.pth'
torch.save(net.state_dict(), MODEL_PATH)

In [None]:
import itertools

net = Net()
net.load_state_dict(torch.load(MODEL_PATH))

for images, labels in itertools.islice(test_loader, 1):
    matplotlib_imshow(torchvision.utils.make_grid(images))
    print('GroundTruth: ', ' '.join('%5s' % CIFAR10_CLASSES[labels[j]] for j in range(4)))
    outputs = net(images)
    _, predicted = torch.max(outputs, 1)
    print('Predicted: ', ' '.join('%5s' % CIFAR10_CLASSES[predicted[j]]
                                  for j in range(4)))


In [None]:
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

In [None]:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        CIFAR10_CLASSES[i], 100 * class_correct[i] / class_total[i]))