# VQ-GAN

## 1. Setup

In [1]:
!pip -q install lightning einops datasets tokenizers

In [2]:
from typing import Tuple
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Resize, ToTensor, ToPILImage
from torchvision.utils import save_image

from torch.optim import Adam

import lightning as L

from einops import rearrange

from tqdm import tqdm

## 2. Model

### 2.1. Normalization

In [3]:
Normalization = lambda in_channels: nn.GroupNorm(
    num_groups=32,
    num_channels=in_channels,
)

### 2.2. Convolution

In [4]:
Convolution = lambda in_channels, out_channels: nn.Conv2d(
    in_channels=in_channels,
    out_channels=out_channels,
    kernel_size=3,
    stride=1,
    padding=1,
)

### 2.3. Pointwise Convolution

In [5]:
PointwiseConvolution = lambda in_channels: nn.Conv2d(
    in_channels=in_channels,
    out_channels=in_channels,
    kernel_size=1,
    stride=1,
    padding=0,
)

### 2.4. Downsample Convolution

In [6]:
DownsampleConvolution = lambda in_channels, out_channels: nn.Conv2d(
    in_channels=in_channels,
    out_channels=out_channels,
    kernel_size=4,
    stride=2,
    padding=1,
)

### 2.5. Activation

In [7]:
Activation = lambda: nn.LeakyReLU(0.2)

### 2.6. ResNet Block

In [8]:
class ResNetBlock(nn.Module):
    """ResNet Block."""

    def __init__(self, *, in_channels: int) -> None:
        """Initialize the module."""

        super().__init__()

        self.sequential = nn.Sequential(
            Normalization(in_channels=in_channels),
            Activation(),
            Convolution(in_channels=in_channels, out_channels=in_channels),
            Normalization(in_channels=in_channels),
            Activation(),
            Convolution(in_channels=in_channels, out_channels=in_channels),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass."""

        return x + self.sequential(x)

### 2.7. Attention Block

In [9]:
class AttentionBlock(nn.Module):
    """Attention block."""

    def __init__(self, *, in_channels: int) -> None:
        """Initialize the module."""

        super().__init__()

        self.normalization = Normalization(in_channels=in_channels)
        self.project_q = PointwiseConvolution(in_channels=in_channels)
        self.project_k = PointwiseConvolution(in_channels=in_channels)
        self.project_v = PointwiseConvolution(in_channels=in_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass."""

        b, c, h, w = x.shape
        z = self.normalization(x)

        q = self.project_q(z)
        k = self.project_k(z)
        v = self.project_v(z)

        q = rearrange(q, 'b c h w -> b (h w) c')
        k = rearrange(k, 'b c h w -> b c (h w)')
        v = rearrange(v, 'b c h w -> b (h w) c')

        z = F.softmax(q @ k, dim=-1) @ v
        z = rearrange(z, 'b (h w) c -> b c h w', h=h, w=w)

        return x + z

### 2.8. Downsample Block

In [10]:
class DownsampleBlock(nn.Module):
    """Downsample block."""

    def __init__(self, *, in_channels: int, out_channels: int) -> None:
        """Initialize the module."""

        super().__init__()

        self.sequential = nn.Sequential(
            Normalization(in_channels=in_channels),
            DownsampleConvolution(
                in_channels=in_channels,
                out_channels=out_channels,
            ),
            Activation(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass."""

        return self.sequential(x)

### 2.9. Upsample Block

In [11]:
class UpsampleBlock(nn.Module):
    """Upsample block."""

    def __init__(self, *, in_channels: int, out_channels: int) -> None:
        """Initialize the module."""

        super().__init__()

        self.sequential = nn.Sequential(
            Normalization(in_channels=in_channels),
            nn.Upsample(scale_factor=2),
            Convolution(in_channels=in_channels, out_channels=out_channels),
            Activation(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass."""

        return self.sequential(x)

### 2.10. Encoder

In [12]:
class Encoder(nn.Module):
    """Encoder."""

    def __init__(
        self,
        *,
        in_channels: int,
        hidden_channels: int,
        embedding_channels: int,
    ) -> None:
        """Initialize the module."""

        super().__init__()

        self.sequential = nn.Sequential(

            # Downsampling.

            Convolution(
                in_channels=in_channels,
                out_channels=hidden_channels,
            ),

            ResNetBlock(in_channels=hidden_channels),
            DownsampleBlock(
                in_channels=hidden_channels,
                out_channels=hidden_channels,
            ),

            ResNetBlock(in_channels=hidden_channels),
            DownsampleBlock(
                in_channels=hidden_channels,
                out_channels=hidden_channels*2,
            ),

            ResNetBlock(in_channels=hidden_channels*2),
            DownsampleBlock(
                in_channels=hidden_channels*2,
                out_channels=hidden_channels*2,
            ),

            ResNetBlock(in_channels=hidden_channels*2),
            DownsampleBlock(
                in_channels=hidden_channels*2,
                out_channels=hidden_channels*4,
            ),

            ResNetBlock(in_channels=hidden_channels*4),

            # Attention.

            ResNetBlock(in_channels=hidden_channels*4),
            AttentionBlock(in_channels=hidden_channels*4),
            ResNetBlock(in_channels=hidden_channels*4),

            # Embedding.

            Normalization(in_channels=hidden_channels*4),
            Convolution(
                in_channels=hidden_channels*4,
                out_channels=embedding_channels,
            ),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass."""

        return self.sequential(x)

### 2.11. Decoder

In [13]:
class Decoder(nn.Module):
    """Decoder."""

    def __init__(
        self,
        *,
        out_channels: int,
        hidden_channels: int,
        embedding_channels: int,
    ) -> None:
        """Initialize the module."""

        super().__init__()

        self.sequential = nn.Sequential(

            # Unembedding.

            Convolution(
                in_channels=embedding_channels,
                out_channels=hidden_channels*4,
            ),

            # Attention.

            ResNetBlock(in_channels=hidden_channels*4),
            AttentionBlock(in_channels=hidden_channels*4),
            ResNetBlock(in_channels=hidden_channels*4),

            # Upsampling.

            ResNetBlock(in_channels=hidden_channels*4),
            UpsampleBlock(
                in_channels=hidden_channels*4,
                out_channels=hidden_channels*2,
            ),

            ResNetBlock(in_channels=hidden_channels*2),
            UpsampleBlock(
                in_channels=hidden_channels*2,
                out_channels=hidden_channels*2,
            ),

            ResNetBlock(in_channels=hidden_channels*2),
            UpsampleBlock(
                in_channels=hidden_channels*2,
                out_channels=hidden_channels,
            ),

            ResNetBlock(in_channels=hidden_channels),
            UpsampleBlock(
                in_channels=hidden_channels,
                out_channels=hidden_channels,
            ),

            ResNetBlock(in_channels=hidden_channels),

            Normalization(in_channels=hidden_channels),
            Convolution(
                in_channels=hidden_channels,
                out_channels=out_channels,
            ),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass."""

        return self.sequential(x)

### 2.12. Vector Quantizer

In [14]:
class VectorQuantizer(nn.Module):
    """Vector quantizer."""

    def __init__(
        self,
        codebook_size: int,
        embedding_channels: int,
    ) -> None:
        """Initialize the module."""

        super().__init__()

        self.codebook_size = codebook_size
        self.embedding_dimension = embedding_channels

        self.embedding = nn.Embedding(
            num_embeddings=codebook_size,
            embedding_dim=embedding_channels,
        )

    def forward(
        self,
        x: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Forward pass."""

        B, C, H, W = x.shape
        x = rearrange(x, 'b c h w -> (b h w) c')

        distance = torch.sum(x ** 2, dim=1, keepdim=True) \
            + torch.sum(self.embedding.weight ** 2, dim=1) \
            - 2*(x @ self.embedding.weight.T)

        tokens = distance.argmin(dim=1).detach()

        quantized = self.embedding(tokens)
        codebook_loss = F.mse_loss(quantized, x.detach())
        commitment_loss = F.mse_loss(x, quantized.detach())

        quantized = x + (quantized - x).detach()
        quantized = rearrange(quantized, '(b h w) c -> b c h w', h=H, w=W)
        tokens = tokens.view(B, H, W)

        return quantized, tokens, codebook_loss, commitment_loss

### 2.13. Gumbel-softmax Vector Quantizer

In [15]:
class GumbelSoftmaxVectorQuantizer(nn.Module):
    ...

    # TODO

### 2.14. VQ-VAE

In [16]:
class VQVAE(nn.Module):
    """VQ-VAE."""

    def __init__(
        self,
        *,
        in_channels: int,
        hidden_channels: int,
        embedding_channels: int,
        codebook_size: int,
    ) -> None:
        """Initialize the module."""

        super().__init__()

        self.encoder = Encoder(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            embedding_channels=embedding_channels,
        )

        self.decoder = Decoder(
            out_channels=in_channels,
            hidden_channels=hidden_channels,
            embedding_channels=embedding_channels,
        )

        self.codebook = VectorQuantizer(
            embedding_channels=embedding_channels,
            codebook_size=codebook_size,
        )

    def encode(
        self,
        x: torch.Tensor,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
    ]:
        """Encode an example."""

        return self.codebook(self.encoder(x))

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """Decode an example."""

        return self.decoder(z)

    def forward(
        self,
        x: torch.Tensor,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
    ]:
        """Forward pass."""

        quantized, tokens, codebook_loss, commitment_loss = self.encode(x)

        x = self.decode(quantized)

        return x, quantized, tokens, codebook_loss, commitment_loss

### 2.15. Patch Discriminator

In [17]:
def patchify(x: torch.Tensor, patch_size: int) -> torch.Tensor:
    """Patchify an image."""

    return rearrange(x, 'b c (h ph) (w pw) -> b (h w) c ph pw', ph=patch_size, pw=patch_size)

In [18]:
class PatchDiscriminator(nn.Module):
    """Patch discriminator."""

    def __init__(
        self,
        *,
        in_channels: int,
        hidden_channels: int,
        patch_size: int,
    ) -> None:
        """Initialize the module."""

        super().__init__()

        self.patch_size = patch_size

        self.sequential = nn.Sequential(

            Convolution(in_channels=in_channels, out_channels=hidden_channels),

            ResNetBlock(in_channels=hidden_channels),
            DownsampleBlock(
                in_channels=hidden_channels,
                out_channels=hidden_channels,
            ),

            ResNetBlock(in_channels=hidden_channels),
            DownsampleBlock(
                in_channels=hidden_channels,
                out_channels=hidden_channels*2,
            ),

            ResNetBlock(in_channels=hidden_channels*2),
            DownsampleBlock(
                in_channels=hidden_channels*2,
                out_channels=hidden_channels*2,
            ),

            ResNetBlock(in_channels=hidden_channels*2),
            DownsampleBlock(
                in_channels=hidden_channels*2,
                out_channels=hidden_channels*4,
            ),

            ResNetBlock(in_channels=hidden_channels*4),
            Normalization(in_channels=hidden_channels*4),

            nn.Flatten(),
            nn.LazyLinear(out_features=1),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass."""

        b, c, h, w = x.shape

        x = patchify(x, patch_size=self.patch_size)
        x = rearrange(x, 'b p c h w -> (b p) c h w')
        x = self.sequential(x)
        x = rearrange(x, '(b p) o -> b p o', b=b)
        x = x.mean(dim=1)  # Average scores accross patches.

        return x

## 3. Training

In [19]:
resolution = 256
channels = 1

transform = Compose([
    Resize((resolution, resolution), interpolation=0),
    ToTensor(),
])

dataset = MNIST(root='.', train=True, download=True, transform=transform)

In [20]:
#model = VQVAE(in_channels=3, hidden_channels=128, embedding_channels=256, codebook_size=1024).cuda()

device = 'cuda:0'

generator = VQVAE(
    in_channels=1,
    hidden_channels=32,
    embedding_channels=256,
    codebook_size=32,
).to(device)

discriminator = PatchDiscriminator(
    in_channels=1,
    hidden_channels=32,
    patch_size=32,
).to(device)



In [21]:
generator_optimizer = Adam(generator.parameters(), lr=1e-3)
discriminator_optimizer = Adam(discriminator.parameters(), lr=1e-3)

In [22]:
!rm -rf ./reconstruction-*.png

In [26]:
batch_size = 16
epochs = 20
batches = len(dataset) // batch_size
accumulate_steps = 1

generator_dataloader = DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
)

discriminator_dataloader = (x for x in DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
))

beta = 0.25
step = 0

for epoch in range(epochs):
    for batch, data in enumerate(generator_dataloader):

        x = data[0].view(-1, channels, resolution, resolution)
        x = x.to(device)

        # VQ-VAE training.

        generator_optimizer.zero_grad()

        generator_output, quantized, tokens, codebook_loss, commitment_loss = generator(x)
        generator_output = F.sigmoid(generator_output)

        perceptual_loss = F.binary_cross_entropy(generator_output, x)  # TODO: use an actual perceptual loss.
        vq_vae_loss = perceptual_loss + codebook_loss + (beta * commitment_loss)
        vq_vae_loss.backward()

        generator_optimizer.step()

        # GAN training.

        # Train the discriminator.

        discriminator_optimizer.zero_grad()

        fake_input = generator_output[: batch_size // 2].detach()  # Take first half of fake inputs and second half of real inputs.
        real_input = x[batch_size // 2 :]

        discriminator_input = torch.cat((fake_input, real_input))
        discriminator_label = torch.cat((torch.zeros(batch_size // 2), torch.ones(batch_size // 2))).to(device)
        discriminator_output = discriminator(discriminator_input).flatten()

        discriminator_loss = F.binary_cross_entropy(discriminator_output, discriminator_label)
        discriminator_loss.backward()

        discriminator_optimizer.step()

        # Train the generator.

        generator_optimizer.zero_grad()

        generator_output, *_ = generator(x)

        discriminator_input = generator_output[batch_size // 2 :]  # Take second half of fake inputs (not yet seen by the discriminator).
        discriminator_label = torch.ones(batch_size // 2).to(device)
        discriminator_output = discriminator(discriminator_input).flatten()

        generator_loss = F.binary_cross_entropy(discriminator_output, discriminator_label)
        generator_loss.backward()

        generator_optimizer.step()

        if (step % 50) == 0:

            vq_vae_loss = vq_vae_loss.detach().item()
            discriminator_loss = discriminator_loss.detach().item()
            generator_loss = generator_loss.detach().item()

            print(f'epoch: {epoch:06d}/{epochs}, batch: {batch:06d}/{batches}, step: {step:06d} - vq-vae loss: {vq_vae_loss:0.3f}, discriminator loss: {discriminator_loss:0.3f}, generator loss: {generator_loss:0.3f}')

            save_image(generator_output, f'./reconstruction-{step:06d}.png')

        step += 1

epoch: 000000/20, batch: 000000/3750, step: 000000 - vq-vae loss: 1.358, discriminator loss: 0.133, generator loss: 0.691
epoch: 000000/20, batch: 000050/3750, step: 000050 - vq-vae loss: 0.782, discriminator loss: 0.013, generator loss: 0.000
epoch: 000000/20, batch: 000100/3750, step: 000100 - vq-vae loss: 0.581, discriminator loss: 0.005, generator loss: 0.000
epoch: 000000/20, batch: 000150/3750, step: 000150 - vq-vae loss: 0.427, discriminator loss: 0.002, generator loss: 0.000
epoch: 000000/20, batch: 000200/3750, step: 000200 - vq-vae loss: 0.589, discriminator loss: 0.006, generator loss: 0.005
epoch: 000000/20, batch: 000250/3750, step: 000250 - vq-vae loss: 0.376, discriminator loss: 0.000, generator loss: 0.000
epoch: 000000/20, batch: 000300/3750, step: 000300 - vq-vae loss: 0.361, discriminator loss: 0.027, generator loss: 0.002
epoch: 000000/20, batch: 000350/3750, step: 000350 - vq-vae loss: 0.314, discriminator loss: 0.000, generator loss: 0.002
epoch: 000000/20, batch:

KeyboardInterrupt: ignored