### kld

```python
kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
```

-  $D_{KL}(Q(z|X) || P(z))$ where P(z) is N(0, I)
    -  $-\frac12\sum (1+\log\sigma^2-\mu^2-\sigma^2)$
    -  对角高斯分布 (Diagonal Gaussian): 为了简化计算，我们做一个重要假设：隐空间的各个维度之间是相互独立的。这意味着描述这个多维高斯分布的协方差矩阵是一个对角矩阵（只有对角线上有值，其余都为0）。
    -  $D_{KL}( \mathcal{N}(\mu, \sigma^2 I) \ || \ \mathcal{N}(0, I))$
        -  两个 $d$ 维高斯分布，$P_1 = \mathcal{N}(\mu_1, \Sigma_1)$, $P_2 = \mathcal{N}(\mu_2, \Sigma_2)$
        -  $D_{KL}(P_1 || P_2) = \frac{1}{2} \left( \text{tr}(\Sigma_2^{-1}\Sigma_1) + (\mu_2 - \mu_1)^T \Sigma_2^{-1} (\mu_2 - \mu_1) - d + \ln\left(\frac{\det \Sigma_2}{\det \Sigma_1}\right) \right)$
        - 我们这里的 $\mu_1 = \mu, \Sigma_1 = \text{diag}(\sigma_1^2, ..., \sigma_d^2) = \sigma^2 I$，$\mu_2 = 0, \Sigma_2 = I$
        - $D_{KL} = \frac{1}{2} \sum_{i=1}^{d} (\sigma_i^2 + \mu_i^2 - 1 - \ln(\sigma_i^2))$

### reparameterize trick

```python
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    """
    :param mu: (Tensor) Mean of the latent Gaussian
    :param logvar: (Tensor) Standard deviation of the latent Gaussian
    :return:
    """
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return eps * std + mu
```

### loss scale??

```python
reconstruction_loss = F.mse_loss(x_hat, x, reduction='mean')
```
$$
L_{\text{MSE}} = \frac{1}{B \cdot C \cdot H \cdot W} \sum_{b=1}^{B} \sum_{c=1}^{C} \sum_{h=1}^{H} \sum_{w=1}^{W} (x_{b, c, h, w} - \hat{x}_{b, c, h, w})^2
$$

- kld
    - 先对潜在维度求和，再对批次维度求均值。这计算了每个数据点的平均KL散度。

```python
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1) # 对每个数据点的潜变量维度求和
kl_loss = torch.mean(kl_loss) # 对batch维度取均值
```

In [9]:
import torch
import torch.nn.functional as F

# 1. 定义我们的输入张量 (tensors)
# 形状为 (B, C, H, W) = (2, 1, 1, 2)
x = torch.tensor([[[[1.0, 2.0]]], 
                  [[[3.0, 4.0]]]], dtype=torch.float32)

x_hat = torch.tensor([[[[1.5, 2.5]]], 
                      [[[2.0, 5.0]]]], dtype=torch.float32)

In [10]:
x.shape

torch.Size([2, 1, 1, 2])

In [11]:
F.mse_loss(x_hat, x, reduction='mean')

tensor(0.6250)

In [12]:
((x_hat - x) ** 2).sum() / x.numel()

tensor(0.6250)

In [13]:
import torch
import torch.nn.functional as F

def vae_loss_function(x_hat, x, mu, log_var, beta=1.0):
    recon_loss = F.mse_loss(x_hat, x, reduction='mean')

    # 2. KL 损失 (KL Divergence Loss)
    # D_KL(q(z|x) || p(z)) = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    # 首先，对每个数据点的潜变量维度求和 (dim=1)
    # 然后，对batch维度取均值
    kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)
    kl_loss = torch.mean(kl_div, dim=0)

    # 3. 总损失
    total_loss = recon_loss + beta * kl_loss

    return total_loss, recon_loss, kl_loss

In [17]:
batch_size = 64
latent_dim = 10
img_channels = 1
# img_size = 28
img_size = 64*64

# 模拟模型输出
x = torch.rand(batch_size, img_channels, img_size, img_size)
x_hat = torch.rand(batch_size, img_channels, img_size, img_size)
mu = torch.randn(batch_size, latent_dim)
log_var = torch.randn(batch_size, latent_dim)

# 设置 beta 值 (可以从一个较小的值开始，或者使用KL退火)
beta_value = 1 

total_loss, recon_loss, kl_loss = vae_loss_function(x_hat, x, mu, log_var, beta=beta_value)

print(f"Beta: {beta_value}")
print(f"Reconstruction Loss (per-pixel mean): {recon_loss.item():.4f}")
print(f"KL Loss (per-datapoint mean): {kl_loss.item():.4f}")
print(f"Weighted KL Loss: {(beta_value * kl_loss).item():.4f}")
print(f"Total Loss: {total_loss.item():.4f}")

Beta: 1
Reconstruction Loss (per-pixel mean): 0.1667
KL Loss (per-datapoint mean): 8.3359
Weighted KL Loss: 8.3359
Total Loss: 8.5025
