In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from six.moves import cPickle as pickle
from keras.layers import Dense, Input, Conv2D, LSTM, MaxPool2D, UpSampling2D
from sklearn.model_selection import train_test_split
from keras.callbacks import EarlyStopping
from keras.utils import to_categorical
from numpy import argmax, array_equal
from keras.models import Model
from imgaug import augmenters
from random import randint

In [None]:
pkl_files = 'full_data.pickle'
with open(pkl_files, 'rb') as file:
    data = pickle.load(file)
    print(data['train_dataset'].shape)
    print(data['train_labels'].shape)
    print(data['valid_dataset'].shape)
    print(data['valid_labels'].shape)
    print(data['test_dataset'].shape)
    print(data['test_labels'].shape)

In [None]:
train_x = data['train_dataset'].reshape(-1, 784)
val_x = data['valid_dataset'].reshape(-1, 784)
test_x = data['test_dataset'].reshape(-1, 784)

print(train_x.shape)
print(val_x.shape)
print(test_x.shape)

In [None]:
## input layer
input_layer = Input(shape=(784,))

## encoding architecture
encode_layer1 = Dense(1500, activation='relu')(input_layer)
encode_layer2 = Dense(1000, activation='relu')(encode_layer1)
encode_layer3 = Dense(500, activation='relu')(encode_layer2)

## latent view
latent_view   = Dense(10, activation='sigmoid')(encode_layer3)

## decoding architecture
decode_layer1 = Dense(500, activation='relu')(latent_view)
decode_layer2 = Dense(1000, activation='relu')(decode_layer1)
decode_layer3 = Dense(1500, activation='relu')(decode_layer2)

## output layer
output_layer  = Dense(784)(decode_layer3)

model = Model(input_layer, output_layer)

In [None]:
model.summary()

In [None]:
model.compile(optimizer='adam', loss='mse')
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1, mode='auto')
model.fit(train_x,
          train_x,
          epochs=20,
          batch_size=2048,
          validation_data=(val_x, val_x),
          callbacks=[early_stopping])

In [None]:
preds = model.predict(test_x)

In [None]:
test_x_pics = test_x.reshape(-1, 28, 28)
preds_pics = preds.reshape(-1, 28, 28)

In [None]:
test_gt = []
test_pred = []
indices = np.random.choice(test_x_pics.shape[0], 5, replace=False)
for idx in indices:
    test_gt.append(test_x_pics[idx])
    test_pred.append(preds_pics[idx])

fig, ax = plt.subplots(2, 5, figsize=(5, 5))
for j in range(5):
    ax[0, j].get_xaxis().set_visible(False)
    ax[0, j].get_yaxis().set_visible(False)
    ax[0, j].cla()
    ax[0, j].imshow(test_gt[j], cmap='gray')
    
for j in range(5):
    ax[1, j].get_xaxis().set_visible(False)
    ax[1, j].get_yaxis().set_visible(False)
    ax[1, j].cla()
    ax[1, j].imshow(test_pred[j], cmap='gray')

plt.show()