In [None]:
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets, transforms
from torchsummary import summary

In [None]:
# HYPERPARAMETERS
T = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])
TRAIN_TEST_SPLIT = 0.9
BATCH_SIZE = 64
EPOCHS = 50
LEARNING_RATE = 0.001

Part 1. Loading the Dataset

In [None]:
dataset = datasets.CIFAR10(root="./CIFAR10", download=True, transform=T)
trainDataset, testDataset = random_split(dataset, [TRAIN_TEST_SPLIT, 1 - TRAIN_TEST_SPLIT])
trainLoader = DataLoader(trainDataset,
                         batch_size=BATCH_SIZE,
                         shuffle=True)
testLoader  = DataLoader(testDataset,
                         batch_size=BATCH_SIZE,
                         shuffle=True)

Part 2. Defining the Neural Network Architecture

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        # size_out = (size_in + 2 * padding - kernel) / stride
        self.encoder = nn.Sequential( 
            # INPUT 3x32x32
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),    # 32x16x16
            nn.ReLU(),
            nn.Conv2d(32, 1, kernel_size=3, stride=2, padding=1),    # 1x8x8
        )
        # size_out = (size_in - 1) * stride - 2 * padding + kernel
        self.decoder = nn.Sequential(
            # INPUT 6x8x8
            nn.ConvTranspose2d(1, 32, kernel_size=4, stride=2, padding=1), # 32x16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1), # 3x32x32
            nn.Tanh()
        )
    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

model = AutoEncoder()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
summary(model, (3,32,32))

Part 3. Train the Model

In [None]:
def train_model(model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer) -> list[float]:
    losses = []
    for epoch in range(1,EPOCHS+1):
        model.train()
        epoch_loss = 0
        for x,_ in trainLoader:
            x = x.to(device)
            optimizer.zero_grad()
            # forward pass
            recon = model(x)
            # calculate gradients
            loss: torch.Tensor = criterion(recon, x)
            loss.backward()
            epoch_loss += loss.detach().sum().item()
            # backward pass
            optimizer.step()
        losses.append(epoch_loss)
        # display results after every 10 epochs
        if epoch % 10 == 0:
            model.eval()
            with torch.no_grad():
                # get a singlar batch of images
                for x,_ in testLoader:
                    x = x.to(device)
                    y = model(x)
                    break
                # torch image is BSxCxWxH but matplot is WxHxC
                x = x.permute(0,2,3,1)
                y = y.permute(0,2,3,1)
                # unnormalize image for viewing purposes
                # torch normalization is out = (in - mean) / std
                #                            = (in - 0.5) / 0.5
                # reversing is   in = out * 0.5 + 0.5
                x = x * 0.5 + 0.5
                y = y * 0.5 + 0.5
                # create plot
                plt.figure(dpi=250)
                _, ax = plt.subplots(2, 7, figsize=(15,4))
                ax[0, 3].set_title(f"Epoch {epoch}")
                for i in range(7):
                    ax[0, i].imshow(x[i].cpu())
                    ax[1, i].imshow(y[i].cpu())
                    ax[0, i].axis("OFF")
                    ax[1, i].axis("OFF")
                plt.show()
        # check for convergence
        if epoch > 1 and abs(losses[epoch - 1] - losses[epoch - 2]) < 0.001:
            model.eval()
            break
    return losses

losses = train_model(model, nn.MSELoss(), optim.Adam(model.parameters(), lr=LEARNING_RATE))

In [None]:
_, ax = plt.subplots()
ax.set_title("Autoencoder Loss")
ax.plot(losses)
ax.set(xlabel="epoch", ylabel="loss")
ax.grid()
plt.show()

Part 5. Mean and Variance of Latent Vectors per Class

In [None]:
def plot_class_distances(encoder: nn.Module, n_classes: int, dataset: Dataset) -> None:
    # Apply the encoder on the test set to get a set of z vectors.
    Z: list[list[torch.Tensor]] = [[] for _ in range(n_classes)]
    for x,c in dataset:
        x = x.to(device)
        Z[c].append( encoder(x).cpu() )

    # Find the mean for each class.
    mean: list[torch.Tensor] = [0 for _ in range(n_classes)]
    for c in range(n_classes):
        mean[c] = sum(Z[c]) / len(Z[c])

    # Find the distance between all z_c and z_c_mean for each class.
    dist: list[float] = [0 for _ in range(n_classes)]
    for c in range(n_classes):
        for z in Z[c]:
            dist[c] += (z - mean[c]).sum().pow(2).item()
        dist[c] /= len(Z[c])

    # Plot distances
    plt.bar(["airplane","car","bird","cat","deer","dog","frog","horse","ship","truck"], dist)

with torch.no_grad():
    plot_class_distances(model.encoder, 10, testDataset)

Part 6. Unsupervised Classification

In [None]:
from sklearn.cluster import KMeans
import numpy as np

def unsupervised_classification(encoder: nn.Module, n_classes: int, dataset: Dataset, pie_labels: list[str]) -> None:
    # Apply the encoder on the test set to get a set of z vectors.
    C: list[int] = [] # store class for each vector
    V: np.ndarray = np.ones(shape=(len(dataset),2)) # sklearn kmeans only works on 2D data, so can only store two features
    for idx,(x,c) in enumerate(dataset):
        x = x.to(device) # 3x32x32
        z = encoder(x).cpu().squeeze() # 8x8
        V[idx][0] *= z[0].sum()
        V[idx][1] *= z[1].sum()
        C.append(c) # record class of current datapoint

    # Initialize KMeans with the number of clusters
    kmeans = KMeans(n_clusters=n_classes, random_state=0)

    # Fit the model
    kmeans.fit(V)

    # Sort vectors by cluster
    labels = kmeans.labels_ # assigned cluster for each vector
    centroids = kmeans.cluster_centers_ # coordinates for cluster centers (not real datapoints)
    clusters: list[list[list[float]]] = [[] for _ in range(n_classes)]
    for i, vector in enumerate(V):
        vector_class = C[i]
        vector_cluster = labels[i]
        # store vector and its original class by its cluster
        clusters[vector_cluster].append([vector[0], vector[1], vector_class])

    # Visualize the cluster
    for i, cluster in enumerate(clusters):
        plt.scatter([vector[0] for vector in cluster], [vector[1] for vector in cluster], label=f"Cluster {i}")
    plt.scatter(centroids[:, 0], centroids[:, 1], c='red', marker='x', label="Centroids")
    plt.legend()
    plt.show()

    # Print cluster percentages
    print(f"{'':<10}",end='')
    for i in range(n_classes): print(f"{i:^9}",end='')
    print()

    for i, lbl in enumerate(pie_labels):
        print(f"{lbl:<10}",end='')
        for cluster in clusters:
            num_classes = len([vector[2] for vector in cluster if vector[2] == i])
            per_classes = f"{100 * num_classes / len(cluster):.2f}%"
            print(f"{per_classes:^9}",end='')
        print()


with torch.no_grad():
    unsupervised_classification(model.encoder, 10, testDataset, ["airplane","car","bird","cat","deer","dog","frog","horse","ship","truck"])