In [3]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input,Dense,Conv2D,MaxPooling2D,UpSampling2D,Flatten
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt

This class creates two different auto encoding models:
<br>
    1. Flat Model:
    <br>
        -----Only consist of Dense Layers
        <br>
    2. Conv Model
    <br>
        -----Only consist of Convolutional Layers
        <br>

In [8]:
class ENCODERS:
    
    def __init__(self,flat_input_shape,conv_input_shape,epochs,activation,batch_size):
        self.flat_input_shape_ = flat_input_shape
        self.conv_input_shape_ = conv_input_shape
        self.flat_input_ = Input(shape=flat_input_shape)
        self.conv_input_ = Input(shape=conv_input_shape)
        self.epochs_ = epochs
        self.activation_ = activation
        self.batch_size_ = batch_size
        self.flat_encoder_ = self.__build_encoder__('flat')
        self.flat_decoder_ = self.__build_decoder__('flat')
        self.conv_encoder_ = self.__build_encoder__('conv')
        self.conv_decoder_ = self.__build_decoder__('conv')
        self.flat_model_ = self.__build_model__('flat')
        self.conv_model_ = self.__build_model__('conv')
        
        
    def __build_encoder__(self,v):
        if v == 'flat':
            enc = Dense(units=128,activation='relu')(self.flat_input_)
            return Dense(units=64,activation='relu')(enc)
        elif v == 'conv':
            enc = Conv2D(filters=8,kernel_size=(3,3),activation='relu',padding='same')(self.conv_input_)
            enc = MaxPooling2D((2,2),padding='same')(enc)
            enc = Conv2D(filters=4,kernel_size=(3,3),activation='relu',padding='same')(enc)
            return MaxPooling2D((2,2),padding='same')(enc)
    
    
    def __build_decoder__(self,v):
        if v == 'flat':
            dec = Dense(units=64,activation='relu')(self.flat_encoder_)
            dec = Dense(units=128,activation='relu')(dec)
            return Dense(units=self.flat_input_shape_,activation=self.activation_)(dec)
        elif v == 'conv':
            dec = Conv2D(filters=4,kernel_size=(3,3),activation='relu',padding='same')(self.conv_encoder_)
            dec = UpSampling2D((2,2))(dec)
            dec = Conv2D(filters=8,kernel_size=(3,3),activation='relu',padding='same')(dec)
            dec = UpSampling2D((2,2))(dec)
            return Conv2D(filters=1,kernel_size=(3,3),activation=self.activation_,padding='same')(dec)
    
    
    def __build_model__(self,v):
        input_ = None
        output_ = None
        if v == 'flat':
            input_ = self.flat_input_
            output_ = self.flat_decoder_
        elif v == 'conv':
            input_ = self.conv_input_
            output_ = self.conv_decoder_
        model = Model(inputs=input_, outputs=output_)
        model.compile(
            optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])
        return model
    
    
    def __model_summary__(self,m):
        if m == 'flat':
            return self.flat_model_.summary()
        elif m == 'conv':
            return self.conv_model_.summary()
    
    
    def __train__(self,X,y,m):
        print('Training...')
        if m == 'flat':
            self.flat_model_.fit(X,y,epochs=self.epochs_,batch_size=self.batch_size_,shuffle=True,verbose=1)
        elif m == 'conv':
            self.conv_model_.fit(X,y,epochs=self.epochs_,batch_size=self.batch_size_,shuffle=True,verbose=1)
        print('Done Training')
        
        
    def __get_encoded_data__(self,X,m):
        model = None
        if m == 'flat':
            model = Model(inputs=self.flat_input_, outputs=self.flat_encoder_)
        elif m == 'conv':
            model = Model(inputs=self.conv_input_, outputs=self.conv_encoder_)
        return model.predict(X)
    
    
    def __get_decoded_data__(self,X,m):
        if m == 'flat':
            return self.flat_model_.predict(X)
        elif m == 'conv':
            return self.conv_model_.predict(X)
        
        
    def __compare_original_decoded__(self,original,decoded,n):
        random_index = np.random.choice(len(original), n)
        plt.figure(figsize=(10,4),dpi=100)
        for i in range(n):
            ax = plt.subplot(2,n,i+1)
            plt.imshow(original[random_index[i]].reshape(28,28))
            ax.set_axis_off()
            
            ax = plt.subplot(2,n,i+n+1)
            plt.imshow(decoded[random_index[i]].reshape(28,28))
            ax.set_axis_off()
        plt.show()