In [70]:
import tensorflow as tf
import tensorflow.keras
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import BatchNormalization
# from tensorflow.keras.utils import np_utils
from tensorflow.keras.layers import Conv2D, MaxPooling2D, ZeroPadding2D, GlobalAveragePooling2D, UpSampling2D, Cropping2D
# from tensorflow.keras.layers.advanced_activations import LeakyReLU 
from tensorflow.keras.preprocessing.image import ImageDataGenerator


In [78]:
class PINE():

  def __init__(self, batch_size):
    
      # Input shape
      self.img_rows = 28
      self.img_cols = 28
      self.channels = 1
      self.img_shape = (self.img_rows, self.img_cols, self.channels)
      self.latent_dim = 100 
      self.batch_size = batch_size  
 

      optimizer = Adam(0.0002, 0.5)

      # Build and compile the interpreter
      self.interpreter = self.build_interpreter()
      # self.interpreter.compile(loss='binary_crossentropy',
      #     optimizer=optimizer,
      #     metrics=['accuracy'])

      # Build the mian model
      self.main_model = self.build_main_model()

      # The interpreter takes an image as input and generates an interpretation
      # img = self.interpreter(img)

      # For the combined model we will only train the generator
      #self.interpreter.trainable = False

      # The main model takes generated interpretation as input and determines validity
      # valid = self.main_model(img)

      # The combined model  (stacked main model and interpreter)
      # Trains the main model to fool the interpreter
      # self.combined = Model(z, valid)
      # self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

  def build_main_model(self):
    
    #     __________
    #   /           \
    #  / MAIN  MODEL \
    # /_______________\
    # model: https://github.com/yashk2810/MNIST-Keras/blob/master/Notebook/MNIST_keras_CNN-99.55%25.ipynb
    
    
    imgs = keras.Input(shape=(28, 28, 1), name="img")
    x = tf.keras.layers.Conv2D(32, 3, activation="relu")(imgs)
    x = BatchNormalization(axis=-1)(x)
    x = tf.keras.layers.Conv2D(32, 3, activation="relu")(x)
    x = tf.keras.layers.MaxPooling2D(2)(x)
    x = BatchNormalization(axis=-1)(x)
    x = tf.keras.layers.Conv2D(64, 3, activation="relu")(x)
    x = BatchNormalization(axis=-1)(x)
    x = tf.keras.layers.Conv2D(64, 3, activation="relu")(x)

    x = tf.keras.layers.MaxPooling2D(2)(x)
    x = Faltten()(x)

    # Fully connected layer
    x = BatchNormalization()(x)
    out_logit = tf.keras.layers.Dense(512, activation="relu")(x)
    x = BatchNormalization()(out_logit)
    x = Dropout(0.2)(x)
    out = tf.keras.layers.Dense(10, activation="softmax")(x)

    model = Model(inputs = imgs, outputs = (out,out_logit))
    print(model.summary())
    return model

    # model.add(tf.keras.layers.Conv2D(32, kernel_size=3, input_shape=(28,28,1)))
    # model.add(tf.keras.layers.Activation('relu'))
    # BatchNormalization(axis=-1)
    # model.add(tf.keras.layers.Conv2D(32, (3, 3)))
    # model.add(tf.keras.layers.Activation('relu'))
    # model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2)))

    # BatchNormalization(axis=-1)
    # model.add(tf.keras.layers.Conv2D(64,(3, 3)))
    # model.add(tf.keras.layers.Activation('relu'))
    # BatchNormalization(axis=-1)
    # model.add(tf.keras.layers.Conv2D(64, (3, 3)))
    # model.add(tf.keras.layers.Activation('relu'))
    # model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2)))

    # model.add(Flatten())
    # Fully connected layer

    # BatchNormalization()
    # model.add(tf.keras.layers.Dense(512))
    # model.add(tf.keras.layers.Activation('relu'))
    # BatchNormalization()
    # model.add(tf.keras.layers.Dropout(0.2))
    # model.add(tf.keras.layers.Dense(10))
    # model.add(tf.keras.layers.Activation('softmax'))

    # model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])
    # out = model.layers[-1].output
    # out_logit = model.layers[-2].output

  def build_interpreter(self):
    # ________________
    # \               /
    #  \             /
    #   \           /
    #    INTERPRETER
    #   /           \
    #  /             \
    # /_______________\

    # model = https://github.com/nathanhubens/Autoencoders/blob/master/Autoencoders.ipynb

    # # Encoder
    encoder_input = tf.keras.Input(shape=(28, 28, 1), name="img")
    x = tf.keras.layers.Conv2D(16, 3, activation="relu", padding='same')(encoder_input)
    x = tf.keras.layers.MaxPooling2D(2, padding='same')(x)
    x = tf.keras.layers.Conv2D(8, 3, activation="relu", padding='same')(x)
    x = tf.keras.layers.MaxPooling2D(2, padding='same')(x)
    x = tf.keras.layers.Conv2D(8, 3, activation="relu", padding='same')(x)
    x = tf.keras.layers.MaxPooling2D(2, padding='same')(x)
    # model.add(tf.keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same', input_shape=(28,28,1)))
    # model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2), padding='same'))
    # model.add(tf.keras.layers.Conv2D(8,(3, 3), padding='same'))
    # model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2), padding='same'))
    # model.add(tf.keras.layers.Conv2D(8,(3, 3), padding='same'))
    # model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2), padding='same'))
    # # Decoder
    x = tf.keras.layers.Conv2D(8, 3, activation="relu", padding='same')(x)
    x = tf.keras.layers.UpSampling2D(2)(x)
    x = tf.keras.layers.Conv2D(8, 3, activation="relu", padding='same')(x)
    x = tf.keras.layers.UpSampling2D(2)(x)
    x = tf.keras.layers.Conv2D(16, 3, activation="relu")(x)
    x = tf.keras.layers.UpSampling2D(2)(x)
    out = tf.keras.layers.Conv2D(1, 3, activation="sigmoid", padding='same')(x)
    # model.add(tf.keras.layers.Conv2D(8, (3, 3), activation='relu', padding='same'))  
    # model.add(tf.keras.layers.UpSampling2D((2, 2)))
    # model.add(tf.keras.layers.Cropping2D(cropping = ((1,0), (1,0))))
    # model.add(tf.keras.layers.Conv2D(8, (3, 3), activation='relu', padding='same'))
    # model.add(tf.keras.layers.UpSampling2D((2, 2)))   
    # model.add(tf.keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same'))
    # model.add(tf.keras.layers.UpSampling2D((2, 2)))
    # model.add(tf.keras.layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same'))
    recon_error = tf.sqrt(2 * tf.nn.l2_loss(out - encoder_input)) / self.batch_size
    model = Model(inputs = encoder_input, outputs = (out, recon_error))
    print(model.summary())
    # out = model.layers[-1].output
    return model


  
  def train(self, epochs, batch_size=128, save_interval=50):
    # # restore check-point if it exits
    # self.batch_size = batch_size
    # could_load, checkpoint_counter = self.load(self.checkpoint_dir)
    # if could_load:
    #     start_epoch = (int)(checkpoint_counter / self.num_batches)
    #     start_batch_id = checkpoint_counter - start_epoch * self.num_batches
    #     counter = checkpoint_counter
    #     print(" [*] Load SUCCESS")
    # else:
    #     start_epoch = 0
    #     start_batch_id = 0
    #     counter = 1
    #     print(" [!] Load failed...")
    self.epochs = epochs

    # Load the dataset
    (X_train, _), (_, _) = mnist.load_data()

    # Rescale -1 to 1
    X_train = X_train / 127.5 - 1.
    X_train = np.expand_dims(X_train, axis=3)

    # Adversarial ground truths
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(self.epochs):
        #################################################### 
        #                                ________________  #
        #         ___________           \               /  #
        #        /           \           \             /   #
        # Train / MAIN  MODEL \           \           /    #
        #      /_______________\           INTERPRETER     #
        #                                 /           \    #
        #                                /             \   #
        #                               /_______________\  #
        ####################################################
        

        # Select a random half of images
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]

        # Sample noise and generate a batch of new images
        noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
        # gen_imgs = self.interpreter.predict(noise)

        # Train the interpreter (real classified as ones and generated as zeros)
        ints, int_error = self.interpreter.train_on_batch(imgs)
        out_int, out_logit_int = self.main_model.train_on_batch(ints)
        out_img, out_logit_img = self.main_model.train_on_batch(ints)
        int_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        self.kcc = tf.keras.losses.CategoricalCrossentropy()
        self.main_model_loss = self.kcc(out_img,self.y)
        out_sqrt = tf.keras.backend.sqrt(out_int)
        sumi = tf.keras.backend.sum(out_sqrt)**2
        self.interpreter_loss = int_error + self.kcc(out_int, self.y) +0.0002*(sumi)

        # ---------------------
        #  Train Generator
        # ---------------------

        # Plot the progress
        print ("%d [Intepretor loss: %f, acc.: %.2f%%] [Main Model loss: %f]" % (epoch, int_loss[0], 100*int_loss[1], main_model_loss))

        # If at save interval => save generated image samples
        if epoch % save_interval == 0:
            self.save_imgs(epoch)

        self.save(self.checkpoint_dir, counter)

          # show temporal results
        self.visualize_results(epoch)

    # save model for final step
    self.save(self.checkpoint_dir, counter)        

    def save_imgs(self, epoch):
        r, c = 5, 5
        imgs = np.random.normal(0, 1, (r * c, self.latent_dim))
        # gen_imgs = self.interpreter.predict(imgs)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close()

In [79]:
if __name__ == '__main__':
    pine = PINE(batch_size=32)
    pine.train(epochs=5, save_interval=50)

AttributeError: ignored