AiApp Miniproject; Leonie Däullary, Ruwen Frick

The task is to categorize brain MRI images into four categories. The main goal is to train a nwtwork to detect alzeheimers disease in brain MRI images. The data can be found here: https://www.kaggle.com/datasets/sachinkumar413/alzheimer-mri-dataset

Needed Packages

In [None]:
pip install opendatasets tensorflow matplotlib

Imports

In [None]:
import tensorflow

from keras import Sequential
from keras_preprocessing.image import ImageDataGenerator
from keras.layers.convolutional import Conv2D
from keras.layers import Flatten, Dense, MaxPooling2D

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import opendatasets as od
import os
from pathlib import Path, PurePath

Dowloading the Data

The dataset is hosted on Kaggle. When executing the following cell you will be asked for your Kaggle credentials. These can be acquired by following below steps:

1. Sign in to https://kaggle.com/ or register a new account, then click on your profile picture on the top right and select "My Account" from the menu.

2. Scroll down to the "API" section and click "Create New API Token". This will download a file kaggle.json with the following contents:
    {"username":"YOUR_KAGGLE_USERNAME","key":"YOUR_KAGGLE_KEY"}

In [None]:
dataset_url = 'https://www.kaggle.com/sachinkumar413/alzheimer-mri-dataset'
od.download(dataset_url)

path = Path(".", 'alzheimer-mri-dataset', 'Dataset').absolute()
print(f"Data stored at: {path}")

classes = [dir for dir in sorted(os.listdir(path))]

mild_path = Path(path, classes[0])
moderate_path = Path(path, classes[1])
non_path = Path(path, classes[2])
verymild_path = Path(path, classes[3])

Print metadata of downloaded dataset

In [None]:
for dirpath, dirnames, filenames in os.walk(path):
    class_name = PurePath(dirpath).name
    if class_name != 'Dataset':
        print(f'{len(filenames)} images in class {class_name}')

Print some examples

In [None]:
all_imgs = []

for i in range(2, 5):
    mild_img = mpimg.imread(Path(mild_path, 'mild_' + str(i) + '.jpg'))
    all_imgs.append(mild_img)

    moderate_img = mpimg.imread(Path(moderate_path, 'moderate_' + str(i) + '.jpg'))
    all_imgs.append(moderate_img)

    non_img = mpimg.imread(Path(non_path, 'non_' + str(i) + '.jpg'))
    all_imgs.append(non_img)

    verymild_img = mpimg.imread(Path(verymild_path, 'verymild_' + str(i) + '.jpg'))
    all_imgs.append(verymild_img)

plt.figure(figsize=(15, 15))

for index in range(1, len(all_imgs) + 1):
    plt.subplot(3, 4, index)
    plt.imshow(all_imgs[index - 1])
    plt.title(classes[(index - 1) % len(classes)])

plt.show()


Preprocessing and Data Augmentation

In [None]:
# Data generation and augmentation parameters
image_size = (128, 128)
horizontal_flip = True
color_mode = 'rgb'
zoom_range = 0.05
rotation_range = 10
shear_range = 0.1
batch_size  = 64
validation_split = 0.15

train_data_generator = ImageDataGenerator(
    horizontal_flip = horizontal_flip,
    zoom_range = zoom_range,
    rotation_range = rotation_range,
    shear_range = shear_range,
    validation_split = validation_split
)

valid_data_generator = ImageDataGenerator(
    validation_split = validation_split
)

train_data = train_data_generator.flow_from_directory(
    path,
    target_size = image_size,
    color_mode = color_mode,
    batch_size  = batch_size,
    subset = 'training'
)

valid_data = valid_data_generator.flow_from_directory(
    path,
    target_size = image_size,
    color_mode = color_mode,
    batch_size  = batch_size,
    subset = 'validation'
)

Building the models

In [None]:
deep_model = Sequential([
  Conv2D(64, (4, 4), activation = 'relu', input_shape = (128, 128, 3)),
  MaxPooling2D((3, 3)),
  Conv2D(32, (3, 3), activation = 'relu'),
  MaxPooling2D((2, 2)),
  Conv2D(32, (2, 2), activation = 'relu'),
  MaxPooling2D((2, 2)),
  Flatten(),
  Dense(128, activation='relu'),
  Dense(64, activation='relu'),
  Dense(len(classes), activation='softmax'),
])

shallow_model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(128, 128, 3)),
    MaxPooling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Flatten(),
    Dense(32, activation='relu'),
    Dense(len(classes), activation='softmax'),
])

In [None]:
deep_model.summary()

In [None]:
shallow_model.summary()

Compile, fit and evaluate models

In [None]:
deep_model.compile(optimizer = 'adam',
              loss = 'categorical_crossentropy',
              metrics = ['accuracy'])

shallow_model.compile(optimizer = 'adam',
              loss = 'categorical_crossentropy',
              metrics = ['accuracy'])

In [None]:
deep_history = deep_model.fit(
    train_data,
    epochs = 5,
    validation_data = valid_data,
)

deep_acc = deep_model.evaluate(valid_data)[1]

In [None]:
shallow_history = shallow_model.fit(
    train_data,
    epochs = 5,
    validation_data = valid_data,
)

shallow_acc = deep_model.evaluate(valid_data)[1]

Print statistics

In [None]:
# print accuracy
print(f'Validation accuracy of deep model: {deep_acc}')
print(f'Validation accuracy of shallow model: {shallow_acc}')


# plot accuracy against epochs
plt.title('Deep Model vs Shallow Model')
plt.plot(deep_history.history['accuracy'], label= 'deep_accuracy')
plt.plot(deep_history.history['val_accuracy'], label = 'deep_val_accuracy')
plt.plot(shallow_history.history['accuracy'], label= 'shallow_accuracy')
plt.plot(shallow_history.history['val_accuracy'], label = 'shallow_val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim(0, 0.8)
plt.legend(loc='lower right')


Performance overall is not great. Validation accuracy being below training accuracy might indicate some overfitting
# TO DO: discuss differences