In [9]:
import torch
import torch.nn as nn
from numpy.ma.core import zeros_like
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from torchvision import transforms, datasets

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
z_dim = 100
batch_size = 128
ngf = 64
ndf = 64
img_channels = 1

In [10]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.block1 = nn.Sequential(
            nn.ConvTranspose2d(z_dim, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
        )
        self.block2 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
        )
        self.block3 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
        )
        self.block4 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
        )
        self.block5 = nn.Sequential(
            nn.ConvTranspose2d(ngf, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        x = self.block1(z)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        return self.block5(x)

In [11]:
class Critic(nn.Module):
    def __init__(self):
        super(Critic,self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(img_channels, ngf * 8, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.block4 = nn.Sequential(
            nn.Conv2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.block5 = nn.Sequential(
            nn.Conv2d(ngf, 1, 4, 1, 0, bias=False),  # 输出一个标量
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        return self.block5(x)

In [12]:
def gradient_penalty(critic, real_samples, fake_samples, device="cuda"):
    batch_size = real_samples.shape[0]

    # 生成随机插值系数 alpha，并调整形状
    alpha = torch.rand(batch_size, 1, 1, 1).to(device)

    # 计算插值数据
    interpolates = alpha * real_samples + (1 - alpha) * fake_samples
    interpolates.requires_grad_(True)  # 需要梯度跟踪

    # 计算判别器在插值样本上的输出
    scores = critic(interpolates)

    # 计算梯度，注意这里需要retain_graph=True
    gradients = torch.autograd.grad(outputs=scores, inputs=interpolates,
                                    grad_outputs=torch.ones_like(scores),
                                    create_graph=True, retain_graph=True)[0]

    # 计算梯度范数
    gradient_norm = gradients.view(batch_size, -1).norm(2, dim=1)

    # 计算梯度惩罚项
    gradient_penalty = ((gradient_norm - 1) ** 2).mean()

    return gradient_penalty



def d_loss_function(real, fake):
    return torch.mean(fake) - torch.mean(real)

def g_loss_function(fake):
    return -torch.mean(fake)

In [13]:
critic = Critic().to(device)
generator = Generator().to(device)
d_optimizer = Adam(generator.parameters(), lr=0.0002, betas=(0.1, 0.999))
g_optimizer = Adam(generator.parameters(), lr=0.0002, betas=(0.1, 0.999))

In [18]:
test = torch.randn(batch_size, z_dim,1,1).to(device)
middle = generator(test)
out = critic(middle)
print(f"out shape: {out.shape},middle shape: {middle.shape}")

out shape: torch.Size([128, 1, 1, 1]),middle shape: torch.Size([128, 1, 64, 64])


In [15]:
def train(num_epochs, train_loader):
    generator.train()
    critic.train()
    epochs = torch.arange(1, num_epochs + 1)
    g_losses = []
    d_losses = []

    for epoch in range(num_epochs):
        g_running_loss = 0.0
        d_running_loss = 0.0
        for i, (real_images, _) in enumerate(train_loader):  # 遍历数据集，每次获取一个 batch 的真实图像
            d_optimizer.zero_grad()  # 清空判别器的梯度

            real_images = real_images.to(device)  # 将真实图像移动到计算设备
            real_outputs = critic(real_images)  # 判别器对真实图像的预测

            z = torch.randn(batch_size, z_dim, 1, 1).to(device)  # 生成 batch_size 个随机噪声向量
            fake_images = generator(z)  # 生成器生成假的图像
            fake_outputs = critic(fake_images.detach())  # 判别器对假图像的预测（detach 避免更新生成器）

            d_loss = d_loss_function(real_outputs, fake_outputs) + 10 * gradient_penalty(critic, real_images,
                                                                                         fake_images)

            d_loss.backward(retain_graph=True)  # 反向传播更新判别器
            d_optimizer.step()  # 进行梯度更新

            if (i + 1) % 50 == 0:
                g_optimizer.zero_grad()  # 清空生成器的梯度
                g_outputs = critic(fake_images)  # 让判别器评估生成的假图像
                g_loss = g_loss_function(g_outputs)  # 计算生成器的损失（希望判别器认为假图像是真实的）
                g_loss.backward()  # 反向传播更新生成器
                g_optimizer.step()  # 进行梯度更新

                g_running_loss += g_loss.item()  # 记录生成器损失
                d_running_loss += d_loss.item()  # 记录判别器损失

            if (i + 1) % 100 == 0:
                print(
                    f"Epoch : [{epoch + 1}/{num_epochs}]\tIter : [{i + 1}/{len(train_loader)}]\tGenerator Loss: {g_running_loss / (i + 1):.3f}\tDiscriminator Loss: {d_running_loss / (i + 1):.3f} ",
                    '\n')

        g_losses.append(g_running_loss / len(train_loader))
        d_losses.append(d_running_loss / len(train_loader))

    fig, axes = plt.subplots(1, 2, figsize=(10, 4))  # 1行2列的子图

    # 绘制 Generator Loss
    axes[0].plot(epochs, g_losses, marker='o', color='blue', label='Generator Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('G Loss')
    axes[0].set_title('Generator Loss Curve')
    axes[0].legend()

    # 绘制 Discriminator Loss
    axes[1].plot(epochs, d_losses, marker='s', color='red', label='Discriminator Loss')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('D Loss')
    axes[1].set_title('Discriminator Loss Curve')
    axes[1].legend()

    plt.tight_layout()
    plt.show()


In [16]:
from torchvision.utils import make_grid
def show_generated_images(generator, z_dim, device, num_images=64):
    # 生成随机噪声
    noise = torch.randn(num_images, z_dim, 1, 1, device=device)
    # 使用生成器生成假图像，并确保不计算梯度
    with torch.no_grad():
        fake_images = generator(noise).detach().cpu()

    # 利用 torchvision 的 make_grid 函数将多张图片拼接成一个网格
    image_grid = make_grid(fake_images, padding=2, normalize=True)

    # 将图像转换成 numpy 数组，并转换通道顺序以适应 matplotlib (H, W, C)
    np_image = image_grid.numpy().transpose((1, 2, 0))

    # 使用 matplotlib 显示图像
    plt.figure(figsize=(8, 8))
    plt.imshow(np_image)
    plt.axis("off")
    plt.show()