This tutorial example is on the application of CNN image segmentation on salt identification. The train and test data were gotten from TGS-Kaggle Salt Identification challenge in 2018: https://www.kaggle.com/c/tgs-salt-identification-challenge/ . A majority of the code in this notebook was extracted from: https://www.kaggle.com/keegil/keras-u-net-starter-lb-0-277 by Kjetil Åmdal-Sævik.

The train data consist of 4000 (101 by 101) seismic patches and 4000 corresponding masks indicating regions of salt presence or not. The test data consist of 18000 (101 by 101) seismic patches. The aim of the competition was to predict regions of salt or no salt on the test seismic patches.

In this example, we will be covering how to use th U-Net model (a popular type of image segmentation model) to identify salt regions from seismic image patches and the following;

* Loading the data and preparation
* Training the U-Net
* Saving the trained model
* Prediction
* Visualizations

In [None]:
# import required libraries and packages

import os, sys, random, warnings
from tqdm import tqdm_notebook

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from tqdm import tqdm
from itertools import chain
from skimage.io import imread, imshow, imread_collection, concatenate_images
from skimage.transform import resize
from skimage.morphology import label

from keras.models import Model, load_model
from keras.layers import Input
from keras.layers.core import Dropout, Lambda
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D
from keras.layers.merge import concatenate
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras import backend as K
from keras.preprocessing.image import load_img, img_to_array

import tensorflow as tf

In [None]:
# Display some example image and mask patches

ids= ['0a1742c740','5b7c160d0d','6c40978ddf','7dfdf6eeb8','1ee0d5b4d0']

plt.figure(figsize=(20,10))
for j, img_name in enumerate(ids):
    q = j+1
    img = load_img('./train/images/' + img_name + '.png')
    img_mask = load_img('./train/masks/' + img_name + '.png')
    
    plt.subplot(1,2*(1+len(ids)),q*2-1)
    plt.imshow(img)
    plt.subplot(1,2*(1+len(ids)),q*2)
    plt.imshow(img_mask)
plt.show()


#Dark patches indicate no salts in the image

### Loading the data and preparation

Getting training image ids to load image and mask data

In [None]:
path_train = './train/'   # path to train images and masks
path_test = './test/'     # path to test images

train_ids = next(os.walk(path_train+"images"))[2]
test_ids = next(os.walk(path_test+"images"))[2]

In [None]:
len(train_ids), len(test_ids)

In [None]:
train_ids[:5]

Image and mask patches will be resized to 64px by 64px from original 101px by 101px

In [None]:
im_height = 64    # new height
im_width = 64    # new width
im_chan = 1     # number of image channel

Creating empty arrays to hold resized train and test image data

In [None]:
X_train = np.zeros((len(train_ids), im_height, im_width, im_chan), dtype=np.uint8)
Y_train = np.zeros((len(train_ids), im_height, im_width, im_chan), dtype=np.bool)

In [None]:
X_train.shape, Y_train.shape

In [None]:
X_train[0]

In [None]:
Y_train[1]

Resizing train images and masks

In [None]:
path = path_train
for n, id_ in tqdm_notebook(enumerate(train_ids), total=len(train_ids)):
    
    img = load_img(path + '/images/' + id_)
    x = img_to_array(img)[:,:,1]
    x = resize(x, (64, 64, 1), mode='constant', preserve_range=True)
    X_train[n] = x
    mask = img_to_array(load_img(path + '/masks/' + id_))[:,:,1]
    Y_train[n] = resize(mask, (64, 64, 1), mode='constant', preserve_range=True)

print('Done!')

In [None]:
X_train.shape

In [None]:
# Quakity check to see if training data looks all right

ix = random.randint(0, len(train_ids))
plt.imshow(np.dstack((X_train[ix],X_train[ix],X_train[ix])))
plt.show()
tmp = np.squeeze(Y_train[ix]).astype(np.float32)
plt.imshow(np.dstack((tmp,tmp,tmp)))
plt.show()

Loading and resizing test image patches

In [None]:
# creating empty test image arrays

X_test = np.zeros((len(test_ids), im_height, im_width, im_chan), dtype=np.uint8)

In [None]:
X_test.shape

In [None]:
path = path_test
for n, id_ in tqdm_notebook(enumerate(test_ids), total=len(test_ids)):
    
    img = load_img(path + '/images/' + id_)
    x = img_to_array(img)[:,:,1]
    x = resize(x, (64, 64, 1), mode='constant', preserve_range=True)
    X_test[n] = x

print('Done!')

In [None]:
X_train.shape, X_test.shape, Y_train.shape

### Model Training with the U-Net Architecture

![U-Net architecture.png](images/u-net-architecture.png)

In [None]:
inputs = Input((im_height, im_width, im_chan))
s = Lambda(lambda x: x / 255) (inputs)
c1 = Conv2D(8, (3, 3), activation='relu', padding='same') (s)
c1 = Conv2D(8, (3, 3), activation='relu', padding='same') (c1)
c1

In [None]:
s

In [None]:
# Build U-Net model
inputs = Input((im_height, im_width, im_chan))
s = Lambda(lambda x: x / 255) (inputs)


# CONTRASTIVE and ENCODING PART
c1 = Conv2D(8, (3, 3), activation='relu', padding='same') (s)
c1 = Conv2D(8, (3, 3), activation='relu', padding='same') (c1)
p1 = MaxPooling2D((2, 2)) (c1)

c2 = Conv2D(16, (3, 3), activation='relu', padding='same') (p1)
c2 = Conv2D(16, (3, 3), activation='relu', padding='same') (c2)
p2 = MaxPooling2D((2, 2)) (c2)

c3 = Conv2D(32, (3, 3), activation='relu', padding='same') (p2)
c3 = Conv2D(32, (3, 3), activation='relu', padding='same') (c3)
p3 = MaxPooling2D((2, 2)) (c3)

c4 = Conv2D(64, (3, 3), activation='relu', padding='same') (p3)
c4 = Conv2D(64, (3, 3), activation='relu', padding='same') (c4)
p4 = MaxPooling2D(pool_size=(2, 2)) (c4)

c5 = Conv2D(128, (3, 3), activation='relu', padding='same') (p4)
c5 = Conv2D(128, (3, 3), activation='relu', padding='same') (c5)

# EXPANSIVE and DECODING PART
u6 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same') (c5)
u6 = concatenate([u6, c4])
c6 = Conv2D(64, (3, 3), activation='relu', padding='same') (u6)
c6 = Conv2D(64, (3, 3), activation='relu', padding='same') (c6)

u7 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same') (c6)
u7 = concatenate([u7, c3])
c7 = Conv2D(32, (3, 3), activation='relu', padding='same') (u7)
c7 = Conv2D(32, (3, 3), activation='relu', padding='same') (c7)

u8 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same') (c7)
u8 = concatenate([u8, c2])
c8 = Conv2D(16, (3, 3), activation='relu', padding='same') (u8)
c8 = Conv2D(16, (3, 3), activation='relu', padding='same') (c8)

u9 = Conv2DTranspose(8, (2, 2), strides=(2, 2), padding='same') (c8)
u9 = concatenate([u9, c1], axis=3)
c9 = Conv2D(8, (3, 3), activation='relu', padding='same') (u9)
c9 = Conv2D(8, (3, 3), activation='relu', padding='same') (c9)

# OUTPUT LAYER
outputs = Conv2D(1, (1, 1), activation='sigmoid') (c9)

model = Model(inputs=[inputs], outputs=[outputs])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=["accuracy"])

In [None]:
model.summary()

### Saving the trained model

Specify model hyperparameters, and call backs to save the model

In [None]:
earlystopper = EarlyStopping(patience=5, verbose=1)
checkpointer = ModelCheckpoint('TGS-salt-model-1.h5', verbose=1, save_best_only=True)

In [None]:
epochs = 15
batch_size = 8
validation_size = 0.2

Train the model

In [None]:
results = model.fit(X_train, Y_train, validation_split=validation_size, 
                    batch_size=batch_size, epochs=epochs, callbacks=[earlystopper, checkpointer])

### Prediction

Making use of the best saved model to make train, test prediction

In [None]:
# Predict on train, val and test
model = load_model('TGS-salt-model-1.h5')
preds_train = model.predict(X_train[:int(X_train.shape[0]*0.8)], verbose=1)
preds_val = model.predict(X_train[int(X_train.shape[0]*0.8):], verbose=1)

preds_test = model.predict(X_test, verbose=1)    # test data

In [None]:
preds_train

In [None]:
# Threshold predictions
preds_train_t = (preds_train > 0.5).astype(np.uint8)
preds_val_t = (preds_val > 0.5).astype(np.uint8)
preds_test_t = (preds_test > 0.5).astype(np.uint8)

In [None]:
preds_train_t

In [None]:
preds_train_t.shape, preds_val_t.shape, preds_test_t.shape

### Visualizations

Visualizing the predicted samples (train and test)

In [None]:
# Train samples predicted

ix = random.randint(0, len(preds_train_t))
imshow(np.squeeze(X_train[ix]))
plt.show()
imshow(np.squeeze(Y_train[ix]))
plt.show()
tmp = np.squeeze(preds_train_t[ix]).astype(np.float32)
imshow(np.dstack((tmp,tmp,tmp)))
plt.show()

In [None]:
# Train samples predicted

ix = random.randint(0, len(preds_train_t))
imshow(np.squeeze(X_train[ix]))
plt.show()
imshow(np.squeeze(Y_train[ix]))
plt.show()
tmp = np.squeeze(preds_train_t[ix]).astype(np.float32)
imshow(np.dstack((tmp,tmp,tmp)))
plt.show()

In [None]:
# Train samples predicted

ix = random.randint(0, len(preds_train_t))
imshow(np.squeeze(X_train[ix]))
plt.show()
imshow(np.squeeze(Y_train[ix]))
plt.show()
tmp = np.squeeze(preds_train_t[ix]).astype(np.float32)
imshow(np.dstack((tmp,tmp,tmp)))
plt.show()

In [None]:
# Test samples predicted

ix = random.randint(0, len(preds_test_t))
imshow(np.squeeze(X_test[ix]))
plt.show()
tmp = np.squeeze(preds_test_t[ix]).astype(np.float32)
imshow(np.dstack((tmp,tmp,tmp)))
plt.show()

In [None]:
# Test samples predicted

ix = random.randint(0, len(preds_test_t))
imshow(np.squeeze(X_test[ix]))
plt.show()
tmp = np.squeeze(preds_test_t[ix]).astype(np.float32)
imshow(np.dstack((tmp,tmp,tmp)))
plt.show()

In [None]:
# Test samples predicted

ix = random.randint(0, len(preds_test_t))
imshow(np.squeeze(X_test[ix]))
plt.show()
tmp = np.squeeze(preds_test_t[ix]).astype(np.float32)
imshow(np.dstack((tmp,tmp,tmp)))
plt.show()