In [93]:
import warnings
warnings.filterwarnings(action='ignore') 

In [1]:
import sys
import numpy as np
import tensorflow as tf

print("Python version:", sys.version)
print('tensorflow', tf.__version__)
print('numpy', np.__version__)

Python version: 3.9.7 (default, Sep 16 2021, 16:59:28) [MSC v.1916 64 bit (AMD64)]
tensorflow 2.5.3
numpy 1.19.5


In [None]:
# xor 문제 해결 : 1개의 모델로는 해결 불가

# 로지스틱 모델 3개가 모이면 해결 가능 -> multinomial 모델 포함한 모델 2개로 합치기 가능

K = tf.sigmoid(tf.matmul(X, W1) + b1)
hyp = tf.sigmoid(tf.matmul(K, W2) + b2)

# xor 문제 해결

## xor 문제 with logistic model

In [None]:
x_data = [[0, 0],
          [0, 1],
          [1, 0],
          [1, 1]]
y_data = [[0],
          [1],
          [1],
          [0]]

tf.random.set_seed(777)

dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch(len(x_data))

In [9]:
def process_data(features, labels):
    features = tf.cast(features, tf.float32)
    labels = tf.cast(labels, tf.float32) #오타 주의
    return features, labels 

In [72]:
W1 = tf.Variable(tf.random.normal((2, 1)), name='weight1')
b1 = tf.Variable(tf.random.normal((1,)), name='bias1')

W2 = tf.Variable(tf.random.normal((2, 1)), name='weight2')
b2 = tf.Variable(tf.random.normal((1,)), name='bias2')

W3 = tf.Variable(tf.random.normal((2, 1)), name='weight3')
b3 = tf.Variable(tf.random.normal((1,)), name='bias3')

In [11]:
def neural_net(features):
    layer1 = tf.sigmoid(tf.matmul(features, W1) + b1)
    layer2 = tf.sigmoid(tf.matmul(features, W2) + b2) #복붙 주의
    layer3 = tf.concat([layer1, layer2], -1) 
    layer3 =tf.reshape(layer3, shape = [-1,2]) 
    hyp = tf.sigmoid(tf.matmul(layer3, W3) + b3)
    return hyp

def loss_fn(hypothesis, labels):
    cost = -tf.reduce_mean(labels * tf.math.log(hypothesis) + (1 - labels) * tf.math.log(1 - hypothesis))
    return cost
    
optimizer = tf.keras.optimizers.SGD(learning_rate = 0.01)

def accuracy_fn(hyp, labels):
    predicted = tf.cast(hyp > 0.5, dtype=tf.float32)
    accuracy = tf.reduce_mean(tf.cast(tf.equal(predicted, labels), dtype=tf.float32))
    return accuracy

def grad(hyp, features, labels):
    with tf.GradientTape() as tape:
        loss_value = loss_fn(neural_net(features), labels)
    return tape.gradient(loss_value, [W1,W2,W3, b1, b2, b3])

In [74]:
epochs = 50000

for step in range(epochs):
    for features, labels in dataset:
        features, labels = process_data(features, labels)
        grads = grad(neural_net(features), features, labels)
        optimizer.apply_gradients(grads_and_vars=zip(grads, [W1, W2, W3, b1, b2, b3]))
        if step % 5000 == 0:
            print("iter : {}, loss : {:.4f}".format(step, loss_fn(neural_net(features), labels)))

x_data, y_data = process_data(x_data,y_data)
test_acc = accuracy_fn(neural_net(x_data), y_data)
print("test acc : {:.4f}".format(test_acc))

iter : 0, loss : 0.8487
iter : 5000, loss : 0.6847
iter : 10000, loss : 0.6610
iter : 15000, loss : 0.6154
iter : 20000, loss : 0.5722
iter : 25000, loss : 0.5433
iter : 30000, loss : 0.5211
iter : 35000, loss : 0.4911
iter : 40000, loss : 0.4416
iter : 45000, loss : 0.3313
test acc : 1.0000


## wide deep learning

1-1 함수 이용

동일하게 사용 가능한 함수는 self.함수명 = 함수명 사용해서 그대로 이용함.

loss_fn의 경우 예제 코드에서 불필요하게 features가 사용되어 해당 내용을 제거하고 그대로 사용함.

In [143]:
x_data = [[0, 0],
          [0, 1],
          [1, 0],
          [1, 1]]
y_data = [[0],
          [1],
          [1],
          [0]]

In [151]:
nb_classes = 10

class wide_deep_nn():
    def __init__(self, nb_classes):
        super(wide_deep_nn, self).__init__()
        
        self.W1 = tf.Variable(tf.random.normal((2, nb_classes)), name='weight1')
        self.b1 = tf.Variable(tf.random.normal((nb_classes,)), name='bias1')
        
        self.W2 = tf.Variable(tf.random.normal((nb_classes, nb_classes)), name='weight2')
        self.b2 = tf.Variable(tf.random.normal((nb_classes,)), name='bias2')
        
        self.W3 = tf.Variable(tf.random.normal((nb_classes, nb_classes)), name='weight3')
        self.b3 = tf.Variable(tf.random.normal((nb_classes,)), name='bias3')
        
        self.W4 = tf.Variable(tf.random.normal((nb_classes, nb_classes)), name='weight4')
        self.b4 = tf.Variable(tf.random.normal((1,)), name='bias4')
        
        self.variables = [self.W1, self.b1, self.W2, self.b2, self.W3, self.b3, self.W4, self.b4]
        
        self.process_data = process_data
        self.neural_net = neural_net
        self.loss_fn = loss_fn
        self.accuracy_fn = accuracy_fn
        
    def deep_nn(self, features):
        layers1 = tf.sigmoid(tf.matmul(features, self.W1) + self.b1)
        layers2 = tf.sigmoid(tf.matmul(layers1, self.W2) + self.b2)
        layers3 = tf.sigmoid(tf.matmul(layers2, self.W3) + self.b3)
        hyp = tf.sigmoid(tf.matmul(layers3, self.W4) + self.b4)
        return hyp
    
    def grad(self, hypothesis, features, labels):
        with tf.GradientTape() as tape:
            loss_value = self.loss_fn(self.deep_nn(features),labels)
        return tape.gradient(loss_value,self.variables)
        
    def fit(self, dataset, epochs = 2000, verbose = 50):
        optimizer = tf.keras.optimizers.SGD(learning_rate = 0.2)
        for step in range(epochs):
            for features, labels in dataset:
                features, labels = self.process_data(features, labels)
                grads = self.grad(self.deep_nn(features), features, labels)
                optimizer.apply_gradients(grads_and_vars=zip(grads, self.variables))
                if step % verbose == 0 :
                    print("iter : {}, loss : {:.4f}".format(step, self.loss_fn(self.deep_nn(features), labels)))

    def test_model(self, x_data, y_data):
        x_data, y_data = self.process_data(x_data, y_data)
        test_acc = self.accuracy_fn(self.deep_nn(x_data), y_data)
        print("test accuracy : {:.4f}".format(test_acc))

In [152]:
model = wide_deep_nn(nb_classes)

In [153]:
model.fit(dataset, epochs=10000, verbose = 500)

iter : 0, loss : 1.0519
iter : 500, loss : 0.6936
iter : 1000, loss : 0.6925
iter : 1500, loss : 0.6908
iter : 2000, loss : 0.6870
iter : 2500, loss : 0.6750
iter : 3000, loss : 0.6334
iter : 3500, loss : 0.5242
iter : 4000, loss : 0.2010
iter : 4500, loss : 0.0727
iter : 5000, loss : 0.0401
iter : 5500, loss : 0.0268
iter : 6000, loss : 0.0199
iter : 6500, loss : 0.0157
iter : 7000, loss : 0.0129
iter : 7500, loss : 0.0109
iter : 8000, loss : 0.0094
iter : 8500, loss : 0.0083
iter : 9000, loss : 0.0074
iter : 9500, loss : 0.0067


In [154]:
model.test_model(x_data, y_data)

test accuracy : 1.0000


## using tensorboard

In [2]:
x_data = [[0, 0],
          [0, 1],
          [1, 0],
          [1, 1]]
y_data = [[0],
          [1],
          [1],
          [0]]

dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch(len(x_data))

W1 = tf.Variable(tf.random.normal((2, 10)), name='weight1')
b1 = tf.Variable(tf.random.normal((10,)), name='bias1')

W2 = tf.Variable(tf.random.normal((10, 10)), name='weight2')
b2 = tf.Variable(tf.random.normal((10,)), name='bias2')

W3 = tf.Variable(tf.random.normal((10, 10)), name='weight3')
b3 = tf.Variable(tf.random.normal((10,)), name='bias3')

W4 = tf.Variable(tf.random.normal((10, 1)), name='weight4')
b4 = tf.Variable(tf.random.normal((1,)), name='bias4')

In [4]:
log_path = "./logs/xor"
writer = tf.summary.create_file_writer(log_path)
writer

<tensorflow.python.ops.summary_ops_v2.ResourceSummaryWriter at 0x1e624735b50>

In [13]:
def new_neural_net(features, step):
    layer1 = tf.sigmoid(tf.matmul(features, W1) + b1)
    layer2 = tf.sigmoid(tf.matmul(layer1, W2) + b2)
    layer3 = tf.sigmoid(tf.matmul(layer2, W3) + b3)
    hypothesis = tf.sigmoid(tf.matmul(layer3, W4) + b4)

    with writer.as_default():
        tf.summary.histogram("weights1", W1, step=step)
        tf.summary.histogram("biases1", b1, step=step)
        tf.summary.histogram("layer1", layer1, step=step)

        tf.summary.histogram("weights2", W2, step=step)
        tf.summary.histogram("biases2", b2, step=step)
        tf.summary.histogram("layer2", layer2, step=step)

        tf.summary.histogram("weights3", W3, step=step)
        tf.summary.histogram("biases3", b3, step=step)
        tf.summary.histogram("layer3", layer3, step=step)

        tf.summary.histogram("weights4", W4, step=step)
        tf.summary.histogram("biases4", b4, step=step)
        tf.summary.histogram("hypothesis", hypothesis, step=step)
    return hypothesis

def new_grad(hypothesis, features, labels, step):
    with tf.GradientTape() as tape:
        loss_value = loss_fn(new_neural_net(features, step),labels)
    return tape.gradient(loss_value, [W1, W2, W3, W4, b1, b2, b3, b4])

In [14]:
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)

In [None]:
EPOCHS = 300

for step in range(EPOCHS):    
    for features, labels  in dataset:
        features, labels = process_data(features, labels)
        grads = new_grad(new_neural_net(features, step), features, labels, step)
        optimizer.apply_gradients(grads_and_vars=zip(grads,[W1, W2, W3, W4, b1, b2, b3, b4]))
        if step % 50 == 0:
            loss_value = loss_fn(new_neural_net(features, step),labels)
            print("Iter: {}, Loss: {:.4f}".format(step, loss_value))

In [18]:
x_data, y_data = process_data(x_data, y_data)
test_acc = accuracy_fn(new_neural_net(x_data, step),y_data)
print("Testset Accuracy: {:.4f}".format(test_acc))

Testset Accuracy: 1.0000


In [19]:
## (Optional) Jupyter Notebook에서 Tensorboard 실행하기

# Load the TensorBoard notebook extension
%load_ext tensorboard

'''Start TensorBoard through the command line or within a notebook experience. 
The two interfaces are generally the same. In notebooks, use the %tensorboard line magic. 
On the command line, run the same command without "%".'''

%tensorboard --logdir logs/xor

In [24]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(10, input_dim=2, activation = tf.nn.sigmoid),
    tf.keras.layers.Dense(20, activation = tf.nn.sigmoid),
    tf.keras.layers.Dense(40,  activation = tf.nn.sigmoid),
    tf.keras.layers.Dense(20, activation = tf.nn.sigmoid),
    tf.keras.layers.Dense(20, activation = tf.nn.sigmoid),
    tf.keras.layers.Dense(1, activation = tf.nn.sigmoid),
])

model.compile(optimizer='adam', loss='binary_crossentropy', metrics = ['binary_accuracy'])

In [25]:
tb_hist = tf.keras.callbacks.TensorBoard(log_dir='./logs/xor_logs', histogram_freq=0, write_graph=True, write_images=True)

model.fit(x_data, y_data, epochs=500, callbacks=[tb_hist])

Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500
Epoch 18/500
Epoch 19/500
Epoch 20/500
Epoch 21/500
Epoch 22/500
Epoch 23/500
Epoch 24/500
Epoch 25/500
Epoch 26/500
Epoch 27/500
Epoch 28/500
Epoch 29/500
Epoch 30/500
Epoch 31/500
Epoch 32/500
Epoch 33/500
Epoch 34/500
Epoch 35/500
Epoch 36/500
Epoch 37/500
Epoch 38/500
Epoch 39/500
Epoch 40/500
Epoch 41/500
Epoch 42/500
Epoch 43/500
Epoch 44/500
Epoch 45/500
Epoch 46/500
Epoch 47/500
Epoch 48/500
Epoch 49/500
Epoch 50/500
Epoch 51/500
Epoch 52/500
Epoch 53/500
Epoch 54/500
Epoch 55/500
Epoch 56/500
Epoch 57/500
Epoch 58/500
Epoch 59/500
Epoch 60/500
Epoch 61/500
Epoch 62/500
Epoch 63/500
Epoch 64/500
Epoch 65/500
Epoch 66/500
Epoch 67/500
Epoch 68/500
Epoch 69/500
Epoch 70/500
Epoch 71/500
Epoch 72/500
Epoch 73/500
Epoch 74/500
Epoch 75/500
Epoch 76/500
Epoch 77/500
Epoch 78

Epoch 80/500
Epoch 81/500
Epoch 82/500
Epoch 83/500
Epoch 84/500
Epoch 85/500
Epoch 86/500
Epoch 87/500
Epoch 88/500
Epoch 89/500
Epoch 90/500
Epoch 91/500
Epoch 92/500
Epoch 93/500
Epoch 94/500
Epoch 95/500
Epoch 96/500
Epoch 97/500
Epoch 98/500
Epoch 99/500
Epoch 100/500
Epoch 101/500
Epoch 102/500
Epoch 103/500
Epoch 104/500
Epoch 105/500
Epoch 106/500
Epoch 107/500
Epoch 108/500
Epoch 109/500
Epoch 110/500
Epoch 111/500
Epoch 112/500
Epoch 113/500
Epoch 114/500
Epoch 115/500
Epoch 116/500
Epoch 117/500
Epoch 118/500
Epoch 119/500
Epoch 120/500
Epoch 121/500
Epoch 122/500
Epoch 123/500
Epoch 124/500
Epoch 125/500
Epoch 126/500
Epoch 127/500
Epoch 128/500
Epoch 129/500
Epoch 130/500
Epoch 131/500
Epoch 132/500
Epoch 133/500
Epoch 134/500
Epoch 135/500
Epoch 136/500
Epoch 137/500
Epoch 138/500
Epoch 139/500
Epoch 140/500
Epoch 141/500
Epoch 142/500
Epoch 143/500
Epoch 144/500
Epoch 145/500
Epoch 146/500
Epoch 147/500
Epoch 148/500
Epoch 149/500
Epoch 150/500
Epoch 151/500
Epoch 152/50

Epoch 157/500
Epoch 158/500
Epoch 159/500
Epoch 160/500
Epoch 161/500
Epoch 162/500
Epoch 163/500
Epoch 164/500
Epoch 165/500
Epoch 166/500
Epoch 167/500
Epoch 168/500
Epoch 169/500
Epoch 170/500
Epoch 171/500
Epoch 172/500
Epoch 173/500
Epoch 174/500
Epoch 175/500
Epoch 176/500
Epoch 177/500
Epoch 178/500
Epoch 179/500
Epoch 180/500
Epoch 181/500
Epoch 182/500
Epoch 183/500
Epoch 184/500
Epoch 185/500
Epoch 186/500
Epoch 187/500
Epoch 188/500
Epoch 189/500
Epoch 190/500
Epoch 191/500
Epoch 192/500
Epoch 193/500
Epoch 194/500
Epoch 195/500
Epoch 196/500
Epoch 197/500
Epoch 198/500
Epoch 199/500
Epoch 200/500
Epoch 201/500
Epoch 202/500
Epoch 203/500
Epoch 204/500
Epoch 205/500
Epoch 206/500
Epoch 207/500
Epoch 208/500
Epoch 209/500
Epoch 210/500
Epoch 211/500
Epoch 212/500
Epoch 213/500
Epoch 214/500
Epoch 215/500
Epoch 216/500
Epoch 217/500
Epoch 218/500
Epoch 219/500
Epoch 220/500
Epoch 221/500
Epoch 222/500
Epoch 223/500
Epoch 224/500
Epoch 225/500
Epoch 226/500
Epoch 227/500
Epoch 

Epoch 234/500
Epoch 235/500
Epoch 236/500
Epoch 237/500
Epoch 238/500
Epoch 239/500
Epoch 240/500
Epoch 241/500
Epoch 242/500
Epoch 243/500
Epoch 244/500
Epoch 245/500
Epoch 246/500
Epoch 247/500
Epoch 248/500
Epoch 249/500
Epoch 250/500
Epoch 251/500
Epoch 252/500
Epoch 253/500
Epoch 254/500
Epoch 255/500
Epoch 256/500
Epoch 257/500
Epoch 258/500
Epoch 259/500
Epoch 260/500
Epoch 261/500
Epoch 262/500
Epoch 263/500
Epoch 264/500
Epoch 265/500
Epoch 266/500
Epoch 267/500
Epoch 268/500
Epoch 269/500
Epoch 270/500
Epoch 271/500
Epoch 272/500
Epoch 273/500
Epoch 274/500
Epoch 275/500
Epoch 276/500
Epoch 277/500
Epoch 278/500
Epoch 279/500
Epoch 280/500
Epoch 281/500
Epoch 282/500
Epoch 283/500
Epoch 284/500
Epoch 285/500
Epoch 286/500
Epoch 287/500
Epoch 288/500
Epoch 289/500
Epoch 290/500
Epoch 291/500
Epoch 292/500
Epoch 293/500
Epoch 294/500
Epoch 295/500
Epoch 296/500
Epoch 297/500
Epoch 298/500
Epoch 299/500
Epoch 300/500
Epoch 301/500
Epoch 302/500
Epoch 303/500
Epoch 304/500
Epoch 

Epoch 311/500
Epoch 312/500
Epoch 313/500
Epoch 314/500
Epoch 315/500
Epoch 316/500
Epoch 317/500
Epoch 318/500
Epoch 319/500
Epoch 320/500
Epoch 321/500
Epoch 322/500
Epoch 323/500
Epoch 324/500
Epoch 325/500
Epoch 326/500
Epoch 327/500
Epoch 328/500
Epoch 329/500
Epoch 330/500
Epoch 331/500
Epoch 332/500
Epoch 333/500
Epoch 334/500
Epoch 335/500
Epoch 336/500
Epoch 337/500
Epoch 338/500
Epoch 339/500
Epoch 340/500
Epoch 341/500
Epoch 342/500
Epoch 343/500
Epoch 344/500
Epoch 345/500
Epoch 346/500
Epoch 347/500
Epoch 348/500
Epoch 349/500
Epoch 350/500
Epoch 351/500
Epoch 352/500
Epoch 353/500
Epoch 354/500
Epoch 355/500
Epoch 356/500
Epoch 357/500
Epoch 358/500
Epoch 359/500
Epoch 360/500
Epoch 361/500
Epoch 362/500
Epoch 363/500
Epoch 364/500
Epoch 365/500
Epoch 366/500
Epoch 367/500
Epoch 368/500
Epoch 369/500
Epoch 370/500
Epoch 371/500
Epoch 372/500
Epoch 373/500
Epoch 374/500
Epoch 375/500
Epoch 376/500
Epoch 377/500
Epoch 378/500
Epoch 379/500
Epoch 380/500
Epoch 381/500
Epoch 

Epoch 388/500
Epoch 389/500
Epoch 390/500
Epoch 391/500
Epoch 392/500
Epoch 393/500
Epoch 394/500
Epoch 395/500
Epoch 396/500
Epoch 397/500
Epoch 398/500
Epoch 399/500
Epoch 400/500
Epoch 401/500
Epoch 402/500
Epoch 403/500
Epoch 404/500
Epoch 405/500
Epoch 406/500
Epoch 407/500
Epoch 408/500
Epoch 409/500
Epoch 410/500
Epoch 411/500
Epoch 412/500
Epoch 413/500
Epoch 414/500
Epoch 415/500
Epoch 416/500
Epoch 417/500
Epoch 418/500
Epoch 419/500
Epoch 420/500
Epoch 421/500
Epoch 422/500
Epoch 423/500
Epoch 424/500
Epoch 425/500
Epoch 426/500
Epoch 427/500
Epoch 428/500
Epoch 429/500
Epoch 430/500
Epoch 431/500
Epoch 432/500
Epoch 433/500
Epoch 434/500
Epoch 435/500
Epoch 436/500
Epoch 437/500
Epoch 438/500
Epoch 439/500
Epoch 440/500
Epoch 441/500
Epoch 442/500
Epoch 443/500
Epoch 444/500
Epoch 445/500
Epoch 446/500
Epoch 447/500
Epoch 448/500
Epoch 449/500
Epoch 450/500
Epoch 451/500
Epoch 452/500
Epoch 453/500
Epoch 454/500
Epoch 455/500
Epoch 456/500
Epoch 457/500
Epoch 458/500
Epoch 

Epoch 465/500
Epoch 466/500
Epoch 467/500
Epoch 468/500
Epoch 469/500
Epoch 470/500
Epoch 471/500
Epoch 472/500
Epoch 473/500
Epoch 474/500
Epoch 475/500
Epoch 476/500
Epoch 477/500
Epoch 478/500
Epoch 479/500
Epoch 480/500
Epoch 481/500
Epoch 482/500
Epoch 483/500
Epoch 484/500
Epoch 485/500
Epoch 486/500
Epoch 487/500
Epoch 488/500
Epoch 489/500
Epoch 490/500
Epoch 491/500
Epoch 492/500
Epoch 493/500
Epoch 494/500
Epoch 495/500
Epoch 496/500
Epoch 497/500
Epoch 498/500
Epoch 499/500
Epoch 500/500


<tensorflow.python.keras.callbacks.History at 0x1e6703ca220>

# Relu, weight initialzation, Dropout, Batch norm.

In [None]:
# xavier, he weight

np.random.randn(fan_in, fan_out) / np.sqrt(fan_in)

np.random.randn(fan_in, fan_out) / np.sqrt(fan_in/2)

In [2]:
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist
from time import time
import os

In [34]:
def load(model, checkpoint_dir):
    print(" [*] Reading checkpoints...")
    
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        checkpoint = tf.train.Checkpoint(dnn=model)
        checkpoint.restore(save_path=os.path.join(checkpoint_dir, ckpt_name))
        counter = int(ckpt_name.split('-')[1])
        print( "[*] success to read {}".format(ckpt_name))
        return True, counter
    else:
        print( "[*] failed to find a checkpoint")
        return False, 0
    
def check_folder(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)
    return dir

In [54]:
def flatten():
    return tf.keras.layers.Flatten()

def dense(label_dim, weight_init):
    return tf.keras.layers.Dense(units=label_dim, use_bias=True, kernel_initializer=weight_init)

def sigmoid():
    return tf.keras.layers.Activation(tf.keras.activations.sigmoid)

def relu():
    return tf.keras.layers.Activation(tf.keras.activations.relu)

In [53]:
def load_mnist():
    (train_data, train_labels), (test_data, test_labels) = mnist.load_data()
    train_data = np.expand_dims(train_data, axis=-1) #채널 위치를 마지막에 생성하는 것 = axis = -1
    # [N, 28, 28]  #채널 1이 생략됨 -> [N, 28, 28, 1] #batch_size, height, width, channel
    test_data = np.expand_dims(test_data, axis=-1)
    
    train_data, test_data = normalize(train_data, test_data)
    
    train_labels = to_categorical(train_labels, 10) # [N,] -> [N, 10] #총 라벨 수 one hot encoding
    test_labels = to_categorical(test_labels, 10)
    
    return train_data, train_labels, test_data, test_labels

def normalize(train_data, test_data):
    train_data = train_data.astype(np.float32) / 255.0
    test_data = test_data.astype(np.float32) / 255.0
    
    return train_data, test_data

def loss_fn(model, images, labels):
    logits = model(images, training=True)
    loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_pred=logits, y_true=labels, from_logits=True)) 
    return loss

def accuracy_fn(model, images, labels):
    logits = model(images, training=False)
    prediction = tf.equal(tf.argmax(logits, -1), tf.argmax(labels,-1)) #각 batch_size를 할 때 가장 큰 값들을 추출해서 같은지 비교
    accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32)) #true, false 값을 숫자로 바꿔줌
    return accuracy

def grad(model, images, labels):
    with tf.GradientTape() as tape:
        loss = loss_fn(model, images, labels)
    return tape.gradient(loss, model.variables)

In [55]:
class create_model_class(tf.keras.Model): #sigmoid versoin
    def __init__(self, label_dim):
        super(create_model_class, self).__init__()
        weight_init = tf.keras.initializers.RandomNormal() #일반 버전 (평균0, 분산1인 정규분포)
        #weight_init = tf.keras.initializers.GlorotUniform() #xavier 버전 #HeUniform() #HE 버전

        self.model = tf.keras.Sequential()
        self.model.add(flatten()) #[N,28,28,1] -> [N,784] #convolution이 아니므로 flatten 필요
        
        for ii in range(2):
            self.model.add(dense(256, weight_init))
            self.model.add(sigmoid())
            
        self.model.add(dense(label_dim, weight_init))
        
    def call(self, x, training=None, mask=None):
        x = self.model(x)
        return x

In [56]:
def create_model_function(label_dim) : #relu version
    weight_init = tf.keras.initializers.RandomNormal()

    model = tf.keras.Sequential()
    model.add(flatten())

    for i in range(2) :
        model.add(dense(256, weight_init))
        model.add(relu())

    model.add(dense(label_dim, weight_init))

    return model

In [63]:
train_x, train_y, test_x, test_y = load_mnist()

learning_rate = 0.001
batch_size = 128

training_epochs = 1
training_iterations = len(train_x) // batch_size

label_dim = 10
train_flag = True

# shuffle : 주어진 train, test 데이터 수보다 크면 됨. 주어진 데이터가 shuffle이 되는 것
# prefetch : batch_size만큼 데이터 미리 올려놓기
train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y)).\
    shuffle(buffer_size = 100000).\
    prefetch(buffer_size=batch_size).\
    batch(batch_size, drop_remainder=True)

test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y)).\
    shuffle(buffer_size = 100000).\
    prefetch(buffer_size=len(test_x)).\
    batch(len(test_x))

In [24]:
# model
network = create_model_function(label_dim) # create_model_class(label_dim)

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

checkpoint_dir = 'checkpoints'
logs_dir = 'logs'

model_dir = 'nn_deep' #'nn_softmax'

checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
check_folder(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, model_dir)
logs_dir = os.path.join(logs_dir, model_dir)

In [25]:
if train_flag : 
    checkpoint = tf.train.Checkpoint(dnn=network)
    
    summary_writer = tf.summary.create_file_writer(logdir=logs_dir)
    start_time = time()
    
    could_load, checkpoint_counter = load(network, checkpoint_dir)
    
    ## restore check-point 
    if could_load:
        start_epoch = (int)(checkpoint_counter / training_iterations)
        counter = checkpoint_counter
        print(" [*] load success")
    else:
        start_epoch = 0
        start_iteration = 0
        counter = 0
        print(" [!] load failed")
    
    #train
    with summary_writer.as_default(): # for tensorboard
        for epoch in range(start_epoch, training_epochs):
            for idx, (train_input, train_label) in enumerate(train_dataset):
                grads = grad(network, train_input, train_label)
                optimizer.apply_gradients(grads_and_vars=zip(grads, network.variables))
                
                train_loss = loss_fn(network, train_input, train_label)
                train_accuracy = accuracy_fn(network, train_input, train_label)
                
                for test_input, test_label in test_dataset:
                    test_accuracy = accuracy_fn(network, test_input, test_label)
                    
                tf.summary.scalar(name='train_loss', data=train_loss, step=counter)
                tf.summary.scalar(name='train_accuracy', data=train_accuracy, step=counter)
                tf.summary.scalar(name='test_accuracy', data=test_accuracy, step=counter)
                
                print("epoch : [%2d] [%5d/%5d] time : %4.4f, \
                      train_loss : %.8f, train_accuracy : %.4f, test_accuracy : %.4f"\
                      % (epoch, idx, training_iterations, time() - start_time, 
                         train_loss, train_accuracy, test_accuracy))
                counter += 1
        checkpoint.save(file_prefix = checkpoint_prefix + '-{}'.format(counter))    
    
else:
    _, _ = load(network, checkpoint_dir)
    for test_input, test_label in test_dataset:
        test_accuracy = accuracy_fn(network, test_input, test_label)
        
    print("test_Accuracy : %.4f" % (test_accuracy))

 [*] Reading checkpoints...
[*] failed to find a checkpoint
 [!] load failed
epoch : [ 0] [    0/  468] time : 0.4924,                       train_loss : 2.17003441, train_accuracy : 0.3984, test_accuracy : 0.2721
epoch : [ 0] [    1/  468] time : 0.6295,                       train_loss : 2.11864138, train_accuracy : 0.5156, test_accuracy : 0.4543
epoch : [ 0] [    2/  468] time : 0.7689,                       train_loss : 2.03002572, train_accuracy : 0.6094, test_accuracy : 0.5071
epoch : [ 0] [    3/  468] time : 0.9121,                       train_loss : 2.00755143, train_accuracy : 0.6172, test_accuracy : 0.5519
epoch : [ 0] [    4/  468] time : 1.0628,                       train_loss : 1.92455721, train_accuracy : 0.6172, test_accuracy : 0.5997
epoch : [ 0] [    5/  468] time : 1.2362,                       train_loss : 1.79077244, train_accuracy : 0.6484, test_accuracy : 0.6269
epoch : [ 0] [    6/  468] time : 1.3814,                       train_loss : 1.64151549, train_accura

epoch : [ 0] [   60/  468] time : 9.3400,                       train_loss : 0.31363285, train_accuracy : 0.9297, test_accuracy : 0.8972
epoch : [ 0] [   61/  468] time : 9.4985,                       train_loss : 0.23443964, train_accuracy : 0.9297, test_accuracy : 0.8961
epoch : [ 0] [   62/  468] time : 9.6502,                       train_loss : 0.42477441, train_accuracy : 0.8672, test_accuracy : 0.9002
epoch : [ 0] [   63/  468] time : 9.8226,                       train_loss : 0.36430979, train_accuracy : 0.8828, test_accuracy : 0.9029
epoch : [ 0] [   64/  468] time : 9.9927,                       train_loss : 0.28142422, train_accuracy : 0.9141, test_accuracy : 0.9037
epoch : [ 0] [   65/  468] time : 10.1633,                       train_loss : 0.41613072, train_accuracy : 0.8359, test_accuracy : 0.9019
epoch : [ 0] [   66/  468] time : 10.3303,                       train_loss : 0.33749321, train_accuracy : 0.8516, test_accuracy : 0.8945
epoch : [ 0] [   67/  468] time : 10.50

epoch : [ 0] [  120/  468] time : 19.9216,                       train_loss : 0.19459637, train_accuracy : 0.9531, test_accuracy : 0.9215
epoch : [ 0] [  121/  468] time : 20.0911,                       train_loss : 0.22545929, train_accuracy : 0.9297, test_accuracy : 0.9213
epoch : [ 0] [  122/  468] time : 20.2657,                       train_loss : 0.27741760, train_accuracy : 0.8906, test_accuracy : 0.9206
epoch : [ 0] [  123/  468] time : 20.4462,                       train_loss : 0.24888684, train_accuracy : 0.8984, test_accuracy : 0.9230
epoch : [ 0] [  124/  468] time : 20.6294,                       train_loss : 0.14010113, train_accuracy : 0.9688, test_accuracy : 0.9230
epoch : [ 0] [  125/  468] time : 20.8073,                       train_loss : 0.26481646, train_accuracy : 0.8906, test_accuracy : 0.9234
epoch : [ 0] [  126/  468] time : 20.9794,                       train_loss : 0.18137035, train_accuracy : 0.9609, test_accuracy : 0.9223
epoch : [ 0] [  127/  468] time : 

epoch : [ 0] [  181/  468] time : 30.7597,                       train_loss : 0.21983993, train_accuracy : 0.9375, test_accuracy : 0.9338
epoch : [ 0] [  182/  468] time : 30.9347,                       train_loss : 0.16975103, train_accuracy : 0.9453, test_accuracy : 0.9312
epoch : [ 0] [  183/  468] time : 31.1073,                       train_loss : 0.22752656, train_accuracy : 0.9453, test_accuracy : 0.9292
epoch : [ 0] [  184/  468] time : 31.2785,                       train_loss : 0.30254561, train_accuracy : 0.9219, test_accuracy : 0.9281
epoch : [ 0] [  185/  468] time : 31.4504,                       train_loss : 0.24360660, train_accuracy : 0.9297, test_accuracy : 0.9305
epoch : [ 0] [  186/  468] time : 31.6384,                       train_loss : 0.12300718, train_accuracy : 0.9531, test_accuracy : 0.9324
epoch : [ 0] [  187/  468] time : 31.8152,                       train_loss : 0.15905406, train_accuracy : 0.9531, test_accuracy : 0.9323
epoch : [ 0] [  188/  468] time : 

epoch : [ 0] [  241/  468] time : 41.3934,                       train_loss : 0.16551644, train_accuracy : 0.9453, test_accuracy : 0.9467
epoch : [ 0] [  242/  468] time : 41.5650,                       train_loss : 0.17163163, train_accuracy : 0.9453, test_accuracy : 0.9458
epoch : [ 0] [  243/  468] time : 41.7387,                       train_loss : 0.25054002, train_accuracy : 0.9531, test_accuracy : 0.9453
epoch : [ 0] [  244/  468] time : 41.9051,                       train_loss : 0.19649269, train_accuracy : 0.9297, test_accuracy : 0.9445
epoch : [ 0] [  245/  468] time : 42.0811,                       train_loss : 0.17635921, train_accuracy : 0.9609, test_accuracy : 0.9435
epoch : [ 0] [  246/  468] time : 42.2491,                       train_loss : 0.09371889, train_accuracy : 0.9766, test_accuracy : 0.9421
epoch : [ 0] [  247/  468] time : 42.4256,                       train_loss : 0.28298607, train_accuracy : 0.9297, test_accuracy : 0.9414
epoch : [ 0] [  248/  468] time : 

epoch : [ 0] [  301/  468] time : 51.8506,                       train_loss : 0.12200910, train_accuracy : 0.9688, test_accuracy : 0.9424
epoch : [ 0] [  302/  468] time : 52.0254,                       train_loss : 0.12722668, train_accuracy : 0.9688, test_accuracy : 0.9403
epoch : [ 0] [  303/  468] time : 52.1956,                       train_loss : 0.30962256, train_accuracy : 0.9141, test_accuracy : 0.9417
epoch : [ 0] [  304/  468] time : 52.3665,                       train_loss : 0.22113830, train_accuracy : 0.9219, test_accuracy : 0.9447
epoch : [ 0] [  305/  468] time : 52.5412,                       train_loss : 0.23899007, train_accuracy : 0.9141, test_accuracy : 0.9489
epoch : [ 0] [  306/  468] time : 52.7111,                       train_loss : 0.19606918, train_accuracy : 0.9453, test_accuracy : 0.9513
epoch : [ 0] [  307/  468] time : 52.8883,                       train_loss : 0.16218829, train_accuracy : 0.9688, test_accuracy : 0.9488
epoch : [ 0] [  308/  468] time : 

epoch : [ 0] [  361/  468] time : 62.3860,                       train_loss : 0.30768526, train_accuracy : 0.9375, test_accuracy : 0.9539
epoch : [ 0] [  362/  468] time : 62.5667,                       train_loss : 0.05968004, train_accuracy : 0.9922, test_accuracy : 0.9508
epoch : [ 0] [  363/  468] time : 62.7396,                       train_loss : 0.16167252, train_accuracy : 0.9297, test_accuracy : 0.9490
epoch : [ 0] [  364/  468] time : 62.9106,                       train_loss : 0.24195534, train_accuracy : 0.9062, test_accuracy : 0.9481
epoch : [ 0] [  365/  468] time : 63.0812,                       train_loss : 0.16209576, train_accuracy : 0.9688, test_accuracy : 0.9494
epoch : [ 0] [  366/  468] time : 63.2524,                       train_loss : 0.18687978, train_accuracy : 0.9375, test_accuracy : 0.9526
epoch : [ 0] [  367/  468] time : 63.4235,                       train_loss : 0.20791379, train_accuracy : 0.9453, test_accuracy : 0.9545
epoch : [ 0] [  368/  468] time : 

epoch : [ 0] [  422/  468] time : 73.5504,                       train_loss : 0.09635494, train_accuracy : 0.9609, test_accuracy : 0.9581
epoch : [ 0] [  423/  468] time : 73.7252,                       train_loss : 0.12973526, train_accuracy : 0.9688, test_accuracy : 0.9572
epoch : [ 0] [  424/  468] time : 73.9004,                       train_loss : 0.12899821, train_accuracy : 0.9688, test_accuracy : 0.9561
epoch : [ 0] [  425/  468] time : 74.0711,                       train_loss : 0.10344458, train_accuracy : 0.9766, test_accuracy : 0.9554
epoch : [ 0] [  426/  468] time : 74.2404,                       train_loss : 0.22725120, train_accuracy : 0.9219, test_accuracy : 0.9576
epoch : [ 0] [  427/  468] time : 74.4186,                       train_loss : 0.10041623, train_accuracy : 0.9844, test_accuracy : 0.9575
epoch : [ 0] [  428/  468] time : 74.6043,                       train_loss : 0.13157025, train_accuracy : 0.9688, test_accuracy : 0.9582
epoch : [ 0] [  429/  468] time : 

epoch : [ 1] [   14/  468] time : 84.4382,                       train_loss : 0.10116888, train_accuracy : 0.9766, test_accuracy : 0.9621
epoch : [ 1] [   15/  468] time : 84.6181,                       train_loss : 0.08485142, train_accuracy : 0.9844, test_accuracy : 0.9628
epoch : [ 1] [   16/  468] time : 84.7996,                       train_loss : 0.10031894, train_accuracy : 0.9609, test_accuracy : 0.9625
epoch : [ 1] [   17/  468] time : 84.9770,                       train_loss : 0.16654216, train_accuracy : 0.9453, test_accuracy : 0.9615
epoch : [ 1] [   18/  468] time : 85.1503,                       train_loss : 0.03840597, train_accuracy : 1.0000, test_accuracy : 0.9612
epoch : [ 1] [   19/  468] time : 85.3295,                       train_loss : 0.07530543, train_accuracy : 0.9688, test_accuracy : 0.9607
epoch : [ 1] [   20/  468] time : 85.5186,                       train_loss : 0.13625729, train_accuracy : 0.9531, test_accuracy : 0.9608
epoch : [ 1] [   21/  468] time : 

epoch : [ 1] [   75/  468] time : 95.5578,                       train_loss : 0.05337565, train_accuracy : 0.9844, test_accuracy : 0.9604
epoch : [ 1] [   76/  468] time : 95.7279,                       train_loss : 0.12064378, train_accuracy : 0.9609, test_accuracy : 0.9600
epoch : [ 1] [   77/  468] time : 95.9006,                       train_loss : 0.11114351, train_accuracy : 0.9609, test_accuracy : 0.9592
epoch : [ 1] [   78/  468] time : 96.0810,                       train_loss : 0.12150731, train_accuracy : 0.9453, test_accuracy : 0.9602
epoch : [ 1] [   79/  468] time : 96.2608,                       train_loss : 0.19126301, train_accuracy : 0.9375, test_accuracy : 0.9605
epoch : [ 1] [   80/  468] time : 96.4387,                       train_loss : 0.14370783, train_accuracy : 0.9531, test_accuracy : 0.9609
epoch : [ 1] [   81/  468] time : 96.6206,                       train_loss : 0.08821221, train_accuracy : 0.9766, test_accuracy : 0.9613
epoch : [ 1] [   82/  468] time : 

epoch : [ 1] [  135/  468] time : 106.2751,                       train_loss : 0.15643750, train_accuracy : 0.9844, test_accuracy : 0.9651
epoch : [ 1] [  136/  468] time : 106.4530,                       train_loss : 0.10141179, train_accuracy : 0.9531, test_accuracy : 0.9635
epoch : [ 1] [  137/  468] time : 106.6372,                       train_loss : 0.09987684, train_accuracy : 0.9766, test_accuracy : 0.9630
epoch : [ 1] [  138/  468] time : 106.8177,                       train_loss : 0.13259697, train_accuracy : 0.9531, test_accuracy : 0.9627
epoch : [ 1] [  139/  468] time : 106.9897,                       train_loss : 0.15069520, train_accuracy : 0.9609, test_accuracy : 0.9641
epoch : [ 1] [  140/  468] time : 107.1696,                       train_loss : 0.13260543, train_accuracy : 0.9531, test_accuracy : 0.9650
epoch : [ 1] [  141/  468] time : 107.3450,                       train_loss : 0.06460799, train_accuracy : 0.9844, test_accuracy : 0.9654
epoch : [ 1] [  142/  468] 

epoch : [ 1] [  194/  468] time : 117.6260,                       train_loss : 0.10373113, train_accuracy : 0.9688, test_accuracy : 0.9657
epoch : [ 1] [  195/  468] time : 117.7959,                       train_loss : 0.09192932, train_accuracy : 0.9688, test_accuracy : 0.9655
epoch : [ 1] [  196/  468] time : 117.9764,                       train_loss : 0.09297705, train_accuracy : 0.9688, test_accuracy : 0.9659
epoch : [ 1] [  197/  468] time : 118.1505,                       train_loss : 0.06428191, train_accuracy : 0.9766, test_accuracy : 0.9661
epoch : [ 1] [  198/  468] time : 118.3240,                       train_loss : 0.18161784, train_accuracy : 0.9375, test_accuracy : 0.9663
epoch : [ 1] [  199/  468] time : 118.5068,                       train_loss : 0.06718634, train_accuracy : 0.9766, test_accuracy : 0.9658
epoch : [ 1] [  200/  468] time : 118.6824,                       train_loss : 0.02800309, train_accuracy : 1.0000, test_accuracy : 0.9657
epoch : [ 1] [  201/  468] 

epoch : [ 1] [  254/  468] time : 128.2895,                       train_loss : 0.07064471, train_accuracy : 0.9688, test_accuracy : 0.9675
epoch : [ 1] [  255/  468] time : 128.4706,                       train_loss : 0.06221455, train_accuracy : 0.9844, test_accuracy : 0.9668
epoch : [ 1] [  256/  468] time : 128.6504,                       train_loss : 0.05690335, train_accuracy : 0.9844, test_accuracy : 0.9668
epoch : [ 1] [  257/  468] time : 128.8281,                       train_loss : 0.10862181, train_accuracy : 0.9766, test_accuracy : 0.9670
epoch : [ 1] [  258/  468] time : 129.0040,                       train_loss : 0.06842864, train_accuracy : 0.9844, test_accuracy : 0.9672
epoch : [ 1] [  259/  468] time : 129.1791,                       train_loss : 0.05488730, train_accuracy : 0.9844, test_accuracy : 0.9674
epoch : [ 1] [  260/  468] time : 129.3572,                       train_loss : 0.04284908, train_accuracy : 0.9922, test_accuracy : 0.9683
epoch : [ 1] [  261/  468] 

epoch : [ 1] [  314/  468] time : 138.8662,                       train_loss : 0.05031300, train_accuracy : 0.9766, test_accuracy : 0.9665
epoch : [ 1] [  315/  468] time : 139.0428,                       train_loss : 0.08400520, train_accuracy : 0.9844, test_accuracy : 0.9656
epoch : [ 1] [  316/  468] time : 139.2249,                       train_loss : 0.10228336, train_accuracy : 0.9531, test_accuracy : 0.9644
epoch : [ 1] [  317/  468] time : 139.3987,                       train_loss : 0.03948309, train_accuracy : 0.9922, test_accuracy : 0.9638
epoch : [ 1] [  318/  468] time : 139.5868,                       train_loss : 0.11303806, train_accuracy : 0.9688, test_accuracy : 0.9646
epoch : [ 1] [  319/  468] time : 139.7651,                       train_loss : 0.09874314, train_accuracy : 0.9688, test_accuracy : 0.9657
epoch : [ 1] [  320/  468] time : 139.9409,                       train_loss : 0.10915279, train_accuracy : 0.9609, test_accuracy : 0.9684
epoch : [ 1] [  321/  468] 

epoch : [ 1] [  374/  468] time : 149.5184,                       train_loss : 0.04675063, train_accuracy : 0.9922, test_accuracy : 0.9669
epoch : [ 1] [  375/  468] time : 149.7058,                       train_loss : 0.07385892, train_accuracy : 0.9688, test_accuracy : 0.9669
epoch : [ 1] [  376/  468] time : 149.8752,                       train_loss : 0.06809467, train_accuracy : 0.9922, test_accuracy : 0.9681
epoch : [ 1] [  377/  468] time : 150.0468,                       train_loss : 0.09476579, train_accuracy : 0.9766, test_accuracy : 0.9686
epoch : [ 1] [  378/  468] time : 150.2236,                       train_loss : 0.09446399, train_accuracy : 0.9766, test_accuracy : 0.9691
epoch : [ 1] [  379/  468] time : 150.3957,                       train_loss : 0.12047806, train_accuracy : 0.9688, test_accuracy : 0.9698
epoch : [ 1] [  380/  468] time : 150.5806,                       train_loss : 0.09142242, train_accuracy : 0.9766, test_accuracy : 0.9711
epoch : [ 1] [  381/  468] 

epoch : [ 1] [  433/  468] time : 160.2611,                       train_loss : 0.07638073, train_accuracy : 0.9688, test_accuracy : 0.9702
epoch : [ 1] [  434/  468] time : 160.4350,                       train_loss : 0.05805014, train_accuracy : 0.9844, test_accuracy : 0.9706
epoch : [ 1] [  435/  468] time : 160.6146,                       train_loss : 0.08730160, train_accuracy : 0.9844, test_accuracy : 0.9708
epoch : [ 1] [  436/  468] time : 160.8053,                       train_loss : 0.07182668, train_accuracy : 0.9844, test_accuracy : 0.9703
epoch : [ 1] [  437/  468] time : 161.0034,                       train_loss : 0.04581804, train_accuracy : 0.9922, test_accuracy : 0.9705
epoch : [ 1] [  438/  468] time : 161.1981,                       train_loss : 0.15365265, train_accuracy : 0.9688, test_accuracy : 0.9709
epoch : [ 1] [  439/  468] time : 161.3987,                       train_loss : 0.08620869, train_accuracy : 0.9766, test_accuracy : 0.9713
epoch : [ 1] [  440/  468] 

## nn_deep, dropout, batchnorm

In [79]:
def dropout(rate) :
    return tf.keras.layers.Dropout(rate)

In [80]:
def batch_norm() :
    return tf.keras.layers.BatchNormalization()

In [102]:
class nn_create_model_class(tf.keras.Model): #sigmoid versoin
    def __init__(self, label_dim):
        super(nn_create_model_class, self).__init__()
        weight_init = tf.keras.initializers.glorot_uniform() #xavier 버전

        self.model = tf.keras.Sequential()
        self.model.add(flatten())
        
        for ii in range(4):
            # 전부 다 시용시 96.71%
            self.model.add(dense(512, weight_init))
            self.model.add(batch_norm()) # batchnorm만 추가 test acc 96.84% 
            self.model.add(relu()) #batchnorm, dropout 제외시 96.88%
            #self.model.add(dropout(rate=0.2)) #해당 비율을 0으로 만드는 것 # dropout 0.5 추가 test acc 95.91 #0.2로 변경시 96.29
            
            #layer -> norm -> activation
            
        self.model.add(dense(label_dim, weight_init))
        
    def call(self, x, training=None, mask=None):
        x = self.model(x)
        return x

In [103]:
# model
network = nn_create_model_class(label_dim) # create_model_class(label_dim)

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

checkpoint_dir = 'checkpoints'
logs_dir = 'logs'

model_dir = 'nn_deep_batch' #'nn_softmax'

checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
check_folder(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, model_dir)
logs_dir = os.path.join(logs_dir, model_dir)

In [105]:
if train_flag :

    checkpoint = tf.train.Checkpoint(dnn=network)

    # create writer for tensorboard
    summary_writer = tf.summary.create_file_writer(logdir=logs_dir)
    start_time = time()

    # restore check-point if it exits
    could_load, checkpoint_counter = load(network, checkpoint_dir)    

    if could_load:
        start_epoch = (int)(checkpoint_counter / training_iterations)        
        counter = checkpoint_counter        
        print(" [*] Load SUCCESS")
    else:
        start_epoch = 0
        start_iteration = 0
        counter = 0
        print(" [!] Load failed...")
    
    # train phase
    with summary_writer.as_default():  # for tensorboard
        for epoch in range(start_epoch, training_epochs):
            for idx, (train_input, train_label) in enumerate(train_dataset):            
                grads = grad(network, train_input, train_label)
                optimizer.apply_gradients(grads_and_vars=zip(grads, network.variables))

                train_loss = loss_fn(network, train_input, train_label)
                train_accuracy = accuracy_fn(network, train_input, train_label)
                
                for test_input, test_label in test_dataset:                
                    test_accuracy = accuracy_fn(network, test_input, test_label)

                tf.summary.scalar(name='train_loss', data=train_loss, step=counter)
                tf.summary.scalar(name='train_accuracy', data=train_accuracy, step=counter)
                tf.summary.scalar(name='test_accuracy', data=test_accuracy, step=counter)

                print(
                    "Epoch: [%2d] [%5d/%5d] time: %4.4f, train_loss: %.8f, train_accuracy: %.4f, test_Accuracy: %.4f" \
                    % (epoch, idx, training_iterations, time() - start_time, train_loss, train_accuracy,
                       test_accuracy))
                counter += 1                
        checkpoint.save(file_prefix=checkpoint_prefix + '-{}'.format(counter))
        
# test phase      
else :
    _, _ = load(network, checkpoint_dir)
    for test_input, test_label in test_dataset:    
        test_accuracy = accuracy_fn(network, test_input, test_label)

    print("test_Accuracy: %.4f" % (test_accuracy))