In [4]:
#Dataset link: https://www.kaggle.com/datasets/ninadaithal/imagesoasis

# Dataset path and class names
dataset_path = './OASIS Data'

# Verify access to dataset
print("Accessing dataset...")
print(os.listdir(dataset_path))

classes = ['Non Demented', 'Mild Dementia', 'Moderate Dementia', 'Very mild Dementia']
image_paths = []
labels = []

num_files = 1464  # Number of files to select randomly from each category (number was chosen because there are only 488 images for moderate dementia)

# Function to load images
def load_images(paths, img_size=(224, 224)):
    images = []
    for path in paths:
        try:
            # Read image (grayscale for simplicity)
            img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                print(f"Failed to load image: {path}")
                continue
            img = cv2.resize(img, img_size)
            img = img / 255.0  # Normalize to [0, 1]
            images.append(img)
        except Exception as e:
            print(f"Error processing image {path}: {str(e)}")
    return np.array(images)

Accessing dataset...
['Mild Dementia', '.DS_Store', 'Very mild Dementia', 'Moderate Dementia', 'Non Demented']


In [None]:
try:
    # Iterate through classes and load images
    non_demented_path = os.path.join(dataset_path, 'Non Demented')
    non_demented_files = os.listdir(non_demented_path)
    # Randomly select a subset of images
    non_demented_files = random.sample(non_demented_files, min(num_files, len(non_demented_files)))
    
    print(f"Loaded {len(non_demented_files)} non-demented images")

    image_paths = []

    for image_filename in non_demented_files:
        image_path = os.path.join(non_demented_path, image_filename)
        if os.path.isfile(image_path):
            image_paths.append(image_path)
            labels.append(0)  # Label for Non Demented

    dementia_classes = ['Mild Dementia', 'Moderate Dementia', 'Very mild Dementia']
    for category in dementia_classes:
        try:
            category_path = os.path.join(dataset_path, category)
            category_files = os.listdir(category_path)
            selected_files = random.sample(category_files, min(num_files // len(dementia_classes), len(category_files)))
            for image_filename in selected_files:
                image_path = os.path.join(category_path, image_filename)
                if os.path.isfile(image_path):
                    image_paths.append(image_path)
                    labels.append(1)  # Label for Dementia (combined)
            
            print(f"Loaded {len(selected_files)} images from {category}")
            
        except Exception as e:
            print(f"Error processing category {category}: {str(e)}")

    print(f"Total images loaded: {len(image_paths)}")
except Exception as e:
    print(f"Error loading image directories: {str(e)}")

Loaded 1000 non-demented images
Loaded 488 images from Mild Dementia
Loaded 488 images from Moderate Dementia
Loaded 488 images from Very mild Dementia
Total images loaded: 2464


# Two Stage Hierarchical Classifier for AD
This approach will:
1. First classify images as "Dementia" vs "Non-Dementia" (binary classification)
2. Then classify "Dementia" images into severity levels (multi-class classification)



In [24]:
"""
Script to prepare binary data
returns: image_paths, binary_labels, detailed_labels, subjects
"""
def prepare_binary_data():
    # Dataset path
    dataset_path = './OASIS Data'
    
    image_paths = []
    binary_labels = []  # 0 for Non-Demented, 1 for any type of Dementia
    detailed_labels = []  # 0: Non-Demented, 1: Mild, 2: Moderate, 3: Very Mild
    subjects = []
    
    # Load Non-Demented images (keep all 1000)
    print(f"Accessing dataset from: {dataset_path}")
    non_demented_path = os.path.join(dataset_path, 'Non Demented')
    print(f"Loading Non-Demented images from: {non_demented_path}")
    non_demented_files = os.listdir(non_demented_path)
    print(f"Found {len(non_demented_files)} Non-Demented files")
    
    # Keep all non-demented files (up to 1000)
    non_demented_files = random.sample(non_demented_files, min(1000, len(non_demented_files)))
    print(f"Sampled {len(non_demented_files)} Non-Demented files")
    
    non_demented_count = 0
    for image_filename in non_demented_files:
        image_path = os.path.join(non_demented_path, image_filename)
        if os.path.isfile(image_path):
            image_paths.append(image_path)
            binary_labels.append(0)  # Non-Demented
            detailed_labels.append(0)  # Non-Demented
            # Extract subject ID for proper train/val split
            subject_id = image_filename.split('OAS1_')[1].split('_')[0]
            subjects.append(subject_id)
            non_demented_count += 1
    
    print(f"Successfully processed {non_demented_count} Non-Demented images")
    
    """
    Now collecting all dementia files
    """
    # First, collect all dementia files to sample from them later
    dementia_files = []
    dementia_classes = {
        'Mild Dementia': 1,
        'Moderate Dementia': 2,
        'Very mild Dementia': 3
    }
    
    # Collect all dementia files with their labels and subjects
    all_dementia_paths = []
    all_dementia_detailed_labels = []
    all_dementia_subjects = []
    
    for category, label in dementia_classes.items():
        category_path = os.path.join(dataset_path, category)
        print(f"Loading {category} images from: {category_path}")
        category_files = os.listdir(category_path)
        print(f"Found {len(category_files)} {category} files")
        
        category_count = 0
        for image_filename in category_files:
            image_path = os.path.join(category_path, image_filename)
            if os.path.isfile(image_path):
                all_dementia_paths.append(image_path)
                all_dementia_detailed_labels.append(label)
                subject_id = image_filename.split('OAS1_')[1].split('_')[0]
                all_dementia_subjects.append(subject_id)
                category_count += 1
        
        print(f"Successfully processed {category_count} {category} images")
    
    """
    For binary classification: sample 1000 dementia images to match non-demented count
    Use stratified sampling to maintain proportion of each dementia type
    """
    if len(all_dementia_paths) > 1000:
        # Create indices for stratified sampling
        indices = np.arange(len(all_dementia_paths))
        # Get stratified sample
        _, sampled_indices = train_test_split(
            indices, 
            test_size=1000/len(all_dementia_paths),
            stratify=all_dementia_detailed_labels,
            random_state=42
        )
        
        # Add sampled dementia images to our dataset
        for idx in sampled_indices:
            image_paths.append(all_dementia_paths[idx])
            binary_labels.append(1)  # Dementia
            detailed_labels.append(all_dementia_detailed_labels[idx])
            subjects.append(all_dementia_subjects[idx])
    else:
        # If we have fewer than 1000 dementia images, use all of them
        for i in range(len(all_dementia_paths)):
            image_paths.append(all_dementia_paths[i])
            binary_labels.append(1)  # Dementia
            detailed_labels.append(all_dementia_detailed_labels[i])
            subjects.append(all_dementia_subjects[i])
    
    # Convert to numpy arrays
    binary_labels = np.array(binary_labels)
    detailed_labels = np.array(detailed_labels)
    
    # Print statistics
    print(f"Total images: {len(image_paths)}")
    print(f"Total subjects: {len(set(subjects))}")
    print(f"Binary label distribution: {np.bincount(binary_labels)}")
    
    return image_paths, binary_labels, detailed_labels, subjects

In [35]:
"""
Subject based train/validation split to prevent data leakage. This ensures that the demented class has a balanced distribution of subjects in the train and validation sets.

This script works by first getting the unique subjects found from prepare_binary_data().

Then, it creates a mapping of subjects to their binary classes.

Finally, it splits the data into train and validation sets.
"""

def split_by_subject(image_paths, binary_labels, detailed_labels, subjects):
    # Get unique subjects
    subject_counter = Counter(subjects)
    print(f"Number of unique subjects: {len(subject_counter)}")
    print(f"Average images per subject: {len(subjects) / len(subject_counter):.1f}")
    print(f"Subject distribution: {sorted(subject_counter.items(), key=lambda x: x[1], reverse=True)[:10]}")
    
    # Create a mapping of subjects to their classes (for stratified split)
    subject_to_binary_class = {}
    for i, subject in enumerate(subjects):
        if subject not in subject_to_binary_class:
            subject_to_binary_class[subject] = binary_labels[i]

    print(f'Subject to binary class: {subject_to_binary_class}')
    
    # Get subjects for each binary class
    class0_subjects = [s for s, c in subject_to_binary_class.items() if c == 0]
    class1_subjects = [s for s, c in subject_to_binary_class.items() if c == 1]
    print(f'Class 0 subjects: {class0_subjects}')
    print(f'Class 1 subjects: {class1_subjects}')

    # Split each class separately to maintain class distribution
    train_class0, val_class0 = train_test_split(class0_subjects, test_size=0.2, random_state=42)
    train_class1, val_class1 = train_test_split(class1_subjects, test_size=0.2, random_state=42)
    
    # Combine
    train_subjects = train_class0 + train_class1
    val_subjects = val_class0 + val_class1
    
    # Get indices for train and validation
    train_indices = [i for i, subject in enumerate(subjects) if subject in train_subjects]
    val_indices = [i for i, subject in enumerate(subjects) if subject in val_subjects]
    
    return train_indices, val_indices

In [37]:
def prepare_balanced_data():
    """
    Prepares data with balanced subjects across classes
    Returns: image_paths, binary_labels, detailed_labels, subjects
    """
    # Dataset path
    dataset_path = './OASIS Data'
    
    # Initialize lists
    image_paths = []
    binary_labels = []  # 0 for Non-Demented, 1 for any type of Dementia
    detailed_labels = []  # 0: Non-Demented, 1: Mild, 2: Moderate, 3: Very Mild
    subjects = []
    
    # Dictionary to track subjects and their classes
    subject_images = {}  # {subject_id: [(image_path, binary_label, detailed_label), ...]}
    
    print(f"Accessing dataset from: {dataset_path}")
    
    # Process Non-Demented images
    non_demented_path = os.path.join(dataset_path, 'Non Demented')
    print(f"Loading Non-Demented images from: {non_demented_path}")
    non_demented_files = os.listdir(non_demented_path)
    print(f"Found {len(non_demented_files)} Non-Demented files")
    
    # Process all non-demented files
    for image_filename in non_demented_files:
        if not os.path.isfile(os.path.join(non_demented_path, image_filename)):
            continue
            
        # Extract subject ID
        try:
            subject_id = image_filename.split('OAS1_')[1].split('_')[0]
        except:
            print(f"Skipping file with invalid format: {image_filename}")
            continue
            
        image_path = os.path.join(non_demented_path, image_filename)
        
        # Add to subject dictionary
        if subject_id not in subject_images:
            subject_images[subject_id] = []
        subject_images[subject_id].append((image_path, 0, 0))  # (path, binary_label, detailed_label)
    
    # Process dementia classes
    dementia_classes = {
        'Mild Dementia': 1,
        'Moderate Dementia': 2,
        'Very mild Dementia': 3
    }
    
    for category, label in dementia_classes.items():
        category_path = os.path.join(dataset_path, category)
        print(f"Loading {category} images from: {category_path}")
        category_files = os.listdir(category_path)
        print(f"Found {len(category_files)} {category} files")
        
        for image_filename in category_files:
            if not os.path.isfile(os.path.join(category_path, image_filename)):
                continue
                
            # Extract subject ID
            try:
                subject_id = image_filename.split('OAS1_')[1].split('_')[0]
            except:
                print(f"Skipping file with invalid format: {image_filename}")
                continue
                
            image_path = os.path.join(category_path, image_filename)
            
            # Add to subject dictionary
            if subject_id not in subject_images:
                subject_images[subject_id] = []
            subject_images[subject_id].append((image_path, 1, label))  # (path, binary_label, detailed_label)
    
    # Analyze subject distribution
    subject_class = {}  # {subject_id: binary_class}
    for subject_id, images in subject_images.items():
        # Determine the majority class for this subject
        class_counts = Counter([img[1] for img in images])
        majority_class = class_counts.most_common(1)[0][0]
        subject_class[subject_id] = majority_class
    
    # Count subjects per class
    class0_subjects = [s for s, c in subject_class.items() if c == 0]
    class1_subjects = [s for s, c in subject_class.items() if c == 1]
    
    print(f"Found {len(class0_subjects)} subjects in class 0 (Non-Demented)")
    print(f"Found {len(class1_subjects)} subjects in class 1 (Dementia)")
    
    # Balance number of subjects per class if needed
    min_subjects = min(len(class0_subjects), len(class1_subjects))
    if len(class0_subjects) > min_subjects:
        class0_subjects = random.sample(class0_subjects, min_subjects)
    if len(class1_subjects) > min_subjects:
        class1_subjects = random.sample(class1_subjects, min_subjects)
    
    print(f"Using {len(class0_subjects)} subjects from each class for balance")
    
    # Collect all images from selected subjects
    selected_subjects = class0_subjects + class1_subjects
    
    for subject_id in selected_subjects:
        for img_path, binary_label, detailed_label in subject_images[subject_id]:
            image_paths.append(img_path)
            binary_labels.append(binary_label)
            detailed_labels.append(detailed_label)
            subjects.append(subject_id)
    
    # Convert to numpy arrays
    binary_labels = np.array(binary_labels)
    detailed_labels = np.array(detailed_labels)
    
    # Print statistics
    print(f"Total images: {len(image_paths)}")
    print(f"Total subjects: {len(set(subjects))}")
    print(f"Binary label distribution: {np.bincount(binary_labels)}")
    
    # Print detailed statistics
    subject_counter = Counter(subjects)
    print(f"Number of unique subjects: {len(subject_counter)}")
    print(f"Average images per subject: {len(subjects) / len(subject_counter):.1f}")
    print(f"Top 10 subjects by image count: {sorted(subject_counter.items(), key=lambda x: x[1], reverse=True)[:10]}")
    
    return image_paths, binary_labels, detailed_labels, subjects


## Stage 1: Binary Classifier

In [30]:
def train_binary_classifier(X_train, y_train, X_val, y_val):
    # Data augmentation for training
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )
    
    # Only rescaling for validation
    val_datagen = ImageDataGenerator(rescale=1./255)
    
    # Create generators
    train_generator = train_datagen.flow(X_train, y_train, batch_size=32)
    val_generator = val_datagen.flow(X_val, y_val, batch_size=32)
    
    # Build model
    binary_model = Sequential([
        Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 1), kernel_regularizer=l2(0.001)),
        MaxPooling2D(pool_size=(2, 2)),
        
        Conv2D(64, (3, 3), activation='relu', kernel_regularizer=l2(0.001)),
        MaxPooling2D(pool_size=(2, 2)),
        
        Conv2D(128, (3, 3), activation='relu', kernel_regularizer=l2(0.001)),
        MaxPooling2D(pool_size=(2, 2)),
        
        Flatten(),
        Dense(128, activation='relu', kernel_regularizer=l2(0.001)),
        Dropout(0.5),
        Dense(1, activation='sigmoid')
    ])
    
    # Compile model
    binary_model.compile(
        optimizer=Adam(learning_rate=0.0001),
        loss='binary_crossentropy',
        metrics=['accuracy']
    )
    
    # Callbacks
    early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.00001)
    
    # Train model
    history = binary_model.fit(
        train_generator,
        validation_data=val_generator,
        epochs=20,
        callbacks=[early_stopping, reduce_lr]
    )
    
    return binary_model, history

## Stage 2: Multi-class Classification for Dementia Severity

In [31]:
def train_severity_classifier(X_train, y_train, X_val, y_val):
    # Only use samples that have dementia (classes 1, 2, 3)
    train_dementia_indices = np.where(y_train > 0)[0]
    val_dementia_indices = np.where(y_val > 0)[0]
    
    X_train_dementia = X_train[train_dementia_indices]
    y_train_dementia = y_train[train_dementia_indices] - 1  # Adjust labels to be 0, 1, 2
    
    X_val_dementia = X_val[val_dementia_indices]
    y_val_dementia = y_val[val_dementia_indices] - 1  # Adjust labels to be 0, 1, 2
    
    # Convert to one-hot encoding
    y_train_dementia = to_categorical(y_train_dementia, num_classes=3)
    y_val_dementia = to_categorical(y_val_dementia, num_classes=3)
    
    # Data augmentation
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )
    
    val_datagen = ImageDataGenerator(rescale=1./255)
    
    train_generator = train_datagen.flow(X_train_dementia, y_train_dementia, batch_size=32)
    val_generator = val_datagen.flow(X_val_dementia, y_val_dementia, batch_size=32)
    
    # Build model
    severity_model = Sequential([
        Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 1), kernel_regularizer=l2(0.001)),
        MaxPooling2D(pool_size=(2, 2)),
        
        Conv2D(64, (3, 3), activation='relu', kernel_regularizer=l2(0.001)),
        MaxPooling2D(pool_size=(2, 2)),
        
        Conv2D(128, (3, 3), activation='relu', kernel_regularizer=l2(0.001)),
        MaxPooling2D(pool_size=(2, 2)),
        
        Flatten(),
        Dense(128, activation='relu', kernel_regularizer=l2(0.001)),
        Dropout(0.5),
        Dense(3, activation='softmax')  # 3 classes: Mild, Moderate, Very Mild
    ])
    
    # Compile model
    severity_model.compile(
        optimizer=Adam(learning_rate=0.0001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    # Callbacks
    early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.00001)
    
    # Train model
    history = severity_model.fit(
        train_generator,
        validation_data=val_generator,
        epochs=50,
        callbacks=[early_stopping, reduce_lr]
    )
    
    return severity_model, history

In [33]:
def predict_hierarchical(binary_model, severity_model, image):
    # Preprocess image
    img = cv2.resize(image, (224, 224))
    img = img / 255.0
    img = img.reshape(1, 224, 224, 1)
    
    # Stage 1: Binary classification
    binary_pred = binary_model.predict(img)[0][0]
    
    if binary_pred < 0.5:  # Threshold can be adjusted
        return "Non-Demented", binary_pred, None
    else:
        # Stage 2: Severity classification
        severity_pred = severity_model.predict(img)[0]
        severity_class = np.argmax(severity_pred)
        
        severity_labels = ["Mild Dementia", "Moderate Dementia", "Very Mild Dementia"]
        return severity_labels[severity_class], binary_pred, severity_pred

In [38]:
def train_with_group_kfold(n_splits=5, epochs=20):
    """
    Train the model using GroupKFold cross-validation
    to properly handle subject-based splitting
    """
    # Prepare data
    image_paths, binary_labels, detailed_labels, subjects = prepare_balanced_data()
    
    # Load images
    print("Loading images...")
    X = load_images(image_paths)
    y = binary_labels
    groups = np.array(subjects)
    
    # Setup GroupKFold
    gkf = GroupKFold(n_splits=n_splits)
    
    # Store results
    fold_results = []
    
    # Train on each fold
    for fold, (train_idx, val_idx) in enumerate(gkf.split(X, y, groups)):
        print(f"\n===== Training fold {fold+1}/{n_splits} =====")
        
        # Split data
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        
        # Print fold statistics
        print(f"Training samples: {len(X_train)}, Validation samples: {len(X_val)}")
        print(f"Training class distribution: {np.bincount(y_train)}")
        print(f"Validation class distribution: {np.bincount(y_val)}")
        
        # Count unique subjects in each split
        train_subjects = set([groups[i] for i in train_idx])
        val_subjects = set([groups[i] for i in val_idx])
        print(f"Training subjects: {len(train_subjects)}, Validation subjects: {len(val_subjects)}")
        
        # Verify no subject overlap
        assert len(train_subjects.intersection(val_subjects)) == 0, "Subject leakage detected!"
        
        # Reshape to include channel dimension
        X_train = X_train[..., np.newaxis]
        X_val = X_val[..., np.newaxis]
        
        # Create data generators with augmentation
        train_datagen = ImageDataGenerator(
            rescale=1./255,  # Normalize pixel values
            rotation_range=15,  # Rotate images
            width_shift_range=0.1,  # Shift horizontally
            height_shift_range=0.1,  # Shift vertically
            shear_range=0.1,  # Shear
            zoom_range=0.1,  # Zoom
            horizontal_flip=True,  # Flip horizontally
            fill_mode='nearest'  # Fill strategy
        )
        
        val_datagen = ImageDataGenerator(rescale=1./255)  # Only normalize validation data
        
        # Create generators
        train_generator = train_datagen.flow(X_train, y_train, batch_size=32)
        val_generator = val_datagen.flow(X_val, y_val, batch_size=32)
        
        # Build model with regularization to prevent overfitting
        model = Sequential([
            # First convolutional block
            Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 1), 
                   kernel_regularizer=l2(0.001), padding='same'),
            MaxPooling2D(pool_size=(2, 2)),
            
            # Second convolutional block
            Conv2D(64, (3, 3), activation='relu', kernel_regularizer=l2(0.001), padding='same'),
            MaxPooling2D(pool_size=(2, 2)),
            
            # Third convolutional block
            Conv2D(128, (3, 3), activation='relu', kernel_regularizer=l2(0.001), padding='same'),
            MaxPooling2D(pool_size=(2, 2)),
            
            # Flatten and dense layers
            Flatten(),
            Dense(128, activation='relu', kernel_regularizer=l2(0.001)),
            Dropout(0.5),  # Strong dropout to prevent overfitting
            Dense(1, activation='sigmoid')
        ])
        
        # Compile model
        model.compile(
            optimizer=Adam(learning_rate=0.0001),  # Low learning rate
            loss='binary_crossentropy',
            metrics=['accuracy']
        )
        
        # Print model summary
        model.summary()
        
        # Callbacks
        early_stopping = EarlyStopping(
            monitor='val_loss', 
            patience=10,  # Wait for 10 epochs before stopping
            restore_best_weights=True,  # Restore weights from best epoch
            verbose=1
        )
        
        reduce_lr = ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.2,  # Reduce LR by 80% when plateauing
            patience=5,
            min_lr=0.00001,
            verbose=1
        )
        
        # Train model
        history = model.fit(
            train_generator,
            validation_data=val_generator,
            epochs=epochs,
            callbacks=[early_stopping, reduce_lr],
            verbose=1
        )
        
        # Evaluate model
        val_loss, val_acc = model.evaluate(val_generator, verbose=0)
        fold_results.append(val_acc)
        
        print(f"Fold {fold+1} validation accuracy: {val_acc:.4f}")
        
        # Save model for this fold
        model.save(f'alzheimer_model_fold_{fold+1}.h5')
    
    # Print overall results
    print("\n===== Cross-validation Results =====")
    for i, acc in enumerate(fold_results):
        print(f"Fold {i+1}: {acc:.4f}")
    print(f"Average validation accuracy: {np.mean(fold_results):.4f}")
    print(f"Standard deviation: {np.std(fold_results):.4f}")
    
    return fold_results


In [None]:
# Import necessary libraries
import os
import random
import numpy as np
import cv2
from collections import Counter
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.regularizers import l2
from sklearn.model_selection import train_test_split

# Load and prepare data
image_paths, binary_labels, detailed_labels, subjects = prepare_binary_data()

# Ensure data distribution is balanced
print("Binary labels:", np.bincount(binary_labels))

# Load images
def load_images(paths, img_size=(224, 224)):
    images = []
    for path in paths:
        try:
            img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                print(f"Failed to load image: {path}")
                continue
            img = cv2.resize(img, img_size)
            images.append(img)
        except Exception as e:
            print(f"Error processing image {path}: {str(e)}")
    return np.array(images)

X = load_images(image_paths)

# Split data by subject
train_indices, val_indices = split_by_subject(image_paths, binary_labels, detailed_labels, subjects)

# Create train/val datasets
X_train = X[train_indices]
y_train_binary = binary_labels[train_indices]
y_train_detailed = detailed_labels[train_indices]

X_val = X[val_indices]
y_val_binary = binary_labels[val_indices]
y_val_detailed = detailed_labels[val_indices]

# Reshape to include channel dimension
X_train = X_train[..., np.newaxis]
X_val = X_val[..., np.newaxis]

# Train binary classifier
binary_model, binary_history = train_binary_classifier(X_train, y_train_binary, X_val, y_val_binary)

# Train severity classifier
# severity_model, severity_history = train_severity_classifier(X_train, y_train_detailed, X_val, y_val_detailed)

# Save models
binary_model.save('binary_classifier.h5')
# severity_model.save('severity_classifier.h5')

# Evaluate models
binary_eval = binary_model.evaluate(X_val, y_val_binary)
print(f"Binary classifier - Validation Loss: {binary_eval[0]:.4f}, Validation Accuracy: {binary_eval[1]:.4f}")

# Evaluate severity classifier on dementia samples only
# val_dementia_indices = np.where(y_val_detailed > 0)[0]
# X_val_dementia = X_val[val_dementia_indices]
# y_val_dementia = y_val_detailed[val_dementia_indices] - 1  # Adjust to 0-based
# y_val_dementia_cat = to_categorical(y_val_dementia, num_classes=3)

# severity_eval = severity_model.evaluate(X_val_dementia, y_val_dementia_cat)
# print(f"Severity classifier - Validation Loss: {severity_eval[0]:.4f}, Validation Accuracy: {severity_eval[1]:.4f}")

Accessing dataset from: ./OASIS Data
Loading Non-Demented images from: ./OASIS Data/Non Demented
Found 1000 Non-Demented files
Sampled 1000 Non-Demented files
Successfully processed 1000 Non-Demented images
Loading Mild Dementia images from: ./OASIS Data/Mild Dementia
Found 1000 Mild Dementia files
Successfully processed 1000 Mild Dementia images
Loading Moderate Dementia images from: ./OASIS Data/Moderate Dementia
Found 488 Moderate Dementia files
Successfully processed 488 Moderate Dementia images
Loading Very mild Dementia images from: ./OASIS Data/Very mild Dementia
Found 1000 Very mild Dementia files
Successfully processed 1000 Very mild Dementia images
Total images: 2000
Total subjects: 17
Binary label distribution: [1000 1000]
Binary labels: [1000 1000]
Unique subjects: ['0001', '0022', '0015', '0053', '0006', '0028', '0031', '0004', '0351', '0016', '0003', '0052', '0308', '0021', '0002', '0005', '0035']
Subject to binary class: {'0005': 0, '0002': 0, '0001': 0, '0004': 0, '0006

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  self._warn_if_super_not_called()


Epoch 1/20
[1m49/49[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 476ms/step - accuracy: 0.5375 - loss: 1.0291 - val_accuracy: 0.4590 - val_loss: 0.9228 - learning_rate: 1.0000e-04
Epoch 2/20
[1m49/49[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 556ms/step - accuracy: 0.6873 - loss: 0.8447 - val_accuracy: 0.4590 - val_loss: 0.8842 - learning_rate: 1.0000e-04
Epoch 3/20
[1m49/49[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 528ms/step - accuracy: 0.6963 - loss: 0.7581 - val_accuracy: 0.4590 - val_loss: 0.9902 - learning_rate: 1.0000e-04
Epoch 4/20
[1m49/49[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 858ms/step - accuracy: 0.7404 - loss: 0.7083 - val_accuracy: 0.4590 - val_loss: 1.1568 - learning_rate: 1.0000e-04
Epoch 5/20
[1m49/49[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m99s[0m 2s/step - accuracy: 0.7275 - loss: 0.6946 - val_accuracy: 0.4590 - val_loss: 1.0322 - learning_rate: 1.0000e-04
Epoch 6/20
[1m49/49[0m [32m━━━━━━━━━━━━━━━━━━

KeyboardInterrupt: 

In [40]:
import os
import random
import numpy as np
import cv2
from collections import Counter
from sklearn.model_selection import GroupKFold, train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.regularizers import l2

if __name__ == "__main__":
    train_with_group_kfold(n_splits=5, epochs=30)

Accessing dataset from: ./OASIS Data
Loading Non-Demented images from: ./OASIS Data/Non Demented
Found 1000 Non-Demented files
Loading Mild Dementia images from: ./OASIS Data/Mild Dementia
Found 1000 Mild Dementia files
Loading Moderate Dementia images from: ./OASIS Data/Moderate Dementia
Found 488 Moderate Dementia files
Loading Very mild Dementia images from: ./OASIS Data/Very mild Dementia
Found 1000 Very mild Dementia files
Found 5 subjects in class 0 (Non-Demented)
Found 12 subjects in class 1 (Dementia)
Using 5 subjects from each class for balance
Total images: 2098
Total subjects: 10
Binary label distribution: [1000 1098]
Number of unique subjects: 10
Average images per subject: 209.8
Top 10 subjects by image count: [('0002', 244), ('0004', 244), ('0001', 244), ('0005', 244), ('0003', 244), ('0021', 244), ('0028', 244), ('0015', 183), ('0016', 183), ('0006', 24)]
Loading images...

===== Training fold 1/5 =====
Training samples: 1610, Validation samples: 488
Training class distr

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/30


  self._warn_if_super_not_called()


[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 608ms/step - accuracy: 0.5776 - loss: 1.0217 - val_accuracy: 0.3791 - val_loss: 0.9979 - learning_rate: 1.0000e-04
Epoch 2/30
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 723ms/step - accuracy: 0.7156 - loss: 0.7900 - val_accuracy: 0.5574 - val_loss: 0.9643 - learning_rate: 1.0000e-04
Epoch 3/30
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 576ms/step - accuracy: 0.8077 - loss: 0.6131 - val_accuracy: 0.5000 - val_loss: 1.5302 - learning_rate: 1.0000e-04
Epoch 4/30
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 807ms/step - accuracy: 0.8255 - loss: 0.5850 - val_accuracy: 0.4529 - val_loss: 1.1042 - learning_rate: 1.0000e-04
Epoch 5/30
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 1s/step - accuracy: 0.8636 - loss: 0.4970 - val_accuracy: 0.4877 - val_loss: 1.2187 - learning_rate: 1.0000e-04
Epoch 6/30
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m



Fold 1 validation accuracy: 0.5574

===== Training fold 2/5 =====
Training samples: 1610, Validation samples: 488
Training class distribution: [756 854]
Validation class distribution: [244 244]
Training subjects: 8, Validation subjects: 2


Epoch 1/30
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m56s[0m 1s/step - accuracy: 0.5180 - loss: 1.0270 - val_accuracy: 0.4365 - val_loss: 0.9149 - learning_rate: 1.0000e-04
Epoch 2/30
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m56s[0m 1s/step - accuracy: 0.7049 - loss: 0.8198 - val_accuracy: 0.5000 - val_loss: 0.9583 - learning_rate: 1.0000e-04
Epoch 3/30
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m53s[0m 1s/step - accuracy: 0.7148 - loss: 0.7370 - val_accuracy: 0.5471 - val_loss: 0.9080 - learning_rate: 1.0000e-04
Epoch 4/30
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m53s[0m 1s/step - accuracy: 0.7787 - loss: 0.6496 - val_accuracy: 0.5000 - val_loss: 1.1159 - learning_rate: 1.0000e-04
Epoch 5/30
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 821ms/step - accuracy: 0.7991 - loss: 0.6087 - val_accuracy: 0.8525 - val_loss: 0.5899 - learning_rate: 1.0000e-04
Epoch 6/30
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[3



Fold 2 validation accuracy: 0.8525

===== Training fold 3/5 =====
Training samples: 1671, Validation samples: 427
Training class distribution: [756 915]
Validation class distribution: [244 183]
Training subjects: 8, Validation subjects: 2


Epoch 1/30
[1m53/53[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m52s[0m 969ms/step - accuracy: 0.5027 - loss: 1.0259 - val_accuracy: 0.4286 - val_loss: 0.9102 - learning_rate: 1.0000e-04
Epoch 2/30
[1m53/53[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m48s[0m 897ms/step - accuracy: 0.5626 - loss: 0.8703 - val_accuracy: 0.6557 - val_loss: 0.8099 - learning_rate: 1.0000e-04
Epoch 3/30
[1m53/53[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m52s[0m 980ms/step - accuracy: 0.6123 - loss: 0.8073 - val_accuracy: 0.7775 - val_loss: 0.7386 - learning_rate: 1.0000e-04
Epoch 4/30
[1m53/53[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m47s[0m 893ms/step - accuracy: 0.6288 - loss: 0.7583 - val_accuracy: 0.7190 - val_loss: 0.6853 - learning_rate: 1.0000e-04
Epoch 5/30
[1m53/53[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m47s[0m 895ms/step - accuracy: 0.6734 - loss: 0.7006 - val_accuracy: 0.8712 - val_loss: 0.6191 - learning_rate: 1.0000e-04
Epoch 6/30
[1m53/53[0m [32m━━━━━━━━━━━━━━━



Fold 3 validation accuracy: 1.0000

===== Training fold 4/5 =====
Training samples: 1671, Validation samples: 427
Training class distribution: [756 915]
Validation class distribution: [244 183]
Training subjects: 8, Validation subjects: 2


Epoch 1/30
[1m53/53[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 1s/step - accuracy: 0.5195 - loss: 1.0377 - val_accuracy: 0.4286 - val_loss: 0.9368 - learning_rate: 1.0000e-04
Epoch 2/30
[1m53/53[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.6041 - loss: 0.8738

KeyboardInterrupt: 

Okay this is great cause our model is learning at least. Since we already have such a small dataset, I think the fact that Moderately demented has two unique subjects  heavily skews our learning process. 

We have 5 unique non-demented subjects, and 10 demented subjects (excluding moderate dementia), so that's 5 entries for each class. How about we keep all 15 unique subjects, and for the demented patients just limit the number of images per subject so that we can have about a 1:1 ratio of images for demented vs. non-demented subjects. 

Even if the experiment fails, I think the fact that we only have 2 subjects for moderate dementia messes with our data distribution and learning process for the model

In [None]:
def prepare_balanced_data_by_image_count():
    """
    Prepares data with all unique subjects but balances the number of images
    per class by limiting images per subject in the larger class.
    Returns: image_paths, binary_labels, detailed_labels, subjects
    """
    # Dataset path
    dataset_path = './OASIS Data'
    
    # Initialize lists
    image_paths = []
    binary_labels = []  # 0 for Non-Demented, 1 for any type of Dementia
    detailed_labels = []  # 0: Non-Demented, 1: Mild, 3: Very Mild
    subjects = []
    
    # Dictionary to track subjects and their classes
    subject_images = {}  # {subject_id: [(image_path, binary_label, detailed_label), ...]}
    
    print(f"Accessing dataset from: {dataset_path}")
    
    # Process Non-Demented images
    non_demented_path = os.path.join(dataset_path, 'Non Demented')
    print(f"Loading Non-Demented images from: {non_demented_path}")
    non_demented_files = os.listdir(non_demented_path)
    print(f"Found {len(non_demented_files)} Non-Demented files")
    
    # Process all non-demented files
    for image_filename in non_demented_files:
        if not os.path.isfile(os.path.join(non_demented_path, image_filename)):
            continue
            
        # Extract subject ID
        try:
            subject_id = image_filename.split('OAS1_')[1].split('_')[0]
        except:
            print(f"Skipping file with invalid format: {image_filename}")
            continue
            
        image_path = os.path.join(non_demented_path, image_filename)
        
        # Add to subject dictionary
        if subject_id not in subject_images:
            subject_images[subject_id] = []
        subject_images[subject_id].append((image_path, 0, 0))  # (path, binary_label, detailed_label)
    
    # Process dementia classes (excluding Moderate Dementia)
    dementia_classes = {
        'Mild Dementia': 1,
        'Very mild Dementia': 3
    }
    
    for category, label in dementia_classes.items():
        category_path = os.path.join(dataset_path, category)
        print(f"Loading {category} images from: {category_path}")
        category_files = os.listdir(category_path)
        print(f"Found {len(category_files)} {category} files")
        
        for image_filename in category_files:
            if not os.path.isfile(os.path.join(category_path, image_filename)):
                continue
                
            # Extract subject ID
            try:
                subject_id = image_filename.split('OAS1_')[1].split('_')[0]
            except:
                print(f"Skipping file with invalid format: {image_filename}")
                continue
                
            image_path = os.path.join(category_path, image_filename)
            
            # Add to subject dictionary
            if subject_id not in subject_images:
                subject_images[subject_id] = []
            subject_images[subject_id].append((image_path, 1, label))  # (path, binary_label, detailed_label)
    
    # Analyze subject distribution
    subject_class = {}  # {subject_id: binary_class}
    for subject_id, images in subject_images.items():
        # Determine the majority class for this subject
        class_counts = Counter([img[1] for img in images])
        majority_class = class_counts.most_common(1)[0][0]
        subject_class[subject_id] = majority_class
    
    # Count subjects per class
    class0_subjects = [s for s, c in subject_class.items() if c == 0]
    class1_subjects = [s for s, c in subject_class.items() if c == 1]
    
    print(f"Found {len(class0_subjects)} subjects in class 0 (Non-Demented)")
    print(f"Found {len(class1_subjects)} subjects in class 1 (Dementia)")
    
    # Count total images per class
    class0_images = sum(len(subject_images[s]) for s in class0_subjects)
    class1_images = sum(len(subject_images[s]) for s in class1_subjects)
    
    print(f"Found {class0_images} images in class 0 (Non-Demented)")
    print(f"Found {class1_images} images in class 1 (Dementia)")
    
    # Calculate target number of images per dementia subject to balance classes
    target_images_per_dementia_subject = class0_images // len(class1_subjects)
    print(f"Limiting to {target_images_per_dementia_subject} images per dementia subject to balance classes")
    
    # Collect all non-demented images
    for subject_id in class0_subjects:
        for img_path, binary_label, detailed_label in subject_images[subject_id]:
            image_paths.append(img_path)
            binary_labels.append(binary_label)
            detailed_labels.append(detailed_label)
            subjects.append(subject_id)
    
    # Collect limited dementia images
    for subject_id in class1_subjects:
        # Randomly sample images if we have more than the target
        subject_image_list = subject_images[subject_id]
        if len(subject_image_list) > target_images_per_dementia_subject:
            subject_image_list = random.sample(subject_image_list, target_images_per_dementia_subject)
        
        for img_path, binary_label, detailed_label in subject_image_list:
            image_paths.append(img_path)
            binary_labels.append(binary_label)
            detailed_labels.append(detailed_label)
            subjects.append(subject_id)
    
    # Convert to numpy arrays
    binary_labels = np.array(binary_labels)
    detailed_labels = np.array(detailed_labels)
    
    # Print statistics
    print(f"Total images: {len(image_paths)}")
    print(f"Total subjects: {len(set(subjects))}")
    print(f"Binary label distribution: {np.bincount(binary_labels)}")
    
    # Print detailed statistics
    subject_counter = Counter(subjects)
    print(f"Number of unique subjects: {len(subject_counter)}")
    print(f"Average images per subject: {len(subjects) / len(subject_counter):.1f}")
    print(f"Images per subject: {dict(subject_counter)}")
    
    return image_paths, binary_labels, detailed_labels, subjects