# Lab Week 8 Extra Task Code
This extra code is used to show how semantic segmentation can be run on a small medical dataset of retina vessels (CHASE).
Semantic segmentation is the task of assigning a class label to every pixel in an image. For this use case, we have images of a retina, which is the part of the eye that contains the light-sensitive cells. The goal is to find all the pixels, that belong to blood vessels. This can aid specialists in the diagnosis and monitoring of issues within the eye.

<img src=https://blogs.kingston.ac.uk/retinal/files/2016/11/cropped-retina6.jpg width="600">

Image Source: https://blogs.kingston.ac.uk/retinal/files/2016/11/cropped-retina6.jpg  



---


The dataset source is as follows:
- Fraz, Muhammad Moazam [Creator], Remagnino, Paolo, Hoppe, Andreas, Uyyanonvara, Bunyarit, Rudnicka, Alicja R [Creator], Owen, Christopher G [Creator] and Barman, Sarah A [Creator] (2012) CHASE_DB1 retinal vessel reference dataset. [Data Collection] [Link](https://blogs.kingston.ac.uk/retinal/chasedb1/).


The dataset has been modified for this exercise and is provided on GCU learn: **CHASEDB1.zip**. Download and extract this file into the same directory as this Jupyter Notebook.

---


**TASK**: You are required to look up any function calls that are unclear to you to understand them: https://www.tensorflow.org/api_docs/python/tf/keras

**NOTE**: Some parts of the code are outlined with the keyword `ADVANCED CODE`. You do not need to try to understand what this part of the code does, simply read the comment next to it.

In [None]:
# import modules
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers


import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'


In [None]:
# obtain the dataset train and test paths
train_img_path = './CHASEDB1/train_images/'
train_label_path = './CHASEDB1/train_labels/'

test_img_path = './CHASEDB1/test_images/'
test_label_path = './CHASEDB1/test_labels/'


In [None]:
# ADVANCED CODE: function to show the image, label and optional prediction
def show_image(img, label, pred=None):

    img = img * 255

    plt.figure(figsize=(10, 10))
    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.title('Image')
    plt.subplot(1, 3, 2)
    plt.imshow(label, cmap='gray')
    plt.title('Label')
    if pred is not None:
        plt.subplot(1, 3, 3)
        plt.imshow(pred, cmap='gray')
        plt.title('Prediction')
    plt.show()

In [None]:
# using image data generator to load the images and labels from the directory
# https://stackoverflow.com/questions/58050113/imagedatagenerator-for-semantic-segmentation 

# Using a small batch size to avoid memory issues
BATCH_SIZE = 1


img_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, validation_split=0.2, rotation_range=90, horizontal_flip=True, vertical_flip=True)
label_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, validation_split=0.2, rotation_range=90, horizontal_flip=True, vertical_flip=True)
seed = 42

train_img_gen = img_data_gen.flow_from_directory(train_img_path, class_mode=None, seed=seed, target_size=(256, 256), batch_size=BATCH_SIZE, subset='training')
train_label_gen = label_data_gen.flow_from_directory(train_label_path, class_mode=None, seed=seed, target_size=(256, 256),  batch_size=BATCH_SIZE, subset='training')
train_gen = zip(train_img_gen, train_label_gen)

val_img_gen = img_data_gen.flow_from_directory(train_img_path, class_mode=None, seed=seed, target_size=(256, 256),  batch_size=BATCH_SIZE, subset='validation')
val_label_gen = label_data_gen.flow_from_directory(train_label_path, class_mode=None, seed=seed, target_size=(256, 256),  batch_size=BATCH_SIZE, subset='validation')
val_gen = zip(val_img_gen, val_label_gen)


test_img_gen = img_data_gen.flow_from_directory(test_img_path, class_mode=None, seed=seed, target_size=(256, 256),  batch_size=BATCH_SIZE)
test_label_gen = label_data_gen.flow_from_directory(test_label_path, class_mode=None, seed=seed, target_size=(256, 256),  batch_size=BATCH_SIZE)
test_gen = zip(test_img_gen, test_label_gen)



In [None]:
# show a sample from the training set
img, label = next(train_gen)
show_image(img[0]/255., label[0])

# U-Net for Semantic Segmentation
U-net is a fully convolutional network that can be used for image segmentation
U-net is used in many applications for image segmentation such as medical imaging, satellite imagery, and more.
It features an encoder and decoder, where the encoder is used to extract features from the image and the decoder is used to upsample the features to the original image size.
Skip connections are used to preserve the spatial information from the encoder to the decoder.
The output is an image with the same size as the input, where each pixel is a class label.

- Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015. ([Link to the original paper](https://arxiv.org/abs/1505.04597))




In [None]:
# Function to build a u-net keras model and code to compile it
# not the non-specified input size. This allows for any size input. 
# The model will be trained on 256x256 images, but it can be used on any size image.

# NOTE: This function uses keras functional API, which differs from the sequential API that you have been using so far. (https://keras.io/guides/functional_api/)

def build_model():
    inputs = layers.Input(shape=(None, None, 3))

    # downsample
    x1 = layers.Conv2D(16, 3, padding='same', activation='relu')(inputs)
    x2 = layers.MaxPool2D()(x1)
    x2 = layers.Conv2D(32, 3, padding='same', activation='relu')(x2)
    x3 = layers.MaxPool2D()(x2)
    x3 = layers.Conv2D(64, 3, padding='same', activation='relu')(x3)
    x4 = layers.MaxPool2D()(x3)
    x4 = layers.Conv2D(128, 3, padding='same', activation='relu')(x4)

    # upsample
    x = layers.UpSampling2D()(x4)
    x = layers.Concatenate()([x, x3])
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
    x = layers.UpSampling2D()(x)
    x = layers.Concatenate()([x, x2])
    x = layers.Conv2D(32, 3, padding='same', activation='relu')(x)
    x = layers.UpSampling2D()(x)
    x = layers.Concatenate()([x, x1])
    x = layers.Conv2D(16, 3, padding='same', activation='relu')(x)

    # output
    outputs = layers.Conv2D(1, 1, padding='same', activation='sigmoid')(x)

    # build functional Model
    model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    model.summary()

    return model

# build the model
model = build_model()

# compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss = tf.keras.losses.BinaryCrossentropy()
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])


The following code is used to train the model.
Depending on your computer hardware, this can be faster or slower.
**TASK**: Try starting with a single epoch to see how training could take. Then set the value to be appropriate for your system. For this task you do not want it to take longer than about 2-3 minutes. 

In [None]:
# train the model
history = model.fit(train_gen, validation_data=val_gen, epochs=1, steps_per_epoch=len(train_img_gen), validation_steps=len(val_img_gen))


# Training History
Often you want to monitor the progress of your model training. This can be done via various helpful tools such as [Tensorboard](https://www.tensorflow.org/tensorboard) or very simple with plotting libraries as below.

Monitoring the training process allows us to see if the model is overfitting or underfitting or if the training is progressing as expected. We can usually tell whether a model converges or not by looking at the training and validation loss and accuracy. 
The following "rule of thumbs" are a good guidance:
- If the training loss does not decrease and the validation loss does not increase, then the model is not learning anything and seems to have converged at a local optimum. 
- If the training loss decreases and the validation loss increases, then the model is overfitting.
- If the training loss decreases and the validation loss decreases, then the model is converging.

In [None]:
# plot the training history
plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.legend()
plt.show()


In [None]:
# NOTE: the focus of this tutorial is on the training process, not on the model performance or metrics. 
# This step is just to show that the model is able to predict something.
model.evaluate(test_gen, steps=len(test_img_gen))

### Examples
Once the model has been trained, we can then observe the prediciton outputs.
As this is just to show you how few lines of code you need to implement a semantic segmentaiton model, we will not focus on improving performance or validation metrics for segmentation. 

In [None]:
# Show samples from the test set
# This code loops through the test set and shows the image, label, and prediction for each image.
for _ in range(len(test_img_gen)):
    img, label = next(test_gen)
    pred = model.predict(img, verbose=0)
    show_image(img[0]/255., label[0], pred[0])

# More Computer Vision Examples
A list of furhter Computer Vision examples using different datasets and models can be found here:
 - U-Net for Pet Segmentaiton https://keras.io/examples/vision/oxford_pets_image_segmentation/
 - U-Net for Brain Tumor Segmentation (3D Data) https://keras.io/examples/vision/3D_image_classification/
 - Classificaiton network from scratch https://keras.io/examples/vision/image_classification_from_scratch/
