ReLU activation function

In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist
from time import time
import os

In [2]:
# Checkpoint function 
def load(model, checkpoint_dir):
    print(" [*] Reading checkpoints...")
    
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        checkpoint = tf.train.Checkpoint(dnn=model)
        checkpoint.restore(save_path = os.path.join(checkpoint_dir, ckpt_name))
        counter = int(ckpt_name.split('-')[1])
        print(" [*] Success to read {}".format(ckpt_name))
        return True, counter
    else:
        print('[*] Failed to find a checkpoint')
        return False, 0
    
def check_folder(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)
    return dir

In [3]:
# Load mnist
def load_mnist():
    (train_data, train_labels), (test_data, test_labels) = mnist.load_data()
    
    train_data = np.expand_dims(train_data, axis = -1)
    test_data = np.expand_dims(test_data, axis = -1)
    
    train_data, test_data = normalize(train_data, test_data)
    
    train_labels = to_categorical(train_labels, 10)
    test_lables = to_categorical(test_labels, 10)
    
    return train_data, train_labels, test_data, test_lables
    
def normalize(train_data, test_data):
    train_data = train_data.astype(np.float32) / 255.0
    test_data = test_data.astype(np.float32) / 255.0
    
    return train_data, test_data

In [4]:
# Define loss
def loss_fn(model, images, labels):
    logits = model(images, training = True)
    loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_pred = logits,
                                                                   y_true=labels,
                                                                   from_logits=True))
    return loss
    
def accuracy_fn(model, images, labels):
    logits = model(images, training = False)
    prediction = tf.equal(tf.argmax(logits, -1), tf.argmax(labels, -1))
    accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32))
    return accuracy

def grad(model, images, labels):
    with  tf.GradientTape() as tape:
        loss = loss_fn(model, images, labels)
    return tape.gradient(loss, model.variables)

In [5]:
#create network
def flatten():
    return tf.keras.layers.Flatten()

def dense(label_dim, weight_init):
    return tf.keras.layers.Dense(units=label_dim, use_bias = True, 
                                 kernel_initializer=weight_init)

# 엥 시그모이드 사용하네
def sigmoid():
    return tf.keras.layers.Activation(tf.keras.activations.sigmoid)

In [6]:
def create_model_function(label_dim):
    weight_init = tf.keras.initializers.RandomNormal()
    
    model = tf.keras.Sequential()
    model.add(flatten())
    
    for i in range(2):
        model.add(dense(256, weight_init))
        model.add(sigmoid())
        
    model.add(dense(label_dim, weight_init))
    
    return model

In [7]:
# Experiments (parameters)
# dataset
train_x, train_y, test_x, test_y = load_mnist()

# parameters
learning_rate = 0.001
batch_size = 128

training_epochs = 1
training_iterations = len(train_x) // batch_size

label_dim = 10

train_flag = True

# Graph Input using Dataset API
train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y)).\
shuffle(buffer_size = 100000).prefetch(buffer_size = batch_size).\
batch(batch_size, drop_remainder=True)

test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y)).\
shuffle(buffer_size = 100000).prefetch(buffer_size = len(test_x)).\
batch(len(test_x))


In [8]:
# Experiments (model)
# Model
network = create_model_function(label_dim)

# Training
optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate)

# Writer
checkpoint_dir = 'checkpoints'
logs_dir = 'logs'

model_dir = 'nn_softmax'

checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
check_folder(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, model_dir)
logs_dir = os.path.join(logs_dir, model_dir)

In [10]:
if train_flag:
    checkpoint = tf.train.Checkpoint(dnn=network)
    
    # create writer for tensorboard
    summary_writer = tf.summary.create_file_writer(logdir = logs_dir)
    start_time = time()
    
    # restore check-point if it exits
    could_load, checkpoint_counter = load(network, checkpoint_dir)
    
    if could_load:
        start_epoch = (int)(checkpoint_counter / training_iterations)
        counter = checkpoint_counter
        print(' [*] Load SUCCESS')
    else:
        start_epoch = 0
        start_iteration = 0
        counter = 0
        print(' [!] Load failed...')
        
    # train phase
    with summary_writer.as_default():
        for epoch in range(start_epoch, training_epochs):
            for idx, (train_input, train_label) in enumerate(train_dataset):
                grads = grad(network, train_input, train_label)
                optimizer.apply_gradients(grads_and_vars=zip(grads, network.variables))
                
                train_loss = loss_fn(network, train_input, train_label)
                train_accuracy = accuracy_fn(network, train_input, train_label)
                
                for test_input, test_label in test_dataset:
                    test_accuracy = accuracy_fn(network, train_input, train_label)
                    
                tf.summary.scalar(name = 'train_loss', data = train_loss, step = counter)
                tf.summary.scalar(name = 'train_accuracy', data = train_accuracy, step=counter)
                tf.summary.scalar(name = 'test_accuracy', data = test_accuracy, step=counter)
                
                print("Epoch: [%2d] [%5d/%5d] time: %4.4f, train_loss: %.8f, train_accuracy: %.4f, test_Accuracy: %.4f" \
                    % (epoch, idx, training_iterations, time() - start_time, train_loss, train_accuracy,\
                       test_accuracy))
                
                counter += 1
        checkpoint.save(file_prefix = checkpoint_prefix + '-{}'.format(counter))
        
else:
    _, _ = load(network, checkpoint_dir)
    for test_input, test_label in test_dataset:
        test_accuracy = accuracy_fn(network, test_input, test_label)
        
    print('test_Accuracy: %.4f' %(test_accuracy))

 [*] Reading checkpoints...
[*] Failed to find a checkpoint
 [!] Load failed...
Epoch: [ 0] [    0/  468] time: 0.4308, train_loss: 2.27726769, train_accuracy: 0.1172, test_Accuracy: 0.1172
Epoch: [ 0] [    1/  468] time: 0.5397, train_loss: 2.26573229, train_accuracy: 0.2578, test_Accuracy: 0.2578
Epoch: [ 0] [    2/  468] time: 0.6416, train_loss: 2.29206944, train_accuracy: 0.1328, test_Accuracy: 0.1328
Epoch: [ 0] [    3/  468] time: 0.7678, train_loss: 2.24998283, train_accuracy: 0.1094, test_Accuracy: 0.1094
Epoch: [ 0] [    4/  468] time: 0.8677, train_loss: 2.29259610, train_accuracy: 0.1406, test_Accuracy: 0.1406
Epoch: [ 0] [    5/  468] time: 0.9667, train_loss: 2.26160717, train_accuracy: 0.1875, test_Accuracy: 0.1875
Epoch: [ 0] [    6/  468] time: 1.0616, train_loss: 2.25342131, train_accuracy: 0.2109, test_Accuracy: 0.2109
Epoch: [ 0] [    7/  468] time: 1.1766, train_loss: 2.23372984, train_accuracy: 0.3203, test_Accuracy: 0.3203
Epoch: [ 0] [    8/  468] time: 1.2695, 

Epoch: [ 0] [   76/  468] time: 7.9999, train_loss: 0.93307745, train_accuracy: 0.7656, test_Accuracy: 0.7656
Epoch: [ 0] [   77/  468] time: 8.0918, train_loss: 0.73418045, train_accuracy: 0.8750, test_Accuracy: 0.8750
Epoch: [ 0] [   78/  468] time: 8.1808, train_loss: 0.88368058, train_accuracy: 0.7734, test_Accuracy: 0.7734
Epoch: [ 0] [   79/  468] time: 8.2727, train_loss: 0.90646428, train_accuracy: 0.8281, test_Accuracy: 0.8281
Epoch: [ 0] [   80/  468] time: 8.3897, train_loss: 0.84766883, train_accuracy: 0.8203, test_Accuracy: 0.8203
Epoch: [ 0] [   81/  468] time: 8.4846, train_loss: 0.93237996, train_accuracy: 0.7266, test_Accuracy: 0.7266
Epoch: [ 0] [   82/  468] time: 8.5796, train_loss: 0.83186895, train_accuracy: 0.8047, test_Accuracy: 0.8047
Epoch: [ 0] [   83/  468] time: 8.6735, train_loss: 0.80155134, train_accuracy: 0.8359, test_Accuracy: 0.8359
Epoch: [ 0] [   84/  468] time: 8.8075, train_loss: 0.84109986, train_accuracy: 0.7812, test_Accuracy: 0.7812
Epoch: [ 0

Epoch: [ 0] [  151/  468] time: 15.5451, train_loss: 0.55832493, train_accuracy: 0.8516, test_Accuracy: 0.8516
Epoch: [ 0] [  152/  468] time: 15.6401, train_loss: 0.43582392, train_accuracy: 0.9062, test_Accuracy: 0.9062
Epoch: [ 0] [  153/  468] time: 15.7500, train_loss: 0.44626769, train_accuracy: 0.8906, test_Accuracy: 0.8906
Epoch: [ 0] [  154/  468] time: 15.8430, train_loss: 0.47949812, train_accuracy: 0.8828, test_Accuracy: 0.8828
Epoch: [ 0] [  155/  468] time: 15.9369, train_loss: 0.50942552, train_accuracy: 0.9062, test_Accuracy: 0.9062
Epoch: [ 0] [  156/  468] time: 16.0269, train_loss: 0.51073539, train_accuracy: 0.8672, test_Accuracy: 0.8672
Epoch: [ 0] [  157/  468] time: 16.1489, train_loss: 0.54903072, train_accuracy: 0.8594, test_Accuracy: 0.8594
Epoch: [ 0] [  158/  468] time: 16.2408, train_loss: 0.39435378, train_accuracy: 0.9219, test_Accuracy: 0.9219
Epoch: [ 0] [  159/  468] time: 16.3328, train_loss: 0.44400477, train_accuracy: 0.8750, test_Accuracy: 0.8750
E

Epoch: [ 0] [  227/  468] time: 23.2077, train_loss: 0.26252097, train_accuracy: 0.9297, test_Accuracy: 0.9297
Epoch: [ 0] [  228/  468] time: 23.2967, train_loss: 0.33554783, train_accuracy: 0.9375, test_Accuracy: 0.9375
Epoch: [ 0] [  229/  468] time: 23.3866, train_loss: 0.40609515, train_accuracy: 0.8594, test_Accuracy: 0.8594
Epoch: [ 0] [  230/  468] time: 23.4986, train_loss: 0.35103831, train_accuracy: 0.8594, test_Accuracy: 0.8594
Epoch: [ 0] [  231/  468] time: 23.5885, train_loss: 0.30423057, train_accuracy: 0.9297, test_Accuracy: 0.9297
Epoch: [ 0] [  232/  468] time: 23.6755, train_loss: 0.35185191, train_accuracy: 0.8828, test_Accuracy: 0.8828
Epoch: [ 0] [  233/  468] time: 23.7704, train_loss: 0.35974881, train_accuracy: 0.8984, test_Accuracy: 0.8984
Epoch: [ 0] [  234/  468] time: 23.8994, train_loss: 0.36327314, train_accuracy: 0.8594, test_Accuracy: 0.8594
Epoch: [ 0] [  235/  468] time: 23.9893, train_loss: 0.27900109, train_accuracy: 0.9219, test_Accuracy: 0.9219
E

Epoch: [ 0] [  301/  468] time: 30.3955, train_loss: 0.26829165, train_accuracy: 0.9219, test_Accuracy: 0.9219
Epoch: [ 0] [  302/  468] time: 30.4865, train_loss: 0.33564568, train_accuracy: 0.9297, test_Accuracy: 0.9297
Epoch: [ 0] [  303/  468] time: 30.5774, train_loss: 0.32729343, train_accuracy: 0.8984, test_Accuracy: 0.8984
Epoch: [ 0] [  304/  468] time: 30.6823, train_loss: 0.28260875, train_accuracy: 0.9141, test_Accuracy: 0.9141
Epoch: [ 0] [  305/  468] time: 30.7753, train_loss: 0.23427504, train_accuracy: 0.9219, test_Accuracy: 0.9219
Epoch: [ 0] [  306/  468] time: 30.8682, train_loss: 0.36186346, train_accuracy: 0.8828, test_Accuracy: 0.8828
Epoch: [ 0] [  307/  468] time: 30.9952, train_loss: 0.24731512, train_accuracy: 0.9297, test_Accuracy: 0.9297
Epoch: [ 0] [  308/  468] time: 31.0861, train_loss: 0.43073136, train_accuracy: 0.8594, test_Accuracy: 0.8594
Epoch: [ 0] [  309/  468] time: 31.1761, train_loss: 0.43033370, train_accuracy: 0.8672, test_Accuracy: 0.8672
E

Epoch: [ 0] [  376/  468] time: 37.6152, train_loss: 0.21777323, train_accuracy: 0.9453, test_Accuracy: 0.9453
Epoch: [ 0] [  377/  468] time: 37.7252, train_loss: 0.24246255, train_accuracy: 0.9219, test_Accuracy: 0.9219
Epoch: [ 0] [  378/  468] time: 37.8171, train_loss: 0.27357835, train_accuracy: 0.9375, test_Accuracy: 0.9375
Epoch: [ 0] [  379/  468] time: 37.9081, train_loss: 0.23930971, train_accuracy: 0.9219, test_Accuracy: 0.9219
Epoch: [ 0] [  380/  468] time: 37.9990, train_loss: 0.32219845, train_accuracy: 0.8984, test_Accuracy: 0.8984
Epoch: [ 0] [  381/  468] time: 38.1042, train_loss: 0.27496353, train_accuracy: 0.9219, test_Accuracy: 0.9219
Epoch: [ 0] [  382/  468] time: 38.1931, train_loss: 0.34989673, train_accuracy: 0.8828, test_Accuracy: 0.8828
Epoch: [ 0] [  383/  468] time: 38.2831, train_loss: 0.33843371, train_accuracy: 0.9062, test_Accuracy: 0.9062
Epoch: [ 0] [  384/  468] time: 38.4592, train_loss: 0.33348650, train_accuracy: 0.9141, test_Accuracy: 0.9141
E

Epoch: [ 0] [  449/  468] time: 44.8447, train_loss: 0.22360522, train_accuracy: 0.9297, test_Accuracy: 0.9297
Epoch: [ 0] [  450/  468] time: 44.9578, train_loss: 0.27475917, train_accuracy: 0.8984, test_Accuracy: 0.8984
Epoch: [ 0] [  451/  468] time: 45.0487, train_loss: 0.37616730, train_accuracy: 0.8594, test_Accuracy: 0.8594
Epoch: [ 0] [  452/  468] time: 45.1397, train_loss: 0.19403151, train_accuracy: 0.9453, test_Accuracy: 0.9453
Epoch: [ 0] [  453/  468] time: 45.2296, train_loss: 0.30841595, train_accuracy: 0.8984, test_Accuracy: 0.8984
Epoch: [ 0] [  454/  468] time: 45.3416, train_loss: 0.25389659, train_accuracy: 0.9141, test_Accuracy: 0.9141
Epoch: [ 0] [  455/  468] time: 45.4326, train_loss: 0.35746822, train_accuracy: 0.8906, test_Accuracy: 0.8906
Epoch: [ 0] [  456/  468] time: 45.5245, train_loss: 0.24454370, train_accuracy: 0.9375, test_Accuracy: 0.9375
Epoch: [ 0] [  457/  468] time: 45.6165, train_loss: 0.27046514, train_accuracy: 0.9062, test_Accuracy: 0.9062
E