In [4]:
from keras import models, backend, datasets, utils
from keras.layers import (Input, Conv2D, MaxPool2D, Dropout, Activation, UpSampling2D,
                          BatchNormalization, Concatenate)

# UNET 모델링

In [3]:
class UNET(models.Model):
    def __init__(self, org_shape, n_ch):
        ic = 3 if backend.image_data_format() == "channels_last" else 1
        
        def conv(x, n_f, mp_flag = True):
            x = MaxPool2D((2, 2), padding = "same")(x) if mp_flag else x
            x = Conv2D(n_f, (3, 3), padding = "same")(x)
            x = BatchNormalization()(x)
            x = Activation("tanh")(x)
            x = Dropout(0.05)(x)
            x = Conv2D(n_f, (3, 3), padding = "same")(x)
            x = BatchNormalization()(x)
            x = Activation("tanh")(x)
            return x

        def deconv_unet(x, e, n_f):
            x = UpSampling2D((2, 2))(x)
            x = Concatenate(axis = ic)([x, e])
            x = Conv2D(n_f, (3, 3), padding = "same")(x)
            x = BatchNormalization()(x)
            x = Activation("tanh")(x)
            x = Conv2D(n_f, (3, 3), padding = "same")(x)
            x = BatchNormalization()(x)
            x = Activation("tanh")(x)
            return x

        original = Input(shape = org_shape)
        # Encoding
        c1 = conv(original, 16, mp_flag = False)
        c2 = conv(c1, 32)

        encoded = conv(c2, 64)

        # Decoding
        x = deconv_unet(encoded, c2, 32)
        x = deconv_unet(x, c1, 16)

        decoded = Conv2D(n_ch, (3, 3), activation = "sigmoid", padding = "same")(x)

        super().__init__(original, decoded)
        self.compile(optimizer = "adadelta", loss = "mse")

# 데이터 준비

In [5]:
class DATA():
    def __init__(self, in_ch = None):
        (x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
        
        if x_train.ndim == 4:
            if backend.image_data_format() == "channels_first":
                n_ch, img_rows, img_cols = x_train.shape[1:]
            else:
                img_rows, img_cols, n_ch = x_train.shape[1:]
                
        else:
            img_rows, img_cols = x_train.shape[1:]
            n_ch = 1
            
        in_ch = n_ch if in_ch is None else in_ch
        
        x_train = x_train.astype("float32")
        x_test = x_test.astype("float32")
        x_train /= 255
        x_test /= 255
        
        # 컬러 -> 흑백 변환 함수
        def RGB2Gray(X, fmt):
            if fmt == "channels_first":
                R = X[:, 0:1]
                G = X[:, 1:2]
                B = X[:, 2:3]
            else:
                R = X[..., 0:1]
                G = X[..., 1:2]
                B = X[..., 2:3]
            return 0.299 * R + 0.587 * G + 0.114 * B
        
        if backend.image_data_format() == "channels_first":
            x_train_out = x_train.reshape(x_train.shape[0], n_ch, img_rows, img_cols)
            x_test_out = x_test.reshape(x_test.shape[0], n_ch, img_rows, img_cols)
            input_shape = (in_ch, img_rows, img_cols)
            
        else:
            x_train_out = x_train.reshape(x_train.shape[0], img_rows, img_cols, n_ch)
            x_test_out = x_test.reshape(x_test.shape[0], img_rows, img_cols, n_ch)
            input_shape = (img_rows, img_cols, in_ch)
            
        # 그레이 스케일 변환 적용
        if in_ch == 1 and n_ch == 3:
            x_train_in = RGB2Gray(x_train_out, backend.image_data_format())
            x_test_in = RGB2Gray(x_test_out, backend.image_data_format())
        else:
            x_train_in = x_train_out
            x_test_in = x_test_out
            
        # 인스턴스 변수 변환
        self.input_shape = input_shape
        self.x_train_in, self.x_train_out = x_train_in, x_train_out
        self.x_test_in, self.x_test_out = x_test_in, x_test_out
        self.n_ch = n_ch
        self.in_ch = in_ch

# 학습

In [6]:
data = DATA(in_ch = 1)

In [7]:
unet = UNET(data.input_shape, data.n_ch)






In [None]:
history = unet.fit(data.x_train_in, data.x_train_out,
                   epochs = 10,
                   batch_size = 128,
                   shuffle = True,
                   validation_split = 0.2)

Epoch 1/10

Epoch 2/10
Epoch 3/10
Epoch 4/10