# Alzheimer's Disease Detection with CNN

This notebook demonstrates how to use the prepared ADNI dataset to train a CNN model for Alzheimer's Disease detection, similar to the MNIST example.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from prepare_data import load_adni_data
import os

## 1. Load the prepared data

First, we'll run the data preparation script if the data hasn't been prepared yet.

In [6]:
# Check if the data file exists
data_file = 'adni_data.pkl'
if not os.path.exists(data_file):
    print("Data file not found. Running data preparation script...")
    from prepare_data import load_and_prepare_data
    (x_train, y_train), (x_test, y_test) = load_and_prepare_data(
        'binary_data.csv', 'ADNI_IMAGES', data_file
    )
else:
    print("Loading prepared data...")
    (x_train, y_train), (x_test, y_test) = load_adni_data(data_file)

Loading prepared data...


## 2. Explore the data

Let's look at the shape of our data and visualize some samples.

In [7]:
print(f"x_train shape: {x_train.shape}")
print(f"y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape}")
print(f"y_test shape: {y_test.shape}")

x_train shape: (4891, 128, 128, 1)
y_train shape: (4891,)
x_test shape: (3261, 128, 128, 1)
y_test shape: (3261,)


Training samples:


NameError: name 'visualize_samples' is not defined

## 3. Convert labels to one-hot encoding

Similar to the MNIST example, we'll convert our labels to one-hot encoding.

In [11]:
# Get the number of classes
num_classes = len(np.unique(y_train))
print(f"Number of classes: {num_classes}")

# Convert into a 0/1 labels. We currently are using the original 0/4
y_train_updated_labels = np.where(y_train == 4, 0, 1)
y_test_updated_labels = np.where(y_test == 4, 0, 1)

print(np.unique(y_train_updated_labels))

# Convert to one-hot encoding
y_train_one_hot = keras.utils.to_categorical(y_train_updated_labels, num_classes)
y_test_one_hot = keras.utils.to_categorical(y_test_updated_labels, num_classes)

print(f"y_train_one_hot shape: {y_train_one_hot.shape}")
print(f"y_test_one_hot shape: {y_test_one_hot.shape}")

Number of classes: 2
[0 1]
y_train_one_hot shape: (4891, 2)
y_test_one_hot shape: (3261, 2)


## 4. Build a CNN model

Now we'll build a CNN model similar to what you might use for MNIST.

In [12]:
# Get the input shape from our training data
input_shape = x_train.shape[1:]
print(f"Input shape: {input_shape}")

# Build the model
model = keras.Sequential([
    layers.Conv2D(32, kernel_size=(3, 3), activation="relu", input_shape=input_shape),
    layers.MaxPooling2D(pool_size=(2, 2)),
    layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
    layers.MaxPooling2D(pool_size=(2, 2)),
    layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
    layers.MaxPooling2D(pool_size=(2, 2)),
    layers.Flatten(),
    layers.Dropout(0.5),
    layers.Dense(128, activation="relu"),
    layers.Dropout(0.3),
    layers.Dense(num_classes, activation="softmax")
])

# Compile the model
model.compile(
    loss="categorical_crossentropy",
    optimizer="adam",
    metrics=["accuracy"]
)

# Print the model summary
model.summary()

Input shape: (128, 128, 1)


## 5. Train the model

Now we'll train our model on the prepared data.

In [None]:
# Define callbacks
callbacks = [
    keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True),
    keras.callbacks.ReduceLROnPlateau(factor=0.1, patience=3)
]

# Train the model``
batch_size = 32
epochs = 20

history = model.fit(
    x_train, y_train_one_hot,
    batch_size=batch_size,
    epochs=epochs,
    validation_split=0.1,
    callbacks=callbacks
)

## 6. Evaluate the model

Let's evaluate our model on the test set.

In [None]:
# Evaluate the model
test_loss, test_acc = model.evaluate(x_test, y_test_one_hot)
print(f"Test accuracy: {test_acc:.4f}")

## 7. Plot training history

Let's visualize the training process.

In [None]:
# Plot training & validation accuracy and loss
plt.figure(figsize=(12, 4))

# Accuracy plot
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='lower right')

# Loss plot
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper right')

plt.tight_layout()
plt.show()

## 8. Make predictions

Let's make predictions on some test samples and visualize the results.

In [None]:
# Get predictions for test samples
predictions = model.predict(x_test)
predicted_classes = np.argmax(predictions, axis=1)

# Plot some test samples with their predictions
num_samples = 5
plt.figure(figsize=(15, 3))

for i in range(num_samples):
    plt.subplot(1, num_samples, i + 1)
    
    # Display the image
    plt.imshow(x_test[i].reshape(x_test.shape[1], x_test.shape[2]), cmap='gray')
    
    # Get the true and predicted class
    true_class = y_test[i]
    pred_class = predicted_classes[i]
    pred_prob = predictions[i, pred_class]
    
    # Set the title color based on correctness
    title_color = 'green' if true_class == pred_class else 'red'
    
    # Set the title
    plt.title(f"True: {true_class}\nPred: {pred_class} ({pred_prob:.2f})", color=title_color)
    plt.axis('off')
    
plt.tight_layout()
plt.show()

## 9. Save the model

Finally, let's save our trained model.

In [None]:
# Save the model
model.save('ad_detection_model.h5')
print("Model saved as 'ad_detection_model.h5'")