In [1]:
import torch
import torch.nn as nn
from torch import optim
num_classes = 10
batch_size = 128
embedding_dim = 64
z_dim = 100
img_size = 28

In [None]:
# 定义判别器（带辅助分类器）
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, embedding_dim)
        self.features = nn.Sequential(
            nn.Linear(28 * 28 + embedding_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4)
        )
        self.validity = nn.Linear(256, 1)  # 对抗损失输出
        self.classifier = nn.Linear(256, num_classes)  # 分类损失输出

        self.sigmoid = nn.Sigmoid()

    def forward(self, img, labels):
        # 将图像展平
        img_flat = img.view(img.size(0), -1)
        # 将标签转换为嵌入向量
        label_embedding = self.label_embedding(labels)
        # 将图像特征和标签嵌入连接
        inputs = torch.cat([img_flat, label_embedding], dim=1)
        # 提取特征
        features = self.features(inputs)  # 提取到256维特征
        # 对抗损失输出
        validity =self.sigmoid( self.validity(features) ) # 最后的1维输出
        # 分类损失输出
        class_output = self.classifier(features)

        return validity, class_output

# 定义生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, embedding_dim)
        self.model = nn.Sequential(
            nn.Linear(z_dim + embedding_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 28 * 28),  # 假设生成 28x28 的图像
            nn.Tanh()
        )

    def forward(self, z, labels):
        # 将标签转换为嵌入向量
        label_embedding = self.label_embedding(labels)
        # 将噪声向量和标签嵌入连接
        inputs = torch.cat([z, label_embedding], dim=1)
        # 生成图像
        img = self.model(inputs)
        img = img.view(img.size(0), 1, 28, 28)  # 调整形状为 (batch_size, 1, 28, 28)
        return img

# 初始化模型

generator = Generator()
discriminator = Discriminator()

# 定义损失函数
adversarial_loss = nn.BCELoss()
classification_loss = nn.CrossEntropyLoss()

# 定义优化器
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)



In [None]:
train_loader = 0  #懒得弄
def train():
    # 训练过程（简化示例）
    for epoch in range(10):
        for i, (real_imgs, labels) in enumerate(train_loader):  # 假设 train_loader 已定义


            # 真实和假标签
            valid = torch.ones(batch_size, 1)
            fake = torch.zeros(batch_size, 1)

            # ---------------------
            #  训练判别器
            # ---------------------
            optimizer_D.zero_grad()

            # 真实图像的损失
            real_validity, real_class_output = discriminator(real_imgs, labels)
            real_adv_loss = adversarial_loss(real_validity, valid)
            real_cls_loss = classification_loss(real_class_output, labels)
            real_loss = real_adv_loss + real_cls_loss

            # 生成假图像
            z = torch.randn(batch_size, z_dim)
            gen_labels = torch.randint(0, num_classes, (batch_size,))
            gen_imgs = generator(z, gen_labels)

            # 假图像的损失
            fake_validity, fake_class_output = discriminator(gen_imgs.detach(), gen_labels)
            fake_adv_loss = adversarial_loss(fake_validity, fake)
            fake_cls_loss = classification_loss(fake_class_output, gen_labels)
            fake_loss = fake_adv_loss + fake_cls_loss

            # 总损失
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()

            # ---------------------
            #  训练生成器
            # ---------------------
            optimizer_G.zero_grad()

            # 生成器损失
            gen_validity, gen_class_output = discriminator(gen_imgs, gen_labels)
            g_adv_loss = adversarial_loss(gen_validity, valid)
            g_cls_loss = classification_loss(gen_class_output, gen_labels)
            g_loss = g_adv_loss + g_cls_loss

            g_loss.backward()
            optimizer_G.step()

            # 打印训练信息
            if i % 100 == 0:
                print(
                    f"[Epoch {epoch}/10] [Batch {i}/{len(train_loader)}] "
                    f"[D loss: {d_loss.item()}] [G loss: {g_loss.item()}]"
                )