In [1]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/mnist/", one_hot=True)

Extracting /tmp/mnist/train-images-idx3-ubyte.gz
Extracting /tmp/mnist/train-labels-idx1-ubyte.gz
Extracting /tmp/mnist/t10k-images-idx3-ubyte.gz
Extracting /tmp/mnist/t10k-labels-idx1-ubyte.gz


In [2]:
import os
import tensorflow as tf
from tqdm import tqdm

from models import modelA, modelB, modelC, modelD

class Trainer(object):
    
    def __init__(self, sess, mnist, type=0):

        self.mnist = mnist
        self.sess = sess
        
        # Parameters
        self.learning_rate = 0.001
        self.total_epoch = 5
        self.batch_size = 128
        
        # Network Parameters
        self.n_input = 784
        self.n_classes = 10
        self.n_size = 28
        self.n_channel = 1
        self.dropout = 0.75
        self.scope='naive_{}'.format(type)
        
        
        self.checkpoint_dir = './checkpoints'
        self.save_file_name = 'mnist_cnn_weight_type{}.ckpt'.format(type)
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
            
        models = [modelA, modelB, modelC, modelD]
        
        self.conv_net = models[type]
        self.build()
        
        self.vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope)
        self.saver = tf.train.Saver(var_list = self.vars)
        
        print('setting done!')
    
    def restore(self):
        self.saver.restore(self.sess, os.path.join(self.checkpoint_dir, self.save_file_name))
        
        
    def build(self):
        ## Modelling
        # Input, Output
        self.X = tf.placeholder(tf.float32, [None, 784], name='cnn_X')
        self.Y = tf.placeholder(tf.float32, [None, 10], name='cnn_Y')
        self.is_training = tf.placeholder(tf.bool, name='cnn_placeholder')
        
        self.X_img = tf.reshape(self.X, (-1, self.n_size, self.n_size, self.n_channel))
        self.pred = self.conv_net(self.X_img, self.is_training, self.scope)
        self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = self.pred, labels = self.Y))
        
        self.optimizer = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.cost)

        correct_pred = tf.equal(tf.argmax(self.pred, 1), tf.argmax(self.Y, 1))
        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
        
        print('build done!')
    
    def train(self):

        init = tf.global_variables_initializer()
        self.sess.run(init)
        
        total_iter = mnist.train.num_examples // self.batch_size
        
        step = 1
        for epochs in range(self.total_epoch):
            for ii in tqdm(range(total_iter)):
                batch_x, batch_y = self.mnist.train.next_batch(self.batch_size)
                self.sess.run(self.optimizer, feed_dict={self.X: batch_x, self.Y: batch_y, self.is_training: True})
                
            print("finished!")

            print("Testing Accuracy:", \
                  self.sess.run(self.accuracy, feed_dict={self.X: self.mnist.test.images[:256],
                                                          self.Y: self.mnist.test.labels[:256],
                                                          self.is_training: False}))

            
        saved_path = self.saver.save(self.sess, os.path.join(self.checkpoint_dir, self.save_file_name))
        print("Model saved in {}".format(saved_path))
    
    def test(self, X):
        return self.sess.run(self.pred, feed_dict={self.X: X, self.is_training: False})

In [3]:
sess = tf.Session()
trainer = Trainer(sess, mnist, type=0)
trainer.train()
trainer = Trainer(sess, mnist, type=1)
trainer.train()
trainer = Trainer(sess, mnist, type=2)
trainer.train()
trainer = Trainer(sess, mnist, type=3)
trainer.train()

build done!
setting done!


100%|██████████| 429/429 [00:02<00:00, 144.39it/s]
  0%|          | 1/429 [00:00<00:49,  8.68it/s]

finished!
Testing Accuracy: 0.9921875


100%|██████████| 429/429 [00:01<00:00, 217.17it/s]
  0%|          | 2/429 [00:00<00:21, 19.85it/s]

finished!
Testing Accuracy: 0.99609375


100%|██████████| 429/429 [00:01<00:00, 217.27it/s]
  1%|          | 3/429 [00:00<00:16, 26.30it/s]

finished!
Testing Accuracy: 0.98828125


100%|██████████| 429/429 [00:02<00:00, 212.64it/s]
  1%|          | 3/429 [00:00<00:16, 25.66it/s]

finished!
Testing Accuracy: 0.99609375


100%|██████████| 429/429 [00:02<00:00, 213.32it/s]


finished!
Testing Accuracy: 0.98828125
Model saved in ./checkpoints/mnist_cnn_weight_type0.ckpt


  0%|          | 0/429 [00:00<?, ?it/s]

build done!
setting done!


100%|██████████| 429/429 [00:01<00:00, 231.29it/s]
  1%|          | 5/429 [00:00<00:09, 44.53it/s]

finished!
Testing Accuracy: 0.97265625


100%|██████████| 429/429 [00:01<00:00, 234.20it/s]
  1%|          | 5/429 [00:00<00:09, 44.98it/s]

finished!
Testing Accuracy: 0.97265625


100%|██████████| 429/429 [00:01<00:00, 240.56it/s]
  1%|▏         | 6/429 [00:00<00:07, 57.46it/s]

finished!
Testing Accuracy: 0.97265625


100%|██████████| 429/429 [00:01<00:00, 231.44it/s]
  2%|▏         | 7/429 [00:00<00:07, 52.93it/s]

finished!
Testing Accuracy: 0.984375


100%|██████████| 429/429 [00:01<00:00, 224.98it/s]


finished!
Testing Accuracy: 0.98046875
Model saved in ./checkpoints/mnist_cnn_weight_type1.ckpt


  0%|          | 0/429 [00:00<?, ?it/s]

build done!
setting done!


100%|██████████| 429/429 [00:06<00:00, 70.91it/s]
  0%|          | 0/429 [00:00<?, ?it/s]

finished!
Testing Accuracy: 0.9921875


100%|██████████| 429/429 [00:05<00:00, 72.30it/s]
  2%|▏         | 8/429 [00:00<00:05, 73.45it/s]

finished!
Testing Accuracy: 0.9921875


100%|██████████| 429/429 [00:05<00:00, 72.51it/s]
  2%|▏         | 8/429 [00:00<00:05, 74.49it/s]

finished!
Testing Accuracy: 0.99609375


100%|██████████| 429/429 [00:05<00:00, 72.40it/s]
  2%|▏         | 8/429 [00:00<00:05, 75.01it/s]

finished!
Testing Accuracy: 0.99609375


100%|██████████| 429/429 [00:05<00:00, 72.61it/s]


finished!
Testing Accuracy: 0.99609375
Model saved in ./checkpoints/mnist_cnn_weight_type2.ckpt


  0%|          | 0/429 [00:00<?, ?it/s]

build done!
setting done!


100%|██████████| 429/429 [00:01<00:00, 274.77it/s]
  3%|▎         | 12/429 [00:00<00:04, 84.08it/s]

finished!
Testing Accuracy: 0.97265625


100%|██████████| 429/429 [00:01<00:00, 293.29it/s]
  3%|▎         | 12/429 [00:00<00:04, 95.64it/s]

finished!
Testing Accuracy: 0.97265625


100%|██████████| 429/429 [00:01<00:00, 286.24it/s]
  3%|▎         | 13/429 [00:00<00:04, 94.65it/s]

finished!
Testing Accuracy: 0.9765625


100%|██████████| 429/429 [00:01<00:00, 304.14it/s]
  3%|▎         | 14/429 [00:00<00:03, 113.92it/s]

finished!
Testing Accuracy: 0.98046875


100%|██████████| 429/429 [00:01<00:00, 313.74it/s]


finished!
Testing Accuracy: 0.984375
Model saved in ./checkpoints/mnist_cnn_weight_type3.ckpt
