# 利用CycleGAN进行风格迁移

### 功能简述

可以实现A类图片和B类图片之间相互的风格迁移。

比如A类图片为马的图片，B类图片为斑马的图片，可以实现将马转化成斑马，也可以将斑马转化成马。

整体架构使用的是CycleGAN，即同时训练将A转化为B风格的GAN和将B转化为A风格的GAN。

训练耗费的时间比较长，但是一旦有了训练好的模型，生成图片的速度比较快。

### 目录结构：

    /--+-- snapshot/                  ...存放快照
       |
       +-- models/                    ...存放训练出来的模型
       |
       +-- data/                      ...存放数据
             |
             +-- vangogh2photo/       ...某个数据集
                       |
                       +-- trainA/    ...类别为A的图片
                       |
                       +-- trainB/    ...类别为B的图片
                 

## 数据集相关

在我们的实现中，使用了python中的“类”的概念，可以把它看作一种工具。一种类就是完成一种任务的工具，包含了完成这种任务的一些数据和处理这些数据的方法。


In [None]:
from PIL import Image
import numpy as np


def load_image(fn, image_size):
    """
    加载一张图片
    fn:图像文件路径
    image_size:图像大小
    """
    im = Image.open(fn).convert('RGB')
    
    #切割图像(截取图像中间的最大正方形，然后将大小调整至输入大小)
    if (im.size[0] >= im.size[1]):
        im = im.crop(((im.size[0] - im.size[1])//2, 0, (im.size[0] + im.size[1])//2, im.size[1]))
    else:
        im = im.crop((0, (im.size[1] - im.size[0])//2, im.size[0], (im.size[0] + im.size[1])//2))
    im = im.resize((image_size, image_size), Image.BILINEAR)
    
    #将0-255的RGB值转换到[-1,1]上的值
    arr = np.array(im)/255*2-1   
    
    return arr

import glob
import random

class DataSet(object):
    """
    用于管理数据的类
    """
    def __init__(self, data_path, image_size = 256):
        self.data_path = data_path
        self.epoch = 0
        self.__init_list()
        self.image_size = image_size
        
    def __init_list(self):
        self.data_list = glob.glob(self.data_path)
        random.shuffle(self.data_list)
        self.ptr = 0
        
    def get_batch(self, batchsize):
        """
        取出batchsize张图片
        """
        if (self.ptr + batchsize >= len(self.data_list)):
            batch = [load_image(x, self.image_size) for x in self.data_list[self.ptr:]]
            rest = self.ptr + batchsize - len(self.data_list)
            self.__init_list()
            batch.extend([load_image(x, self.image_size) for x in self.data_list[:rest]])
            self.ptr = rest
            self.epoch += 1
        else:
            batch = [load_image(x, self.image_size) for x in self.data_list[self.ptr:self.ptr + batchsize]]
            self.ptr += batchsize
        
        return self.epoch, batch
        
    def get_pics(self, num):
        """
        取出num张图片，用于快照
        不会影响队列
        """
        return np.array([load_image(x, self.image_size) for x in random.sample(self.data_list, num)])

def arr2image(X):
    """
    将RGB值从[-1,1]重新转回[0,255]
    """
    int_X = ((X+1)/2*255).clip(0,255).astype('uint8')
    return Image.fromarray(int_X)

def generate(img, fn):
    """
    将一张图片img送入生成网络fn中
    """
    r = fn([np.array([img])])[0]
    return arr2image(np.array(r[0]))


## 构建网络

In [None]:
#导入必要的库
import keras.backend as K

from keras.models import Sequential, Model
from keras.layers import Conv2D, BatchNormalization, Input, Dropout, Add
from keras.layers import Conv2DTranspose, Reshape, Activation, Cropping2D, Flatten
from keras.layers import Concatenate
from keras.optimizers import RMSprop, SGD, Adam

from keras.layers.advanced_activations import LeakyReLU
from keras.activations import relu,tanh
from keras.initializers import RandomNormal

In [None]:
#用于初始化
conv_init = RandomNormal(0, 0.02)

def conv2d(f, *a, **k):
    """
    卷积层
    """
    return Conv2D(f, 
                  kernel_initializer = conv_init,
                  *a, **k)
def batchnorm():
    """
    标准化层
    """
    return BatchNormalization(momentum=0.9, epsilon=1.01e-5, axis=-1,)

def res_block(x, dim):
    """
    残差网络
    [x] --> [卷积] --> [标准化] --> [激活] --> [卷积] --> [标准化] --> [激活] --> [+] --> [激活]
     |                                                                        ^
     |                                                                        |
     +------------------------------------------------------------------------+
    """
    x1 = conv2d(dim, 3, padding="same", use_bias=True)(x)
    x1 = batchnorm()(x1, training=1)
    x1 = Activation('relu')(x1)
    x1 = conv2d(dim, 3, padding="same", use_bias=True)(x1)
    x1 = batchnorm()(x1, training=1)
    x1 = Activation("relu")(Add()([x,x1]))
    return x1

#### 生成网络


In [None]:
def NET_G(ngf=64, block_n=6, downsampling_n=2, upsampling_n=2, image_size = 256):
    """
    生成网络
    采用resnet结构

    block_n为残差网络叠加的数量
    论文中采用的参数为 若图片大小为128,采用6；若图片大小为256,采用9

    [第一层] 大小为7的卷积核 通道数量 3->ngf 
    [下采样] 大小为3的卷积核 步长为2 每层通道数量倍增
    [残差网络] 九个block叠加
    [上采样] 
    [最后一层] 通道数量变回3

    """
    
    input_t = Input(shape=(image_size, image_size, 3))
    #输入层

    x = input_t
    dim = ngf
    
    x = conv2d(dim, 7, padding="same")(x)
    x = batchnorm()(x, training = 1)
    x = Activation("relu")(x)
    #第一层
    
    for i in range(downsampling_n):
        dim *= 2
        x = conv2d(dim, 3, strides = 2, padding="same")(x)
        x = batchnorm()(x, training = 1)
        x = Activation('relu')(x)
    #下采样部分

    for i in range(block_n):
        x = res_block(x, dim)
    #残差网络部分

    for i in range(upsampling_n):
        dim = dim // 2
        x = Conv2DTranspose(dim, 3, strides = 2, kernel_initializer = conv_init, padding="same")(x)
        x = batchnorm()(x, training = 1)
        x = Activation('relu')(x) 
    #上采样
    
    dim = 3
    x = conv2d(dim, 7, padding="same")(x)
    x = Activation("tanh")(x)
    #最后一层
    
    return Model(inputs=input_t, outputs=x)


#### 判别网络


In [None]:
def NET_D(ndf=64, max_layers = 3, image_size = 256):
    """
    判别网络
    """
    input_t = Input(shape=(image_size, image_size, 3))
    
    x = input_t
    x = conv2d(ndf, 4, padding="same", strides=2)(x)
    x = LeakyReLU(alpha = 0.2)(x)
    dim = ndf
    
    for i in range(1, max_layers):
        dim *= 2
        x = conv2d(dim, 4, padding="same", strides=2, use_bias=False)(x)
        x = batchnorm()(x, training=1)
        x = LeakyReLU(alpha = 0.2)(x)

    x = conv2d(dim, 4, padding="same")(x)
    x = batchnorm()(x, training=1)
    x = LeakyReLU(alpha = 0.2)(x)
        
    x = conv2d(1, 4, padding="same", activation = "sigmoid")(x)
    return Model(inputs=input_t, outputs=x)


In [None]:
def loss_func(output, target):
    """
    损失函数
    论文中提到使用平方损失更好
    """
    return K.mean(K.abs(K.square(output-target)))

#### 网络结构的搭建
我们采用“类”的概念来组织GAN的网络结构：

In [None]:
class CycleGAN(object):
    def __init__(self, image_size=256, lambda_cyc=10, lrD = 2e-4, lrG = 2e-4, ndf = 64, ngf = 64, resnet_blocks = 9):
        """
        构建网络结构
                      cyc loss
         +---------------------------------+      
         |            (CycleA)             |       
         v                                 |
        realA -> [GB] -> fakeB -> [GA] -> recA          
         |                 |
         |                 +---------------+
         |                                 |
         v                                 v
        [DA]         <CycleGAN>           [DB]
         ^                                 ^
         |                                 |
         +----------------+                |
                          |                |
        recB <- [GB] <- fakeA <- [GA] <- realB          
         |                                 ^
         |            (CycleB)             |
         +---------------------------------+
                        cyc loss
        """
        
        #创建生成网络
        self.GA = NET_G(image_size = image_size, ngf = ngf, block_n = resnet_blocks)
        self.GB = NET_G(image_size = image_size, ngf = ngf, block_n = resnet_blocks)
        
        #创建判别网络
        self.DA = NET_D(image_size = image_size, ndf = ndf)
        self.DB = NET_D(image_size = image_size, ndf = ndf)

        #获取真实、伪造和复原的A类图和B类图变量
        realA, realB = self.GB.inputs[0],  self.GA.inputs[0]
        fakeB, fakeA = self.GB.outputs[0], self.GA.outputs[0]
        recA,  recB  = self.GA([fakeB]),   self.GB([fakeA])

        #获取由真实图片生成伪造图片和复原图片的函数
        self.cycleA = K.function([realA], [fakeB,recA])
        self.cycleB = K.function([realB], [fakeA,recB])

        #获得判别网络判别真实图片和伪造图片的结果
        DrealA, DrealB = self.DA([realA]), self.DB([realB])
        DfakeA, DfakeB = self.DA([fakeA]), self.DB([fakeB])

        #用生成网络和判别网络的结果计算损失函数
        lossDA, lossGA, lossCycA = self.get_loss(DrealA, DfakeA, realA, recA)
        lossDB, lossGB, lossCycB = self.get_loss(DrealB, DfakeB, realB, recB)

        lossG = lossGA + lossGB + lambda_cyc * (lossCycA + lossCycB)
        lossD = lossDA + lossDB

        #获取参数更新器
        updaterG = Adam(lr = lrG, beta_1=0.5).get_updates(self.GA.trainable_weights + self.GB.trainable_weights, [], lossG)
        updaterD = Adam(lr = lrD, beta_1=0.5).get_updates(self.DA.trainable_weights + self.DB.trainable_weights, [], lossD)
        
        #创建训练函数，可以通过调用这两个函数来训练网络
        self.trainG = K.function([realA, realB], [lossGA, lossGB, lossCycA, lossCycB], updaterG)
        self.trainD = K.function([realA, realB], [lossDA, lossDB], updaterD)
    
    
    def get_loss(self, Dreal, Dfake, real , rec):
        """
        获取网络中的损失函数
        """
        lossD = loss_func(Dreal, K.ones_like(Dreal)) + loss_func(Dfake, K.zeros_like(Dfake))
        lossG = loss_func(Dfake, K.ones_like(Dfake))
        lossCyc = K.mean(K.abs(real - rec))
        return lossD, lossG, lossCyc
    
    def save(self, path="./models/model"):
        self.GA.save("{}-GA.h5".format(path))
        self.GB.save("{}-GB.h5".format(path))
        self.DA.save("{}-DA.h5".format(path))
        self.DB.save("{}-DB.h5".format(path))

    def train(self, A, B):
        errDA, errDB = self.trainD([A, B])
        errGA, errGB, errCycA, errCycB = self.trainG([A, B])
        return errDA, errDB, errGA, errGB, errCycA, errCycB

## 训练相关代码

在这里，我们提供了一个snapshot函数，可以在训练的过程中生成预览效果。

In [None]:
#输入神经网络的图片尺寸
IMG_SIZE = 128

#数据集名称
DATASET = "vangogh2photo"

#数据集路径
dataset_path = "./data/{}/".format(DATASET)


trainA_path = dataset_path + "trainA/*.jpg"
trainB_path = dataset_path + "trainB/*.jpg"

In [None]:
train_A = DataSet(trainA_path, image_size = IMG_SIZE)
train_B = DataSet(trainB_path, image_size = IMG_SIZE)

def train_batch(batchsize):
    """
    从数据集中取出一个Batch
    """
    epa, a = train_A.get_batch(batchsize)
    epb, b = train_B.get_batch(batchsize)
    return max(epa, epb), a, b

In [None]:
def gen(generator, X):
    r = np.array([generator([np.array([x])]) for x in X])
    g = r[:,0,0]
    rec = r[:,1,0]
    return g, rec 

def snapshot(cycleA, cycleB, A, B):        
    """
    产生一个快照
    
    A、B是两个图片列表
    cycleA是 A->B->A的一个循环
    cycleB是 B->A->B的一个循环
    
    输出一张图片：
    +-----------+     +-----------+
    | X (in A)  | ... |  Y (in B) | ...
    +-----------+     +-----------+
    |   GB(X)   | ... |   GA(Y)   | ...
    +-----------+     +-----------+
    | GA(GB(X)) | ... | GB(GA(Y)) | ...
    +-----------+     +-----------+
    """
    gA, recA = gen(cycleA, A)
    gB, recB = gen(cycleB, B)

    lines = [
        np.concatenate(A.tolist()+B.tolist(), axis = 1),
        np.concatenate(gA.tolist()+gB.tolist(), axis = 1),
        np.concatenate(recA.tolist()+recB.tolist(), axis = 1)
    ]

    arr = np.concatenate(lines)
    return arr2image(arr)

In [None]:
#创建模型
model = CycleGAN(image_size = IMG_SIZE)

In [None]:
#训练代码
import time
start_t = time.time()

EPOCH_NUM = 100
epoch = 0

DISPLAY_INTERVAL = 5
SNAPSHOT_INTERVAL = 50
SAVE_INTERVAL = 200

BATCH_SIZE = 1

iter_cnt = 0
err_sum = np.zeros(6)

while epoch < EPOCH_NUM:       
    epoch, A, B = train_batch(BATCH_SIZE) 
    err  = model.train(A, B)
    err_sum += np.array(err)

    iter_cnt += 1

    if (iter_cnt % DISPLAY_INTERVAL == 0):
        err_avg = err_sum / DISPLAY_INTERVAL
        print('[迭代%d] 判别损失: A %f B %f 生成损失: A %f B %f 循环损失: A %f B %f'
        % (iter_cnt, 
        err_avg[0], err_avg[1], err_avg[2], err_avg[3], err_avg[4], err_avg[5]),
        )      
        err_sum = np.zeros_like(err_sum)


    if (iter_cnt % SNAPSHOT_INTERVAL == 0):
        A = train_A.get_pics(4)
        B = train_B.get_pics(4)
        display(snapshot(model.cycleA, model.cycleB, A, B))

    if (iter_cnt % SAVE_INTERVAL == 0):
        model.save(path = "./models/model-{}".format(iter_cnt))
