Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make aae.py work with Keras 2.3.1 #218

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions aae/aae.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, GaussianNoise
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers import MaxPooling2D, merge
from keras.layers import Input, Dense, Reshape, Flatten, Lambda
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras import losses
from keras.utils import to_categorical
import keras.backend as K

import matplotlib.pyplot as plt

import numpy as np


class AdversarialAutoencoder():
def __init__(self):
self.img_rows = 28
Expand All @@ -26,11 +22,8 @@ def __init__(self):

optimizer = Adam(0.0002, 0.5)

# Build and compile the discriminator
# Build the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])

# Build the encoder / decoder
self.encoder = self.build_encoder()
Expand All @@ -44,17 +37,21 @@ def __init__(self):

# For the adversarial_autoencoder model we will only train the generator
self.discriminator.trainable = False
self.discriminator.compile(
loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])

# The discriminator determines validity of the encoding
validity = self.discriminator(encoded_repr)

# The adversarial_autoencoder model (stacked generator and discriminator)
self.adversarial_autoencoder = Model(img, [reconstructed_img, validity])
self.adversarial_autoencoder.compile(loss=['mse', 'binary_crossentropy'],
self.adversarial_autoencoder.compile(
loss=['mse', 'binary_crossentropy'],
loss_weights=[0.999, 0.001],
optimizer=optimizer)


def build_encoder(self):
# Encoder

Expand All @@ -67,12 +64,15 @@ def build_encoder(self):
h = LeakyReLU(alpha=0.2)(h)
mu = Dense(self.latent_dim)(h)
log_var = Dense(self.latent_dim)(h)
latent_repr = merge([mu, log_var],
mode=lambda p: p[0] + K.random_normal(K.shape(p[0])) * K.exp(p[1] / 2),
output_shape=lambda p: p[0])
latent_repr = Lambda(self.latent, output_shape=(self.latent_dim, ))([mu, log_var])

return Model(img, latent_repr)

def latent(self, p):
"""Sample based on `mu` and `log_var`"""
mu, log_var = p
return mu + K.random_normal(K.shape(mu)) * K.exp(log_var / 2)

def build_decoder(self):

model = Sequential()
Expand Down Expand Up @@ -146,7 +146,7 @@ def train(self, epochs, batch_size=128, sample_interval=50):
g_loss = self.adversarial_autoencoder.train_on_batch(imgs, [imgs, valid])

# Plot the progress
print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[0], g_loss[1]))
print("%d [D loss: %f, acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss[0], g_loss[1]))

# If at save interval => save generated image samples
if epoch % sample_interval == 0:
Expand All @@ -155,7 +155,7 @@ def train(self, epochs, batch_size=128, sample_interval=50):
def sample_images(self, epoch):
r, c = 5, 5

z = np.random.normal(size=(r*c, self.latent_dim))
z = np.random.normal(size=(r * c, self.latent_dim))
gen_imgs = self.decoder.predict(z)

gen_imgs = 0.5 * gen_imgs + 0.5
Expand All @@ -164,8 +164,8 @@ def sample_images(self, epoch):
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
axs[i,j].axis('off')
axs[i, j].imshow(gen_imgs[cnt, : , :, 0], cmap='gray')
axs[i, j].axis('off')
cnt += 1
fig.savefig("images/mnist_%d.png" % epoch)
plt.close()
Expand All @@ -175,8 +175,9 @@ def save_model(self):
def save(model, model_name):
model_path = "saved_model/%s.json" % model_name
weights_path = "saved_model/%s_weights.hdf5" % model_name
options = {"file_arch": model_path,
"file_weight": weights_path}
options = {
"file_arch": model_path,
"file_weight": weights_path}
json_string = model.to_json()
open(options['file_arch'], 'w').write(json_string)
model.save_weights(options['file_weight'])
Expand Down