In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist

In [2]:
# image normalization 
def normalize(train_data, test_data) :
  train_data = train_data.astype(np.float32) / 255.0
  test_data = test_data.astype(np.float32) / 255.0
  return train_data, test_data

# load mnist with preprocessing
def load_mnist() : 
  (train_data,train_labels), (test_data, test_labels) = mnist.load_data()

  # in tensorflow input shape : (batch_size, height, width, channel)
  train_data = np.expand_dims(train_data, axis = -1)  # (N,28,28) -> (N, 28, 28, 1) 
  test_data = np.expand_dims(test_data, axis = -1)  # (N,28,28) -> (N, 28, 28, 1) 

  train_data, test_data = normalize(train_data, test_data)  # 0~ 255 > 0~1

  # label preprocessing
  train_labels = to_categorical(train_labels, 10)  # (n,) -> (n,10) / one-hot encoding / 10 : class 개수
  test_labels = to_categorical(test_labels, 10) 

  return train_data, train_labels, test_data, test_labels

In [3]:
# Create network

def flatten() :
  return tf.keras.layers.Flatten()

# To make Fully connected layer
def dense(channel,weight_init) :
  return tf.keras.layers.Dense(units = channel, use_bias = True, kernel_initializer = weight_init) # units : output으로 나가는 channel 개수 use_bias : bias 사용 여부 kernel_initializer : weight initializer

def relu() :
  return tf.keras.layers.Activation(tf.keras.activations.relu)

def batch_norm() :
  return tf.keras.layers.BatchNormalization()

In [4]:
class create_model(tf.keras.Model) : 
  def __init__(self, label_dim) :
    super(create_model, self).__init__()
    weight_init = tf.keras.initializers.glorot_uniform()  # Xavier initialization  / he_uniform() : he initialization
    self.model = tf.keras.Sequential()   # layer을 층층이 쌓아나가는 것을 list에 계속 더해준다고 할 수 있음 / Sequential : list 자료 구조 type

    self.model.add(flatten())   # (N,28,28,1) -> (n, 784)  / fully connected layer을 이용하기 때문에 flatten 시킴 / CNN 이라면 필요 없음

    for i in range(2) : 
      # (N,784) > (N,256) > (N,256)
      self.model.add(dense(256, weight_init))
      self.model.add(batch_norm())              ## 일반적으로 layer > norm > activation 순서 // norm > activation > layer 순서도 있음
      self.model.add(relu())

    self.model.add(dense(label_dim,weight_init))  #(N,256) -> (N,10)
  
  def call(self, x, training=None, mask=None) :
    x = self.model(x)

    return x