# Detecting Blindness with Deep Learning
> Using CNNs to detect diabetic retinopathy in retina photos

- toc: true 
- badges: true
- comments: true
- author: Nikita Kozodoi
- categories: [python,deep learning,computer vision]
- image: images/posts/retina.png

# 1. Overview

Can deep learning help to detect blindness?

This blogpost describes the pipeline of the project that develops a convolutional neural network (CNN) for predicting the severity of the diabetic retinopathy based on the patient's retina photograph. The project was completed within the scope of the [Udacity Machine Learning Engineer](https://confirm.udacity.com/LMMJDA7C) nano-degree program and the [Kaggle competition](https://www.kaggle.com/c/aptos2019-blindness-detection/data) hosted by the Asia Pacific Tele-Ophthalmology Society (APTOS).

The blogpost provides a project walkthrough covering the most important modeling steps.
- data exploration and image preprocessing to normalize images from different clinics
- using transfer learning to pre-train CNN on a large data set and fine-tune the model on target domain
- implementing techniques such as learnung rate scheduler, test-time augmentation and others to improve the performance

The modeling is done in `PyTorch`. All Jupyter notebooks and a PDF report are [available on Github](https://github.com/kozodoi/Udacity_Blindness_Detection).

# 2. Motivation

Diabetic retinopathy (DR) is one of the leading causes of vision loss. The World Health Organization reports that more than 300 million people worldwide have diabetes (Wong et al 2016). According to a recent study from International Diabetes Federation, the global prevalence of DR among the individuals with diabetes for the period from 2015 to 2019 was at more than 25% (Thomas et al 2019). The disease prevalence has been rising rapidly in developing countries.

Early detection and treatment are crucial steps towards preventing DR. The screening procedure requires a trained clinical expert to examine the fundus photographs of the patient's retina. This creates delays in diagnosis and treatment of the disease. This is especially relevant for developing countries, which often lack qualified medical stuff to perform the diagnosis. Automated detection of DR can speed up the efficiency and coverage of the screening programs.

# 3. Data preparation

Data exploration and preprocessing are very important steps that are frequently underestimated. Reagrdless of the application domain, the quality of the input data can have a strong impact on the resulting performance of the developed machine learning models. Therefore, it is crucial to take some time to look at the data and think about possible issues that should be addressed before moving on to the modeling stage.

## Data exploration

The data set is provided by APTOS and is available for the download at the [competition's website](https://www.kaggle.com/c/aptos2019-blindness-detection/data). The data set includes 3,662 labeled retina images of clinical patients and a test set with 1,928 images with unknown labels. 

The images are taken using a fundus photography technique and labeled by a clinical expert. The integer labels indicate the severity of DR on a scale from 0 to 4, where 0 indicates no disease and 5 is the proliferative stage of DR.

First, let us import the data and look at the class ditribution.

In [None]:
#collapse-hide

############ LIBRARIES

import numpy as np
import pandas as pd

import torch
import torchvision

from torchvision import transforms, datasets
from torch.utils.data import Dataset

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import cv2

from tqdm import tqdm_notebook as tqdm

import random
import time
import sys
import os
import math

import matplotlib.pyplot as plt
import seaborn as sns



########## SETTINGS

pd.set_option('display.max_columns', None)
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')



############ CHECK DIMENSIONS

# import data
train = pd.read_csv('../input/aptos2019-blindness-detection/train.csv')
test  = pd.read_csv('../input/aptos2019-blindness-detection/sample_submission.csv')

# check shape
print(train.shape, test.shape)
print('-' * 15)
print(train['diagnosis'].value_counts(normalize = True))



############ CLASS DISTRIBUTION

# plot
fig = plt.figure(figsize = (15, 5))
plt.hist(train['diagnosis'])
plt.title('Class Distribution')
plt.ylabel('Number of examples')
plt.xlabel('Diagnosis')

# export
plt.savefig('fig_class_distribution.pdf')

![](images/fig_class_distribution.png)

The data set is imbalanced. 49% of the images are healthy patients. The remaining 51\% are pictures with different stages of DR. The least common class is 3 (severe stage) with only 5% of the total examples. 

The data is collected from multiple Indian clinics using a variety of different camera models, which creates discrepancies in the image resolution, aspect ratio and other parameters. This is demonstrated in a figure below, where we plot the histograms of image width, height and aspect ratio. A high variety in the image parameters requires to be accounted for during the prepossessing.

In [None]:
#collapse-hide

# import loop
image_stats = []
for index, observation in tqdm(train.iterrows(), total = len(train)):
    
    # import image
    img = cv2.imread('../input/aptos2019-blindness-detection/train_images/{}.png'.format(observation['id_code']))

    # compute stats
    height, width, channels = img.shape
    ratio = width / height
    
    # save
    image_stats.append(np.array((observation['diagnosis'], height, width, channels, ratio)))

# construct DF
image_stats = pd.DataFrame(image_stats)
image_stats.columns = ['diagnosis', 'height', 'width', 'channels', 'ratio']



############ HISTOGRAMS

fig = plt.figure(figsize = (15, 5))

# width
plt.subplot(1, 3, 1)
plt.hist(image_stats['width'])
plt.title('(a) Image Width')
plt.ylabel('Number of examples')
plt.xlabel('Width')

# height
plt.subplot(1, 3, 2)
plt.hist(image_stats['height'])
plt.title('(b) Image Height')
plt.ylabel('Number of examples')
plt.xlabel('Height')

# ratio
plt.subplot(1, 3, 3)
plt.hist(image_stats['ratio'])
plt.title('(c) Aspect Ratio')
plt.ylabel('Number of examples')
plt.xlabel('Ratio')

# export
plt.savefig('fig_size_distribution.pdf')

![](images/fig_size_distribution.png)

Now, let us take a look at the actual images. The code below creates the `EyeData` dataset class to import the images using `opencv` library. We also create a `DataLoader` object to load sample images and visualize the first batch.

In [None]:
#collapse-hide

############ DATASET

# image preprocessing function
def prepare_image(path, image_size = 256):
    
    # import
    image = cv2.imread(path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # resize
    image = cv2.resize(image, (int(image_size), int(image_size)))

    # convert to tensor    
    image = torch.tensor(image)
    image = image.permute(2, 1, 0)
    return image
  
    
# dataset class:
class EyeData(Dataset):
    
    # initialize
    def __init__(self, data, directory, transform = None):
        self.data      = data
        self.directory = directory
        self.transform = transform
        
    # length
    def __len__(self):
        return len(self.data)
    
    # get items    
    def __getitem__(self, idx):
        img_name = os.path.join(self.directory, self.data.loc[idx, 'id_code'] + '.png')
        image    = prepare_image(img_name)  
        image    = self.transform(image)
        label    = torch.tensor(self.data.loc[idx, 'diagnosis'])
        return {'image': image, 'label': label}
    
    
    
############ EXAMINE SAMPLE BATCH

# transformations
sample_trans = transforms.Compose([transforms.ToPILImage(),
                                   transforms.ToTensor(),
                                  ])

# dataset
sample = EyeData(data       = train, 
                 directory  = '../input/aptos2019-blindness-detection/train_images',
                 transform  = sample_trans)

# data loader
sample_loader = torch.utils.data.DataLoader(dataset     = sample, 
                                            batch_size  = 10, 
                                            shuffle     = False, 
                                            num_workers = 4)

# display images
for batch_i, data in enumerate(sample_loader):

    # extract data
    inputs = data['image']
    labels = data['label'].view(-1, 1)
    
    # create plot
    fig = plt.figure(figsize = (15, 7))
    for i in range(len(labels)):
        ax = fig.add_subplot(2, len(labels)/2, i + 1, xticks = [], yticks = [])     
        plt.imshow(inputs[i].numpy().transpose(1, 2, 0))
        ax.set_title(labels.numpy()[i])

    break

![](images/fig_eye_example.png)

The illustration further emphasizes the difference between the retina photographs in terms of aspect ratio, lighting conditions and camera quality. In practice, the severity of DR is diagnosed by the presence of different visual cues on the retina photographs. This includes signs like abnormal blood vessels, hard exudates and so-called "cotton wool" spots. Comparing the sample images, we can see the presence of exudates and "cotton wool" spots on some of the retina photographs of sick patients.

## Image preprocessing

Visual inspection of retina images suggests that images taken from cameras with different aspect ratios result in some images having large black areas around the eye. The black areas do not contain information relevant for prediction and can be cropped. However, we see that the size of black areas varies from one image to another. To address this, we develop a cropping function that converts the image to grayscale and marks black areas based on the pixel intensity. Next, we find a mask of the image by selecting rows and columns in which all pixels exceed the intensity threshold. This helps to remove vertical or horizontal rectangles filled with black similar to the ones observed in the upper-right image in Figure 4. After removing the black stripes, we resize the images to the same height and width.

Another issue is the eye shape. Depending on the image parameters, some eyes appear to have a circular form, whereas others look like ovals. As the size and shape of items located in the retina determine the disease severity, it is crucial to standardize the eye shape as well. To do so, we develop another cropping function that makes a circular crop of a particular radius around the center of the image.

Finally, we correct for the lightning and brightness discrepancies by smoothing the images using a Gaussian filter. 

The snippet below provides the updated `prepare_image()` function that incorporates the discussed preprocessing steps.

In [5]:
#collapse-hide

############ PREPROCESSING FUNCTIONS

##### image preprocessing function
def prepare_image(path, sigmaX = 10, do_random_crop = False):
    
    # import image
    image = cv2.imread(path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # perform smart crops
    image = crop_black(image, tol = 7)
    if do_random_crop == True:
        image = random_crop(image, size = (0.9, 1))
    
    # resize and color
    image = cv2.resize(image, (int(image_size), int(image_size)))
    image = cv2.addWeighted(image, 4, cv2.GaussianBlur(image, (0, 0), sigmaX), -4, 128)
    
    # circular crop
    image = circle_crop(image, sigmaX = sigmaX)

    # convert to tensor    
    image = torch.tensor(image)
    image = image.permute(2, 1, 0)
    return image


##### automatic crop of black areas
def crop_black(img, tol = 7):
    
    if img.ndim == 2:
        mask = img > tol
        return img[np.ix_(mask.any(1),mask.any(0))]
    
    elif img.ndim == 3:
        gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        mask = gray_img > tol
        check_shape = img[:,:,0][np.ix_(mask.any(1),mask.any(0))].shape[0]
        
        if (check_shape == 0): 
            return img 
        else:
            img1 = img[:,:,0][np.ix_(mask.any(1),mask.any(0))]
            img2 = img[:,:,1][np.ix_(mask.any(1),mask.any(0))]
            img3 = img[:,:,2][np.ix_(mask.any(1),mask.any(0))]
            img  = np.stack([img1, img2, img3], axis = -1)
            return img
        
        
##### circular crop around image center
def circle_crop(img, sigmaX = 10):   
        
    height, width, depth = img.shape
    
    largest_side = np.max((height, width))
    img = cv2.resize(img, (largest_side, largest_side))

    height, width, depth = img.shape
    
    x = int(width / 2)
    y = int(height / 2)
    r = np.amin((x,y))
    
    circle_img = np.zeros((height, width), np.uint8)
    cv2.circle(circle_img, (x,y), int(r), 1, thickness = -1)
    
    img = cv2.bitwise_and(img, img, mask = circle_img)
    return img 


##### random crop
def random_crop(img, size = (0.9, 1)):

    height, width, depth = img.shape
    
    cut = 1 - random.uniform(size[0], size[1])
    
    i = random.randint(0, int(cut * height))
    j = random.randint(0, int(cut * width))
    h = i + int((1 - cut) * height)
    w = j + int((1 - cut) * width)

    img = img[i:h, j:w, :]    
    
    return img

Nesxt, we define a new `EyeData` class that uses the new processing functions and visualize a batch of sample images after corrections.

In [None]:
#collapse-hide

############ DATASET
    
# dataset class:
class EyeData(Dataset):

    # initialize
    def __init__(self, data, directory, transform = None):
        self.data      = data
        self.directory = directory
        self.transform = transform
        
    # length
    def __len__(self):
        return len(self.data)
    
    # get items    
    def __getitem__(self, idx):
        img_name = os.path.join(self.directory, self.data.loc[idx, 'id_code'] + '.png')
        image    = prepare_image(img_name)  
        image    = self.transform(image)
        label    = torch.tensor(self.data.loc[idx, 'diagnosis'])
        return {'image': image, 'label': label}



############ EXAMINE SAMPLE BATCH

image_size = 256


# transformations
sample_trans = transforms.Compose([transforms.ToPILImage(),
                                   transforms.ToTensor(),
                                  ])

# dataset
sample = EyeData(data       = train, 
                 directory  = '../input/aptos2019-blindness-detection/train_images',
                 transform  = sample_trans)

# data loader
sample_loader = torch.utils.data.DataLoader(dataset     = sample, 
                                            batch_size  = 10, 
                                            shuffle     = False, 
                                            num_workers = 4)

# display images
for batch_i, data in enumerate(sample_loader):

    # extract data
    inputs = data['image']
    labels = data['label'].view(-1, 1)
    
    # create plot
    fig = plt.figure(figsize = (15, 7))
    for i in range(len(labels)):
        ax = fig.add_subplot(2, len(labels)/2, i + 1, xticks = [], yticks = [])     
        plt.imshow(inputs[i].numpy().transpose(1, 2, 0))
        ax.set_title(labels.numpy()[i])

    break

# export plot
plt.savefig('../figures/fig_data_example.pdf')

![](images/fig_eye_fixed.png)

This looks much better! Comparing the retina images to the ones before the preprocessing, we can see that the apparent discrepancies between the photographs are now fixed. The eyes now have a similar circular shape and the color scheme is more consistent. This should help the model to detect the signs of the DR on the photographs.

Check out [this notebook](https://www.kaggle.com/ratthachat/aptos-eye-preprocessing-in-diabetic-retinopathy) by Nakhon Ratchasima for more ideas on the image preprocessing for retina photographs. The preprocessing functions used in this project are largely inspired by his work suring the competition.