In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from icecream import ic
from typing import Literal



In [None]:
# Hyperparameters
batch_size = 128
latent_dim = 64
num_embeddings = 512
learning_rate = 1e-3
num_epochs = 10
patch_size = 7

# Data Loading
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
device

In [None]:
tr_x, tr_y = next(iter(train_loader))
tr_x.shape

## Patching logic

In [None]:
x = tr_x[:1]
u_x = x.unfold(2, patch_size, patch_size)
print(x.shape, "->",u_x.shape, "4x7(patch_size) = 28")
u_x.unfold(3, patch_size, patch_size).shape
x.shape

In [None]:
# Create a 28x28 tensor with values from 0 to 783
t_size=6
t = torch.arange(0, t_size*t_size).view(t_size, t_size).unsqueeze(0).unsqueeze(0)
t_bs = 1
print(t)
t_patch_size = 3
# Tensor.unfold(dimension, size, step)
# Returns a view of the original tensor which contains all slices of size size from self tensor in the tensor dimension.
# use a step (like a stride)
ut = t.unfold(2,t_patch_size,t_patch_size)
print(ut)
ut = ut.unfold(3,t_patch_size,t_patch_size)
print(ut)
utv = ut.contiguous().view(t_bs, 1, -1, t_patch_size, t_patch_size)
bputv = utv.view(-1, 1, t_patch_size, t_patch_size)
#it should create 4 examples of patches of 3x3.
assert bputv.shape  == ((t_size//t_patch_size)*(t_size//t_patch_size),1, t_patch_size, t_patch_size)



In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_dim:int, patch_size:int):
        super(Encoder, self).__init__()
        self.patch_size = patch_size
        self.latent_dim = latent_dim
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(64*patch_size*patch_size, latent_dim)

    def forward(self, x):
        batch_size, _, _, _ = x.size()
        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.contiguous().view(batch_size, 1, -1, self.patch_size, self.patch_size)
        # each patch in each example is added as an example.
        x = x.view(-1, 1, self.patch_size, self.patch_size)
        #print("after patch creation", x.shape)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 64*patch_size*patch_size)
        self.conv1 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 64, patch_size, patch_size)
        x = F.relu(self.conv1(x))
        x = torch.sigmoid(self.conv2(x))
        return x

In [None]:
# check encode/decode logic works
sm_latent_dim = 2
img_side_size = 28
nb_patches_per_img = (img_side_size/patch_size)**2
print("nb_patches_per_img:", nb_patches_per_img)
print("batch_size:",batch_size)

enc = Encoder(sm_latent_dim, patch_size=patch_size)
dec = Decoder(sm_latent_dim)
sm_x = tr_x[:batch_size]
batch_size = sm_x.size(0)
print("sm_x:", sm_x.shape)
sm_enc_x = enc(sm_x)
print("enc_x:", sm_enc_x.shape)
assert sm_enc_x.shape[1]==sm_latent_dim
assert sm_enc_x.shape[0]==nb_patches_per_img*batch_size
sm_dec_x = dec(sm_enc_x)
assert sm_dec_x.shape == (nb_patches_per_img*batch_size, 1, patch_size, patch_size), "We should reconstruct each patch based on its latent representation"
print("dec_x:", sm_dec_x.shape)

## Vector Quantizer:
Encoder maps x -> z_e (latent space)

It quantizes the latent vectors to the nearest vector in the codebook, aka transform the encoder output into a discrete one-hot vector that is the index of the closest embedding vector z_e -> z_q

Decoder maps z_q -> x_hat to reconstruct the original image.

In [None]:
sm_num_embeddings = 5
embedding = nn.Embedding(sm_num_embeddings, sm_latent_dim)
ic(embedding.weight.data)
embedding.weight.data.uniform_(-1/sm_num_embeddings, 1/sm_num_embeddings)
ic(embedding.weight.data)

z = sm_enc_x
ic(z.shape)
z_flattened = z.view(-1, sm_latent_dim)
ic(z_flattened.shape)
distances = torch.cdist(z_flattened, embedding.weight)
ic(distances.shape)

encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
ic(encoding_indices.shape)

z_q = embedding(encoding_indices).view(z.shape)
ic(z_flattened[3:6])
ic(encoding_indices[3:6])
ic(z_q[3:6])


In [None]:
# Vector Quantizer 
# wrap the logic into a nn.Module


class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings:int, embedding_dim:int, dist_type: Literal['cosine','euclidean']='euclidean'):
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)
        self.dist_type = dist_type

    def get_cosine_similarity_indices(self, z):
        # Normalize the latent vectors and the embedding vectors to unit length
        z_normalized = F.normalize(z, dim=-1)
        embedding_normalized = F.normalize(self.embedding.weight, dim=-1)        
        # Compute cosine similarity
        similarity = torch.matmul(z_normalized, embedding_normalized.t())        
        # Get the indices of the highest similarity (closest vectors)
        encoding_indices = torch.argmax(similarity, dim=-1).unsqueeze(1)
        return encoding_indices
        
    def get_euclid_dist_indices(self, z):
        z_flattened = z.view(-1, self.embedding_dim)
        distances = torch.cdist(z_flattened, self.embedding.weight)
        # Get the indices with smallest euclidean distance
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        return encoding_indices

    def forward(self, z):
        if self.dist_type =='cosine':        
            encoding_indices = self.get_cosine_similarity_indices(z)
        else:
            encoding_indices = self.get_euclid_dist_indices(z)
        z_q = self.embedding(encoding_indices).view(z.shape)
        return z_q, encoding_indices


In [None]:

vq = VectorQuantizer(num_embeddings, sm_latent_dim,dist_type='euclidean')
sm_z_q, sm_z_q_idx = vq(sm_enc_x)
assert sm_z_q.shape == sm_enc_x.shape
sm_z_q[:4], sm_z_q_idx[:4]

In [None]:
# VQ-VAE Model
class VQVAE(nn.Module):
    def __init__(self, latent_dim:int, num_embeddings:int, patch_size:int):
        super(VQVAE, self).__init__()
        self.encoder = Encoder(latent_dim, patch_size=patch_size)
        self.decoder = Decoder(latent_dim)
        self.vector_quantizer = VectorQuantizer(num_embeddings, latent_dim)

    def forward(self, x):
        z = self.encoder(x)
        z_q, _ = self.vector_quantizer(z)
        x_recon = self.decoder(z_q)
        return x_recon, z, z_q

In [None]:
vqvae = VQVAE(latent_dim=sm_latent_dim, num_embeddings=sm_num_embeddings, patch_size=patch_size)
x_recon, z, z_q = vqvae(sm_x)
ic(x_recon.shape, z.shape, z_q.shape)

In [None]:
# Training the Model
model = VQVAE(latent_dim, num_embeddings, patch_size=patch_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

def train(model, data_loader, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        for batch in data_loader:
            x, _ = batch
            x = x.to(device)
            optimizer.zero_grad()
            x_recon, z, z_q = model(x)
            recon_loss = F.mse_loss(x_recon, x)
            commitment_loss = F.mse_loss(z_q, z.detach())
            loss = recon_loss + commitment_loss
            loss.backward()
            optimizer.step()
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

train(model, train_loader, optimizer, num_epochs)

# Visualization of reconstructed images
def show_reconstructed(model, data_loader):
    model.eval()
    with torch.no_grad():
        for batch in data_loader:
            x, _ = batch
            x = x.to(device)
            x_recon, _, _ = model(x)
            break
    x = x.cpu().numpy()
    x_recon = x_recon.cpu().numpy()
    fig, axes = plt.subplots(1, 2)
    axes[0].imshow(x[0][0], cmap='gray')
    axes[0].set_title('Original')
    axes[1].imshow(x_recon[0][0], cmap='gray')
    axes[1].set_title('Reconstructed')
    plt.show()

show_reconstructed(model, train_loader)