In [0]:
#!apt-get install graphviz
#!pip install graphviz pydot
!pip install imageio
!git clone https://github.com/mariomeissner/unet-segmentation.git ./unet

In [0]:
# Imports 
import numpy as np 
import os
import skimage.transform as trans
import matplotlib.pyplot as plt
import unet.model as unet_model
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras.preprocessing.image import ImageDataGenerator
from keras import backend
from keras.utils import plot_model
from imageio import imread, imwrite
from skimage import transform

In [0]:
USE_DRIVE = True
if(USE_DRIVE):
    from google.colab import drive
    folder = '/content/gdrive/My Drive/Projects/datasets/steven2358-larynx_data/'
    drive.mount('/content/gdrive')
else:
    folder = 'your/local/dataset/folder/here'

In [0]:
def plot_metrics(history):
  # summarize history for accuracy
  plt.plot(history.history['acc'])
  #plt.plot(history.history['val_acc'])
  plt.title('model accuracy')
  plt.ylabel('accuracy')
  plt.xlabel('epoch')
  #plt.legend(['train', 'test'], loc='upper left')
  plt.show()
  # summarize history for loss
  plt.plot(history.history['loss'])
  #plt.plot(history.history['val_loss'])
  plt.title('model loss')
  plt.ylabel('loss')
  plt.xlabel('epoch')
  #plt.legend(['train', 'test'], loc='upper left')
  plt.show()

In [0]:
images = np.zeros((154,160,240,3), dtype=np.float32)
for i,filename in enumerate(sorted(os.listdir(folder + 'images_cropped/images/'))):
  images[i,:,:,:] = imread(folder + 'images_cropped/images/' + filename, pilmode='RGB')
print("loaded images")
images = images / 255
print(images.shape)

labels = np.zeros((154,160,240,3), dtype=np.float32)
for i,filename in enumerate(sorted(os.listdir(folder + 'labels_cropped/labels/'))):
  labels[i,:,:,:] = imread(folder + 'labels_cropped/labels/' + filename, pilmode='RGB')
print("loaded labels")
labels = labels / 255
print(labels.shape)

In [0]:
# Have a peak at what they look like
print(images[3][100][120:140])
print(labels[3][100][120:140])

In [0]:
# Data augmentation

data_gen_args = dict(
    rotation_range = 5,
    width_shift_range = 5,
    height_shift_range = 5,
    horizontal_flip = True,
    zoom_range = 0.05,
    data_format = 'channels_last',
    #validation_split = 0.2,
)

image_datagen = ImageDataGenerator(**data_gen_args)
label_datagen = ImageDataGenerator(**data_gen_args)


In [0]:
image_gen = image_datagen.flow(images, seed = 1, batch_size=8, shuffle=True)
label_gen = label_datagen.flow(labels, seed = 1, batch_size=8, shuffle=True)
train_gen = zip(image_gen, label_gen)

In [0]:
# Check that original images and labels match up
position = 55
plt.imshow(images[position])
plt.show()
plt.imshow(labels[position])
plt.show()

In [0]:
# Check that augmented images and labels match up
image_test, label_test = next(train_gen)
image_test, label_test = image_test, label_test
print(image_test[0].dtype)
print(image_test[0][100][100])
print(label_test[0].dtype)
print(label_test[0][100][100])
plt.imshow(image_test[0])
plt.show()
plt.imshow(label_test[0])
plt.show()

In [0]:
checkpoint = ModelCheckpoint(folder + 'unet_checkpoint.hdf5', 
                             monitor='accuracy',
                             verbose=1, 
                             save_best_only=True)

In [0]:
model = unet_model.unet(input_size=(160,240,3))
#plot_model(model, to_file="model.png")

In [0]:
model.compile(optimizer = Adam(lr = 2e-4), loss = 'categorical_crossentropy', metrics = ['accuracy'])

In [0]:
#model.save(folder + 'unet.hdf5')

In [0]:
# Train and periodically check how well we are doing.
loops = 5
for _ in range(loops):
  history = model.fit_generator(train_gen, 
                              steps_per_epoch=154//8,
                              epochs = 10,
                              callbacks = [checkpoint],
                              )
  image_batch, label_batch = next(train_gen)
  image, label = image_batch[0], label_batch[0]
  predicted = model.predict(image_batch)[0]
  # Show real image and label, then predicted one
  print(predicted[100][120:140])
  plt.imshow(image)
  plt.show()
  plt.imshow(label)
  plt.show()
  plt.imshow(predicted)
  plt.show()
  # Flatten prediction
  flat_predicted = np.zeros(predicted.shape)
  for i in range(len(flat_predicted)):
    for j in range(len(flat_predicted[0])):
      flat_predicted[i,j,np.argmax(predicted[i,j])] = 1. 
  plt.imshow(flat_predicted)
  plt.show()