In [1]:
from scipy.optimize import linear_sum_assignment as linear_assignment
from torch.utils.data import DataLoader
from torchvision import transforms

import matplotlib.pyplot as plt
import numpy as np 

import torch.nn as nn
import torchvision
import torch

device = torch.device('mps')

BATCH_SIZE = 128
NUM_CLUSTER = 10
LATENT_SIZE = 10
DATASET_SHUFFLE = True

  Referenced from: <F0D48035-EF9E-3141-9F63-566920E60D7C> /Users/bahk_insung/miniconda3/lib/python3.10/site-packages/torchvision/image.so
  Expected in:     <44B645FB-F027-3EE5-86D7-DBF8E2FC6264> /Users/bahk_insung/miniconda3/lib/python3.10/site-packages/torch/lib/libtorch_cpu.dylib
  warn(f"Failed to load image Python extension: {e}")


In [2]:
trainset    = torchvision.datasets.MNIST('../data/', download=True, train=True, transform=transforms.ToTensor())
testset    = torchvision.datasets.MNIST('../data/', download=True, train=False, transform=transforms.ToTensor())
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=DATASET_SHUFFLE)
testloader  = DataLoader(testset,  batch_size=BATCH_SIZE, shuffle=DATASET_SHUFFLE)

In [3]:
class Flatten(nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)

class DeFlatten(nn.Module):
    def __init__(self, k):
        super(DeFlatten, self).__init__()
        self.k  = k

    def forward(self, x):
        s = x.size()
        feature_size = int((s[1] // self.k) ** 0.5)
        return x.view(s[0], self.k, feature_size, feature_size)

In [4]:
class Kmeans(nn.Module):
    def __init__(self, num_cluster, latent_size):
        super(Kmeans, self).__init__()
        device = torch.device("mps")
        self.num_cluster = num_cluster
        self.centroids = nn.Parameter(torch.rand(
            (self.num_cluster, latent_size)
        ).to(device))

    def argminl2distance(self, a, b):
        return torch.argmin(
            torch.sum((a - b) ** 2, dim=1), dim=0
        )

    def forward(self, x):
        y_assign = list()
        for m in range(x.size(0)):
            h = x[m].expand(self.num_cluster, -1)
            assign = self.argminl2distance(h, self.centroids)
            y_assign.append(assign.item())

        return y_assign, self.centroids[y_assign]

In [5]:
class Encoder(nn.Module):
    def __init__(self, latent_size):
        super(Encoder, self).__init__()
        self.k = 16
        self.encoder = nn.Sequential(
            nn.Conv2d(1, self.k, 3, stride=2),                nn.ReLU(),
            nn.Conv2d(self.k, 2 * self.k, 3, stride=2),       nn.ReLU(),
            nn.Conv2d(2 * self.k, 4 * self.k, 3, stride=1),   nn.ReLU(),
            Flatten(),
            nn.Linear(1024, latent_size),           nn.ReLU()
        )

    def forward(self, x):
        return self.encoder(x) 
        # size = x.size()
        # feature_size = int((size[1] // self.k) ** 0.5)
        # return x.view(size[0], self.k, feature_size, feature_size)

class Decoder(nn.Module):
    def __init__(self, latent_size):
        super(Decoder, self).__init__()
        k = 16
        self.decoder = nn.Sequential(
            nn.Linear(latent_size, 1024),           nn.ReLU(),
            DeFlatten(4 * k),
            nn.ConvTranspose2d(4 * k, 2 * k, 3, stride=1),
            nn.ReLU(),
            nn.ConvTranspose2d(2 * k, k, 3, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(k, 1, 3, stride=2, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.decoder(x)

In [6]:
def cluster_acc(y_true, y_pred):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros(
        (D, D), dtype=np.int64
    )

    for i in range(len(y_pred.size)):
        w[y_pred[i], y_true[i]] += 1

    ind = linear_assignment(w.max() - w)
    return sum([
        w[i, j] for i, j in zip(ind[0], ind[1])
    ]   * 1.0 / y_pred.size)

In [7]:
def evaluation(testloader, encoder, kmeans, device):
    predictions, actual = [], []
    with torch.no_grad():
        for images, labels in testloader:
            inputs = images.to(device)
            labels = labels.to(device)
            latent_var = encoder(inputs)
            y_pred, _  = kmeans(latent_var)

            predictions += y_pred
            actual      += labels.cpu().tolist()
    
    return cluster_acc(actual, predictions)

In [8]:
encoder = Encoder(latent_size=LATENT_SIZE).to(device=device)
decoder = Decoder(latent_size=LATENT_SIZE).to(device=device)
kmeans  = Kmeans(num_cluster=NUM_CLUSTER, latent_size=LATENT_SIZE).to(device)

criterion1 = torch.nn.MSELoss()
criterion2 = torch.nn.MSELoss()
optimizer  = torch.optim.Adam(
    list(encoder.parameters()) + list(decoder.parameters()) + list(kmeans.parameters()),
    lr=1e-3
)

In [9]:
T1 = 50
T2 = 200
lam = 1e-3
ls = 0.05

In [10]:
for ep in range(300):
    if (ep > T1) and (ep < T2):
        alpha = lam * (ep - T1) / (T2 - T1)
    
    elif ep >= T2:
        alpha = lam

    else:
        # print()
        alpha =lam / (T2-T1)
    
    running_loss = 0.0
    for images, _ in trainloader:
        inputs = images.to(device)
        optimizer.zero_grad()
        latent_var = encoder(inputs)
        _, centroids = kmeans(latent_var.detach())
        outputs = decoder(latent_var)

        l_rec = criterion1(inputs, outputs)
        l_clt = criterion2(latent_var, centroids)
        loss = l_rec + alpha * l_clt
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_loss = running_loss / len(trainloader)
    print(f"Epoch : {ep}\tAverage loss : {avg_loss}")

    if avg_loss < ls:
        ls = avg_loss
        torch.save(encoder.state_dict(), './models/dkm_en.pth')
        torch.save(decoder.state_dict(), './models/dkm_de.pth')
        torch.save(kmeans.state_dict(),  './models/dkm_clt.pth')
        print(f"\tEpoch {ep} model has saved!")

Epoch : 0	Average loss : 0.07459391468464693
Epoch : 1	Average loss : 0.04309159790529117
	Epoch 1 model has saved!
Epoch : 2	Average loss : 0.03626011261569539
	Epoch 2 model has saved!
Epoch : 3	Average loss : 0.03220169237459392
	Epoch 3 model has saved!
Epoch : 4	Average loss : 0.02957155457787168
	Epoch 4 model has saved!
Epoch : 5	Average loss : 0.027734393587530548
	Epoch 5 model has saved!


In [None]:
encoder.load_state_dict(torch.load('./models/dkm_en.pth'))
decoder.load_state_dict(torch.load('./models/dkm_de.pth'))
kmeans.load_state_dict(torch.load('./models/dkm_clt.pth'))

predicitions, actual, latent_features = [], [], []
with torch.no_grad():
    for images, labels in testloader:
        inputs = images.to(device)
        labels = labels.to(device)
        latent_var = encoder(inputs)
        y_pred, _  = kmeans(latent_var)

        predicitions += y_pred
        latent_features += latent_var.cpu().tolist()
        actual += labels.cpu().tolist()

print(cluster_acc(actual, predicitions))