Skip to content
Permalink
master
Switch branches/tags
Go to file
 
 
Cannot retrieve contributors at this time
136 lines (103 sloc) 3.79 KB
import keras.backend as K
import matplotlib.pyplot as plot
import numpy as np
import tensorflow as tf
from keras.layers import (Activation, BatchNormalization, Conv2D, Dense,
Flatten, Input, LeakyReLU, Reshape, UpSampling2D)
from keras.models import Model
from keras.optimizers import Adam
from tensorflow.examples.tutorials.mnist import input_data
gpu_options = tf.GPUOptions(allow_growth=True)
session = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
K.set_session(session)
# Supress warnings about wrong compilation of TensorFlow.
tf.logging.set_verbosity(tf.logging.ERROR)
noise_size = 100
## G
z = Input(shape=[noise_size])
G = Dense(7 * 7 * 256)(z)
G = BatchNormalization(momentum=0.9)(G)
G = LeakyReLU(alpha=0.2)(G)
G = Reshape((7, 7, 256))(G)
G = UpSampling2D()(G)
G = Conv2D(128, (5, 5), padding='same')(G)
G = BatchNormalization(momentum=0.9)(G)
G = LeakyReLU(alpha=0.2)(G)
G = UpSampling2D()(G)
G = Conv2D(64, (5, 5), padding='same')(G)
G = BatchNormalization(momentum=0.9)(G)
G = LeakyReLU(alpha=0.2)(G)
G = Conv2D(32, (5, 5), padding='same')(G)
G = BatchNormalization(momentum=0.9)(G)
G = LeakyReLU(alpha=0.2)(G)
G = Conv2D(1, (5, 5), padding='same')(G)
G = Activation('tanh')(G)
## D
x = Input(shape=(28, 28, 1))
D = Conv2D(32, (5, 5), strides=(2, 2), padding='same')(x)
D = LeakyReLU(alpha=0.2)(D)
D = Conv2D(64, (5, 5), strides=(2, 2), padding='same')(D)
D = LeakyReLU(alpha=0.2)(D)
D = Conv2D(128, (5, 5), strides=(2, 2), padding='same')(D)
D = LeakyReLU(alpha=0.2)(D)
D = Conv2D(256, (5, 5), padding='same')(D)
D = LeakyReLU(alpha=0.2)(D)
D = Flatten()(D)
D = Dense(1)(D)
# No Sigmoid for LSGAN
generator = Model(z, G)
discriminator = Model(x, D)
discriminator.compile(
loss='mean_squared_error',
optimizer=Adam(lr=5e-4, beta_1=0.5, decay=2e-7))
discriminator.trainable = False
gan = Model(z, discriminator(G))
gan.compile(
loss='mean_squared_error',
optimizer=Adam(lr=2e-4, beta_1=0.5, decay=1e-7))
discriminator.trainable = True
generator.summary()
discriminator.summary()
data = input_data.read_data_sets('MNIST_data').train.images
data = data.reshape(-1, 28, 28, 1) * 2 - 1
number_of_epochs = 30
batch_size = 256
label_smoothing = 0.9
def noise(size):
return np.random.randn(size, noise_size)
def smooth_labels(size):
return np.random.uniform(low=0.8, high=1.0, size=size)
try:
for epoch in range(number_of_epochs):
print('Epoch: {0}/{1}'.format(epoch + 1, number_of_epochs))
for batch_start in range(0, len(data) - batch_size + 1, batch_size):
generated_images = generator.predict(noise(batch_size))
real_images = data[batch_start:batch_start + batch_size]
all_images = np.concatenate(
[generated_images, real_images], axis=0)
all_images += np.random.normal(0, 0.1, all_images.shape)
labels = np.zeros(len(all_images))
labels[batch_size:] = smooth_labels(batch_size)
d_loss = discriminator.train_on_batch(all_images, labels)
labels = np.ones(batch_size)
g_loss = gan.train_on_batch(noise(batch_size), labels)
batch_index = batch_start // batch_size + 1
message = '\rBatch: {0} | D: {1:.10f} | G: {2:.10f}'
print(message.format(batch_index, d_loss, g_loss), end='')
print()
np.random.shuffle(data)
except KeyboardInterrupt:
print()
print('Training complete!')
display_images = 100
images = generator.predict(noise(display_images))
images = (images + 1) / 2
plot.switch_backend('Agg')
plot.figure(figsize=(10, 4))
for i in range(display_images):
axis = plot.subplot(10, 10, i + 1)
plot.imshow(images[i].reshape(28, 28), cmap='gray')
axis.get_xaxis().set_visible(False)
axis.get_yaxis().set_visible(False)
print('Saving fig.png')
plot.savefig('fig.png')