# 5. Create k-Fold Datasets for MCI Progression

This notebook takes the separated pMCI and sMCI datasets and creates 5-fold cross-validation splits. This is a crucial step for training a robust model for predicting MCI to AD conversion.

The process is as follows:

1.  **File Shuffling**: The lists of pMCI and sMCI files are shuffled randomly to ensure that the folds are not biased by the original file order.
2.  **Fold Creation**: The shuffled lists are divided into 5 equal parts. For each of the 5 iterations (folds):
    *   One part is designated as the **validation set**.
    *   The remaining four parts are combined to form the **training set**.
3.  **Preprocessing**: Each image in the training and validation sets for a given fold is preprocessed using the same 3D pipeline as before (conform, skull-strip, normalize).
4.  **Saving**: The processed images and their corresponding labels for each training and validation set are saved as individual pickled dictionary files. This results in 10 files (5 for training, 5 for validation), each containing PyTorch tensors.

In [None]:
import torch
import numpy as np
import os
import random
from pathlib import Path
import pickle
from tqdm.notebook import tqdm
import nibabel as nib
from sklearn.model_selection import StratifiedKFold
from scipy.ndimage import zoom

### Define Paths and Parameters

In [None]:
# Input paths for the split MCI data
mci_split_path = Path("PATH_TO_DATA")
pmci_path = mci_split_path / "pMCI_stripped"
smci_path = mci_split_path / "sMCI_stripped"

# Output path for k-fold datasets
kfold_output_path = Path("./data/processed/kfold_fdg_final_run/")
kfold_output_path.mkdir(parents=True, exist_ok=True)

# Parameters
NUM_FOLDS = 5
RANDOM_SEED = 42

# Set random seeds for reproducibility
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

### Preprocessing Functions

In [None]:

import cc3d 
from skimage.filters import threshold_otsu
from scipy.ndimage import binary_erosion, binary_dilation, binary_fill_holes, zoom


def resize_volume(volume, target_shape=(100, 100, 90)):
    """
    Resize 3D volume to target shape using trilinear interpolation.
    
    Args:
        volume: 3D numpy array
        target_shape: tuple of target dimensions (height, width, depth)
    
    Returns:
        Resized 3D numpy array
    """
    current_shape = volume.shape
    zoom_factors = [target_shape[i] / current_shape[i] for i in range(3)]
    return zoom(volume, zoom_factors, order=1)  # order=1 for trilinear interpolation


def preprocess_image(image_path, use_brain_mask=True):
    """
    Load a NIfTI image, resize to 100x100x90, and normalize intensities to [0,1].
    
    IMPORTANT: Resizing is done BEFORE masking to prevent interpolation from
    ruining the brain mask!
    
    Args:
        image_path: Path to NIfTI file
        use_brain_mask: Whether to apply brain masking (default True)
    """
    try:
        # Load the NIfTI image
        input_img = nib.load(image_path)
        img_data = input_img.get_fdata()

        # Handle 4D images by taking the first volume
        if img_data.ndim == 4:
            print(f"Note: 4D image detected, using first volume: {image_path.name}")
            img_data = img_data[..., 0]
        elif img_data.ndim != 3:
            print(f"Error: Unsupported image dimensions {img_data.shape}: {image_path.name}")
            return None
        
        # STEP 1: Resize to target dimensions (100, 100, 90)
        img_data = resize_volume(img_data, target_shape=(100, 100, 90))

        # STEP 2: Intensity normalization to [0,1]
        img_min, img_max = img_data.min(), img_data.max()
        if img_max > img_min:  # avoid divide-by-zero
            img_data = (img_data - img_min) / (img_max - img_min)
        else:
            img_data = np.zeros_like(img_data)

        return img_data.astype(np.float32)

    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        return None

### Load and Organize Files

In [None]:
# Get file lists
pmci_files = sorted([pmci_path / f for f in os.listdir(pmci_path) if f.endswith('.nii.gz')])
smci_files = sorted([smci_path / f for f in os.listdir(smci_path) if f.endswith('.nii.gz')])

print(f"Found {len(pmci_files)} pMCI files")
print(f"Found {len(smci_files)} sMCI files")

# Create labels (1 for pMCI/converter, 0 for sMCI/stable)
pmci_labels = [1] * len(pmci_files)
smci_labels = [0] * len(smci_files)

# Extract subject IDs
def extract_subject_id(filepath):
    # Extracts subject ID like '002_S_0729' from the full path
    name = filepath.stem.split('.')[0]  # remove .nii
    parts = name.split('_')
    return f"{parts[0]}_{parts[1]}_{parts[2]}"

pmci_subjects = [extract_subject_id(f) for f in pmci_files]
smci_subjects = [extract_subject_id(f) for f in smci_files]

# Combine files, labels, and subjects
all_files = pmci_files + smci_files
all_labels = pmci_labels + smci_labels
all_subjects = pmci_subjects + smci_subjects

# Convert to numpy arrays for easier manipulation
all_files = np.array(all_files)
all_labels = np.array(all_labels)
all_subjects = np.array(all_subjects)

print(f"Total files: {len(all_files)}")
print(f"pMCI/sMCI ratio: {len(pmci_files)}/{len(smci_files)}")

### Create K-Fold Splits

In [None]:
# Create stratified k-fold splits to maintain class balance
kfold = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=RANDOM_SEED)

for fold_idx, (train_indices, val_indices) in enumerate(kfold.split(all_files, all_labels)):
    print(f"\n--- Processing Fold {fold_idx + 1} ---")
    
    # Get train and validation sets for this fold
    train_files, train_labels, train_subjects = all_files[train_indices], all_labels[train_indices], all_subjects[train_indices]
    val_files, val_labels, val_subjects = all_files[val_indices], all_labels[val_indices], all_subjects[val_indices]
    
    print(f"Train: {len(train_files)} files (pMCI: {np.sum(train_labels)}, sMCI: {len(train_labels) - np.sum(train_labels)})")
    print(f"Val: {len(val_files)} files (pMCI: {np.sum(val_labels)}, sMCI: {len(val_labels) - np.sum(val_labels)})")

    # Process and save function
    def process_and_save_fold(files, labels, subjects, fold_type, fold_num):
        print(f"Processing {fold_type} data...")
        images, valid_labels, valid_subjects = [], [], []
        
        for i, file_path in enumerate(tqdm(files, desc=f"{fold_type.capitalize()} images")):
            processed = preprocess_image(file_path)
            if processed is not None:
                # Add channel dimension for 3D CNNs
                images.append(np.expand_dims(processed, axis=0))  # (1, D, H, W)
                valid_labels.append(labels[i])
                valid_subjects.append(subjects[i])
        
        if images:
            data_dict = {
                "images": torch.tensor(np.array(images)),
                "labels": torch.tensor(valid_labels),
                "subject_ids": np.array(valid_subjects)
            }
            file_name = kfold_output_path / f"{fold_type}_fold_{fold_num}.pkl"
            with open(file_name, 'wb') as f:
                pickle.dump(data_dict, f)
            print(f"Saved {len(images)} {fold_type} images to {file_name}")

    # Process and save for the current fold
    process_and_save_fold(train_files, train_labels, train_subjects, "train", fold_idx + 1)
    process_and_save_fold(val_files, val_labels, val_subjects, "val", fold_idx + 1)

print("\nK-fold dataset creation completed!")

### Save K-Fold Information

In [None]:
# Save metadata about the k-fold splits
kfold_info = {
    "num_folds": NUM_FOLDS,
    "random_seed": RANDOM_SEED,
    "total_pmci_files": len(pmci_files),
    "total_smci_files": len(smci_files),
    "fold_files": {
        f"fold_{i+1}": {
            "train_file": f"train_fold_{i+1}.pkl",
            "val_file": f"val_fold_{i+1}.pkl"
        } for i in range(NUM_FOLDS)
    }
}

info_file = kfold_output_path / "kfold_info.pkl"
with open(info_file, 'wb') as f:
    pickle.dump(kfold_info, f)

print(f"K-fold information saved to {info_file}")

In [None]:
# Visualize original vs processed image comparison
import matplotlib.pyplot as plt

# Load a random fold
random_fold = np.random.randint(1, NUM_FOLDS + 1)
train_file = kfold_output_path / f"train_fold_{random_fold}.pkl"

with open(train_file, 'rb') as f:
    train_data = pickle.load(f)

# Select a random image
random_idx = np.random.randint(0, len(train_data['images']))
processed_image = train_data['images'][random_idx].numpy()
random_label = train_data['labels'][random_idx].item()
random_subject = train_data['subject_ids'][random_idx]

# Remove channel dimension from processed image
if processed_image.ndim == 4 and processed_image.shape[0] == 1:
    processed_image = processed_image.squeeze(0)
elif processed_image.ndim == 4:
    print(f"Warning: Unexpected 4D shape: {processed_image.shape}")
    processed_image = processed_image[0]

# Load the ORIGINAL unprocessed image
label_name = "pMCI" if random_label == 1 else "sMCI"
original_dir = pmci_path if random_label == 1 else smci_path

# Find the original file for this subject
original_file = None
for f in original_dir.glob('*.nii.gz'):
    if random_subject in f.name:
        original_file = f
        break

if original_file is None:
    print(f"ERROR: Could not find original file for subject {random_subject}")
else:
    # Load original image
    original_img = nib.load(original_file)
    original_data = original_img.get_fdata()
    if original_data.ndim == 4:
        original_data = original_data[..., 0]
    
    # Resize original to match processed for comparison
    from scipy.ndimage import zoom
    zoom_factors = [processed_image.shape[i] / original_data.shape[i] for i in range(3)]
    original_resized = zoom(original_data, zoom_factors, order=1)
    
    # Normalize original to [0,1] for display
    orig_min, orig_max = original_resized.min(), original_resized.max()
    if orig_max > orig_min:
        original_resized = (original_resized - orig_min) / (orig_max - orig_min)
    
    # Get the middle slice
    width_slice = processed_image.shape[2] // 2
    
    # Create side-by-side comparison
    fig, axes = plt.subplots(1, 2, figsize=(14, 7))
    
    # Original (unprocessed)
    axes[0].imshow(original_resized[:, :, width_slice], cmap='gray')
    axes[0].set_title('ORIGINAL (Unprocessed)\nWith Skull', fontsize=12, fontweight='bold', color='red')
    axes[0].axis('off')
    
    # Processed (with brain masking)
    axes[1].imshow(processed_image[:, :, width_slice], cmap='gray')
    axes[1].set_title('PROCESSED (Brain Masked)\nSkull Removed', fontsize=12, fontweight='bold', color='green')
    axes[1].axis('off')
    
    fig.suptitle(f'Sagittal View Comparison - Subject: {random_subject} ({label_name}) - Fold {random_fold}', 
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    
    # Save comparison
    sagittal_output_path = kfold_output_path / "visualizations"
    sagittal_output_path.mkdir(exist_ok=True)
    comparison_filename = sagittal_output_path / f"comparison_fold{random_fold}_{random_subject}_{label_name}.png"
    plt.savefig(comparison_filename, dpi=150, bbox_inches='tight')
    print(f"Comparison saved to {comparison_filename}")
    
    plt.show()
    
    # Print statistics
    print(f"\nSubject ID: {random_subject}")
    print(f"Label: {label_name} ({random_label})")
    print(f"Original file: {original_file.name}")
    print(f"Original shape: {original_data.shape}")
    print(f"Processed shape: {processed_image.shape}")
    print(f"\nOriginal non-zero voxels: {(original_resized > 0).sum()}")
    print(f"Processed non-zero voxels: {(processed_image > 0).sum()}")
    print(f"Voxels removed by masking: {(original_resized > 0).sum() - (processed_image > 0).sum()}")
    print(f"Percentage removed: {100 * (1 - (processed_image > 0).sum() / (original_resized > 0).sum()):.2f}%")
