In [None]:
from torchvision.datasets import MNIST
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Compose, Normalize
import matplotlib.pyplot as plt
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform = None

training_mnist = MNIST("./data", download=True, train=True, transform=transform)
test_mnist = MNIST("./data", download=True, train=False, transform=transform)


training_fashion_mnist = FashionMNIST(
    "./data", download=True, train=True, transform=transform
)
test_fashion_mnist = FashionMNIST(
    "./data", download=True, train=False, transform=transform
)

In [None]:
def shuffle_dataset(dataset):
    random_permutation = torch.randperm(len(dataset))
    return dataset.data[random_permutation], dataset.targets[random_permutation]


def generate_dataset(dataset_1, dataset_2):
    left_data = dataset_1.data
    left_labels = dataset_1.targets

    right_data, right_labels = shuffle_dataset(dataset_1)

    center_data = dataset_2.data
    center_labels = dataset_2.targets

    data = (
        torch.cat((left_data, center_data, right_data), 2)
        .float()
        .unsqueeze(1)
        .to(device)
    )
    data = data / 255
    labels = torch.where(center_labels % 2 == 0, left_labels, right_labels).to(device)
    return torch.utils.data.TensorDataset(data, labels)


train_dataset = generate_dataset(training_fashion_mnist, training_mnist)
test_dataset = generate_dataset(test_fashion_mnist, test_mnist)

In [None]:
model = torch.nn.Sequential(
    torch.nn.Conv2d(1, 32, 3),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2),
    torch.nn.Conv2d(32, 64, 3),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2),
    torch.nn.Conv2d(64, 64, 3),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2),
    torch.nn.Flatten(),
    torch.nn.Dropout(0.5),
    torch.nn.Linear(512, 32),
    torch.nn.ReLU(),
    torch.nn.Linear(32, 10),
).to(device)


loss_fn = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
n_epochs = 15

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=True)

In [None]:
def get_accuracy(model, testloader):
    acc = 0
    count = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            y_pred = model(inputs)
            acc += (y_pred.argmax(1) == labels).sum()
            count += len(labels)
    return acc / count


test_acc = []
for epoch in range(n_epochs):
    batch = 0
    for inputs, labels in trainloader:
        print("Processing batch %d of 468" % batch, end="\r")
        y_pred = model(inputs)
        optimizer.zero_grad()
        loss = loss_fn(y_pred, labels)
        loss.backward()
        optimizer.step()
        batch += 1

    acc = get_accuracy(model, testloader)
    print("Epoch: %d, accuracy: %.3f%%" % (epoch, acc * 100))
    test_acc.append(acc)

In [None]:
fmnist_labels = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]


def plot_mnist_dataset(dataset, fname):
    rows = 3
    cols = 3
    fig, ax = plt.subplots(rows, cols, figsize=(10, 10))
    for i in range(rows):
        for j in range(cols):
            image = dataset[i * cols + j][0]
            label = dataset[i * cols + j][1]
            ax[i][j].imshow(image, cmap="gray")
            ax[i][j].set_title(label)
    plt.savefig(fname)


def plot_fmnist_dataset(dataset, fname):
    rows = 3
    cols = 3
    fig, ax = plt.subplots(rows, cols, figsize=(10, 10))
    for i in range(rows):
        for j in range(cols):
            image = dataset[i * cols + j][0]
            label = dataset[i * cols + j][1]
            ax[i][j].imshow(image, cmap="gray")
            ax[i][j].set_title(fmnist_labels[label])
    plt.savefig(fname)


def plot_my_dataset(dataset, fname):
    rows = 3
    cols = 3
    fig, ax = plt.subplots(rows, cols, figsize=(10, 10))
    for i in range(rows):
        for j in range(cols):
            image = dataset[i * cols + j][0][0].cpu()
            label = dataset[i * cols + j][1].item()
            ax[i][j].imshow(image, cmap="gray")
            ax[i][j].set_title(fmnist_labels[label])
    plt.savefig(fname)

def plot_model_pred(dataset, model, fname):
    rows = 3
    cols = 3
    fig, ax = plt.subplots(rows, cols, figsize=(10, 10))
    for i in range(rows):
        for j in range(cols):
            image = dataset[i * cols + j][0][0].cpu()
            label = dataset[i * cols + j][1].item()
            ax[i][j].imshow(image, cmap="gray")
            ax[i][j].set_title(fmnist_labels[label])
            y_pred = model(dataset[i * cols + j][0].unsqueeze(0).to(device))
            ax[i][j].set_xlabel(fmnist_labels[y_pred.argmax(1).item()])
    plt.savefig(fname)

def plot_mnist_accuracy(test_acc, fname):
    fig, ax = plt.subplots()
    test_acc = [x.item() for x in test_acc]
    ax.plot(test_acc)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Accuracy")
    ax.set_title("Test accuracy")
    plt.savefig(fname)



plot_mnist_dataset(training_mnist, "report/images/training_mnist.png")
plot_fmnist_dataset(training_fashion_mnist, "report/images/training_fashion_mnist.png")
plot_my_dataset(train_dataset, "report/images/train_dataset.png")
plot_model_pred(test_dataset, model, "report/images/test_dataset.png")
plot_mnist_accuracy(test_acc, "report/images/test_acc.png")
