In [0]:
# Installation of dependencies 
!pip install imageio
!pip install -U tensorboardcolab
!rm -rf ./unet
!git clone https://github.com/mariomeissner/unet-segmentation.git ./unet
!mv ./unet/* .
# Optinal for graph plotting
#!apt-get install graphviz
#!pip install graphviz pydot


In [0]:
# Imports 
import numpy as np 
import os
import skimage.transform as trans
import matplotlib.pyplot as plt
import model as unet_model
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras.metrics import categorical_accuracy
from keras.preprocessing.image import ImageDataGenerator
from tensorboardcolab import TensorBoardColab, TensorBoardColabCallback
from keras import backend
from keras.utils import plot_model
from imageio import imread, imwrite
from skimage import transform

In [0]:
USE_DRIVE = True
if(USE_DRIVE):
    from google.colab import drive
    folder = '/content/gdrive/My Drive/Projects/datasets/steven2358-larynx_data/'
    drive.mount('/content/gdrive')
else:
    folder = 'your/local/dataset/folder/here'
    
# Hyperparameters
num_images = 154

In [0]:
# Function for plotting model history
def plot_history(history, metric):
  plt.plot(history.history[metric])
  plt.plot(history.history['val_'+metric])
  plt.title('model ' + metric)
  plt.ylabel(metric)
  plt.xlabel('epoch')
  plt.legend(['train', 'validation'], loc='upper left')
  plt.show()
  # summarize history for loss
  plt.plot(history.history['loss'])
  plt.plot(history.history['val_loss'])
  plt.title('model loss')
  plt.ylabel('loss')
  plt.xlabel('epoch')
  plt.legend(['train', 'validation'], loc='upper left')
  plt.show()
  
  
def plot_history_comparison(histories, labels, metric):  
  for history,label in zip(histories,labels):
    color = "#{:06x}".format(random.randint(0, 0xFFFFFF))
    plt.plot(history.history[metric], label = metric + ": " + label, color = color)
    plt.plot(history.history['val_' + metric], label = metric + ": " + label, color = color)
    plt.title('model ' + metric)
    plt.ylabel(metric)
    plt.xlabel('epoch')
    plt.legend()
  plt.show()
  for history,label in zip(histories,labels):
    color = "#{:06x}".format(random.randint(0, 0xFFFFFF))
    plt.plot(history.history['loss'], label = label + ": loss", color = color)
    plt.plot(history.history['val_loss'], label = label + ": val_loss", color = color)
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend()
  plt.show()

In [0]:
images = np.zeros((154,160,240,3), dtype=np.float32)
for i,filename in enumerate(sorted(os.listdir(folder + 'images_cropped/images/'))):
  images[i,:,:,:] = imread(folder + 'images_cropped/images/' + filename, pilmode='RGB')
print("loaded images")
images = images / 255
print(images.shape)

labels = np.zeros((154,160,240,3), dtype=np.float32)
for i,filename in enumerate(sorted(os.listdir(folder + 'labels_cropped/labels/'))):
  labels[i,:,:,:] = imread(folder + 'labels_cropped/labels/' + filename, pilmode='RGB')
print("loaded labels")
labels = labels / 255
print(labels.shape)

In [0]:
# Have a peek at what they look like
# print(images[3][100][120:140])
# print(labels[3][100][120:140])

In [0]:
# Shuffle images and labels and create test split
num_test_images = 15
num_val_images = 15
num_train_images = num_images - (num_test_images + num_val_images)
rng_state = np.random.get_state()
np.random.shuffle(images)
np.random.set_state(rng_state)
np.random.shuffle(labels)
images_test = images[:num_test_images]
labels_test = labels[:num_test_images]
images_val = images[-num_val_images:]
labels_val = labels[-num_val_images:]
images = images[num_test_images:-num_val_images]
labels = labels[num_test_images:-num_val_images]

In [0]:
# Data augmentation

data_gen_args = dict(
    rotation_range = 5,
    width_shift_range = 5,
    height_shift_range = 5,
    horizontal_flip = True,
    zoom_range = 0.05,
    data_format = 'channels_last',
)

image_datagen = ImageDataGenerator(**data_gen_args)
label_datagen = ImageDataGenerator(**data_gen_args)


In [0]:
batch_size = 3
image_gen = image_datagen.flow(images, seed = 1, batch_size=batch_size, shuffle=True)
label_gen = label_datagen.flow(labels, seed = 1, batch_size=batch_size, shuffle=True)
image_val = image_datagen.flow(images_val, seed = 1, batch_size=batch_size, shuffle=True)
label_val = label_datagen.flow(labels_val, seed = 1, batch_size=batch_size, shuffle=True)
train_gen = zip(image_gen, label_gen)
val_gen   = zip(image_val, label_val)
num_train_steps = num_train_images // batch_size
num_val_steps = num_val_images // batch_size

In [0]:
# Check that original images and labels match up
position = 55
plt.imshow(images[position])
plt.show()
plt.imshow(labels[position])
plt.show()

In [0]:
# Check that augmented images and labels match up
image_test, label_test = next(train_gen)
image_test, label_test = image_test[0], label_test[0]
print(image_test.dtype)
print(image_test[100][100])
print(label_test.dtype)
print(label_test[100][100])
plt.imshow(image_test)
plt.show()
plt.imshow(label_test)
plt.show()
image_test, label_test = next(val_gen)
image_test, label_test = image_test[0], label_test[0]
print(image_test.dtype)
print(image_test[100][100])
print(label_test.dtype)
print(label_test[100][100])
plt.imshow(image_test)
plt.show()
plt.imshow(label_test)
plt.show()

In [0]:
def flatten_image(predicted):
  flat_predicted = np.zeros(predicted.shape)
  for i in range(len(flat_predicted)):
    for j in range(len(flat_predicted[0])):
      flat_predicted[i,j,np.argmax(predicted[i,j])] = 1. 
  return flat_predicted

In [0]:
def show_prediction(image_batch, label_batch, model, i=0):
  image, label = image_batch[i], label_batch[i]
  predicted = model.predict(image_batch)[i]

  # Show real image and label, then predicted one
  plt.imshow(image)
  plt.show()
  plt.imshow(label)
  plt.show()

  # Flatten prediction
  flat_predicted = flatten_image(predicted)
  plt.imshow(flat_predicted)
  plt.show()

In [0]:
def weighted_categorical_crossentropy(weights):
    """
    A weighted version of keras.objectives.categorical_crossentropy
    @url: https://gist.github.com/wassname/ce364fddfc8a025bfab4348cf5de852d
    @author: wassname

    Variables:
        weights: numpy array of shape (C,) where C is the number of classes
    
    Usage:
        weights = np.array([0.5,2,10]) # Class one at 0.5, class 2 twice the normal weights, class 3 10x.
        loss = weighted_categorical_crossentropy(weights)
        model.compile(loss=loss,optimizer='adam')
    """
    
    weights = K.variable(weights)
        
    def loss(y_true, y_pred):
        # scale predictions so that the class probas of each sample sum to 1
        y_pred /= K.sum(y_pred, axis=-1, keepdims=True)
        # clip to prevent NaN's and Inf's
        y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())
        # calc
        loss = y_true * K.log(y_pred) * weights
        loss = -K.sum(loss, -1)
        return loss
    
    return loss

In [0]:
def iou(y_true, y_pred, label: int):
    """
    Return the Intersection over Union (IoU) for a given label.
    @url: https://gist.github.com/Kautenja/69d306c587ccdf464c45d28c1545e580
    Args:
        y_true: the expected y values as a one-hot
        y_pred: the predicted y values as a one-hot or softmax output
        label: the label to return the IoU for
    Returns:
        the IoU for the given label
    """
    # extract the label values using the argmax operator then
    # calculate equality of the predictions and truths to the label
    y_true = K.cast(K.equal(K.argmax(y_true), label), K.floatx())
    y_pred = K.cast(K.equal(K.argmax(y_pred), label), K.floatx())
    # calculate the |intersection| (AND) of the labels
    intersection = K.sum(y_true * y_pred)
    # calculate the |union| (OR) of the labels
    union = K.sum(y_true) + K.sum(y_pred) - intersection
    # avoid divide by zero - if the union is zero, return 1
    # otherwise, return the intersection over union
    return K.switch(K.equal(union, 0), 1.0, intersection / union)
  
def mean_iou(y_true, y_pred):
    """
    Return the Intersection over Union (IoU) score.
    Args:
        y_true: the expected y values as a one-hot
        y_pred: the predicted y values as a one-hot or softmax output
    Returns:
        the scalar IoU value (mean over all labels)
    """
    # get number of labels to calculate IoU for
    num_labels = K.int_shape(y_pred)[-1]
    # initialize a variable to store total IoU in
    total_iou = K.variable(0)
    # iterate over labels to calculate IoU for
    for label in range(num_labels):
        total_iou = total_iou + iou(y_true, y_pred, label)
    # divide total IoU by number of labels to get mean IoU
    return total_iou / num_labels

In [0]:
# Experimental custom loss function for unbalanced segmentation
def dice_coef_loss(y_true, y_pred, smooth=1.0):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return 1 - (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
  
# Lets test it
# Take a random label
label = labels[0]

# given the label as prediction, loss should be 0
print(f"Should be almost 0: {K.eval(dice_coef_loss(label, label))}")

# Doing a few random modifications should give a small loss
pred = np.array(label)
pred[100,100] = [0.5, 0.5, 0.]
pred[105,105] = [0.5, 0., 0.5]
pred[110,110] = [0.4, 0.5, 0.1]
pred[115,115] = [0.5, 0.5, 0.]
pred[120,120] = [0.5, 0.5, 0.]
print(f"Should be close to 0: {K.eval(dice_coef_loss(label, pred))}")

# taking a completely different label as prediction should give a high loss
label2 = labels[44]
print(f"Should be high: {K.eval(dice_coef_loss(label, label2))}")

# taking a real prediction from a trained model
pred = model.predict(images[:1])[0]
print(K.eval(dice_coef_loss(label, pred)))

In [0]:
model = unet_model.unet(input_size=(160,240,3), activation='relu')
#plot_model(model, to_file="model.png")

In [0]:
checkpoint = ModelCheckpoint(folder + 'unet_checkpoint.hdf5', 
                             monitor='val_mean_iou',
                             verbose=1, 
                             save_best_only=True)

In [0]:
# Compile using experimental custom loss 
#model.compile(optimizer = Adam(lr = 2e-4), loss = dice_coef_loss, metrics = ['accuracy'])

In [0]:
# Compile using crossentropy
model.compile(optimizer = Adam(lr = 2e-4), loss = "categorical_crossentropy", metrics = ['categorical_accuracy'])

In [0]:
# Load model weights
# model.load_weights(folder + 'best_unet.hdf5')

In [0]:
# Train the model

model.compile(optimizer = Adam(lr = 1e-4), 
              loss = weighted_categorical_crossentropy([2,2,2]), 
              metrics = [categorical_accuracy, mean_iou])

history = model.fit_generator(train_gen,
                          validation_data = val_gen,
                          validation_steps = num_val_steps,
                          steps_per_epoch = num_train_steps,
                          epochs = 50,
                          callbacks = [checkpoint],
                          )

In [0]:
# Plot history 
plot_history(history, 'mean_iou')
# Show prediction over some val images
image_batch, label_batch = next(val_gen)
image, label = image_batch[0], label_batch[0]
show_prediction(image_batch, label_batch, model)

In [0]:
# model.save(folder + 'best_unet.hdf5')

In [0]:
results = model.evaluate(images_test, labels_test)
print(model.metrics_names)
print(results)

In [0]:
# Have a look at test image performance
i = 9
predicted = model.predict(images_test)[i]
# Show real image and label, then predicted one
show_prediction(images_test, labels_test, model, i=i)

In [0]:
# # Test different learning rates
# histories = []
# labels = []
# models = []
# for lr in (1e-4, 5e-4, 1e-3):
#   model = unet_model.unet(input_size=(160,240,3))
#   model.compile(optimizer = Adam(lr = lr), loss = "categorical_crossentropy", metrics = ['categorical_accuracy'])
#   history = model.fit_generator(train_gen,
#                             validation_data = val_gen,
#                             validation_steps = num_val_steps,
#                             steps_per_epoch = num_train_steps,
#                             epochs = 50
#                             )
#   histories.append(history)
#   labels.append(str(lr))
#   models.append(model)

In [0]:
# plot_history_comparison(histories, labels)