### A8: Generative Models 生成式模型
- 在上个章节中，我们介绍了自动编码器及其变种。`自动编码机(AutoEncoders)`通过学习y=x的恒等变换实现在无监督训练自动提取特征。自动编码机的变种，包括`稀疏自动编码机（SparseAEs）`、`去噪自动编码机（DenoisingAEs）`，通过训练技巧提升了编码的稀疏性和鲁棒性。这一类常常被用于特征提取和降维任务。但是对于生成任务而言，他们并不合适，原因在于原始数据分布在编码器复杂非线性变换的投射下未必是规则的，这不利于我们在生成采样时令解码器得到有意义的结果。
- 变分自动编码机（VAE），借鉴了贝叶斯方法中的变分推断技术，在训练模型学习恒等变换的同时通过引入对编码空间的KL散度进行正则化，使得编码空间在编码过程中保持良好的规范性，有利于在编码空间的采样与插值。
- VAE也可以实现一种有监督的变种：条件变分自动编码机（ConditionalVAE，CVAE），可引入标签数据，实现有标签的生成任务。

VAE的原理：
- 首先，它在AutoEncoder的基础上，假定编码器的结果是一个高斯分布，用于近似后验条件分布P(z|x)，前向传播时对它的均值与方差进行预测。
- VAEs在训练时采用重参数化技巧，计算z=mu + std * eps, （eps从高斯分布上采样得到），使得梯度可以顺利回传至编码器。
- 使用重构误差（y与x的误差）和编码空间的KL散度损失同时对模型进行约束。

本实验工作内容如下：
- 以`CIFAR-10`和`MNIST`数据集为例，展示条件变分自动编码机（ConditionalVAE，CVAE）的训练和评估
- 提供了美丽的UI界面，用于展示潜在空间的连续性
- 复现了CVAE作为贝叶斯模型，在训练时遇到的“后验坍缩”问题

### 零、安装依赖

In [1]:
!pip install torch torchvision
!pip install numpy
!pip install matplotlib
!pip install tqdm


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip ins

### 一、导入模块

In [2]:
import os
from enum import Enum

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

In [8]:
class ConditionalConvVAE(nn.Module):
    """
    条件全卷积VAE变分自动编码器
    """
    NAME = 'ConditionalConvVAE'
    
    def __init__(self, potential_dim, channels, num_classes=10):
        super(ConditionalConvVAE, self).__init__()
        self.potential_dim = potential_dim
        self.channels = channels
        
        # 对类别标签进行编码的线性层
        self.label_embedding = nn.Embedding(num_classes, num_classes)
        
        output_shape = (1024, 4, 4)
        
        output_dim = output_shape[0] * output_shape[1] * output_shape[2]
        
        # 编码器
        self.encoder = nn.Sequential(
            nn.Conv2d(channels + num_classes, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 1024, kernel_size=3, stride=2, padding=1), # output: 1024 x 8 x 8
            nn.ReLU(),
            nn.Flatten(),
        )
        
        
        self.enc_mu = nn.Linear(output_dim, potential_dim)     # 均值
        self.enc_log_var = nn.Linear(output_dim, potential_dim) # 对数方差
        # 解码器
        self.decoder_fc = nn.Linear(potential_dim + num_classes, output_dim)
        
        self.decoder = nn.Sequential(
            nn.Unflatten(1, output_shape),
            nn.ConvTranspose2d(1024, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, channels, kernel_size=3, stride=2, padding=1, output_padding=1),   # output: channel x 28 x 28
            nn.Sigmoid()
        )

    def encode(self, x, labels):
        # 将标签嵌入到与图像相同的维度
        labels = self.label_embedding(labels).unsqueeze(2).unsqueeze(3)  # (128, 10, 1, 1)
        labels = labels.expand(labels.size(0), labels.size(1), x.size(2), x.size(3))
        
        # 将标签和图像连接起来
        x = torch.cat((x, labels), dim=1)
        
        # 传入编码器
        x = self.encoder(x)
        mu = self.enc_mu(x)
        log_var = self.enc_log_var(x)
        
        # 重参数化
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        z = mu + eps * std
        
        return z, mu, log_var

    def decode(self, z, labels):
        # 将标签嵌入并与潜在向量连接起来
        labels = self.label_embedding(labels)
        z = torch.cat((z, labels), dim=1)

        # 传入解码器
        x = self.decoder_fc(z)
        x = self.decoder(x)
        return x

    def forward(self, x, labels):
        z, mu, log_var = self.encode(x, labels)
        reconstructed_x = self.decode(z, labels)
        return reconstructed_x, mu, log_var


# 定义损失函数
def vae_loss(recon_x, x, mu, log_var):
    # 重构损失：通常使用二元交叉熵（BCE）损失
    MSE = nn.functional.mse_loss(recon_x.view(-1, 32*32), x.view(-1, 32*32), reduction='sum')
    # KL 散度损失：用于度量学到的潜在分布与标准正态分布之间的差异
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    loss = MSE + KLD
    return loss, np.array([loss.item(), MSE.item(), KLD.item()])

class DatasetType(Enum):
    cifar10 = 'cifar10'
    mnist = 'mnist'
    fashion_mnist = 'fashion_mnist'
    svhn = 'svhn'

### 三、模型训练

In [9]:
# 查看可用的训练设备
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [10]:
# 训练配置
batch_size = 128
epochs = 50
potential_dim = 8
dataset_type = DatasetType.mnist  # 在这里设置你的数据集

In [11]:
MODEL_PATH = f'models/cvae_{dataset_type.value}.pth'
HISTORY_PATH = f'history/cvae_{dataset_type.value}_history.npy'
CHANNELS = 1 if dataset_type in [DatasetType.mnist, DatasetType.fashion_mnist] else 3

# 模型和优化器
vae = ConditionalConvVAE(potential_dim=potential_dim, channels=CHANNELS)
vae.to(device)

running_losses = []
if os.path.exists(MODEL_PATH):
    try:
        vae.load_state_dict(torch.load(MODEL_PATH))
        running_losses = list(np.load(HISTORY_PATH)) if os.path.exists(HISTORY_PATH) else []
    except RuntimeError:
        pass

In [12]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
])

if dataset_type == DatasetType.mnist:
    train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
elif dataset_type == DatasetType.cifar10:
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
elif dataset_type == DatasetType.fashion_mnist:
    train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
elif dataset_type == DatasetType.svhn:
    train_dataset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform)
else:
    raise NotImplementedError

optimizer = optim.Adam(vae.parameters(), lr=0.001)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


p_bar = tqdm(range(epochs))
for epoch in p_bar:
    running_loss = np.array([0., 0., 0.])
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        recon_batch, mu, log_var = vae(data, labels)
        loss, losses = vae_loss(recon_batch, data, mu, log_var)
        loss.backward()
        
        optimizer.step()
        running_loss += 1/(batch_idx + 1) * (losses - running_loss)
        p_bar.set_postfix(progress=f'{(batch_idx+1)/len(train_loader)*100:.2f}%', totalLoss=f'{running_loss[0]:.3f}', MSELoss=f'{running_loss[1]:.3f}', KLDLoss=f'{running_loss[2]:.3f}')
        
    running_losses.append(running_loss)
    np.save(HISTORY_PATH, np.array(running_losses))
    torch.save(vae.state_dict(), MODEL_PATH)

  0%|          | 0/50 [00:00<?, ?it/s, KLDLoss=1166.958, MSELoss=1523.447, progress=2.56%, totalLoss=2690.405]

tensor([[ 0.6900,  0.1343, -0.2119,  ...,  0.5840, -1.0846, -1.1241],
        [-0.3916,  0.3265, -0.2579,  ..., -0.8676, -0.3159, -1.3930],
        [ 0.2934,  1.0896,  2.2991,  ...,  1.0957,  0.4492, -1.4079],
        ...,
        [ 0.5124,  0.5268, -0.3763,  ..., -0.0225,  0.0106, -1.1052],
        [ 0.2934,  1.0896,  2.2991,  ...,  1.0957,  0.4492, -1.4079],
        [ 0.5124,  0.5268, -0.3763,  ..., -0.0225,  0.0106, -1.1052]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[-0.5112,  0.2828, -1.0068,  ...,  0.2114, -0.5222, -0.3816],
        [-0.9997,  0.4582,  0.2789,  ...,  0.3395, -0.6903,  0.3455],
        [ 0.2944,  1.0906,  2.2981,  ...,  1.0967,  0.4482, -1.4089],
        ...,
        [ 0.6890,  0.1353, -0.2129,  ...,  0.5830, -1.0836, -1.1231],
        [ 0.2944,  1.0906,  2.2981,  ...,  1.0967,  0.4482, -1.4089],
        [ 0.6890,  0.1353, -0.2129,  ...,  0.5830, -1.0836, -1.1231]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[ 0.9689,  1.4733, -

  0%|          | 0/50 [00:00<?, ?it/s, KLDLoss=1174.703, MSELoss=1521.132, progress=5.54%, totalLoss=2695.835]

tensor([[ 0.6870,  0.1365, -0.2158,  ...,  0.5839, -1.0859, -1.1254],
        [-0.9953,  0.4572,  0.2770,  ...,  0.3399, -0.6940,  0.3439],
        [ 0.6870,  0.1365, -0.2158,  ...,  0.5839, -1.0859, -1.1254],
        ...,
        [-0.9953,  0.4572,  0.2770,  ...,  0.3399, -0.6940,  0.3439],
        [ 0.2954, -0.5780,  0.0478,  ...,  0.6984, -1.4237,  0.3503],
        [ 0.2959,  1.0906,  2.2980,  ...,  1.0991,  0.4462, -1.4121]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[-0.5136,  0.2841, -1.0094,  ...,  0.2107, -0.5279, -0.3794],
        [-0.3899,  0.3242, -0.2577,  ..., -0.8682, -0.3137, -1.3943],
        [ 0.5148,  0.5264, -0.3797,  ..., -0.0217,  0.0079, -1.0988],
        ...,
        [ 0.8395, -0.7254, -0.6327,  ...,  0.5298,  0.2522, -0.4445],
        [ 0.9695,  1.4755, -0.2629,  ..., -0.2572, -0.6133,  0.4683],
        [-0.9953,  0.4574,  0.2768,  ...,  0.3400, -0.6940,  0.3438]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[ 0.2962, -0.5789,  

  0%|          | 0/50 [00:00<?, ?it/s, KLDLoss=1182.697, MSELoss=1507.867, progress=8.96%, totalLoss=2690.564]

tensor([[ 0.9692,  1.4734, -0.2585,  ..., -0.2603, -0.6098,  0.4708],
        [-0.9957,  0.4564,  0.2751,  ...,  0.3408, -0.6921,  0.3437],
        [ 0.5167,  0.5273, -0.3772,  ..., -0.0222,  0.0093, -1.1007],
        ...,
        [-0.5148,  0.2847, -1.0126,  ...,  0.2081, -0.5262, -0.3794],
        [ 0.2946, -0.5799,  0.0501,  ...,  0.6961, -1.4226,  0.3540],
        [ 0.8386, -0.7247, -0.6332,  ...,  0.5323,  0.2487, -0.4453]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[ 0.2944, -0.5799,  0.0501,  ...,  0.6962, -1.4227,  0.3541],
        [-0.3934,  0.3254, -0.2612,  ..., -0.8688, -0.3165, -1.3923],
        [-0.9955,  0.4561,  0.2752,  ...,  0.3404, -0.6921,  0.3438],
        ...,
        [-0.9955,  0.4561,  0.2752,  ...,  0.3404, -0.6921,  0.3438],
        [ 0.5168,  0.5270, -0.3768,  ..., -0.0221,  0.0097, -1.1007],
        [ 0.2944, -0.5799,  0.0501,  ...,  0.6962, -1.4227,  0.3541]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[-0.9955,  0.4557,  

  0%|          | 0/50 [00:00<?, ?it/s, KLDLoss=1179.559, MSELoss=1506.258, progress=12.15%, totalLoss=2685.817]

tensor([[-0.3935,  0.3263, -0.2584,  ..., -0.8669, -0.3151, -1.3956],
        [ 0.2965,  1.0892,  2.2981,  ...,  1.0974,  0.4444, -1.4104],
        [-0.4508,  0.5842, -0.4474,  ...,  1.8282,  0.8044, -0.1612],
        ...,
        [ 0.9688,  1.4738, -0.2596,  ..., -0.2627, -0.6097,  0.4714],
        [-0.4508,  0.5842, -0.4474,  ...,  1.8282,  0.8044, -0.1612],
        [ 0.8387, -0.7249, -0.6338,  ...,  0.5323,  0.2508, -0.4440]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[-0.4508,  0.5845, -0.4476,  ...,  1.8282,  0.8042, -0.1612],
        [ 0.6863,  0.1349, -0.2126,  ...,  0.5841, -1.0846, -1.1239],
        [ 0.5163,  0.5262, -0.3771,  ..., -0.0221,  0.0131, -1.1007],
        ...,
        [-0.3936,  0.3264, -0.2586,  ..., -0.8667, -0.3155, -1.3955],
        [-0.5154,  0.2832, -1.0086,  ...,  0.2076, -0.5243, -0.3825],
        [ 0.2963,  1.0892,  2.2982,  ...,  1.0976,  0.4439, -1.4104]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[-0.9969,  0.4548,  

  0%|          | 0/50 [00:01<?, ?it/s, KLDLoss=1178.865, MSELoss=1507.335, progress=15.14%, totalLoss=2686.201]

tensor([[-0.4523,  0.5851, -0.4475,  ...,  1.8280,  0.8048, -0.1605],
        [ 0.9680,  1.4719, -0.2585,  ..., -0.2619, -0.6070,  0.4700],
        [-0.5168,  0.2829, -1.0108,  ...,  0.2074, -0.5255, -0.3812],
        ...,
        [ 0.5146,  0.5272, -0.3749,  ..., -0.0226,  0.0106, -1.0991],
        [-0.9949,  0.4565,  0.2751,  ...,  0.3433, -0.6922,  0.3447],
        [ 0.9680,  1.4719, -0.2585,  ..., -0.2619, -0.6070,  0.4700]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[ 0.2976,  1.0892,  2.2968,  ...,  1.0974,  0.4427, -1.4112],
        [ 0.2960, -0.5799,  0.0460,  ...,  0.6948, -1.4251,  0.3530],
        [-0.5166,  0.2828, -1.0106,  ...,  0.2072, -0.5251, -0.3811],
        ...,
        [-0.5166,  0.2828, -1.0106,  ...,  0.2072, -0.5251, -0.3811],
        [-0.5166,  0.2828, -1.0106,  ...,  0.2072, -0.5251, -0.3811],
        [-0.9947,  0.4562,  0.2751,  ...,  0.3429, -0.6920,  0.3447]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[-0.5166,  0.2826, -

  0%|          | 0/50 [00:01<?, ?it/s, KLDLoss=1182.597, MSELoss=1508.721, progress=18.34%, totalLoss=2691.318]

tensor([[ 0.9676,  1.4750, -0.2602,  ..., -0.2596, -0.6105,  0.4684],
        [ 0.2943, -0.5808,  0.0454,  ...,  0.6973, -1.4230,  0.3541],
        [ 0.2989,  1.0892,  2.2984,  ...,  1.0963,  0.4442, -1.4112],
        ...,
        [ 0.2943, -0.5808,  0.0454,  ...,  0.6973, -1.4230,  0.3541],
        [-0.9923,  0.4549,  0.2748,  ...,  0.3433, -0.6910,  0.3441],
        [-0.3935,  0.3263, -0.2573,  ..., -0.8649, -0.3151, -1.3954]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[ 0.2990,  1.0892,  2.2980,  ...,  1.0963,  0.4441, -1.4111],
        [-0.9924,  0.4547,  0.2753,  ...,  0.3432, -0.6906,  0.3444],
        [ 0.9674,  1.4751, -0.2599,  ..., -0.2594, -0.6106,  0.4683],
        ...,
        [ 0.5146,  0.5264, -0.3753,  ..., -0.0204,  0.0088, -1.0981],
        [ 0.9674,  1.4751, -0.2599,  ..., -0.2594, -0.6106,  0.4683],
        [-0.5143,  0.2799, -1.0098,  ...,  0.2048, -0.5212, -0.3813]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[-0.9924,  0.4545,  

  0%|          | 0/50 [00:01<?, ?it/s, KLDLoss=1184.481, MSELoss=1511.078, progress=21.11%, totalLoss=2695.560]

tensor([[ 0.5157,  0.5257, -0.3783,  ..., -0.0206,  0.0085, -1.0980],
        [ 0.2952, -0.5797,  0.0457,  ...,  0.6968, -1.4228,  0.3542],
        [-0.4493,  0.5852, -0.4487,  ...,  1.8261,  0.8057, -0.1610],
        ...,
        [ 0.5157,  0.5257, -0.3783,  ..., -0.0206,  0.0085, -1.0980],
        [ 0.8390, -0.7253, -0.6350,  ...,  0.5335,  0.2471, -0.4432],
        [-0.4493,  0.5852, -0.4487,  ...,  1.8261,  0.8057, -0.1610]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[ 0.2951, -0.5796,  0.0457,  ...,  0.6967, -1.4227,  0.3545],
        [ 0.6891,  0.1322, -0.2147,  ...,  0.5820, -1.0851, -1.1250],
        [ 0.6891,  0.1322, -0.2147,  ...,  0.5820, -1.0851, -1.1250],
        ...,
        [ 0.6891,  0.1322, -0.2147,  ...,  0.5820, -1.0851, -1.1250],
        [ 0.2951, -0.5796,  0.0457,  ...,  0.6967, -1.4227,  0.3545],
        [ 0.2956,  1.0889,  2.2996,  ...,  1.0985,  0.4444, -1.4107]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[ 0.5159,  0.5259, -

  0%|          | 0/50 [00:01<?, ?it/s, KLDLoss=1186.368, MSELoss=1513.074, progress=24.09%, totalLoss=2699.443]

tensor([[-0.5125,  0.2823, -1.0058,  ...,  0.2081, -0.5262, -0.3804],
        [-0.4463,  0.5850, -0.4500,  ...,  1.8269,  0.8042, -0.1623],
        [ 0.5180,  0.5257, -0.3780,  ..., -0.0221,  0.0113, -1.1014],
        ...,
        [ 0.2968, -0.5788,  0.0446,  ...,  0.6997, -1.4243,  0.3534],
        [-0.5125,  0.2823, -1.0058,  ...,  0.2081, -0.5262, -0.3804],
        [-0.4463,  0.5850, -0.4500,  ...,  1.8269,  0.8042, -0.1623]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[ 0.5178,  0.5257, -0.3780,  ..., -0.0225,  0.0115, -1.1013],
        [-0.4461,  0.5849, -0.4498,  ...,  1.8268,  0.8042, -0.1622],
        [ 0.9666,  1.4768, -0.2634,  ..., -0.2588, -0.6141,  0.4682],
        ...,
        [-0.3919,  0.3210, -0.2575,  ..., -0.8700, -0.3115, -1.3958],
        [ 0.2970, -0.5786,  0.0447,  ...,  0.6999, -1.4244,  0.3532],
        [ 0.6884,  0.1329, -0.2120,  ...,  0.5851, -1.0853, -1.1207]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[ 0.2936,  1.0874,  

  0%|          | 0/50 [00:01<?, ?it/s, KLDLoss=1186.776, MSELoss=1514.105, progress=27.08%, totalLoss=2700.882]

tensor([[-0.5110,  0.2829, -1.0048,  ...,  0.2109, -0.5267, -0.3808],
        [-0.3929,  0.3218, -0.2581,  ..., -0.8697, -0.3132, -1.3955],
        [ 0.6871,  0.1331, -0.2119,  ...,  0.5847, -1.0844, -1.1215],
        ...,
        [ 0.8350, -0.7241, -0.6343,  ...,  0.5340,  0.2494, -0.4448],
        [ 0.5169,  0.5246, -0.3785,  ..., -0.0238,  0.0125, -1.0998],
        [ 0.2947,  1.0858,  2.2988,  ...,  1.0953,  0.4455, -1.4082]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[ 0.6870,  0.1332, -0.2119,  ...,  0.5848, -1.0844, -1.1219],
        [ 0.9664,  1.4752, -0.2663,  ..., -0.2566, -0.6165,  0.4652],
        [ 0.5167,  0.5243, -0.3781,  ..., -0.0237,  0.0125, -1.0998],
        ...,
        [-0.5107,  0.2828, -1.0048,  ...,  0.2109, -0.5266, -0.3810],
        [ 0.5167,  0.5243, -0.3781,  ..., -0.0237,  0.0125, -1.0998],
        [ 0.2948,  1.0858,  2.2987,  ...,  1.0956,  0.4454, -1.4084]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[-0.3923,  0.3216, -

  0%|          | 0/50 [00:02<?, ?it/s, KLDLoss=1187.436, MSELoss=1512.426, progress=30.28%, totalLoss=2699.862]

tensor([[ 0.8335, -0.7222, -0.6347,  ...,  0.5332,  0.2473, -0.4436],
        [ 0.2925, -0.5803,  0.0438,  ...,  0.6973, -1.4235,  0.3559],
        [ 0.9646,  1.4719, -0.2639,  ..., -0.2594, -0.6143,  0.4661],
        ...,
        [ 0.5177,  0.5273, -0.3755,  ..., -0.0230,  0.0153, -1.1000],
        [ 0.5177,  0.5273, -0.3755,  ..., -0.0230,  0.0153, -1.1000],
        [ 0.2925, -0.5803,  0.0438,  ...,  0.6973, -1.4235,  0.3559]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[-0.9978,  0.4536,  0.2754,  ...,  0.3411, -0.6900,  0.3438],
        [ 0.9645,  1.4719, -0.2641,  ..., -0.2595, -0.6145,  0.4660],
        [-0.3931,  0.3203, -0.2550,  ..., -0.8677, -0.3139, -1.3924],
        ...,
        [-0.3931,  0.3203, -0.2550,  ..., -0.8677, -0.3139, -1.3924],
        [-0.9978,  0.4536,  0.2754,  ...,  0.3411, -0.6900,  0.3438],
        [-0.3931,  0.3203, -0.2550,  ..., -0.8677, -0.3139, -1.3924]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[ 0.8335, -0.7225, -

  0%|          | 0/50 [00:02<?, ?it/s, KLDLoss=1186.899, MSELoss=1509.596, progress=33.48%, totalLoss=2696.495]

tensor([[ 0.9644,  1.4716, -0.2643,  ..., -0.2601, -0.6134,  0.4654],
        [ 0.5194,  0.5272, -0.3723,  ..., -0.0232,  0.0128, -1.0986],
        [-0.9978,  0.4553,  0.2734,  ...,  0.3418, -0.6898,  0.3429],
        ...,
        [-0.4465,  0.5827, -0.4443,  ...,  1.8276,  0.8039, -0.1600],
        [-0.9978,  0.4553,  0.2734,  ...,  0.3418, -0.6898,  0.3429],
        [-0.5145,  0.2855, -1.0112,  ...,  0.2097, -0.5306, -0.3821]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[ 0.8366, -0.7231, -0.6337,  ...,  0.5293,  0.2500, -0.4435],
        [ 0.2979,  1.0855,  2.3001,  ...,  1.0952,  0.4467, -1.4069],
        [-0.9982,  0.4555,  0.2732,  ...,  0.3417, -0.6902,  0.3429],
        ...,
        [-0.9982,  0.4555,  0.2732,  ...,  0.3417, -0.6902,  0.3429],
        [-0.3933,  0.3212, -0.2560,  ..., -0.8660, -0.3141, -1.3939],
        [ 0.9644,  1.4719, -0.2643,  ..., -0.2606, -0.6133,  0.4653]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[ 0.2948, -0.5808,  

  0%|          | 0/50 [00:02<?, ?it/s, KLDLoss=1186.958, MSELoss=1508.833, progress=36.67%, totalLoss=2695.791]

tensor([[ 0.3010,  1.0845,  2.3001,  ...,  1.0961,  0.4470, -1.4070],
        [-0.4495,  0.5813, -0.4444,  ...,  1.8295,  0.8051, -0.1601],
        [ 0.2944, -0.5780,  0.0461,  ...,  0.6989, -1.4215,  0.3535],
        ...,
        [ 0.9636,  1.4739, -0.2621,  ..., -0.2594, -0.6115,  0.4652],
        [ 0.8350, -0.7232, -0.6336,  ...,  0.5329,  0.2469, -0.4443],
        [ 0.6901,  0.1296, -0.2147,  ...,  0.5798, -1.0805, -1.1245]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[-0.9964,  0.4546,  0.2736,  ...,  0.3388, -0.6898,  0.3421],
        [ 0.6901,  0.1300, -0.2149,  ...,  0.5801, -1.0806, -1.1245],
        [ 0.2941, -0.5780,  0.0463,  ...,  0.6987, -1.4215,  0.3538],
        ...,
        [ 0.5202,  0.5264, -0.3726,  ..., -0.0242,  0.0134, -1.0975],
        [ 0.8350, -0.7234, -0.6336,  ...,  0.5330,  0.2471, -0.4443],
        [ 0.3012,  1.0848,  2.2999,  ...,  1.0964,  0.4466, -1.4074]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[ 0.2939, -0.5781,  

  0%|          | 0/50 [00:02<?, ?it/s, KLDLoss=1187.326, MSELoss=1509.928, progress=39.87%, totalLoss=2697.254]

tensor([[-0.3958,  0.3227, -0.2559,  ..., -0.8638, -0.3159, -1.3924],
        [-0.5137,  0.2829, -1.0120,  ...,  0.2125, -0.5265, -0.3833],
        [ 0.6890,  0.1302, -0.2159,  ...,  0.5805, -1.0802, -1.1236],
        ...,
        [-0.9974,  0.4547,  0.2740,  ...,  0.3398, -0.6909,  0.3437],
        [ 0.2955, -0.5783,  0.0472,  ...,  0.6974, -1.4204,  0.3550],
        [-0.9974,  0.4547,  0.2740,  ...,  0.3398, -0.6909,  0.3437]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[-0.9975,  0.4547,  0.2739,  ...,  0.3400, -0.6911,  0.3435],
        [ 0.5209,  0.5248, -0.3725,  ..., -0.0248,  0.0125, -1.0999],
        [-0.5140,  0.2831, -1.0124,  ...,  0.2125, -0.5268, -0.3834],
        ...,
        [-0.9975,  0.4547,  0.2739,  ...,  0.3400, -0.6911,  0.3435],
        [ 0.6891,  0.1303, -0.2158,  ...,  0.5807, -1.0804, -1.1236],
        [-0.3956,  0.3229, -0.2564,  ..., -0.8636, -0.3162, -1.3924]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[-0.5144,  0.2833, -

  0%|          | 0/50 [00:02<?, ?it/s, KLDLoss=1186.091, MSELoss=1510.353, progress=43.28%, totalLoss=2696.444]

tensor([[-0.3934,  0.3233, -0.2576,  ..., -0.8626, -0.3157, -1.3931],
        [-0.9965,  0.4517,  0.2750,  ...,  0.3388, -0.6901,  0.3429],
        [-0.3934,  0.3233, -0.2576,  ..., -0.8626, -0.3157, -1.3931],
        ...,
        [ 0.2951, -0.5774,  0.0459,  ...,  0.6974, -1.4213,  0.3534],
        [ 0.2951, -0.5774,  0.0459,  ...,  0.6974, -1.4213,  0.3534],
        [-0.4462,  0.5820, -0.4448,  ...,  1.8268,  0.8034, -0.1611]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[-0.4459,  0.5820, -0.4447,  ...,  1.8269,  0.8032, -0.1612],
        [ 0.9649,  1.4745, -0.2652,  ..., -0.2595, -0.6117,  0.4648],
        [ 0.8385, -0.7262, -0.6323,  ...,  0.5328,  0.2519, -0.4427],
        ...,
        [ 0.6866,  0.1316, -0.2154,  ...,  0.5830, -1.0824, -1.1211],
        [-0.3932,  0.3234, -0.2574,  ..., -0.8625, -0.3156, -1.3932],
        [-0.9960,  0.4515,  0.2748,  ...,  0.3388, -0.6901,  0.3427]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[ 0.2980,  1.0839,  

  0%|          | 0/50 [00:03<?, ?it/s, KLDLoss=1186.085, MSELoss=1509.482, progress=46.27%, totalLoss=2695.567]

tensor([[ 0.9643,  1.4759, -0.2622,  ..., -0.2619, -0.6123,  0.4663],
        [-0.3943,  0.3212, -0.2562,  ..., -0.8637, -0.3147, -1.3903],
        [ 0.6875,  0.1309, -0.2129,  ...,  0.5824, -1.0800, -1.1213],
        ...,
        [-0.9945,  0.4521,  0.2723,  ...,  0.3408, -0.6915,  0.3422],
        [ 0.8386, -0.7270, -0.6297,  ...,  0.5324,  0.2531, -0.4404],
        [-0.4461,  0.5805, -0.4442,  ...,  1.8279,  0.8044, -0.1596]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[-0.9946,  0.4522,  0.2722,  ...,  0.3409, -0.6916,  0.3424],
        [-0.9946,  0.4522,  0.2722,  ...,  0.3409, -0.6916,  0.3424],
        [ 0.2945, -0.5765,  0.0444,  ...,  0.6985, -1.4216,  0.3526],
        ...,
        [ 0.5208,  0.5255, -0.3760,  ..., -0.0202,  0.0119, -1.0995],
        [ 0.2958,  1.0880,  2.2964,  ...,  1.0988,  0.4444, -1.4063],
        [-0.4463,  0.5803, -0.4443,  ...,  1.8277,  0.8043, -0.1595]],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor([[ 0.2946, -0.5763,  




KeyboardInterrupt: 

### 四、绘制损失曲线

In [None]:
plt.figure(figsize=(12, 5))
total_loss, bce_loss, kld_loss = np.array(running_losses).T
# 绘制损失
plt.subplot(1, 2, 1)
plt.plot(total_loss, label='Training Loss')
plt.plot(bce_loss, label='Training MSE Loss')
plt.plot(kld_loss, label='Training KLD Loss')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

### 五、模型推理

In [None]:
def generate(model, labels, device='cuda'):
    with torch.no_grad():
        num_samples = len(labels)
        z = torch.randn(num_samples, model.potential_dim).to(device)
        labels = torch.LongTensor(np.array(labels)).to(device)
        # 通过解码器生成图像
        generated_images = model.decode(z, labels)
    return generated_images

def plot(images):
    plt.figure(figsize=(8, 8))
    channels = images.size(1)
    for i, image in enumerate(images):
        plt.subplot(8, 8, i+1)
        img = image.squeeze().cpu().numpy()
        if channels == 3:
            img = np.transpose(img, (2, 1, 0))
            
        plt.imshow(img, cmap='gray' if channels == 1 else None)
        plt.axis('off')
    plt.show()

In [None]:
images = generate(vae, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device)
plot(images)

In [None]:
!python cvae_visualizer.py --mnist
# !python cvae_visualizer.py --cifar10
# !python cvae_visualizer.py --fashion_mnist
# !python cvae_visualizer.py --svhn