In [1]:
# 모듈 임포트
import tensorflow as tf
import numpy as np

## 데이터셋 준비


In [2]:
(x_train,y_train),(x_valid,y_valid) = tf.keras.datasets.mnist.load_data()

x_train = x_train[...,tf.newaxis].astype(np.float32)/255.0
x_valid = x_valid[...,tf.newaxis].astype(np.float32)/255.0

train_data = tf.data.Dataset.from_tensor_slices((x_train,y_train)).shuffle(1000).batch(128)
valid_data = tf.data.Dataset.from_tensor_slices((x_valid,y_valid)).batch(32)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


## 모델링


### DenseUnit 구현

In [3]:
class DenseUnit(tf.keras.models.Model):
    def __init__(self,filter_out,kernel_size):
        super(DenseUnit,self).__init__()
        self.bn = tf.keras.layers.BatchNormalization()
        self.conv = tf.keras.layers.Conv2D(filter_out,kernel_size,padding='same')
        self.concat = tf.keras.layers.Concatenate()
    
    def call(self,input_,training=False):
        x = self.bn(input_,training=training)
        x = tf.nn.relu(x)
        x = self.conv(x)
        x = self.concat([input_,x])
        return x

### DenseLayer 구현

In [4]:
class DenseLayer(tf.keras.models.Model):
    def __init__(self,num_unit,growth_rate,kernel_size):
        super(DenseLayer,self).__init__()
        self.sequence = []
        for _ in range(num_unit):
            self.sequence.append(DenseUnit(growth_rate,kernel_size))
    
    def call(self,x,training=False):
        for unit in self.sequence:
            x = unit(x,training=training)
        return x

### Transition Layer 구현

In [5]:
# growth rate가 클때 채널수가 급격히 증가하는것을 채널수를 조절하여 방지하기위해 사용
class TransitionLayer(tf.keras.models.Model):
    def __init__(self,filter_out,kernel_size):
        super(TransitionLayer,self).__init__()
        self.conv = tf.keras.layers.Conv2D(filter_out,kernel_size,padding='same')
        self.maxpool = tf.keras.layers.MaxPool2D()
    
    def call(self,input_):
        x = self.conv(input_)
        x = self.maxpool(x)
        return x

### 모델 정의

In [6]:
class DenseNet(tf.keras.models.Model):
    def __init__(self):
        super(DenseNet,self).__init__()
        self.conv1 =  tf.keras.layers.Conv2D(8,(3,3),padding='same',activation='relu') # 28x28x8

        self.dl1 = DenseLayer(2,4,(3,3)) # 28x28x16
        self.tr1 = TransitionLayer(16,(3,3)) # 14x14x16

        self.dl2 = DenseLayer(2,8,(3,3)) # 14x14x32
        self.tr2 = TransitionLayer(32,(3,3)) # 7x7x32

        self.dl3 = DenseLayer(2,16,(3,3)) # 7x7x64

        self.flatten = tf.keras.layers.Flatten()
        self.dense = tf.keras.layers.Dense(128,activation='relu')
        self.output_ = tf.keras.layers.Dense(10,activation='softmax')

    def call(self,input_,training=False):
        x = self.conv1(input_)
        x = self.dl1(x,training=training)
        x = self.tr1(x)
        x = self.dl2(x,training=training)
        x = self.tr2(x)
        x = self.dl3(x,training=training)
        x = self.flatten(x)
        x = self.dense(x)
        x = self.output_(x)
        return x

In [7]:
model = DenseNet()
input_ = tf.keras.layers.Input(shape=(28,28,1))
model(input_)

model.summary()

Model: "dense_net"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              multiple                  80        
_________________________________________________________________
dense_layer (DenseLayer)     multiple                  808       
_________________________________________________________________
transition_layer (Transition multiple                  2320      
_________________________________________________________________
dense_layer_1 (DenseLayer)   multiple                  3056      
_________________________________________________________________
transition_layer_1 (Transiti multiple                  9248      
_________________________________________________________________
dense_layer_2 (DenseLayer)   multiple                  11872     
_________________________________________________________________
flatten (Flatten)            multiple                  0 

## 모델 컴파일,학습

In [8]:
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])

In [9]:
model.fit(train_data,
          validation_data=(valid_data),
          epochs =10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7fb86203b7d0>