In [58]:
from keras.layers import Input, Dense, Conv2D, MaxPool2D, Conv2DTranspose, Flatten, Reshape, InputLayer
from keras.optimizers import Adam
from keras.models import Sequential, Model
from keras import backend as K
from keras.datasets import cifar10
from keras.callbacks import TensorBoard
from time import clock
from random import randint, seed
from sklearn.model_selection import train_test_split
from keras_utils import ModelSaveCallback, TqdmProgressCallback
import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
def AutoEncoder(img_shape, code_size):
    
    encoder = Sequential()
    encoder.add(InputLayer(input_shape=img_shape))

    encoder.add(Conv2D(filters=32, kernel_size=(3,3), padding="same", activation="elu"))
    encoder.add(MaxPool2D(pool_size=(3,3), padding="same"))

    encoder.add(Conv2D(filters=64, kernel_size=(3,3), padding="same", activation="elu"))
    encoder.add(MaxPool2D(pool_size=(3,3), padding="same"))

    encoder.add(Conv2D(filters=128, kernel_size=(3,3), padding="same", activation="elu"))
    encoder.add(MaxPool2D(pool_size=(3,3), padding="same"))

    encoder.add(Conv2D(filters=512, kernel_size=(3,3), padding="same", activation="elu"))
    encoder.add(MaxPool2D(pool_size=(3, 3), padding="same"))

    encoder.add(Flatten())
    encoder.add(Dense(units=code_size, activation="elu"))
    
    ###########################################################################################

    decoder = Sequential()
    decoder.add(InputLayer(input_shape=(code_size,)))
                
    decoder.add(Dense(units=512, activation="elu"))
    decoder.add(Reshape(target_shape=(2, 2, 128)))

    decoder.add(Conv2DTranspose(filters=128, kernel_size=(3,3), activation="elu", strides=2, padding="same"))
    
    decoder.add(Conv2DTranspose(filters=64, kernel_size=(3,3), activation="elu", strides=2, padding="same"))
    
    decoder.add(Conv2DTranspose(filters=32, kernel_size=(3,3), activation="elu", strides=2, padding="same"))
    
    decoder.add(Conv2DTranspose(filters=3, kernel_size=(3,3), strides=2, padding="same"))
    
    return encoder, decoder

In [3]:
(x_train, _), (x_test, _) = cifar10.load_data()

In [59]:
def show_samples():
    seed(clock())
    offset = randint(0,x_train.shape[0]-30)
    plt.rcParams["figure.figsize"] = (20, 8)
    for i in range(30):
        plt.subplot(3,10,i+1)
        plt.imshow(x_train[offset + i])

In [64]:
def visualize(img,encoder,decoder):
    """Draws original, encoded and decoded images"""
    code = encoder.predict(img[None])[0]  # img[None] is the same as img[np.newaxis, :]
    reco = decoder.predict(code[None])[0]
    plt.rcParams["figure.figsize"] = (15, 5)
    
    plt.subplot(1,3,1)
    plt.title("Original")
    plt.imshow(img)
    
    plt.subplot(1,3,2)
    plt.title("Code")
    plt.imshow(code.reshape([code.shape[-1]//2,-1]))

    plt.subplot(1,3,3)
    plt.title("Reconstructed")
    show_image(reco)
    plt.imshow(img)

In [5]:
IMG_SHAPE = x_train[0].shape
encoder, decoder = AutoEncoder(img_shape = IMG_SHAPE, code_size=32)

In [6]:
image_input = Input(IMG_SHAPE)
encoding = encoder(image_input)
reconstructed_image = decoder(encoding)

autoenc_model = Model(inputs=image_input, outputs=reconstructed_image)
autoenc_model.compile(optimizer = "adam", loss="mean_squared_logarithmic_error")

In [7]:
autoenc_model.fit(x=x_train, 
                  y=x_train,
                  epochs = 30,
                  validation_data=[x_test, x_test],
                  callbacks = [
                            TensorBoard(log_dir="logs/final/{}".format(clock()), histogram_freq=1),
                            TqdmProgressCallback()
                              ],
                  verbose = 0)

Epoch 1/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 2/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 3/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 4/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 5/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 6/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 7/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 8/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 9/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 10/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 11/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 12/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 13/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 14/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 15/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 16/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 17/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 18/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 19/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 20/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 21/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 22/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 23/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 24/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 25/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 26/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 27/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 28/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 29/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch 30/30


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




<keras.callbacks.History at 0x7f571c0f2898>

In [8]:
encoder.save_weights("autoencoder3.h5")
decoder.save_weights("autodecoder3.h5")

In [53]:
def verify(img, model_no=1):
    if model_no==1:
        encoder.load_weights("autoencoder1.h5")
        decoder.load_weights("autodecoder1.h5")
    else:
        encoder.load_weights("autoencoder2.h5")
        decoder.load_weights("autodecoder2.h5")

    visualize(x_test[100], encoder, decoder)