<a href="https://colab.research.google.com/github/harryypham/MyMLPractice/blob/main/VQVAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

%matplotlib inline

In [18]:
class ResidualBlock(nn.Module):
  def __init__(self, in_dims, hidden_dims):
    super(ResidualBlock, self).__init__()
    self.block = nn.Sequential(
        nn.ReLU(),
        nn.Conv2d(in_dims, hidden_dims, 3, 1, 1),
        nn.ReLU(),
        nn.Conv2d(hidden_dims, in_dims, 1)
    )

  def forward(self, x):
    return self.block(x) + x

class ResidualStack(nn.Module):
  def __init__(self, in_dims, hidden_dims, num_layers):
    super(ResidualStack, self).__init__()
    self.stack = self._make_layers(in_dims, hidden_dims, num_layers)


  def _make_layers(self, in_dims, hidden_dims, num_layers):
    ls = []
    for i in range(num_layers):
      ls.append(ResidualBlock(in_dims, hidden_dims))

    return nn.Sequential(*ls)

  def forward(self, x):
    return self.stack(x)

class Encoder(nn.Module):
  def __init__(self, in_dims, hidden_dims):
    super(Encoder, self).__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(in_dims, hidden_dims//2, 4, 2, 1),
        nn.ReLU(),
        nn.Conv2d(hidden_dims//2, hidden_dims, 4, 2, 1),
        nn.ReLU()
    )
    self.layer2 = ResidualStack(hidden_dims, 256, 2)

  def forward(self, x):
    x = self.layer1(x)
    x = self.layer2(x)
    return x

class Decoder(nn.Module):
  def __init__(self, in_dims, out_dims):
    super(Decoder, self).__init__()
    self.layer1 = ResidualStack(in_dims, 256, 2)
    self.layer2 = nn.Sequential(
        nn.ConvTranspose2d(in_dims, 32, 4, 2, 1),
        nn.ReLU(),
        nn.ConvTranspose2d(32, out_dims, 4, 2, 1)
    )

  def forward(self, x):
    x = self.layer1(x)
    x = self.layer2(x)
    return x



In [None]:
class VQEmbedding(nn.Module):
  def __init__(self, emb_size, emb_dims):
    self.embedding = nn.Embedding(emb_size, emb_dims)
    self.embedding.weight.data.uniform_(-1./emb_size, 1./emb_size)

  def forward(self, x):
    permuted_x = x.permute(0, 2, 3, 1)
    permuted_x_flatten = permuted_x.view(-1, self.embeddin)

In [19]:
x = torch.rand((1, 1, 32, 32))
encode = Encoder(1, 64)
out = encode(x)
print(out.shape)
decode = Decoder(64, 3)
out = decode(out)
print(out.shape)

torch.Size([1, 64, 8, 8])
torch.Size([1, 3, 32, 32])


In [39]:
x = torch.rand((16, 64, 8, 8))
flat_x = x.permute(0, 2, 3, 1).reshape(-1, 64)
print(flat_x.shape)
#Imagine image have 64 channels and for each pixel of that image, we find the vector closest to it

emb_table = nn.Embedding(512, 64)
emb_table.weight.data.uniform_(-1, 1)

distances = (
            (flat_x ** 2).sum(1, keepdim=True)
            - 2 * flat_x @ emb_table.weight.t()
            + (emb_table.weight.t() ** 2).sum(0, keepdim=True)
        )

print(distances.shape)
encoding_indices = distances.argmin(1)
quantizied_x = F.embedding(encoding_indices, emb_table.weight)
print(emb_table.weight.shape)
print(quantizied_x.shape)

torch.Size([1024, 64])
torch.Size([1024, 512])
torch.Size([512, 64])
torch.Size([1024, 64])
