In [4]:
from scipy.optimize import linear_sum_assignment as linear_assignment
from torch.utils.data import DataLoader
from torchvision import transforms

import matplotlib.pyplot as plt
import numpy as np 

import torch.nn as nn
import torchvision
import torch

BATCH_SIZE = 128
NUM_CLUSTER = 10
LATENT_SIZE = 10
DATASET_SHUFFLE = True

In [5]:
trainset    = torchvision.datasets.MNIST('../data/', download=True, train=True, transform=transforms.ToTensor())
testset    = torchvision.datasets.MNIST('../data/', download=True, train=False, transform=transforms.ToTensor())
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=DATASET_SHUFFLE)
testloader  = DataLoader(testset,  batch_size=BATCH_SIZE, shuffle=DATASET_SHUFFLE)

In [6]:
class Flatten(nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)

class DeFlatten(nn.Module):
    def __init__(self, k):
        super(DeFlatten, self).__init__()
        self.k  = k

    def forward(self, x):
        s = x.size()
        feature_size = int((s[1] // self.k) ** 0.5)
        return x.view(s[0], self.k, feature_size, feature_size)

In [7]:
class Kmeans(nn.Module):
    def __init__(self, num_cluster, latent_size):
        super(Kmeans, self).__init__()
        device = torch.device("mps")
        self.num_cluster = num_cluster
        self.centroids = nn.Parameter(torch.rand(
            (self.num_cluster, latent_size).to(device)
        ))

    def argminl2distance(self, a, b):
        return torch.argmin(
            torch.sum((a - b) ** 2, dim=1), dim=0
        )

    def forward(self, x):
        y_assign = list()
        for m in range(x.size(0)):
            h = x[m].expand(self.num_cluster, -1)
            assign = self.argminl2distance(h, self.centroids)
            y_assign.append(assign.item())

        return y_assign, self.centroids[y_assign]

In [8]:
class Encoder(nn.Module):
    def __init__(self, latent_size):
        super(Encoder, self).__init__()

        k = 16
        self.encoder = nn.Sequential(
            nn.Conv2d(1, k, 3, stride=2),           nn.ReLU(),
            nn.Conv2d(k, 2 * k, 3, stride=2),       nn.ReLU(),
            nn.Conv2d(2 * k, 4 * k, 3, stride=1),   nn.ReLU(),
            Flatten(),
            nn.Linear(1024, latent_size),           nn.ReLU()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        size = x.size()
        feature_size = int((size[1] // self.k) ** 0.5)
        return x.view(size[0], self.k, feature_size, feature_size)

class Decoder(nn.Module):
    def __init__(self, latent_size):
        super(Decoder, self).__init__()
        k = 16
        self.decoder = nn.Sequential(
            nn.Linear(latent_size, 1024),           nn.ReLU(),
            DeFlatten(4 * k),
            nn.ConvTranspose2d(4 * k, 2 * k, 3, stride=1),
            nn.ReLU(),
            nn.ConvTranspose2d(2 * k, k, 3, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(k, 1, 3, stride=2, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.decoder(x)