In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from resnet18_32x32 import ResNet18_32x32

from torch.utils.data import TensorDataset, DataLoader
import pickle

device = torch.device("cuda" if torch.cuda.is_available() else False)

  warn(


In [2]:
# 加载保存的label_list
with open('../logs/res/cifar/label_tensor.pickle', 'rb') as f:
    label_list = pickle.load(f)

# 加载保存的probs_list
with open('../logs/res/cifar/probs_tensor.pickle', 'rb') as f:
    probs_list = pickle.load(f)

In [3]:
#参数加载
latent_dim = 10
lr = 0.0002
batch_size = 1
num_epochs = 200
num_classes = 10

class Generator(nn.Module):
    

    
    def __init__(self, latent_dim, num_classes):
        super(Generator, self).__init__()
        self.label_embed = nn.Embedding(num_classes, latent_dim)
        self.generator = nn.Sequential(
            nn.Linear(latent_dim * 2 , 128),  # 将噪音维度加入生成器输入维度
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, num_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, labels, noise):
        gen_input = self.label_embed(labels)
        gen_input_with_noise = torch.cat((gen_input, noise), -1)  # 将噪音和标签嵌入向量连接起来
        class_probs = self.generator(gen_input_with_noise)
        return class_probs



class Discriminator(nn.Module):
    def __init__(self, num_classes):
        super(Discriminator, self).__init__()
        self.label_embed = nn.Embedding(num_classes, num_classes)
        self.discriminator = nn.Sequential(
            nn.Linear(num_classes * 2 , 512),  # 输入维度为类别数乘以2
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        

    def forward(self, class_probs, labels):
        label_embed = self.label_embed(labels)
        input_tensor = torch.cat((class_probs, label_embed), dim=1)
        validity = self.discriminator(input_tensor)
        return validity

In [4]:
# 创建生成器和判别器实例
generator = Generator(latent_dim, num_classes).to(device)
discriminator = Discriminator(num_classes).to(device)

#generator.load_state_dict(torch.load('./lenet/generator200+100.pth'))

#discriminator.load_state_dict(torch.load('./lenet/discriminator200+100.pth'))


# 定义损失函数和优化器
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)

optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)


In [None]:
d_losses = []  # 存储判别器损失值
g_losses = []  # 存储生成器损失值
# 训练生成对抗网络

# 将label_list和probs_list转换为Tensor对象
# 将label_list和probs_list转换为整数类型的Tensor对象
# 将label_list和probs_list转换为Tensor对象
label_tensor = torch.cat(label_list, dim=0)
probs_tensor = torch.cat(probs_list, dim=0)

dataset = TensorDataset(label_tensor, probs_tensor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)


for epoch in range(num_epochs):
    for i, (predicted, probabilities) in enumerate(dataloader):
        # 将数据移动到GPU
        predicted = predicted.to(device)
        probabilities = probabilities.to(device)

        gen_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
        noise = torch.randn(batch_size, latent_dim).to(device)
        gen_class_probs = generator(gen_labels, noise)
        
        
        # 训练判别器
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        optimizer_D.zero_grad()
        # 判别器判断真实类别概率分布
        real_validity = discriminator(probabilities, predicted)
        real_loss = adversarial_loss(real_validity, real_labels)

        # 判别器判断生成的类别概率分布
        fake_validity = discriminator(gen_class_probs.detach(), gen_labels)
        fake_loss = adversarial_loss(fake_validity, fake_labels)

        d_loss = (real_loss + fake_loss) /2 
        d_loss.backward()
        optimizer_D.step()
        
        
        # 训练生成器
        optimizer_G.zero_grad()

        # 生成器生成类别概率分布，并判别器判断生成的类别概率分布
        gen_validity = discriminator(gen_class_probs, gen_labels)
        g_loss = adversarial_loss(gen_validity, real_labels)

        g_loss.backward()
        optimizer_G.step()

    d_losses.append(d_loss.item())
    g_losses.append(g_loss.item())
    print(f"[Epoch {epoch + 1}/{num_epochs}] Label: {gen_labels}")
    print(f"[Epoch {epoch + 1}/{num_epochs}] Generated Probs: {gen_class_probs.detach().cpu().numpy()}")
    print(f"[Epoch {epoch + 1}/{num_epochs}] D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
    if (epoch + 1) % 10 == 0:
        # 保存生成器和判别器的状态字典
        torch.save(generator.state_dict(), f'./res/cifar/1/generator{epoch + 1}.pth')
        torch.save(discriminator.state_dict(), f'./res/cifar/1/discriminator{epoch + 1}.pth')

[Epoch 1/200] Label: tensor([0], device='cuda:0')
[Epoch 1/200] Generated Probs: [[9.9995649e-01 2.8665406e-09 6.6630737e-06 2.4010720e-05 5.0542135e-06
  7.5355624e-06 4.9646875e-08 1.7587404e-07 1.7482103e-07 2.3944732e-08]]
[Epoch 1/200] D Loss: 0.7582, G Loss: 0.5329
[Epoch 2/200] Label: tensor([1], device='cuda:0')
[Epoch 2/200] Generated Probs: [[9.0327444e-14 9.9998903e-01 2.4461296e-11 1.3387374e-07 3.7350335e-12
  5.5724749e-14 7.2368822e-09 1.0376064e-05 4.4830364e-10 4.3837483e-07]]
[Epoch 2/200] D Loss: 0.3789, G Loss: 0.6399
[Epoch 3/200] Label: tensor([8], device='cuda:0')
[Epoch 3/200] Generated Probs: [[2.4568748e-12 1.0204633e-08 1.9597163e-14 4.6900666e-05 8.6081450e-08
  1.0628789e-03 1.6555949e-05 8.6372147e-06 9.9886489e-01 1.4501108e-08]]
[Epoch 3/200] D Loss: 0.6963, G Loss: 0.7369
[Epoch 4/200] Label: tensor([0], device='cuda:0')
[Epoch 4/200] Generated Probs: [[9.9795258e-01 5.3288179e-05 4.5512381e-04 5.0296984e-04 7.7111495e-04
  8.4953346e-05 1.0441585e-05 6

In [None]:
# 可视化损失值
plt.plot(range(1, num_epochs + 1), d_losses, label='Discriminator Loss')

plt.plot(range(1, num_epochs + 1), g_losses, label='Generator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()