In [11]:
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.models import Model
from keras import backend as K

In [None]:
import keras.backend.tensorflow_backend as KTF
KTF.set_session(utils.get_session())

In [12]:
import utils
import numpy as np
import matplotlib.pyplot as plt

In [13]:
from keras.datasets import cifar10
(x_train, _), (x_test, _) = cifar10.load_data()

In [15]:
x_train.shape, x_test.shape

((50000, 32, 32, 3), (10000, 32, 32, 3))

In [None]:
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 32, 32, 3))  # adapt this if using `channels_first` image data format
x_test = np.reshape(x_test, (len(x_test), 32, 32, 3))  # adapt this if using `channels_first` image data format

In [None]:
# Adding JPEG artifacts
x_train, x_train_noisy = utils.cifar10_jpeg(x_train)
x_test, x_test_noisy = utils.cifar10_jpeg(x_test)

In [None]:
x_train.shape, x_train_noisy.shape, x_test.shape, x_test_noisy.shape

In [None]:
# Displaying noisy images
n = 10
plt.figure(figsize=(20, 4))
for i in range(1, n):
    # display original
    ax = plt.subplot(2, n, i)
    plt.imshow(x_train[i].reshape(32, 32, 3))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + n)
    plt.imshow(x_train_noisy[i].reshape(32, 32, 3))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

In [None]:
import keras
dist_input = Input(shape=(32, 32, 3))     # distorted image

############# Autoencoder for reconstructing the clean image ###############
x = Conv2D(32, (3, 3), activation='relu', padding='same')(dist_input)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
clean_inp_enc = MaxPooling2D((2, 2), padding='same')(x)

# at this point the representation is (7, 7, 32)

x = Conv2D(32, (3, 3), activation='relu', padding='same')(clean_inp_enc)
x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
clean_recon = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)

############# Autoencoder for reconstructing the distorted image ############
x = Conv2D(32, (3, 3), activation='relu', padding='same')(dist_input)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
dist_enc = MaxPooling2D((2, 2), padding='same')(x)

# at this point the representation is (7, 7, 32)
# Adding distortion to the encoded clean input
dist_inp_enc = keras.layers.add([clean_inp_enc, dist_enc])
x = Conv2D(32, (3, 3), activation='relu', padding='same')(dist_inp_enc)
x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
dist_recon = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)

autoencoder = Model(inputs=[dist_input], outputs=[clean_recon, dist_recon])
autoencoder.compile(optimizer='adadelta', loss='mean_squared_error')

In [None]:
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot

SVG(model_to_dot(autoencoder).create(prog='dot', format='svg'))

In [None]:
################## Model Flow Diagram ####################
from keras.utils import plot_model
plot_model(autoencoder, to_file='my_models/cdA_jpeg_model.png')

In [None]:
from keras.callbacks import TensorBoard
import sys

sys.stdout = open('cdA_jpeg_output.txt', 'w')
autoencoder.fit([x_train_noisy],
                [x_train, x_train_noisy],
                epochs=100,
                batch_size=128,
                shuffle=True,
                validation_data=([x_test_noisy], [x_test, x_test_noisy]),
                callbacks=[TensorBoard(log_dir='/tmp/autoencoder', histogram_freq=0, write_graph=False)])
sys.stdout = sys.__stdout__

In [None]:
decoded_imgs = autoencoder.predict([x_test_noisy])

n = 20
plt.figure(figsize=(20, 4))
for i in range(1, n):
    # display original
    ax = plt.subplot(4, n, i)
    plt.imshow(x_test[i].reshape(32, 32, 3))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display distorted
    ax = plt.subplot(4, n, i + n)
    plt.imshow(x_test_noisy[i].reshape(32, 32, 3))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    
    # display original image reconstruction
    ax = plt.subplot(4, n, i + 2*n)
    plt.imshow(decoded_imgs[0][i].reshape(32, 32, 3))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display distorted image reconstruction
    ax = plt.subplot(4, n, i + 3*n)
    plt.imshow(decoded_imgs[1][i].reshape(32, 32, 3))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

In [1]:
################## Testing the model ##################

In [19]:
import utils
import numpy as np
import matplotlib.pyplot as plt

import keras.backend.tensorflow_backend as KTF
KTF.set_session(utils.get_session())
from keras.datasets import cifar10

(x_train, _), (x_test, _) = cifar10.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 32, 32, 3))  # adapt this if using `channels_first` image data format
x_test = np.reshape(x_test, (len(x_test), 32, 32, 3))  # adapt this if using `channels_first` image data format

In [20]:
x_train.shape, x_test.shape

((50000, 32, 32, 3), (10000, 32, 32, 3))

In [21]:
x_test_noisy = utils.add_jpeg(x_test, 50)

In [None]:
from keras.models import load_model
autoencoder = load_model('my_models/cdA_jpeg.h5')
decoded_imgs = autoencoder.predict([x_test_noisy])

n = 20
plt.figure(figsize=(20, 4))
for i in range(1, n):
    # display original
    ax = plt.subplot(4, n, i)
    plt.imshow(x_test[i].reshape(32, 32, 3))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display distorted
    ax = plt.subplot(4, n, i + n)
    plt.imshow(x_test_noisy[i].reshape(32, 32, 3))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    
    # display clean image reconstruction
    ax = plt.subplot(4, n, i + 2*n)
    plt.imshow(decoded_imgs[0][i].reshape(32, 32, 3))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display distorted image reconstruction
    ax = plt.subplot(4, n, i + 3*n)
    plt.imshow(decoded_imgs[1][i].reshape(32, 32, 3))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()