In [1]:
# coding:utf8
import os
# import ipdb
import torch as t
import torchvision as tv
import tqdm
from model import NetG, NetD
from torch.autograd import Variable
# from torchnet.meter import AverageValueMeter

In [2]:
class Config(object):
    data_path = './data/'  # 数据集存放路径
    num_workers = 1  # 多进程加载数据所用的进程数
    image_size = 96  # 图片尺寸3*96*96
    batch_size = 128
    max_epoch = 31
    lr1 = 2e-4  # 生成器的学习率
    lr2 = 2e-4  # 判别器的学习率
    beta1 = 0.5  # Adam优化器的beta1参数
    gpu = True  # 是否使用GPU
    nz = 100  # 噪声维度
    ngf = 64  # 生成器feature map数
    ndf = 64  # 判别器feature map数

    save_path = './imgs/'  # 生成图片保存路径

    vis = True  # 是否使用visdom可视化
    env = 'GAN'  # visdom的env
    plot_every = 20  # 每间隔20 batch，visdom画图一次

    debug_file = './tmp/debuggan'  # 存在该文件则进入debug模式
    d_every = 1  # 每1个batch训练一次判别器
    g_every = 5  # 每5个batch训练一次生成器
    decay_every = 10  # 每10个epoch保存一次模型
    netd_path = None  # 'checkpoints/netd_.pth' #预训练模型 加载参数
    netg_path = None  # 'checkpoints/netg_211.pth'

    # 只测试不训练
    gen_img = 'result.png'
    # 从512张生成的图片中保存最好的64张
    gen_num = 64
    gen_search_num = 512
    gen_mean = 0  # 噪声的均值
    gen_std = 1  # 噪声的方差

opt = Config()

## 数据加载

In [3]:
transforms = tv.transforms.Compose([
        tv.transforms.Scale(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms) # 数据加载的时候进行一定的转换
dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True
                                         )

  "please use transforms.Resize instead.")


## 定义网络

In [4]:
netg, netd = NetG(opt), NetD(opt)
map_location = lambda storage, loc: storage
# 如果有预训练好的参数可以加载进来
if opt.netd_path:
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
if opt.netg_path:
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))

## 定义优化器和损失

In [5]:
optimizer_g = t.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
optimizer_d = t.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
criterion = t.nn.BCELoss()

## 真图片label为1，假图片label为0
## noises为生成网络的输入

In [6]:
true_labels = Variable(t.ones(opt.batch_size))
fake_labels = Variable(t.zeros(opt.batch_size))
fix_noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))
noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))

## 加入gpu运算

In [None]:
if opt.gpu:
    netd.cuda()
    netg.cuda()
    criterion.cuda()
    true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
    fix_noises, noises = fix_noises.cuda(), noises.cuda()

##  训练

In [None]:
epochs = range(opt.max_epoch)
for epoch in iter(epochs):
    for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
        real_img = Variable(img)
        if opt.gpu:
            real_img = real_img.cuda()
        if ii % opt.d_every == 0:
            # 训练判别器 每一个batch训练一次
            optimizer_d.zero_grad()
            ## 尽可能的把真图片判别为正确
            output = netd(real_img)
            error_d_real = criterion(output, true_labels)
            error_d_real.backward()

            ## 尽可能把假图片判别为错误
            noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
            fake_img = netg(noises).detach()  # 根据噪声生成假图
            output = netd(fake_img)
            error_d_fake = criterion(output, fake_labels)
            error_d_fake.backward()
            optimizer_d.step()


        if ii % opt.g_every == 0:
            # 训练生成器 每5个batch训练一次
            optimizer_g.zero_grad()
            noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
            fake_img = netg(noises)
            output = netd(fake_img)
            error_g = criterion(output, true_labels)
            error_g.backward()
            optimizer_g.step()


    if (epoch+1) % opt.decay_every == 0:
        # 保存模型、图片
        print('保存模型参数...')
#         tv.utils.save_image(fix_fake_imgs.data[:64], '%s/%s.png' % (opt.save_path, epoch), normalize=True,
#                             range=(-1, 1))
        t.save(netd.state_dict(), './checkpoints/netd_%s.pth' % epoch)
        t.save(netg.state_dict(), './checkpoints/netg_%s.pth' % epoch)
        optimizer_g = t.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
        optimizer_d = t.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
        print('next')

64it [02:48,  2.63s/it]
64it [02:52,  2.69s/it]
64it [02:48,  2.64s/it]
64it [02:57,  2.77s/it]
64it [02:54,  2.72s/it]
64it [02:26,  2.29s/it]
64it [02:20,  2.19s/it]
64it [02:27,  2.30s/it]
64it [02:07,  2.00s/it]
64it [02:17,  2.14s/it]


保存模型参数...
next


64it [02:17,  2.14s/it]
64it [02:21,  2.21s/it]
64it [02:28,  2.32s/it]
64it [02:20,  2.20s/it]
64it [02:30,  2.34s/it]
64it [02:36,  2.44s/it]
64it [03:25,  3.22s/it]
64it [02:35,  2.43s/it]
64it [02:36,  2.44s/it]
64it [02:07,  1.99s/it]


保存模型参数...
next


64it [02:12,  2.08s/it]
64it [02:09,  2.02s/it]
64it [03:15,  3.06s/it]
64it [03:31,  3.30s/it]
64it [03:39,  3.44s/it]
64it [03:38,  3.42s/it]
64it [03:29,  3.27s/it]
64it [03:28,  3.27s/it]
64it [03:53,  3.64s/it]
64it [04:42,  4.41s/it]


保存模型参数...
next


35it [07:51, 13.48s/it]