In [None]:
import tensorflow as tf
import tensorflow.keras
import matplotlib.pyplot as plt
%matplotlib inline
import os
import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2D, MaxPooling2D, ZeroPadding2D, GlobalAveragePooling2D, UpSampling2D, Cropping2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import accuracy_score

In [58]:
class PINE():
  model_name = "pine_mnist"     # name for checkpoint
  dataset_name = "mnist"

  def __init__(self, batch_size, dataset_name):
    
      # Input shape
      self.img_rows = 28
      self.img_cols = 28
      self.channels = 1
      self.y_dim = 10
      self.img_shape = (self.img_rows, self.img_cols, self.channels)
      self.latent_dim = 100 
      self.batch_size = batch_size
      self.learning_rate_main_model = 0.0001
      self.learning_rate_interpreter = 0.0001
      self.checkpoint_dir = 'checkpoint'
      self.dataset_name = dataset_name
 
      # Build and compile the interpreter
      self.interpreter = self.build_interpreter()

      # Build the mian model
      self.main_model = self.build_main_model()

  def build_main_model(self):
    
    #    ___________
    #   /           \
    #  / MAIN  MODEL \
    # /_______________\
    # model: https://github.com/yashk2810/MNIST-Keras/blob/master/Notebook/MNIST_keras_CNN-99.55%25.ipynb
    
    
    imgs = tf.keras.Input(shape=(28, 28, 1), name="img")
    x = tf.keras.layers.Conv2D(32, 3, activation="relu")(imgs)
    x = BatchNormalization(axis=-1)(x)
    x = tf.keras.layers.Conv2D(32, 3, activation="relu")(x)
    x = tf.keras.layers.MaxPooling2D(2)(x)
    x = BatchNormalization(axis=-1)(x)
    x = tf.keras.layers.Conv2D(64, 3, activation="relu")(x)
    x = BatchNormalization(axis=-1)(x)
    x = tf.keras.layers.Conv2D(64, 3, activation="relu")(x)

    x = tf.keras.layers.MaxPooling2D(2)(x)
    x = tf.keras.layers.Flatten()(x)

    # Fully connected layer
    x = BatchNormalization()(x)
    out_logit = tf.keras.layers.Dense(512, activation="relu")(x)
    x = BatchNormalization()(out_logit)
    x = tf.keras.layers.Dropout(0.2)(x)
    out = tf.keras.layers.Dense(10, activation="softmax")(x)

    model = tf.keras.Model(inputs = imgs, outputs = out)

    print(model.summary())
    return model

  def build_interpreter(self):
    # _________________
    # \               /
    #  \             /
    #   \           /
    #    INTERPRETER
    #   /           \
    #  /             \
    # /_______________\

    # model = https://github.com/nathanhubens/Autoencoders/blob/master/Autoencoders.ipynb

    # # Encoder
    encoder_input = tf.keras.Input(shape=(28, 28, 1), name="img")
    x = tf.keras.layers.Conv2D(16, 3, activation="relu", padding='same')(encoder_input)
    x = tf.keras.layers.MaxPooling2D(2, padding='same')(x)
    x = tf.keras.layers.Conv2D(8, 3, activation="relu", padding='same')(x)
    x = tf.keras.layers.MaxPooling2D(2, padding='same')(x)
    x = tf.keras.layers.Conv2D(8, 3, activation="relu", padding='same')(x)
    x = tf.keras.layers.MaxPooling2D(2, padding='same')(x)

    # # Decoder
    x = tf.keras.layers.Conv2D(8, 3, activation="relu", padding='same')(x)
    x = tf.keras.layers.UpSampling2D(2)(x)
    x = tf.keras.layers.Conv2D(8, 3, activation="relu", padding='same')(x)
    x = tf.keras.layers.UpSampling2D(2)(x)
    x = tf.keras.layers.Conv2D(16, 3, activation="relu")(x)
    x = tf.keras.layers.UpSampling2D(2)(x)
    out = tf.keras.layers.Conv2D(1, 3, activation="sigmoid", padding='same')(x)

    model = tf.keras.Model(inputs = encoder_input, outputs = out)


    print(model.summary())

    return model

  def train(self, epochs, batch_size=128):
    def categorical_accuracy(self, y_true, y_pred):
      return tf.keras.backend.mean(tf.keras.backend.equal(tf.math.argmax(y_true, axis=-1), tf.math.argmax(y_pred, axis=-1)))
        #################################################### 
        #                                ________________  #
        #         ___________           \               /  #
        #        /           \           \             /   #
        # Train / MAIN  MODEL \           \           /    #
        #      /_______________\           INTERPRETER     #
        #                                 /           \    #
        #                                /             \   #
        #                               /_______________\  #
        ####################################################


    start_batch_id=0
    self.epochs = epochs

    # Load the dataset
    (X_train,y_train), (X_test, y_test) = mnist.load_data()

    # Rescale -1 to 1
    X_train = X_train / 127.5 - 1.
    X_train = np.expand_dims(X_train, axis=3)


    X_test = X_test / 127.5 - 1.
    X_test = np.expand_dims(X_test, axis=3)

    y_vec = np.zeros((len(y_train), 10), dtype=np.float)
    for i, label in enumerate(y_train):
        y_vec[i, y_train[i]] = 1.0

    y_vec_test = np.zeros((len(y_test), 10), dtype=np.float)
    for i, label in enumerate(y_test):
        y_vec_test[i, y_test[i]] = 1.0

    # Adversarial ground truths
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    opt_interpreter = tf.keras.optimizers.Adam(self.learning_rate_interpreter)
    opt_main_model = tf.keras.optimizers.Adam(self.learning_rate_main_model)
    

    self.num_batches = len(X_train) // self.batch_size    
    for epoch in range(self.epochs):
      for idx in range(start_batch_id, self.num_batches):
        print('Batch:',idx,'/',self.num_batches, end=" ")
        imgs = X_train[idx*self.batch_size:(idx+1)*self.batch_size]
        self.y = y_vec[idx * self.batch_size:(idx + 1) * self.batch_size]

        # Iterate over the batches of a dataset.
        # Open a GradientTape.
        with tf.GradientTape(persistent= True) as tape:

        # Train the interpreter (real classified as ones and generated as zeros)
          ints = self.interpreter(imgs, training=True)

          out_int = self.main_model(ints, training=True)

          out_img = self.main_model(imgs, training=True)


          # Get gradients of loss wrt the weights.
          # Main Model Loss
          CatCrossEnt = tf.keras.losses.CategoricalCrossentropy()
          loss_eval = CatCrossEnt(self.y, out_img)
          main_model_grads = tape.gradient(loss_eval, self.main_model.trainable_weights)
          # Interpreter Loss
          out_int = self.main_model(ints)
          int_error = tf.sqrt(2 * tf.nn.l2_loss(out_img - out_int)) / self.batch_size        
          l1 = tf.dtypes.cast(int_error, tf.float32)
          l2 = tf.dtypes.cast(CatCrossEnt(self.y, out_int), tf.float32)
          out_sqrt = tf.keras.backend.sqrt(out_int)
          sumi = tf.keras.backend.sum(out_sqrt)**2        
          l3 = tf.dtypes.cast(0.0002*(sumi), tf.float32)
          self.interpreter_loss = l1 + l2 + l3
          interpreter_grads = tape.gradient(self.interpreter_loss, self.interpreter.trainable_weights)


          # Update the weights of the model.
          
          opt_interpreter.apply_gradients(zip(interpreter_grads, self.interpreter.trainable_weights))
          opt_main_model.apply_gradients(zip(main_model_grads, self.main_model.trainable_weights))
        # print('Main Model Loss: '+ loss_eval+  'Interpreter Loss: '+self.interpreter_loss)
        print('Main Model Loss:',loss_eval.numpy(), end=" ")
        print('Interpreter Loss:',self.interpreter_loss.numpy())
        

      # self.main_model.save_model(self.checkpoint_dir)
      # self.interpreter.save_model(self.checkpoint_dir)
      pred_label = numpy.argmax(pred) # pred_label = 1 (index)

      ints_test = self.interpreter(X_test, training=False)

      out_int_test = self.main_model(ints_test, training=False)

      out_img_test = self.main_model(X_test, training=False)

      main_model_acc = categorical_accuracy(self, y_vec_test,out_img_test)
      interpreter_acc =categorical_accuracy(self, y_vec_test,out_int_test)
      print(' Main Model Acc: ', main_model_acc.numpy(), 'Interpreter Acc: ', interpreter_acc.numpy())
      self.main_model.save(self.checkpoint_dir)
      self.interpreter.save(self.checkpoint_dir)


In [None]:
if __name__ == '__main__':
    pine = PINE(batch_size=64, dataset_name="mnist")
    pine.train(epochs=5)

Model: "model_50"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
img (InputLayer)             [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d_275 (Conv2D)          (None, 28, 28, 16)        160       
_________________________________________________________________
max_pooling2d_125 (MaxPoolin (None, 14, 14, 16)        0         
_________________________________________________________________
conv2d_276 (Conv2D)          (None, 14, 14, 8)         1160      
_________________________________________________________________
max_pooling2d_126 (MaxPoolin (None, 7, 7, 8)           0         
_________________________________________________________________
conv2d_277 (Conv2D)          (None, 7, 7, 8)           584       
_________________________________________________________________
max_pooling2d_127 (MaxPoolin (None, 4, 4, 8)           0  