<a href="https://colab.research.google.com/github/jaeyukkim/TF-study/blob/main/DenseNet(MNIST).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
import tensorflow as tf
import tensorflow_datasets as tfds 

from termcolor import colored
from tensorflow.keras.layers import BatchNormalization, Concatenate
from tensorflow.keras.layers import Conv2D, MaxPool2D, Flatten, Dense
from tensorflow.keras.metrics import Mean, SparseCategoricalAccuracy
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.losses import SparseCategoricalCrossentropy

class DenseUnit(Model):
  def __init__(self, filter_out, kernel_size):
    super(DenseUnit, self).__init__()
    self.bn = BatchNormalization()
    self.conv2d = Conv2D(filter_out, kernel_size, padding='same')
    self.concat = Concatenate()

  def call(self, x, training=False, mask=None):
    h = self.bn(x, training==training)
    h = tf.nn.relu(h)
    h = self.conv2d(h)
    h = self.concat([x, h])

    return h


class DenseLayer(Model):
  def __init__(self, num_unit, growth_rate, kernel_size):
    super(DenseLayer, self).__init__()
    self.sequential = list()
    for i in range(num_unit):
      self.sequential.append(DenseUnit(growth_rate, kernel_size))
  
  def call(self, x, training=False, mask=None):
    for layer in self.sequential:
      x = layer(x, training==training)
    
    return x


class TransitionLayer(Model):
  def __init__(self, filters, kernel_size):
    super(TransitionLayer, self).__init__()
    self.conv = Conv2D(filters, kernel_size, padding='same')
    self.pool = MaxPool2D()
  
  def call(self, x, training=False, Mask=None):
    x = self.conv(x)
    x = self.pool(x)

    return x



class DenseNet(Model):
  def __init__(self):
    super(DenseNet, self).__init__()
    self.conv1 = Conv2D(8, (3,3), padding='same', activation='relu')    #28x28x8
    
    self.dense1 = DenseLayer(2, 4, (3,3))       #28x28x16
    self.trans1 = TransitionLayer(16,(3,3))             #14x14x16

    self.dense2 = DenseLayer(2, 8, (3,3))       #14x14x32
    self.trans2 = TransitionLayer(16,(3,3))             #7x7x32

    self.dense3 = DenseLayer(2, 16, (3,3))       #7x7x64

    self.flatten = Flatten()
    self.fully_conected1 = Dense(128, activation='relu')
    self.fully_conected2 = Dense(10, activation='softmax')

    
  def call(self, x, training=False, Mask=None):
    x = self.conv1(x)

    x = self.dense1(x, training==training)
    x = self.trans1(x)

    x = self.dense2(x, training==training)
    x = self.trans2(x)

    x = self.dense3(x, training==training)
  
    x = self.flatten(x)
    x = self.fully_conected1(x)
    x = self.fully_conected2(x)

    return x


#===============================================================================================

def load_dataset():
  (train_validation_ds, test_ds) ,ds_info = tfds.load(name='mnist',
                                                      split=['train', 'test'],
                                                      as_supervised=True,
                                                      with_info=True,
                                                      shuffle_files = True,
                                                      batch_size = None)
  
  n_train_validation_ds = ds_info.splits['train'].num_examples
  train_ratio = 0.8
  n_train = int(n_train_validation_ds * train_ratio)
  n_validation = n_train_validation_ds - n_train

  train_ds = train_validation_ds.take(n_train)
  remain_ds = train_validation_ds.skip(n_train)
  validation_ds = remain_ds.take(n_validation)
 
  return train_ds, validation_ds, test_ds, ds_info


def normalization(TRAIN_BATCH_SIZE, TEST_BATCH_SIZE):
  global train_ds, validation_ds, test_ds
  
  def norm(images, labels):
    images = tf.cast(images, tf.float32) / 255.
    return [images, labels]
  
  train_ds = train_ds.map(norm).shuffle(1000).batch(TRAIN_BATCH_SIZE)
  validation_ds = validation_ds.map(norm).batch(TEST_BATCH_SIZE)
  test_ds = test_ds.map(norm).batch(TEST_BATCH_SIZE)

#-------------------------------------------------------------------------------

def load_matrics():
  global train_loss, validation_loss, test_loss
  global train_acc, validation_acc, test_acc
  
  train_loss = Mean()
  validation_loss = Mean()
  test_loss = Mean()

  train_acc = SparseCategoricalAccuracy()
  validation_acc = SparseCategoricalAccuracy()
  test_acc = SparseCategoricalAccuracy()


@tf.function
def training():
  global train_ds, train_loss, train_acc
  global loss_object, optimizer, model

  for images, labels in train_ds:
    with tf.GradientTape() as tape:
      predictions = model(images, training=True)
      loss = loss_object(labels, predictions)
    
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_acc(labels, predictions)


@tf.function
def validation():
  global validation_ds, validation_acc, validation_loss
  global loss_object, model

  for images, labels in validation_ds:      
    predictions = model(images, training=False)
    loss = loss_object(labels, predictions)
      
    validation_loss(loss)
    validation_acc(labels, predictions)


@tf.function
def tester():
  global test_ds, test_acc, test_loss
  global loss_object, model

  for images, labels in test_ds:      
    predictions = model(images, training=False)
    loss = loss_object(labels, predictions)
      
    test_loss(loss)
    test_acc(labels, predictions)


def train_result_and_reset_state():
  global epoch
  global train_loss, train_acc
  global validation_loss, validation_acc

  print(colored('Epochs', 'red', 'on_white'), epoch + 1)
  temp = 'Train Loss : {:.4f}\t Train Accuracy : {:.2f}%\n' +\
         'Validation Loss : {:.4f}\t Validation Accuracy : {:.2f}%\n'
  
  print(temp.format(train_loss.result(),
                    train_acc.result()*100,
                    validation_loss.result(),
                    validation_acc.result()*100))

  train_loss.reset_states()
  train_acc.reset_states()
  validation_loss.reset_states()
  validation_acc.reset_states()


EPOCHS = 20
#LR = 0.001
TRAIN_BATCH_SIZE = 100
TEST_BATCH_SIZE = 100

optimizer = Adam()
loss_object = SparseCategoricalCrossentropy()

train_ds, validation_ds, test_ds, ds_info = load_dataset()
normalization(TRAIN_BATCH_SIZE, TEST_BATCH_SIZE)

model = DenseNet()
model.build(input_shape=(None, 28, 28, 1))
load_matrics()

for epoch in range(EPOCHS):  
  training()
  validation()
  train_result_and_reset_state()

tester()
print(colored('Epochs', 'cyan', 'on_white') , epoch + 1)
print('============Test Result============')
temp = 'TEST LOSS : {:.4f}\t TEST ACC : {:.2f}%\n'
print(temp.format(test_loss.result(),
                  test_acc.result()*100))


[47m[31mEpochs[0m 1
Train Loss : 0.1453	 Train Accuracy : 95.60%
Validation Loss : 0.0473	 Validation Accuracy : 98.43%

[47m[31mEpochs[0m 2
Train Loss : 0.0479	 Train Accuracy : 98.49%
Validation Loss : 0.0571	 Validation Accuracy : 98.19%

[47m[31mEpochs[0m 3
Train Loss : 0.0313	 Train Accuracy : 99.01%
Validation Loss : 0.0487	 Validation Accuracy : 98.72%

[47m[31mEpochs[0m 4
Train Loss : 0.0261	 Train Accuracy : 99.15%
Validation Loss : 0.0407	 Validation Accuracy : 98.77%

[47m[31mEpochs[0m 5
Train Loss : 0.0201	 Train Accuracy : 99.33%
Validation Loss : 0.0422	 Validation Accuracy : 98.73%

[47m[31mEpochs[0m 6
Train Loss : 0.0185	 Train Accuracy : 99.45%
Validation Loss : 0.0719	 Validation Accuracy : 98.12%

[47m[31mEpochs[0m 7
Train Loss : 0.0205	 Train Accuracy : 99.36%
Validation Loss : 0.0618	 Validation Accuracy : 98.33%

[47m[31mEpochs[0m 8
Train Loss : 0.0172	 Train Accuracy : 99.47%
Validation Loss : 0.0608	 Validation Accuracy : 98.53%

[47m[3