In [1]:
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import sys, os

from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.layers import Conv2D, MaxPooling2D, UpSampling2D, ZeroPadding2D
from keras.layers import Input, Dense, Flatten, Reshape, Dropout
from keras.layers import BatchNormalization, Activation
from keras.layers import LeakyReLU
from keras.callbacks import EarlyStopping
from keras.optimizers import SGD, Adam, RMSprop
from keras.utils import np_utils
import keras.backend as K

%matplotlib inline

Using TensorFlow backend.


WGAN对GAN的改进:

- 判别器最后一层去掉sigmoid

- 生成器和判别器的loss不取log

- 更新后的权重强制截断到一定范围内，比如[-0.01，0.01]，以满足论文中提到的lipschitz连续性条件。

- 论文中推荐使用SGD， RMSprop等优化器，不要基于使用动量的优化算法(包括momentum和Adam)。



WGAN的作用：

- WGAN理论上给出了GAN训练不稳定的原因，即交叉熵（JS散度）不适合衡量具有不相交部分的分布之间的距离，转而使用wassertein距离去衡量生成数据分布和真实数据分布之间的距离，理论上解决了训练不稳定的问题。
- 解决了模式崩溃的（collapse mode）问题，生成结果多样性更丰富。
- 对GAN的训练提供了一个指标，这个数值越小代表GAN训练得越好，代表生成器产生的图像质量越高。

### generator

In [2]:
def build_generator(input_shape, channels):
    model = Sequential()
    model.add(Dense(128 * 7 * 7, activation="relu", input_shape=input_shape))
    model.add(Reshape((7, 7, 128)))
    model.add(BatchNormalization())
    model.add(UpSampling2D())
    
    model.add(Conv2D(128, kernel_size=4, padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Activation("relu"))
    model.add(UpSampling2D())
    
    model.add(Conv2D(64, kernel_size=4, padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Activation("relu"))
    
    model.add(Conv2D(channels, kernel_size=4, padding="same"))
    model.add(Activation("tanh"))
    model.summary()
    
    input_noise = Input(shape=input_shape)
    gen_v = model(input_noise)
    model_generator = Model(input_noise, gen_v)
    
    return model_generator

### discriminator

In [3]:
def build_discriminator(input_shape):
    model = Sequential()
    model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=input_shape, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    
    model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
    model.add(ZeroPadding2D(padding=((0,1),(0,1))))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    
    model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    
    model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    
    #model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
    #model.add(LeakyReLU(alpha=0.2))
    #model.add(Dropout(0.25))

    model.add(Flatten())
    model.add(Dense(1))
    model.summary()

    input_img = Input(shape=img_shape)
    validity = model(input_img)
    model_discriminator = Model(input_img, validity)
    return model_discriminator

In [4]:
def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)

In [5]:
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
latent_dim = 100
shape = (100,)

# Following parameter and optimizer set as recommended in paper
n_critic = 5
clip_value = 0.01
optimizer = RMSprop(lr=0.00005)

# Build and compile the critic
critic = build_discriminator(img_shape)
critic.compile(loss=wasserstein_loss, optimizer=optimizer, metrics=['accuracy'])

# Build the generator
generator = build_generator(shape, channels)

# The generator takes noise as input and generated imgs
z = Input(shape=shape)
img = generator(z)

# For the combined model we will only train the generator
critic.trainable = False

# The critic takes generated images as input and determines validity
valid = critic(img)

# The combined model  (stacked generator and critic)
combined = Model(z, valid)
combined.compile(loss=wasserstein_loss, optimizer=optimizer, metrics=['accuracy'])

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_1 (Conv2D)            (None, 14, 14, 16)        160       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 14, 14, 16)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 14, 14, 16)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 7, 7, 32)          4640      
_________________________________________________________________
zero_padding2d_1 (ZeroPaddin (None, 8, 8, 32)          0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 8, 8, 32)          128       
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 8, 8, 32)          0         
__________

In [6]:
def save_imgs(epoch):
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, 100))
    gen_imgs = generator.predict(noise)

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 1

    fig, axs = plt.subplots(r, c, figsize=(12,8))
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
            
    if not os.path.exists('images'):
        os.makedirs('images')
        
    fig.savefig("images/mnist_%d.png" % epoch)
    plt.close()

In [7]:
def train(X_train, model, epochs=100, batch_size=64, save_interval=50):
    (generator, critic, combined) = model
    
    # Adversarial ground truths
    valid = -np.ones((batch_size, 1))
    fake = np.ones((batch_size, 1))
    
    for epoch in range(1, epochs+1):
        for step in range(n_critic):
            #  Train Discriminator
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]
            
            # Sample noise as generator input
            noise = np.random.normal(0, 1, (batch_size, 100)) #latent_dim=100
            
            # Generate a batch of new images
            gen_imgs = generator.predict(noise)
            
            # Train the critic
            d_loss_real = critic.train_on_batch(imgs, valid)
            d_loss_fake = critic.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
                
            # Clip critic weights
            for l in critic.layers:
                weights = l.get_weights()
                weights = [np.clip(w, -clip_value, clip_value) for w in weights]
                l.set_weights(weights)
                
        #  Train Generator
        g_loss = combined.train_on_batch(noise, valid)
            
        if (epoch % save_interval == 0) or (epoch == 1):
            save_imgs(epoch)
            print("Epoch: %d, Discriminator Loss: %f, Generator loss: %f" % (epoch, 1-d_loss[0], 1-g_loss[0]))


In [8]:
from keras.datasets import mnist
# Load the dataset
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
# X_train = np.concatenate((X_train, X_test), axis=0)

# Rescale -1 to 1
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
print(X_train.shape)

(60000, 28, 28, 1)


In [None]:
model = (generator, critic, combined)

train(X_train, model, epochs=8000, batch_size=32, save_interval=50)