# 1. Starting point

Input: (256, 256, 3) Chest X-ray image 566개
<br>
Label: 좌우 lung 마스킹되어 있는 흑백 이미지 566개
<br><br>
이미지 기준 좌측 폐: Right lung
이미지 기준 우측 폐: Left lung
<br>
※ Task<br>
&nbsp;&nbsp;&nbsp;    - Left lung, Right lung 분할(2 classes + 1 class(background)) <br>
&nbsp;&nbsp;&nbsp;    - 훈련 & 추론 + 후처리(보간법) 노이즈 제거(예측 성능 ↑)

## 데이터 구조

data <br>
&nbsp;&nbsp;    └ image/ <br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;        resize_CHNCXR_0001_0.png <br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;        resize_CHNCXR_0002_0.png <br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;        resize_CHNCXR_0003_0.png <br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;        ... <br>
&nbsp;&nbsp;    └ label/                     -> 사전 작업 후에는 사용하지 않습니다. <br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;        resize_CHNCXR_0001_0.png <br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;        resize_CHNCXR_0002_0.png <br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;        resize_CHNCXR_0003_0.png <br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;        ...    

## 사전 작업

label 폴더의 마스킹 이미지를 각각 좌측 폐('label_rl/l/*'), 우측 폐'label_rl/r/*'로 분할하여 어노테이션 처리하였습니다.<br>
추후 작업 시 라벨링을 배경: 0, 좌측 폐: 1, 우측 폐: 2로 하고 다시 one-hot encoding 작업으로 분할하여 진행한다.

# 2. Library Import

In [1]:
import os
import time
import datetime
import pickle
import statistics
from tqdm import tqdm
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

from skimage.io import imread
from skimage.transform import resize
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow.keras.layers import Add, Input, Dense, Conv2D, Flatten, MaxPool2D, UpSampling2D
from tensorflow.keras.layers import Conv2DTranspose, Concatenate, BatchNormalization
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

# 3. Data Preprocessing

In [2]:
# fix random seed for reproductibility
seed = 777
np.random.seed(seed)
tf.random.set_seed(seed)

# hyper parameter
IMG_WIDTH = 256
IMG_HEIGHT = 256
IMG_CHANNELS = 3
N_CLASSES = 2
EPOCHS = 20
BATCH_SIZE = 32

In [3]:
# Input, Label making
input_path = './data/image/'
input_files = os.listdir(input_path)
label_path = './data/label/'
label_files = os.listdir(label_path)
half_width = int(IMG_WIDTH/2)

X_all = np.zeros((len(input_files), IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS),
                  dtype=np.uint8) # Input image
y_all = np.zeros((len(label_files), IMG_HEIGHT, IMG_WIDTH, N_CLASSES),
                  dtype=np.bool)  # Label(Mask)

input_count = 0
for input_file in input_files:
    input_img = cv2.imread(input_path +  input_file, cv2.IMREAD_GRAYSCALE)
    input_img = np.expand_dims(input_img, axis=-1)
    
    X_all[input_count] = input_img
    input_count += 1
    

cnt_max = []
label_count = 0    
for label_file in label_files:
    label_img = cv2.imread(label_path + label_file, cv2.IMREAD_GRAYSCALE)
    
    cnt, mass = cv2.connectedComponents(label_img)
    labels = np.zeros_like(mass) # (256, 256)
    
    for label_cnt, j in enumerate(range(cnt)):
        labels[mass==j] = label_cnt + 10
    
    labels_info = np.unique(labels, return_counts=True)
    labels_index = labels_info[0]
    labels_value = labels_info[1]           
    
    labels_left = labels[:, :half_width]
    left_info = np.unique(labels_left, return_counts=True)
    left_index = left_info[0]
    left_value = left_info[1]
    left_sort = left_value.argsort()
    left_pick = left_index[left_sort[-2]]
    
    labels_right = labels[:, half_width:]
    right_info = np.unique(labels_right, return_counts=True)
    right_index = right_info[0]
    right_value = right_info[1]
    right_sort = right_value.argsort()
    right_pick = right_index[right_sort[-2]]
    
    labels[labels==left_pick] = 1
    labels[labels==right_pick] = 2
    labels[((labels != 1) & (labels != 2))] = 0
    
    labels = tf.keras.utils.to_categorical(y=labels, num_classes=3)
    
    left_ = np.expand_dims(labels[:,:,1], axis=-1)
    right_ = np.expand_dims(labels[:,:,2], axis=-1)
    
    labels_ = np.concatenate((left_, right_), axis=2)
    
    y_all[label_count] = labels_
    label_count += 1

In [4]:
print('X_all shape: ', X_all.shape)
print('y_all shape: ', y_all.shape)

X_all shape:  (566, 256, 256, 3)
y_all shape:  (566, 256, 256, 2)


In [5]:
X_all = X_all.astype('float32') / 255.
y_all = y_all.astype('float32')

In [6]:
# data slicing
X_train, X_test, y_train, y_test = train_test_split(X_all, y_all, test_size=0.2)
X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.2)

print('X_train shape: ', X_train.shape)
print('X_valid shape: ', X_valid.shape)
print('X_test shape: ', X_test.shape)
print('y_train shape: ', y_train.shape)
print('y_valid shape: ', y_valid.shape)
print('y_test shape: ', y_test.shape)

X_train shape:  (361, 256, 256, 3)
X_valid shape:  (91, 256, 256, 3)
X_test shape:  (114, 256, 256, 3)
y_train shape:  (361, 256, 256, 2)
y_valid shape:  (91, 256, 256, 2)
y_test shape:  (114, 256, 256, 2)


In [7]:
# def plotTrainData(a,b,c):
#     for i in range(3):
#         ix = np.random.randint(0, len(a))
#         plt.subplot(1,2,1)
#         plt.title("X_" + c)
#         plt.imshow(a[ix])
#         plt.axis('off')
#         plt.subplot(1,2,2)
#         plt.title("y_" + c)
#         plt.imshow(np.squeeze(b[ix]))#, 'gray')
#         plt.axis('off')
#         plt.show()
        
# plotTrainData(X_train,y_train, 'train')
# plotTrainData(X_valid,y_valid, 'valid')
# plotTrainData(X_test,y_test, 'test')

# 4. Modeling(U-Net)

![대체 텍스트](https://www.renom.jp/notebooks/tutorial/image_processing/u-net/unet.png)

## 4-1. Model

In [8]:
# U-Net model
# 신경망의 끝단을 MLP가 아닌 CNN을 채택함으로서 기존의 MLP에서의 Flatten으로
# 인한 이미지 특성의 보존이 약해지는 것을 보완하고자 하였다.
# Input: (H, W, C)
# Output: FCN을 사용하여,
#         합성곱(분류 클래스 개수, kernel_size, activation)
# 이미지 해상도를 Maxpooling(<-> Upsampling)하여 줄여나가다(채널은 증가)
# Upsampling으로 작아진 해상도를 늘리며 주변 픽셀을 예측해 값을 채운다.
# 가장 특징적인 점은 Skip connection 기법으로 은닉층을 거칠수록 피쳐맵이 형이상학적인 모양을 띄어가는데
# 그 이전에 초반 레이어 단에서의 비교적 단순한 피쳐맵(수평선, 수직선, 곡선같은 모양)을 Concatenate하여
# 후반 레이어 단에 연결시켜 가중치를 더하는 역할을 한다.
def unet(input_size=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)):
    inputs = Input(input_size)
    
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    pool1 = MaxPool2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    pool2 = MaxPool2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    pool3 = MaxPool2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    pool4 = MaxPool2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = BatchNormalization()(conv5)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)
    
    up6 = Concatenate()([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4])
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)
    
    up7 = Concatenate()([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3])
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)
    
    up8 = Concatenate()([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2])
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)
    
    up9 = Concatenate()([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1])
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)

    conv10 = Conv2D(N_CLASSES, (1, 1), activation='sigmoid')(conv9)

    return Model(inputs=[inputs], outputs=[conv10])

## 4-2. Compile & Fit

In [9]:
# Loss function 계산
# dice coefficient영역이 얼마나 겹치는지를(교집합) 판단하여 오차를 계산한다.(=F1 score)
def dice_coef(y_true, y_pred):
    intersection = tf.reduce_sum(y_true * y_pred)
    return (2. * intersection + 1) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + 1)

def dice_coef_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

In [None]:
model = unet()
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss=dice_coef_loss,
              optimizer=sgd,
              # optimizer='Adam',
              metrics=[dice_coef])

# model.summary()

start = time.time()

# FIT THE MODEL - OPTIMIZATION
hist = model.fit(X_train, y_train,
                 validation_data=(X_valid, y_valid),
                 epochs=EPOCHS,
                 batch_size=BATCH_SIZE,
                 verbose=1)
#                  callbacks=[modelcheckpoint, earlystopping])

runtime = str(datetime.timedelta(seconds=time.time()-start)).split(".")
runtime = runtime[0]
print(f"Fitting Runtime: {runtime}")

model.save(f"./model/{time.strftime('%Y%m%d-%H%M%S')}.h5")

Epoch 1/20

### i) 모델 존재(로드, 컴파일) ii) 모델 X(훈련, 저장)

In [None]:
model_h5_path = './model/20211129-151140.h5'
# print(f"model_h5_path:  {model_h5_path}")

if not os.path.isdir('./model'):
    os.mkdir('./model')

if os.path.isfile(model_h5_path):
    print("Model already saved and loaded..")
    model = load_model(model_h5_path, custom_objects={'dice_coef':dice_coef, 'dice_coef_loss':dice_coef_loss})

    sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
    model.compile(loss=dice_coef_loss,
                  optimizer=sgd,
                  metrics=[dice_coef])

    model.summary()

else: # model.fit
    print("Model initial setting(compile & fit) start..")
    
    # earlystopping = EarlyStopping(monitor='val_loss', # 'dice_coef_loss를 custom 적용.. 찾기'
    #                               patience=10)
    # modelcheckpoint = ModelCheckpoint(f"./model_ckpt/{time.strftime('%Y%m%d-%H%M%S')}.h5",
    #                                   monitor='val_loss',
    #                                   verbose=1,
    #                                   save_best_only=True,
    #                                   mode='auto')

    # build the model
    model = unet()
    sgd = SGD(lr=1e-2, decay=1e-6, momentum=0.9, nesterov=True)
    model.compile(loss=dice_coef_loss,
                  optimizer=sgd,
                  # optimizer='Adam',
                  metrics=[dice_coef])
    
    model.summary()

    start = time.time()

    # FIT THE MODEL - OPTIMIZATION
    hist = model.fit(X_train, y_train,
                     validation_data=(X_valid, y_valid),
                     epochs=EPOCHS,
                     batch_size=BATCH_SIZE,
                     verbose=1)
    #                  callbacks=[modelcheckpoint, earlystopping])

    runtime = str(datetime.timedelta(seconds=time.time()-start)).split(".")
    runtime = runtime[0]
    print(f"Fitting Runtime: {runtime}")

    model.save(f"./model/{time.strftime('%Y%m%d-%H%M%S')}.h5")

## 4-3. Fit Visualization

In [17]:
if os.path.isfile(model_h5_path):
    pass

else:
    # 학습과정 살펴보기
    fig, loss_ax = plt.subplots()

    acc_ax = loss_ax.twinx()

    loss_ax.plot(hist.history['loss'], 'y', label='train loss')
    loss_ax.plot(hist.history['val_loss'], 'r', label='val loss')

    acc_ax.plot(hist.history['dice_coef'], 'b', label='train dice_coef')
    acc_ax.plot(hist.history['val_dice_coef'], 'g', label='val dice_coef')

    loss_ax.set_xlabel('epoch')
    loss_ax.set_ylabel('loss')
    acc_ax.set_ylabel('dice_coef')

    loss_ax.legend(loc='upper left')
    acc_ax.legend(loc='lower left')

    plt.show()

In [18]:
print('X_train', X_train.shape)
print('X_valid', X_valid.shape)
print('X_test', X_test.shape)
print('y_train', y_train.shape)
print('y_valid', y_valid.shape)
print('y_test', y_test.shape)

X_train (361, 256, 256, 3)
X_valid (91, 256, 256, 3)
X_test (114, 256, 256, 3)
y_train (361, 256, 256, 3)
y_valid (91, 256, 256, 3)
y_test (114, 256, 256, 3)
