### [机器之心GitHub项目：GAN完整理论推导与实现](https://zhuanlan.zhihu.com/p/29837245)

[原notebookGITHUB](https://github.com/jiqizhixin/ML-Tutorial-Experiment/blob/master/Experiments/Keras_GAN.ipynb)

In [1]:
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Flatten
from keras.optimizers import SGD
from keras.datasets import mnist
import numpy as np
from PIL import Image
import argparse
import math

Using TensorFlow backend.


In [2]:
def generator_model():
    #下面搭建生成器的架构，首先导入序贯模型（sequential），即多个网络层的线性堆叠
    model = Sequential()
    #添加一个全连接层，输入为100维向量，输出为1024维
    model.add(Dense(input_dim=100, output_dim=1024))
    #添加一个激活函数tanh
    model.add(Activation('tanh'))
    #添加一个全连接层，输出为128×7×7维度
    model.add(Dense(128*7*7))
    #添加一个批量归一化层，该层在每个batch上将前一层的激活值重新规范化，即使得其输出数据的均值接近0，其标准差接近1
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    #Reshape层用来将输入shape转换为特定的shape，将含有128*7*7个元素的向量转化为7×7×128张量
    model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
    #2维上采样层，即将数据的行和列分别重复2次
    model.add(UpSampling2D(size=(2, 2)))
    #添加一个2维卷积层，卷积核大小为5×5，激活函数为tanh，共64个卷积核，并采用padding以保持图像尺寸不变
    model.add(Conv2D(64, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    model.add(UpSampling2D(size=(2, 2)))
    #卷积核设为1即输出图像的维度
    model.add(Conv2D(1, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    return model

In [3]:
def discriminator_model():
    #下面搭建判别器架构，同样采用序贯模型
    model = Sequential()
    
    #添加2维卷积层，卷积核大小为5×5，激活函数为tanh，输入shape在‘channels_first’模式下为（samples,channels，rows，cols）
    #在‘channels_last’模式下为（samples,rows,cols,channels），输出为64维
    model.add(
            Conv2D(64, (5, 5),
            padding='same',
            input_shape=(28, 28, 1))
            )
    model.add(Activation('tanh'))
    #为空域信号施加最大值池化，pool_size取（2，2）代表使图片在两个维度上均变为原长的一半
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(128, (5, 5)))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    #Flatten层把多维输入一维化，常用在从卷积层到全连接层的过渡
    model.add(Flatten())
    model.add(Dense(1024))
    model.add(Activation('tanh'))
    #一个结点进行二值分类，并采用sigmoid函数的输出作为概念
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model

In [4]:
def generator_containing_discriminator(g, d):
    #将前面定义的生成器架构和判别器架构组拼接成一个大的神经网络，用于判别生成的图片
    model = Sequential()
    #先添加生成器架构，再令d不可训练，即固定d
    #因此在给定d的情况下训练生成器，即通过将生成的结果投入到判别器进行辨别而优化生成器
    model.add(g)
    d.trainable = False
    model.add(d)
    return model

In [5]:
def combine_images(generated_images):
    #生成图片拼接
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num)/width))
    shape = generated_images.shape[1:3]
    image = np.zeros((height*shape[0], width*shape[1]),
                     dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \
            img[:, :, 0]
    return image

#### 对于每一次迭代：

-  从真实数据分布 P_data 抽取 m 个样本
-  从先验分布 P_prior(z) 抽取 m 个噪声样本
-  将噪声样本投入 G 而生成数据，即x^tilde = G(Z^i)；通过最大化 V 的近似而更新判别器参数θ_d

以上是学习判别器 D 的过程。因为学习 D 的过程是计算 JS 散度的过程，并且我们希望能最大化价值函数，所以该步骤会重复 k 次。

-  从先验分布 P_prior(z) 中抽取另外 m 个噪声样本 {z^1,...,z^m}
-  通过极小化 V^tilde 而更新生成器参数θ_g

In [6]:
def train(BATCH_SIZE):
    
    # 国内好像不能直接导入数据集，我们试了几次都不行，后来将数据集下载到本地'~/.keras/datasets/'，也就是当前目录（我的是用户文件夹下）下的.keras文件夹中。
    #下载的地址为：https://s3.amazonaws.com/img-datasets/mnist.npz
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    #iamge_data_format选择"channels_last"或"channels_first"，该选项指定了Keras将要使用的维度顺序。
    #"channels_first"假定2D数据的维度顺序为(channels, rows, cols)，3D数据的维度顺序为(channels, conv_dim1, conv_dim2, conv_dim3)
    
    #转换字段类型，并将数据导入变量中
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    X_train = X_train[:, :, :, None]
    X_test = X_test[:, :, :, None]
    # X_train = X_train.reshape((X_train.shape, 1) + X_train.shape[1:])
    
    #将定义好的模型架构赋值给特定的变量
    d = discriminator_model()
    g = generator_model()
    d_on_g = generator_containing_discriminator(g, d)
    
    #定义生成器模型判别器模型更新所使用的优化算法及超参数
    d_optim = SGD(lr=0.001, momentum=0.9, nesterov=True)
    g_optim = SGD(lr=0.001, momentum=0.9, nesterov=True)
    
    #编译三个神经网络并设置损失函数和优化算法，其中损失函数都是用的是二元分类交叉熵函数。编译是用来配置模型学习过程的
    g.compile(loss='binary_crossentropy', optimizer="SGD")
    d_on_g.compile(loss='binary_crossentropy', optimizer=g_optim)
    
    #前一个架构在固定判别器的情况下训练了生成器，所以在训练判别器之前先要设定其为可训练。
    d.trainable = True
    d.compile(loss='binary_crossentropy', optimizer=d_optim)
    
    #下面在满足epoch条件下进行训练
    for epoch in range(30):
        print("Epoch is", epoch)
        
        #计算一个epoch所需要的迭代数量，即训练样本数除批量大小数的值取整；其中shape[0]就是读取矩阵第一维度的长度
        print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))
        
        #在一个epoch内进行迭代训练
        for index in range(int(X_train.shape[0]/BATCH_SIZE)):
            
            #随机生成的噪声服从均匀分布，且采样下界为-1、采样上界为1，输出BATCH_SIZE×100个样本；即抽取一个批量的随机样本
            noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))
            
            #抽取一个批量的真实图片
            image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
            
            #生成的图片使用生成器对随机噪声进行推断；verbose为日志显示，0为不在标准输出流输出日志信息，1为输出进度条记录
            generated_images = g.predict(noise, verbose=0)
            
            #每经过100次迭代输出一张生成的图片
            if index % 100 == 0:
                image = combine_images(generated_images)
                image = image*127.5+127.5
                Image.fromarray(image.astype(np.uint8)).save("./GAN/"+str(epoch)+"_"+str(index)+".png")
            
            #将真实的图片和生成的图片以多维数组的形式拼接在一起，真实图片在上，生成图片在下
            X = np.concatenate((image_batch, generated_images))
            
            #生成图片真假标签，即一个包含两倍批量大小的列表；前一个批量大小都是1，代表真实图片，后一个批量大小都是0，代表伪造图片
            y = [1] * BATCH_SIZE + [0] * BATCH_SIZE
            
            #判别器的损失；在一个batch的数据上进行一次参数更新
            d_loss = d.train_on_batch(X, y)
            print("batch %d d_loss : %f" % (index, d_loss))
            
            #随机生成的噪声服从均匀分布
            noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
            
            #固定判别器
            d.trainable = False
            
            #计算生成器损失；在一个batch的数据上进行一次参数更新
            g_loss = d_on_g.train_on_batch(noise, [1] * BATCH_SIZE)
            
            #令判别器可训练
            d.trainable = True
            print("batch %d g_loss : %f" % (index, g_loss))
            
            #每100次迭代保存一次生成器和判别器的权重
            if index % 100 == 9:
                g.save_weights('generator', True)
                d.save_weights('discriminator', True)

In [17]:
train(128)

  """


Epoch is 0
Number of batches 468
batch 0 d_loss : 0.651978
batch 0 g_loss : 0.653149
batch 1 d_loss : 0.639725
batch 1 g_loss : 0.645179
batch 2 d_loss : 0.624481
batch 2 g_loss : 0.633186
batch 3 d_loss : 0.600257
batch 3 g_loss : 0.619147
batch 4 d_loss : 0.581192
batch 4 g_loss : 0.603321
batch 5 d_loss : 0.554355
batch 5 g_loss : 0.597126
batch 6 d_loss : 0.537645
batch 6 g_loss : 0.584211
batch 7 d_loss : 0.520177
batch 7 g_loss : 0.567349
batch 8 d_loss : 0.503279
batch 8 g_loss : 0.559891
batch 9 d_loss : 0.493053
batch 9 g_loss : 0.545466
batch 10 d_loss : 0.484570
batch 10 g_loss : 0.540249
batch 11 d_loss : 0.469190
batch 11 g_loss : 0.531435
batch 12 d_loss : 0.463894
batch 12 g_loss : 0.526381
batch 13 d_loss : 0.453370
batch 13 g_loss : 0.520574
batch 14 d_loss : 0.458233
batch 14 g_loss : 0.508590
batch 15 d_loss : 0.457384
batch 15 g_loss : 0.507561
batch 16 d_loss : 0.458243
batch 16 g_loss : 0.506943
batch 17 d_loss : 0.461206
batch 17 g_loss : 0.498258
batch 18 d_loss

batch 150 d_loss : 0.167415
batch 150 g_loss : 0.651569
batch 151 d_loss : 0.137605
batch 151 g_loss : 0.665578
batch 152 d_loss : 0.173832
batch 152 g_loss : 0.636862
batch 153 d_loss : 0.166983
batch 153 g_loss : 0.692957
batch 154 d_loss : 0.196354
batch 154 g_loss : 0.675350
batch 155 d_loss : 0.231310
batch 155 g_loss : 0.722934
batch 156 d_loss : 0.242196
batch 156 g_loss : 0.643948
batch 157 d_loss : 0.230547
batch 157 g_loss : 0.592723
batch 158 d_loss : 0.274976
batch 158 g_loss : 0.459430
batch 159 d_loss : 0.269870
batch 159 g_loss : 0.470483
batch 160 d_loss : 0.263416
batch 160 g_loss : 0.539096
batch 161 d_loss : 0.268544
batch 161 g_loss : 0.704852
batch 162 d_loss : 0.343421
batch 162 g_loss : 0.596352
batch 163 d_loss : 0.365035
batch 163 g_loss : 0.710933
batch 164 d_loss : 0.514752
batch 164 g_loss : 0.526133
batch 165 d_loss : 0.479948
batch 165 g_loss : 0.599979
batch 166 d_loss : 0.506345
batch 166 g_loss : 0.443901
batch 167 d_loss : 0.550222
batch 167 g_loss : 0

batch 296 g_loss : 1.060447
batch 297 d_loss : 0.483550
batch 297 g_loss : 1.064798
batch 298 d_loss : 0.549521
batch 298 g_loss : 1.089615
batch 299 d_loss : 0.555634
batch 299 g_loss : 1.092007
batch 300 d_loss : 0.525216
batch 300 g_loss : 1.159459
batch 301 d_loss : 0.566051
batch 301 g_loss : 1.063646
batch 302 d_loss : 0.521454
batch 302 g_loss : 1.034525
batch 303 d_loss : 0.508925
batch 303 g_loss : 1.084675
batch 304 d_loss : 0.512438
batch 304 g_loss : 1.098975
batch 305 d_loss : 0.523531
batch 305 g_loss : 1.126875
batch 306 d_loss : 0.529770
batch 306 g_loss : 1.128532
batch 307 d_loss : 0.586404
batch 307 g_loss : 1.102218
batch 308 d_loss : 0.598007
batch 308 g_loss : 1.010792
batch 309 d_loss : 0.532905
batch 309 g_loss : 1.073311
batch 310 d_loss : 0.561134
batch 310 g_loss : 0.985234
batch 311 d_loss : 0.552264
batch 311 g_loss : 0.957160
batch 312 d_loss : 0.495653
batch 312 g_loss : 1.010074
batch 313 d_loss : 0.516226
batch 313 g_loss : 1.091649
batch 314 d_loss : 0

batch 443 d_loss : 0.466241
batch 443 g_loss : 1.235449
batch 444 d_loss : 0.432919
batch 444 g_loss : 1.237295
batch 445 d_loss : 0.467291
batch 445 g_loss : 1.117564
batch 446 d_loss : 0.454384
batch 446 g_loss : 1.133072
batch 447 d_loss : 0.442760
batch 447 g_loss : 1.202503
batch 448 d_loss : 0.446953
batch 448 g_loss : 1.182075
batch 449 d_loss : 0.440620
batch 449 g_loss : 1.158328
batch 450 d_loss : 0.427054
batch 450 g_loss : 1.304453
batch 451 d_loss : 0.390578
batch 451 g_loss : 1.255840
batch 452 d_loss : 0.463106
batch 452 g_loss : 1.161743
batch 453 d_loss : 0.440882
batch 453 g_loss : 1.329164
batch 454 d_loss : 0.384138
batch 454 g_loss : 1.230767
batch 455 d_loss : 0.446092
batch 455 g_loss : 1.262440
batch 456 d_loss : 0.494039
batch 456 g_loss : 1.256007
batch 457 d_loss : 0.429485
batch 457 g_loss : 1.261782
batch 458 d_loss : 0.329171
batch 458 g_loss : 1.358471
batch 459 d_loss : 0.429144
batch 459 g_loss : 1.215102
batch 460 d_loss : 0.414207
batch 460 g_loss : 1

batch 125 d_loss : 0.386293
batch 125 g_loss : 1.228069
batch 126 d_loss : 0.348796
batch 126 g_loss : 1.888906
batch 127 d_loss : 0.354609
batch 127 g_loss : 1.376099
batch 128 d_loss : 0.294960
batch 128 g_loss : 1.840254
batch 129 d_loss : 0.299020
batch 129 g_loss : 1.371385
batch 130 d_loss : 0.319235
batch 130 g_loss : 1.799068
batch 131 d_loss : 0.376376
batch 131 g_loss : 0.890320
batch 132 d_loss : 0.398667
batch 132 g_loss : 2.203684
batch 133 d_loss : 0.413587
batch 133 g_loss : 0.852041
batch 134 d_loss : 0.401793
batch 134 g_loss : 2.286932
batch 135 d_loss : 0.392231
batch 135 g_loss : 0.995199
batch 136 d_loss : 0.398062
batch 136 g_loss : 2.182449
batch 137 d_loss : 0.568440
batch 137 g_loss : 0.301770
batch 138 d_loss : 0.813412
batch 138 g_loss : 3.461815
batch 139 d_loss : 0.778626
batch 139 g_loss : 0.420164
batch 140 d_loss : 0.647006
batch 140 g_loss : 3.328544
batch 141 d_loss : 0.601103
batch 141 g_loss : 0.734650
batch 142 d_loss : 0.485911
batch 142 g_loss : 2

batch 271 g_loss : 0.918489
batch 272 d_loss : 0.527755
batch 272 g_loss : 1.268314
batch 273 d_loss : 0.468476
batch 273 g_loss : 1.588399
batch 274 d_loss : 0.487079
batch 274 g_loss : 0.941471
batch 275 d_loss : 0.468743
batch 275 g_loss : 1.553002
batch 276 d_loss : 0.536337
batch 276 g_loss : 0.975147
batch 277 d_loss : 0.535105
batch 277 g_loss : 1.390451
batch 278 d_loss : 0.443892
batch 278 g_loss : 1.115123
batch 279 d_loss : 0.421680
batch 279 g_loss : 1.577903
batch 280 d_loss : 0.516056
batch 280 g_loss : 0.786304
batch 281 d_loss : 0.518448
batch 281 g_loss : 1.690316
batch 282 d_loss : 0.636974
batch 282 g_loss : 0.981885
batch 283 d_loss : 0.581746
batch 283 g_loss : 1.355654
batch 284 d_loss : 0.437106
batch 284 g_loss : 1.270923
batch 285 d_loss : 0.472065
batch 285 g_loss : 1.297708
batch 286 d_loss : 0.467554
batch 286 g_loss : 1.413678
batch 287 d_loss : 0.438006
batch 287 g_loss : 1.309584
batch 288 d_loss : 0.426302
batch 288 g_loss : 1.428411
batch 289 d_loss : 0

batch 418 d_loss : 0.411667
batch 418 g_loss : 1.580013
batch 419 d_loss : 0.340878
batch 419 g_loss : 1.542480
batch 420 d_loss : 0.302454
batch 420 g_loss : 1.610459
batch 421 d_loss : 0.289563
batch 421 g_loss : 1.764639
batch 422 d_loss : 0.366280
batch 422 g_loss : 1.823069
batch 423 d_loss : 0.278884
batch 423 g_loss : 1.847920
batch 424 d_loss : 0.329258
batch 424 g_loss : 1.763125
batch 425 d_loss : 0.398736
batch 425 g_loss : 1.786752
batch 426 d_loss : 0.378725
batch 426 g_loss : 1.878575
batch 427 d_loss : 0.486311
batch 427 g_loss : 1.576686
batch 428 d_loss : 0.382331
batch 428 g_loss : 1.690217
batch 429 d_loss : 0.433445
batch 429 g_loss : 1.712208
batch 430 d_loss : 0.359642
batch 430 g_loss : 1.746453
batch 431 d_loss : 0.425347
batch 431 g_loss : 1.652702
batch 432 d_loss : 0.392898
batch 432 g_loss : 1.776468
batch 433 d_loss : 0.512220
batch 433 g_loss : 1.169250
batch 434 d_loss : 0.624500
batch 434 g_loss : 1.552957
batch 435 d_loss : 0.589738
batch 435 g_loss : 1

batch 100 d_loss : 0.508856
batch 100 g_loss : 0.657094
batch 101 d_loss : 0.688800
batch 101 g_loss : 2.557090
batch 102 d_loss : 0.558131
batch 102 g_loss : 0.923550
batch 103 d_loss : 0.407503
batch 103 g_loss : 2.474998
batch 104 d_loss : 0.364268
batch 104 g_loss : 1.392776
batch 105 d_loss : 0.331237
batch 105 g_loss : 2.363018
batch 106 d_loss : 0.442097
batch 106 g_loss : 0.677369
batch 107 d_loss : 0.530164
batch 107 g_loss : 2.934262
batch 108 d_loss : 0.467695
batch 108 g_loss : 1.056248
batch 109 d_loss : 0.393472
batch 109 g_loss : 2.199952
batch 110 d_loss : 0.349373
batch 110 g_loss : 1.699593
batch 111 d_loss : 0.340389
batch 111 g_loss : 1.577768
batch 112 d_loss : 0.349974
batch 112 g_loss : 1.732228
batch 113 d_loss : 0.375946
batch 113 g_loss : 1.099101
batch 114 d_loss : 0.390269
batch 114 g_loss : 2.545557
batch 115 d_loss : 0.411453
batch 115 g_loss : 1.069388
batch 116 d_loss : 0.393594
batch 116 g_loss : 1.717600
batch 117 d_loss : 0.409206
batch 117 g_loss : 1

KeyboardInterrupt: 

跑30次迭代每次有400多batch时间太长，所以在第一个100处保存了就先退出了。

但是看了下下面生成的效果不好，还是再继续跑跑吧..吃饭去了先

In [18]:
def generate(BATCH_SIZE, nice= False ):
    #训练完模型后，可以运行该函数生成图片
    g = generator_model()
    g.compile(loss='binary_crossentropy', optimizer="SGD")
    g.load_weights('generator')
    if nice:
        d = discriminator_model()
        d.compile(loss='binary_crossentropy', optimizer="SGD")
        d.load_weights('discriminator')
        noise = np.random.uniform(-1, 1, (BATCH_SIZE*20, 100))
        generated_images = g.predict(noise, verbose=1)
        d_pret = d.predict(generated_images, verbose=1)
        index = np.arange(0, BATCH_SIZE*20)
        index.resize((BATCH_SIZE*20, 1))
        pre_with_index = list(np.append(d_pret, index, axis=1))
        pre_with_index.sort(key=lambda x: x[0], reverse=True)
        nice_images = np.zeros((BATCH_SIZE,) + generated_images.shape[1:3], dtype=np.float32)
        nice_images = nice_images[:, :, :, None]
        for i in range(BATCH_SIZE):
            idx = int(pre_with_index[i][1])
            nice_images[i, :, :, 0] = generated_images[idx, :, :, 0]
        image = combine_images(nice_images)
    else:
        noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
        generated_images = g.predict(noise, verbose=0)
        image = combine_images(generated_images)
    image = image*127.5+127.5
    Image.fromarray(image.astype(np.uint8)).save("./MNIST_data/generated_image.png")

In [19]:
generate(128)

  """


初始：
![start](./data/0_0.png)

没跑完，结果还是可以看出一个大致的轮廓：

![result](./data/2_100.png)