# Model Definitions

In [None]:
from abc import ABC, abstractmethod
import os
import shutil
import tensorflow as tf

# Base Class

In [None]:
class ImageClassifier(ABC):
    IMAGE_NCOLS = 784
    LABELS_NCOLS = 10
    
    def __init__(self):
        self.graph = None
        self.session = None
       
    def close(self):
        if self.session is not None:
            self.session.close()
            self.session = None
        if self.graph is not None:
            self.graph = None
    
    @abstractmethod
    def build_graph(self):
        pass
    
    def get_graph(self):
        if self.graph is None:
            self.graph = tf.Graph()
            with self.graph.as_default():
                self.build_graph()
        return self.graph
    
    def get_session(self):
        if self.session is None:
            graph = self.get_graph()
            with graph.as_default():
               init = tf.global_variables_initializer() 
            self.session = tf.Session(graph=graph)
            self.session.run(init)
        return self.session
        
    def train(self, train_set, test_set, learning_rate=0.01, num_epochs=50, batch_size=20):
        session = self.get_session()
        
        # add training scaffolding
        with session.graph.as_default():
            with tf.name_scope("training"):
                y = tf.placeholder(tf.float32, [None, ImageClassifier.LABELS_NCOLS])

                # cross entropy loss function
                cross_entropy_loss = tf.reduce_mean(
                    tf.nn.softmax_cross_entropy_with_logits_v2(
                        logits=self.y_hat, 
                        labels=y
                    )
                )

                train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy_loss)

                # model evaluation
                actual_class, predicted_class = tf.argmax(y, 1), tf.argmax(self.y_prob, 1)
                correct_prediction = tf.cast(tf.equal(predicted_class, actual_class), tf.float32)
                classification_accuracy = tf.reduce_mean(correct_prediction)
        
        # run training cycle
        for epoch in range(num_epochs):
            avg_cost = 0.
            avg_accuracy = 0.
            total_batch = int(train_set.num_examples / batch_size)
            
            # loop over all batches
            for i in range(total_batch):
                batch_x, batch_y = mnist_data.train.next_batch(batch_size)
                
                # run optimization op (backprop), cost op and accuracy op (to get training losses) #
                _, c, a = session.run(
                    [train_step, cross_entropy_loss, classification_accuracy],
                    feed_dict={
                        self.x: batch_x,
                        y: batch_y
                    }
                )
                
                # compute avg training loss and avg training accuracy #
                avg_cost += c / total_batch
                avg_accuracy += a / total_batch
            
            # display info per epoch step
            print(
                "Epoch {}: cross-entropy-loss = {:.4f}, training-accuracy = {:.3f}%".format(
                    epoch + 1,
                    avg_cost,
                    avg_accuracy * 100
                )
             )

        print("Training Completed!")

        # calculate test set accuracy #
        test_accuracy = session.run(
            classification_accuracy,
            feed_dict={
                self.x: mnist_data.test.images,
                y: mnist_data.test.labels
            }
        )

        print("Accuracy on test set = {:.3f}%".format(test_accuracy * 100))

    def predict(self, input):
        session = self.get_session()
        return session.run(
            self.y_argmax_prob,
            feed_dict={self.x: input}
        )
    
    def restore(self, path):
        session = self.get_session()
        with session.graph.as_default():
            saver = tf.train.Saver()
        saver.restore(session, path)
    
    def save(self, path):
        session = self.get_session()
        with session.graph.as_default():
            saver = tf.train.Saver()
        saver.save(session, path)
        
    def export_for_model_server(self, path, overwrite=True):
        if os.path.exists(path) and os.path.isdir(path) and overwrite:
            shutil.rmtree(path)
        session = self.get_session()
        with session.graph.as_default(): 
            tf.saved_model.simple_save(
                session,
                path,
                {"images": self.x},
                {"labels": self.y_argmax_prob}
            )

# Linear Image Classifier Class

In [None]:
class LinearImageClassifier(ImageClassifier):
    def __init__(self):
        super().__init__()
        
    def build_graph(self):
        
        # create model
        with tf.name_scope("model"):
            self.x = tf.placeholder(tf.float32, [None, ImageClassifier.IMAGE_NCOLS])
            W = tf.Variable(tf.random_normal([ImageClassifier.IMAGE_NCOLS, ImageClassifier.LABELS_NCOLS]))
            b = tf.Variable(tf.ones([ImageClassifier.LABELS_NCOLS]))
            self.y_hat = tf.add(tf.matmul(self.x, W), b)
            self.y_prob = tf.nn.softmax(self.y_hat)
            self.y_argmax_prob = tf.argmax(self.y_prob, 1)

# Neural Network Image Classifier Class

In [None]:
class NeuralNetworkImageClassifier(ImageClassifier):
    def __init__(self):
        super().__init__()
        
    def build_graph(self):
        
        # create model
        with tf.name_scope("model"):
            self.x = tf.placeholder(tf.float32, [None, ImageClassifier.IMAGE_NCOLS])
            W1 = tf.Variable(tf.random_normal([ImageClassifier.IMAGE_NCOLS, ImageClassifier.IMAGE_NCOLS]))
            b1 = tf.Variable(tf.ones([ImageClassifier.IMAGE_NCOLS]))
            y1 = tf.nn.relu(tf.add(tf.matmul(self.x, W1), b1))
            W2 = tf.Variable(tf.random_normal([ImageClassifier.IMAGE_NCOLS, ImageClassifier.LABELS_NCOLS]))
            b2 = tf.Variable(tf.ones([ImageClassifier.LABELS_NCOLS]))
            self.y_hat = tf.add(tf.matmul(y1, W2), b2)
            self.y_prob = tf.nn.softmax(self.y_hat)
            self.y_argmax_prob = tf.argmax(self.y_prob, 1)