## autoencoder

In [None]:
from __future__ import print_function
from numpy import random
import numpy as np
from matplotlib.lines import Line2D  
from PIL import Image
import argparse

random.seed(42)  # @UndefinedVariable

from tensorflow.keras.datasets import mnist , fashion_mnist, cifar100
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Dense, Flatten, Dropout
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam

"""Autoencoder Simples Model
    https://elix-tech.github.io/ja/2016/07/17/autoencoder.html
    参考論文 : https://arxiv.org/pdf/1812.11262.pdf
      [我々はロバストな予測のためのオートエンコーダーベースの残差ディープネットワークを提案する]
"""

CASE = "mnist"  # mnist, fashion_mnist, cifar100
OBJ  = {"mnist":mnist, "fashion":fashion_mnist, "cifar100":cifar100}

# load mnist data
(x_train, _), (x_test, _) = OBJ[CASE].load_data()
x_train = x_train.astype('float32') / 255.
x_test  = x_test.astype('float32') / 255.   
ENC_ACT, DEC_ACT = 'relu', 'sigmoid'            
    
WIDTH, HEIGHT, CL = x_train.shape[1], x_train.shape[2], 1
encoding_dim, decoding_dim, EPOCH, BATCH_SIZE = 32, WIDTH * HEIGHT * CL, 60, 64
 x_train = np.reshape(x_train, [-1, x_train.shape[1] * x_train.shape[2] * CL])
x_test  = np.reshape(x_test, [-1, x_test.shape[1] * x_test.shape[2] * CL])

print("train.shape = {}, test.shape= {}".format(x_train.shape, x_test.shape))

    
# encode
input_img    = Input(shape=(x_train.shape[1], ), name = "autoencoder" + "_input")

#encoded      = Dropout(0.5)(input_img)
encoded      = Dense(encoding_dim, activation=ENC_ACT)(input_img)
encoded      = Flatten(name='flatten_e1')(encoded)
encoded      = Dense(encoding_dim, activation=ENC_ACT)(encoded)
encoded      = Flatten(name='flatten_e2')(encoded)
encoded      = Dense(encoding_dim, activation=ENC_ACT)(encoded)
encoded      = Flatten(name='flatten_e3')(encoded)
encoded      = Dense(encoding_dim, activation=ENC_ACT)(encoded)
encoded      = Flatten(name='flatten_e4')(encoded)
encoded      = Dense(encoding_dim, activation=ENC_ACT)(encoded)  # n = 10
# decode
decoded      = Dense(decoding_dim, activation=DEC_ACT)(encoded)
decoded      = Flatten(name='flatten_d1')(decoded)
decoded      = Dense(decoding_dim, activation=DEC_ACT)(decoded)
decoded      = Flatten(name='flatten_d2')(decoded)
decoded      = Dense(decoding_dim, activation=DEC_ACT)(decoded)
decoded      = Flatten(name='flatten_d3')(decoded)
decoded      = Dense(decoding_dim, activation=DEC_ACT)(decoded)
decoded      = Flatten(name='flatten_d4')(decoded)
decoded      = Dense(decoding_dim, activation=DEC_ACT)(decoded)

autoencoder  = Model(input_img, decoded)

# Opt
#Adam
#lr: 0以上の浮動小数点数．学習率．
#beta_1: 浮動小数点数, 0 < beta < 1. 一般的に1に近い値です．
#beta_2: 浮動小数点数, 0 < beta < 1. 一般的に1に近い値です．
#epsilon: 0以上の浮動小数点数．微小量．NoneならばデフォルトでK.epsilon()．
#decay: 0以上の浮動小数点数．各更新の学習率減衰．
#amsgrad: 論文"On the Convergence of Adam and Beyond"にあるAdamの変種であるAMSGradを適用するかどうか．

opt = Adam(lr=0.001, beta_1=0.9, beta_2=0.999)

autoencoder.compile(optimizer=opt, loss='binary_crossentropy')

# ResourceExhaustedError
# https://testpy.hatenablog.com/entry/2017/05/07/122323
hist = autoencoder.fit(x_train, x_train, epochs=EPOCH, batch_size=BATCH_SIZE, shuffle=True, validation_data=(x_test, x_test))

def results_draw(x_test, decode_imgs, d_size) :
      
    """Draw Autoencoder Results
    """
      
    import matplotlib.pyplot as plt
      
    # 何個表示するか
    n = 10
    plt.figure(figsize=(20, 4))
    for i in range(n):
        # オリジナルのテスト画像を表示
        ax = plt.subplot(2, n, i+1)
        plt.imshow(x_test[i].reshape(d_size[0], d_size[1], d_size[2]))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        # 変換された画像を表示
        ax = plt.subplot(2, n, i+1+n)
        plt.imshow(decoded_imgs[i].reshape(d_size[0], d_size[1], d_size[2]))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

    plt.savefig("autoencoder_results.png")

decoded_imgs = autoencoder.predict(x_test)
results_draw(x_test, decoded_imgs, (WIDTH, HEIGHT, CL))

    