# 第8课：生成对抗网络 (GAN)

## 学习目标
- 理解 GAN 的基本原理
- 实现简单的 GAN
- 了解常见的 GAN 变体
- 学习 GAN 的训练技巧

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. GAN 原理

GAN (Generative Adversarial Network) 由两个网络组成：

### 生成器 (Generator)
- 输入：随机噪声 z
- 输出：生成的假样本
- 目标：生成以假乱真的样本

### 判别器 (Discriminator)
- 输入：真实样本或生成样本
- 输出：样本为真的概率
- 目标：区分真假样本

### 博弈过程
- 生成器试图欺骗判别器
- 判别器试图识别假样本
- 两者相互竞争，共同进步

In [None]:
# 可视化 GAN 原理
fig, ax = plt.subplots(figsize=(12, 6))

# 绘制流程图
ax.text(0.1, 0.7, 'Noise z', fontsize=14, ha='center', 
        bbox=dict(boxstyle='round', facecolor='lightblue'))
ax.annotate('', xy=(0.25, 0.7), xytext=(0.15, 0.7),
            arrowprops=dict(arrowstyle='->', color='black'))

ax.text(0.35, 0.7, 'Generator', fontsize=14, ha='center',
        bbox=dict(boxstyle='round', facecolor='lightgreen'))
ax.annotate('', xy=(0.5, 0.7), xytext=(0.45, 0.7),
            arrowprops=dict(arrowstyle='->', color='black'))

ax.text(0.6, 0.7, 'Fake Image', fontsize=14, ha='center',
        bbox=dict(boxstyle='round', facecolor='lightyellow'))

ax.text(0.6, 0.3, 'Real Image', fontsize=14, ha='center',
        bbox=dict(boxstyle='round', facecolor='lightyellow'))

ax.annotate('', xy=(0.75, 0.5), xytext=(0.7, 0.7),
            arrowprops=dict(arrowstyle='->', color='black'))
ax.annotate('', xy=(0.75, 0.5), xytext=(0.7, 0.3),
            arrowprops=dict(arrowstyle='->', color='black'))

ax.text(0.85, 0.5, 'Discriminator', fontsize=14, ha='center',
        bbox=dict(boxstyle='round', facecolor='lightcoral'))

ax.text(0.85, 0.2, 'Real/Fake?', fontsize=12, ha='center')

ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis('off')
ax.set_title('GAN Architecture', fontsize=16)
plt.show()

## 2. 数据准备

In [None]:
# 超参数
batch_size = 128
image_size = 28  # MNIST 图像大小
latent_dim = 100  # 噪声维度
lr = 0.0002
beta1 = 0.5
num_epochs = 50

# 数据变换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # 归一化到 [-1, 1]
])

# 加载 MNIST 数据集
dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
# num_workers 说明：Windows 或 Jupyter 环境下如果卡住，请将 num_workers 改为 0
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

print(f"数据集大小: {len(dataset)}")
print(f"批次数: {len(dataloader)}")

## 3. 定义生成器

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, image_size):
        super(Generator, self).__init__()
        
        self.model = nn.Sequential(
            # 输入: latent_dim
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm1d(256),
            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm1d(512),
            
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm1d(1024),
            
            nn.Linear(1024, image_size * image_size),
            nn.Tanh()  # 输出范围 [-1, 1]
        )
        
        self.image_size = image_size
    
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, self.image_size, self.image_size)
        return img

# 创建生成器
generator = Generator(latent_dim, image_size).to(device)
print(generator)

## 4. 定义判别器

In [None]:
class Discriminator(nn.Module):
    def __init__(self, image_size):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            # 输入: image_size * image_size
            nn.Linear(image_size * image_size, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            nn.Linear(256, 1),
            nn.Sigmoid()  # 输出概率
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

# 创建判别器
discriminator = Discriminator(image_size).to(device)
print(discriminator)

## 5. 损失函数和优化器

In [None]:
# 损失函数
criterion = nn.BCELoss()

# 优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# 固定噪声用于可视化
fixed_noise = torch.randn(64, latent_dim, device=device)

## 6. 训练 GAN

In [None]:
def train_gan(generator, discriminator, dataloader, num_epochs):
    G_losses = []
    D_losses = []
    
    for epoch in range(num_epochs):
        for i, (real_images, _) in enumerate(dataloader):
            batch_size = real_images.size(0)
            real_images = real_images.to(device)
            
            # 标签
            real_labels = torch.ones(batch_size, 1, device=device)
            fake_labels = torch.zeros(batch_size, 1, device=device)
            
            # ---------------------
            # 训练判别器
            # ---------------------
            optimizer_D.zero_grad()
            
            # 真实图像
            real_output = discriminator(real_images)
            d_loss_real = criterion(real_output, real_labels)
            
            # 生成假图像
            z = torch.randn(batch_size, latent_dim, device=device)
            fake_images = generator(z)
            fake_output = discriminator(fake_images.detach())
            d_loss_fake = criterion(fake_output, fake_labels)
            
            # 判别器总损失
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            optimizer_D.step()
            
            # ---------------------
            # 训练生成器
            # ---------------------
            optimizer_G.zero_grad()
            
            # 生成器希望判别器认为假图像是真的
            fake_output = discriminator(fake_images)
            g_loss = criterion(fake_output, real_labels)
            
            g_loss.backward()
            optimizer_G.step()
            
            # 记录损失
            G_losses.append(g_loss.item())
            D_losses.append(d_loss.item())
        
        # 打印进度
        print(f'Epoch [{epoch+1}/{num_epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}')
        
        # 每 10 个 epoch 可视化生成结果
        if (epoch + 1) % 10 == 0:
            show_generated_images(generator, fixed_noise, epoch + 1)
    
    return G_losses, D_losses

def show_generated_images(generator, noise, epoch):
    """显示生成的图像"""
    generator.eval()
    with torch.no_grad():
        fake_images = generator(noise).cpu()
    generator.train()
    
    # 反归一化
    fake_images = (fake_images + 1) / 2
    
    fig, axes = plt.subplots(8, 8, figsize=(10, 10))
    for i, ax in enumerate(axes.flat):
        ax.imshow(fake_images[i, 0], cmap='gray')
        ax.axis('off')
    plt.suptitle(f'Epoch {epoch}')
    plt.tight_layout()
    plt.show()

In [None]:
# 训练 (这里设置较少的 epoch 用于演示)
num_epochs = 30
G_losses, D_losses = train_gan(generator, discriminator, dataloader, num_epochs)

In [None]:
# 绘制损失曲线
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(G_losses, label='Generator', alpha=0.7)
plt.plot(D_losses, label='Discriminator', alpha=0.7)
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.title('GAN Training Loss')
plt.legend()
plt.grid(True, alpha=0.3)

# 平滑后的损失
window = 100
G_smooth = np.convolve(G_losses, np.ones(window)/window, mode='valid')
D_smooth = np.convolve(D_losses, np.ones(window)/window, mode='valid')

plt.subplot(1, 2, 2)
plt.plot(G_smooth, label='Generator')
plt.plot(D_smooth, label='Discriminator')
plt.xlabel('Iterations')
plt.ylabel('Smoothed Loss')
plt.title('Smoothed Training Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 7. 生成新图像

In [None]:
# 生成新图像
def generate_images(generator, num_images=100):
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(num_images, latent_dim, device=device)
        generated = generator(noise).cpu()
    generator.train()
    return generated

# 生成并显示
new_images = generate_images(generator, 64)
new_images = (new_images + 1) / 2  # 反归一化

fig, axes = plt.subplots(8, 8, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
    ax.imshow(new_images[i, 0], cmap='gray')
    ax.axis('off')
plt.suptitle('Generated Images')
plt.tight_layout()
plt.show()

## 8. 潜在空间插值

In [None]:
def interpolate(generator, z1, z2, steps=10):
    """在潜在空间中插值"""
    generator.eval()
    
    # 线性插值
    alphas = torch.linspace(0, 1, steps)
    interpolated = []
    
    with torch.no_grad():
        for alpha in alphas:
            z = (1 - alpha) * z1 + alpha * z2
            img = generator(z.unsqueeze(0))
            interpolated.append(img)
    
    generator.train()
    return torch.cat(interpolated, dim=0)

# 随机选择两个噪声向量
z1 = torch.randn(latent_dim, device=device)
z2 = torch.randn(latent_dim, device=device)

# 插值
interpolated_images = interpolate(generator, z1, z2, steps=10)
interpolated_images = (interpolated_images.cpu() + 1) / 2

# 显示
fig, axes = plt.subplots(1, 10, figsize=(15, 2))
for i, ax in enumerate(axes):
    ax.imshow(interpolated_images[i, 0], cmap='gray')
    ax.axis('off')
plt.suptitle('Latent Space Interpolation')
plt.tight_layout()
plt.show()

## 9. DCGAN (Deep Convolutional GAN)

使用卷积层的 GAN，效果更好

In [None]:
class DCGenerator(nn.Module):
    def __init__(self, latent_dim, ngf=64):
        super(DCGenerator, self).__init__()
        
        self.main = nn.Sequential(
            # 输入: latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, ngf * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 状态: (ngf*4) x 4 x 4
            
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 状态: (ngf*2) x 8 x 8
            
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 状态: (ngf) x 16 x 16
            
            nn.ConvTranspose2d(ngf, 1, 4, 2, 3, bias=False),
            nn.Tanh()
            # 输出: 1 x 28 x 28
        )
    
    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 1, 1)
        return self.main(z)

class DCDiscriminator(nn.Module):
    def __init__(self, ndf=64):
        super(DCDiscriminator, self).__init__()
        
        self.main = nn.Sequential(
            # 输入: 1 x 28 x 28
            nn.Conv2d(1, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态: (ndf) x 14 x 14
            
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态: (ndf*2) x 7 x 7
            
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态: (ndf*4) x 3 x 3
            
            nn.Conv2d(ndf * 4, 1, 3, 1, 0, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, img):
        return self.main(img).view(-1, 1)

print("DCGAN Generator:")
print(DCGenerator(latent_dim))
print("\nDCGAN Discriminator:")
print(DCDiscriminator())

## 10. GAN 的常见问题和技巧

### 模式崩溃 (Mode Collapse)
- 生成器只生成有限几种样本
- 解决方案：Wasserstein GAN、特征匹配

### 训练不稳定
- 判别器太强或太弱
- 解决方案：调整学习率、标签平滑

### 训练技巧
1. 使用 LeakyReLU 而不是 ReLU
2. 使用 BatchNorm
3. 使用 Adam 优化器
4. 判别器多训练几步

In [None]:
# 标签平滑示例
def smooth_labels(real_labels, fake_labels, smoothing=0.1):
    """标签平滑，增加训练稳定性"""
    real_labels = real_labels * (1 - smoothing) + 0.5 * smoothing
    fake_labels = fake_labels * (1 - smoothing) + 0.5 * smoothing
    return real_labels, fake_labels

# 示例
real = torch.ones(5, 1)
fake = torch.zeros(5, 1)
real_smooth, fake_smooth = smooth_labels(real, fake)

print(f"原始真实标签: {real.flatten().numpy()}")
print(f"平滑真实标签: {real_smooth.flatten().numpy()}")
print(f"原始假标签: {fake.flatten().numpy()}")
print(f"平滑假标签: {fake_smooth.flatten().numpy()}")

## 11. 练习题

### 练习1：训练 DCGAN
使用 DCGAN 架构训练生成器

In [None]:
# 在这里编写代码


### 练习2：条件 GAN
实现一个可以指定生成数字类别的条件 GAN

In [None]:
# 在这里编写代码


## 12. 本课小结

### GAN 基本原理

1. **生成器**：将随机噪声转换为图像
2. **判别器**：区分真假图像
3. **对抗训练**：两个网络相互竞争

### 常见 GAN 变体

| 名称 | 特点 |
|------|------|
| DCGAN | 使用卷积层 |
| WGAN | Wasserstein 距离 |
| cGAN | 条件生成 |
| StyleGAN | 高质量图像生成 |
| CycleGAN | 图像风格转换 |

### 训练要点

1. 平衡生成器和判别器的能力
2. 使用合适的网络架构
3. 仔细调整超参数
4. 监控训练过程，观察生成质量