# <center><b>CNN for Alzheimer's Detection From MRI Images</b></center>

# INTRODUCTION
Alzheimer’s disease (AD) is a neurological disorder that results in diminished cognitive function. AD onset most often occurs when people are in their mid 60s and is the most frequent cause of dementia in seniors. It is currently estimated that over 6 million American’s over 65 have AD [4]. There are currently no cures for AD, however there are some treatment strategies. Early detection and intervention have been shown to slow disease progression and improve the quality of life for individuals suffering from AD [3]. Definitively diagnosing AD while someone is alive remains a challenge for the medical community and several metrics need to be assessed to determine if an individual is suffering from AD [2]. These methods may include brain scans such as magnetic resonance imaging (MRI), cognitive assessments through testing of memory, attention and problem solving, overall health assessment, and examining environmental and biological factors [2]. Developing models that help detect early-stage AD would be a great help to those suffering from the disease. The focus of this project will be to build a convolutional neural network (CNN) to detect AD in MRI scans. 

# LIBRARIES

In [None]:
import os
import glob
import numpy as np 
import pandas as pd 
import math
from PIL import Image
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Input, BatchNormalization, Activation, Conv2D, MaxPooling2D, Flatten, Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.activations import leaky_relu
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.regularizers import l2
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from collections import Counter

# DATA EXPLORATION AND PROCESSING
The dataset used for this project consists of images from MRI brain scans. The MRI technique is non-invasive and can produce detailed images of soft tissue such as brain tissue [1]. Changes in brain structure such as cerebral atrophy (shrinking of the brain), and abnormal protein build up are characteristics of AD [2]. The data is comprised of four classes, Non-Demented, Mild Demented, Moderate Demented, and Very Mild Demented. The data is preprocessed so little cleaning needs to be done. However, the images will still be normalized using a data generator. There is a class imbalance in the data with only 64 images associated with moderate dementia. 

**Alzheimer MRI Preprocessed Dataset Available at Kaggle**
https://www.kaggle.com/datasets/borhanitrash/alzheimer-mri-disease-classification-dataset/data


## Directories

In [None]:
dataset_dir = "/kaggle/input/imagesoasis/Data" 

# Check if the directory exists 
if os.path.exists("/kaggle/input/imagesoasis"):
    print("Dataset directory exists")
    print("Contents:", os.listdir("/kaggle/input/imagesoasis"))
else:
    print("Dataset directory not found")
    print("Available inputs:", os.listdir("/kaggle/input"))

### Check Categories

In [None]:
categories = os.listdir(dataset_dir)
print("Categories:", categories)

## Load the Data

**Check Number of Images per Category**

In [None]:
image_paths = []
labels = []

for class_idx, class_name in enumerate(categories):
    class_path = os.path.join(dataset_dir, class_name)
        
    # Get image files 
    files = glob.glob(os.path.join(class_path, "*.jpg"))
    if not files:
        files = glob.glob(os.path.join(class_path, "*.jpeg"))
    if not files:
        files = glob.glob(os.path.join(class_path, "*.png"))
    
    print(f"Category: {class_name}, Files found: {len(files)}")
    
    for file_path in files:
        image_paths.append(file_path)
        labels.append(class_idx)

print(f"Total images found: {len(image_paths)}")


## Data Generator
Normalize the pixel values of the images by dividing them by 255 ensuring that the pixel values are in the range [0, 1] instead of [0, 255]. Also, splits the data into testing and validation sets. 

In [None]:
# Create ImageDataGenerator objects
datagen = ImageDataGenerator(validation_split=0.25, rescale=1./255)

# Training data generator
train_generator = datagen.flow_from_directory(
    dataset_dir,
    target_size=(256, 256), # Resize images to 256x256 pixels
    batch_size=32,
    class_mode='categorical', # For multi-class classification
    subset='training'
)

# Validation data generator
validation_generator = datagen.flow_from_directory(
    dataset_dir,
    target_size=(256, 256), # Resize images to 256x256 pixels
    batch_size=32,
    class_mode='categorical', # For multi-class classification
    subset='validation'
)

## Display Example Images

In [None]:
classes = ["Mild Dementia", "Moderate Dementia", "Non Demented", "Very mild Dementia"]

num_images_display = 2

# Initialize plot
fig, axes = plt.subplots(num_images_display, len(classes), figsize = (len(classes) * 3, num_images_display * 3))

# Loop to get images from each class
for i, class_name in enumerate(classes):
    class_dir = os.path.join(dataset_dir, class_name)
    images = os.listdir(class_dir)[:num_images_display]
    
    for j, img_name in enumerate(images):
        img_path = os.path.join(class_dir, img_name)
        img = image.load_img(img_path, target_size =(256, 256))
        img_array = image.img_to_array(img)/ 255.0
        
        ax = axes[j,i]
        ax.imshow(img_array)
        ax.axis('off')
        ax.set_title(f"{class_name}")
        
plt.tight_layout
# Save image (optional)
#plt.savefig('example_mri.png')
plt.show()

# MODELS
CNNs will be used to detect AD in brain tissue by analyzing MRI images. CNNs are an appropriate choice for this task for several reasons. For one they use convolutional layers with local receptive fields to recognize patterns such as edges, textures and shape. Given the changes in brain structure associated with AD [2], detecting patterns such as these could be helpful in diagnosing the disease. CNNs process images through multiple layers, they learn to extract increasingly complex features. Early layers detect simple structures, while deeper layers can capture more complex patterns, such as those associated with cerebral atrophy and protein build-up. 

### Data Augmentation
Data augmentation helps prevent overfitting by generating diverse training samples. This is done by randomly rotating, flipping, zooming or adjusting the contrast of some images. Defining the augmentation this way will allow for it to be easily take in and out of models if required.

In [None]:
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip("horizontal"),
    tf.keras.layers.RandomRotation(0.15, fill_mode='nearest'),
    tf.keras.layers.RandomZoom(0.15, fill_mode='nearest'),
    tf.keras.layers.RandomTranslation(0.1, 0.1, fill_mode='nearest'), # slight translation
    tf.keras.layers.RandomContrast(0.2),
    tf.keras.layers.RandomBrightness(0.1),
    #tf.keras.layers.GaussianNoise(0.01), # adds a small amount of blur
])


## Learning Rate Scheduler
This scheduler will be used to decrease the learning rate during model training. 

In [None]:
reduce_lr = ReduceLROnPlateau(
    monitor = 'val_loss', # assess validation loss
    factor = 0.5, # amount to reduce lr by
    patience = 2, # number of epochs to wait with no improvemnet 
    min_lr = 1e-6, # min lr not to go below
    verbose = 1 # print out when lr is reduce
)


### Performance Plot Function
Plots the accuracy, loss and auc scores from model training history.

In [None]:
def model_performance_plots(history, model_name):
    """
    Plot the accuracy, loss, and auc score from the training history.
    Inputs: 
    history - the history of the fit method of the model
    model_name - the name of the model
    """

    plt.figure(figsize=(15, 5))

    # Accuracy plot
    plt.subplot(1, 3, 1)
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title(f'{model_name} Accuracy')
    plt.legend(loc='lower right')

    # Loss plot
    plt.subplot(1, 3, 2)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'{model_name} Loss')
    plt.legend(loc='upper right')

    # AUC Score plot
    plt.subplot(1, 3, 3)
    plt.plot(history.history['auc'], label='Train AUC')
    plt.plot(history.history['val_auc'], label='Validation AUC')
    plt.xlabel('Epoch')
    plt.ylabel('AUC')
    plt.title(f'{model_name} AUC')
    plt.legend(loc='lower right')

    plt.tight_layout()


## Model Architecture

**Convolutional Layers:** capture spatial hierarchies in the data

Conv2D(64, (3, 3), padding='same', activation='leaky_relu', input_shape=(256, 256, 3)): The first convolutional layer with 64 filters of size 3x3, using ReLU activation, with padding to keep the dimensions the same as the input. The input shape is 256x256 with 3 channels (RGB).

Conv2D(128, (3, 3), activation='leaky_relu',padding='same'): The second convolutional layer with 128 filters. 

Conv2D(256, (3, 3), activation='leaky_relu',padding='same'): The third convolutional layer with 256 filters. 

Conv2D(512, (3, 3), activation='leaky_relu',padding='same'): The third convolutional layer with 512 filters. 

**Dropout:** 

Dropout icreasing: Randomly sets designated percentage of the input units to 0 during training, which helps prevent overfitting.

Dropout(0.3): Applied before the dense layers, with a higher dropout rate to prevent overfitting as the model becomes more complex.


**Batch Normalization:**

Added after each convolutional layer to normalize the activations, stabilize the learning process, and potentially reduce overfitting.

**Max Pooling:**

MaxPooling2D((2, 2)) reduces the spatial dimensions by half, and helps the model focus on the most important features.

**GlobalAveragePooling:** 

GlobalAveragePooling before dense layers reduces parameters to 512 inputs.

**Dense Layers:**

Dense(256): A fully connected layer with 256 units. Larger dense layer can learn more complex and abstract features from the data allowing the model to capture more detailed information and interactions.

Dense(4, activation='softmax'): The output layer with 4 units (since it’s a 4-class classification problem) and softmax activation to output class probabilities.


In [None]:
# Define input
inputs = Input(shape=(256, 256, 3))
x = inputs
# Add data augmentation layer that's only active during training
x = data_augmentation(x, training=True)

x = Conv2D(64, (3, 3), padding='same', activation=leaky_relu, kernel_regularizer=l2(0.001))(x)
x = Dropout(0.05)(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)

x = Conv2D(128, (3, 3), padding='same', activation=leaky_relu, kernel_regularizer=l2(0.001))(x)
x = Dropout(0.1)(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)

x = Conv2D(256, (3, 3), padding='same', activation=leaky_relu, kernel_regularizer=l2(0.001))(x)
x = Dropout(0.15)(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)

x = Conv2D(512, (3, 3), padding='same', activation=leaky_relu, kernel_regularizer=l2(0.001))(x)
x = Dropout(0.2)(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)

# Global Average Pooling to reduce spatial dimensions
x = GlobalAveragePooling2D()(x)

# Fully connected layers
x = Dense(256)(x)
x = Dropout(0.4)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Output 
outputs = Dense(4, activation='softmax', kernel_regularizer=l2(0.001))(x)

# Build model
model = Model(inputs, outputs)

## Class Weight Calcuation
The dataset is severely imbalanced, and class weights will be implemented to help with this. 

In [None]:
"""
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_generator.classes),  # Use training data
    y=train_generator.classes
)
class_weights = dict(enumerate(class_weights))
print("Class weights:", class_weights)
"""

The above class weight calculation resulted in the model collapsing into a single class, 'Non Demented'. Therefore, a customized class weight calculation will be implemented. 

In [None]:
class_counts = {
    0: 5002,  # Mild Demented
    1: 488,   # Moderate Demented
    2: 67222, # Non Demented
    3: 13725  # Very Mild Demented
}

total = sum(class_counts.values())
class_weights = {cls: total/ (len(class_counts) * count) for cls, count in class_counts.items()}

max_weight = 15
min_weight = 1.0
for cls in class_weights:
    class_weights[cls] = np.clip(class_weights[cls], min_weight, max_weight)

print("Class Weights", class_weights)

In [None]:
model.compile(optimizer=Adam(learning_rate=0.00005),
              loss='categorical_crossentropy',
              metrics=['accuracy', 'auc'])

model.summary()

In [None]:
# Fit model
history = model.fit(
    train_generator,
    validation_data = validation_generator,
    epochs=15,
    callbacks=[reduce_lr],
    class_weight=class_weights
)

## Results

In [None]:
# Plot results
model_plots = model_performance_plots(history, "CNN Alzheimer's Performance")
plt.savefig("/kaggle/working/model_performance_plots.jpg")
plt.show()

### Confusion Matrix

In [None]:
# True labels from the validation generator
validation_labels = validation_generator.classes

# Predict the probabilities for the validation data
predictions = model.predict(validation_generator)

# Convert the probabilities to labels
predicted_classes = np.argmax(predictions, axis=1)

In [None]:
# Generate confusion matrix
conf_matrix = confusion_matrix(validation_labels, predicted_classes)

# Plot 
plt.figure(figsize=(8, 8), dpi=100)
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=validation_generator.class_indices.keys(), yticklabels=validation_generator.class_indices.keys())
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.xticks(rotation = 45, ha = 'right')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.savefig("/kaggle/working/model_confusion_matrix.jpg")
plt.show()

# CONCLUSION
Adult dementia is devastating to the individuals inflicted and their loved ones. If current population trends continue, there will be a dramatic increase in the number of people suffering from AD [5]. Deep learning techniques such as CNN could be helpful in detecting early-stage AD and directing treatments to those afflicted sooner. 

**Improvements**

Dealing with overfitting while increasing accuracy needs to be addressed. One possible way to improve the model’s performance would be to apply transfer learning. Foster showed that using a pretrained MobileNetV2 architecture to detect AD in MRI images enhanced a deep learning model [3]. However, this was on a binary AD classification not a multi-class problem. Transfer learning could be further enhanced by implementing autotune, which helps tune hyperparameters of the pre-trained model. Continuing to tune this model could be helpful as well. 


# REFERENCES

[1] Ashby, K., Adams, B. N., & Shetty, M. (2022, November 14). Appropriate magnetic resonance imaging ordering. StatPearls - NCBI Bookshelf. https://www.ncbi.nlm.nih.gov/books/NBK565857/

[2] Coupé, P., Manjón, J. V., Lanuza, E., & Catheline, G. (2019). Lifespan changes of the human brain in Alzheimer’s disease. Scientific Reports, 9(1). https://doi.org/10.1038/s41598-019-39809-8

[3] Foster, L. (2023, April 18). Identifying Alzheimer’s Disease with Deep Learning: A Transfer Learning Approach. Medium. https://medium.com/@lfoster49203/identifying-alzheimers-disease-with-deep-learning-a-transfer-learning-approach-620abf802631

[4] “How Is Alzheimer’s Disease Diagnosed?”. National Institute on Aging. Dec.08, 2022. https://www.nia.nih.gov/health/alzheimers-symptoms-and-diagnosis/how-alzheimers-disease-diagnosed

[5] Rasmussen, J., & Langerman, H. (2019). Alzheimer’s Disease – Why We Need Early    Diagnosis. Degenerative Neurological and Neuromuscular Disease, Volume 9, 123–130. https://doi.org/10.2147/dnnd.s228939

[6] “What Is Alzheimer’s Disease?”. National Institute on Aging, Jul. 08, 2021. https://www.nia.nih.gov/health/alzheimers-and-dementia/what-alzheimers-disease


Alzheimer MRI Preprocessed Dataset Available at Kaggle:
https://www.kaggle.com/datasets/sachinkumar413/alzheimer-mri-dataset
