In [1]:
import math
import pickle
import os

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset
import torchvision
from torchvision import transforms, datasets
from torch import optim
from torchnet import meter
from tqdm import tqdm
from PIL import Image

import matplotlib.pyplot as plt

In [2]:
class Config(object):
    data_path = 'data/'  # 数据集路径
    image_size = 96  # 图像尺寸
    batch_size = 32  # 批大小
    epochs = 20  # 训练轮数
    lr1 = 2e-3  # 生成器学习率
    lr2 = 2e-4  # 判别器学习率
    beta1 = 0.5  # Adam优化器参数
    gpu = False  # 是否使用GPU (已废弃)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 自动选择设备
    nz = 100  # 噪声向量维度
    ngf = 64  # 生成器特征图基数
    ndf = 64  # 判别器特征图基数
    save_path = './images'  # 图像保存路径
    generator_path = './generator.pkl'  # 生成器保存路径
    discriminator_path = './discriminator.pkl'  # 判别器保存路径
    gen_img = 'result.png'  # 输出图像文件名
    gen_num = 64  # 生成图像数量
    gen_search_num = 10000  # 候选图像数量
    gen_mean = 0  # 噪声均值
    gen_std = 1  # 噪声标准差

config = Config()  # 实例化配置

In [3]:
data_transform = transforms.Compose([
    transforms.Resize(config.image_size),
    transforms.CenterCrop(config.image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

In [4]:
# 2.形成训练集
train_dataset = datasets.ImageFolder(root=os.path.join(config.data_path),
                                     transform=data_transform)

In [5]:
# 3.形成迭代器
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           config.batch_size,
                                           True,
                                           drop_last=True)

print('using {} images for training.'.format(len(train_dataset)))



using 63565 images for training.


In [6]:
class Generator(nn.Module):
    def __init__(self, config):
        super().__init__()

        ngf = config.ngf

        self.model = nn.Sequential(
            nn.ConvTranspose2d(config.nz, ngf * 8, 4, 1, 0),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, 3, 5, 3, 1),
            nn.Tanh()
        )

    def forward(self, x):
        output = self.model(x)
        return output



In [7]:
class Discriminator(nn.Module):
    def __init__(self, config):
        super().__init__()

        ndf = config.ndf

        self.model = nn.Sequential(
            nn.Conv2d(3, ndf, 5, 3, 1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 8, 1, 4, 1, 0)
        )

    def forward(self, x):
        output = self.model(x)
        return output.view(-1)


In [None]:
# 1. 确保设备设置正确
config.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {config.device}")

# 2. 初始化模型并移动到设备
generator = Generator(config).to(config.device)
discriminator = Discriminator(config).to(config.device)

# 3. 设置优化器
optimizer_generator = torch.optim.Adam(generator.parameters(), config.lr1, betas=(config.beta1, 0.999))
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), config.lr2, betas=(config.beta1, 0.999))

# 4. 标签和噪声在设备上创建
true_labels = torch.ones(config.batch_size, device=config.device)
fake_labels = torch.zeros(config.batch_size, device=config.device)
fix_noises = torch.randn(config.batch_size, config.nz, 1, 1, device=config.device)
noises = torch.randn(config.batch_size, config.nz, 1, 1, device=config.device)

# 5. 训练循环
for epoch in range(config.epochs):
    for ii, (img, _) in tqdm(enumerate(train_loader)):
        if epoch % 10 == 0:  # 每10个epoch保存一次
            with torch.no_grad():
                test_img = generator(fix_noises)
        save_path = os.path.join(config.save_path, f"epoch_{epoch}.png")
        torchvision.utils.save_image(
            test_img, 
            save_path,
            normalize=True,
            nrow=8,  # 每行8张图
            value_range=(-1, 1)
        )
        # 确保数据移动到设备 - 添加错误检查
        try:
            real_img = img.to(config.device)
        except RuntimeError as e:
            print(f"移动数据到设备失败: {str(e)}")
            # 尝试释放内存
            torch.cuda.empty_cache()
            real_img = img.to(config.device)
        
        # 创建新的噪声张量（直接在设备上）
        new_noise = torch.randn(config.batch_size, config.nz, 1, 1, device=config.device)
        
        if ii % 2 == 0:
            optimizer_discriminator.zero_grad()
            
            # 更新噪声
            noises.data.copy_(new_noise)
            
            # 前向传播 - 添加设备检查
            # print(f"判别器设备: {next(discriminator.parameters()).device}")
            # print(f"输入数据设备: {real_img.device}")
            
            r_preds = discriminator(real_img)
            fake_img = generator(noises).detach()
            f_preds = discriminator(fake_img)
            
            r_f_diff = (r_preds - f_preds.mean()).clamp(max=1)  # 真实与假图像评分的差异
            f_r_diff = (f_preds - r_preds.mean()).clamp(min=-1)  # 假图像与真实评分的差异
            loss_d_real = (1 - r_f_diff).mean()  # 真实图像损失
            loss_d_fake = (1 + f_r_diff).mean()  # 假图像损失
            loss_d = loss_d_real + loss_d_fake  # 判别器总损失

            loss_d.backward()  # 反向传播计算梯度
            optimizer_discriminator.step()  # 更新判别器参数

        else:
            optimizer_generator.zero_grad()  # 清除生成器梯度
            noises.data.copy_(torch.randn(config.batch_size, config.nz, 1, 1))  # 更新噪声
            fake_img = generator(noises)  # 生成假图像
            f_preds = discriminator(fake_img)  # 判别器对假图像的评分
            r_preds = discriminator(real_img)  # 判别器对真实图像的评分

            # 8. 计算生成器损失
            r_f_diff = r_preds - torch.mean(f_preds)  # 真实与假图像评分的差异
            f_r_diff = f_preds - torch.mean(r_preds)  # 假图像与真实评分的差异
            # 使用ReLU激活的损失函数
            loss_g = torch.mean(F.relu(1 + r_f_diff)) + torch.mean(F.relu(1 - f_r_diff))

            loss_g.backward()  # 反向传播计算梯度
            optimizer_generator.step()  # 更新生成器参数
              # 9. 在最后一个epoch保存模型
    if epoch == config.epochs - 1:
        torch.save(discriminator.state_dict(), config.discriminator_path)  # 保存判别器参数
        torch.save(generator.state_dict(), config.generator_path)  # 保存生成器参数

print('Finished Training')  # 训练完成提示


使用设备: cuda


1986it [02:50, 11.63it/s]
1986it [02:22, 13.92it/s]
1986it [02:23, 13.80it/s]
1986it [02:23, 13.84it/s]
1986it [02:22, 13.96it/s]
1986it [02:22, 13.92it/s]
1986it [02:22, 13.90it/s]
1986it [02:22, 13.95it/s]
1986it [02:23, 13.89it/s]
1986it [02:22, 13.94it/s]
1986it [02:40, 12.34it/s]
1986it [02:37, 12.58it/s]
1986it [02:39, 12.48it/s]
1986it [02:38, 12.53it/s]
1986it [02:37, 12.57it/s]
1986it [02:37, 12.64it/s]
1986it [02:37, 12.59it/s]
1986it [02:37, 12.61it/s]
1986it [02:39, 12.46it/s]
1928it [02:34, 12.59it/s]

In [None]:
generator = Generator(config)  # 重新创建生成器实例
discriminator = Discriminator(config)  # 重新创建判别器实例

# 11. 准备大量噪声用于生成候选图像
noises = torch.randn(config.gen_search_num, config.nz, 1, 1).normal_(config.gen_mean, config.gen_std)
noises = noises.to(config.device)  # 将噪声移动到设备

# 12. 加载训练好的模型参数
generator.load_state_dict(torch.load(config.generator_path, map_location='cpu'))  # 加载生成器参数到CPU
discriminator.load_state_dict(torch.load(config.discriminator_path, map_location='cpu'))  # 加载判别器参数到CPU
generator.to(config.device)  # 将生成器移动到设备
discriminator.to(config.device)  # 将判别器移动到设备

# 13. 生成图像并评分
fake_img = generator(noises)  # 使用生成器生成图像
scores = discriminator(fake_img).detach()  # 使用判别器评分并分离梯度

# 14. 选择评分最高的图像
indexs = scores.topk(config.gen_num)[1]  # 获取最高分的索引
result = []  # 存储结果图像
for ii in indexs:
    result.append(fake_img.data[ii])  # 收集最佳图像

# 15. 保存结果图像
torchvision.utils.save_image(torch.stack(result), config.gen_img, normalize=True, value_range=(-1, 1))