<a href="https://colab.research.google.com/github/byunsy/retinal-oct-classification/blob/main/Retinal_OCT_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Retinal OCT (Optical Coherence Topography) Classification

---



## 01. Import Necessary Packages
We first need to import several packages. We will be using TensorFlow and Keras.

In [1]:
import os
import numpy as np
import glob
import shutil
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [2]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [3]:
import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)

## 02. Attain Dataset from Kaggle








Firstly, pip install kaggle.

In [4]:
!pip install -q kaggle

Also import google.colab to upload the kaggle.json file which can be downloaded manually from your kaggle account.

In [None]:
from google.colab import files
files.upload()

Make a new directory and copy the kaggle.json file to that directory. This is required to download datasets.

In [6]:
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/

Make some changes to the permission settings.

In [7]:
!chmod 600 ~/.kaggle/kaggle.json

Download the desired datasets. It is easy if you go to the dataset, and click the additional setting (three dots) and "copy API command".

In [8]:
!kaggle datasets download -d paultimothymooney/kermany2018

Downloading kermany2018.zip to /content
100% 10.8G/10.8G [04:19<00:00, 35.9MB/s]
100% 10.8G/10.8G [04:19<00:00, 44.8MB/s]


Now unzip the downloaded zip folder.

In [None]:
!unzip /content/kermany2018.zip

Delete some unnecessary folders and files from unzipped directory.

In [10]:
!rm -r /content/oct2017/

## 03. Understanding the Dataset

Firstly, create directory paths for the base directory and its main subdirectories.

In [11]:
# Create directory paths
base_dir = os.path.join(os.path.dirname('/content/kermany2018.zip'), 'OCT2017')

test_dir  = os.path.join(base_dir, 'test')
train_dir = os.path.join(base_dir, 'train')
val_dir   = os.path.join(base_dir, 'val')

Now, let's learn more about the number of images in each directory.

**Note that the unzipped directory has a space at the end of OCT2017. Be sure to remove it before proceeding to the next step.

In [12]:
# Number of images
num_ts_norm   = len(os.listdir(os.path.join(test_dir, 'NORMAL')))
num_ts_cnv    = len(os.listdir(os.path.join(test_dir, 'CNV')))
num_ts_dme    = len(os.listdir(os.path.join(test_dir, 'DME')))
num_ts_drusen = len(os.listdir(os.path.join(test_dir, 'DRUSEN')))
num_ts = num_ts_norm + num_ts_cnv + num_ts_dme + num_ts_drusen

num_tr_norm   = len(os.listdir(os.path.join(train_dir, 'NORMAL')))
num_tr_cnv    = len(os.listdir(os.path.join(train_dir, 'CNV')))
num_tr_dme    = len(os.listdir(os.path.join(train_dir, 'DME')))
num_tr_drusen = len(os.listdir(os.path.join(train_dir, 'DRUSEN')))
num_tr = num_tr_norm + num_tr_cnv + num_tr_dme + num_tr_drusen

num_vl_norm   = len(os.listdir(os.path.join(val_dir, 'NORMAL')))
num_vl_cnv    = len(os.listdir(os.path.join(val_dir, 'CNV')))
num_vl_dme    = len(os.listdir(os.path.join(val_dir, 'DME')))
num_vl_drusen = len(os.listdir(os.path.join(val_dir, 'DRUSEN')))
num_vl = num_vl_norm + num_vl_cnv + num_vl_dme + num_vl_drusen

# Display number of images in each directory
print("TOTAL NUMBER OF TEST IMAGES:", num_ts)
print("Number of NORMAL - TEST :", num_ts_norm)
print("Number of CNV    - TEST :", num_ts_cnv)
print("Number of DME    - TEST :", num_ts_dme)
print("Number of DRUSEN - TEST :", num_ts_drusen, "\n")

print("TOTAL NUMBER OF TRAIN IMAGES:", num_tr)
print("Number of NORMAL - TRAIN :", num_tr_norm)
print("Number of CNV    - TRAIN :", num_tr_cnv)
print("Number of DME    - TRAIN :", num_tr_dme)
print("Number of DRUSEN - TRAIN :", num_tr_drusen, "\n")

print("TOTAL NUMBER OF VALIDATION IMAGES:", num_vl)
print("Number of NORMAL - VALIDATION :", num_vl_norm)
print("Number of CNV    - VALIDATION :", num_vl_cnv)
print("Number of DME    - VALIDATION :", num_vl_dme)
print("Number of DRUSEN - VALIDATION :", num_vl_drusen, "\n")

print("-"*50)
print("NORMAL :", num_ts_norm + num_tr_norm + num_vl_norm)
print("CNV    :", num_ts_cnv + num_tr_cnv + num_vl_cnv)
print("DME    :", num_ts_dme + num_tr_dme + num_vl_dme)
print("DRUSEN :", num_ts_drusen + num_tr_drusen + num_vl_drusen)


TOTAL NUMBER OF TEST IMAGES: 968
Number of NORMAL - TEST : 242
Number of CNV    - TEST : 242
Number of DME    - TEST : 242
Number of DRUSEN - TEST : 242 

TOTAL NUMBER OF TRAIN IMAGES: 83484
Number of NORMAL - TRAIN : 26315
Number of CNV    - TRAIN : 37205
Number of DME    - TRAIN : 11348
Number of DRUSEN - TRAIN : 8616 

TOTAL NUMBER OF VALIDATION IMAGES: 32
Number of NORMAL - VALIDATION : 8
Number of CNV    - VALIDATION : 8
Number of DME    - VALIDATION : 8
Number of DRUSEN - VALIDATION : 8 

--------------------------------------------------
NORMAL : 26565
CNV    : 37455
DME    : 11598
DRUSEN : 8866


Notice that we only have **32** validation images and **968** test images whereas we have over **83,484** training images. To get a less extreme division in the dataset, we will first append the three sets together and then randomly split them into testing, training and validation sets with 80:10:10 ratio. 

It is important to note that we also have some imbalance in the number of images for different categories (NORMAL, CNV, DME, DRUSEN). To handle this issue, we will create class weights which we will later use in the model training process.

## 04. Data Preprocessing

In [None]:
# Move the images to corresponding directories (from val to train)
!mv /content/OCT2017/val/NORMAL/* /content/OCT2017/train/NORMAL/
!mv /content/OCT2017/val/CNV/* /content/OCT2017/train/CNV/
!mv /content/OCT2017/val/DME/* /content/OCT2017/train/DME/
!mv /content/OCT2017/val/DRUSEN/* /content/OCT2017/train/DRUSEN/

# Move the images to corresponding directories (from test to train)
!mv /content/OCT2017/test/NORMAL/* /content/OCT2017/train/NORMAL/
!mv /content/OCT2017/test/CNV/* /content/OCT2017/train/CNV/
!mv /content/OCT2017/test/DME/* /content/OCT2017/train/DME/
!mv /content/OCT2017/test/DRUSEN/* /content/OCT2017/train/DRUSEN/

Check if the files have moved.

In [None]:
# Number of images
num_ts_norm   = len(os.listdir(os.path.join(test_dir, 'NORMAL')))
num_ts_cnv    = len(os.listdir(os.path.join(test_dir, 'CNV')))
num_ts_dme    = len(os.listdir(os.path.join(test_dir, 'DME')))
num_ts_drusen = len(os.listdir(os.path.join(test_dir, 'DRUSEN')))
num_ts = num_ts_norm + num_ts_cnv + num_ts_dme + num_ts_drusen

num_tr_norm   = len(os.listdir(os.path.join(train_dir, 'NORMAL')))
num_tr_cnv    = len(os.listdir(os.path.join(train_dir, 'CNV')))
num_tr_dme    = len(os.listdir(os.path.join(train_dir, 'DME')))
num_tr_drusen = len(os.listdir(os.path.join(train_dir, 'DRUSEN')))
num_tr = num_tr_norm + num_tr_cnv + num_tr_dme + num_tr_drusen

num_vl_norm   = len(os.listdir(os.path.join(val_dir, 'NORMAL')))
num_vl_cnv    = len(os.listdir(os.path.join(val_dir, 'CNV')))
num_vl_dme    = len(os.listdir(os.path.join(val_dir, 'DME')))
num_vl_drusen = len(os.listdir(os.path.join(val_dir, 'DRUSEN')))
num_vl = num_vl_norm + num_vl_cnv + num_vl_dme + num_vl_drusen

# Display number of images in each directory
print("TOTAL NUMBER OF TEST IMAGES:", num_ts)
print("Number of NORMAL - TEST :", num_ts_norm)
print("Number of CNV    - TEST :", num_ts_cnv)
print("Number of DME    - TEST :", num_ts_dme)
print("Number of DRUSEN - TEST :", num_ts_drusen, "\n")

print("TOTAL NUMBER OF TRAIN IMAGES:", num_tr)
print("Number of NORMAL - TRAIN :", num_tr_norm)
print("Number of CNV    - TRAIN :", num_tr_cnv)
print("Number of DME    - TRAIN :", num_tr_dme)
print("Number of DRUSEN - TRAIN :", num_tr_drusen, "\n")

print("TOTAL NUMBER OF VALIDATION IMAGES:", num_vl)
print("Number of NORMAL - VALIDATION :", num_vl_norm)
print("Number of CNV    - VALIDATION :", num_vl_cnv)
print("Number of DME    - VALIDATION :", num_vl_dme)
print("Number of DRUSEN - VALIDATION :", num_vl_drusen, "\n")

We can see that we have successfully moved all the validation images to the training directory. We will now properly split the dataset using train_test_split() from sklearn.

In [None]:
# Get a list of all the filenames in the train directory
files = glob.glob("/content/chest_xray/train/*/*")

# Randomly shuffle and split the files into two sets (lists) in 80:20 ratio
train_files, val_files = train_test_split(files, test_size=0.2)

# Number of training and validation images
NUM_TRAIN = len(train_files)
NUM_VALIDATION = len(val_files)

print("TRAIN     :", NUM_TRAIN)
print("VALIDATION:", NUM_VALIDATION)

We now have a proper split between training and validation sets. 

Let's create a function to identify the label of an image and count the number of **NORMAL** and **PNEUMONIA** images in the training set.

In [None]:
# Parses the filename and determines whether it is in normal or pneumonia class
# Returns True or 1 if it is PNEU. False or 0 if it is NORM (not PNEU)
def get_label(filename):
  label = filename.split(os.path.sep)[-2]
  return label == 'PNEUMONIA'

# Counter
NUM_NORM = 0
NUM_PNEU = 0

# Iterate through each file and increment accordingly
for file in train_files:
  if get_label(file):
    NUM_PNEU += 1
  else:
    NUM_NORM += 1

print("NORMAL    in TRAINING:", NUM_NORM)
print("PNEUMONIA in TRAINING:", NUM_PNEU)

Organize these newly split training and validation sets in correct directories. 

In [None]:
# For each file, determine its label(NORM or PNEU) and move it to its
# corresponding directory. 
# If it is already in the correct directory, then do nothing (pass).

for file in train_files:
  try:
    if get_label(file):
      shutil.move(file, os.path.join(train_dir, 'PNEUMONIA'))
    else:
      shutil.move(file, os.path.join(train_dir, 'NORMAL'))
  except:
    pass

for file in val_files:
  try:
    if get_label(file):
      shutil.move(file, os.path.join(val_dir, 'PNEUMONIA'))
    else:
      shutil.move(file, os.path.join(val_dir, 'NORMAL'))
  except: 
    pass

Check if we have done so successfully and accurately.

In [None]:
# Number of images
num_tr_norm = len(os.listdir(os.path.join(train_dir, 'NORMAL')))
num_tr_pneu = len(os.listdir(os.path.join(train_dir, 'PNEUMONIA')))
num_tr = num_tr_norm + num_tr_pneu

num_vl_norm = len(os.listdir(os.path.join(val_dir, 'NORMAL')))
num_vl_pneu = len(os.listdir(os.path.join(val_dir, 'PNEUMONIA')))
num_vl = num_vl_norm + num_vl_pneu

# Display
print("TOTAL NUMBER OF TRAIN IMAGES:", num_tr)
print("Number of NORMAL    - TRAIN :", num_tr_norm)
print("Number of PNEUMONIA - TRAIN :", num_tr_pneu, "\n")

print("TOTAL NUMBER OF VALIDATION IMAGES:", num_vl)
print("Number of NORMAL    - VALIDATION :", num_vl_norm)
print("Number of PNEUMONIA - VALIDATION :", num_vl_pneu, "\n")

Let's now deal with the imbalance in the number of normal and pneumonia images in training set using class weights. This dictionary of class weights will be used later on. 

According to TF documentations, class weights "can be useful to tell the model to "pay more attention" to samples from an under-represented class." (https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit)

In [None]:
# Compute class weights for normal and pneumonia images
weight_norm = (1 / NUM_NORM)*(NUM_TRAIN)/2.0 
weight_pneu = (1 / NUM_PNEU)*(NUM_TRAIN)/2.0

# Create a dictionary to store the weights
class_weight = {0: weight_norm, 1: weight_pneu}

print('Weight for NORMAL    (class 0): {:.2f}'.format(weight_norm))
print('Weight for PNEUMONIA (class 1): {:.2f}'.format(weight_pneu))

## 05. Data Augmentation


We will be using a batch size of 130 and image size of 200x200. 

In [None]:
BATCH_SIZE = 130
IMG_SHAPE = 200

Create a function that will display images.

In [None]:
# This function will plot images in the form of a grid with 1 row and 5 columns 
# where images are placed in each column.
def plotImages(images_arr):
    fig, axes = plt.subplots(1, 5, figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip( images_arr, axes):
        ax.imshow(img)
    plt.tight_layout()
    plt.show()

### Applying Horizontal Flip

We use ImageDataGenerator to rescale the images by 255 and then apply a random horizontal flip.

In [None]:
image_gen = ImageDataGenerator(rescale=1./255, horizontal_flip=True)

train_data_gen = image_gen.flow_from_directory(batch_size=BATCH_SIZE,
                                               directory=train_dir,
                                               shuffle=True, 
                                               target_size=(IMG_SHAPE,IMG_SHAPE))

augmented_images = [train_data_gen[0][0][0] for i in range(5)]
plotImages(augmented_images)

### Applying Rotation

We use ImageDataGenerator to rescale the images by 255 and then apply a random rotation.

In [None]:
image_gen = ImageDataGenerator(rescale=1./255, rotation_range=45)

train_data_gen = image_gen.flow_from_directory(batch_size=BATCH_SIZE,
                                               directory=train_dir,
                                               shuffle=True, 
                                               target_size=(IMG_SHAPE,IMG_SHAPE))

augmented_images = [train_data_gen[0][0][0] for i in range(5)]
plotImages(augmented_images)

### Applying Zoom

We use ImageDataGenerator to rescale the images by 255 and then apply a random zoom.

In [None]:
image_gen = ImageDataGenerator(rescale=1./255, zoom_range=0.25)

train_data_gen = image_gen.flow_from_directory(batch_size=BATCH_SIZE,
                                               directory=train_dir,
                                               shuffle=True, 
                                               target_size=(IMG_SHAPE,IMG_SHAPE))

augmented_images = [train_data_gen[0][0][0] for i in range(5)]
plotImages(augmented_images)

### Combining the Transformations

Preparing data for training images.

In [None]:
image_gen_train = ImageDataGenerator(
    rescale=1./255, 
    zoom_range=0.25,
    horizontal_flip=True, 
)

train_data_gen = image_gen_train.flow_from_directory(batch_size=BATCH_SIZE,
                                                     directory=train_dir,
                                                     shuffle=True, 
                                                     target_size=(IMG_SHAPE,IMG_SHAPE),
                                                     class_mode='binary')

augmented_images = [train_data_gen[0][0][0] for i in range(5)]
plotImages(augmented_images)

Preparing data for validation images.

In [None]:
image_gen_val = ImageDataGenerator(rescale=1./255)

val_data_gen = image_gen_val.flow_from_directory(batch_size=BATCH_SIZE,
                                                 directory=val_dir,
                                                 target_size=(IMG_SHAPE,IMG_SHAPE),
                                                 class_mode='binary')

Preparing data for testing images.

In [None]:
image_gen_test = ImageDataGenerator(rescale=1./255)

test_data_gen = image_gen_test.flow_from_directory(batch_size=BATCH_SIZE,
                                                   directory=test_dir,
                                                   target_size=(IMG_SHAPE,IMG_SHAPE),
                                                   class_mode='binary')

## 05. Creating a CNN Model

The genereal structure is as follows:

* Three pairs of convolution and max-pooling layers
   - 16, 32, and 62 nodes in that order
   - Same Padding
   - ReLU Activation
* Flatten
* Two Dense layers
    - 512 and 2 nodes in that order
    - ReLU Activation for the first dense layer
    - Dropout rate at 0.2 at each dense layer


In [None]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(16, (3,3), padding='same', activation='relu', input_shape=(IMG_SHAPE,IMG_SHAPE,3)),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),

    tf.keras.layers.Conv2D(32, (3,3), padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),

    tf.keras.layers.Conv2D(64, (3,3), padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),

    tf.keras.layers.Flatten(),

    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(512, activation='relu'),

    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(2)
])

## 06. Compiling the CNN Model

In [None]:
# Compile the model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [None]:
callback_cp = tf.keras.callbacks.ModelCheckpoint("pneumonia_model.hdf5",
                                                 save_best_only=True)

callback_es = tf.keras.callbacks.EarlyStopping(patience=10,
                                               restore_best_weights=True)

## 07. Training the CNN Model

In [None]:
epochs = 40

history = model.fit(
    train_data_gen,
    steps_per_epoch=int(np.ceil(train_data_gen.n / float(BATCH_SIZE))),
    epochs=epochs,
    validation_data=val_data_gen,
    validation_steps=int(np.ceil(val_data_gen.n / float(BATCH_SIZE)))
    # class_weight=class_weight,
    # callbacks=[callback_cp, callback_es]
)

## 08. Visualizing Model Performance

In [None]:
# Accuracy
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

# Loss
loss = history.history['loss']
val_loss = history.history['val_loss']

# x-axis
epochs_range = range(epochs)

# First figure: Model Accuracy
plt.figure(figsize=(20, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')

# Second figure: Model Loss
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

plt.show()

## 09. Model Prediction

In [None]:
class_names = np.array(['Normal','Pneumonia'])

image_batch, label_batch = next(iter(test_data_gen))

predicted_batch = model.predict(image_batch)
predicted_batch = tf.squeeze(predicted_batch).numpy()

predicted_ids = np.argmax(predicted_batch, axis=-1)
predicted_class_names = class_names[predicted_ids]

print(predicted_class_names)

In [None]:
print("Labels:\n", label_batch.astype('int64'))
print("Predicted labels:\n", predicted_ids)

In [None]:
plt.figure(figsize=(15,15))
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.subplots_adjust(hspace=0.5)
  plt.imshow(image_batch[n])

  if predicted_ids[n] == label_batch[n]:
    color = "blue" 
  else:
    color = "red"

  plt.title(predicted_class_names[n].title(), color=color)
  plt.axis('off')

_ = plt.suptitle("Model Predictions\n (Blue: correct, Red: incorrect)", 
                 fontsize='xx-large', fontweight='bold')

In [None]:
acc = model.evaluate(test_data_gen)