This U-Net Implementation Borrowed heavily from Keras Docs Tutorial https://keras.io/examples/vision/oxford_pets_image_segmentation/

load datasets

In [None]:
import os
import matplotlib
import matplotlib.pyplot as plt

ROOT_DIR = os.getcwd()
BUOY_DIR = os.path.join(ROOT_DIR, "buoy_data")
input_dir = os.path.join(BUOY_DIR, "Train/img")
target_dir = os.path.join(BUOY_DIR, "Train/masks_machine")

**Prepare paths of input images and target segmentation masks**

In [None]:
img_size = (128, 128)

num_classes = 1
batch_size = 32

input_img_paths = sorted(
    [
        os.path.join(input_dir, fname)
        for fname in os.listdir(input_dir)
        if fname.endswith(".jpg")
    ]
)
target_img_paths = sorted(
    [
        os.path.join(target_dir, fname)
        for fname in os.listdir(target_dir)
        if fname.endswith(".png") and not fname.startswith(".")
    ]
)

print("Number of samples:", len(input_img_paths))

for input_path, target_path in zip(input_img_paths[:10], target_img_paths[:10]):
    print(input_path, "|", target_path)

**Display an input image and a segmentation mask**

In [None]:
from IPython.display import Image, display
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import PIL
from PIL import ImageOps
import numpy as np


# Display input image #7
display(Image(filename=input_img_paths[9]))

# Display auto-contrast version of corresponding target (per-pixel categories)
img = PIL.ImageOps.autocontrast(load_img(target_img_paths[9]))
display(img)
print(np.unique(load_img(target_img_paths[9])))

**Prepare Sequence class to load & vectorize batches of data**

In [None]:
from tensorflow import keras
import numpy as np
from tensorflow.keras.preprocessing.image import load_img


class Buoys(keras.utils.Sequence):
    """Helper to iterate over the data (as Numpy arrays)."""

    def __init__(self, batch_size, img_size, input_img_paths, target_img_paths):
        self.batch_size = batch_size
        self.img_size = img_size
        self.input_img_paths = input_img_paths
        self.target_img_paths = target_img_paths

    def __len__(self):
        return len(self.target_img_paths) // self.batch_size

    def __getitem__(self, idx):
        """Returns tuple (input, target) correspond to batch #idx."""
        i = idx * self.batch_size
        batch_input_img_paths = self.input_img_paths[i : i + self.batch_size]
        batch_target_img_paths = self.target_img_paths[i : i + self.batch_size]
        x = np.zeros((batch_size,) + self.img_size + (3,), dtype="float32")
        for j, path in enumerate(batch_input_img_paths):
            img = load_img(path, target_size=self.img_size)
            x[j] = img
        y = np.zeros((batch_size,) + self.img_size + (1,), dtype="uint8")
        for j, path in enumerate(batch_target_img_paths):
            img = load_img(path, target_size=self.img_size, color_mode="grayscale")
            y[j] = np.expand_dims(img, 2)
        return x, y

**U-Net blocks**

In [None]:
def down_block(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
    p = keras.layers.MaxPool2D((2, 2), (2, 2))(c)
    return c, p

def up_block(x, skip, filters, kernel_size=(3, 3), padding="same", strides=1):
    us = keras.layers.UpSampling2D((2, 2))(x)
    concat = keras.layers.Concatenate()([us, skip])
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(concat)
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
    return c

def bottleneck(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
    return c

**U-Net architecture**

In [None]:
def UNet():
    f = [16, 32, 64, 128, 256]
    inputs = keras.layers.Input((img_size[0], img_size[1], 3))
    
    p0 = inputs
    c1, p1 = down_block(p0, f[0]) #128 -> 64
    c2, p2 = down_block(p1, f[1]) #64 -> 32
    c3, p3 = down_block(p2, f[2]) #32 -> 16
    c4, p4 = down_block(p3, f[3]) #16->8
    
    bn = bottleneck(p4, f[4])
    
    u1 = up_block(bn, c4, f[3]) #8 -> 16
    u2 = up_block(u1, c3, f[2]) #16 -> 32
    u3 = up_block(u2, c2, f[1]) #32 -> 64
    u4 = up_block(u3, c1, f[0]) #64 -> 128
    
    outputs = keras.layers.Conv2D(3, (1, 1), padding="same", activation="softmax")(u4)
    model = keras.models.Model(inputs, outputs)
    return model

In [None]:
model = UNet()
model.summary()

**Set aside a validation split**

In [None]:
import random
# Split our img paths into a training and a validation set
val_samples = 1000
random.Random(1337).shuffle(input_img_paths)
random.Random(1337).shuffle(target_img_paths)
train_input_img_paths = input_img_paths[:-val_samples]
train_target_img_paths = target_img_paths[:-val_samples]
val_input_img_paths = input_img_paths[-val_samples:]
val_target_img_paths = target_img_paths[-val_samples:]

# Instantiate data Sequences for each split
train_gen = Buoys(
    batch_size, img_size, train_input_img_paths, train_target_img_paths
)
val_gen = Buoys(batch_size, img_size, val_input_img_paths, val_target_img_paths)

**Train the model**

In [None]:
# Configure the model for training.
# We use the "sparse" version of categorical_crossentropy
# because our target data is integers.
model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy")

callbacks = [
    keras.callbacks.ModelCheckpoint("buoy_detection.h5", save_best_only=True)
]

# Train the model, doing validation at the end of each epoch.
epochs = 1
history = model.fit(train_gen, epochs=epochs, validation_data=val_gen, callbacks=callbacks)

Optionally Load the Model if model is already trained, don't want to waste time/resources constantly retraining

In [None]:
model = keras.models.load_model("buoy_detection.h5")

**Visualize predictions**

In [None]:
# Generate predictions for all images in the validation set

val_gen = Buoys(batch_size, img_size, val_input_img_paths, val_target_img_paths)
val_preds = model.predict(val_gen)

def display_mask(i):
    """Quick utility to display a model's prediction."""
    mask = np.argmax(val_preds[i], axis=-1)
    mask = np.expand_dims(mask, axis=-1)
    img = PIL.ImageOps.autocontrast(keras.preprocessing.image.array_to_img(mask))
    display(img)


In [None]:
for i in range(10):
    display_mask(i)

plot loss and accuracy

In [None]:
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.title("UNet Model Loss")
plt.legend(['train','val'], loc='upper left')
plt.show()

Compute Dice

In [None]:
from keras import backend as K
from PIL import Image, ImageOps
import tensorflow as tf

def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
d = 0
c= 0
for i in range(len(val_preds)):
    truth = Image.open(target_img_paths[i])
    truth = np.asarray(truth, dtype=np.uint8)
    truth = tf.convert_to_tensor(truth, dtype=np.uint8)
    # truth = img_to_array(load_img(target_img_paths[i]))
    pred = val_preds[i]
    # truth = truth.resize((128,128))
    print(pred.shape)
    print(truth.shape)
    d += dice_coef(truth, pred)
    c +=1
print(1-(d/c))


Compute meanIoU

In [None]:
import tensorflow as tf
from keras import backend as K
m = tf.keras.metrics.MeanIoU(num_classes=1)

for i in range(4):
    truth = Image.open(target_img_paths[i])
    truth = truth.resize((128,128))
    truth = np.asarray(truth, dtype=np.uint8)
    # truth = tf.convert_to_tensor(truth, dtype=np.uint8)
    # truth = img_to_array(load_img(target_img_paths[i]))
    pred = val_preds[i]
    y_true_f = K.flatten(truth)
    y_pred_f = K.flatten(pred)
    m.update_state(truth,pred)



Show Masks

In [None]:
def display_masks(im, mask):
    plt.figure()
    plt.imshow(im, 'gray', interpolation='none')

    plt.show()
display_masks(load_img(input_img_paths[0]))