# 变分自编码器 (VAE) 深度实现

**SOTA 教育标准** | 包含完整数学推导、KL 散度解析解、流形可视化

---

## 1. 数学理论基础

### 1.1 核心思想

**问题**: 数据 $x$ 是由潜在变量 $z$ 生成的，但 $z$ 不可观测，如何学习 $p(z|x)$？

**变分推断思路**: 不直接计算 $p(z|x)$，而是用一组分布 $q_\phi(z|x)$ 去近似它。

### 1.2 ELBO 推导 ⭐

**目标**: 最大化对数似然 $\log p(x)$

$$\begin{aligned}
\log p(x) &= \log \int_z p(x, z) dz \\
&= \log \int_z q(z|x) \frac{p(x, z)}{q(z|x)} dz \\
&= \log \mathbb{E}_{z \sim q}[\frac{p(x, z)}{q(z|x)}] \\
&\geq \mathbb{E}_{z \sim q}[\log \frac{p(x, z)}{q(z|x)}] \quad \text{(Jensen不等式)} \\
&= \mathbb{E}_{z \sim q}[\log \frac{p(x|z)p(z)}{q(z|x)}] \\
&= \mathbb{E}_{z \sim q}[\log p(x|z)] + \mathbb{E}_{z \sim q}[\log \frac{p(z)}{q(z|x)}] \\
&= \mathbb{E}_{z \sim q}[\log p(x|z)] - D_{KL}(q(z|x) \| p(z))
\end{aligned}$$

这就是 **ELBO (Evidence Lower Bound)**:
$$\mathcal{L}_{ELBO} = \underbrace{\mathbb{E}_{q(z|x)}[\log p(x|z)]}_{\text{重构项}} - \underbrace{D_{KL}(q(z|x) \| p(z))}_{\text{KL散度}}$$

**直觉**:
- **重构项**: 让 $z$ 能够解码出 $x$（编码器能提取有效信息）
- **KL项**: 让 $q(z|x)$ 接近先验 $p(z)$（正则化，防止过拟合）

### 1.3 KL 散度解析解推导 ⭐

设:
- $q(z|x) = \mathcal{N}(\mu, \sigma^2)$
- $p(z) = \mathcal{N}(0, 1)$

**多维高斯分布的 KL 散度公式**:

$$\begin{aligned}
D_{KL}(q \| p) &= \int q(z) \log \frac{q(z)}{p(z)} dz \\
&= -\frac{1}{2} \sum_{j=1}^J \left(1 + \log \sigma_j^2 - \mu_j^2 - \sigma_j^2\right)
\end{aligned}$$

**推导过程**:

$$\begin{aligned}
\log \frac{q(z)}{p(z)} &= \log \frac{\frac{1}{\sqrt{2\pi\sigma^2}}e^{-\frac{(z-\mu)^2}{2\sigma^2}}}{\frac{1}{\sqrt{2\pi}}e^{-\frac{z^2}{2}}} \\
&= \log \left(\frac{1}{\sigma} \cdot e^{-\frac{(z-\mu)^2}{2\sigma^2} + \frac{z^2}{2}}\right) \\
&= -\log \sigma - \frac{(z-\mu)^2}{2\sigma^2} + \frac{z^2}{2}
\end{aligned}$$

取期望 $\mathbb{E}_q[\cdot]$:

$$\begin{aligned}
D_{KL}(q \| p) &= \mathbb{E}_q\left[-\log \sigma - \frac{(Z-\mu)^2}{2\sigma^2} + \frac{Z^2}{2}\right] \\
&= -\log \sigma - \frac{\mathbb{E}_q[(Z-\mu)^2]}{2\sigma^2} + \frac{\mathbb{E}_q[Z^2]}{2} \\
&= -\log \sigma - \frac{\sigma^2}{2\sigma^2} + \frac{\mu^2 + \sigma^2}{2} \\
&= -\frac{1}{2}\left(1 + 2\log \sigma - \mu^2 - \sigma^2\right) \\
&= -\frac{1}{2}\left(1 + \log \sigma^2 - \mu^2 - \sigma^2\right)
\end{aligned}$$

多维情况下求和:
$$D_{KL}(q \| p) = -\frac{1}{2}\sum_{j=1}^{J}\left(1 + \log \sigma_j^2 - \mu_j^2 - \sigma_j^2\right)$$

**直觉解释**:
- 当 $\mu \to 0, \sigma \to 1$ 时，KL 散度 $\to 0$（$q$ 接近 $p$）
- KL 项作为正则化，迫使潜在空间接近标准正态分布

### 1.4 重参数化技巧 (Reparameterization Trick)

**问题**: 需要从 $q(z|x) = \mathcal{N}(\mu, \sigma^2)$ 采样，但采样不可微。

**解决方案**:
$$z = \mu + \sigma \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)$$

**关键**:
- $\epsilon$ 是噪声，与参数无关（随机源）
- $\mu, \sigma$ 是网络输出，可微
- 梯度可以反向传播到 $\mu, \sigma$

---

## 2. 代码实现

In [None]:
from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

In [None]:
@dataclass
class VAEConfig:
    """VAE 配置类。"""

    input_dim: int = 784
    hidden_dim: int = 400
    latent_dim: int = 2  # 2D for visualization
    lr: float = 1e-3
    batch_size: int = 128
    epochs: int = 20
    beta: float = 1.0  # KL 权重 (beta-VAE)

In [None]:
class VAE(nn.Module):
    """变分自编码器。

    核心思想:
        编码器学习 q_phi(z|x) ~ N(mu, sigma^2)
        解码器学习 p_theta(x|z)
        通过 ELBO = E[log p(x|z)] - KL(q||p) 端到端训练

    数学原理:
        重参数化: z = mu + sigma * epsilon, epsilon ~ N(0, I)
        KL散度: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    """

    def __init__(self, config: VAEConfig) -> None:
        super().__init__()
        self.config = config

        # Encoder: x -> hidden -> (mu, log_var)
        self.encoder = nn.Sequential(
            nn.Linear(config.input_dim, config.hidden_dim),
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(config.hidden_dim, config.latent_dim)
        self.fc_logvar = nn.Linear(config.hidden_dim, config.latent_dim)

        # Decoder: z -> hidden -> x
        self.decoder = nn.Sequential(
            nn.Linear(config.latent_dim, config.hidden_dim),
            nn.ReLU(),
            nn.Linear(config.hidden_dim, config.input_dim),
            nn.Sigmoid(),  # MNIST pixels in [0, 1]
        )

    def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        """编码器输出 mu 和 log_var。

        使用 log_var 而非 var 保证数值稳定性:
            sigma = exp(0.5 * log_var) > 0
        """
        h = self.encoder(x)
        mu = self.fc_mu(h)
        log_var = self.fc_logvar(h)
        return mu, log_var

    def reparameterize(self, mu: Tensor, log_var: Tensor) -> Tensor:
        """重参数化技巧: z = mu + sigma * epsilon。

        关键: epsilon 从标准正态采样，与 mu, sigma 独立，
             使得梯度可以反向传播到 mu, sigma。

        推导:
            设 Z ~ N(mu, sigma^2), E ~ N(0, 1)
            则 Z = mu + sigma * E
        """
        std = torch.exp(0.5 * log_var)  # sigma = exp(0.5 * log(sigma^2))
        eps = torch.randn_like(std)  # epsilon ~ N(0, I)
        return mu + std * eps

    def decode(self, z: Tensor) -> Tensor:
        """从潜在变量 z 重构 x。"""
        return self.decoder(z)

    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        """前向传播: x -> (x_recon, mu, log_var)。"""
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decode(z)
        return x_recon, mu, log_var

    def loss_function(
        self,
        x_recon: Tensor,
        x: Tensor,
        mu: Tensor,
        log_var: Tensor,
    ) -> Tuple[Tensor, Tensor, Tensor]:
        """计算 ELBO 损失。

        ELBO = E[log p(x|z)] - beta * KL(q(z|x) || p(z))

        其中:
            重构项: Binary Cross Entropy (适合 MNIST 二值像素)
            KL散度: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        """
        # 重构损失 (Binary Cross Entropy)
        recon_loss = F.binary_cross_entropy(x_recon, x, reduction="sum")

        # KL 散度: D_KL(N(mu, sigma^2) || N(0, 1))
        # 推导见上文 "KL 散度解析解推导"
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

        total_loss = recon_loss + self.config.beta * kl_loss

        return total_loss, recon_loss, kl_loss


# 创建配置和模型
config = VAEConfig(latent_dim=2, epochs=5)  # 减少训练时间用于演示
model = VAE(config).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

print(f"VAE 参数量: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# 数据加载
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST("./data", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST("./data", train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

print(f"训练集: {len(train_dataset):,}, 测试集: {len(test_dataset):,}")

---

## 3. 训练循环

In [None]:
def train(model, loader, optimizer) -> Tuple[float, float, float]:
    model.train()
    total_loss = total_recon = total_kl = 0

    for x, _ in loader:
        x = x.view(-1, config.input_dim).to(device)

        x_recon, mu, log_var = model(x)
        loss, recon_loss, kl_loss = model.loss_function(x_recon, x, mu, log_var)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_recon += recon_loss.item()
        total_kl += kl_loss.item()

    n = len(loader.dataset)
    return total_loss / n, total_recon / n, total_kl / n


@torch.no_grad()
def test(model, loader) -> float:
    model.eval()
    total_loss = 0

    for x, _ in loader:
        x = x.view(-1, config.input_dim).to(device)
        x_recon, mu, log_var = model(x)
        loss, _, _ = model.loss_function(x_recon, x, mu, log_var)
        total_loss += loss.item()

    return total_loss / len(loader.dataset)


# 训练
print("开始训练 VAE...")
print("=" * 70)

history = {"train": [], "test": [], "recon": [], "kl": []}

for epoch in range(1, config.epochs + 1):
    train_loss, recon_loss, kl_loss = train(model, train_loader, optimizer)
    test_loss = test(model, test_loader)

    history["train"].append(train_loss)
    history["test"].append(test_loss)
    history["recon"].append(recon_loss)
    history["kl"].append(kl_loss)

    print(
        f"Epoch {epoch:2d}/{config.epochs} | "
        f"Loss: {train_loss:.4f} | "
        f"Recon: {recon_loss:.4f} | "
        f"KL: {kl_loss:.4f} | "
        f"Test: {test_loss:.4f}"
    )

print("=" * 70)
print("训练完成!")

In [None]:
# 绘制训练曲线
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# 左图: 总损失
axes[0].plot(history["train"], label="Train Loss", marker="o")
axes[0].plot(history["test"], label="Test Loss", marker="s")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Total Loss")
axes[0].set_title("VAE Training Loss")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# 右图: 重构 vs KL
axes[1].plot(history["recon"], label="Reconstruction", marker="o")
axes[1].plot(history["kl"], label="KL Divergence", marker="s")
axes[1].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[1].set_title("Reconstruction vs KL Loss")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

## 4. 结果可视化

In [None]:
@torch.no_grad()
def show_reconstruction(model, loader, n=10):
    """展示原始 vs 重构图片。"""
    model.eval()
    x, _ = next(iter(loader))
    x = x[:n].to(device)

    x_flat = x.view(-1, config.input_dim)
    x_recon, _, _ = model(x_flat)
    x_recon = x_recon.view(-1, 1, 28, 28)

    fig, axes = plt.subplots(2, n, figsize=(15, 3))

    for i in range(n):
        axes[0, i].imshow(x[i].cpu().squeeze(), cmap="gray")
        axes[0, i].axis("off")
        axes[1, i].imshow(x_recon[i].cpu().squeeze(), cmap="gray")
        axes[1, i].axis("off")

    axes[0, 0].set_title("Original", fontsize=12)
    axes[1, 0].set_title("Reconstructed", fontsize=12)
    plt.suptitle("VAE Reconstruction", fontsize=14)
    plt.tight_layout()
    plt.show()


show_reconstruction(model, test_loader)

In [None]:
@torch.no_grad()
def visualize_latent_space(model, loader):
    """可视化 2D 潜在空间。"""
    model.eval()

    z_all, labels_all = [], []
    for x, labels in loader:
        x = x.view(-1, config.input_dim).to(device)
        mu, _ = model.encode(x)
        z_all.append(mu.cpu())
        labels_all.append(labels)

    z = torch.cat(z_all).numpy()
    labels = torch.cat(labels_all).numpy()

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(z[:, 0], z[:, 1], c=labels, cmap="tab10", alpha=0.6, s=5)
    plt.colorbar(scatter, label="Digit")
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.title("VAE Latent Space (colored by digit)")
    plt.grid(True, alpha=0.3)
    plt.show()


if config.latent_dim == 2:
    visualize_latent_space(model, test_loader)

In [None]:
@torch.no_grad()
def visualize_manifold(model, n=20, range_val=3):
    """在 2D 潜在空间网格上采样，生成流形图。

    这展示了 VAE 学习到的连续潜在空间结构。
    """
    if config.latent_dim != 2:
        print("流形可视化需要 latent_dim=2")
        return

    model.eval()

    grid_x = np.linspace(-range_val, range_val, n)
    grid_y = np.linspace(-range_val, range_val, n)[::-1]

    figure = np.zeros((28 * n, 28 * n))

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z = torch.tensor([[xi, yi]], dtype=torch.float32).to(device)
            x_decoded = model.decode(z).cpu().numpy().reshape(28, 28)
            figure[i * 28 : (i + 1) * 28, j * 28 : (j + 1) * 28] = x_decoded

    plt.figure(figsize=(12, 12))
    plt.imshow(figure, cmap="gray")
    plt.title("VAE Latent Manifold", fontsize=14)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")

    # 坐标刻度
    ticks = np.linspace(0, 28 * n, 7)
    tick_labels = [f"{x:.1f}" for x in np.linspace(-range_val, range_val, 7)]
    plt.xticks(ticks, tick_labels)
    plt.yticks(ticks, tick_labels[::-1])

    plt.show()


visualize_manifold(model)

In [None]:
@torch.no_grad()
def generate_samples(model, n=20):
    """从 N(0, I) 采样生成新图片。"""
    model.eval()

    z = torch.randn(n, config.latent_dim).to(device)
    samples = model.decode(z).view(-1, 1, 28, 28)

    fig, axes = plt.subplots(2, n // 2, figsize=(15, 4))
    for i in range(n):
        row, col = i // (n // 2), i % (n // 2)
        axes[row, col].imshow(samples[i].cpu().squeeze(), cmap="gray")
        axes[row, col].axis("off")

    plt.suptitle("Generated Samples (z ~ N(0, I))", fontsize=14)
    plt.tight_layout()
    plt.show()


generate_samples(model)

---

## 5. β-VAE: 解耦表示学习 ⭐⭐

### 5.1 核心思想

**问题**: 标准 VAE 的潜在空间可能是纠缠的（一个 $z_i$ 影响多个生成因素）。

**β-VAE**: 增大 KL 权重 $\beta > 1$，迫使潜在空间更接近各向同性高斯，从而解耦。

$$\mathcal{L}_{\beta\text{-VAE}} = \mathbb{E}[\log p(x|z)] - \beta \cdot D_{KL}(q(z|x) \| p(z))$$

**权衡**:
- $\beta = 1$: 标准 VAE
- $\beta > 1$: 更解耦，但重构质量下降
- $\beta < 1$: 更好重构，但潜在空间可能塌缩

In [None]:
def compare_beta_values() -> None:
    """对比不同 beta 值对潜在空间的影响。"""
    betas = [0.1, 1.0, 4.0, 10.0]
    results = {}

    for beta in betas:
        cfg = VAEConfig(latent_dim=2, epochs=3, beta=beta)
        vae = VAE(cfg).to(device)
        opt = torch.optim.Adam(vae.parameters(), lr=cfg.lr)

        # 快速训练
        for _ in range(3):
            for x, _ in train_loader:
                x = x.view(-1, cfg.input_dim).to(device)
                x_recon, mu, log_var = vae(x)
                loss, _, _ = vae.loss_function(x_recon, x, mu, log_var)
                opt.zero_grad()
                loss.backward()
                opt.step()

        # 收集潜在表示
        vae.eval()
        z_list, y_list = [], []
        with torch.no_grad():
            for x, y in test_loader:
                x = x.view(-1, cfg.input_dim).to(device)
                mu, _ = vae.encode(x)
                z_list.append(mu.cpu())
                y_list.append(y)
                if len(z_list) > 10:
                    break

        results[beta] = (torch.cat(z_list).numpy(), torch.cat(y_list).numpy())

    # 可视化
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    for ax, beta in zip(axes, betas):
        z, y = results[beta]
        ax.scatter(z[:, 0], z[:, 1], c=y, cmap="tab10", alpha=0.6, s=10)
        ax.set_title(f"β = {beta}")
        ax.set_xlabel("z[0]")
        ax.set_ylabel("z[1]")

    plt.suptitle("β-VAE: 不同 β 值对潜在空间的影响", fontsize=14)
    plt.tight_layout()
    plt.show()

    print("\n观察:")
    print("  β=0.1: 潜在空间分散，重构好但不规则")
    print("  β=1.0: 标准 VAE，平衡")
    print("  β=4.0: 更紧凑，开始解耦")
    print("  β=10.0: 高度压缩，可能过度正则化")


compare_beta_values()

---

## 6. CNN-VAE: 更强的特征提取 ⭐⭐

使用卷积网络替代全连接网络，更好地捕捉图像的空间结构。

In [None]:
class ConvVAE(nn.Module):
    """卷积变分自编码器。

    使用 CNN 替代 MLP，更好地捕捉图像空间结构。
    """

    def __init__(self, latent_dim: int = 32, beta: float = 1.0) -> None:
        super().__init__()
        self.latent_dim = latent_dim
        self.beta = beta

        # Encoder: (1, 28, 28) -> (64, 7, 7) -> latent
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),  # -> (32, 14, 14)
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # -> (64, 7, 7)
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
        )

        self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
        self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)

        # Decoder: latent -> (64, 7, 7) -> (1, 28, 28)
        self.fc_decode = nn.Linear(latent_dim, 64 * 7 * 7)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # -> (32, 14, 14)
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1),  # -> (1, 28, 28)
            nn.Sigmoid(),
        )

    def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu: Tensor, log_var: Tensor) -> Tensor:
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + std * eps

    def decode(self, z: Tensor) -> Tensor:
        h = self.fc_decode(z).view(-1, 64, 7, 7)
        return self.decoder(h)

    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

    def loss_function(self, x_recon: Tensor, x: Tensor, mu: Tensor, log_var: Tensor) -> Tensor:
        recon_loss = F.binary_cross_entropy(x_recon, x, reduction="sum")
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return recon_loss + self.beta * kl_loss


# 测试 CNN-VAE
conv_vae = ConvVAE(latent_dim=32).to(device)
print(f"CNN-VAE 参数量: {sum(p.numel() for p in conv_vae.parameters()):,}")

# 测试前向传播
x_test = torch.randn(4, 1, 28, 28).to(device)
x_recon, mu, log_var = conv_vae(x_test)
print(f"输入: {x_test.shape} -> 重构: {x_recon.shape}, 潜在: {mu.shape}")

---

## 7. VQ-VAE 简介: 离散潜在空间 ⭐⭐⭐

### 7.1 核心思想

**问题**: 连续潜在空间可能导致 "posterior collapse"（解码器忽略 $z$）。

**VQ-VAE**: 使用离散码本 (codebook)，将连续编码量化到最近的码向量。

$$z_q = \text{argmin}_{e_k \in \mathcal{E}} \|z_e - e_k\|_2$$

**损失函数**:
$$\mathcal{L} = \|x - \hat{x}\|^2 + \|\text{sg}[z_e] - e\|^2 + \beta\|z_e - \text{sg}[e]\|^2$$

其中 $\text{sg}[\cdot]$ 是 stop-gradient 操作。

In [None]:
class VectorQuantizer(nn.Module):
    """向量量化层 (VQ-VAE 核心)。

    将连续编码映射到离散码本中最近的向量。
    """

    def __init__(
        self, num_embeddings: int, embedding_dim: int, commitment_cost: float = 0.25
    ) -> None:
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost

        # 码本 (codebook)
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1 / num_embeddings, 1 / num_embeddings)

    def forward(self, z: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        """量化操作。

        Args:
            z: 编码器输出 (B, C, H, W)

        Returns:
            z_q: 量化后的向量
            loss: VQ 损失
            indices: 码本索引
        """
        # (B, C, H, W) -> (B, H, W, C) -> (B*H*W, C)
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flat = z.view(-1, self.embedding_dim)

        # 计算到所有码向量的距离
        distances = (
            torch.sum(z_flat**2, dim=1, keepdim=True)
            + torch.sum(self.embedding.weight**2, dim=1)
            - 2 * torch.matmul(z_flat, self.embedding.weight.t())
        )

        # 找最近的码向量
        indices = torch.argmin(distances, dim=1)
        z_q = self.embedding(indices).view(z.shape)

        # VQ 损失
        codebook_loss = F.mse_loss(z_q.detach(), z)  # 更新编码器
        commitment_loss = F.mse_loss(z_q, z.detach())  # 更新码本
        loss = codebook_loss + self.commitment_cost * commitment_loss

        # Straight-through estimator: 前向用 z_q，反向用 z
        z_q = z + (z_q - z).detach()

        # 恢复形状 (B, H, W, C) -> (B, C, H, W)
        z_q = z_q.permute(0, 3, 1, 2).contiguous()

        return z_q, loss, indices


# 测试 VQ 层
vq = VectorQuantizer(num_embeddings=512, embedding_dim=64)
z_test = torch.randn(4, 64, 7, 7)
z_q, vq_loss, indices = vq(z_test)
print(f"VQ 输入: {z_test.shape} -> 输出: {z_q.shape}")
print(f"VQ Loss: {vq_loss.item():.4f}")
print(f"码本使用率: {len(indices.unique())}/{vq.num_embeddings}")

---

## 8. VAE 变体对比

| 变体 | 潜在空间 | 优势 | 劣势 |
|:-----|:---------|:-----|:-----|
| **VAE** | 连续高斯 | 简单，可微 | 可能模糊 |
| **β-VAE** | 连续高斯 | 解耦表示 | 重构质量下降 |
| **VQ-VAE** | 离散码本 | 清晰，无 posterior collapse | 需要码本管理 |
| **CVAE** | 条件高斯 | 可控生成 | 需要标签 |

**进阶学习**: VQ-VAE-2, DALL-E (dVAE), Hierarchical VAE

---

## 9. 总结

| 组件 | 说明 |
|:-----|:-----|
| **ELBO** | $\mathbb{E}[\log p(x|z)] - D_{KL}(q \| p)$ |
| **重参数化** | $z = \mu + \sigma \cdot \epsilon$ |
| **KL 解析解** | $-0.5\sum(1 + \log\sigma^2 - \mu^2 - \sigma^2)$ |
| **β-VAE** | $\beta > 1$ 促进解耦 |
| **VQ-VAE** | 离散码本，避免 posterior collapse |
| **CNN-VAE** | 卷积网络，更好的图像特征 |