In [None]:
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

import os

data_path = "~/code/personal/adversarial-attack/data"

In [None]:
def get_mnist_loaders(batch_size=16, data_path=None):
    """
    Function to download MNIST dataset and return train- and test-loaders.
    """

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_dataset = MNIST(root=data_path, train=True, download=True, transform=transform)
    test_dataset = MNIST(root=data_path, train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

train, test = get_mnist_loaders(16, data_path)

In [None]:
import matplotlib.pyplot as plt

# Display images from the first batch of the test loader
train_iterator = iter(train)
batch = next(train_iterator)

def display_n_images(batch, num_to_display : int = 8):
    images, labels = batch
    fig, axes = plt.subplots(1, num_to_display, figsize=(12, 4))
    for i in range(num_to_display):
        image = transforms.ToPILImage()(images[i])  # Convert tensor to PIL Image
        axes[i].imshow(image, cmap='gray')
        axes[i].axis('off')
        axes[i].set_title(f"L: {labels[i].item()}")
    plt.show()

display_n_images(batch, 8)

img = batch[0][0]

In [None]:
from torch import nn

class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()

        self.l1 = nn.Sequential(
            nn.Conv2d(1, 16, 5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )

        self.l2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )

        self.fc1 = nn.Linear(32*7*7, 10)
    
    def forward(self, x):
        x = self.l1(x)
        x = self.l2(x)
        x = x.view(-1, 32*7*7)
        return self.fc1(x)

In [None]:
c = Classifier()
pred = c(img)
print(pred)

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(c.parameters(), lr=0.001, momentum=0.9)

In [None]:
import torch

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Move the model to CUDA
c.to(device)

In [None]:
for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(train):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = c(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 500 == 499:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 500:.3f}')
            running_loss = 0.0

print('Finished Training')

In [None]:
torch.save(c.state_dict(), './weights.pth')

In [None]:
import torchvision
import numpy as np

dataiter = iter(test)
images, labels = next(dataiter)

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(8)))

In [None]:
c = Classifier()
c.load_state_dict(torch.load('./weights.pth'))

outputs = c(images)

_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}'
                              for j in range(8)))

In [None]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in test:
        images, labels = data
        # calculate outputs by running images through the network
        outputs = c(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

In [None]:
# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

# again no gradients needed
with torch.no_grad():
    for data in test:
        images, labels = data
        outputs = c(images)
        _, predictions = torch.max(outputs, 1)
        # collect the correct predictions for each class
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1


# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

# Model Evaluation

In [None]:
c.eval()

def get_predictions(model, dataloader):
    all_predictions = []

    with torch.no_grad():
        for images, _ in dataloader:
            # Forward pass
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            # Convert predictions to CPU and numpy for further analysis if needed
            all_predictions.extend(predicted.cpu().numpy())

    return all_predictions

# Get predictions for the test dataset
test_predictions = get_predictions(c, test)

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# Assuming you have the ground truth labels for the test dataset
# In the case of MNIST, these would be the actual digits (0-9)
actual_labels = [label for _, label in test.dataset]

# Calculate the confusion matrix
cm = confusion_matrix(actual_labels, test_predictions)

# Visualize the confusion matrix using seaborn
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()
