In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# 定义生成器
class Generator(nn.Module):
    def __init__(self, latent_dim, img_size, channels):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.main(z)

# 定义判别器
class Discriminator(nn.Module):
    def __init__(self, img_size, channels):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.main(img)

# 设置超参数
latent_dim = 100
img_size = 64
channels = 3
batch_size = 128
lr = 0.0002
num_epochs = 100
beta1 = 0.5

# 加载数据集（以CelebA为例）
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = torchvision.datasets.CelebA(root='./data/img_align_celeba', split='train', download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# 初始化生成器和判别器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(latent_dim, img_size, channels).to(device)
discriminator = Discriminator(img_size, channels).to(device)

# 定义优化器和损失函数
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
criterion = nn.BCELoss()

# 训练过程
for epoch in range(num_epochs):
    for i, (real_imgs, _) in enumerate(dataloader):
        # 将数据移动到设备
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)
        
        # 训练判别器
        optimizer_D.zero_grad()
        
        # 真实图像的判别损失
        label = torch.full((batch_size,), 1., device=device)
        output = discriminator(real_imgs).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        
        # 生成图像的判别损失
        noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
        fake_imgs = generator(noise)
        label.fill_(0.)
        output = discriminator(fake_imgs.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        errD = errD_real + errD_fake
        optimizer_D.step()
        
        # 训练生成器
        optimizer_G.zero_grad()
        label.fill_(1.)
        output = discriminator(fake_imgs).view(-1)
        errG = criterion(output, label)
        errG.backward()
        optimizer_G.step()
        
        # 打印训练进度
        if i % 50 == 0:
            print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] [D loss: {errD.item()}] [G loss: {errG.item()}]")

# 生成新图像
def generate_images(num_images):
    noise = torch.randn(num_images, latent_dim, 1, 1, device=device)
    generated_imgs = generator(noise)
    generated_imgs = generated_imgs.cpu().detach().numpy()
    generated_imgs = (generated_imgs + 1) / 2  # 反归一化
    return generated_imgs

# 可视化生成的图像
def visualize_images(generated_imgs):
    plt.figure(figsize=(10, 10))
    for i in range(generated_imgs.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(np.transpose(generated_imgs[i], (1, 2, 0)))
        plt.axis('off')
    plt.show()

# 调优建议
# 1. 调整学习率和优化器参数
# 2. 增加网络深度和复杂度
# 3. 使用不同的激活函数
# 4. 应用数据增强技术
# 5. 调整训练轮数和批次大小
# 6. 尝试不同的生成器和判别器架构
# 7. 使用梯度惩罚（WGAN-GP）等技术

# 使用示例
generated_imgs = generate_images(16)
visualize_images(generated_imgs)

ProxyError: HTTPSConnectionPool(host='drive.google.com', port=443): Max retries exceeded with url: /uc?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM (Caused by ProxyError('Unable to connect to proxy', FileNotFoundError(2, 'No such file or directory')))