In [2]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist
from time import time
import os
print(tf.__version__)

2.3.1


In [4]:
def load(model, checkpoint_dir):
    print(" [*] Reading checkpoints...")
    
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        checkpoint = tf.train.Checkpoint(dnn=model)
        checkpoint.restore(save_path=os.path.join(checkpoint_dir, ckpt_name))
        counter = int(ckpt_name.split('-')[1])
        print(" [*] Success to read {}".format(ckpt_name))
        return True, counter
    else:
        print(" [*] Failed to find a checkpoint")
        return False, 0

def check_folder(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)
    return dir

In [6]:
def load_mnist():
    (train_data, train_labels), (test_data, test_labels) = mnist.load_data()
    train_data = np.expand_dims(train_data, axis=-1)  # [N, 28, 28] -> [N, 28, 28, 1]
    test_data = np.expand_dims(test_data, axis=-1)  # [N, 28, 28] -> [N, 28, 28, 1]
    
    train_data, test_data = normalize(train_data, test_data)
    
    train_labels = to_categorical(train_labels, 10)  # [N,] -> [N, 10]
    test_labels = to_categorical(test_labels, 10)  # [N,] -> [N, 10]
    
    return train_data, train_labels, test_data, test_labels

def normalize(train_data, test_data):
    train_data = train_data.astype(np.float32) / 255.0
    test_data = test_data.astype(np.float32) / 255.0
    
    return train_data, test_data

In [19]:
def loss_fn(model, images, labels):
    logits = model(images, training=True)
    loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_pred=logits, y_true=labels,
                                                                  from_logits=True))
    return loss

def accuracy_fn(model, images, labels):
    logits = model(images, training=False)
    prediction = tf.equal(tf.argmax(logits, -1), tf.argmax(labels, -1))
    accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32))
    return accuracy

def grad(model, images, labels):
    with tf.GradientTape() as tape:
        loss = loss_fn(model, images, labels)
    return tape.gradient(loss, model.variables)

In [20]:
def flatten():
    return tf.keras.layers.Flatten()

def dense(label_dim, weight_init):
    return tf.keras.layers.Dense(units=label_dim, use_bias=True, kernel_initializer=weight_init)

def sigmoid():
    return tf.keras.layers.Activation(tf.keras.activations.sigmoid)

In [21]:
class create_model_class(tf.keras.Model):
    def __init__(self, lable_dim):
        super(create_model_class, self).__init__()
        weight_init = tf.keras.initializers.RandomNormal()
        
        self.model = tf.keras.Sequential()
        self.model.add(flatten())
        
        for i in range(2):
            self.model.add(dense(256, weight_init))
            self.model.add(sigmoid())
        
        self.model.add(dense(label_dim, weight_init))
    
    def call(self, x, training=None, mask=None):
        x = self.model(x)
        
        return x

In [22]:
def create_model_function(label_dim):
    weight_init = tf.keras.initializers.RandomNormal()
    
    model = tf.keras.Sequential()
    model.add(flatten())
    
    for i in range(2):
        model.add(dense(256, weight_init))
        model.add(sigmoid())
    
    model.add(dense(label_dim, weight_init))
    
    return model

In [23]:
""" dataset """
train_x, train_y, test_x, test_y = load_mnist()

""" parameters """
learning_rate = 0.001
batch_size = 128

training_epochs = 1
training_iterations = len(train_x) // batch_size

label_dim = 10

train_flag = True

""" Graph Input using Dataset API """
train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y)).\
    shuffle(buffer_size=100000).\
    prefetch(buffer_size=batch_size).\
    batch(batch_size, drop_remainder=True)

test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y)).\
    shuffle(buffer_size=100000).\
    prefetch(buffer_size=len(test_x)).\
    batch(len(test_x))

In [24]:
""" Model """
network = create_model_function(label_dim)

""" Training """
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

""" Writer """
checkpoint_dir = 'checkpoints'
logs_dir = 'logs'

model_dir = 'nn_softmax'

checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
check_folder(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, model_dir)
logs_dir = os.path.join(logs_dir, model_dir)

In [25]:
if train_flag:
    
    checkpoint = tf.train.Checkpoint(dnn=network)
    
    # create writer for tensorboard
    summary_writer = tf.summary.create_file_writer(logdir=logs_dir)
    start_time = time()
    
    # restore check-point if it exits
    could_load, checkpoint_counter = load(network, checkpoint_dir)
    
    if could_load:
        start_epoch = (int)(checkpoint_counter / training_iterations)
        counter = checkpoint_counter
        print(" [*] Load SUCCESS")
    else:
        start_epoch = 0
        start_iteration = 0
        counter = 0
        print(" [!] Load failed...")
    
    # train phase
    with summary_writer.as_default():  # for tensorboard
        for epoch in range(start_epoch, training_epochs):
            for idx, (train_input, train_label) in enumerate(train_dataset):
                grads = grad(network, train_input, train_label)
                optimizer.apply_gradients(grads_and_vars=zip(grads, network.variables))
                
                train_loss = loss_fn(network, train_input, train_label)
                train_accuracy = accuracy_fn(network, train_input, train_label)
                
                for test_input, test_label in test_dataset:
                    test_accuracy = accuracy_fn(network, test_input, test_label)
                
                tf.summary.scalar(name='train_loss', data=train_loss, step=counter)
                tf.summary.scalar(name='train_accuracy', data=train_accuracy, step=counter)
                tf.summary.scalar(name='test_accuracy', data=test_accuracy, step=counter)
                
                print(
                    "Epoch: [%2d] [%5d/%5d] time: %4.4f, train_loss: %.8f, train_accuracy: %.4f, test_accuracy: %.4f" \
                    % (epoch, idx, training_iterations, time() - start_time, train_loss, train_accuracy,
                       test_accuracy))
                counter += 1
        checkpoint.save(file_prefix=checkpoint_prefix + '-{}'.format(counter))

# test phase
else:
    _, _ = load(network, checkpoint_dir)
    for test_input, test_label in test_dataset:
        test_accuracy = accuracy_fn(network, test_input, test_label)
    
    print("test_accuracy: %.4f" % (test_accuracy))

 [*] Reading checkpoints...
 [*] Failed to find a checkpoint
 [!] Load failed...
Epoch: [ 0] [    0/  468] time: 0.8954, train_loss: 2.29469705, train_accuracy: 0.1016, test_accuracy: 0.0892
Epoch: [ 0] [    1/  468] time: 1.1734, train_loss: 2.29305005, train_accuracy: 0.1250, test_accuracy: 0.1131
Epoch: [ 0] [    2/  468] time: 1.4804, train_loss: 2.28097892, train_accuracy: 0.0938, test_accuracy: 0.1052
Epoch: [ 0] [    3/  468] time: 1.6749, train_loss: 2.28629804, train_accuracy: 0.2266, test_accuracy: 0.1985
Epoch: [ 0] [    4/  468] time: 1.8609, train_loss: 2.27084947, train_accuracy: 0.1406, test_accuracy: 0.1135
Epoch: [ 0] [    5/  468] time: 2.0381, train_loss: 2.32843471, train_accuracy: 0.1094, test_accuracy: 0.1135
Epoch: [ 0] [    6/  468] time: 2.2131, train_loss: 2.26015353, train_accuracy: 0.1484, test_accuracy: 0.1135
Epoch: [ 0] [    7/  468] time: 2.3935, train_loss: 2.24491167, train_accuracy: 0.1328, test_accuracy: 0.1135
Epoch: [ 0] [    8/  468] time: 2.5721,

Epoch: [ 0] [   75/  468] time: 13.2823, train_loss: 0.99674201, train_accuracy: 0.7500, test_accuracy: 0.7664
Epoch: [ 0] [   76/  468] time: 13.4406, train_loss: 0.90752065, train_accuracy: 0.8125, test_accuracy: 0.7856
Epoch: [ 0] [   77/  468] time: 13.5866, train_loss: 0.87378007, train_accuracy: 0.7891, test_accuracy: 0.7988
Epoch: [ 0] [   78/  468] time: 13.7396, train_loss: 0.95285296, train_accuracy: 0.7812, test_accuracy: 0.8031
Epoch: [ 0] [   79/  468] time: 13.8916, train_loss: 0.87259769, train_accuracy: 0.7734, test_accuracy: 0.8018
Epoch: [ 0] [   80/  468] time: 14.0496, train_loss: 0.87179458, train_accuracy: 0.7969, test_accuracy: 0.7986
Epoch: [ 0] [   81/  468] time: 14.2026, train_loss: 0.83666790, train_accuracy: 0.7812, test_accuracy: 0.7966
Epoch: [ 0] [   82/  468] time: 14.3496, train_loss: 0.86614287, train_accuracy: 0.8281, test_accuracy: 0.7968
Epoch: [ 0] [   83/  468] time: 14.5066, train_loss: 0.89879417, train_accuracy: 0.8203, test_accuracy: 0.8028
E

Epoch: [ 0] [  150/  468] time: 24.6315, train_loss: 0.50186878, train_accuracy: 0.8750, test_accuracy: 0.8772
Epoch: [ 0] [  151/  468] time: 24.7755, train_loss: 0.66060567, train_accuracy: 0.8203, test_accuracy: 0.8788
Epoch: [ 0] [  152/  468] time: 24.9255, train_loss: 0.48041725, train_accuracy: 0.8594, test_accuracy: 0.8798
Epoch: [ 0] [  153/  468] time: 25.0799, train_loss: 0.44753158, train_accuracy: 0.8828, test_accuracy: 0.8813
Epoch: [ 0] [  154/  468] time: 25.2289, train_loss: 0.38754171, train_accuracy: 0.9297, test_accuracy: 0.8825
Epoch: [ 0] [  155/  468] time: 25.3809, train_loss: 0.41756245, train_accuracy: 0.8984, test_accuracy: 0.8837
Epoch: [ 0] [  156/  468] time: 25.5299, train_loss: 0.48584595, train_accuracy: 0.8672, test_accuracy: 0.8850
Epoch: [ 0] [  157/  468] time: 25.6820, train_loss: 0.40935043, train_accuracy: 0.8984, test_accuracy: 0.8860
Epoch: [ 0] [  158/  468] time: 25.8310, train_loss: 0.48108745, train_accuracy: 0.8516, test_accuracy: 0.8863
E

Epoch: [ 0] [  224/  468] time: 35.7318, train_loss: 0.22760452, train_accuracy: 0.9609, test_accuracy: 0.9000
Epoch: [ 0] [  225/  468] time: 35.8838, train_loss: 0.38525555, train_accuracy: 0.9219, test_accuracy: 0.9021
Epoch: [ 0] [  226/  468] time: 36.0328, train_loss: 0.39244986, train_accuracy: 0.8984, test_accuracy: 0.9050
Epoch: [ 0] [  227/  468] time: 36.1969, train_loss: 0.35499340, train_accuracy: 0.8750, test_accuracy: 0.9055
Epoch: [ 0] [  228/  468] time: 36.3468, train_loss: 0.32386550, train_accuracy: 0.8828, test_accuracy: 0.9041
Epoch: [ 0] [  229/  468] time: 36.4978, train_loss: 0.40108085, train_accuracy: 0.8828, test_accuracy: 0.9020
Epoch: [ 0] [  230/  468] time: 36.6608, train_loss: 0.46415585, train_accuracy: 0.8672, test_accuracy: 0.8983
Epoch: [ 0] [  231/  468] time: 36.8128, train_loss: 0.21364611, train_accuracy: 0.9531, test_accuracy: 0.8950
Epoch: [ 0] [  232/  468] time: 36.9668, train_loss: 0.38307959, train_accuracy: 0.8750, test_accuracy: 0.8930
E

Epoch: [ 0] [  299/  468] time: 48.0862, train_loss: 0.34463525, train_accuracy: 0.9141, test_accuracy: 0.9127
Epoch: [ 0] [  300/  468] time: 48.2452, train_loss: 0.30206531, train_accuracy: 0.8906, test_accuracy: 0.9134
Epoch: [ 0] [  301/  468] time: 48.4052, train_loss: 0.32582894, train_accuracy: 0.8828, test_accuracy: 0.9134
Epoch: [ 0] [  302/  468] time: 48.5642, train_loss: 0.20920165, train_accuracy: 0.9531, test_accuracy: 0.9141
Epoch: [ 0] [  303/  468] time: 48.7272, train_loss: 0.37823245, train_accuracy: 0.9219, test_accuracy: 0.9141
Epoch: [ 0] [  304/  468] time: 48.8969, train_loss: 0.37867483, train_accuracy: 0.8906, test_accuracy: 0.9143
Epoch: [ 0] [  305/  468] time: 49.0509, train_loss: 0.32183018, train_accuracy: 0.9141, test_accuracy: 0.9141
Epoch: [ 0] [  306/  468] time: 49.2069, train_loss: 0.28892323, train_accuracy: 0.9141, test_accuracy: 0.9136
Epoch: [ 0] [  307/  468] time: 49.4375, train_loss: 0.28708357, train_accuracy: 0.9219, test_accuracy: 0.9124
E

Epoch: [ 0] [  373/  468] time: 60.5974, train_loss: 0.19355658, train_accuracy: 0.9531, test_accuracy: 0.9192
Epoch: [ 0] [  374/  468] time: 60.7714, train_loss: 0.39026877, train_accuracy: 0.8672, test_accuracy: 0.9187
Epoch: [ 0] [  375/  468] time: 61.0084, train_loss: 0.28259334, train_accuracy: 0.9141, test_accuracy: 0.9192
Epoch: [ 0] [  376/  468] time: 61.1994, train_loss: 0.21019444, train_accuracy: 0.9531, test_accuracy: 0.9194
Epoch: [ 0] [  377/  468] time: 61.3997, train_loss: 0.33684668, train_accuracy: 0.8984, test_accuracy: 0.9180
Epoch: [ 0] [  378/  468] time: 61.5837, train_loss: 0.25681317, train_accuracy: 0.9375, test_accuracy: 0.9168
Epoch: [ 0] [  379/  468] time: 61.7507, train_loss: 0.35283890, train_accuracy: 0.8906, test_accuracy: 0.9160
Epoch: [ 0] [  380/  468] time: 61.9197, train_loss: 0.34645683, train_accuracy: 0.9062, test_accuracy: 0.9173
Epoch: [ 0] [  381/  468] time: 62.0879, train_loss: 0.27883485, train_accuracy: 0.9609, test_accuracy: 0.9178
E

Epoch: [ 0] [  447/  468] time: 73.6101, train_loss: 0.29957855, train_accuracy: 0.9062, test_accuracy: 0.9242
Epoch: [ 0] [  448/  468] time: 73.9101, train_loss: 0.17423511, train_accuracy: 0.9453, test_accuracy: 0.9235
Epoch: [ 0] [  449/  468] time: 74.1411, train_loss: 0.21061787, train_accuracy: 0.9375, test_accuracy: 0.9236
Epoch: [ 0] [  450/  468] time: 74.3275, train_loss: 0.25840476, train_accuracy: 0.9219, test_accuracy: 0.9244
Epoch: [ 0] [  451/  468] time: 74.5375, train_loss: 0.27544758, train_accuracy: 0.9375, test_accuracy: 0.9239
Epoch: [ 0] [  452/  468] time: 74.6965, train_loss: 0.33559895, train_accuracy: 0.9062, test_accuracy: 0.9246
Epoch: [ 0] [  453/  468] time: 74.9075, train_loss: 0.26858750, train_accuracy: 0.9297, test_accuracy: 0.9241
Epoch: [ 0] [  454/  468] time: 75.1375, train_loss: 0.30333722, train_accuracy: 0.9141, test_accuracy: 0.9248
Epoch: [ 0] [  455/  468] time: 75.3175, train_loss: 0.34382814, train_accuracy: 0.9141, test_accuracy: 0.9245
E