# Autoencoders

Autoencoders constitute a family of neural network architectures designed to learn
compressed representations of data in an unsupervised manner. The fundamental structure
of an autoencoder is organized into two main blocks: An encoder, which transforms the
original input into a lower-dimensional latent representation, and a decoder, which takes
this latent representation and reconstructs from it an approximation of the original
input. The training objective consists of minimizing the discrepancy between the
reconstructed output and the input, so that the model is forced to capture the most
relevant characteristics of the data in the latent space.

This document presents several variants of autoencoders, from basic dense architectures
to more advanced models such as variational autoencoders (VAE), Beta-VAE, and VQ-VAE. All
implementations are developed on the MNIST dataset and are provided as fully functional
code, ready to be executed from start to finish in a Jupyter Notebook environment.

## Vanilla Autoencoder with Dense Layers

The vanilla autoencoder uses exclusively dense (fully connected) layers to encode and
decode MNIST images. Each image of size $28 \times 28$ is flattened into a vector of
dimension 784 and projected into a lower-dimensional latent space. The encoder applies a
sequence of linear transformations and nonlinear activation functions until it reaches
the latent space, whereas the decoder performs the inverse process to reconstruct the
image.

This configuration introduces the central idea of autoencoders but exhibits clear
limitations. Dense layers do not explicitly exploit the spatial structure of the image,
which leads to a large number of parameters due to the full connectivity between neurons.
In addition, since local relationships between pixels are not modeled explicitly,
reconstructions tend to be blurrier and less detailed.

The following code presents a basic functional implementation on MNIST.

In [None]:
# 3pps
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class VanillaAutoencoder(nn.Module):
    def __init__(self, input_dim: int = 784, latent_dim: int = 32) -> None:
        super().__init__()
        # Encoder: Progressively reduces dimensionality
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim),
        )
        # Decoder: Reconstructs from the latent space
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim),
            nn.Sigmoid(),  # Output in [0, 1]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Flatten the image
        x = x.view(x.size(0), -1)
        # Encode
        latent = self.encoder(x)
        # Decode
        reconstructed = self.decoder(latent)
        # Return to image shape
        return reconstructed.view(-1, 1, 28, 28)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(x.size(0), -1)
        return self.encoder(x)

In [None]:
def prepare_mnist_data(batch_size: int = 128):
    transform = transforms.Compose([transforms.ToTensor()])

    train_dataset = datasets.MNIST(
        root="./data", train=True, download=True, transform=transform
    )
    test_dataset = datasets.MNIST(
        root="./data", train=False, download=True, transform=transform
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

In [None]:
def train_autoencoder(
    model: nn.Module,
    train_loader: DataLoader,
    num_epochs: int = 10,
    device: str = "cuda",
) -> nn.Module:

    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        for data, _ in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            reconstructed = model(data)
            loss = criterion(reconstructed, data)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.6f}")

    return model

In [None]:
def visualize_reconstructions(
    model: nn.Module,
    test_loader: DataLoader,
    num_images: int = 10,
    device: str = "cuda",
) -> None:

    model.eval()
    data, _ = next(iter(test_loader))
    data = data[:num_images].to(device)

    with torch.no_grad():
        reconstructed = model(data)

    data = data.cpu()
    reconstructed = reconstructed.cpu()

    fig, axes = plt.subplots(2, num_images, figsize=(15, 3))
    for i in range(num_images):
        axes[0, i].imshow(data[i].squeeze(), cmap="gray")
        axes[0, i].axis("off")
        axes[0, i].set_title("Original")

        axes[1, i].imshow(reconstructed[i].squeeze(), cmap="gray")
        axes[1, i].axis("off")
        axes[1, i].set_title("Reconstructed")

    plt.tight_layout()
    plt.show()

In [None]:
# Vanilla autoencoder execution
train_loader, test_loader = prepare_mnist_data()
vanilla_ae = VanillaAutoencoder(input_dim=784, latent_dim=32)
device = "cuda" if torch.cuda.is_available() else "cpu"
vanilla_ae = train_autoencoder(vanilla_ae, train_loader, num_epochs=10, device=device)
visualize_reconstructions(vanilla_ae, test_loader, device=device)

## Denoising Autoencoder

The denoising autoencoder extends the previous approach by introducing noise into the
input during training. In this case, the encoder receives a corrupted version of the
image, while the loss function compares the decoder output with the clean original image.
This mechanism forces the model to learn robust latent representations that capture the
underlying structure of the data, rather than merely approximating the identity function.

Noise is usually introduced as additive Gaussian noise, and values are subsequently
clipped to keep them in the range $[0, 1]$. In this way, the model learns to "undo" the
corruption, acting as a filter that preserves relevant content and discards spurious
details.

The following code illustrates an implementation of this variant on MNIST.

In [None]:
class DenoisingAutoencoder(nn.Module):
    def __init__(self, input_dim: int = 784, latent_dim: int = 32) -> None:
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(x.size(0), -1)
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return reconstructed.view(-1, 1, 28, 28)

In [None]:
def add_noise(images: torch.Tensor, noise_factor: float = 0.3) -> torch.Tensor:
    noisy = images + noise_factor * torch.randn_like(images)
    noisy = torch.clip(noisy, 0.0, 1.0)
    return noisy

In [None]:
def train_denoising_ae(
    model: nn.Module,
    train_loader: DataLoader,
    num_epochs: int = 10,
    device: str = "cuda",
    noise_factor: float = 0.3,
) -> nn.Module:

    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        for data, _ in train_loader:
            clean_data = data.to(device)
            noisy_data = add_noise(clean_data, noise_factor)

            optimizer.zero_grad()
            reconstructed = model(noisy_data)
            loss = criterion(reconstructed, clean_data)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.6f}")

    return model

In [None]:
def visualize_denoising(
    model: nn.Module,
    test_loader: DataLoader,
    noise_factor: float = 0.3,
    num_images: int = 10,
    device: str = "cuda",
) -> None:

    model.eval()
    data, _ = next(iter(test_loader))
    data = data[:num_images].to(device)
    noisy_data = add_noise(data, noise_factor)

    with torch.no_grad():
        reconstructed = model(noisy_data)

    data = data.cpu()
    noisy_data = noisy_data.cpu()
    reconstructed = reconstructed.cpu()

    fig, axes = plt.subplots(3, num_images, figsize=(15, 5))
    for i in range(num_images):
        axes[0, i].imshow(data[i].squeeze(), cmap="gray")
        axes[0, i].axis("off")
        if i == 0:
            axes[0, i].set_ylabel("Original", rotation=0, labelpad=40)

        axes[1, i].imshow(noisy_data[i].squeeze(), cmap="gray")
        axes[1, i].axis("off")
        if i == 0:
            axes[1, i].set_ylabel("Noisy", rotation=0, labelpad=40)

        axes[2, i].imshow(reconstructed[i].squeeze(), cmap="gray")
        axes[2, i].axis("off")
        if i == 0:
            axes[2, i].set_ylabel("Denoised", rotation=0, labelpad=40)

    plt.tight_layout()
    plt.show()

In [None]:
# Denoising autoencoder execution
denoising_ae = DenoisingAutoencoder(input_dim=784, latent_dim=32)
denoising_ae = train_denoising_ae(
    denoising_ae, train_loader, num_epochs=10, device=device
)
visualize_denoising(denoising_ae, test_loader, device=device)

## Convolutional Autoencoder

Convolutional autoencoders are better suited to image data because they explicitly
exploit spatial structure. The encoder applies convolutions with shared weights and local
filters; spatial dimensionality is reduced through stride and the stacking of layers. The
decoder uses transposed convolutions to perform upsampling and reconstruct the original
resolution.

In this context, convolutions provide several advantages. They significantly reduce the
number of parameters compared with dense layers, due to weight sharing across different
spatial positions. They also capture local patterns and hierarchical structures (edges,
digit parts, whole digits), which leads to sharper reconstructions that are more
consistent with image content.

The following implementation illustrates a convolutional autoencoder with a linear
bottleneck.

In [None]:
class ConvAutoencoder(nn.Module):
    def __init__(self, latent_dim: int = 128) -> None:
        super().__init__()
        # Convolutional encoder
        self.encoder = nn.Sequential(
            # 28x28 -> 14x14
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            # 14x14 -> 7x7
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            # 7x7 -> 4x4 (slight additional reduction)
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
        )

        # Linear bottleneck
        self.flatten = nn.Flatten()
        self.fc_encode = nn.Linear(128 * 4 * 4, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, 128 * 4 * 4)
        self.unflatten = nn.Unflatten(1, (128, 4, 4))

        # Decoder with transposed convolutions
        self.decoder = nn.Sequential(
            # 4x4 -> 7x7
            nn.ConvTranspose2d(
                128, 64, kernel_size=3, stride=2, padding=1, output_padding=0
            ),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            # 7x7 -> 14x14
            nn.ConvTranspose2d(
                64, 32, kernel_size=3, stride=2, padding=1, output_padding=1
            ),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            # 14x14 -> 28x28
            nn.ConvTranspose2d(
                32, 1, kernel_size=3, stride=2, padding=1, output_padding=1
            ),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Encode
        x = self.encoder(x)
        x = self.flatten(x)
        latent = self.fc_encode(x)
        # Decode
        x = self.fc_decode(latent)
        x = self.unflatten(x)
        reconstructed = self.decoder(x)
        return reconstructed

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encoder(x)
        x = self.flatten(x)
        return self.fc_encode(x)

In [None]:
# Training the convolutional autoencoder
conv_ae = ConvAutoencoder(latent_dim=128)
conv_ae = train_autoencoder(conv_ae, train_loader, num_epochs=10, device=device)
visualize_reconstructions(conv_ae, test_loader, device=device)

Transposed convolutions can introduce characteristic artifacts known as checkerboard
artifacts, which arise when the combination of kernel size and stride produces uneven
overlaps during the upsampling operation.

## Autoencoder with Interpolation-Based Upsampling

To mitigate checkerboard artifacts, it is common to replace transposed convolutions with
an upsampling strategy based on interpolation followed by standard convolutions. In this
configuration, the spatial resolution is first increased by interpolation (bilinear,
bicubic, etc.), and then a convolution is applied to refine the result and learn filters
over the rescaled image.

This procedure tends to produce smoother and visually more coherent reconstructions,
significantly reducing undesired patterns at the cost of some additional computational
cost.

The following model preserves the same convolutional encoder as the previous autoencoder
but replaces the `ConvTranspose2d`-based decoder with a decoder that combines `Upsample`
and `Conv2d`.

In [None]:
class UpsamplingAutoencoder(nn.Module):
    def __init__(self, latent_dim: int = 128) -> None:
        super().__init__()
        # Encoder identical to the convolutional autoencoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
        )

        self.flatten = nn.Flatten()
        self.fc_encode = nn.Linear(128 * 4 * 4, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, 128 * 4 * 4)
        self.unflatten = nn.Unflatten(1, (128, 4, 4))

        # Decoder with upsampling + convolution
        self.decoder = nn.Sequential(
            # 4x4 -> 7x7
            nn.Upsample(size=(7, 7), mode="bilinear", align_corners=False),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            # 7x7 -> 14x14
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            # 14x14 -> 28x28
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            nn.Conv2d(32, 1, kernel_size=3, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encoder(x)
        x = self.flatten(x)
        latent = self.fc_encode(x)
        x = self.fc_decode(latent)
        x = self.unflatten(x)
        reconstructed = self.decoder(x)
        return reconstructed

In [None]:
# Training and visualization
upsampling_ae = UpsamplingAutoencoder(latent_dim=128)
upsampling_ae = train_autoencoder(
    upsampling_ae, train_loader, num_epochs=10, device=device
)
visualize_reconstructions(upsampling_ae, test_loader, device=device)

The use of bilinear or bicubic interpolation followed by standard convolutions generally
produces visually more pleasant reconstructions and significantly reduces checkerboard
artifacts, while preserving the model's ability to capture high-level patterns.

## Variational Autoencoder (VAE)

The variational autoencoder (VAE) introduces an important conceptual change with respect
to deterministic autoencoders. Instead of learning a direct mapping from the input to a
fixed latent vector, the encoder learns the parameters of a probability distribution over
the latent space. It is usually assumed that each latent dimension follows an independent
Gaussian distribution, so the encoder produces a mean $\mu$ and a logarithm of the
variance $\log \sigma^2$ for each dimension.

During training, a sample $z$ is drawn from the latent space using the reparameterization
trick:

$$z = \mu + \sigma \odot \varepsilon$$

where $\varepsilon \sim \mathcal{N}(0, I)$ and

$$\sigma = \exp\left(\tfrac{1}{2} \log\sigma^2\right)$$

This formulation allows gradients to be backpropagated through the sampling operation.

The VAE loss function includes two terms. The first is the reconstruction loss, which
measures the discrepancy between the original and reconstructed images (for example,
using binary cross-entropy). The second is a regularization term based on the
Kullback–Leibler (KL) divergence between the learned latent distribution and a standard
normal distribution $\mathcal{N}(0, I)$:

$$\mathcal{L}_{\text{KL}} = -\frac{1}{2}\sum_{i} \left(1 + \log \sigma_i^2 - \mu_i^2 - \sigma_i^2\right)$$

This term forces the latent space to adopt a well-structured distribution, facilitating
sampling and the generation of new examples.

The following code presents a convolutional VAE implementation for MNIST.

In [None]:
# 3pps
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class VAE(nn.Module):
    def __init__(self, latent_dim: int = 20) -> None:
        super().__init__()
        # Convolutional encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),  # 28x28 -> 14x14
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # 14x14 -> 7x7
            nn.ReLU(),
            nn.Flatten(),
        )

        self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
        self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)

        # Decoder
        self.fc_decode = nn.Linear(latent_dim, 64 * 7 * 7)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (64, 7, 7)),
            nn.ConvTranspose2d(
                64, 32, kernel_size=3, stride=2, padding=1, output_padding=1
            ),  # 7x7 -> 14x14
            nn.ReLU(),
            nn.ConvTranspose2d(
                32, 1, kernel_size=3, stride=2, padding=1, output_padding=1
            ),  # 14x14 -> 28x28
            nn.Sigmoid(),
        )

    def encode(self, x: torch.Tensor):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        x = self.fc_decode(z)
        return self.decoder(x)

    def forward(self, x: torch.Tensor):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        reconstructed = self.decode(z)
        return reconstructed, mu, logvar

In [None]:
def vae_loss(
    reconstructed: torch.Tensor,
    original: torch.Tensor,
    mu: torch.Tensor,
    logvar: torch.Tensor,
) -> torch.Tensor:

    recon_loss = nn.functional.binary_cross_entropy(
        reconstructed, original, reduction="sum"
    )
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss

In [None]:
def train_vae(
    model: nn.Module,
    train_loader: DataLoader,
    num_epochs: int = 10,
    device: str = "cuda",
) -> nn.Module:

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        for data, _ in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            reconstructed, mu, logvar = model(data)
            loss = vae_loss(reconstructed, data, mu, logvar)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(
            f"Epoch [{epoch+1}/{num_epochs}], "
            f"Loss: {total_loss / len(train_loader.dataset):.4f}"
        )

    return model

In [None]:
def visualize_latent_space_tsne(
    model: VAE, data_loader: DataLoader, device: str = "cuda", n_samples: int = 5000
) -> None:
    """Visualize the latent space using t-SNE."""
    model.eval()
    latent_vectors = []
    labels = []

    with torch.no_grad():
        for data, label in data_loader:
            data = data.to(device)
            mu, _ = model.encode(data)
            latent_vectors.append(mu.cpu().numpy())
            labels.append(label.numpy())
            if len(latent_vectors) * data.size(0) >= n_samples:
                break

    latent_vectors = np.concatenate(latent_vectors, axis=0)[:n_samples]
    labels = np.concatenate(labels, axis=0)[:n_samples]

    print("Applying t-SNE...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    latent_2d = tsne.fit_transform(latent_vectors)

    plt.figure(figsize=(12, 10))
    scatter = plt.scatter(
        latent_2d[:, 0], latent_2d[:, 1], c=labels, cmap="tab10", alpha=0.6, s=5
    )
    plt.colorbar(scatter, label="Digit")
    plt.title("t-SNE Visualization of the VAE Latent Space")
    plt.xlabel("t-SNE Dimension 1")
    plt.ylabel("t-SNE Dimension 2")
    plt.tight_layout()
    plt.show()

In [None]:
def generate_samples(
    model: VAE, num_samples: int = 16, latent_dim: int = 20, device: str = "cuda"
) -> None:

    model.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, latent_dim).to(device)
        samples = model.decode(z).cpu()

    fig, axes = plt.subplots(4, 4, figsize=(8, 8))
    for i, ax in enumerate(axes.flat):
        ax.imshow(samples[i].squeeze(), cmap="gray")
        ax.axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
def prepare_mnist_data(batch_size: int = 128):
    transform = transforms.Compose([transforms.ToTensor()])

    train_dataset = datasets.MNIST(
        root="./data", train=True, download=True, transform=transform
    )
    test_dataset = datasets.MNIST(
        root="./data", train=False, download=True, transform=transform
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

In [None]:
# Prepare data
train_loader, test_loader = prepare_mnist_data()

# Train VAE
vae = VAE(latent_dim=20)
device = "cuda" if torch.cuda.is_available() else "cpu"
vae = train_vae(vae, train_loader, num_epochs=20, device=device)

# Visualize latent space with t-SNE
visualize_latent_space_tsne(vae, test_loader, device=device)

# Generate synthetic samples
generate_samples(vae, latent_dim=20, device=device)

# Visualize reconstructions
with torch.no_grad():
    data, _ = next(iter(test_loader))
    data = data[:10].to(device)
    reconstructed, _, _ = vae(data)

    fig, axes = plt.subplots(2, 10, figsize=(15, 3))
    for i in range(10):
        axes[0, i].imshow(data[i].cpu().squeeze(), cmap="gray")
        axes[0, i].axis("off")
        axes[1, i].imshow(reconstructed[i].cpu().squeeze(), cmap="gray")
        axes[1, i].axis("off")

    axes[0, 0].set_ylabel("Original", size=12)
    axes[1, 0].set_ylabel("Reconstructed", size=12)
    plt.tight_layout()
    plt.show()

VAEs are particularly useful for generating synthetic data by direct sampling in the
latent space and for anomaly detection by analyzing out-of-distribution examples.
However, they can suffer from the posterior collapse phenomenon, in which the decoder
largely ignores latent information and learns to reconstruct from local patterns alone,
reducing the quality and informativeness of latent representations.

## Beta-VAE

The Beta-VAE introduces a hyperparameter $\beta$ in the VAE loss function to weight the
KL divergence term:

$$\mathcal{L}_{\beta\text{-VAE}} = \mathcal{L}_{\text{recon}} + \beta \,\mathcal{L}_{\text{KL}}$$

When $\beta > 1$, the model is forced to align the latent distribution more strongly with
the standard normal distribution, which tends to produce more disentangled
representations. In a disentangled latent space, each dimension preferentially captures
an independent factor of variation in the data (for example, stroke thickness, slant, or
size), improving interpretability and control over generated samples.

Excessively high values of $\beta$ can degrade reconstruction quality by penalizing
latent code complexity too strongly.

The following code shows how to adapt the loss and training procedure for a Beta-VAE
using the VAE architecture defined above.

In [None]:
def beta_vae_loss(
    reconstructed: torch.Tensor,
    original: torch.Tensor,
    mu: torch.Tensor,
    logvar: torch.Tensor,
    beta: float = 4.0,
) -> torch.Tensor:

    recon_loss = nn.functional.binary_cross_entropy(
        reconstructed, original, reduction="sum"
    )
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + beta * kl_loss

In [None]:
def train_beta_vae(
    model: VAE,
    train_loader: DataLoader,
    num_epochs: int = 10,
    beta: float = 4.0,
    device: str = "cuda",
) -> VAE:

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    model.train()

    for epoch in range(num_epochs):
        total_loss = 0.0
        for data, _ in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            reconstructed, mu, logvar = model(data)
            loss = beta_vae_loss(reconstructed, data, mu, logvar, beta)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader.dataset)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

    return model

To explore the effect of controlled variations along individual latent dimensions, the
latent traversal technique is used. It consists of systematically modifying a single
latent coordinate while keeping the remaining ones fixed.

In [None]:
def visualize_latent_traversal(
    model: VAE,
    test_loader: DataLoader,
    latent_dim: int = 20,
    dim_to_vary: int = 0,
    device: str = "cuda",
) -> None:

    model.eval()
    data, _ = next(iter(test_loader))
    data = data[0:1].to(device)

    with torch.no_grad():
        mu, _ = model.encode(data)
        values = torch.linspace(-3, 3, 10)
        samples = []

        for val in values:
            z = mu.clone()
            z[0, dim_to_vary] = val
            reconstructed = model.decode(z)
            samples.append(reconstructed)

        samples = torch.cat(samples, dim=0)

    samples = samples.cpu()
    fig, axes = plt.subplots(1, 10, figsize=(15, 2))
    for i, ax in enumerate(axes.flat):
        ax.imshow(samples[i].squeeze(), cmap="gray")
        ax.axis("off")
        ax.set_title(f"{values[i]:.1f}")
    plt.tight_layout()
    plt.show()

In [None]:
# Training the Beta-VAE
beta_vae = VAE(latent_dim=20)
beta_vae = train_beta_vae(
    beta_vae, train_loader, num_epochs=20, beta=4.0, device=device
)

# Visualize variation of some latent dimensions
for dim in range(5):
    visualize_latent_traversal(beta_vae, test_loader, dim_to_vary=dim, device=device)

Latent traversal enables inspection of the influence of each latent dimension on
generated samples, facilitating the interpretation of disentangled representations and
the design of controlled manipulations over specific attributes.

## VQ-VAE (Vector Quantized VAE)

VQ-VAE introduces a fundamental modification in the treatment of the latent space.
Instead of continuous codes, it uses a discrete representation based on a learned
codebook of embeddings. The encoder projects the input into a continuous latent tensor of
dimension $C$; each latent vector is then quantized by selecting the closest embedding
from the codebook, that is, by assigning a discrete index. The decoder receives the
quantized embeddings and reconstructs the input.

This discretization offers several advantages. It avoids the posterior collapse problem
typical of some VAEs, as quantization forces the model to actively use the latent space.
Moreover, the discrete representation is particularly well suited to be modeled later
using autoregressive models (for example, transformers), which has been crucial in
generative architectures such as DALL·E. In this context, latent indices act as tokens on
which language-modeling techniques can be applied.

The following code presents a simple VQ-VAE implementation for MNIST, including the
vector quantization module.

In [None]:
"""VQ-VAE (Vector Quantized Variational Autoencoder) Implementation"""

# 3pps
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
class VectorQuantizer(nn.Module):
    """
    Vector Quantizer layer for VQ-VAE.
    Converts continuous latent vectors into discrete codes from the codebook.
    """

    def __init__(
        self, num_embeddings: int, embedding_dim: int, commitment_cost: float = 0.25
    ) -> None:
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost

        # Codebook of embeddings
        self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
        self.embeddings.weight.data.uniform_(-1 / num_embeddings, 1 / num_embeddings)

    def forward(self, inputs: torch.Tensor):
        """
        Args:
            inputs: Tensor of shape (B, C, H, W)
        Returns:
            quantized: Quantized tensor (B, C, H, W)
            loss: Quantization loss (codebook + commitment)
            encoding_indices: Indices of selected codebook vectors
        """
        # Reorder to (B, H, W, C)
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape

        # Flatten to (B*H*W, C)
        flat_input = inputs.view(-1, self.embedding_dim)

        # L2 distances to each codebook embedding
        distances = (
            torch.sum(flat_input**2, dim=1, keepdim=True)
            + torch.sum(self.embeddings.weight**2, dim=1)
            - 2 * torch.matmul(flat_input, self.embeddings.weight.t())
        )

        # Index of nearest embedding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)

        # One-hot encoding
        encodings = torch.zeros(
            encoding_indices.shape[0], self.num_embeddings, device=inputs.device
        )
        encodings.scatter_(1, encoding_indices, 1)

        # Quantization via codebook
        quantized = torch.matmul(encodings, self.embeddings.weight).view(input_shape)

        # VQ losses
        e_latent_loss = nn.functional.mse_loss(quantized.detach(), inputs)
        q_latent_loss = nn.functional.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        # Straight-through estimator
        quantized = inputs + (quantized - inputs).detach()

        # Back to (B, C, H, W)
        quantized = quantized.permute(0, 3, 1, 2).contiguous()

        return quantized, loss, encoding_indices

In [None]:
class VQVAE(nn.Module):
    """
    VQ-VAE model with encoder, vector quantizer, and decoder.
    """

    def __init__(self, num_embeddings: int = 512, embedding_dim: int = 64) -> None:
        super().__init__()

        # Encoder: (1, 28, 28) -> (embedding_dim, 7, 7)
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),  # 28x28 -> 14x14
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # 14x14 -> 7x7
            nn.ReLU(),
            nn.Conv2d(64, embedding_dim, kernel_size=1),  # 7x7, C=embedding_dim
        )

        # Vector Quantizer
        self.vq = VectorQuantizer(num_embeddings, embedding_dim)

        # Decoder: (embedding_dim, 7, 7) -> (1, 28, 28)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(
                embedding_dim, 64, kernel_size=4, stride=2, padding=1
            ),  # 7x7 -> 14x14
            nn.ReLU(),
            nn.ConvTranspose2d(
                64, 32, kernel_size=4, stride=2, padding=1
            ),  # 14x14 -> 28x28
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=1),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor):
        """
        Args:
            x: Input tensor (B, 1, 28, 28)
        Returns:
            reconstructed: Reconstruction (B, 1, 28, 28)
            vq_loss: Vector quantization loss
        """
        z = self.encoder(x)
        quantized, vq_loss, _ = self.vq(z)
        reconstructed = self.decoder(quantized)
        return reconstructed, vq_loss

    def encode(self, x: torch.Tensor):
        """Encode and quantize the input."""
        z = self.encoder(x)
        quantized, _, indices = self.vq(z)
        return quantized, indices

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """Decode a quantized latent tensor."""
        return self.decoder(z)

In [None]:
def train_vqvae(
    model: VQVAE,
    train_loader: DataLoader,
    num_epochs: int = 10,
    lr: float = 1e-3,
    device: str = "cuda",
) -> VQVAE:
    """
    Train the VQ-VAE model.

    Args:
        model: VQVAE model.
        train_loader: Training DataLoader.
        num_epochs: Number of epochs.
        lr: Learning rate.
        device: "cuda" or "cpu".

    Returns:
        Trained model.
    """
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    model.train()

    for epoch in range(num_epochs):
        total_recon_loss = 0.0
        total_vq_loss = 0.0

        for data, _ in train_loader:
            data = data.to(device)
            optimizer.zero_grad()

            reconstructed, vq_loss = model(data)
            recon_loss = nn.functional.mse_loss(reconstructed, data)
            loss = recon_loss + vq_loss

            loss.backward()
            optimizer.step()

            total_recon_loss += recon_loss.item()
            total_vq_loss += vq_loss.item()

        avg_recon = total_recon_loss / len(train_loader)
        avg_vq = total_vq_loss / len(train_loader)
        print(
            f"Epoch [{epoch+1}/{num_epochs}] | "
            f"Recon Loss: {avg_recon:.6f} | "
            f"VQ Loss: {avg_vq:.6f}"
        )

    return model

In [None]:
def visualize_vqvae_reconstructions(
    model: VQVAE, test_loader: DataLoader, device: str = "cuda", num_images: int = 8
) -> None:
    """
    Visualize original and VQ-VAE reconstructed images.
    """
    model.eval()
    data, _ = next(iter(test_loader))
    data = data[:num_images].to(device)

    with torch.no_grad():
        reconstructed, _ = model(data)

    data = data.cpu()
    reconstructed = reconstructed.cpu()

    fig, axes = plt.subplots(2, num_images, figsize=(12, 3))
    for i in range(num_images):
        axes[0, i].imshow(data[i].squeeze(), cmap="gray")
        axes[0, i].axis("off")
        axes[1, i].imshow(reconstructed[i].squeeze(), cmap="gray")
        axes[1, i].axis("off")

    axes[0, 0].set_ylabel("Original", size=12)
    axes[1, 0].set_ylabel("Reconstructed", size=12)
    plt.tight_layout()
    plt.show()

In [None]:
# Main VQ-VAE execution
NUM_EMBEDDINGS = 512
EMBEDDING_DIM = 64
NUM_EPOCHS = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

vqvae = VQVAE(num_embeddings=NUM_EMBEDDINGS, embedding_dim=EMBEDDING_DIM)
vqvae = train_vqvae(vqvae, train_loader, num_epochs=NUM_EPOCHS, device=DEVICE)
visualize_vqvae_reconstructions(vqvae, test_loader, device=DEVICE)

VQ-VAE provides a discrete latent space that is particularly suitable for integration
into multimodal systems and complex generative models, in which tokenization of data is
essential. Vector quantization offers a robust foundation for applying advanced
sequential modeling techniques to image representations and facilitates integration with
language architectures that operate on discrete sequences.