# 2. 3D Preprocessing of AD and CN Data

This notebook takes the organized AD and CN neuroimages and performs the necessary 3D preprocessing steps. The pipeline for each image is as follows:

1.  **Patient Limiting**: Limits the dataset to a maximum of 300 AD subjects and 300 CN subjects (if available) to balance computational efficiency with dataset size.
2.  **Resizing**: Each 3D image is resized to a standard size of (91, 109, 91) using trilinear interpolation for computational efficiency.
3.  **Normalization**: The intensity values are normalized to a [0, 1] range.
4.  **Data Splitting**: The processed images are split into training, validation, and test sets (80/10/10 split) at the subject level to prevent data leakage.
5.  **Saving**: The final processed data (images and labels) for each set is saved as a pickled dictionary containing PyTorch tensors.

In [None]:
import torch
import numpy as np
import nibabel as nib
import os
from skimage.filters import threshold_otsu
import cc3d
from pathlib import Path
from tqdm.notebook import tqdm
import pickle
import random

### Define Paths and Parameters

In [None]:
# Define base paths for input and output
base_path = Path("PATH_TO_DATA")
output_path = Path("PATH_TO_DATA")
output_path.mkdir(parents=True, exist_ok=True)

# Input data paths
ad_path = base_path / "ad_stripped"
cn_path = base_path / "cn_stripped"

# Verify paths exist and show file counts
print(f"AD path: {ad_path}")
print(f"AD path exists: {ad_path.exists()}")
if ad_path.exists():
    ad_files_count = len(list(ad_path.glob('*.nii*')))
    print(f"AD .nii files found: {ad_files_count}")

print(f"\nCN path: {cn_path}")
print(f"CN path exists: {cn_path.exists()}")
if cn_path.exists():
    cn_files_count = len(list(cn_path.glob('*.nii*')))
    print(f"CN .nii files found: {cn_files_count}")

# Output data paths for train, validation, and test sets
train_path = output_path / "train"
val_path = output_path / "val"
test_path = output_path / "test"

# Create the output directories
train_path.mkdir(exist_ok=True)
val_path.mkdir(exist_ok=True)
test_path.mkdir(exist_ok=True)

### Preprocessing Functions

In [None]:
import nibabel as nib
import numpy as np
import torch
import pickle
from scipy.ndimage import zoom


def resize_volume(volume, target_shape=(100, 100, 90)):
    """
    Resize 3D volume to target shape using trilinear interpolation.
    
    Args:
        volume: 3D numpy array
        target_shape: tuple of target dimensions (height, width, depth)
    
    Returns:
        Resized 3D numpy array
    """
    current_shape = volume.shape
    zoom_factors = [target_shape[i] / current_shape[i] for i in range(3)]
    return zoom(volume, zoom_factors, order=1)  # order=1 for trilinear interpolation

def preprocess_image(image_path):
    """
    Load a NIfTI image, resize to 91x109x91, and normalize intensities to [0,1].
    - Assumes the image is already standardized (ADNI PET).
    """
    try:
        # Load the NIfTI image
        input_img = nib.load(image_path)
        img_data = input_img.get_fdata()

        # Handle 4D images by taking the first volume
        if img_data.ndim == 4:
            print(f"Note: 4D image detected, using first volume: {image_path.name}")
            img_data = img_data[..., 0]
        elif img_data.ndim != 3:
            print(f"Error: Unsupported image dimensions {img_data.shape}: {image_path.name}")
            return None

        # Resize to target dimensions (100, 100, 90) for computational efficiency
        img_data = resize_volume(img_data, target_shape=(100, 100, 90))

        # Intensity normalization to [0,1]
        img_min, img_max = img_data.min(), img_data.max()
        if img_max > img_min:  # avoid divide-by-zero
            img_data = (img_data - img_min) / (img_max - img_min)
        else:
            img_data = np.zeros_like(img_data)

        return img_data.astype(np.float32)

    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        return None


def process_and_save(image_paths, labels, output_file):
    """
    Process a list of images and save them as a pickled dictionary of PyTorch tensors.
    Adds a channel dimension for 3D CNNs.
    """
    processed_images = []
    valid_labels = []

    for i, image_path in enumerate(tqdm(image_paths, desc=f"Processing {output_file.name}")):
        processed = preprocess_image(image_path)
        if processed is not None:
            processed_images.append(np.expand_dims(processed, axis=0))  # (1, D, H, W)
            valid_labels.append(labels[i])

    if processed_images:
        data = {
            "images": torch.tensor(np.array(processed_images)),
            "labels": torch.tensor(valid_labels)
        }
        with open(output_file, 'wb') as f:
            pickle.dump(data, f)
        print(f"Saved {len(processed_images)} images to {output_file}")


### Data Splitting and Processing

In [None]:

# Get all files and organize by subject to prevent data leakage
print("Loading and organizing files by subject...")

# Load AD files
ad_files = sorted([f for f in ad_path.glob('*.nii*')])  # Matches both .nii and .nii.gz

print(f"Found {len(ad_files)} AD files")

# Load CN files  
cn_files = sorted([f for f in cn_path.glob('*.nii*')])  # Matches both .nii and .nii.gz
print(f"Found {len(cn_files)} CN files")

if len(ad_files) == 0 and len(cn_files) == 0:
    print("❌ ERROR: No .nii files found! Make sure you've run the organization notebook first.")
    print("Expected files in:")
    print(f"  - {ad_path}")
    print(f"  - {cn_path}")
else:
    print(f"✓ Total files found: {len(ad_files) + len(cn_files)}")

# Extract subject IDs and group files by subject to prevent data leakage
def extract_subject_id(filename):
    """Extract subject ID from filename (format: SUBJECT_ID_*.nii)"""
    parts = filename.stem.split('_')
    if len(parts) >= 3:
        return f"{parts[0]}_{parts[1]}_{parts[2]}"  # e.g., "002_S_0295"
    return filename.stem

# Group files by subject - handle subjects with multiple diagnoses
subject_files = {}

# Process AD files
for file_path in ad_files:
    subject_id = extract_subject_id(file_path)
    if subject_id not in subject_files:
        subject_files[subject_id] = {'files': [], 'labels': []}
    subject_files[subject_id]['files'].append(file_path)
    subject_files[subject_id]['labels'].append(1)  # AD = 1

# Process CN files
for file_path in cn_files:
    subject_id = extract_subject_id(file_path)
    if subject_id not in subject_files:
        subject_files[subject_id] = {'files': [], 'labels': []}
    subject_files[subject_id]['files'].append(file_path)
    subject_files[subject_id]['labels'].append(0)  # CN = 0

print(f"\nFound {len(subject_files)} unique subjects")

# Limit to 300 subjects per class
MAX_AD_SUBJECTS = 400
MAX_CN_SUBJECTS = 400

# Separate subjects by diagnosis
ad_subjects = {subj: data for subj, data in subject_files.items() if 1 in data['labels']}
cn_subjects = {subj: data for subj, data in subject_files.items() if 0 in data['labels'] and 1 not in data['labels']}

print(f"\nBefore limiting:")
print(f"  AD subjects: {len(ad_subjects)}")
print(f"  CN subjects: {len(cn_subjects)}")

# Randomly select subjects to keep the limit
random.seed(42)

if len(ad_subjects) > MAX_AD_SUBJECTS:
    ad_subject_ids = list(ad_subjects.keys())
    random.shuffle(ad_subject_ids)
    ad_subject_ids = ad_subject_ids[:MAX_AD_SUBJECTS]
    ad_subjects = {subj: ad_subjects[subj] for subj in ad_subject_ids}
    print(f"  → Limited AD subjects to {len(ad_subjects)}")

if len(cn_subjects) > MAX_CN_SUBJECTS:
    cn_subject_ids = list(cn_subjects.keys())
    random.shuffle(cn_subject_ids)
    cn_subject_ids = cn_subject_ids[:MAX_CN_SUBJECTS]
    cn_subjects = {subj: cn_subjects[subj] for subj in cn_subject_ids}
    print(f"  → Limited CN subjects to {len(cn_subjects)}")

# Combine the limited subject sets
subject_files = {**ad_subjects, **cn_subjects}
print(f"\nAfter limiting: {len(subject_files)} total subjects")

# Show subjects with multiple diagnoses (this is expected in longitudinal studies)
multi_diagnosis_subjects = {subj: data for subj, data in subject_files.items() 
                           if len(set(data['labels'])) > 1}
print(f"Subjects with multiple diagnoses: {len(multi_diagnosis_subjects)}")
if multi_diagnosis_subjects:
    print("This is expected in longitudinal studies (e.g., CN → AD progression)")
    print("Top subjects with multiple diagnoses:")
    sorted_subjects = sorted(multi_diagnosis_subjects.items(), key=lambda x: len(x[1]['files']), reverse=True)
    for subj, data in sorted_subjects[:5]:
        unique_labels = set(data['labels'])
        label_names = ['CN' if l == 0 else 'AD' for l in unique_labels]
        print(f"  {subj}: {len(data['files'])} scans (diagnoses: {', '.join(label_names)})")

# Show subjects with multiple scans (same diagnosis)
multi_scan_subjects = {subj: data for subj, data in subject_files.items() if len(data['files']) > 1}
print(f"Subjects with multiple scans: {len(multi_scan_subjects)}")

# Split subjects (not files) into train/val/test to prevent data leakage
subject_ids = list(subject_files.keys())
random.seed(42)
random.shuffle(subject_ids)

# Split by subjects: 80% train, 10% val, 10% test
total_subjects = len(subject_ids)
train_subjects = int(0.85 * total_subjects)
val_subjects = int(0.1 * total_subjects)

train_subject_ids = subject_ids[:train_subjects]
val_subject_ids = subject_ids[train_subjects:train_subjects + val_subjects]
test_subject_ids = subject_ids[train_subjects + val_subjects:]

print(f"\nSubject split (ensuring no data leakage):")
print(f"  Train subjects: {len(train_subject_ids)}")
print(f"  Val subjects: {len(val_subject_ids)}")
print(f"  Test subjects: {len(test_subject_ids)}")

# Verify no overlap between splits
train_set = set(train_subject_ids)
val_set = set(val_subject_ids)
test_set = set(test_subject_ids)
overlaps = train_set.intersection(val_set) or train_set.intersection(test_set) or val_set.intersection(test_set)
if overlaps:
    print(f"❌ ERROR: Data leakage detected! Overlapping subjects: {overlaps}")
else:
    print("✓ No data leakage: All subject splits are mutually exclusive")

# Create file lists from subject splits
def create_file_lists(subject_ids):
    files = []
    labels = []
    for subject_id in subject_ids:
        subject_data = subject_files[subject_id]
        files.extend(subject_data['files'])
        labels.extend(subject_data['labels'])  # Use actual labels for each file
    return files, labels

train_files, train_labels = create_file_lists(train_subject_ids)
val_files, val_labels = create_file_lists(val_subject_ids)
test_files, test_labels = create_file_lists(test_subject_ids)

print(f"\nFile split:")
print(f"  Train files: {len(train_files)} (AD: {sum(train_labels)}, CN: {len(train_labels) - sum(train_labels)})")
print(f"  Val files: {len(val_files)} (AD: {sum(val_labels)}, CN: {len(val_labels) - sum(val_labels)})")
print(f"  Test files: {len(test_files)} (AD: {sum(test_labels)}, CN: {len(test_labels) - sum(test_labels)})")

# Process and save the datasets
if len(train_files) > 0:
    print("\nProcessing datasets...")
    process_and_save(train_files, train_labels, train_path / "ad_cn_train.pkl")
    process_and_save(val_files, val_labels, val_path / "ad_cn_val.pkl")
    process_and_save(test_files, test_labels, test_path / "ad_cn_test.pkl")
    print("✓ All datasets have been processed and saved with no data leakage!")
else:
    print("❌ No files to process. Please check your data organization.")

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

def visualize_processed_image(image_tensor, title="Processed Image", slice_indices=None):
    """
    Visualize slices of a processed 3D image tensor.
    
    Args:
        image_tensor: 3D numpy array or PyTorch tensor of shape (channels, height, width, depth)
        title: Title for the plot
        slice_indices: List of slice indices to show for each view. If None, uses middle slices.
    """
    # Convert to numpy if it's a tensor
    if hasattr(image_tensor, 'numpy'):
        image_data = image_tensor.numpy()
    else:
        image_data = image_tensor
    
    # Remove channel dimension if present (should be shape (1, H, W, D))
    if image_data.ndim == 4:
        image_data = image_data[0]  # Remove channel dimension
    
    # Get dimensions
    h, w, d = image_data.shape
    
    # Set default slice indices if not provided
    if slice_indices is None:
        slice_indices = [h//2, w//2, d//2]  # Middle slices
    
    # Create figure with subplots for three views
    fig = plt.figure(figsize=(15, 5))
    gs = gridspec.GridSpec(1, 3, figure=fig)
    
    # Axial view (horizontal slice)
    ax1 = fig.add_subplot(gs[0])
    axial_slice = image_data[slice_indices[0], :, :]
    im1 = ax1.imshow(axial_slice, cmap='gray', aspect='equal')
    ax1.set_title(f'Axial View (slice {slice_indices[0]})')
    ax1.set_xlabel('Width')
    ax1.set_ylabel('Depth')
    plt.colorbar(im1, ax=ax1, fraction=0.046, pad=0.04)
    
    # Sagittal view (side-to-side slice)
    ax2 = fig.add_subplot(gs[1])
    sagittal_slice = image_data[:, slice_indices[1], :]
    im2 = ax2.imshow(sagittal_slice, cmap='gray', aspect='equal')
    ax2.set_title(f'Sagittal View (slice {slice_indices[1]})')
    ax2.set_xlabel('Depth')
    ax2.set_ylabel('Height')
    plt.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04)
    
    # Coronal view (front-to-back slice)
    ax3 = fig.add_subplot(gs[2])
    coronal_slice = image_data[:, :, slice_indices[2]]
    im3 = ax3.imshow(coronal_slice, cmap='gray', aspect='equal')
    ax3.set_title(f'Coronal View (slice {slice_indices[2]})')
    ax3.set_xlabel('Width')
    ax3.set_ylabel('Height')
    plt.colorbar(im3, ax=ax3, fraction=0.046, pad=0.04)
    
    plt.suptitle(f'{title}\nShape: {image_data.shape}, Range: [{image_data.min():.3f}, {image_data.max():.3f}]')
    plt.tight_layout()
    plt.show()

def load_and_visualize_dataset(dataset_path, num_samples=3):
    """
    Load a processed dataset and visualize some sample images.
    
    Args:
        dataset_path: Path to the pickle file containing processed data
        num_samples: Number of sample images to visualize
    """
    print(f"Loading dataset from: {dataset_path}")
    
    # Load the dataset
    with open(dataset_path, 'rb') as f:
        data = pickle.load(f)
    
    images = data['images']
    labels = data['labels']
    
    print(f"Dataset loaded:")
    print(f"  Images shape: {images.shape}")
    print(f"  Labels shape: {labels.shape}")
    print(f"  Label distribution: AD={torch.sum(labels).item()}, CN={len(labels) - torch.sum(labels).item()}")
    print(f"  Image data type: {images.dtype}")
    print(f"  Image value range: [{images.min().item():.3f}, {images.max().item():.3f}]")
    
    # Visualize a few sample images
    for i in range(min(num_samples, len(images))):
        label_name = "AD" if labels[i].item() == 1 else "CN"
        visualize_processed_image(
            images[i], 
            title=f"Sample {i+1} - {label_name} (label: {labels[i].item()})"
        )
    
    return images, labels


In [None]:
# Visualize processed images from the training set
train_dataset_path = train_path / "ad_cn_train.pkl"

if train_dataset_path.exists():
    print("Loading and visualizing training dataset...")
    images, labels = load_and_visualize_dataset(train_dataset_path, num_samples=10)
else:
    print(f"Training dataset not found at: {train_dataset_path}")
    print("Please run the data processing cells first.")


In [None]:
# Let's also visualize an original image before processing for comparison
def visualize_original_vs_processed(original_path, processed_tensor, title_prefix=""):
    """
    Compare an original image with its processed version.
    """
    # Load original image
    original_img = nib.load(original_path)
    original_data = original_img.get_fdata()
    
    # Handle 4D images
    if original_data.ndim == 4:
        original_data = original_data[:, :, :, 0]
    
    # Get processed data
    if hasattr(processed_tensor, 'numpy'):
        processed_data = processed_tensor.numpy()
    else:
        processed_data = processed_tensor
    
    # Remove channel dimension if present
    if processed_data.ndim == 4:
        processed_data = processed_data[0]
    
    # Create comparison plot
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Original image slices
    h_orig, w_orig, d_orig = original_data.shape
    axes[0, 0].imshow(original_data[h_orig//2, :, :], cmap='gray')
    axes[0, 0].set_title(f'Original - Axial (slice {h_orig//2})')
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(original_data[:, w_orig//2, :], cmap='gray')
    axes[0, 1].set_title(f'Original - Sagittal (slice {w_orig//2})')
    axes[0, 1].axis('off')
    
    axes[0, 2].imshow(original_data[:, :, d_orig//2], cmap='gray')
    axes[0, 2].set_title(f'Original - Coronal (slice {d_orig//2})')
    axes[0, 2].axis('off')
    
    # Processed image slices
    h_proc, w_proc, d_proc = processed_data.shape
    axes[1, 0].imshow(processed_data[h_proc//2, :, :], cmap='gray')
    axes[1, 0].set_title(f'Processed - Axial (slice {h_proc//2})')
    axes[1, 0].axis('off')
    
    axes[1, 1].imshow(processed_data[:, w_proc//2, :], cmap='gray')
    axes[1, 1].set_title(f'Processed - Sagittal (slice {w_proc//2})')
    axes[1, 1].axis('off')
    
    axes[1, 2].imshow(processed_data[:, :, d_proc//2], cmap='gray')
    axes[1, 2].set_title(f'Processed - Coronal (slice {d_proc//2})')
    axes[1, 2].axis('off')
    
    plt.suptitle(f'{title_prefix}\nOriginal: {original_data.shape} [{original_data.min():.3f}, {original_data.max():.3f}] → '
                f'Processed: {processed_data.shape} [{processed_data.min():.3f}, {processed_data.max():.3f}]')
    plt.tight_layout()
    plt.show()

# Show comparison for one sample
if train_dataset_path.exists():
    print("\nComparing original vs processed image...")
    
    # Get the first training file path
    first_file = train_files[0]
    print(f"Original file: {first_file}")
    
    # Load the processed version
    with open(train_dataset_path, 'rb') as f:
        data = pickle.load(f)
    processed_image = data['images'][0]
    
    visualize_original_vs_processed(first_file, processed_image, "Original vs Processed Comparison")
