In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 设定超参数
batch_size = 64
learning_rate_g = 0.001
learning_rate_d = 0.001
epochs = 50
image_size = 32
latent_dim = 100
channels = 3
num_classes = 10  # CIFAR10数据集的类别数

# 数据集加载器
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        # 全连接层，用于增加维度
        self.fc = nn.Linear(num_classes + latent_dim, 256 * 4 * 4)
        self.model = nn.Sequential(
            # 重塑为4D张量，以匹配ConvTranspose2d层的输入需求
            nn.Unflatten(1, (256, 4, 4)),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # 将噪声扩展为4D张量
        noise = noise.view(noise.size(0), latent_dim, 1, 1)

        # 将标签嵌入扩展到与噪声相同的维度
        labels = self.label_emb(labels).unsqueeze(-1).unsqueeze(-1)
        labels = labels.expand(-1, -1, noise.size(2), noise.size(3))

        # 现在可以将噪声和标签合并
        gen_input = torch.cat((noise, labels), 1)
        gen_input = self.fc(gen_input.view(gen_input.size(0), -1))
        out = self.model(gen_input)
        return out

# 判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        # 修改全连接层的输出维度，使之与卷积层输入匹配
        self.fc = nn.Linear(num_classes + 3 * 32 * 32, 128 * 16 * 16)
        self.model = nn.Sequential(
            # 重塑为适合卷积层的尺寸
            nn.Unflatten(1, (128, 16, 16)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x, labels):
        # 将图像和标签结合起来
        d_input = torch.cat((x.view(x.size(0), -1), self.label_emb(labels)), -1)
        d_input = self.fc(d_input)
        out = self.model(d_input)
        return out.view(-1, 1).squeeze(1)


# 初始化网络并移动到指定的设备上
netG = Generator().to(device)
netD = Discriminator().to(device)

# 损失函数和优化器
criterion = nn.BCELoss()
optimizerG = optim.Adam(netG.parameters(), lr=learning_rate_g)
optimizerD = optim.Adam(netD.parameters(), lr=learning_rate_d)

# 训练循环
for epoch in range(epochs):
    for i, (data, labels) in enumerate(train_loader):
        ############################
        # (1) 更新判别器网络: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        netD.zero_grad()
        # 训练真实数据
        real_data = data.to(device)
        labels = labels.to(device)
        b_size = real_data.size(0)
        label = torch.full((b_size,), 1, dtype=torch.float, device=device)
        output = netD(real_data, labels)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # 训练生成的假数据
        noise = torch.randn(b_size, latent_dim, device=device)
        fake_labels = torch.randint(0, num_classes, (b_size,), device=device)
        fake = netG(noise, fake_labels)
        label.fill_(0)
        output = netD(fake.detach(), fake_labels)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) 更新生成器网络: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(1)  # fake labels are real for generator cost
        output = netD(fake, fake_labels)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        # 打印损失
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, epochs, i, len(train_loader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))


# 保存模型
torch.save(netG.state_dict(), 'generator.pth')
torch.save(netD.state_dict(), 'discriminator.pth')


# 生成新图像
num_images_per_class = 10  # 每个类别生成的图片数
fake_images = []

for class_label in range(num_classes):
    # 为每个类别生成相应的标签
    labels = torch.full((num_images_per_class,), class_label, dtype=torch.long, device=device)
    # 生成噪声
    noise = torch.randn(num_images_per_class, latent_dim, 1, 1, device=device)
    # 生成图片
    with torch.no_grad():
        class_fake_images = netG(noise, labels)
        fake_images.append(class_fake_images)

# 将所有生成的图片合并到一起
fake_images = torch.cat(fake_images, dim=0)
fake_images = fake_images.detach().cpu()

# 显示生成的图片
plt.figure(figsize=(15, 15))
plt.axis("off")
plt.title("Generated Images")
plt.imshow(np.transpose(vutils.make_grid(fake_images, nrow=num_images_per_class, padding=2, normalize=True).cpu(), (1, 2, 0)))
plt.show()