In [None]:
%tensorflow_version 2  # This tells Colab to use TF2

Segment wave breaking pixels with UNet-line conv-nets.

This program loads manually labbelled wave image data and classify
each pixel in the image into "breaking" (1) or "no-breaking" (0).

The data needs to be organized as follows:

For example:

```
└───train or test or valid
    ├───images
        ├───data
               ├───img1.png
               ├───img2.png
               ...
    ├───masks
        ├───data
               ├───img1.png
               ├───img2.png
               ...
```

The neural nets are modified UNets from:
https://keras.io/examples/vision/oxford_pets_image_segmentation/
https://www.tensorflow.org/tutorials/images/segmentation


In [None]:
!pip install -q git+https://github.com/tensorflow/examples.git

In [None]:
import os
import platform

import datetime

import argparse

import numpy as np
import pandas as pd

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import optimizers
from tensorflow.keras import callbacks
from tensorflow.keras.models import load_model
from tensorflow_examples.models.pix2pix import pix2pix
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import matplotlib.pyplot as plt

# Data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!cp "/content/drive/My Drive/Colab Notebooks/FEM/data/wave_breaking_detection/segmentation_v1.tar.gz" .

In [None]:
!tar -zxf segmentation_v1.tar.gz

## Xception

In [None]:
def xception(img_size, num_classes):
    """Define the model."""
    inputs = keras.Input(shape=img_size + (3,))

    # -- [First half of the network: downsampling inputs] ---

    # entry block
    x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    # blocks 1, 2, 3 are identical apart from the feature depth.
    for filters in [64, 128, 256]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    # --- [Second half of the network: upsampling inputs] ---

    for filters in [256, 128, 64, 32]:
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.UpSampling2D(2)(x)

        # Project residual
        residual = layers.UpSampling2D(2)(previous_block_activation)
        residual = layers.Conv2D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    # Add a per-pixel classification layer
    outputs = layers.Conv2D(
        num_classes, 3, activation="softmax", padding="same")(x)

    # Define the model
    model = keras.Model(inputs, outputs)
    return model

# MobileNet

In [None]:
def mobilenet(img_size, num_classes, model_path="mobilenet.h5"):
    """Define the model."""
    # Use mobile net
    # base_model = tf.keras.applications.MobileNetV2(input_shape=[256, 256, 3],
    #                                                include_top=False)
    base_model = load_model(model_path)

    # use the activations of these layers
    layer_names = [
        'block_1_expand_relu',   # 64x64
        'block_3_expand_relu',   # 32x32
        'block_6_expand_relu',   # 16x16
        'block_13_expand_relu',  # 8x8
        'block_16_project']      # 4x4

    layers = [base_model.get_layer(name).output for name in layer_names]

    # create the feature extraction model
    down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)
    down_stack.trainable = False

    # create the upstack
    up_stack = [pix2pix.upsample(512, 3),  # 4x4 -> 8x8
                pix2pix.upsample(256, 3),  # 8x8 -> 16x16
                pix2pix.upsample(128, 3),  # 16x16 -> 32x32
                pix2pix.upsample(64, 3)]  # 32x32 -> 64x64

    inputs = tf.keras.layers.Input(shape=[256, 256, 3])
    x = inputs

    # Downsampling through the model
    skips = down_stack(x)
    x = skips[-1]
    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

    # This is the last layer of the model
    last = tf.keras.layers.Conv2DTranspose(num_classes, 3, strides=2,
                                           padding='same')  # 64x64 -> 128x128

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
def display_mask(val_preds):
    """Display a model's prediction."""
    mask = np.argmax(val_preds, axis=-1)
    mask = np.expand_dims(mask, axis=-1)
    return mask

In [None]:
def image_mask_generator(image_data_generator, mask_data_generator):
    """Yield a generator."""
    train_generator = zip(image_data_generator, mask_data_generator)
    for (img, mask) in train_generator:
        yield (img[0], mask[0][:, :, :, 0])

# Parameters

In [None]:
data = "segmentation_v1"
backbone = "xception"
pre_trained = None
model_name = "wave_xception"
img_size = (256, 256)
batch_size = 32
random_seed = 11
epochs = 16
logdir = "logs"
learning_rate = 10E-6
num_classes = 2

## Callbacks

In [None]:
date = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
if platform.system().lower() == "windows":
    logdir = logdir + "\\" + model_name + "\\" + date
else:
    logdir = logdir + "/" + model_name + "/" + date
if not os.path.isdir(logdir):
    os.makedirs(logdir, exist_ok=True)

tensorboard = callbacks.TensorBoard(log_dir=logdir,
                                    histogram_freq=1,
                                    profile_batch=1)

if platform.system().lower() == "windows":
    checkpoint_path = logdir + "\\" + "best.h5"
else:
    checkpoint_path = logdir + "/" + "best.h5"
checkpoint = callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                        save_best_only=True,
                                        save_weights_only=False,
                                        monitor='val_loss',
                                        mode="min",
                                        verbose=1)
if platform.system().lower() == "windows":
    pred_out = logdir + "\\" + "pred"
else:
    pred_out = logdir + "/" + "pred"
os.makedirs(pred_out, exist_ok=True)

# Data Augmentation

In [None]:
# train generators - they need to be identical!
image_train_generator = ImageDataGenerator(
    zoom_range=0.2,
    rotation_range=20,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    vertical_flip=False,
    rescale=1. / 255.).flow_from_directory(data + "/train/images",
                                           batch_size=batch_size,
                                           target_size=img_size,
                                           seed=random_seed)
mask_train_generator = ImageDataGenerator(
    zoom_range=0.2,
    width_shift_range=0.1,
    height_shift_range=0.1,
    rotation_range=10,
    horizontal_flip=True,
    vertical_flip=False,
    rescale=1. / 255.).flow_from_directory(data + "/train/masks",
                                           batch_size=batch_size,
                                           target_size=img_size,
                                           seed=random_seed)

# test/valid generators - they need to be identical!
# note that no augmentation is really done to this data, only a resize
image_valid_generator = ImageDataGenerator(
    rescale=1. / 255.).flow_from_directory(data + "/valid/images",
                                           batch_size=batch_size,
                                           target_size=img_size,
                                           seed=random_seed)

mask_valid_generator = ImageDataGenerator(
    rescale=1. / 255.).flow_from_directory(data + "/valid/masks",
                                           batch_size=batch_size,
                                           target_size=img_size,
                                           seed=random_seed)
image_test_generator = ImageDataGenerator(
    rescale=1. / 255.).flow_from_directory(data + "/test/images",
                                           batch_size=1,
                                           target_size=img_size,
                                           seed=random_seed)

mask_test_generator = ImageDataGenerator(
    rescale=1. / 255.).flow_from_directory(data + "/test/masks",
                                           batch_size=1,
                                           target_size=img_size,
                                           seed=random_seed)

train_generator = image_mask_generator(image_train_generator, mask_train_generator)
valid_generator = image_mask_generator(image_valid_generator, mask_valid_generator)
test_generator = image_mask_generator(image_test_generator, mask_test_generator)
train_size = image_train_generator.n
valid_size = image_valid_generator.n

# Train

In [None]:
# define the model
if backbone.lower() == "xception":
    model = xception(img_size, num_classes)
elif backbone.lower() == "mobilenet":
    if not pre_trained:
        raise ValueError("Pre-trained model is required with {}.".format(backbone))
    model = mobilenet(img_size, num_classes, pre_trained)
else:
    raise NotImplementedError("Backbone {} is not implemented".format(backbone))
model.summary()

# configure the model for training.
# we use the "sparse" version of categorical_crossentropy
# because our target data are integers.
optimizer = optimizers.Adam(learning_rate=learning_rate)
model.compile(optimizer=optimizer,
                loss="sparse_categorical_crossentropy",
                metrics=["accuracy"])

# train the model, doing validation at the end of each epoch.
history = model.fit(train_generator,
                    epochs=epochs,
                    steps_per_epoch=train_size // batch_size,
                    validation_data=valid_generator,
                    validation_steps=valid_size // batch_size,
                    callbacks=[tensorboard, checkpoint])
hist = pd.DataFrame(history.history)
hist["epoch"] = history.epoch

# Test

In [None]:
# predict on the test data and save the outputs
for i in range(4):
    img, msk = next(test_generator)
    val_preds = model.predict(img)
    # process prediction
    prd = display_mask(val_preds)
    # plot
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 6),
                                        sharex=True, sharey=True)
    ax1.imshow(np.squeeze(img))
    ax2.imshow(np.squeeze(msk))
    ax3.imshow(np.squeeze(prd))
    fig.tight_layout()
    plt.show()

# Tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs