## 使用 GAN 生成动漫头像
#### 参考：
* 《深度学习框架 PyTorch 入门与实践》
* [PyTorch nn document](https://pytorch.org/docs/stable/nn.html)
* [pytorch中ConvTranspose2d的计算公式](https://zhuanlan.zhihu.com/p/39240159)

In [1]:
import os
import tqdm
import torch as t
import torch.nn as nn
import torchvision as tv
from torch.autograd import Variable

#### 生成器
* 采用"上卷积"，根据噪声输出一张 64*64*3 的图片
* `torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1)`
* 卷积中尺寸变化的公式：`H_{out} = (H_{in} - 1) * stride - 2 * padding + kernel_size`
  * 当 kennel size、stride、padding 分别为 4、2、1 是，输出尺寸刚好变成输入的两倍

In [2]:
class NetG(nn.Module):
    """
    生成器定义                                                                                             
    """
    def __init__(self, opt):
        super(NetG, self).__init__()
        ngf = opt.ngf  # 生成器 feature map 数
        self.main = nn.Sequential(
            # 输入是 nz 维度的噪声，可认为其实一个 nz*1*1 的 feature map
            nn.ConvTranspose2d(opt.nz, ngf * 8, 4, 1, 0, bias=False),  # 尺寸： 1 -> 4
            nn.BatchNorm2d(ngf * 8),  # 对输入数据进行标准化（能让机器学习有效率地学习）
            nn.ReLU(True),
            # 上一步的输出形状：(ngf*8) * 4 * 4
            
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),  # 尺寸：4 -> 8
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 上一步的输出形状：(ngf*4) * 8 * 8
            
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),  # 尺寸：8 -> 16
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 上一步的输出形状：(ngf*2) * 16 * 16
            
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),  # 尺寸：16 -> 32
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 上一步的输出形状：(ngf) * 32 * 32
            
            nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),  # 尺寸：32 -> 96
            nn.Tanh()  #  将图片的像素归一化至 -1~1
            # 输出形状：3 * 96 * 96
        )
    
    def forward(self, input):
        return self.main(input)

#### 判别器
* 采用"下卷积"，根据输入的 64\*64\*3 的图片，输出图片属于正负样本的分数（概率）

In [3]:
class NetD(nn.Module):
    """
    判别器定义
    """
    def __init__(self, opt):
        super(NetD, self).__init__()
        ndf = opt.ndf
        self.main = nn.Sequential(
            # 输入 3 * 96 * 96
            nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf) * 32 * 32
            
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=False),
            # 输出 (ndf*2) * 16 * 16
            
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*4) * 8 * 8
            
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*8) * 4 * 4
            
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # 输出一个数（概率）
        )
    
    def forward(self, input):
        return self.main(input).view(-1)

#### visdom 可视化封装

In [4]:
import visdom
import time
import torchvision as tv
import numpy as np


class Visualizer:
    """
    封装了visdom的基本操作，但是你仍然可以通过`self.vis.function`
    调用原生的visdom接口
    """

    def __init__(self, env='default', **kwargs):
        import visdom
        self.vis = visdom.Visdom(env=env, **kwargs)

        # 画的第几个数，相当于横座标
        # 保存（’loss',23） 即loss的第23个点
        self.index = {}
        self.log_text = ''

    def reinit(self, env='default', **kwargs):
        """
        修改visdom的配置
        """
        self.vis = visdom.Visdom(env=env, **kwargs)
        return self

    def plot_many(self, d):
        """
        一次plot多个
        @params d: dict (name,value) i.e. ('loss',0.11)
        """
        for k, v in d.items():
            self.plot(k, v)

    def img_many(self, d):
        for k, v in d.items():
            self.img(k, v)

    def plot(self, name, y):
        """
        self.plot('loss',1.00)
        """
        x = self.index.get(name, 0)
        self.vis.line(Y=np.array([y]), X=np.array([x]),
                      win=(name),
                      opts=dict(title=name),
                      update=None if x == 0 else 'append'
                      )
        self.index[name] = x + 1

    def img(self, name, img_):
        """
        self.img('input_img',t.Tensor(64,64))
        """

        if len(img_.size()) < 3:
            img_ = img_.cpu().unsqueeze(0)
        self.vis.image(img_.cpu(), win=(name), opts=dict(title=name) )

    def img_grid_many(self, d):
        for k, v in d.items():
            self.img_grid(k, v)

    def img_grid(self, name, input_3d):
        """
        一个batch的图片转成一个网格图，i.e. input（36，64，64）
        会变成 6*6 的网格图，每个格子大小64*64
        """
        self.img(name, tv.utils.make_grid(
            input_3d.cpu()[0].unsqueeze(1).clamp(max=1, min=0)))

    def log(self, info, win='log_text'):
        """
        self.log({'loss':1,'lr':0.0001})
        """

        self.log_text += ('[{time}] {info} <br>'.format(
            time=time.strftime('%m%d_%H%M%S'),
            info=info))
        self.vis.text(self.log_text, win=win)

    def __getattr__(self, name):
        return getattr(self.vis, name)

#### 参数、数据集加载器

In [5]:
import os
import tqdm
import torch as t
import torchvision as tv
from torch.autograd import Variable
from torchnet.meter import AverageValueMeter

In [6]:
class Config(object):
    """
    配置信息                                                                                          
    """
    data_path = '/home/centos/leon/gan_data/'  # 数据集存放路径
    num_workers = 1  # 多进程加载数据所用的进程数
    image_size = 96  # 图片尺寸
    batch_size = 128
    max_epoch = 200
    lr1 = 2e-4  # 生成器的学习率
    lr2 = 2e-4  # 判别器的学习率
    beta1 = 0.5  # Adam 优化器的 beta1 参数
    use_gpu = False  # 是否使用 GPU
    nz = 100  # 噪声维度
    ngf = 64  # 生成器的 feature map 数
    ndf = 64  # 判别器的 feature map 数
    
    save_path = '/home/centos/leon/machine_learning_jupyter/neural_network/gan_demo/imgs/'  # 生成图片保存路径
    
    vis = True  # 是否使用 visdom 可视化
    env = 'GAN'  # visdom 的 env
    plot_every = 20  # 每间隔 20 batch，visdom 画图一次
    
    debug_file = '/tmp/debuggan'  # 存在该文件则进入 debug 模式
    d_every = 1  # 每 1 个 batch 训练一次判别器
    g_every = 5  # 每 5 个 batch 训练一次生成器
    decay_every = 10  # 每 10 个 epoch 保存一次模型
    netd_path = '/home/centos/leon/machine_learning_jupyter/neural_network/gan_demo/checkpoints/netd.pth'  # 预训练模型
    netg_path = '/home/centos/leon/machine_learning_jupyter/neural_network/gan_demo/checkpoints/netg.pth'
    
    # 测试时用的参数
    gen_img = 'result.png'
    # 从 512 张生成的图片中保存最好的 64 张
    gen_num = 64
    gen_search_num = 512
    gen_mean = 0  # 噪声的均值
    gen_std = 1  # 噪声的方差

opt = Config()
# 数据加载
transforms = tv.transforms.Compose([
    tv.transforms.Scale(opt.image_size), 
    tv.transforms.CenterCrop(opt.image_size), 
    tv.transforms.ToTensor(), 
    tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 
    ])
dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
dataloader = t.utils.data.DataLoader(dataset=dataset, 
                                     batch_size=opt.batch_size,
                                     shuffle=True,
                                     num_workers=opt.num_workers,
                                     drop_last=True
                                    )

#### 网络训练

In [None]:
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    if opt.vis:
        vis = Visualizer(opt.env)

    transforms = tv.transforms.Compose([
        tv.transforms.Scale(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True
                                         )

    # 定义网络
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if os.path.exists(opt.netd_path):
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if os.path.exists(opt.netg_path):
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss()

    # 真图片label为1，假图片label为0
    # noises为生成网络的输入
    true_labels = Variable(t.ones(opt.batch_size))
    fake_labels = Variable(t.zeros(opt.batch_size))
    fix_noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))
    noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    if opt.use_gpu:
        netd.cuda()
        netg.cuda()
        criterion.cuda()
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        fix_noises, noises = fix_noises.cuda(), noises.cuda()

    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = Variable(img)
            if opt.use_gpu:
                real_img = real_img.cuda()
            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 尽可能把假图片判别为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.data[0])

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.data[0])

            if opt.vis and ii % opt.plot_every == 0:
                ## 可视化
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.data.cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5, win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if epoch % opt.decay_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:64], '%s/%s.png' % (opt.save_path, epoch), normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(), opt.netd_path.replace('netd.pth', 'netd_{}.pth').format(epoch))
            t.save(netg.state_dict(), opt.netg_path.replace('netg.pth', 'netg_{}.pth').format(epoch))
            errord_meter.reset()
            errorg_meter.reset()
            optimizer_g = t.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
            optimizer_d = t.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
            pass
        pass
    pass

train()


0it [00:00, ?it/s][A
1it [00:22, 22.05s/it][A
2it [00:34, 19.15s/it][A
3it [00:46, 17.14s/it][A
4it [00:58, 15.58s/it][A