adapted from https://machinelearningmastery.com/how-to-train-a-progressive-growing-gan-in-keras-for-synthesizing-faces/

In [None]:
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, AveragePooling2D, Add, Layer
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
from tensorflow.python.keras.layers.convolutional import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model, save_model, load_model
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm
from skimage.transform import resize
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.constraints import max_norm
from tensorflow.keras import backend
import IPython.display as ipd

In [None]:
img_shape = (28, 28, 1) 
latent_dim = 100
model_sizes = [(7, 7, 1), (14, 14, 1), (28, 28, 1)]
batch_size = 64
epochs = [10, 20, 30]
g_lr = .00002
d_lr = .00002
label = '4'

In [None]:
X_train = np.load('./data/fashion/' + label + '-train.npy')
# change to shape (28, 28, 1)
X_train = np.expand_dims(X_train, axis=3)

In [None]:
plt.imshow(np.squeeze(X_train[np.random.randint(0, X_train.shape[0])]), cmap='gray')

In [None]:
def wasserstein_loss(y_true, y_pred):
    return backend.mean(y_true * y_pred)

In [None]:
def scale_dataset(images, new_shape):
    images_list = list()
    for image in images:
        # resize with nearest neighbor interpolation
        new_image = resize(image, new_shape, 0)
        images_list.append(new_image)
    return np.asarray(images_list)

In [None]:
class WeightedSum(Add):
    def __init__(self, alpha=0.0, **kwargs):
        super(WeightedSum, self).__init__(**kwargs)
        self.alpha = backend.variable(alpha, name='ws_alpha')

    # output a weighted sum of inputs
    def _merge_function(self, inputs):
        # only supports a weighted sum of two inputs
        assert (len(inputs) == 2)
        output = ((1.0 - self.alpha) * inputs[0]) + (self.alpha * inputs[1])
        return output

In [None]:
class PixelNormalization(Layer):
    # initialize the layer
    def __init__(self, **kwargs):
        super(PixelNormalization, self).__init__(**kwargs)

    # standardize inputs
    def call(self, inputs):
        normalized = (inputs - backend.mean(inputs)) / backend.std(inputs)
        return inputs

    # define the output shape of the layer
    def compute_output_shape(self, input_shape):
        return input_shape

In [None]:
def define_discriminator():
    model_list = []
    
    in_image = Input(shape=model_sizes[0])

    d = Conv2D(64, (1,1), padding='same')(in_image)
    d = LeakyReLU(alpha=0.2)(d)
    
    d = Conv2D(64, (3,3), padding='same')(d)
    d = LeakyReLU(alpha=0.2)(d)
    
    d = Conv2D(64, (4,4), padding='same')(d)
    d = LeakyReLU(alpha=0.2)(d)
    
    d = Flatten()(d)
    out_class = Dense(1)(d)
    
    model = Model(in_image, out_class)
    model.compile(loss=wasserstein_loss, optimizer=Adam(lr=d_lr, beta_1=0, beta_2=0.99, epsilon=10e-8))
    model_list.append([model, model])
    
    # create submodels
    for i in range(1, len(model_sizes)):
        # get prior model without the fade-on
        old_model = model_list[i - 1][0]
        
        # create new model for next resolution
        models = add_discriminator_block(old_model, i)
        model_list.append(models)
    return model_list

In [None]:
def add_discriminator_block(old_model, i, n_input_layers=3):
    # get shape of existing model
    in_shape = list(old_model.input.shape)
    
    # define new input shape
    input_shape = model_sizes[i]
    
    in_image = Input(shape=input_shape)
   
    d = Conv2D(64, (1,1), padding='same')(in_image)
    d = LeakyReLU(alpha=0.2)(d)
    
    d = Conv2D(64, (3,3), padding='same')(d)
    d = LeakyReLU(alpha=0.2)(d)
    
    d = Conv2D(64, (3,3), padding='same')(d)
    d = LeakyReLU(alpha=0.2)(d)
    
    d = AveragePooling2D()(d)
    block_new = d
    
    # skip the input, 1x1 and activation for the old model
    for i in range(n_input_layers, len(old_model.layers)):
        d = old_model.layers[i](d)
    # define straight-through model
    model1 = Model(in_image, d)
    model1.compile(loss=wasserstein_loss, optimizer=Adam(lr=d_lr, beta_1=0, beta_2=0.99, epsilon=10e-8))
    
    # downsample the new larger image
    downsample = AveragePooling2D()(in_image)
    
    # connect old input processing to downsampled new input
    block_old = old_model.layers[1](downsample)
    block_old = old_model.layers[2](block_old)
    
    # fade in output of old model input layer with new input
    d = WeightedSum()([block_old, block_new])
    
    # skip the input, 1x1 and activation for the old model
    for i in range(n_input_layers, len(old_model.layers)):
        d = old_model.layers[i](d)
    # define straight-through model
    model2 = Model(in_image, d)
    model2.compile(loss=wasserstein_loss, optimizer=Adam(lr=d_lr, beta_1=0, beta_2=0.99, epsilon=10e-8))
    return [model1, model2]

In [None]:
def define_generator():
    model_list = []
    
    in_latent = Input(shape=(latent_dim,))
    g  = Dense(64 * model_sizes[0][0] * model_sizes[0][1])(in_latent)
    g = Reshape((model_sizes[0][0], model_sizes[0][1], 64))(g)
    
    g = Conv2D(64, (3,3), padding='same')(g)
    g = PixelNormalization()(g)
    g = LeakyReLU(alpha=0.2)(g)
    
    g = Conv2D(64, (3,3), padding='same')(g)
    g = PixelNormalization()(g)
    g = LeakyReLU(alpha=0.2)(g)
    
    out_image = Conv2D(1, (1,1), padding='same', activation='sigmoid')(g)
    
    model = Model(in_latent, out_image)
    model_list.append([model, model])
    
    # create submodels
    for i in range(1, len(model_sizes)):
        # get prior model without the fade-on
        old_model = model_list[i - 1][0]
        # create new model for next resolution
        models = add_generator_block(old_model)
        model_list.append(models)
    return model_list

In [None]:
def add_generator_block(old_model):
    init = RandomNormal(stddev=0.02)
    const = max_norm(1.0)
    
    # get the end of the last block
    block_end = old_model.layers[-2].output
    
    # upsample, and define new block
    upsampling = UpSampling2D()(block_end)
    g = Conv2D(64, (3,3), padding='same', )(upsampling)
    g = PixelNormalization()(g)
    g = LeakyReLU(alpha=0.2)(g)
    
    g = Conv2D(64, (3,3), padding='same', )(g)
    g = PixelNormalization()(g)
    g = LeakyReLU(alpha=0.2)(g)
    
    out_image = Conv2D(1, (1,1), padding='same', activation='sigmoid')(g)
    
    model1 = Model(old_model.input, out_image)
    
    # get the output layer from old model
    out_old = old_model.layers[-1]
    # connect the upsampling to the old output layer
    out_image2 = out_old(upsampling)
    # define new output image as the weighted sum of the old and new models
    merged = WeightedSum()([out_image2, out_image])
    
    model2 = Model(old_model.input, merged)
    return [model1, model2]

In [None]:
def define_composite(discriminators, generators):
    model_list = []
    
    for i in range(len(discriminators)):
        g_models, d_models = generators[i], discriminators[i]
        
        # straight-through model
        d_models[0].trainable = False
        model1 = Sequential()
        model1.add(g_models[0])
        model1.add(d_models[0])
        model1.compile(loss=wasserstein_loss, optimizer=Adam(lr=g_lr, beta_1=0, beta_2=0.99, epsilon=10e-8))
        
        # fade-in model
        d_models[1].trainable = False
        model2 = Sequential()
        model2.add(g_models[1])
        model2.add(d_models[1])
        model2.compile(loss=wasserstein_loss, optimizer=Adam(lr=g_lr, beta_1=0, beta_2=0.99, epsilon=10e-8))
        
        model_list.append([model1, model2])
    return model_list

In [None]:
def generate_real_samples(dataset, n_samples):
    i = np.random.randint(0, dataset.shape[0], n_samples)
    X = dataset[i]
    y = np.ones((n_samples, 1))
    return X, y

In [None]:
def generate_fake_samples(generator, n_samples):
    noise = np.random.normal(0, 1, (n_samples, latent_dim))
    generated_images = generator.predict(noise)
    y = -np.ones((n_samples, 1))
    return generated_images, y

In [None]:
def update_fadein(models, step, n_steps):
    # calculate current alpha (linear from 0 to 1)
    alpha = step / float(n_steps - 1)
    # update the alpha for each model
    for model in models:
        for layer in model.layers:
            if isinstance(layer, WeightedSum):
                backend.set_value(layer.alpha, alpha)

In [None]:
def train(g_models, d_models, gan_models, dataset, latent_dim):
    # fit the baseline model
    g_normal, d_normal, gan_normal = g_models[0][0], d_models[0][0], gan_models[0][0]
    
    # scale dataset to appropriate size
    gen_shape = g_normal.output_shape
    scaled_data = scale_dataset(dataset, gen_shape[1:])
    print('Data Size:', scaled_data.shape)
    
    # train normal or straight-through models
    train_epochs(g_normal, d_normal, gan_normal, scaled_data, epochs[0])
    save_images('tuned', gen_shape, g_normal)
    
    for i in range(1, len(g_models)):
        # retrieve models for this level of growth
        [g_normal, g_fadein] = g_models[i]
        [d_normal, d_fadein] = d_models[i]
        [gan_normal, gan_fadein] = gan_models[i]
        # scale dataset to appropriate size
        gen_shape = g_normal.output_shape
        scaled_data = scale_dataset(dataset, gen_shape[1:])
        print('Data Size:', scaled_data.shape)
        # train fade-in models for next level of growth
        train_epochs(g_fadein, d_fadein, gan_fadein, scaled_data, epochs[i], True)
        save_images('faded', gen_shape, g_fadein)
        # train normal or straight-through models
        train_epochs(g_normal, d_normal, gan_normal, scaled_data, epochs[i])
        save_images('tuned', gen_shape, g_normal)

In [None]:
def train_epochs(g_model, d_model, gan_model, dataset, n_epochs, fadein=False):
    # calculate the number of training iterations
    n_steps = int(dataset.shape[0] / batch_size) * n_epochs
    
    for i in range(n_steps):
        # update alpha for all WeightedSum layers when fading in new blocks
        if fadein:
            update_fadein([g_model, d_model, gan_model], i, n_steps)
        
        X_real, y_real = generate_real_samples(dataset, batch_size)
        X_fake, y_fake = generate_fake_samples(g_model, batch_size)
        
        # update discriminator model
        d_loss_real = d_model.train_on_batch(X_real, y_real)
        d_loss_fake = d_model.train_on_batch(X_fake, y_fake)
        
        # update the generator via the discriminator's error
        z_input = np.random.normal(0, 1, (batch_size, latent_dim))
        y_real2 = np.ones((batch_size, 1))
        g_loss = gan_model.train_on_batch(z_input, y_real2)
        
        # save images every 100 epochs
        epoch = i / int(dataset.shape[0] / batch_size)
        if epoch % 150 == 0:
            if fadein:
                save_images('faded' + str(int(epoch)), g_model.output_shape, g_model)
                print('dreal=%.3f, dfake=%.3f, g=%.3f' % (d_loss_real, d_loss_fake, g_loss))
            else:
                save_images('tuned' + str(int(epoch)), g_model.output_shape, g_model)
                print('dreal=%.3f, dfake=%.3f, g=%.3f' % (d_loss_real, d_loss_fake, g_loss))
        
    print('dreal=%.3f, dfake=%.3f, g=%.3f' % (d_loss_real, d_loss_fake, g_loss))

In [None]:
def save_images(status, gen_shape, generator):
    rows, cols = 2, 2
    noise = np.random.normal(0, 1, (rows * cols, latent_dim))
    gen_imgs = generator.predict(noise)

    fig, axs = plt.subplots(rows, cols)
    cnt = 1
    for i in range(rows * cols):
        plt.subplot(rows, cols, cnt)
        plt.imshow(gen_imgs[i, :,:,0])
        cnt += 1
    name = 'images/%02dx%02d-%s' % (gen_shape[1], gen_shape[2], status)
    fig.savefig(name)
    plt.close()

In [None]:
d_models = define_discriminator()
g_models = define_generator()
gan_models = define_composite(d_models, g_models)

In [None]:
train(g_models, d_models, gan_models, X_train, latent_dim)

In [None]:
noise = np.random.normal(0, 1, (1, latent_dim))
generated_image = g_models[2][0].predict(noise)
plt.imshow(np.squeeze(generated_image), cmap='gray')

In [None]:
generator = g_models[2][0]
discriminator = d_models[2][0]
gan = gan_models[2][0]

In [None]:
def train_final_model(epochs):
    g_loss_list = [] 
    dfake_list = []
    dreal_list = []
    scaled_data = scale_dataset(X_train, (28, 28, 1))
    for epoch in range(epochs):
        X_real, y_real = generate_real_samples(scaled_data, batch_size)
        X_fake, y_fake = generate_fake_samples(generator, batch_size)
        
        # update discriminator model
        d_loss_real = discriminator.train_on_batch(X_real, y_real)
        d_loss_fake = discriminator.train_on_batch(X_fake, y_fake)
        acc = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        # update the generator via the discriminator's error
        z_input = np.random.normal(0, 1, (batch_size, latent_dim))
        y_real2 = np.ones((batch_size, 1))
        g_loss = gan.train_on_batch(z_input, y_real2)
        g_loss_list.append(g_loss)
        dreal_list.append(d_loss_real)
        dfake_list.append(d_loss_fake)
        
        if epoch % 150 == 0:
            print(str(epoch) + ' dreal=%.5f, dfake=%.5f, g=%.5f' % (d_loss_real, d_loss_fake, g_loss))
            save_images('final' + str(epoch), generator.output_shape, generator)
                   
    return g_loss_list, dreal_list, dfake_list

In [None]:
real = np.expand_dims(X_train[np.random.randint(0, X_train.shape[0])], axis=0)
print(discriminator.predict(real))

In [None]:
g_loss, fake_loss, real_loss = train_final_model(551)

In [None]:
noise = np.random.normal(0, 1, (1, latent_dim))
generated_image = generator.predict(noise)

print(discriminator.predict(generated_image))
plt.imshow(np.squeeze(generated_image), cmap='gray')

In [None]:
plt.plot(np.abs(g_loss))
plt.plot(np.abs(fake_loss))
plt.plot(np.abs(real_loss))
plt.legend(['g_loss', 'fake_loss', 'real_loss'])

In [None]:
def generate_images(n, generator):
    noise = np.random.normal(0, 1, (n, latent_dim))
    generated_images = np.squeeze(generator.predict(noise))
    np.save('data/fashion/' + label + '-generated',  generated_images)
    
generate_images(1500, generator)    