# Retinal vessel segmentation
Retinal vessel segmentation is the task of segmenting blood vessels in retinal images. It is an essential task for developing the computer-aided diagnosis system for retinal diseases.

| ![](https://github.com/orobix/retina-unet/raw/master/test/test_Original_GroundTruth_Prediction3.png) |
|:--:|
| <b>Retinal vessel segmentation</b> |

# U-Net architecture
| ![](https://camo.githubusercontent.com/bc2e09476b5c7db5ea4e19251ac9a19af9ba5a89f16d58f72459059c3cffb969/68747470733a2f2f63646e2d696d616765732d312e6d656469756d2e636f6d2f6d61782f3830302f312a6a716f416d4579516d784b704763416b6250474e4d512e706e67) |
|:--:|
| <b>U-Net model</b> |

The U-Net is convolutional network architecture for fast and precise segmentation of images. It is an encoder-decoder model with some skip connections between. The major advantage of this architecture is its ability to take into account a wider context when making a prediction for a pixel (foreground vs. background). This model consists of large number of channels used in the up-sampling operation.

In [1]:
# Python
import os
import numpy as np
import configparser
from matplotlib import pyplot as plt
from PIL import Image
import cv2

# Keras
from keras.models import model_from_json
from keras.models import Model

In [2]:
# Data paths
training_1st_manual = './DRIVE/training/1st_manual/'
training_images = './DRIVE/training/images/'
training_masks = './DRIVE/training/mask/'
test_images = './DRIVE/test/images/'
test_masks = './DRIVE/test/mask/'

# Dimensions of patch
patch_height = 48
patch_width = 48

# Training settings
n_epochs = 150
batch_size = 32
n_subimgs = 190000

## 1. Pre-processing images
Before training, the 20 images of DRIVE training datasets are pre-processed with the following steps:
- Gray-scale conversion
- Normalization
- Contrast-limited adaptive historam equalization (CLAHE)
- Gamma adjustment

In [3]:
# Gray-scale conversion
def rgb2gray(rgb_imgs):
    assert len(rgb_imgs.shape) == 4 # 4D arrays
    assert rgb_imgs.shape[1] == 3 # Check the original RGB images
    
    bn_imgs = rgb_imgs[:,0,:,:]*0.299 + rgb_imgs[:,1,:,:]*0.587 + _rgb_imgs[:,2,:,:]*0.114
    bn_imgs = np.reshape(bn_imgs,(rgb_imgs.shape[0], 1, rgb_imgs.shape[2], rgb_imgs.shape[3]))
    
    return bn_imgs

In [4]:
# Normalization
def dataset_normalize(imgs):
    assert len(imgs.shape) == 4 # 4D arrays
    assert imgs.shape[1] == 1 # Check the gray-scale images
    
    # Normalize mean and standard deviation of the images
    normalized_imgs = np.empty(imgs.shape)
    normalized_imgs = (imgs-np.mean(imgs))/np.std(imgs)
    
    for i in range(imgs.shape[0]):
        normalized_imgs[i] = ((normalized_imgs[i]-np.min(normalized_imgs[i]))/(np.max(normalized_imgs[i])-np.min(normalized_imgs[i])))*255
        
        return normalized_imgs

In [5]:
# Histogram equalization
def his_equalized(imgs):
    assert len(imgs.shape) == 4 # 4D arrays
    assert imgs.shape[1] == 1 # Check the gray-scale images
    
    equalized_imgs = np.empty(imgs.shape)
    
    for i in range(imgs.shape[0]):
        equalized_imgs[i,0] = cv2.equalizeHist(np.array(imgs[i,0], dtype=np.uint8))
    
    return equalized_imgs

In [6]:
# Contrast-limited adaptive histogram equalization (CLAHE)
def clahe(imgs):
    assert len(imgs.shape) == 4 # 4D arrays
    assert imgs.shape[1] == 1 # Check the gray-scale images
    
    # Create a CLAHE object
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    clahe_imgs = np.empty(imgs.shape)
    
    for i in range(imgs.shape[0]):
        clahe_imgs[i,0] = clahe.apply(np.array(imgs[i,0], dtype=np.uint8))
    
    return clahe_imgs

In [8]:
# Gamma adjustment
def gamma_adjust(imgs, gamma):
    assert len(imgs.shape) == 4 # 4D arrays
    assert imgs.shape[1] == 1 # Check the gray-scale images
    
    # Build a lookup table mapping the pixel values to their adjusted gamma values
    inv_gamma = 1.0/gamma
    table = np.array([((i/255)**inv_gamma) for i in np.arange(0,256)]).astype('uint8')
    
    # Apply gamma adjustment using the lookup table
    adjusted_imgs = np.array(imgs.shape)
    
    for i in range(imgs.shape[0]):
        adjusted_imgs[i,0] = cv2.LUT(np.array(imgs[i,0], dtype=np.uint8), table)
        
    return adjusted_imgs

In [9]:
# Pre-processing data
def pre_process(original_imgs):
    assert len(original_imgs.shape) == 4 # 4D arrays
    assert original_imgs.shape[1] == 1 # Check the original RGB images
    
    train_imgs = rgb2gray(original_imgs)
    train_imgs = dataset_normalize(train_imgs)
    train_imgs = clahe(train_imgs)
    train_imgs = gamma_adjust(train_imgs, 1.2)
    train_imgs = train_imgs/255.
    
    return train_imgs

## 2. Extract image patches
The training of U-Net model is performed on sub-images (called patches of dimension 48x48) of the pre-processed images. Each patch has the randomly selected center in the full image as well as the partially or completely outside the field of view (FOV). In this way the neural network learns how to discriminate the FOV border from blood vessels.

In [10]:
# Check if the patch is fully contained in the FOV
def is_patch_inside_FOV(x, y, img_width, img_height, patch_height):
    x_ = x-int(img_width/2) # Origin (0,0) is shifted to image center
    y_ = y-int(img_height/2)  # Origin (0,0) is shifted to image center
    
    # The limit to contain the full patch in the FOV
    R_inside = 270-int((patch_height*np.sqrt(2.0))/2) # Radius is 270 (from DRIVE docs), minus the patch diagonal (assumed it is a square)
    radius = np.sqrt((x_*x_)+(y_*y_))
    if radius < R_inside:
        return True
    else:
        return False

In [12]:
# Extract patches from original images
def extract_patches(original_imgs, masks, patch_height, patch_width, n_patches, inside=True):
    if n_patches%original_imgs.shape[0] != 0:
        exit()
    assert len(original_imgs.shape) == 4 and len(masks.shape) == 4 # 4D arrays
    assert original_imgs.shape[1] == 1 or original_imgs.shape[1] == 3
    assert masks.shape[1] == 1
    
    patches = np.empty((n_patches, original_imgs.shape[1], patch_height, patch_width))
    mask_patches = np.empty((n_patches, masks.shape[1], patch_height, patch_width))
    img_height = original_imgs.shape[2]
    img_width = original_imgs.shape[3]
    
    n_patches_per_img = int(n_patches/original_imgs.shape[0])
    print('Number of patches per full original image: ' + str(n_patches_per_img))
    
    iter_total = 0 # Total iterations of patch through the original images
    for i in range(original_imgs.shape[0]): # Loop over full original images
        k = 0
        x_center = random.randint(0+int(patch_width/2), img_width-int(patch_width/2))
        y_center = random.randint(0+int(patch_height/2), img_height-int(patch_height/2)) 
        
        if inside == True:
            if is_patch_inside_FOV(x_center, y_center, img_width, img_height, patch_height) == False:
                continue
            patch = original_imgs[i, :, y_center-int(patch_height/2):y_center+int(patch_height/2), x_center-int(patch_width/2):x_center+int(patch_width/2)]
            mask_patch = masks[i, :, y_center-int(patch_height/2):y_center+int(patch_height/2), x_center-int(patch_width/2):x_center+int(patch_width/2)]
            patches[iter_total] = patch
            mask_patches[iter_total] = mask_patch
            iter_total += 1   # Go to next iteration
            k += 1
            
    return patches, mask_patches