### UNET-seResNet18 segmentation network


This notebook trains a segmentation model on T0 dataset from sratch.

Feel free to change the *path_images* and *path_masks* parameters to process another dataset into the cell below

Path to the dataset

In [None]:
# T0 dataset
path_images = "data/T0/images/"
path_masks  = "data/T0/masks/"

Imports

In [None]:
import os
import numpy as np
from PIL import Image
import segmentation_models as sm
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import keras.backend as K
from keras.optimizers import Adam


Useful functions

In [None]:
def load_images_from_directory(folder,resize,filextension):
    '''
    This function read images from a directory and store it into a Numpy array
    
    folder : path to the folder containing the images
    resize : target size of the images stored into the Numpy array
    filextension: image format. For the moment only 'tif' images are accepted
    
    Use : X = load_images_from_directory( foo/bar/, (128,128), 'tif')
    
    '''
    images  = []
    img_rows=resize[0]
    img_cols=resize[1]
    for filename in sorted(os.listdir(folder)):
        curimg = os.path.join(folder, filename)
        if curimg.endswith(filextension):
            img = Image.open(curimg)
            resize = img.resize((img_rows,img_cols), Image.NEAREST)
            images.append(resize)
    imgarray=list();
    for i in range(len(images)):
        tmp = np.array(images[i])
        imgarray.append(tmp)
    imgarray = np.asarray(imgarray).astype('float32')

    if len(imgarray.shape)==3:
        imgarray = np.expand_dims(imgarray,axis=3)

    return imgarray

Parameters

In [None]:
target_size=(128,128)

BACKBONE = "seresnet18"

bs = 32

lr = 1e-4

epochs = 20

steps_per_epoch = 400

validation_steps = 1


Read, prepare and Split the dataset

In [None]:
# Read images and masks

X = load_images_from_directory(path_images,target_size,'tif')
y = load_images_from_directory(path_masks,target_size,'tif')

print (X.shape[0], ": total images read from directory")
print (y.shape[0], ": total masks  read from directory")


#Preprocess images and masks

X = np.concatenate((X,X,X),axis=3)
preprocess_input = sm.get_preprocessing(BACKBONE)
X = preprocess_input(X)

y = y/255

#Split

X_train,X_valid,y_train,y_valid = train_test_split(X,y,test_size=0.1)

print(X_train.shape[0], " : training images")
print(X_valid.shape[0], " : validation images")

Data Augmentation

In [None]:
datagen = ImageDataGenerator(
    fill_mode = "reflect",
    horizontal_flip=True,
    vertical_flip=True,
    width_shift_range = 0.2,
    height_shift_range = 0.2,
    zoom_range=0.2
    
)

it = datagen.flow(X_train, batch_size=bs,seed=1)
it2 = datagen.flow(y_train, batch_size=bs,seed=1)

train_generator = zip(it,it2)

Check that augmentation is working

In [None]:
# We display one image for a random generated batch of images from the augmentation generator
a,b = next(train_generator)

#one image
plt.imshow(a[0,:,:,0],cmap="gray")
plt.show()

#correspinding mask
plt.imshow(b[0,:,:,0]>0.5,cmap="gray")
plt.show()

Pretrained UNET-seResNet18 creation from _segmentation model_ [library](https://github.com/qubvel/segmentation_models)

In [None]:
K.clear_session()

model = sm.Unet(BACKBONE, encoder_weights='imagenet', input_shape=X.shape[1:],classes=1, activation='sigmoid')



Loss (Dice loss) and Metrics (Intersection Over Union)

In [None]:
model.compile(
    Adam(lr=lr),
    loss=sm.losses.dice_loss,
    metrics=[sm.metrics.iou_score],
)

Model training

In [None]:
H = model.fit_generator(
    train_generator,
    validation_data  = (X_valid,y_valid),
    epochs           = epochs,
    steps_per_epoch  = steps_per_epoch,
    validation_steps = validation_steps,
    verbose          = 1
)

Predict and check the segmentation quality on a random image from the validation set

In [None]:
p = model.predict(X_valid)

In [None]:

img = 1

plt.figure(figsize=(8,8))
plt.imshow(np.squeeze((p[img])>0.5),cmap="gray")
#plt.imshow(np.squeeze(p[img]<0.5),cmap="gray")
#plt.imshow(np.squeeze(p[img]<0.5),cmap="gray")
plt.show()
plt.figure(figsize=(8,8))
plt.imshow(np.squeeze(y_valid[img]),cmap="gray")
plt.show()
plt.figure(figsize=(8,8))
plt.imshow(np.squeeze(X_valid[img][:,:,0]),cmap="gray")
plt.show()

Save the trained segmentation model

In [None]:
model.save("models/T0/jerome_128x128_seresnet18_e20_spe400.h5")