<a href="https://colab.research.google.com/github/harryypham/MyMLPractice/blob/main/VQVAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

%matplotlib inline

In [3]:
trainset = torchvision.datasets.CIFAR10(root="data", train=True, download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                                  ]))

testset = torchvision.datasets.CIFAR10(root="data", train=False, download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                                  ]))

batch_size = 32


trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=2)
testloader = DataLoader(testset, batch_size=batch_size, num_workers=2)

data_variance = np.var(trainset.data / 255.0)

Files already downloaded and verified
Files already downloaded and verified


In [32]:
class ResidualBlock(nn.Module):
  def __init__(self, in_dims, hidden_dims):
    super(ResidualBlock, self).__init__()
    self.block = nn.Sequential(
        nn.ReLU(),
        nn.Conv2d(in_dims, hidden_dims, 3, 1, 1, bias=False),
        nn.ReLU(),
        nn.Conv2d(hidden_dims, in_dims, 1, bias=False)
    )

  def forward(self, x):
    return self.block(x) + x

class ResidualStack(nn.Module):
  def __init__(self, in_dims, hidden_dims, num_layers):
    super(ResidualStack, self).__init__()
    self.stack = self._make_layers(in_dims, hidden_dims, num_layers)


  def _make_layers(self, in_dims, hidden_dims, num_layers):
    ls = []
    for i in range(num_layers):
      ls.append(ResidualBlock(in_dims, hidden_dims))

    return nn.Sequential(*ls)

  def forward(self, x):
    return self.stack(x)

class Encoder(nn.Module):
  def __init__(self, in_dims, hidden_dims, residual_hidden_dims=64):
    super(Encoder, self).__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(in_dims, hidden_dims//2, 4, 2, 1),
        nn.ReLU(),
        nn.Conv2d(hidden_dims//2, hidden_dims, 4, 2, 1),
        nn.ReLU()
    )
    self.layer2 = ResidualStack(hidden_dims, residual_hidden_dims, 2)

  def forward(self, x):
    x = self.layer1(x)
    x = self.layer2(x)
    return x

class Decoder(nn.Module):
  def __init__(self, in_dims, out_dims, residual_hidden_dims=64):
    super(Decoder, self).__init__()
    self.layer1 = ResidualStack(in_dims, residual_hidden_dims, 2)
    self.layer2 = nn.Sequential(
        nn.ConvTranspose2d(in_dims, 32, 4, 2, 1),
        nn.ReLU(),
        nn.ConvTranspose2d(32, out_dims, 4, 2, 1)
    )

  def forward(self, x):
    x = self.layer1(x)
    x = self.layer2(x)
    return x



In [5]:
torch.manual_seed(0)

<torch._C.Generator at 0x7d149381e890>

In [18]:
class VQEmbedding(nn.Module):
  def __init__(self, emb_size, emb_dims, beta=0.25):
    super().__init__()
    self.emb_size = emb_size
    self.emb_dims = emb_dims
    self.beta = beta
    self.embedding = nn.Embedding(emb_size, emb_dims)
    self.embedding.weight.data.uniform_(-1./emb_size, 1./emb_size)

  def forward(self, x):
    # Current: https://github.com/zalandoresearch/pytorch-vq-vae/blob/master/vq-vae.ipynb
    # https://juliusruseckas.github.io/ml/vq-vae.html
    # Work out how the formula to calculate the distance work
    x = x.permute(0, 2, 3, 1).contiguous()
    input_shape = x.shape
    flatten_x = x.view(-1, self.emb_dims)

    distances = torch.sum(flatten_x ** 2, dim=1, keepdim=True) \
                + torch.sum(self.embedding.weight**2, dim=1) \
                - 2 * torch.matmul(flatten_x, self.embedding.weight.t())

    indices = torch.argmin(distances, dim=1).unsqueeze(1)
    encodings = torch.zeros(indices.shape[0], self.emb_size).to(device)
    encodings.scatter_(1, indices, 1)


    # Quantize and unflatten
    quantized = torch.matmul(encodings, self.embedding.weight).view(input_shape)

    e_latent_loss = F.mse_loss(quantized.detach(), x)
    q_latent_loss = F.mse_loss(quantized, x.detach())
    loss = q_latent_loss + self.beta * e_latent_loss

    # Make the gradient with respect to inputs be equal to the gradient with respect to quantized latents (cool trick!)
    quantized = x + (quantized - x).detach()
    quantized = quantized.permute(0, 3, 1, 2).contiguous()

    return quantized, loss

In [34]:
class VQVAE(nn.Module):
  def __init__(self, emb_size, emb_dims, hidden_dims, residual_hidden_dims):
    super().__init__()
    self.encoder = Encoder(3, hidden_dims)
    self.pre_vq_conv = nn.Conv2d(hidden_dims, emb_dims, 1, 1)
    self.decoder = Decoder(emb_dims, 3)
    self.vq = VQEmbedding(emb_size, emb_dims)

  def forward(self, x):
    z_e = self.encoder(x)
    z_e = self.pre_vq_conv(z_e)
    z_q, loss = self.vq(z_e)
    x_recon = self.decoder(z_q)
    return x_recon, loss

In [10]:
x = torch.rand((16, 2, 8, 8))
targets = torch.rand((1024, 2))

In [21]:

embed = VQEmbedding(20, 2)
embed.zero_grad()
out = embed(x)
criterion = nn.MSELoss()
loss = criterion(out, targets)
loss.backward()
print(embed.embedding.weight.grad)

tensor([[ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [-0.0014, -0.0018],
        [-0.0029, -0.0044],
        [ 0.0000,  0.0000],
        [-0.1275, -0.1530],
        [ 0.0000,  0.0000],
        [-0.2329, -0.2252],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [-0.1104, -0.0966],
        [ 0.0000,  0.0000]])


In [None]:
x = torch.rand((1, 1, 32, 32))
encode = Encoder(1, 64)
out = encode(x)
print(out.shape)
decode = Decoder(64, 3)
out = decode(out)
print(out.shape)

torch.Size([1, 64, 8, 8])
torch.Size([1, 3, 32, 32])


In [None]:
x = torch.rand((16, 64, 8, 8))
flat_x = x.permute(0, 2, 3, 1).reshape(-1, 64)
print(flat_x.shape)
#Imagine image have 64 channels and for each pixel of that image, we find the vector closest to it

emb_table = nn.Embedding(512, 64)
emb_table.weight.data.uniform_(-1, 1)

distances = (
            (flat_x ** 2).sum(1, keepdim=True)
            - 2 * flat_x @ emb_table.weight.t()
            + (emb_table.weight.t() ** 2).sum(1)
        )

print(distances.shape)
encoding_indices = distances.argmin(1)
quantizied_x = F.embedding(encoding_indices, emb_table.weight)
print(emb_table.weight.shape)
print(quantizied_x.shape)

torch.Size([1024, 64])
torch.Size([1024, 512])
torch.Size([512, 64])
torch.Size([1024, 64])


In [8]:
num_hiddens = 128
num_residual_hiddens = 256
embedding_dim = 64
num_embeddings = 512
num_epochs = 20
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [35]:
model = VQVAE(num_embeddings, embedding_dim, num_hiddens, num_residual_hiddens).to(device)
optimizer = optim.Adam(model.parameters(), lr=2e-4)

loss_history = []

for epoch in range(num_epochs):
  train_loss = 0
  for batch_idx, (inputs, _) in enumerate(trainloader):
    inputs = inputs.to(device)
    optimizer.zero_grad()

    outputs, vq_loss = model(inputs)
    loss = vq_loss + (F.mse_loss(outputs, inputs) / data_variance)
    loss.backward()

    optimizer.step()
    loss_history.append(loss.item())
    train_loss += loss.item()

    if (not batch_idx % 200) and batch_idx != 0:
            print ('Batch %03d | Cost: %.6f'
                  %(batch_idx, train_loss/(batch_idx+1)))



Batch 200 | Cost: 110525.076168
Batch 400 | Cost: 101228.419506
Batch 600 | Cost: 82433.370400
Batch 800 | Cost: 69752.308293
Batch 1000 | Cost: 61206.350148
Batch 1200 | Cost: 53704.884954
Batch 1400 | Cost: 47308.566809
Batch 200 | Cost: 6037.409195
Batch 400 | Cost: 6717.036955
Batch 600 | Cost: 7551.921424
Batch 800 | Cost: 7667.910918
Batch 1000 | Cost: 8307.086652
Batch 1200 | Cost: 8130.532272
Batch 1400 | Cost: 7829.706780
Batch 200 | Cost: 3401.278342
Batch 400 | Cost: 3322.945808
Batch 600 | Cost: 3308.120538
Batch 800 | Cost: 3116.768793
Batch 1000 | Cost: 2931.333768
Batch 1200 | Cost: 2756.529744
Batch 1400 | Cost: 2603.009650
Batch 200 | Cost: 1300.245091
Batch 400 | Cost: 1256.102798
Batch 600 | Cost: 1241.405489
Batch 800 | Cost: 1306.290302
Batch 1000 | Cost: 1282.126504
Batch 1200 | Cost: 1236.776047
Batch 1400 | Cost: 1201.608634
Batch 200 | Cost: 963.982366
Batch 400 | Cost: 986.937480
Batch 600 | Cost: 1003.591512
Batch 800 | Cost: 982.633514
Batch 1000 | Cost: 954