In [None]:
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets, transforms
from torchsummary import summary

In [None]:
# HYPERPARAMETERS
T = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])
TRAIN_TEST_SPLIT = 0.9
BATCH_SIZE = 64
EPOCHS = 50
LEARNING_RATE = 0.001

Part 1. Loading the Dataset

In [None]:
dataset = datasets.CIFAR10(root="./CIFAR10", download=True, transform=T)
trainDataset, testDataset = random_split(dataset, [TRAIN_TEST_SPLIT, 1 - TRAIN_TEST_SPLIT])
trainLoader = DataLoader(trainDataset,
                         batch_size=BATCH_SIZE,
                         shuffle=True)
testLoader  = DataLoader(testDataset,
                         batch_size=BATCH_SIZE,
                         shuffle=True)

Part 2. Defining the Neural Network Architecture

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        # size_out = (size_in + 2 * padding - kernel) / stride
        self.encoder = nn.Sequential( 
            # INPUT 3x32x32
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),    # 3x16x16
            nn.ReLU(),
            nn.Conv2d(32, 1, kernel_size=3, stride=2, padding=1),    # 1x8x8
        )
        # size_out = (size_in - 1) * stride - 2 * padding + kernel
        self.decoder = nn.Sequential(
            # INPUT 6x8x8
            nn.ConvTranspose2d(1, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

model = AutoEncoder()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
summary(model, (3,32,32))

Part 3. Train the Model

In [None]:
def train_model(model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer) -> list[float]:
    losses = []
    for epoch in range(1,EPOCHS+1):
        epoch_loss = 0
        for x,_ in trainLoader:
            x = x.to(device)
            optimizer.zero_grad()
            # forward pass
            recon = model(x)
            # calculate gradients
            loss: torch.Tensor = criterion(recon, x)
            loss.backward()
            epoch_loss += loss.detach().sum().item()
            # backward pass
            optimizer.step()
        losses.append(epoch_loss)
        # display results after every 10 epochs
        if epoch % 10 == 0:
            with torch.no_grad():
                # get a singlar batch of images
                for x,_ in testLoader:
                    x = x.to(device)
                    y = model(x)
                    break
                # torch image is BSxCxWxH but matplot is WxHxC
                x = x.permute(0,2,3,1)
                y = y.permute(0,2,3,1)
                # unnormalize image for viewing purposes
                # torch normalization is out = (in - mean) / std
                #                            = (in - 0.5) / 0.5
                # reversing is   in = out * 0.5 + 0.5
                x = x * 0.5 + 0.5
                y = y * 0.5 + 0.5
                # create plot
                plt.figure(dpi=250)
                _, ax = plt.subplots(2, 7, figsize=(15,4))
                ax[0, 3].set_title(f"Epoch {epoch}")
                for i in range(7):
                    ax[0, i].imshow(x[i])
                    ax[1, i].imshow(y[i])
                    ax[0, i].axis("OFF")
                    ax[1, i].axis("OFF")
                plt.show()
    return losses

losses = train_model(model, nn.MSELoss(), optim.Adam(model.parameters(), lr=LEARNING_RATE))

In [None]:
_, ax = plt.subplots()
ax.set_title("Autoencoder Loss")
ax.plot(losses)
ax.set(xlabel="epoch", ylabel="loss")
ax.grid()
plt.show()


Part 6. Unsupervised Classification