<a href="https://colab.research.google.com/github/hyunaeee/PR_semantic_image_segmentation_unet/blob/main/Unet_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# unet_model.py
from keras.models import Model
from keras.backend import int_shape
from keras.layers import BatchNormalization, Conv2D, Conv2DTranspose, MaxPooling2D, Dropout, UpSampling2D, Input, concatenate


In [None]:
# train.py
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import sys
from PIL import Image
masks = glob.glob("./dataset/ct21/train/label/*.png")
orgs = glob.glob("./dataset/ct21/train/image/*.png")


In [None]:
# test.py
from PIL import Image
import numpy as np
import glob
masks = glob.glob("./dataset/ct21/train/label/*.png")
orgs = glob.glob("./dataset/ct21/train/image/*.png")
imgs_list = []
masks_list = []
for image, mask in zip(orgs, masks):
    imgs_list.append(np.array(Image.open(image).resize((512,512))))

    im = Image.open(mask).resize((512,512))
    masks_list.append(np.array(im))
imgs_np = np.asarray(imgs_list)
masks_np = np.asarray(masks_list)
x = np.asarray(imgs_np, dtype=np.float32)/255
y = np.asarray(masks_np, dtype=np.float32)/255
y = y.reshape(y.shape[0], y.shape[1], y.shape[2], 1)
x = x.reshape(x.shape[0], x.shape[1], x.shape[2], 1)

from sklearn.model_selection import train_test_split
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.1, random_state=0)
from unet_model import unet_model
input_shape = x_train[0].shape
model = unet_model(
    input_shape,
    num_classes=1,
    filters=64,
    dropout=0.2,
    num_layers=4,
    output_activation='sigmoid'
)
model_filename = 'segm_model_v0.h5'
model.load_weights(model_filename)
y_pred = model.predict(x_val)
from utils import plot_imgs
plot_imgs(org_imgs=x_val, mask_imgs=y_val, pred_imgs=y_pred, nm_img_to_plot=3)

IndexError: ignored

In [None]:
def upsample_conv(filters, kernel_size, strides, padding):
    return Conv2DTranspose(filters, kernel_size, strides=strides, padding=padding)
def upsample_simple(filters, kernel_size, strides, padding):
    return UpSampling2D(strides)

In [None]:
def conv2d_block(
    inputs,
    use_batch_norm=True,
    dropout=0.3,
    filters=16,
    kernel_size=(3,3),
    activation='relu',
    kernel_initializer='he_normal',
    padding='same'):

    c = Conv2D(filters, kernel_size, activation=activation,
    kernel_initializer=kernel_initializer, padding=padding) (inputs)
    if use_batch_norm:
        c = BatchNormalization()(c)
    if dropout > 0.0:
        c = Dropout(dropout)(c)
    c = Conv2D(filters, kernel_size, activation=activation,
    kernel_initializer=kernel_initializer, padding=padding) (c)
    if use_batch_norm:
        c = BatchNormalization()(c)
    return c

In [None]:
def unet_model(
    input_shape,
    num_classes=1,
    use_batch_norm=True,
    upsample_mode='deconv', # 'de-convolution' or 'simple upsampling'
    use_dropout_on_upsampling=False,
    dropout=0.3,
    dropout_change_per_layer=0.0,
    filters=16,
    num_layers=4,
    output_activation='sigmoid'): # 'sigmoid' or 'softmax'

    if upsample_mode=='deconv':
        upsample=upsample_conv
    else:
        upsample=upsample_simple

    # Build U-Net model
    inputs = Input(input_shape)
    x = inputs
    down_layers = []
    for l in range(num_layers):
        x = conv2d_block(inputs=x, filters=filters, use_batch_norm=use_batch_norm,
    dropout=dropout)
        down_layers.append(x)
        x = MaxPooling2D((2, 2)) (x)
        dropout += dropout_change_per_layer
        filters = filters*2 # double the number of filters with each layer
    x = conv2d_block(inputs=x, filters=filters, use_batch_norm=use_batch_norm,
    dropout=dropout)
    if not use_dropout_on_upsampling:
        dropout = 0.0
        dropout_change_per_layer = 0.0
    for conv in reversed(down_layers):
        filters //= 2 # decrease the number of filters with each layer
        dropout -= dropout_change_per_layer
        x = upsample(filters, (2, 2), strides=(2, 2), padding='same') (x)
        x = concatenate([x, conv])
        x = conv2d_block(inputs=x, filters=filters,      use_batch_norm=use_batch_norm,
    dropout=dropout)

    outputs = Conv2D(num_classes, (1, 1), activation=output_activation) (x)

    model = Model(inputs=[inputs], outputs=[outputs])
    return model

In [None]:
imgs_list = []
masks_list = []
for image, mask in zip(orgs, masks):
    imgs_list.append(np.array(Image.open(image).resize((512,512))))

    im = Image.open(mask).resize((512,512))
    masks_list.append(np.array(im))
imgs_np = np.asarray(imgs_list)
masks_np = np.asarray(masks_list)

In [None]:
from sklearn.model_selection import train_test_split
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.5, random_state=0)

In [None]:
from utils import get_augmented
train_gen = get_augmented(
    x_train, y_train, batch_size=2,
    data_gen_args = dict(
        rotation_range=15.,
        width_shift_range=0.05,
        height_shift_range=0.05,
        shear_range=50,
        zoom_range=0.2,
        horizontal_flip=True,
        vertical_flip=True,
        fill_mode='constant'
    ))

In [None]:
sample_batch = next(train_gen)
xx, yy = sample_batch
print(xx.shape, yy.shape)
from keras_unet.utils import plot_imgs
plot_imgs(org_imgs=xx, mask_imgs=yy, nm_img_to_plot=2, figsize=6)

In [None]:
from unet_model import unet_model
input_shape = x_train[0].shape
model = unet_model(
    input_shape,
    num_classes=1,
    filters=64,
    dropout=0.2,
    num_layers=4,
    output_activation='sigmoid'
)
print(model.summary())

In [None]:
from keras.callbacks import ModelCheckpoint
model_filename = 'segm_model_v0.h5'
callback_checkpoint = ModelCheckpoint(
    model_filename,
    verbose=1,
    monitor='val_loss',
    save_best_only=True,
)
from keras.optimizers import Adam, SGD
from metrics import iou, iou_thresholded
model.compile(
    optimizer=SGD(lr=0.01, momentum=0.99),
    loss='binary_crossentropy',
    metrics=[iou, iou_thresholded]
)

In [None]:
history = model.fit_generator(
    train_gen,
    steps_per_epoch=100,
    epochs=10,

    validation_data=(x_val, y_val),
    callbacks=[callback_checkpoint]
)