Train one MLP for each selected slices

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from pathlib import Path
import torch.nn.functional as F # For Global Average Pooling
import os
import gc
# Removed: from tqdm import tqdm # For progress bars

# --- Configuration ---
# Base directory where the multi-slice preprocessed features reside
FEATURE_ROOT = Path("/data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_multi_slice") # Updated Path
CSV_PATH = Path("/data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv")
MODELS_OUTPUT_DIR = Path("./specific_slice_models_multi_dir") # Directory to save best models
DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu" # Use specific GPU if desired
SPLITS = ["train", "validation", "test"] # Define dataset splits

# Define the target slices
# The script will look for files like 'slice_sagittal_80.npy', 'slice_coronal_125.npy', etc.
# inside each subject's folder within the FEATURE_ROOT/split directories.
TARGET_SLICES = [
    {'orientation': 'sagittal', 'index': 80},
    {'orientation': 'sagittal', 'index': 125},
    {'orientation': 'coronal',  'index': 125},
    {'orientation': 'axial',    'index': 80},
]

# --- Hyperparameters ---
EPOCHS = 500 # Adjust as needed
LEARNING_RATE = 1e-4
BATCH_SIZE = 32 # Adjust based on GPU memory
WEIGHT_DECAY = 1e-5
SCHEDULER_PATIENCE_PERCENT = 0.10 # e.g., 10% of EPOCHS for LR scheduler patience
EARLY_STOPPING_PATIENCE = 50 # Number of epochs to wait for val_mae improvement
NUM_WORKERS = 4 # Dataloader workers

# --- Model Definition (Using AgeMLPWithAttentionBN) ---
class AgeMLPWithAttentionBN(nn.Module):
    def __init__(self, input_dim=256, embed_dim=256, num_heads=8,
                 hidden_dim1=128, hidden_dim2=64, hidden_dim3=32, dropout_rate=0.3):
        super(AgeMLPWithAttentionBN, self).__init__()
        if embed_dim % num_heads != 0:
            raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})")
        self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=False)
        self.norm_attn = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1), nn.BatchNorm1d(hidden_dim1), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim1, hidden_dim2), nn.BatchNorm1d(hidden_dim2), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim2, hidden_dim3), nn.BatchNorm1d(hidden_dim3), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim3, 1)
        )

    def forward(self, x):
        # x shape: (batch_size, input_dim)
        x_attn = x.unsqueeze(0) # (1, batch_size, input_dim)
        attn_output, _ = self.attention(x_attn, x_attn, x_attn)
        x = self.norm_attn(x_attn + attn_output) # Add residual
        x = x.squeeze(0) # (batch_size, input_dim)
        output = self.mlp(x)
        return output.squeeze(-1) # (batch_size,)

# --- Dataset Definition (Modified for Multi-Slice Directory Structure) ---
class SliceAgePredictionDataset(Dataset):
    def __init__(self, feature_root, csv_path, split, slice_index, orientation):
        """
        Args:
            feature_root (Path): Base directory containing split folders
                                 (e.g., .../BrainAGE_preprocessed_multi_slice).
            csv_path (Path): Path to the CSV file with metadata.
            split (str): The dataset split ('train', 'validation', or 'test').
            slice_index (int): The specific slice index to load features for.
            orientation (str): The orientation ('sagittal', 'coronal', 'axial').
        """
        self.feature_root = Path(feature_root)
        self.csv_path = Path(csv_path)
        self.split = split
        self.slice_index = slice_index
        self.orientation = orientation
        self.split_dir = self.feature_root / self.split

        print(f"\n[Dataset Init] Split: {self.split}, Orientation: {self.orientation}, Slice: {self.slice_index}")
        print(f"Scanning subjects in: {self.split_dir}")
        print(f"Loading metadata from: {self.csv_path}")

        # Check if directories/files exist
        if not self.feature_root.is_dir():
             raise FileNotFoundError(f"Feature root directory not found: {self.feature_root}")
        if not self.split_dir.is_dir():
            raise FileNotFoundError(f"Split directory not found: {self.split_dir}")
        if not self.csv_path.is_file():
            raise FileNotFoundError(f"CSV file not found: {self.csv_path}")

        # Load CSV and build filename -> age lookup
        try:
            df = pd.read_csv(self.csv_path)
            self.meta_dict = df.set_index('filename')['age'].to_dict()

            # Create mapping from directory name to original filename
            self.subject_dir_to_filename = {}
            nii_gz_keys = {k.replace(".nii.gz", "") for k in self.meta_dict.keys()}
            nii_keys = {k.replace(".nii", "") for k in self.meta_dict.keys()}

            potential_subject_dirs = [d.name for d in self.split_dir.iterdir() if d.is_dir()]
            mapped_count = 0
            for subj_dir_name in potential_subject_dirs:
                matched_key = None
                base_subj_name = subj_dir_name.split('_mri_brainmask')[0] # Example heuristic

                for key_base in nii_gz_keys:
                    if key_base.startswith(base_subj_name):
                         matched_key = key_base + ".nii.gz"
                         break
                if not matched_key:
                     for key_base in nii_keys:
                          if key_base.startswith(base_subj_name):
                               matched_key = key_base + ".nii"
                               break

                if matched_key and matched_key in self.meta_dict:
                    self.subject_dir_to_filename[subj_dir_name] = matched_key
                    mapped_count += 1

            print(f"Loaded metadata for {len(self.meta_dict)} subjects from CSV.")
            print(f"Successfully mapped {mapped_count} subject directories in '{self.split}' to CSV entries.")
            if mapped_count < len(potential_subject_dirs):
                 print(f"Warning: Could not map {len(potential_subject_dirs) - mapped_count} directories to CSV.")

        except Exception as e:
            raise ValueError(f"Error loading or processing CSV {self.csv_path} or mapping directories: {e}")

        # Find all subject directories in the split
        all_subject_dirs = [d for d in self.split_dir.iterdir() if d.is_dir()]
        self.valid_slice_files = []
        missing_meta_count = 0
        missing_slice_count = 0

        print(f"Scanning {len(all_subject_dirs)} potential subject directories for slice {self.orientation}_{self.slice_index}...")

        for subject_dir in all_subject_dirs:
            subject_dir_name = subject_dir.name
            original_filename = self.subject_dir_to_filename.get(subject_dir_name)
            if original_filename:
                slice_filename = f"slice_{self.orientation}_{self.slice_index}.npy"
                expected_slice_path = subject_dir / slice_filename
                if expected_slice_path.is_file():
                    self.valid_slice_files.append(expected_slice_path)
                else:
                    missing_slice_count += 1
            else:
                missing_meta_count += 1

        if missing_meta_count > 0:
            print(f"Info: {missing_meta_count} subject directories were skipped as they couldn't be mapped to metadata.")
        if missing_slice_count > 0:
            print(f"Warning: {missing_slice_count} subjects with metadata were missing slice {self.orientation}_{self.slice_index}.")

        if not self.valid_slice_files:
             raise RuntimeError(f"No valid slice files found for {self.orientation} slice {self.slice_index} in {self.split_dir} with matching metadata.")

        print(f"Found {len(self.valid_slice_files)} valid files for {self.orientation} slice {self.slice_index} in split {self.split}.")


    def __len__(self):
        return len(self.valid_slice_files)

    def __getitem__(self, idx):
        slice_path = self.valid_slice_files[idx]
        subject_dir_name = slice_path.parent.name
        original_filename = self.subject_dir_to_filename.get(subject_dir_name)
        if not original_filename:
             raise ValueError(f"Could not find original filename mapping for subject directory {subject_dir_name}")

        try:
            embedding = np.load(slice_path)
            embedding_tensor = torch.tensor(embedding, dtype=torch.float32)

            # Apply Global Average Pooling (GAP)
            if len(embedding_tensor.shape) == 4 and embedding_tensor.shape[0] == 1:
                 pooled_embedding = F.adaptive_avg_pool2d(embedding_tensor, (1, 1)).squeeze()
            elif len(embedding_tensor.shape) == 3:
                 pooled_embedding = F.adaptive_avg_pool2d(embedding_tensor.unsqueeze(0), (1, 1)).squeeze()
            elif len(embedding_tensor.shape) == 1:
                 pooled_embedding = embedding_tensor
            else:
                 raise ValueError(f"Unexpected embedding shape {embedding_tensor.shape} for {slice_path}")

            if pooled_embedding.shape[0] != 256:
                 raise ValueError(f"Pooled embedding channel dimension is not 256 for {slice_path}: {pooled_embedding.shape}")

            age = self.meta_dict[original_filename]
            age_tensor = torch.tensor(age, dtype=torch.float32)

            return pooled_embedding, age_tensor

        except Exception as e:
            print(f"Error loading or processing file {slice_path}: {e}")
            raise e

# --- Training and Evaluation Functions (Progress Bars Removed) ---
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    num_samples = 0
    # Removed tqdm wrapper
    for features, ages in loader:
        features, ages = features.to(device), ages.to(device)
        optimizer.zero_grad()
        predictions = model(features)
        loss = criterion(predictions, ages)
        loss.backward()
        optimizer.step()
        batch_loss = loss.item()
        total_loss += batch_loss * features.size(0)
        num_samples += features.size(0)
        # Removed pbar.set_postfix
    return total_loss / num_samples

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_mae = 0.0
    num_samples = 0
    # Removed tqdm wrapper
    for features, ages in loader:
        features, ages = features.to(device), ages.to(device)
        predictions = model(features)
        loss = criterion(predictions, ages)
        total_loss += loss.item() * features.size(0)
        mae = F.l1_loss(predictions, ages, reduction='sum')
        total_mae += mae.item()
        num_samples += features.size(0)
        # Removed pbar.set_postfix

    avg_loss = total_loss / num_samples
    avg_mae = total_mae / num_samples
    return avg_loss, avg_mae

# --- Main Training Loop for Each Specific Slice ---
if __name__ == "__main__":
    print(f"Starting specific slice MLP training from multi-slice directory...")
    print(f"Target Slices:")
    for slice_info in TARGET_SLICES:
        print(f"  - Orientation: {slice_info['orientation']}, Index: {slice_info['index']}")
    print(f"Using device: {DEVICE}")
    print(f"Feature Root: {FEATURE_ROOT}")
    print(f"CSV Path: {CSV_PATH}")
    print(f"Models will be saved to: {MODELS_OUTPUT_DIR}")
    print(f"Hyperparameters: Epochs={EPOCHS}, LR={LEARNING_RATE}, Batch={BATCH_SIZE}, ES_Patience={EARLY_STOPPING_PATIENCE}")

    MODELS_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    results = {}

    for slice_info in TARGET_SLICES:
        orientation = slice_info['orientation']
        slice_idx = slice_info['index']

        print(f"\n{'='*20} Training for Slice: {orientation} {slice_idx} {'='*20}")
        print(f"Using feature root: {FEATURE_ROOT}")

        model_save_path = MODELS_OUTPUT_DIR / f"best_model_{orientation}_slice_{slice_idx}.pth"

        try:
            print("Setting up datasets...")
            train_dataset = SliceAgePredictionDataset(FEATURE_ROOT, CSV_PATH, 'train', slice_idx, orientation)
            val_dataset = SliceAgePredictionDataset(FEATURE_ROOT, CSV_PATH, 'validation', slice_idx, orientation)

            print("Setting up dataloaders...")
            train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
            val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

        except (FileNotFoundError, ValueError, RuntimeError) as e:
            print(f"Error initializing datasets/loaders for {orientation} slice {slice_idx}: {e}")
            print(f"Skipping training for this slice.")
            results[f"{orientation}_{slice_idx}"] = {'best_val_mae': float('inf'), 'best_epoch': -1, 'error': str(e)}
            if 'train_dataset' in locals(): del train_dataset
            if 'val_dataset' in locals(): del val_dataset
            if 'train_loader' in locals(): del train_loader
            if 'val_loader' in locals(): del val_loader
            gc.collect()
            if DEVICE.startswith('cuda'): torch.cuda.empty_cache()
            continue

        print("Initializing model, optimizer, scheduler...")
        model = AgeMLPWithAttentionBN().to(DEVICE)
        criterion = nn.L1Loss()
        optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        scheduler_patience_epochs = max(1, int(EPOCHS * SCHEDULER_PATIENCE_PERCENT))
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=scheduler_patience_epochs, factor=0.5, verbose=False)

        best_val_mae = float('inf')
        epochs_no_improve = 0
        best_epoch = -1

        print(f"\n--- Starting Training for {orientation} Slice {slice_idx} ---")
        for epoch in range(EPOCHS):
            train_loss = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE)
            val_loss, val_mae = evaluate(model, val_loader, criterion, DEVICE)

            current_lr = optimizer.param_groups[0]['lr']
            # Print epoch results without progress bar updates
            print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val MAE: {val_mae:.3f} | LR: {current_lr:.1e}")

            scheduler.step(val_loss)

            if val_mae < best_val_mae:
                best_val_mae = val_mae
                best_epoch = epoch + 1
                epochs_no_improve = 0
                torch.save(model.state_dict(), model_save_path)
                print(f"  -> New best Val MAE: {best_val_mae:.3f}. Saved model to {model_save_path}")
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
                print(f"\nEarly stopping triggered for {orientation} slice {slice_idx} after {EARLY_STOPPING_PATIENCE} epochs without improvement.")
                break

            if (epoch + 1) % 10 == 0:
                 gc.collect()
                 if DEVICE.startswith('cuda'):
                     torch.cuda.empty_cache()

        print(f"\n--- Training Finished for {orientation} Slice {slice_idx} ---")
        slice_key = f"{orientation}_{slice_idx}"
        if best_epoch != -1:
             print(f"Best Validation MAE: {best_val_mae:.3f} achieved at epoch {best_epoch}")
             results[slice_key] = {'best_val_mae': best_val_mae, 'best_epoch': best_epoch}
        else:
             print("No improvement found during training.")
             results[slice_key] = {'best_val_mae': float('inf'), 'best_epoch': -1, 'error': 'No improvement'}

        # Optional Test Set Evaluation Code (remains the same logic)
        # ...

        del model, optimizer, scheduler, train_dataset, val_dataset, train_loader, val_loader
        gc.collect()
        if DEVICE.startswith('cuda'):
            torch.cuda.empty_cache()

    # --- Final Summary ---
    print("\n\n{'='*20} Overall Training Summary {'='*20}")
    for slice_key, res in results.items():
        if 'error' in res and res['error'] is not None:
            print(f"Slice {slice_key}: Error - {res['error']}")
        elif res['best_epoch'] == -1:
             print(f"Slice {slice_key}: No improvement found.")
        else:
            test_mae_str = ""
            if 'test_mae' in res:
                test_mae_val = res['test_mae']
                test_mae_str = f" | Test MAE: {test_mae_val:.3f}" if isinstance(test_mae_val, (int, float)) else f" | Test MAE: {test_mae_val}"

            print(f"Slice {slice_key}: Best Val MAE = {res['best_val_mae']:.3f} (Epoch {res['best_epoch']}){test_mae_str}")

    print("\nSpecific slice model training finished.")
    print(f"Best models saved in: {MODELS_OUTPUT_DIR}")

Starting specific slice MLP training from multi-slice directory...
Target Slices:
  - Orientation: sagittal, Index: 80
  - Orientation: sagittal, Index: 125
  - Orientation: coronal, Index: 125
  - Orientation: axial, Index: 80
Using device: cuda:1
Feature Root: /data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_multi_slice
CSV Path: /data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv
Models will be saved to: specific_slice_models_multi_dir
Hyperparameters: Epochs=500, LR=0.0001, Batch=32, ES_Patience=50

Using feature root: /data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_multi_slice
Setting up datasets...

[Dataset Init] Split: train, Orientation: sagittal, Slice: 80
Scanning subjects in: /data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_multi_slice/train
Loading metadata from: /data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv
Loaded metadata for 2850 subjects from CSV.
Successfully mapped 2274 subject directories in 'train' 

Average predictions from 4 MLPs

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from pathlib import Path
import torch.nn.functional as F # For Global Average Pooling
import os
import gc

# --- Configuration ---
# Base directory where the multi-slice preprocessed features reside
FEATURE_ROOT = Path("/data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_multi_slice") # Updated Path
CSV_PATH = Path("/data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv")
MODELS_OUTPUT_DIR = Path("./specific_slice_models_multi_dir") # Directory to save best models
DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu" # Use specific GPU if desired
SPLITS = ["train", "validation", "test"] # Define dataset splits

# Define the target slices
# The script will look for files like 'slice_sagittal_80.npy', 'slice_coronal_125.npy', etc.
# inside each subject's folder within the FEATURE_ROOT/split directories.
TARGET_SLICES = [
    {'orientation': 'sagittal', 'index': 80},
    {'orientation': 'sagittal', 'index': 125},
    {'orientation': 'coronal',  'index': 125},
    {'orientation': 'axial',    'index': 80},
]

# --- Hyperparameters ---
EPOCHS = 500 # Adjust as needed
LEARNING_RATE = 1e-4
BATCH_SIZE = 32 # Adjust based on GPU memory
WEIGHT_DECAY = 1e-5
SCHEDULER_PATIENCE_PERCENT = 0.10 # e.g., 10% of EPOCHS for LR scheduler patience
EARLY_STOPPING_PATIENCE = 50 # Number of epochs to wait for val_mae improvement
NUM_WORKERS = 4 # Dataloader workers

# --- Model Definition (Using AgeMLPWithAttentionBN) ---
class AgeMLPWithAttentionBN(nn.Module):
    def __init__(self, input_dim=256, embed_dim=256, num_heads=8,
                 hidden_dim1=128, hidden_dim2=64, hidden_dim3=32, dropout_rate=0.3):
        super(AgeMLPWithAttentionBN, self).__init__()
        if embed_dim % num_heads != 0:
            raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})")
        self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=False)
        self.norm_attn = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1), nn.BatchNorm1d(hidden_dim1), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim1, hidden_dim2), nn.BatchNorm1d(hidden_dim2), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim2, hidden_dim3), nn.BatchNorm1d(hidden_dim3), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim3, 1)
        )

    def forward(self, x):
        # x shape: (batch_size, input_dim)
        x_attn = x.unsqueeze(0) # (1, batch_size, input_dim)
        attn_output, _ = self.attention(x_attn, x_attn, x_attn)
        x = self.norm_attn(x_attn + attn_output) # Add residual
        x = x.squeeze(0) # (batch_size, input_dim)
        output = self.mlp(x)
        return output.squeeze(-1) # (batch_size,)

# --- Dataset Definition (Modified for Multi-Slice Directory Structure) ---
class SliceAgePredictionDataset(Dataset):
    def __init__(self, feature_root, csv_path, split, slice_index, orientation):
        """
        Args:
            feature_root (Path): Base directory containing split folders
                                 (e.g., .../BrainAGE_preprocessed_multi_slice).
            csv_path (Path): Path to the CSV file with metadata.
            split (str): The dataset split ('train', 'validation', or 'test').
            slice_index (int): The specific slice index to load features for.
            orientation (str): The orientation ('sagittal', 'coronal', 'axial').
        """
        self.feature_root = Path(feature_root)
        self.csv_path = Path(csv_path)
        self.split = split
        self.slice_index = slice_index
        self.orientation = orientation
        self.split_dir = self.feature_root / self.split

        print(f"\n[Dataset Init] Split: {self.split}, Orientation: {self.orientation}, Slice: {self.slice_index}")
        print(f"Scanning subjects in: {self.split_dir}")
        print(f"Loading metadata from: {self.csv_path}")

        # Check if directories/files exist
        if not self.feature_root.is_dir():
             raise FileNotFoundError(f"Feature root directory not found: {self.feature_root}")
        if not self.split_dir.is_dir():
            raise FileNotFoundError(f"Split directory not found: {self.split_dir}")
        if not self.csv_path.is_file():
            raise FileNotFoundError(f"CSV file not found: {self.csv_path}")

        # Load CSV and build filename -> age lookup
        try:
            df = pd.read_csv(self.csv_path)
            self.meta_dict = df.set_index('filename')['age'].to_dict()

            # Create mapping from directory name to original filename
            self.subject_dir_to_filename = {}
            nii_gz_keys = {k.replace(".nii.gz", "") for k in self.meta_dict.keys()}
            nii_keys = {k.replace(".nii", "") for k in self.meta_dict.keys()}

            potential_subject_dirs = [d.name for d in self.split_dir.iterdir() if d.is_dir()]
            mapped_count = 0
            for subj_dir_name in potential_subject_dirs:
                matched_key = None
                base_subj_name = subj_dir_name.split('_mri_brainmask')[0] # Example heuristic

                for key_base in nii_gz_keys:
                    if key_base.startswith(base_subj_name):
                         matched_key = key_base + ".nii.gz"
                         break
                if not matched_key:
                     for key_base in nii_keys:
                          if key_base.startswith(base_subj_name):
                               matched_key = key_base + ".nii"
                               break

                if matched_key and matched_key in self.meta_dict:
                    self.subject_dir_to_filename[subj_dir_name] = matched_key
                    mapped_count += 1

            print(f"Loaded metadata for {len(self.meta_dict)} subjects from CSV.")
            print(f"Successfully mapped {mapped_count} subject directories in '{self.split}' to CSV entries.")
            if mapped_count < len(potential_subject_dirs):
                 print(f"Warning: Could not map {len(potential_subject_dirs) - mapped_count} directories to CSV.")

        except Exception as e:
            raise ValueError(f"Error loading or processing CSV {self.csv_path} or mapping directories: {e}")

        # Find all subject directories in the split
        all_subject_dirs = [d for d in self.split_dir.iterdir() if d.is_dir()]
        self.valid_slice_files = []
        missing_meta_count = 0
        missing_slice_count = 0

        print(f"Scanning {len(all_subject_dirs)} potential subject directories for slice {self.orientation}_{self.slice_index}...")

        for subject_dir in all_subject_dirs:
            subject_dir_name = subject_dir.name
            original_filename = self.subject_dir_to_filename.get(subject_dir_name)
            if original_filename:
                slice_filename = f"slice_{self.orientation}_{self.slice_index}.npy"
                expected_slice_path = subject_dir / slice_filename
                if expected_slice_path.is_file():
                    self.valid_slice_files.append(expected_slice_path)
                else:
                    missing_slice_count += 1
            else:
                missing_meta_count += 1

        if missing_meta_count > 0:
            print(f"Info: {missing_meta_count} subject directories were skipped as they couldn't be mapped to metadata.")
        if missing_slice_count > 0:
            print(f"Warning: {missing_slice_count} subjects with metadata were missing slice {self.orientation}_{self.slice_index}.")

        if not self.valid_slice_files:
             raise RuntimeError(f"No valid slice files found for {self.orientation} slice {self.slice_index} in {self.split_dir} with matching metadata.")

        print(f"Found {len(self.valid_slice_files)} valid files for {self.orientation} slice {self.slice_index} in split {self.split}.")


    def __len__(self):
        return len(self.valid_slice_files)

    def __getitem__(self, idx):
        slice_path = self.valid_slice_files[idx]
        subject_dir_name = slice_path.parent.name
        original_filename = self.subject_dir_to_filename.get(subject_dir_name)
        if not original_filename:
             raise ValueError(f"Could not find original filename mapping for subject directory {subject_dir_name}")

        try:
            embedding = np.load(slice_path)
            embedding_tensor = torch.tensor(embedding, dtype=torch.float32)

            # Apply Global Average Pooling (GAP)
            if len(embedding_tensor.shape) == 4 and embedding_tensor.shape[0] == 1:
                 pooled_embedding = F.adaptive_avg_pool2d(embedding_tensor, (1, 1)).squeeze()
            elif len(embedding_tensor.shape) == 3:
                 pooled_embedding = F.adaptive_avg_pool2d(embedding_tensor.unsqueeze(0), (1, 1)).squeeze()
            elif len(embedding_tensor.shape) == 1:
                 pooled_embedding = embedding_tensor
            else:
                 raise ValueError(f"Unexpected embedding shape {embedding_tensor.shape} for {slice_path}")

            if pooled_embedding.shape[0] != 256:
                 raise ValueError(f"Pooled embedding channel dimension is not 256 for {slice_path}: {pooled_embedding.shape}")

            age = self.meta_dict[original_filename]
            age_tensor = torch.tensor(age, dtype=torch.float32)

            return pooled_embedding, age_tensor

        except Exception as e:
            print(f"Error loading or processing file {slice_path}: {e}")
            raise e

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from pathlib import Path
import os
import gc
from collections import defaultdict

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# --- Test Set Evaluation ---
print("\n--- Starting Test Set Evaluation ---")

models = {}
test_loaders = {}
test_results = {}
all_predictions = defaultdict(list)
all_ages_collected = defaultdict(list) # To store ages collected per model run for verification

# 1. Load Models and Create Test Loaders
print("Loading models and creating test datasets/loaders...")
# Ensure required variables (DEVICE, MODELS_OUTPUT_DIR, TARGET_SLICES, FEATURE_ROOT, CSV_PATH,
# AgeMLPWithAttentionBN, SliceAgePredictionDataset, BATCH_SIZE, NUM_WORKERS)
# are available from the previous cell.
successful_loads = 0
for slice_info in TARGET_SLICES:
    orientation = slice_info['orientation']
    slice_idx = slice_info['index']
    slice_key = f"{orientation}_{slice_idx}"
    model_save_path = MODELS_OUTPUT_DIR / f"best_model_{orientation}_slice_{slice_idx}.pth"

    if model_save_path.is_file():
        print(f"Loading test data for {slice_key}...")
        try:
            # Use 'test' split for evaluation
            test_dataset = SliceAgePredictionDataset(FEATURE_ROOT, CSV_PATH, 'test', slice_idx, orientation)
            # shuffle=False is crucial for potentially aligning predictions later
            test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
            test_loaders[slice_key] = test_loader

            print(f"Loading model for {slice_key} from {model_save_path}...")
            model = AgeMLPWithAttentionBN().to(DEVICE)
            # Load state dict, ensuring map_location handles CPU/GPU transfers
            model.load_state_dict(torch.load(model_save_path, map_location=DEVICE))
            model.eval() # Set model to evaluation mode
            models[slice_key] = model
            print(f"Successfully loaded model and data for {slice_key}.")
            successful_loads += 1

        except (FileNotFoundError, ValueError, RuntimeError) as e:
            print(f"Error initializing test dataset/loader for {slice_key}: {e}")
            print(f"Skipping evaluation for this slice.")
        except Exception as e:
             print(f"An unexpected error occurred loading data/model for {slice_key}: {e}")
             print(f"Skipping evaluation for this slice.")

    else:
        print(f"Model file not found for {slice_key} at {model_save_path}. Skipping.")

if not models:
    print("No models were loaded successfully. Exiting evaluation.")
elif successful_loads < len(TARGET_SLICES):
    print(f"Warning: Only loaded {successful_loads}/{len(TARGET_SLICES)} models. Averaging will be based on loaded models.")

if models: # Proceed only if at least one model was loaded
    # Assume criterion = nn.L1Loss() was used for training/validation MAE reporting
    criterion = nn.L1Loss()

    # 2. Evaluate Individual Models and Collect Predictions
    print("\nEvaluating individual models on the test set...")
    with torch.no_grad():
        for slice_key, model in models.items():
            print(f"Evaluating {slice_key}...")
            loader = test_loaders[slice_key]
            total_loss = 0.0
            total_mae = 0.0
            num_samples = 0

            for i, (features, ages) in enumerate(loader):
                features, ages = features.to(DEVICE), ages.to(DEVICE)
                predictions = model(features)

                # Store predictions and ages for averaging later
                # Ensure storing happens in the same order by iterating through loader
                all_predictions[slice_key].extend(predictions.cpu().numpy())
                # Store ages per model run to verify consistency later
                all_ages_collected[slice_key].extend(ages.cpu().numpy())

                # Calculate metrics for individual model report
                loss = criterion(predictions, ages)
                total_loss += loss.item() * features.size(0)
                mae = F.l1_loss(predictions, ages, reduction='sum') # Sum MAE over batch
                total_mae += mae.item()
                num_samples += features.size(0)

            if num_samples > 0:
                avg_loss = total_loss / num_samples
                avg_mae = total_mae / num_samples
                test_results[slice_key] = {'test_mae': avg_mae, 'test_loss': avg_loss, 'num_samples': num_samples}
                print(f"  {slice_key} - Test MAE: {avg_mae:.3f}, Test Loss: {avg_loss:.4f}, Samples: {num_samples}")
            else:
                 print(f"  {slice_key} - No samples evaluated.")
                 test_results[slice_key] = {'test_mae': float('inf'), 'test_loss': float('inf'), 'num_samples': 0}


    # 3. Verify Data Consistency and Calculate Averaged MAE
    print("\nCalculating MAE for averaged predictions...")

    # Check if we have predictions from all loaded models
    if len(all_predictions) != len(models):
        print("Error: Prediction collection failed for some models. Cannot compute average.")
    else:
        # Check for consistent number of samples across models
        num_samples_list = [len(preds) for preds in all_predictions.values()]
        if len(set(num_samples_list)) > 1:
            print(f"Error: Inconsistent number of predictions across models: {num_samples_list}. Cannot compute average reliably.")
            print("This might happen if test datasets for different slices have different numbers of valid files.")
        elif not num_samples_list or num_samples_list[0] == 0:
             print("No test samples found to calculate average MAE.")
        else:
            num_test_samples = num_samples_list[0]
            print(f"Number of test samples per model: {num_test_samples}")

            # Verify that the true ages collected are consistent across runs
            # This assumes shuffle=False in DataLoader worked as expected
            first_key = list(all_ages_collected.keys())[0]
            reference_ages = np.array(all_ages_collected[first_key])
            consistent_ages = True
            for key in all_ages_collected:
                if len(all_ages_collected[key]) != num_test_samples or not np.array_equal(reference_ages, np.array(all_ages_collected[key])):
                    consistent_ages = False
                    print(f"Error: True ages collected for model {key} (count: {len(all_ages_collected[key])}) do not match the reference ages (count: {len(reference_ages)}).")
                    break

            if not consistent_ages:
                 print("Cannot compute average MAE due to inconsistent ground truth ages across loaders.")
            else:
                # Calculate average predictions
                avg_preds = np.zeros(num_test_samples)
                print(f"Averaging predictions from {len(all_predictions)} models: {list(all_predictions.keys())}")
                for slice_key in all_predictions:
                    avg_preds += np.array(all_predictions[slice_key])
                avg_preds /= len(all_predictions) # Divide by the number of models we actually have predictions for

                # Calculate MAE for the averaged predictions using the verified reference_ages
                average_mae = np.mean(np.abs(avg_preds - reference_ages))
                print(f"\n---> Average Prediction Test MAE: {average_mae:.3f} <---")


# 4. Clean up
print("\nCleaning up resources...")
del models, test_loaders, all_predictions, all_ages_collected, test_results
if 'model' in locals(): del model # Ensure loop variables are cleared
if 'test_loader' in locals(): del test_loader
if 'test_dataset' in locals(): del test_dataset
gc.collect()
if DEVICE.startswith('cuda'):
    torch.cuda.empty_cache()

print("\nTest evaluation finished.")


--- Starting Test Set Evaluation ---
Loading models and creating test datasets/loaders...
Loading test data for sagittal_80...

[Dataset Init] Split: test, Orientation: sagittal, Slice: 80
Scanning subjects in: /data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_multi_slice/test
Loading metadata from: /data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv
Loaded metadata for 2850 subjects from CSV.
Successfully mapped 296 subject directories in 'test' to CSV entries.
Scanning 296 potential subject directories for slice sagittal_80...
Found 296 valid files for sagittal slice 80 in split test.
Loading model for sagittal_80 from specific_slice_models_multi_dir/best_model_sagittal_slice_80.pth...
Successfully loaded model and data for sagittal_80.
Loading test data for sagittal_125...

[Dataset Init] Split: test, Orientation: sagittal, Slice: 125
Scanning subjects in: /data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_multi_slice/test
Loading metadata from: /data/

Early fusion using concatenation

In [4]:
# In the third code cell
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from pathlib import Path
import torch.nn.functional as F # For Global Average Pooling
import os
import gc
from collections import defaultdict # Keep for potential future use if needed

# --- Configuration (Mostly unchanged, ensure these are defined) ---
FEATURE_ROOT = Path("/data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_multi_slice")
CSV_PATH = Path("/data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv")
MODELS_OUTPUT_DIR = Path("./combined_slice_model") # Updated output directory
DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu"
SPLITS = ["train", "validation", "test"]

TARGET_SLICES = [
    {'orientation': 'sagittal', 'index': 80},
    {'orientation': 'sagittal', 'index': 125},
    {'orientation': 'coronal',  'index': 125},
    {'orientation': 'axial',    'index': 80},
]
CONCAT_DIM = 256 * len(TARGET_SLICES) # 4 * 256 = 1024

# --- Hyperparameters (Mostly unchanged) ---
EPOCHS = 500
LEARNING_RATE = 1e-4
BATCH_SIZE = 32 # Adjust based on GPU memory for the larger input
WEIGHT_DECAY = 1e-5
SCHEDULER_PATIENCE_PERCENT = 0.10
EARLY_STOPPING_PATIENCE = 50
NUM_WORKERS = 4

# --- Model Definition (Adjusted Input Dimension) ---
class AgeMLPWithAttentionBN(nn.Module):
    # Adjusted input_dim and embed_dim defaults
    def __init__(self, input_dim=1024, embed_dim=1024, num_heads=16, # Adjusted num_heads for divisibility
                 hidden_dim1=128, hidden_dim2=64, hidden_dim3=32, dropout_rate=0.3):
        super(AgeMLPWithAttentionBN, self).__init__()
        if embed_dim % num_heads != 0:
             # Adjust num_heads if necessary, e.g., find a divisor of input_dim
             # For input_dim=1024, potential heads: 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024
             # Let's default to 16, but raise error if user provides incompatible combination
            raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})")
        self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=False)
        self.norm_attn = nn.LayerNorm(embed_dim)
        # MLP input layer now takes the concatenated dimension
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1), nn.BatchNorm1d(hidden_dim1), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim1, hidden_dim2), nn.BatchNorm1d(hidden_dim2), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim2, hidden_dim3), nn.BatchNorm1d(hidden_dim3), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim3, 1)
        )

    def forward(self, x):
        # x shape: (batch_size, input_dim) which is now 1024
        x_attn = x.unsqueeze(0) # (1, batch_size, input_dim)
        attn_output, _ = self.attention(x_attn, x_attn, x_attn)
        x = self.norm_attn(x_attn + attn_output) # Add residual
        x = x.squeeze(0) # (batch_size, input_dim)
        output = self.mlp(x)
        return output.squeeze(-1) # (batch_size,)

# --- Dataset Definition (Modified for Concatenating Slices) ---
class CombinedSliceDataset(Dataset):
    def __init__(self, feature_root, csv_path, split, target_slices):
        """
        Args:
            feature_root (Path): Base directory containing split folders.
            csv_path (Path): Path to the CSV file with metadata.
            split (str): The dataset split ('train', 'validation', or 'test').
            target_slices (list): List of dicts, each specifying {'orientation': str, 'index': int}.
        """
        self.feature_root = Path(feature_root)
        self.csv_path = Path(csv_path)
        self.split = split
        self.target_slices = target_slices
        self.split_dir = self.feature_root / self.split

        print(f"\n[Dataset Init - Combined] Split: {self.split}")
        print(f"Target Slices: {self.target_slices}")
        print(f"Scanning subjects in: {self.split_dir}")
        print(f"Loading metadata from: {self.csv_path}")

        # Basic checks
        if not self.feature_root.is_dir(): raise FileNotFoundError(f"Feature root not found: {self.feature_root}")
        if not self.split_dir.is_dir(): raise FileNotFoundError(f"Split directory not found: {self.split_dir}")
        if not self.csv_path.is_file(): raise FileNotFoundError(f"CSV file not found: {self.csv_path}")

        # Load CSV and build mapping (same logic as before)
        try:
            df = pd.read_csv(self.csv_path)
            self.meta_dict = df.set_index('filename')['age'].to_dict()
            self.subject_dir_to_filename = {}
            nii_gz_keys = {k.replace(".nii.gz", "") for k in self.meta_dict.keys()}
            nii_keys = {k.replace(".nii", "") for k in self.meta_dict.keys()}
            potential_subject_dirs = [d.name for d in self.split_dir.iterdir() if d.is_dir()]
            mapped_count = 0
            for subj_dir_name in potential_subject_dirs:
                matched_key = None
                # Using the same heuristic, adjust if needed based on actual directory names
                base_subj_name = subj_dir_name.split('_mri_brainmask')[0]
                for key_base in nii_gz_keys:
                    if key_base.startswith(base_subj_name):
                         matched_key = key_base + ".nii.gz"
                         break
                if not matched_key:
                     for key_base in nii_keys:
                          if key_base.startswith(base_subj_name):
                               matched_key = key_base + ".nii"
                               break
                if matched_key and matched_key in self.meta_dict:
                    self.subject_dir_to_filename[subj_dir_name] = matched_key
                    mapped_count += 1
            print(f"Loaded metadata for {len(self.meta_dict)} subjects from CSV.")
            print(f"Successfully mapped {mapped_count} subject directories in '{self.split}' to CSV entries.")
            if mapped_count < len(potential_subject_dirs):
                 print(f"Warning: Could not map {len(potential_subject_dirs) - mapped_count} directories to CSV.")
        except Exception as e:
            raise ValueError(f"Error loading/processing CSV or mapping directories: {e}")

        # Find subject directories that have ALL required slice files
        self.valid_subject_dirs = []
        subjects_missing_slices = 0
        subjects_without_meta = 0

        print(f"Scanning {len(potential_subject_dirs)} potential subject directories for all target slices...")
        all_subject_dirs_in_split = [d for d in self.split_dir.iterdir() if d.is_dir()]

        for subject_dir in all_subject_dirs_in_split:
            subject_dir_name = subject_dir.name
            # Check if metadata exists for this subject directory
            if subject_dir_name in self.subject_dir_to_filename:
                all_slices_found = True
                for slice_info in self.target_slices:
                    slice_filename = f"slice_{slice_info['orientation']}_{slice_info['index']}.npy"
                    expected_slice_path = subject_dir / slice_filename
                    if not expected_slice_path.is_file():
                        all_slices_found = False
                        break # No need to check further slices for this subject
                if all_slices_found:
                    self.valid_subject_dirs.append(subject_dir)
                else:
                    subjects_missing_slices += 1
            else:
                subjects_without_meta += 1 # Count dirs we couldn't map to CSV

        if subjects_without_meta > 0:
             print(f"Info: {subjects_without_meta} subject directories were skipped (no metadata mapping).")
        if subjects_missing_slices > 0:
            print(f"Warning: {subjects_missing_slices} subjects with metadata were missing at least one target slice.")

        if not self.valid_subject_dirs:
             raise RuntimeError(f"No subjects found with all required slices in {self.split_dir} with matching metadata.")

        print(f"Found {len(self.valid_subject_dirs)} valid subjects with all required slices in split {self.split}.")

    def __len__(self):
        return len(self.valid_subject_dirs)

    def __getitem__(self, idx):
        subject_dir = self.valid_subject_dirs[idx]
        subject_dir_name = subject_dir.name
        original_filename = self.subject_dir_to_filename.get(subject_dir_name)
        if not original_filename:
             raise ValueError(f"Internal Error: Could not find original filename for valid subject dir {subject_dir_name}")

        pooled_embeddings = []
        try:
            for slice_info in self.target_slices:
                slice_filename = f"slice_{slice_info['orientation']}_{slice_info['index']}.npy"
                slice_path = subject_dir / slice_filename
                embedding = np.load(slice_path)
                embedding_tensor = torch.tensor(embedding, dtype=torch.float32)

                # Apply Global Average Pooling (GAP) - same logic as before
                if len(embedding_tensor.shape) == 4 and embedding_tensor.shape[0] == 1:
                     pooled = F.adaptive_avg_pool2d(embedding_tensor, (1, 1)).squeeze() # -> (C,) e.g., (256,)
                elif len(embedding_tensor.shape) == 3: # Assume (C, H, W)
                     pooled = F.adaptive_avg_pool2d(embedding_tensor.unsqueeze(0), (1, 1)).squeeze() # Add batch, pool, squeeze -> (C,)
                elif len(embedding_tensor.shape) == 1: # Assume already pooled (C,)
                     pooled = embedding_tensor
                else:
                     raise ValueError(f"Unexpected embedding shape {embedding_tensor.shape} for {slice_path}")

                if pooled.shape[0] != 256: # Ensure channel dim is correct after pooling
                     raise ValueError(f"Pooled embedding channel dim is not 256 for {slice_path}: {pooled.shape}")

                pooled_embeddings.append(pooled)

            # Concatenate the pooled embeddings
            concatenated_embedding = torch.cat(pooled_embeddings, dim=0) # Shape: (1024,)

            # Get age
            age = self.meta_dict[original_filename]
            age_tensor = torch.tensor(age, dtype=torch.float32)

            return concatenated_embedding, age_tensor

        except Exception as e:
            print(f"Error loading/processing slices for subject {subject_dir_name}: {e}")
            raise e # Re-raise

# --- Training and Evaluation Functions (Unchanged from previous notebook cell) ---
# Assume train_one_epoch and evaluate functions are defined here as they were before
# (without the tqdm wrappers if running in a notebook context where they were removed)
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    num_samples = 0
    for features, ages in loader:
        features, ages = features.to(device), ages.to(device)
        optimizer.zero_grad()
        predictions = model(features)
        loss = criterion(predictions, ages)
        loss.backward()
        optimizer.step()
        batch_loss = loss.item()
        total_loss += batch_loss * features.size(0)
        num_samples += features.size(0)
    return total_loss / num_samples

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_mae = 0.0
    num_samples = 0
    for features, ages in loader:
        features, ages = features.to(device), ages.to(device)
        predictions = model(features)
        loss = criterion(predictions, ages)
        total_loss += loss.item() * features.size(0)
        mae = F.l1_loss(predictions, ages, reduction='sum')
        total_mae += mae.item()
        num_samples += features.size(0)
    avg_loss = total_loss / num_samples if num_samples > 0 else 0
    avg_mae = total_mae / num_samples if num_samples > 0 else float('inf')
    return avg_loss, avg_mae


# --- Main Training Loop (Single Model) ---
if __name__ == "__main__": # Or run directly in notebook cell
    print(f"Starting combined MLP training using concatenated slices...")
    print(f"Target Slices for Concatenation: {TARGET_SLICES}")
    print(f"Concatenated Input Dimension: {CONCAT_DIM}")
    print(f"Using device: {DEVICE}")
    print(f"Feature Root: {FEATURE_ROOT}")
    print(f"CSV Path: {CSV_PATH}")
    print(f"Model will be saved to: {MODELS_OUTPUT_DIR}")
    print(f"Hyperparameters: Epochs={EPOCHS}, LR={LEARNING_RATE}, Batch={BATCH_SIZE}, ES_Patience={EARLY_STOPPING_PATIENCE}")

    MODELS_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    model_save_path = MODELS_OUTPUT_DIR / "best_combined_model.pth"

    try:
        print("\nSetting up combined datasets...")
        train_dataset = CombinedSliceDataset(FEATURE_ROOT, CSV_PATH, 'train', TARGET_SLICES)
        val_dataset = CombinedSliceDataset(FEATURE_ROOT, CSV_PATH, 'validation', TARGET_SLICES)
        # Optional: test_dataset = CombinedSliceDataset(FEATURE_ROOT, CSV_PATH, 'test', TARGET_SLICES)

        print("Setting up dataloaders...")
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
        # Optional: test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

    except (FileNotFoundError, ValueError, RuntimeError) as e:
        print(f"\nError initializing combined datasets/loaders: {e}")
        print("Stopping training.")
        # Exit or handle error appropriately
        exit() # Or raise e

    print("\nInitializing model, optimizer, scheduler...")
    # Instantiate model with the correct concatenated dimensions
    model = AgeMLPWithAttentionBN(input_dim=CONCAT_DIM, embed_dim=CONCAT_DIM, num_heads=16).to(DEVICE) # Ensure num_heads is compatible
    criterion = nn.L1Loss() # MAE Loss
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler_patience_epochs = max(1, int(EPOCHS * SCHEDULER_PATIENCE_PERCENT))
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=scheduler_patience_epochs, factor=0.5, verbose=False)

    best_val_mae = float('inf')
    epochs_no_improve = 0
    best_epoch = -1

    print(f"\n--- Starting Combined Model Training ---")
    for epoch in range(EPOCHS):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE)
        val_loss, val_mae = evaluate(model, val_loader, criterion, DEVICE)

        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val MAE: {val_mae:.3f} | LR: {current_lr:.1e}")

        scheduler.step(val_loss)

        if val_mae < best_val_mae:
            best_val_mae = val_mae
            best_epoch = epoch + 1
            epochs_no_improve = 0
            torch.save(model.state_dict(), model_save_path)
            print(f"  -> New best Val MAE: {best_val_mae:.3f}. Saved model to {model_save_path}")
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"\nEarly stopping triggered after {EARLY_STOPPING_PATIENCE} epochs without improvement.")
            break

        if (epoch + 1) % 10 == 0:
             gc.collect()
             if DEVICE.startswith('cuda'):
                 torch.cuda.empty_cache()

    print(f"\n--- Combined Model Training Finished ---")
    if best_epoch != -1:
         print(f"Best Validation MAE: {best_val_mae:.3f} achieved at epoch {best_epoch}")
         print(f"Best model saved to: {model_save_path}")
         # Store results if needed for comparison later
         final_results = {'best_val_mae': best_val_mae, 'best_epoch': best_epoch}
    else:
         print("No improvement found during training.")
         final_results = {'best_val_mae': float('inf'), 'best_epoch': -1, 'error': 'No improvement'}


    # --- Optional: Evaluate on Test Set ---
    # if 'test_loader' in locals() and model_save_path.exists():
    #     print(f"\n--- Evaluating Combined Model on Test Set ---")
    #     try:
    #         # Load the best model state
    #         model.load_state_dict(torch.load(model_save_path))
    #         test_loss, test_mae = evaluate(model, test_loader, criterion, DEVICE)
    #         print(f"Combined Model Test Loss: {test_loss:.4f} | Test MAE: {test_mae:.3f}")
    #         final_results['test_mae'] = test_mae
    #     except Exception as e:
    #         print(f"Error during test set evaluation: {e}")
    # else:
    #     print("\nSkipping test set evaluation.")

    # Clean up
    print("\nCleaning up resources...")
    del model, optimizer, scheduler, train_dataset, val_dataset, train_loader, val_loader
    # if 'test_dataset' in locals(): del test_dataset
    # if 'test_loader' in locals(): del test_loader
    gc.collect()
    if DEVICE.startswith('cuda'):
        torch.cuda.empty_cache()

    print("\nCombined slice model training finished.")

Starting combined MLP training using concatenated slices...
Target Slices for Concatenation: [{'orientation': 'sagittal', 'index': 80}, {'orientation': 'sagittal', 'index': 125}, {'orientation': 'coronal', 'index': 125}, {'orientation': 'axial', 'index': 80}]
Concatenated Input Dimension: 1024
Using device: cuda:1
Feature Root: /data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_multi_slice
CSV Path: /data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv
Model will be saved to: combined_slice_model
Hyperparameters: Epochs=500, LR=0.0001, Batch=32, ES_Patience=50

Setting up combined datasets...

[Dataset Init - Combined] Split: train
Target Slices: [{'orientation': 'sagittal', 'index': 80}, {'orientation': 'sagittal', 'index': 125}, {'orientation': 'coronal', 'index': 125}, {'orientation': 'axial', 'index': 80}]
Scanning subjects in: /data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_multi_slice/train
Loading metadata from: /data/kuang/Projects/MedSAM/data/Subj

Evaluating Early Fusion model

In [5]:
import torch
from torch.utils.data import DataLoader
from pathlib import Path
import gc

import torch.nn as nn

# Ensure necessary classes and functions from the previous cell are available
# (AgeMLPWithAttentionBN, CombinedSliceDataset, evaluate)

# --- Configuration (Ensure these are consistent with the training cell) ---
# MODEL_SAVE_PATH is derived from MODELS_OUTPUT_DIR used in the previous cell
MODEL_SAVE_PATH = MODELS_OUTPUT_DIR / "best_combined_model.pth"
TEST_SPLIT = 'test' # Explicitly define the split for clarity

# --- Test Set Evaluation ---
print(f"\n--- Evaluating Combined Model on Test Set ---")
print(f"Loading model from: {MODEL_SAVE_PATH}")
print(f"Using test data from: {FEATURE_ROOT} (split: {TEST_SPLIT})")
print(f"CSV Path: {CSV_PATH}")
print(f"Device: {DEVICE}")

test_dataset = None
test_loader = None
model = None

try:
    # 1. Create Test Dataset and DataLoader
    print("Setting up test dataset...")
    test_dataset = CombinedSliceDataset(FEATURE_ROOT, CSV_PATH, TEST_SPLIT, TARGET_SLICES)
    print("Setting up test dataloader...")
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

    # 2. Initialize Model
    print("Initializing model...")
    # Ensure input_dim and embed_dim match the trained model's configuration
    # Assuming num_heads=16 was used during training as per the previous cell's default
    model = AgeMLPWithAttentionBN(input_dim=CONCAT_DIM, embed_dim=CONCAT_DIM, num_heads=16).to(DEVICE)

    # 3. Load Trained Weights
    if MODEL_SAVE_PATH.is_file():
        print(f"Loading model state dict from {MODEL_SAVE_PATH}...")
        model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
        model.eval() # Set model to evaluation mode
        print("Model loaded successfully.")
    else:
        raise FileNotFoundError(f"Model file not found at {MODEL_SAVE_PATH}. Cannot evaluate.")

    # 4. Define Criterion
    criterion = nn.L1Loss() # Use the same loss function (MAE) for evaluation

    # 5. Evaluate
    print("Starting evaluation on the test set...")
    test_loss, test_mae = evaluate(model, test_loader, criterion, DEVICE) # Use the existing evaluate function

    # 6. Print Results
    print(f"\n--- Combined Model Test Set Results ---")
    print(f"Test Loss (MAE): {test_loss:.4f}")
    print(f"Test MAE: {test_mae:.3f}")

    # Optionally add results to the final_results dict if it exists and you want to store them
    if 'final_results' in locals() and isinstance(final_results, dict):
        final_results['test_mae'] = test_mae
        final_results['test_loss'] = test_loss
        print("Test results added to 'final_results' dictionary.")


except FileNotFoundError as e:
    print(f"\nError: {e}")
    print("Skipping test set evaluation.")
except (ValueError, RuntimeError) as e:
    print(f"\nError during dataset/dataloader creation or evaluation: {e}")
    print("Skipping test set evaluation.")
except Exception as e:
    print(f"\nAn unexpected error occurred during test set evaluation: {e}")
    print("Skipping test set evaluation.")
finally:
    # 7. Clean up
    print("\nCleaning up test evaluation resources...")
    del model, test_loader, test_dataset, criterion
    if 'test_loss' in locals(): del test_loss
    if 'test_mae' in locals(): del test_mae
    gc.collect()
    if DEVICE.startswith('cuda'):
        torch.cuda.empty_cache()
    print("Test evaluation finished.")



--- Evaluating Combined Model on Test Set ---
Loading model from: combined_slice_model/best_combined_model.pth
Using test data from: /data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_multi_slice (split: test)
CSV Path: /data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv
Device: cuda:1
Setting up test dataset...

[Dataset Init - Combined] Split: test
Target Slices: [{'orientation': 'sagittal', 'index': 80}, {'orientation': 'sagittal', 'index': 125}, {'orientation': 'coronal', 'index': 125}, {'orientation': 'axial', 'index': 80}]
Scanning subjects in: /data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_multi_slice/test
Loading metadata from: /data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv
Loaded metadata for 2850 subjects from CSV.
Successfully mapped 296 subject directories in 'test' to CSV entries.
Scanning 296 potential subject directories for all target slices...
Found 296 valid subjects with all required slices in split test.
Se

Feature fusion using Multi-head Attention

In [10]:
# Define this new model class in the same cell as the dataset and training loop,
# or import it if defined elsewhere.

import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionFusionMLP(nn.Module):
    def __init__(self, num_slices=4, embed_dim_per_slice=256, num_heads=4,
                 mlp_hidden_dim1=128, mlp_hidden_dim2=64, mlp_hidden_dim3=32, dropout_rate=0.3):
        """
        Args:
            num_slices (int): Number of input slices (e.g., 4).
            embed_dim_per_slice (int): Dimension of the embedding from each slice (e.g., 256).
            num_heads (int): Number of attention heads for fusing slice features. Must divide embed_dim_per_slice.
            mlp_hidden_dim1, mlp_hidden_dim2, mlp_hidden_dim3 (int): Hidden dimensions for the final MLP.
            dropout_rate (float): Dropout rate for the MLP.
        """
        super(AttentionFusionMLP, self).__init__()
        self.num_slices = num_slices
        self.embed_dim = embed_dim_per_slice

        if self.embed_dim % num_heads != 0:
            raise ValueError(f"embed_dim_per_slice ({self.embed_dim}) must be divisible by num_heads ({num_heads})")

        # Attention layer to fuse features across slices
        # Input: (batch_size, num_slices, embed_dim)
        # We use embed_dim as query, key, and value dimension.
        # batch_first=True expects input shape (batch_size, seq_len, feature_dim)
        self.fusion_attention = nn.MultiheadAttention(embed_dim=self.embed_dim, num_heads=num_heads, batch_first=True)
        self.norm_attn = nn.LayerNorm(self.embed_dim)

        # Learnable query vector (or use mean/max of input features as query)
        # Using a learnable query allows the model to focus on relevant aspects for age prediction
        self.query_vector = nn.Parameter(torch.randn(1, 1, self.embed_dim)) # (1, 1, embed_dim) for broadcasting

        # MLP for final age prediction (takes the fused feature vector)
        self.mlp = nn.Sequential(
            # Input dimension is embed_dim_per_slice (the output of attention fusion)
            nn.Linear(self.embed_dim, mlp_hidden_dim1), nn.BatchNorm1d(mlp_hidden_dim1), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(mlp_hidden_dim1, mlp_hidden_dim2), nn.BatchNorm1d(mlp_hidden_dim2), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(mlp_hidden_dim2, mlp_hidden_dim3), nn.BatchNorm1d(mlp_hidden_dim3), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(mlp_hidden_dim3, 1)
        )

    def forward(self, x):
        # x shape: (batch_size, concat_dim) e.g., (32, 1024)
        batch_size = x.shape[0]

        # Reshape concatenated features into (batch_size, num_slices, embed_dim)
        # Example: (32, 1024) -> (32, 4, 256)
        x_reshaped = x.view(batch_size, self.num_slices, self.embed_dim)

        # Expand the learnable query vector to match the batch size
        query = self.query_vector.expand(batch_size, -1, -1) # -> (batch_size, 1, embed_dim)

        # Apply attention: Query attends to the slice features (Keys and Values)
        # query: (batch_size, 1, embed_dim)
        # key:   (batch_size, num_slices, embed_dim)
        # value: (batch_size, num_slices, embed_dim)
        # attn_output shape: (batch_size, 1, embed_dim)
        attn_output, attn_weights = self.fusion_attention(query=query, key=x_reshaped, value=x_reshaped)

        # Apply layer normalization (optional but often helpful)
        # We apply it to the output corresponding to the query
        fused_features = self.norm_attn(attn_output) # Still (batch_size, 1, embed_dim)

        # Squeeze the sequence dimension (which was 1 for the query)
        fused_features = fused_features.squeeze(1) # -> (batch_size, embed_dim)

        # Pass the fused features through the MLP
        output = self.mlp(fused_features) # -> (batch_size, 1)

        return output.squeeze(-1) # -> (batch_size,)


In [7]:
# In the third code cell (or whichever cell contains the main training loop)

# ... (Keep imports, configurations, Dataset definition, train/eval functions) ...
# ... (Define or import AttentionFusionMLP class here) ...

# --- Main Training Loop (Single Model - Now Fusion) ---
if __name__ == "__main__": # Or run directly in notebook cell
    # --- Set Random Seed (Keep this if you added it) ---
    SEED = 42
    import random
    import numpy as np
    import torch
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False
    print(f"Random seed set to {SEED}")
    # --- End Set Random Seed ---

    # --- Update Output Directory for Fusion Model ---
    MODELS_OUTPUT_DIR = Path("./attention_fusion_model") # New directory for this model
    # --- End Update ---

    print(f"Starting Attention Fusion MLP training...") # Updated message
    print(f"Target Slices for Fusion: {TARGET_SLICES}")
    # CONCAT_DIM is still relevant for the dataset loading, but not the MLP input directly
    # print(f"Concatenated Input Dimension (for Dataset): {CONCAT_DIM}")
    print(f"Using device: {DEVICE}")
    print(f"Feature Root: {FEATURE_ROOT}")
    print(f"CSV Path: {CSV_PATH}")
    print(f"Model will be saved to: {MODELS_OUTPUT_DIR}") # Updated path
    print(f"Hyperparameters: Epochs={EPOCHS}, LR={LEARNING_RATE}, Batch={BATCH_SIZE}, ES_Patience={EARLY_STOPPING_PATIENCE}")

    MODELS_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    model_save_path = MODELS_OUTPUT_DIR / "best_fusion_model.pth" # Updated filename

    try:
        print("\nSetting up combined datasets (still concatenates for loading)...")
        # Dataset remains the same - it loads and concatenates features
        train_dataset = CombinedSliceDataset(FEATURE_ROOT, CSV_PATH, 'train', TARGET_SLICES)
        val_dataset = CombinedSliceDataset(FEATURE_ROOT, CSV_PATH, 'validation', TARGET_SLICES)
        # Optional: test_dataset = CombinedSliceDataset(FEATURE_ROOT, CSV_PATH, 'test', TARGET_SLICES)

        print("Setting up dataloaders...")
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
        # Optional: test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

    except (FileNotFoundError, ValueError, RuntimeError) as e:
        print(f"\nError initializing combined datasets/loaders: {e}")
        print("Stopping training.")
        exit() # Or raise e

    print("\nInitializing Attention Fusion model, optimizer, scheduler...")
    # --- Instantiate the NEW Fusion Model ---
    model = AttentionFusionMLP(
        num_slices=len(TARGET_SLICES),
        embed_dim_per_slice=256, # Assuming 256 from GAP
        num_heads=8, # Example: ensure 256 % 8 == 0
        mlp_hidden_dim1=128,
        mlp_hidden_dim2=64,
        mlp_hidden_dim3=32,
        dropout_rate=0.3
    ).to(DEVICE)
    # --- End Model Instantiation ---

    criterion = nn.L1Loss() # MAE Loss
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler_patience_epochs = max(1, int(EPOCHS * SCHEDULER_PATIENCE_PERCENT))
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=scheduler_patience_epochs, factor=0.5, verbose=False)

    best_val_mae = float('inf')
    epochs_no_improve = 0
    best_epoch = -1

    print(f"\n--- Starting Attention Fusion Model Training ---") # Updated message
    for epoch in range(EPOCHS):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE)
        val_loss, val_mae = evaluate(model, val_loader, criterion, DEVICE)

        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val MAE: {val_mae:.3f} | LR: {current_lr:.1e}")

        scheduler.step(val_loss) # Step based on validation loss

        if val_mae < best_val_mae:
            best_val_mae = val_mae
            best_epoch = epoch + 1
            epochs_no_improve = 0
            torch.save(model.state_dict(), model_save_path)
            print(f"  -> New best Val MAE: {best_val_mae:.3f}. Saved model to {model_save_path}")
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"\nEarly stopping triggered after {EARLY_STOPPING_PATIENCE} epochs without improvement.")
            break

        if (epoch + 1) % 10 == 0:
             gc.collect()
             if DEVICE.startswith('cuda'):
                 torch.cuda.empty_cache()

    print(f"\n--- Attention Fusion Model Training Finished ---") # Updated message
    if best_epoch != -1:
         print(f"Best Validation MAE: {best_val_mae:.3f} achieved at epoch {best_epoch}")
         print(f"Best model saved to: {model_save_path}")
         final_results = {'best_val_mae': best_val_mae, 'best_epoch': best_epoch}
    else:
         print("No improvement found during training.")
         final_results = {'best_val_mae': float('inf'), 'best_epoch': -1, 'error': 'No improvement'}


    # --- Optional: Evaluate on Test Set (Ensure model instantiation matches) ---
    # if 'test_loader' in locals() and model_save_path.exists():
    #     print(f"\n--- Evaluating Fusion Model on Test Set ---")
    #     try:
    #         # Re-initialize the model structure before loading state dict
    #         model_test = AttentionFusionMLP(
    #             num_slices=len(TARGET_SLICES), embed_dim_per_slice=256, num_heads=8 # Use same params
    #         ).to(DEVICE)
    #         model_test.load_state_dict(torch.load(model_save_path, map_location=DEVICE))
    #         model_test.eval()
    #         test_loss, test_mae = evaluate(model_test, test_loader, criterion, DEVICE) # Use model_test
    #         print(f"Fusion Model Test Loss: {test_loss:.4f} | Test MAE: {test_mae:.3f}")
    #         final_results['test_mae'] = test_mae
    #         del model_test # Clean up test model
    #     except Exception as e:
    #         print(f"Error during test set evaluation: {e}")
    # else:
    #     print("\nSkipping test set evaluation.")

    # Clean up
    print("\nCleaning up resources...")
    # Ensure model is deleted if it exists
    if 'model' in locals(): del model
    del optimizer, scheduler, train_dataset, val_dataset, train_loader, val_loader
    # if 'test_dataset' in locals(): del test_dataset
    # if 'test_loader' in locals(): del test_loader
    gc.collect()
    if DEVICE.startswith('cuda'):
        torch.cuda.empty_cache()

    print("\nAttention Fusion model training finished.") # Updated message


Random seed set to 42
Starting Attention Fusion MLP training...
Target Slices for Fusion: [{'orientation': 'sagittal', 'index': 80}, {'orientation': 'sagittal', 'index': 125}, {'orientation': 'coronal', 'index': 125}, {'orientation': 'axial', 'index': 80}]
Using device: cuda:1
Feature Root: /data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_multi_slice
CSV Path: /data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv
Model will be saved to: attention_fusion_model
Hyperparameters: Epochs=500, LR=0.0001, Batch=32, ES_Patience=50

Setting up combined datasets (still concatenates for loading)...

[Dataset Init - Combined] Split: train
Target Slices: [{'orientation': 'sagittal', 'index': 80}, {'orientation': 'sagittal', 'index': 125}, {'orientation': 'coronal', 'index': 125}, {'orientation': 'axial', 'index': 80}]
Scanning subjects in: /data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_multi_slice/train
Loading metadata from: /data/kuang/Projects/MedSAM/data/Subject

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from pathlib import Path
import os
import gc
from collections import defaultdict # Keep for potential future use if needed

# --- Configuration ---
FEATURE_ROOT = Path("/data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_multi_slice")
CSV_PATH = Path("/data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv")
DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu" # Use specific GPU if desired
SPLITS = ["train", "validation", "test"] # Define dataset splits

# Define the target slices used for the fusion model
TARGET_SLICES = [
    {'orientation': 'sagittal', 'index': 80},
    {'orientation': 'sagittal', 'index': 125},
    {'orientation': 'coronal',  'index': 125},
    {'orientation': 'axial',    'index': 80},
]
CONCAT_DIM = 256 * len(TARGET_SLICES) # 4 * 256 = 1024

# --- Hyperparameters (relevant for DataLoader) ---
BATCH_SIZE = 32 # Adjust based on GPU memory
NUM_WORKERS = 4 # Dataloader workers

# --- Configuration for Test Evaluation ---
# Directory where the fusion model was saved
FUSION_MODEL_DIR = Path("./attention_fusion_model")
# Path to the specific saved model file
MODEL_SAVE_PATH = FUSION_MODEL_DIR / "best_fusion_model.pth"
TEST_SPLIT = 'test' # Explicitly define the split for clarity

# Model parameters (MUST match the trained model)
NUM_SLICES_TEST = len(TARGET_SLICES)
EMBED_DIM_PER_SLICE_TEST = 256
NUM_HEADS_TEST = 8 # Ensure this matches the num_heads used during training
MLP_HIDDEN_DIM1_TEST = 128
MLP_HIDDEN_DIM2_TEST = 64
MLP_HIDDEN_DIM3_TEST = 32
DROPOUT_RATE_TEST = 0.3 # Usually dropout is inactive in eval mode, but good practice to match

# --- Dataset Definition (Modified for Concatenating Slices) ---
class CombinedSliceDataset(Dataset):
    def __init__(self, feature_root, csv_path, split, target_slices):
        """
        Args:
            feature_root (Path): Base directory containing split folders.
            csv_path (Path): Path to the CSV file with metadata.
            split (str): The dataset split ('train', 'validation', or 'test').
            target_slices (list): List of dicts, each specifying {'orientation': str, 'index': int}.
        """
        self.feature_root = Path(feature_root)
        self.csv_path = Path(csv_path)
        self.split = split
        self.target_slices = target_slices
        self.split_dir = self.feature_root / self.split

        print(f"\n[Dataset Init - Combined] Split: {self.split}")
        print(f"Target Slices: {self.target_slices}")
        print(f"Scanning subjects in: {self.split_dir}")
        print(f"Loading metadata from: {self.csv_path}")

        # Basic checks
        if not self.feature_root.is_dir(): raise FileNotFoundError(f"Feature root not found: {self.feature_root}")
        if not self.split_dir.is_dir(): raise FileNotFoundError(f"Split directory not found: {self.split_dir}")
        if not self.csv_path.is_file(): raise FileNotFoundError(f"CSV file not found: {self.csv_path}")

        # Load CSV and build mapping
        try:
            df = pd.read_csv(self.csv_path)
            self.meta_dict = df.set_index('filename')['age'].to_dict()
            self.subject_dir_to_filename = {}
            nii_gz_keys = {k.replace(".nii.gz", "") for k in self.meta_dict.keys()}
            nii_keys = {k.replace(".nii", "") for k in self.meta_dict.keys()}
            potential_subject_dirs = [d.name for d in self.split_dir.iterdir() if d.is_dir()]
            mapped_count = 0
            for subj_dir_name in potential_subject_dirs:
                matched_key = None
                # Using the same heuristic, adjust if needed based on actual directory names
                base_subj_name = subj_dir_name.split('_mri_brainmask')[0]
                for key_base in nii_gz_keys:
                    if key_base.startswith(base_subj_name):
                         matched_key = key_base + ".nii.gz"
                         break
                if not matched_key:
                     for key_base in nii_keys:
                          if key_base.startswith(base_subj_name):
                               matched_key = key_base + ".nii"
                               break
                if matched_key and matched_key in self.meta_dict:
                    self.subject_dir_to_filename[subj_dir_name] = matched_key
                    mapped_count += 1
            print(f"Loaded metadata for {len(self.meta_dict)} subjects from CSV.")
            print(f"Successfully mapped {mapped_count} subject directories in '{self.split}' to CSV entries.")
            if mapped_count < len(potential_subject_dirs):
                 print(f"Warning: Could not map {len(potential_subject_dirs) - mapped_count} directories to CSV.")
        except Exception as e:
            raise ValueError(f"Error loading/processing CSV or mapping directories: {e}")

        # Find subject directories that have ALL required slice files
        self.valid_subject_dirs = []
        subjects_missing_slices = 0
        subjects_without_meta = 0

        print(f"Scanning {len(potential_subject_dirs)} potential subject directories for all target slices...")
        all_subject_dirs_in_split = [d for d in self.split_dir.iterdir() if d.is_dir()]

        for subject_dir in all_subject_dirs_in_split:
            subject_dir_name = subject_dir.name
            # Check if metadata exists for this subject directory
            if subject_dir_name in self.subject_dir_to_filename:
                all_slices_found = True
                for slice_info in self.target_slices:
                    slice_filename = f"slice_{slice_info['orientation']}_{slice_info['index']}.npy"
                    expected_slice_path = subject_dir / slice_filename
                    if not expected_slice_path.is_file():
                        all_slices_found = False
                        # print(f"Debug: Missing {expected_slice_path} for subject {subject_dir_name}") # Optional debug
                        break # No need to check further slices for this subject
                if all_slices_found:
                    self.valid_subject_dirs.append(subject_dir)
                else:
                    subjects_missing_slices += 1
            else:
                subjects_without_meta += 1 # Count dirs we couldn't map to CSV

        if subjects_without_meta > 0:
             print(f"Info: {subjects_without_meta} subject directories were skipped (no metadata mapping).")
        if subjects_missing_slices > 0:
            print(f"Warning: {subjects_missing_slices} subjects with metadata were missing at least one target slice.")

        if not self.valid_subject_dirs:
             raise RuntimeError(f"No subjects found with all required slices in {self.split_dir} with matching metadata.")

        print(f"Found {len(self.valid_subject_dirs)} valid subjects with all required slices in split {self.split}.")

    def __len__(self):
        return len(self.valid_subject_dirs)

    def __getitem__(self, idx):
        subject_dir = self.valid_subject_dirs[idx]
        subject_dir_name = subject_dir.name
        original_filename = self.subject_dir_to_filename.get(subject_dir_name)
        if not original_filename:
             raise ValueError(f"Internal Error: Could not find original filename for valid subject dir {subject_dir_name}")

        pooled_embeddings = []
        try:
            for slice_info in self.target_slices:
                slice_filename = f"slice_{slice_info['orientation']}_{slice_info['index']}.npy"
                slice_path = subject_dir / slice_filename
                embedding = np.load(slice_path)
                embedding_tensor = torch.tensor(embedding, dtype=torch.float32)

                # Apply Global Average Pooling (GAP)
                if len(embedding_tensor.shape) == 4 and embedding_tensor.shape[0] == 1:
                     pooled = F.adaptive_avg_pool2d(embedding_tensor, (1, 1)).squeeze() # -> (C,) e.g., (256,)
                elif len(embedding_tensor.shape) == 3: # Assume (C, H, W)
                     pooled = F.adaptive_avg_pool2d(embedding_tensor.unsqueeze(0), (1, 1)).squeeze() # Add batch, pool, squeeze -> (C,)
                elif len(embedding_tensor.shape) == 1: # Assume already pooled (C,)
                     pooled = embedding_tensor
                else:
                     raise ValueError(f"Unexpected embedding shape {embedding_tensor.shape} for {slice_path}")

                if pooled.shape[0] != 256: # Ensure channel dim is correct after pooling
                     raise ValueError(f"Pooled embedding channel dim is not 256 for {slice_path}: {pooled.shape}")

                pooled_embeddings.append(pooled)

            # Concatenate the pooled embeddings
            concatenated_embedding = torch.cat(pooled_embeddings, dim=0) # Shape: (1024,)

            # Get age
            age = self.meta_dict[original_filename]
            age_tensor = torch.tensor(age, dtype=torch.float32)

            return concatenated_embedding, age_tensor

        except Exception as e:
            print(f"Error loading/processing slices for subject {subject_dir_name}: {e}")
            raise e # Re-raise

# --- Model Definition (Attention Fusion MLP) ---
class AttentionFusionMLP(nn.Module):
    def __init__(self, num_slices=4, embed_dim_per_slice=256, num_heads=4,
                 mlp_hidden_dim1=128, mlp_hidden_dim2=64, mlp_hidden_dim3=32, dropout_rate=0.3):
        """
        Args:
            num_slices (int): Number of input slices (e.g., 4).
            embed_dim_per_slice (int): Dimension of the embedding from each slice (e.g., 256).
            num_heads (int): Number of attention heads for fusing slice features. Must divide embed_dim_per_slice.
            mlp_hidden_dim1, mlp_hidden_dim2, mlp_hidden_dim3 (int): Hidden dimensions for the final MLP.
            dropout_rate (float): Dropout rate for the MLP.
        """
        super(AttentionFusionMLP, self).__init__()
        self.num_slices = num_slices
        self.embed_dim = embed_dim_per_slice

        if self.embed_dim % num_heads != 0:
            raise ValueError(f"embed_dim_per_slice ({self.embed_dim}) must be divisible by num_heads ({num_heads})")

        # Attention layer to fuse features across slices
        self.fusion_attention = nn.MultiheadAttention(embed_dim=self.embed_dim, num_heads=num_heads, batch_first=True)
        self.norm_attn = nn.LayerNorm(self.embed_dim)

        # Learnable query vector
        self.query_vector = nn.Parameter(torch.randn(1, 1, self.embed_dim)) # (1, 1, embed_dim) for broadcasting

        # MLP for final age prediction
        self.mlp = nn.Sequential(
            nn.Linear(self.embed_dim, mlp_hidden_dim1), nn.BatchNorm1d(mlp_hidden_dim1), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(mlp_hidden_dim1, mlp_hidden_dim2), nn.BatchNorm1d(mlp_hidden_dim2), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(mlp_hidden_dim2, mlp_hidden_dim3), nn.BatchNorm1d(mlp_hidden_dim3), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(mlp_hidden_dim3, 1)
        )

    def forward(self, x):
        # x shape: (batch_size, concat_dim) e.g., (32, 1024)
        batch_size = x.shape[0]

        # Reshape concatenated features into (batch_size, num_slices, embed_dim)
        x_reshaped = x.view(batch_size, self.num_slices, self.embed_dim)

        # Expand the learnable query vector to match the batch size
        query = self.query_vector.expand(batch_size, -1, -1) # -> (batch_size, 1, embed_dim)

        # Apply attention: Query attends to the slice features (Keys and Values)
        attn_output, attn_weights = self.fusion_attention(query=query, key=x_reshaped, value=x_reshaped)

        # Apply layer normalization
        fused_features = self.norm_attn(attn_output) # Still (batch_size, 1, embed_dim)

        # Squeeze the sequence dimension
        fused_features = fused_features.squeeze(1) # -> (batch_size, embed_dim)

        # Pass the fused features through the MLP
        output = self.mlp(fused_features) # -> (batch_size, 1)

        return output.squeeze(-1) # -> (batch_size,)

# --- Evaluation Function ---
@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_mae = 0.0
    num_samples = 0
    for features, ages in loader:
        features, ages = features.to(device), ages.to(device)
        predictions = model(features)
        loss = criterion(predictions, ages)
        total_loss += loss.item() * features.size(0)
        # Calculate MAE directly using L1Loss reduction='sum' for batch, then sum across batches
        mae = F.l1_loss(predictions, ages, reduction='sum')
        total_mae += mae.item()
        num_samples += features.size(0)

    avg_loss = total_loss / num_samples if num_samples > 0 else 0
    avg_mae = total_mae / num_samples if num_samples > 0 else float('inf')
    return avg_loss, avg_mae

# --- Main Test Set Evaluation Logic ---
if __name__ == "__main__": # Ensures this runs when script is executed

    print(f"\n--- Evaluating Attention Fusion MLP on Test Set ---")
    print(f"Loading model from: {MODEL_SAVE_PATH}")
    print(f"Using test data from: {FEATURE_ROOT} (split: {TEST_SPLIT})")
    print(f"CSV Path: {CSV_PATH}")
    print(f"Device: {DEVICE}")

    test_dataset = None
    test_loader = None
    model = None
    criterion = None

    try:
        # 1. Create Test Dataset and DataLoader
        print("\nSetting up test dataset...")
        test_dataset = CombinedSliceDataset(FEATURE_ROOT, CSV_PATH, TEST_SPLIT, TARGET_SLICES)
        print("Setting up test dataloader...")
        test_loader = DataLoader(test_dataset,
                                 batch_size=BATCH_SIZE,
                                 shuffle=False, # Important for evaluation
                                 num_workers=NUM_WORKERS,
                                 pin_memory=True)

        # 2. Initialize Model
        print("\nInitializing AttentionFusionMLP model structure...")
        # Instantiate with the EXACT same parameters as the trained model
        model = AttentionFusionMLP(
            num_slices=NUM_SLICES_TEST,
            embed_dim_per_slice=EMBED_DIM_PER_SLICE_TEST,
            num_heads=NUM_HEADS_TEST,
            mlp_hidden_dim1=MLP_HIDDEN_DIM1_TEST,
            mlp_hidden_dim2=MLP_HIDDEN_DIM2_TEST,
            mlp_hidden_dim3=MLP_HIDDEN_DIM3_TEST,
            dropout_rate=DROPOUT_RATE_TEST
        ).to(DEVICE)

        # 3. Load Trained Weights
        if MODEL_SAVE_PATH.is_file():
            print(f"Loading model state dict from {MODEL_SAVE_PATH}...")
            model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
            model.eval() # Set model to evaluation mode
            print("Model loaded successfully.")
        else:
            raise FileNotFoundError(f"Model file not found at {MODEL_SAVE_PATH}. Cannot evaluate.")

        # 4. Define Criterion
        criterion = nn.L1Loss() # MAE Loss

        # 5. Evaluate
        print("\nStarting evaluation on the test set...")
        test_loss, test_mae = evaluate(model, test_loader, criterion, DEVICE)

        # 6. Print Results
        print(f"\n--- Attention Fusion Model Test Set Results ---")
        print(f"Test Loss (calculated using L1Loss): {test_loss:.4f}")
        print(f"Test MAE: {test_mae:.3f}")

    except FileNotFoundError as e:
        print(f"\nError: {e}")
        print("Skipping test set evaluation.")
    except (FileNotFoundError, ValueError, RuntimeError) as e:
        print(f"\nError during dataset/dataloader creation or evaluation: {e}")
        print("Skipping test set evaluation.")
    except Exception as e:
        print(f"\nAn unexpected error occurred during test set evaluation: {e}")
        print("Skipping test set evaluation.")
    finally:
        # 7. Clean up
        print("\nCleaning up test evaluation resources...")
        if 'model' in locals() and model is not None: del model
        if 'test_loader' in locals() and test_loader is not None: del test_loader
        if 'test_dataset' in locals() and test_dataset is not None: del test_dataset
        if 'criterion' in locals() and criterion is not None: del criterion
        if 'test_loss' in locals(): del test_loss
        if 'test_mae' in locals(): del test_mae
        gc.collect()
        if DEVICE.startswith('cuda'):
            torch.cuda.empty_cache()
        print("Test evaluation finished.")


--- Evaluating Attention Fusion MLP on Test Set ---
Loading model from: attention_fusion_model/best_fusion_model.pth
Using test data from: /data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_multi_slice (split: test)
CSV Path: /data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv
Device: cuda:1

Setting up test dataset...

[Dataset Init - Combined] Split: test
Target Slices: [{'orientation': 'sagittal', 'index': 80}, {'orientation': 'sagittal', 'index': 125}, {'orientation': 'coronal', 'index': 125}, {'orientation': 'axial', 'index': 80}]
Scanning subjects in: /data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_multi_slice/test
Loading metadata from: /data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv
Loaded metadata for 2850 subjects from CSV.
Successfully mapped 296 subject directories in 'test' to CSV entries.
Scanning 296 potential subject directories for all target slices...
Found 296 valid subjects with all required slices in split t