<a href="https://colab.research.google.com/github/jaeyukkim/TF-study/blob/main/Untitled46_ipynb%EC%9D%98_%EC%82%AC%EB%B3%B8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from termcolor import colored
import tensorflow_datasets as tfds 

from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D
from tensorflow.keras.layers import Activation, Flatten, Dropout
from tensorflow.keras.metrics import SparseCategoricalAccuracy, Mean
from tensorflow.keras.models import Model, Sequential

class Residualunit(tf.keras.Model):
  def __init__(self, filter_in, filter_out, kernel_size):
    super(Residualunit, self).__init__()
    self.bn1 = BatchNormalization()
    self.conv1 = Conv2D(filter_out, kernel_size, padding='same')

    self.bn2 = BatchNormalization()
    self.conv2 = Conv2D(filter_out, kernel_size, padding='same')

    if filter_in == filter_out:
      self.inentity = lambda x: x
    else:
      self.identity = Conv2D(filter_out, (1,1), padding='same')

  def call(self, x, training=False, mask=None):
    h = self.bn1(x, training=training)
    h = tf.nn.relu(h)
    h = self.conv1(h)

    h = self.bn2(h, training=training)
    h = tf.nn.relu(h)
    h = self.conv2(h)

    return self.identity(x) + h


class ResnetLayer(Model):
  def __init__(self, filter_in, filters, kernel_size):
    super(ResnetLayer, self).__init__()
    self.sequence = list()
    for f_in, f_out in zip([filter_in] + list(filters), filters):
      self.sequence.append(Residualunit(f_in, f_out, kernel_size))

  def call(self, x, training=False, mask=None):
    for unit in self.sequence:
      x = unit(x, training=training)
      return x


class ResNet(Model):
  def __init__(self):
    super(ResNet, self).__init__()
    self.conv1 = Conv2D(8, (3, 3), padding = 'same', activation = 'relu') #28X28X8

    self.res1 = ResnetLayer(8, (16,16), (3,3)) #28X28X16
    self.pool1 = MaxPool2D((2,2)) #14X14X16

    self.res2 = ResnetLayer(16, (32,32), (3,3)) #14X14X32
    self.pool2 = MaxPool2D((2,2)) #7X7X32

    self.res3 = ResnetLayer(32, (64,64), (3,3)) #7X7X64
    self.flatten = Flatten()

    self.dense1 = Dense(128, activation = 'relu')
    self.dense2 = Dense(10, activation = 'softmax')


  def call(self, x, training=False, mask=None):
    x = self.conv1(x)

    x = self.res1(x, training=training)
    x = self.pool1(x)

    x = self.res2(x, training=training)
    x = self.pool2(x)

    x = self.res3(x, training=training)
    x = self.flatten(x)

    x = self.dense1(x)
    x = self.dense2(x)
    
    return x


#===============================================================================================

def load_dataset():
  (train_validation_ds, test_ds) ,ds_info = tfds.load(name='mnist',
                                                      split=['train', 'test'],
                                                      as_supervised=True,
                                                      with_info=True,
                                                      shuffle_files = True,
                                                      batch_size = None)
  
  n_train_validation_ds = ds_info.splits['train'].num_examples
  train_ratio = 0.8
  n_train = int(n_train_validation_ds * train_ratio)
  n_validation = n_train_validation_ds - n_train

  train_ds = train_validation_ds.take(n_train)
  remain_ds = train_validation_ds.skip(n_train)
  validation_ds = remain_ds.take(n_validation)
 
  return train_ds, validation_ds, test_ds, ds_info


def normalization(TRAIN_BATCH_SIZE, TEST_BATCH_SIZE):
  global train_ds, validation_ds, test_ds
  
  def norm(images, labels):
    images = tf.cast(images, tf.float32) / 255.
    return [images, labels]
  
  train_ds = train_ds.map(norm).shuffle(1000).batch(TRAIN_BATCH_SIZE)
  validation_ds = validation_ds.map(norm).batch(TEST_BATCH_SIZE)
  test_ds = test_ds.map(norm).batch(TEST_BATCH_SIZE)

#-------------------------------------------------------------------------------

def load_matrics():
  global train_loss, validation_loss, test_loss
  global train_acc, validation_acc, test_acc, loss_object
  
  train_loss = Mean()
  validation_loss = Mean()
  test_loss = Mean()

  train_acc = SparseCategoricalAccuracy()
  validation_acc = SparseCategoricalAccuracy()
  test_acc = SparseCategoricalAccuracy()


@tf.function
def training():
  global train_ds, train_loss, train_acc
  global loss_object, optimizer, model

  for images, labels in train_ds:
    with tf.GradientTape() as tape:
      predictions = model(images, training=True)
      loss = loss_object(labels, predictions)
    
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_acc(labels, predictions)


@tf.function
def validation():
  global validation_ds, validation_acc, validation_loss
  global loss_object, model

  for images, labels in validation_ds:      
    predictions = model(images, training=False)
    loss = loss_object(labels, predictions)
      
    validation_loss(loss)
    validation_acc(labels, predictions)


@tf.function
def tester():
  global test_ds, test_acc, test_loss
  global loss_object, model

  for images, labels in test_ds:      
    predictions = model(images, training=False)
    loss = loss_object(labels, predictions)
      
    test_loss(loss)
    test_acc(labels, predictions)


def train_result_and_reset_state():
  global epoch
  global train_loss, train_acc
  global validation_loss, validation_acc

  print(colored('Epochs', 'red', 'on_white'), epoch + 1)
  temp = 'Train Loss : {:.4f}\t Train Accuracy : {:.2f}%\n' +\
         'Validation Loss: {:.4f}\t Validation Accuracy : {:.2f}%\n'
  
  print(temp.format(train_loss.result(),
                    train_acc.result()*100,
                    validation_loss.result(),
                    validation_acc.result()*100))

  train_loss.reset_states()
  train_acc.reset_states()
  validation_loss.reset_states()
  validation_acc.reset_states()


EPOCHS = 20
#LR = 0.001
TRAIN_BATCH_SIZE = 100
TEST_BATCH_SIZE = 100

optimizer = Adam()
loss_object = SparseCategoricalCrossentropy()

train_ds, validation_ds, test_ds, ds_info = load_dataset()
normalization(TRAIN_BATCH_SIZE, TEST_BATCH_SIZE)

model = ResNet()
model.build(input_shape=(None, 28, 28, 1))
load_matrics()

for epoch in range(EPOCHS):  
  training()
  validation()
  train_result_and_reset_state()

tester()
print(colored('Epochs', 'cyan', 'on_white') , epoch + 1)
print('============Test Result============')
temp = 'TEST LOSS : {:.4f}\t TEST ACC : {:.2f}%\n'
print(temp.format(test_loss.result(),
                  test_acc.result()*100))


[47m[31mEpochs[0m 1
Train Loss : 0.1498	 Train Accuracy : 95.60%
Validation Loss: 0.0693	 Validation Accuracy : 97.73%

[47m[31mEpochs[0m 2
Train Loss : 0.0461	 Train Accuracy : 98.61%
Validation Loss: 0.0459	 Validation Accuracy : 98.65%

[47m[31mEpochs[0m 3
Train Loss : 0.0320	 Train Accuracy : 98.98%
Validation Loss: 0.0376	 Validation Accuracy : 98.93%

[47m[31mEpochs[0m 4
Train Loss : 0.0330	 Train Accuracy : 98.95%
Validation Loss: 0.0412	 Validation Accuracy : 98.78%

[47m[31mEpochs[0m 5
Train Loss : 0.0236	 Train Accuracy : 99.26%
Validation Loss: 0.0587	 Validation Accuracy : 98.12%

[47m[31mEpochs[0m 6
Train Loss : 0.0197	 Train Accuracy : 99.40%
Validation Loss: 0.0565	 Validation Accuracy : 98.52%

[47m[31mEpochs[0m 7
Train Loss : 0.0181	 Train Accuracy : 99.43%
Validation Loss: 0.0575	 Validation Accuracy : 98.72%

[47m[31mEpochs[0m 8
Train Loss : 0.0210	 Train Accuracy : 99.40%
Validation Loss: 0.0881	 Validation Accuracy : 98.10%

[47m[31mEpochs