## Abstract

The author of paper propose a simple and effective end-to-end image segmentation network architecture for medical images.
The proposed network, called U-net, has main three factors for well-training.
- U-shaped network structure with two configurations: Contracting and Expanding path
- Training more faster than sliding-windows: Patch units and Overlap-tile
- Data augmentation: Elastic deformation and Weight cross entropy

## Dataset

The dataset we used is Transmission Electron Microscopy (ssTEM) data set of the Drosophila first instar larva ventral nerve cord (VNC), which is dowloaded from [ISBI Challenge: Segmentation of of neural structures in EM stacks](http://brainiac2.mit.edu/isbi_challenge/home)


![ISBI](./images/ISBI.gif)


- Black and white segmentation of membrane and cell with EM(Electron Microscopic) image.
- The data set is a large size of image and few so the data augmentation is needed.
- The data set contains 30 images of size 512x512 for the train, train-labels and test.
- There is no images for test-labels for the ISBI competition.
- If you want to get the evaluation metrics of competition, you should split part of the train data set for testing.


## Overlap-tile


Sliding Window
![sliding_window](./images/sliding_window.png) 


Patch
![patch](./images/patch.png) 


- Patch method has low overlap ratio so that the speed of detection can be improvement.
- However, as the wide size of patch detect image at once, the performance of context is good but the performance of localization is lower.
- In this paper, the U-net architecture and overlap-tile methods were proposed to solve this localization problem.


Overlap-tile
![overlap_tile](./images/overlap_tile.png)


Simple. Because the EM image is large, sometimes the model of detection input is larger than the patch size (yellow). If so, mirror and fill in the patch area with the empty part.

## Data Augmenation

We preprocessed the images for data augmentation. Following preprocessing are :
   * Flip
   * Gaussian noise
   * Uniform noise
   * Brightness
   * Elastic deformation
   * Crop
   * Pad 
   
You can easily to understand refer this [page](https://github.com/ugent-korea/pytorch-unet-segmentation/blob/master/README.md#preprocessing)

In [2]:
import glob
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img,array_to_img
from utills import *
from random import randint
from PIL import Image, ImageSequence
import os


class dataAugment(object):
    def __init__(self, out_rows, out_cols, data_path="./data"):
        self.out_rows = out_rows
        self.out_cols = out_cols
        self.data_path = data_path

    def augmentation(self):
        # read images
        print('-' * 30)
        print('Augment train images...')
        print('-' * 30)

        # Create directory
        if not os.path.exists(self.data_path + "/raw/images"):
            os.makedirs(self.data_path + "/raw/images")
        if not os.path.exists(self.data_path + "/raw/labels"):
            os.makedirs(self.data_path + "/raw/labels")
        if not os.path.exists(self.data_path + "/aug/images"):
            os.makedirs(self.data_path + "/aug/images")
        if not os.path.exists(self.data_path + "/aug/labels"):
            os.makedirs(self.data_path + "/aug/labels")
        if not os.path.exists(self.data_path + "/npy"):
            os.makedirs(self.data_path + "/npy")

        # Split isbi tif image&label to single frame of png images
        isbi_img = Image.open(self.data_path + "/train-volume.tif")  # raw image from isbi dataset
        for i, page in enumerate(ImageSequence.Iterator(isbi_img)):
            page.save(self.data_path+"/raw/images/" + str(i) + ".png")

        isbi_lbl = Image.open(self.data_path + "/train-labels.tif")  # raw image from isbi dataset
        for i, page in enumerate(ImageSequence.Iterator(isbi_lbl)):
            page.save(self.data_path+"/raw/labels/" + str(i) + ".png")

        train_imgs = glob.glob(self.data_path + "/raw/images/*.png")
        label_imgs = glob.glob(self.data_path + "/raw/labels/*.png")
        slices = len(train_imgs)
        if len(train_imgs) != len(label_imgs) or len(train_imgs) == 0:
            print("trains can't match labels")
            return 0

        print('Using real-time data augmentation. len: ', slices)
        # one by one augmentation
        batch_size = 30  # one frame for 30 images augment
        for i in range(slices):
            for b in range(batch_size):
                img_as_img = Image.open(self.data_path + "/raw/images/" + str(i) + ".png")
                lbl_as_img = Image.open(self.data_path + "/raw/labels/" + str(i) + ".png")
                img_as_np = np.asarray(img_as_img)
                lbl_as_np = np.asarray(lbl_as_img)

                # flip {0: vertical, 1: horizontal, 2: both, 3: none}
                flip_num = randint(0, 3)
                img_as_np = flip(img_as_np, flip_num)
                lbl_as_np = flip(lbl_as_np, flip_num)

                # Noise Determine {0: Gaussian_noise, 1: uniform_noise
                if randint(0, 1):
                    # Gaussian_noise
                    gaus_sd, gaus_mean = randint(0, 20), 0
                    img_as_np = add_gaussian_noise(img_as_np, gaus_mean, gaus_sd)
                    lbl_as_np = add_gaussian_noise(lbl_as_np, gaus_mean, gaus_sd)
                else:
                    # uniform_noise
                    l_bound, u_bound = randint(-20, 0), randint(0, 20)
                    img_as_np = add_uniform_noise(img_as_np, l_bound, u_bound)
                    lbl_as_np = add_uniform_noise(lbl_as_np, l_bound, u_bound)

                # Brightness
                pix_add = randint(-20, 20)
                img_as_np = change_brightness(img_as_np, pix_add)
                lbl_as_np = change_brightness(lbl_as_np, pix_add)

                # Elastic distort {0: distort, 1:no distort}
                sigma = randint(6, 12)
                # sigma = 4, alpha = 34
                img_as_np, seed = add_elastic_transform(img_as_np, alpha=34, sigma=sigma, pad_size=20)
                lbl_as_np, seed = add_elastic_transform(lbl_as_np, alpha=34, sigma=sigma, pad_size=20)

                # Crop the image
                in_size = 512
                out_size = 388
                img_height, img_width = img_as_np.shape[0], img_as_np.shape[1]
                pad_size = int((in_size - out_size)/2)
                img_as_np = np.pad(img_as_np, pad_size, mode="symmetric")
                lbl_as_np = np.pad(lbl_as_np, pad_size, mode="symmetric")
                y_loc, x_loc = randint(0, img_height-out_size), randint(0, img_width-out_size)
                img_as_np = cropping(img_as_np, crop_size=in_size, dim1=y_loc, dim2=x_loc)
                lbl_as_np = cropping(lbl_as_np, crop_size=in_size, dim1=y_loc, dim2=x_loc)

                # Normalize the image
                img_as_np = normalization2(img_as_np, max=1, min=0)
                img = img_as_np.reshape(img_as_np.shape[0], img_as_np.shape[1], 1)
                img = array_to_img(img)
                img.save(self.data_path + "/aug/images/" + str(30*i+b) + ".png")

                lbl = lbl_as_np.reshape(lbl_as_np.shape[0], lbl_as_np.shape[1], 1)
                lbl = array_to_img(lbl)
                lbl.save(self.data_path + "/aug/labels/" + str(30*i+b) + ".png")

            print(str(i+1))


if __name__ == "__main__":
    try:
        mydata = dataAugment(512, 512)
        mydata.augmentation()
    except RuntimeError as e:
        print(e)


------------------------------
Augment train images...
------------------------------
Using real-time data augmentation. len:  30
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30


## Data Pre-processing

Data pre-process for converting .tif to .png

Create the train, train-label, test png image files.

In [4]:
from keras.preprocessing.image import img_to_array, load_img,array_to_img
import numpy as np 
import glob
import tensorflow as tf

class dataProcess(object):
    def __init__(self, out_rows, out_cols, data_path = "./data", img_type = "tif"):
        self.out_rows = out_rows
        self.out_cols = out_cols
        self.data_path = data_path
        self.img_type = img_type
    
    # Image to numpy 
    def create_train_data(self):
        i = 0
        j = 0
        print('-'*30)
        print('Creating training images...')
        print('-'*30)
        imgs = glob.glob(self.data_path+"/raw/train/*."+self.img_type)
        augimgs = glob.glob(self.data_path+"/aug/train/*."+self.img_type)
        print("original images",len(imgs))
        print("augmented images",len(augimgs))
        imgdatas = np.ndarray((len(imgs)+len(augimgs),self.out_rows,self.out_cols,1), dtype=np.uint8)
        imglabels = np.ndarray((len(imgs)+len(augimgs),self.out_rows,self.out_cols,1), dtype=np.uint8)
        for imgname in imgs:
            midname = imgname[imgname.rindex("/")+1:]
            img = load_img(self.data_path + "/raw/train/" + midname,color_mode = "grayscale")
            label = load_img(self.data_path + "/raw/label/" + midname,color_mode = "grayscale")
            img = img_to_array(img)
            label = img_to_array(label)
            imgdatas[i] = img
            imglabels[i] = label
            j += 1
            if j % 30 == 0:
                print('Done: {0}/{1} images'.format(j, len(imgs)))
        for imgname in augimgs:
            midname = imgname[imgname.rindex("/")+1:]
            img = load_img(self.data_path + "/aug/train/" + midname,color_mode = "grayscale")
            label = load_img(self.data_path + "/aug/label/" + midname,color_mode = "grayscale")
            img = img_to_array(img)
            label = img_to_array(label)
            imgdatas[i] = img
            imglabels[i] = label
            i += 1
            if i % 100 == 0:
                print('Done: {0}/{1} images'.format(i, len(augimgs)))
    
        print('loading done')
        np.save(self.data_path + '/npy/imgs_train.npy', imgdatas)
        np.save(self.data_path + '/npy/imgs_mask_train.npy', imglabels)
        print('Saving to .npy files done.')
    
    def create_test_data(self):
        print('-'*30)
        print('Creating test images...')
        print('-'*30)
        imgs = glob.glob(self.data_path+"/raw/test/*."+self.img_type)
        imgdatas = np.ndarray((len(imgs),self.out_rows,self.out_cols,1), dtype=np.uint8)
        for ind in range(len(imgs)):
            img = load_img(self.data_path + "/raw/test/" +str(ind)+".tif",color_mode = "grayscale")
            img = img_to_array(img)
            imgdatas[ind] = img
            ind += 1
        print('loading done')
        np.save(self.data_path + '/npy/imgs_test.npy', imgdatas)
        print('Saving to imgs_test.npy files done.')
    
    
    # Masking and Labeling for training
    def load_train_data(self):
        print('-'*30)
        print('load train images...')
        print('-'*30)
        imgs_train = np.load(self.data_path+"/npy/imgs_train.npy")
        imgs_mask_train = np.load(self.data_path+"/npy/imgs_mask_train.npy")
        imgs_train = imgs_train.astype('float32')
        imgs_mask_train = imgs_mask_train.astype('float32')
        print(imgs_train)
        imgs_train /= 255 # RGB 0~1
        imgs_mask_train /= 255
        imgs_mask_train[imgs_mask_train > 0.5] = 1 # Membrane
        imgs_mask_train[imgs_mask_train <= 0.5] = 0 # Cell
        return imgs_train,imgs_mask_train
    
    def load_test_data(self):
        print('-'*30)
        print('load test images...')
        print('-'*30)
        imgs_test = np.load(self.data_path+"/npy/imgs_test.npy")
        imgs_test = imgs_test.astype('float32')
        imgs_test /= 255
        return imgs_test

if __name__ == "__main__":
    try:
        with tf.device('/device:GPU:1'):
            mydata = dataProcess(512,512)
            mydata.create_train_data()
            mydata.create_test_data()
    except RuntimeError as e:
      print(e)

------------------------------
Creating training images...
------------------------------
original images 0
augmented images 0
loading done
Saving to .npy files done.
------------------------------
Creating test images...
------------------------------
loading done
Saving to imgs_test.npy files done.


## Network Architecture

![Unet](./images/unet.png)

### Contracting Path(Fully Convolution)
- 전형적인 convolution network. 
- 3x3 conv 와 max-pooling, drop out
- 이미지 feature 을 정확하게 추출하나 feature map 크기가 줄어듬


### Expanding Path(Deconvolution)
- 줄어든 feature map 의 크기를 다시 복구하여 ouput segmentation map 출력
- 2x2 up-conv 와 3x3 conv, concatenate
- expand 과정은 localization에 대한 정보를 잃게 된다는 단점
- 따라서, up-conv 된 이후의 feature map 과 동일한 level 의 feature map 을 결합하여 localization 정보를 제공
- 마지막은 1x1 conv mapping

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dropout, UpSampling2D, concatenate
from tensorflow.keras.optimizers import Adam


def unet(pretrained_weights = None,input_size = (512,512,1)):
    inputs = Input(input_size)
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
    conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
    drop4 = Dropout(0.5)(conv4) # for crop and copy
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

    conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
    conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
    drop5 = Dropout(0.5)(conv5)

    up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
    merge6 = concatenate([drop4,up6], axis = 3) # Concatenate for localization informantion
    conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
    conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)

    up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
    merge7 = concatenate([conv3,up7], axis = 3)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)

    up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
    merge8 = concatenate([conv2,up8], axis = 3)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)

    up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
    merge9 = concatenate([conv1,up9], axis = 3)
    conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
    conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
    conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
    conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=conv10)

    model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])

    #model.summary()

    if(pretrained_weights):
        model.load_weights(pretrained_weights)

    return model


## Train and Test

In [None]:
from preprocessing import *
from model import unet
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers import Adam
from tifffile import imsave as tifsave


class myUnet(object):

    def __init__(self, img_rows=512, img_cols=512, save_path="./results/"):
        self.img_rows = img_rows
        self.img_cols = img_cols
        self.save_path = save_path

    def load_data(self):
        mydata = dataProcess(self.img_rows, self.img_cols)
        imgs_train, imgs_mask_train = mydata.load_train_data()
        imgs_test = mydata.load_test_data()
        print(imgs_test)
        return imgs_train, imgs_mask_train, imgs_test

    def train(self, load_pretrained):
        print("loading data")
        model_name = 'my_model.h5'
        log_dir = "logs/000"
        logging = TensorBoard(log_dir=log_dir)
        reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3, verbose=1)
        early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1)
        imgs_train, imgs_mask_train, imgs_test = self.load_data()
        print("loading data done")
        if load_pretrained:
            model = load_model(model_name)
            model.compile(optimizer=Adam(lr=1e-4), loss='binary_crossentropy', metrics=['accuracy'])
            model_checkpoint = ModelCheckpoint('unet.h5', monitor='val_loss', verbose=1, save_best_only=True)
            model.fit(imgs_train, imgs_mask_train, batch_size=2, epochs=30, verbose=1,
                      validation_split=0.2, shuffle=True, callbacks=[logging, model_checkpoint, reduce_lr])
            model.save(model_name)
        else:
            model = unet()
            model.summary()
            model_checkpoint = ModelCheckpoint('unet.h5', monitor='val_loss', verbose=1, save_best_only=True)
            model.fit(imgs_train, imgs_mask_train, batch_size=2, epochs=30, verbose=1,
                      validation_split=0.2, shuffle=True,
                      callbacks=[logging, model_checkpoint, reduce_lr, early_stopping])
            model.save(model_name)

    def test(self):
        model_name = 'my_model.h5'
        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)

        imgs_train, imgs_mask_train, imgs_test = self.load_data()
        model = load_model(model_name)
        imgs_mask_test = model.predict(imgs_test, batch_size=2, verbose=1)
        np.save(self.save_path + "imgs_mask_test.npy", imgs_mask_test)

        print("array to image")
        imgs = np.load(self.save_path + "imgs_mask_test.npy")
        total = []
        for i in range(imgs.shape[0]):
            img = imgs[i]
            img[img > 0.5] = 1
            img[img <= 0.5] = 0
            total.append(img)
        np_total = np.array(total)
        tifsave("./prediction.tif", np_total)


if __name__ == '__main__':
    myunet = myUnet()
    myunet.train(load_pretrained=False)
    myunet.test()


## Result

![test-volume](./images/test-volume.tif) ![test](./images/total_result_aug_10.tif)