In [None]:
import matplotlib.pyplot as plt
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets, transforms
from torchsummary import summary

In [None]:
# HYPERPARAMETERS
T = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])
TRAIN_TEST_SPLIT = 0.9
BATCH_SIZE = 64
EPOCHS = 50

Part 1. Loading the Dataset

In [None]:
dataset = datasets.CIFAR10(root="./CIFAR10", download=True, transform=T)
trainDataset, testDataset = random_split(dataset, [TRAIN_TEST_SPLIT, 1 - TRAIN_TEST_SPLIT])
trainLoader = DataLoader(trainDataset,
                         batch_size=BATCH_SIZE,
                         shuffle=True)
testLoader  = DataLoader(testDataset,
                         batch_size=BATCH_SIZE,
                         shuffle=True)

Part 2. Defining the Neural Network Architecture

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        # size_out = (size_in + 2 * padding - kernel) / stride
        self.encoder = nn.Sequential( 
            # INPUT 3x32x32
            nn.Conv2d(3, 3, kernel_size=3, stride=2, padding=1),    # 3x16x16
            nn.ReLU(),
            nn.Conv2d(3, 1, kernel_size=3, stride=2, padding=1),    # 1x8x8
        )
        # size_out = (size_in - 1) * stride - 2 * padding + kernel
        self.decoder = nn.Sequential(
            # INPUT 6x8x8
            nn.ConvTranspose2d(1, 3, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(3, 3, kernel_size=4, stride=2, padding=1)
        )
    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

model = AutoEncoder()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
summary(model, (3,32,32))

Part 3. Train the Model

In [None]:
with torch.no_grad():
    for data,_ in testLoader:
        data = data.to(device)
        recon = model(data)
        break
    plt.figure(dpi=250)
    fig, ax = plt.subplots(2, 7, figsize=(15,4))
    for i in range(7):
        ax[0, i].imshow(data[i].cpu().numpy().transpose((1, 2, 0)))
        ax[1, i].imshow(recon[i].cpu().numpy().transpose((1, 2, 0)))
        ax[0, i].axis("OFF")
        ax[1, i].axis("OFF")
    plt.show()

Part 6. Unsupervised Classification