**합성곱 AE 모델링**

In [1]:
from keras import layers, models

In [3]:
def Conv2D(filters, kernel_size, padding='same', activation='relu'):
    return layers.Conv2D(filters=filters, kernel_size=kernel_size, padding=padding, activation=activation)

In [4]:
class AE(models.Model):
    def __init__(self, orig_shape):

        original = layers.Input(shape=orig_shape)

        # Encoding 1
        x = Conv2D(4, (3, 3))(original)
        x = layers.MaxPooling2D((2, 2), padding='same')(x)

        # Encoding 2
        x = Conv2D(8, (3, 3))(x)
        x = layers.MaxPooling2D((2, 2), padding='same')(x)

        # Encoding 3 - encoded image with 7x7
        z = Conv2D(1, (7, 7))(x)

        # Decoding 1
        y = Conv2D(16, (3, 3))(z)
        y = layers.UpSampling2D((2, 2))(y)

        # Decoding 2
        y = Conv2D(8, (3, 3))(y)
        y = layers.UpSampling2D((2, 2))(y)

        # Decoding 3
        y = Conv2D(4, (3, 3))(y)

        # Output
        output = Conv2D(1, (3, 3), activation='sigmoid')(y)

        super(AE, self).__init__(original, output)

        self.compile(
            optimizer='adam',
            loss='binary_crossentropy',
            metrics=['accuracy']
        )

        self.original = original
        self.z = z

    def Encoder(self):
        return models.Model(self.original, self.z)

In [5]:
from keras.datasets import mnist

In [7]:
(X_train, _), (X_test, _) = mnist.load_data()

X_train = X_train.astype('float32') / 255.
X_test = X_test.astype('float32') / 255.

X_train = X_train.reshape((len(X_train), 28, 28, 1))
X_test = X_test.reshape((len(X_test), 28, 28, 1))

print(X_train.shape)
print(X_test.shape)

(60000, 28, 28, 1)
(10000, 28, 28, 1)


In [8]:
%matplotlib inline

In [9]:
import matplotlib.pyplot as plt

In [11]:
def plot_loss(history):
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.tilte('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(['Train', 'Test'], loc=0)

def plot_acc(history):
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.tilte('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend(['Train', 'Test'], loc=0)

def show_ae(autoencoder):
    encoder = autoencoder.Encoder()

    encoded_imgs = encoder.predict(X_test)
    decoded_imgs = autoencoder.predict(X_test)

    n = 10
    for i in range(n):
        ax = plt.subplot(3, n, i+1)
        plt.imshow(X_test[i])
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        ax = plt.subplot(3, n, i+1+n)
        plt.imshow(encoded_imgs[i])
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        ax = plt.subplot(3, n, i+1+n+n)
        plt.imshow(decoded_imgs[i])
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

    plt.show()

In [13]:
def main():
    epochs = 50
    batch_size = 128

    autoencoder = AE(orig_shape=(28, 28, 1))

    history = autoencoder.fit(
        X_train, X_train,
        epochs=epochs, batch_size=batch_size,
        validation_data=(X_test, X_test)
    )

    plot_loss(history)
    plt.show()

    plot_acc(history)
    plt.show()

    show_ae()
    plt.show()

In [14]:
main()

Epoch 1/50
 78/469 [===>..........................] - ETA: 29s - loss: 0.4391 - accuracy: 0.8057

KeyboardInterrupt: 