Train one MLP for each slice. Average the predictions of the 11 MLPs

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from pathlib import Path
import torch.nn.functional as F
import os
import gc
from collections import defaultdict

# --- Configuration (Should match training script) ---
FEATURE_ROOT = Path("/data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_coronal_120_130")
CSV_PATH = Path("/data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv")
MODELS_OUTPUT_DIR = Path("./slice_specific_models") # Directory where best models are saved
DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 32 # Use the same batch size as validation/training if possible, or adjust
NUM_WORKERS = 4

# Define the target coronal slices (Must match training and preprocessing)
CORONAL_SLICE_CENTER = 125
CORONAL_SLICE_RANGE = 5
TARGET_CORONAL_INDICES = list(range(CORONAL_SLICE_CENTER - CORONAL_SLICE_RANGE,
                                    CORONAL_SLICE_CENTER + CORONAL_SLICE_RANGE + 1)) # e.g., 120 to 130

# --- Model Definition (Ensure this class definition is available) ---
class AgeMLPWithAttentionBN(nn.Module):
    # ... (Paste the full class definition here as in the training script) ...
    def __init__(self, input_dim=256, embed_dim=256, num_heads=8,
                 hidden_dim1=128, hidden_dim2=64, hidden_dim3=32, dropout_rate=0.3):
        super(AgeMLPWithAttentionBN, self).__init__()
        if embed_dim % num_heads != 0:
            raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})")
        self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=False)
        self.norm_attn = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1), nn.BatchNorm1d(hidden_dim1), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim1, hidden_dim2), nn.BatchNorm1d(hidden_dim2), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim2, hidden_dim3), nn.BatchNorm1d(hidden_dim3), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim3, 1)
        )

    def forward(self, x):
        x_attn = x.unsqueeze(0)
        attn_output, _ = self.attention(x_attn, x_attn, x_attn)
        x = self.norm_attn(x_attn + attn_output)
        x = x.squeeze(0)
        output = self.mlp(x)
        return output.squeeze(-1)


# --- Dataset Definition (Modified to return subject ID) ---
class SliceAgePredictionDataset(Dataset):
    # ... (Paste the full class definition here, but modify __getitem__) ...
    def __init__(self, feature_root, csv_path, split, slice_index):
        self.feature_root = Path(feature_root)
        self.csv_path = Path(csv_path)
        self.split = split
        self.slice_index = slice_index
        self.split_dir = self.feature_root / self.split

        if not self.split_dir.is_dir():
            raise FileNotFoundError(f"Split directory not found: {self.split_dir}")
        if not self.csv_path.is_file():
            raise FileNotFoundError(f"CSV file not found: {self.csv_path}")

        try:
            df = pd.read_csv(self.csv_path)
            self.meta_dict = df.set_index('filename')['age'].to_dict()
            self.subject_id_to_filename = {
                fname.replace(".nii.gz", "").replace(".nii", ""): fname
                for fname in self.meta_dict.keys()
            }
        except Exception as e:
            raise ValueError(f"Error loading or processing CSV {self.csv_path}: {e}")

        all_subject_dirs = [d for d in self.split_dir.iterdir() if d.is_dir()]
        self.valid_slice_files = []
        self.file_to_subject_id = {} # Map file path back to subject_id

        for subject_dir in all_subject_dirs:
            subject_id = subject_dir.name
            slice_filename = f"slice_{self.slice_index}.npy"
            expected_slice_path = subject_dir / slice_filename
            original_filename = self.subject_id_to_filename.get(subject_id)

            if original_filename and original_filename in self.meta_dict:
                if expected_slice_path.is_file():
                    self.valid_slice_files.append(expected_slice_path)
                    self.file_to_subject_id[expected_slice_path] = subject_id # Store mapping

        if not self.valid_slice_files:
             raise RuntimeError(f"No valid slice files found for slice {self.slice_index} in {self.split_dir} with matching metadata.")
        print(f"Found {len(self.valid_slice_files)} valid files for slice {self.slice_index} in split {self.split}.")


    def __len__(self):
        return len(self.valid_slice_files)

    def __getitem__(self, idx):
        slice_path = self.valid_slice_files[idx]
        subject_id = self.file_to_subject_id[slice_path] # Retrieve subject_id
        original_filename = self.subject_id_to_filename.get(subject_id)

        if not original_filename:
             raise ValueError(f"Could not find original filename for subject ID {subject_id}")

        try:
            embedding = np.load(slice_path)
            embedding_tensor = torch.tensor(embedding, dtype=torch.float32)

            if embedding_tensor.shape == (1, 256, 64, 64):
                 pooled_embedding = F.adaptive_avg_pool2d(embedding_tensor, (1, 1)).squeeze()
            elif embedding_tensor.shape == (256, 64, 64):
                 pooled_embedding = F.adaptive_avg_pool2d(embedding_tensor.unsqueeze(0), (1, 1)).squeeze()
            elif embedding_tensor.shape == (256,):
                 pooled_embedding = embedding_tensor
            else:
                 raise ValueError(f"Unexpected embedding shape {embedding_tensor.shape} for {slice_path}")

            if pooled_embedding.shape[0] != 256:
                 raise ValueError(f"Pooled embedding shape is not 256 for {slice_path}: {pooled_embedding.shape}")

            age = self.meta_dict[original_filename]
            age_tensor = torch.tensor(age, dtype=torch.float32)

            # Return subject_id along with features and age
            return pooled_embedding, age_tensor, subject_id

        except Exception as e:
            print(f"Error loading or processing file {slice_path}: {e}")
            raise e

# --- Evaluation Function (Modified to return predictions and IDs) ---
@torch.no_grad()
def predict_on_test_set(model, loader, device):
    model.eval()
    all_predictions = []
    all_ages = []
    all_subject_ids = []

    for features, ages, subject_ids in loader:
        features = features.to(device)
        # ages = ages.to(device) # No need to move ages if only used on CPU later

        predictions = model(features)

        all_predictions.extend(predictions.cpu().numpy())
        all_ages.extend(ages.cpu().numpy())
        all_subject_ids.extend(subject_ids) # subject_ids are already strings

    return np.array(all_predictions), np.array(all_ages), all_subject_ids


# --- Main Testing Logic ---
print(f"Starting testing on device: {DEVICE}")
print(f"Loading models from: {MODELS_OUTPUT_DIR}")
print(f"Testing on slices: {TARGET_CORONAL_INDICES}")

individual_model_maes = {}
# Use defaultdict to easily append predictions per subject
# Stores {subject_id: [pred_slice_1, pred_slice_2, ...]}
all_slice_predictions = defaultdict(list)
# Stores {subject_id: ground_truth_age} - only need to store once
ground_truth_ages = {}

for slice_idx in TARGET_CORONAL_INDICES:
    print(f"\n--- Testing Slice: {slice_idx} ---")
    model_path = MODELS_OUTPUT_DIR / f"best_model_slice_{slice_idx}.pth"

    if not model_path.exists():
        print(f"Warning: Model file not found for slice {slice_idx} at {model_path}. Skipping.")
        individual_model_maes[slice_idx] = np.nan
        continue

    # --- Load Data for the current slice ---
    try:
        test_dataset = SliceAgePredictionDataset(FEATURE_ROOT, CSV_PATH, 'test', slice_idx)
        # Important: Use shuffle=False for testing to keep order consistent if needed,
        # and to correctly map predictions back to subjects.
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                                 num_workers=NUM_WORKERS, pin_memory=True)
    except (FileNotFoundError, ValueError, RuntimeError) as e:
        print(f"Error initializing dataset/loader for test split, slice {slice_idx}: {e}")
        print(f"Skipping testing for slice {slice_idx}.")
        individual_model_maes[slice_idx] = np.nan
        continue

    # --- Load Model ---
    model = AgeMLPWithAttentionBN().to(DEVICE)
    try:
        model.load_state_dict(torch.load(model_path, map_location=DEVICE))
        print(f"Loaded model weights from {model_path}")
    except Exception as e:
        print(f"Error loading model weights for slice {slice_idx}: {e}")
        individual_model_maes[slice_idx] = np.nan
        del model, test_dataset, test_loader
        gc.collect()
        if DEVICE.startswith('cuda'): torch.cuda.empty_cache()
        continue

    # --- Get Predictions ---
    predictions, ages, subject_ids = predict_on_test_set(model, test_loader, DEVICE)

    # --- Calculate and Store Individual MAE ---
    if len(predictions) > 0:
        mae = np.mean(np.abs(predictions - ages))
        individual_model_maes[slice_idx] = mae
        print(f"Slice {slice_idx} Test MAE: {mae:.3f}")

        # --- Store predictions and ground truth for averaging ---
        for i, subj_id in enumerate(subject_ids):
            all_slice_predictions[subj_id].append(predictions[i])
            # Store ground truth age only once per subject
            if subj_id not in ground_truth_ages:
                ground_truth_ages[subj_id] = ages[i]
    else:
        print(f"No predictions generated for slice {slice_idx}.")
        individual_model_maes[slice_idx] = np.nan


    # Clean up GPU memory
    del model, test_dataset, test_loader, predictions, ages, subject_ids
    gc.collect()
    if DEVICE.startswith('cuda'):
        torch.cuda.empty_cache()


# --- Calculate MAE for Averaged Predictions ---
print("\n--- Calculating MAE for Averaged Predictions ---")
average_predictions = []
corresponding_ground_truths = []
subjects_with_complete_preds = 0

# Check subjects that have predictions for *all* target slices
num_target_slices = len(TARGET_CORONAL_INDICES)
subject_ids_in_order = sorted(all_slice_predictions.keys()) # Ensure consistent order

for subj_id in subject_ids_in_order:
    preds = all_slice_predictions[subj_id]
    if len(preds) == num_target_slices: # Ensure we have a prediction from each slice model
        avg_pred = np.mean(preds)
        average_predictions.append(avg_pred)
        corresponding_ground_truths.append(ground_truth_ages[subj_id])
        subjects_with_complete_preds += 1
    else:
        print(f"Warning: Subject {subj_id} has {len(preds)} predictions, expected {num_target_slices}. Skipping for average MAE calculation.")

if subjects_with_complete_preds > 0:
    average_predictions = np.array(average_predictions)
    corresponding_ground_truths = np.array(corresponding_ground_truths)
    average_mae = np.mean(np.abs(average_predictions - corresponding_ground_truths))
    print(f"\nNumber of subjects with complete predictions across all {num_target_slices} slices: {subjects_with_complete_preds}")
    print(f"MAE using averaged predictions: {average_mae:.3f}")
else:
    print("\nCould not calculate average MAE: No subjects had predictions for all required slices.")


# --- Final Summary ---
print("\n--- Individual Slice Model Test MAEs ---")
for slice_idx in TARGET_CORONAL_INDICES:
    mae = individual_model_maes.get(slice_idx, 'N/A')
    if isinstance(mae, float) and not np.isnan(mae):
        print(f"Slice {slice_idx}: {mae:.3f}")
    else:
        print(f"Slice {slice_idx}: {mae}")


Starting testing on device: cuda:1
Loading models from: slice_specific_models
Testing on slices: [120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130]

--- Testing Slice: 120 ---
Found 296 valid files for slice 120 in split test.
Loaded model weights from slice_specific_models/best_model_slice_120.pth
Slice 120 Test MAE: 7.078

--- Testing Slice: 121 ---
Found 296 valid files for slice 121 in split test.
Loaded model weights from slice_specific_models/best_model_slice_121.pth
Slice 121 Test MAE: 7.898

--- Testing Slice: 122 ---
Found 296 valid files for slice 122 in split test.
Loaded model weights from slice_specific_models/best_model_slice_122.pth
Slice 122 Test MAE: 7.434

--- Testing Slice: 123 ---
Found 296 valid files for slice 123 in split test.
Loaded model weights from slice_specific_models/best_model_slice_123.pth
Slice 123 Test MAE: 7.943

--- Testing Slice: 124 ---
Found 296 valid files for slice 124 in split test.
Loaded model weights from slice_specific_models/best_mod

Early fusion with concatenation

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from pathlib import Path
import torch.nn.functional as F
import os
import gc
from tqdm import tqdm
import random

# --- Configuration ---
FEATURE_ROOT = Path("/data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_coronal_120_130")
CSV_PATH = Path("/data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv")
MODEL_SAVE_PATH = Path("./concatenated_mlp_model.pth") # Path to save the best model
DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu"
SPLITS = ["train", "validation", "test"]

# Define the target coronal slices (Must match preprocessing)
CORONAL_SLICE_CENTER = 125
CORONAL_SLICE_RANGE = 5
TARGET_CORONAL_INDICES = list(range(CORONAL_SLICE_CENTER - CORONAL_SLICE_RANGE,
                                    CORONAL_SLICE_CENTER + CORONAL_SLICE_RANGE + 1)) # e.g., 120 to 130
NUM_SLICES = len(TARGET_CORONAL_INDICES)
FEATURE_DIM_PER_SLICE = 256 # After Global Average Pooling
CONCAT_FEATURE_DIM = NUM_SLICES * FEATURE_DIM_PER_SLICE # e.g., 11 * 256 = 2816

# --- Hyperparameters ---
EPOCHS = 500
LEARNING_RATE = 1e-4
BATCH_SIZE = 32
WEIGHT_DECAY = 1e-5
DROPOUT_RATE = 0.4 # Might need adjustment for the larger input dim
SCHEDULER_PATIENCE_PERCENT = 0.10
EARLY_STOPPING_PATIENCE = 50
NUM_WORKERS = 4

# --- Dataset Definition (Loads and Concatenates Features) ---
class ConcatenatedSliceDataset(Dataset):
    def __init__(self, feature_root, csv_path, split, slice_indices):
        """
        Args:
            feature_root (Path): Base directory containing split folders.
            csv_path (Path): Path to the CSV file with metadata.
            split (str): The dataset split ('train', 'validation', or 'test').
            slice_indices (list[int]): List of slice indices to load and concatenate.
        """
        self.feature_root = Path(feature_root)
        self.csv_path = Path(csv_path)
        self.split = split
        self.slice_indices = sorted(slice_indices) # Ensure consistent order
        self.num_slices = len(self.slice_indices)
        self.split_dir = self.feature_root / self.split

        print(f"\n[Dataset Init] Split: {self.split}, Concatenating Slices: {self.slice_indices}")
        print(f"Loading features from base: {self.feature_root}")
        print(f"Loading metadata from: {self.csv_path}")

        if not self.split_dir.is_dir():
            raise FileNotFoundError(f"Split directory not found: {self.split_dir}")
        if not self.csv_path.is_file():
            raise FileNotFoundError(f"CSV file not found: {self.csv_path}")

        try:
            df = pd.read_csv(self.csv_path)
            self.meta_dict = df.set_index('filename')['age'].to_dict()
            self.subject_id_to_filename = {
                fname.replace(".nii.gz", "").replace(".nii", ""): fname
                for fname in self.meta_dict.keys()
            }
            print(f"Loaded metadata for {len(self.meta_dict)} subjects from CSV.")
        except Exception as e:
            raise ValueError(f"Error loading or processing CSV {self.csv_path}: {e}")

        # Find subject directories that have metadata and *all* required slice files
        all_subject_dirs = [d for d in self.split_dir.iterdir() if d.is_dir()]
        self.valid_subject_dirs = []
        missing_meta_count = 0
        incomplete_slice_count = 0

        print(f"Scanning {len(all_subject_dirs)} potential subject directories...")

        for subject_dir in all_subject_dirs:
            subject_id = subject_dir.name
            original_filename = self.subject_id_to_filename.get(subject_id)

            if original_filename and original_filename in self.meta_dict:
                # Check if all required slice files exist for this subject
                all_slices_present = True
                for slice_idx in self.slice_indices:
                    slice_filename = f"slice_{slice_idx}.npy"
                    expected_slice_path = subject_dir / slice_filename
                    if not expected_slice_path.is_file():
                        all_slices_present = False
                        break # No need to check further slices for this subject

                if all_slices_present:
                    self.valid_subject_dirs.append(subject_dir)
                else:
                    incomplete_slice_count += 1
            else:
                missing_meta_count += 1

        if missing_meta_count > 0:
            print(f"Warning: {missing_meta_count} subject directories did not have corresponding metadata.")
        if incomplete_slice_count > 0:
            print(f"Warning: {incomplete_slice_count} subjects with metadata were missing one or more required slices.")

        if not self.valid_subject_dirs:
             raise RuntimeError(f"No subjects found in {self.split_dir} with metadata and all required slices {self.slice_indices}.")

        print(f"Found {len(self.valid_subject_dirs)} valid subjects for split {self.split}.")

    def __len__(self):
        return len(self.valid_subject_dirs)

    def __getitem__(self, idx):
        subject_dir = self.valid_subject_dirs[idx]
        subject_id = subject_dir.name
        original_filename = self.subject_id_to_filename.get(subject_id)

        if not original_filename:
             raise ValueError(f"Could not find original filename for subject ID {subject_id}")

        concatenated_features = []
        try:
            for slice_idx in self.slice_indices:
                slice_filename = f"slice_{slice_idx}.npy"
                slice_path = subject_dir / slice_filename
                embedding = np.load(slice_path)
                embedding_tensor = torch.tensor(embedding, dtype=torch.float32)

                # Apply Global Average Pooling (GAP)
                if embedding_tensor.shape == (1, 256, 64, 64):
                    pooled_embedding = F.adaptive_avg_pool2d(embedding_tensor, (1, 1)).squeeze() # -> (256,)
                elif embedding_tensor.shape == (256, 64, 64):
                    pooled_embedding = F.adaptive_avg_pool2d(embedding_tensor.unsqueeze(0), (1, 1)).squeeze()
                elif embedding_tensor.shape == (256,):
                    pooled_embedding = embedding_tensor
                else:
                    raise ValueError(f"Unexpected embedding shape {embedding_tensor.shape} for {slice_path}")

                if pooled_embedding.shape[0] != FEATURE_DIM_PER_SLICE:
                    raise ValueError(f"Pooled embedding shape is not {FEATURE_DIM_PER_SLICE} for {slice_path}: {pooled_embedding.shape}")

                concatenated_features.append(pooled_embedding)

            # Concatenate features from all slices for this subject
            final_feature_vector = torch.cat(concatenated_features, dim=0) # Shape: (num_slices * 256,)

            # Get the age
            age = self.meta_dict[original_filename]
            age_tensor = torch.tensor(age, dtype=torch.float32)

            return final_feature_vector, age_tensor

        except Exception as e:
            print(f"Error loading or processing slices for subject {subject_dir}: {e}")
            raise e

# --- Model Definition (Takes concatenated features as input) ---
class ConcatenatedAgeMLP(nn.Module):
    def __init__(self, input_dim=CONCAT_FEATURE_DIM, hidden_dim1=1024, hidden_dim2=512, hidden_dim3=256, dropout_rate=DROPOUT_RATE):
        """
        MLP for age prediction using concatenated features from multiple slices.
        Args:
            input_dim (int): Dimension of the concatenated input feature vector (num_slices * feature_dim_per_slice).
            hidden_dim1/2/3 (int): Sizes of the hidden layers.
            dropout_rate (float): Dropout probability.
        """
        super(ConcatenatedAgeMLP, self).__init__()
        self.mlp = nn.Sequential(
            # Layer 1
            nn.Linear(input_dim, hidden_dim1),
            nn.BatchNorm1d(hidden_dim1), # Batch Norm is often helpful
            nn.ReLU(),
            nn.Dropout(dropout_rate),

            # Layer 2
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.BatchNorm1d(hidden_dim2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),

            # Layer 3
            nn.Linear(hidden_dim2, hidden_dim3),
            nn.BatchNorm1d(hidden_dim3),
            nn.ReLU(),
            nn.Dropout(dropout_rate),

            # Output Layer
            nn.Linear(hidden_dim3, 1) # Output is a single value (age)
        )

    def forward(self, x):
        # x shape: (batch_size, input_dim)
        output = self.mlp(x)
        return output.squeeze(-1) # (batch_size,)


# --- Training and Evaluation Functions (Unchanged from previous script) ---
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    num_samples = 0
    for features, ages in loader:
        features, ages = features.to(device), ages.to(device)
        optimizer.zero_grad()
        predictions = model(features)
        loss = criterion(predictions, ages)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * features.size(0)
        num_samples += features.size(0)
    return total_loss / num_samples

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_mae = 0.0
    num_samples = 0
    for features, ages in loader:
        features, ages = features.to(device), ages.to(device)
        predictions = model(features)
        loss = criterion(predictions, ages) # Use the same criterion (e.g., L1Loss)
        total_loss += loss.item() * features.size(0)
        mae = F.l1_loss(predictions, ages, reduction='sum') # Sum MAE for batch
        total_mae += mae.item()
        num_samples += features.size(0)

    avg_loss = total_loss / num_samples
    avg_mae = total_mae / num_samples
    return avg_loss, avg_mae

# --- Main Training Script ---
if __name__ == "__main__":
    print(f"Starting concatenated feature MLP training...")
    print(f"Target Slices: {TARGET_CORONAL_INDICES}")
    print(f"Concatenated Feature Dim: {CONCAT_FEATURE_DIM}")
    print(f"Using device: {DEVICE}")
    print(f"Feature Root: {FEATURE_ROOT}")
    print(f"CSV Path: {CSV_PATH}")
    print(f"Best model will be saved to: {MODEL_SAVE_PATH}")
    print(f"Hyperparameters: Epochs={EPOCHS}, LR={LEARNING_RATE}, Batch={BATCH_SIZE}, Dropout={DROPOUT_RATE}, ES_Patience={EARLY_STOPPING_PATIENCE}")

    # --- Setup Datasets and DataLoaders ---
    try:
        print("Setting up datasets...")
        train_dataset = ConcatenatedSliceDataset(FEATURE_ROOT, CSV_PATH, 'train', TARGET_CORONAL_INDICES)
        val_dataset = ConcatenatedSliceDataset(FEATURE_ROOT, CSV_PATH, 'validation', TARGET_CORONAL_INDICES)
        # Optional: test_dataset = ConcatenatedSliceDataset(FEATURE_ROOT, CSV_PATH, 'test', TARGET_CORONAL_INDICES)

        print("Setting up dataloaders...")
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
        # Optional: test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

    except (FileNotFoundError, ValueError, RuntimeError) as e:
        print(f"Error initializing datasets/loaders: {e}")
        raise SystemExit("Dataset/Loader initialization failed.")

    # --- Setup Model, Loss, Optimizer, Scheduler ---
    print("Initializing model, optimizer, scheduler...")
    model = ConcatenatedAgeMLP(input_dim=CONCAT_FEATURE_DIM, dropout_rate=DROPOUT_RATE).to(DEVICE)
    criterion = nn.L1Loss() # MAE Loss
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

    scheduler_patience_epochs = max(1, int(EPOCHS * SCHEDULER_PATIENCE_PERCENT))
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=scheduler_patience_epochs, factor=0.5, verbose=True)

    # --- Training Loop ---
    best_val_mae = float('inf')
    epochs_no_improve = 0
    best_epoch = -1
    history = {'train_loss': [], 'val_loss': [], 'val_mae': []}

    print(f"\n--- Starting Training ---")
    for epoch in range(EPOCHS):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE)
        val_loss, val_mae = evaluate(model, val_loader, criterion, DEVICE)

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_mae'].append(val_mae)

        print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val MAE: {val_mae:.3f}")

        scheduler.step(val_loss) # Step scheduler based on validation loss

        # --- Early Stopping & Model Saving ---
        if val_mae < best_val_mae:
            best_val_mae = val_mae
            best_epoch = epoch + 1
            epochs_no_improve = 0
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"  -> New best Val MAE: {best_val_mae:.3f}. Saved model to {MODEL_SAVE_PATH}")
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"\nEarly stopping triggered after {EARLY_STOPPING_PATIENCE} epochs without improvement.")
            break

        # Optional: Clean up GPU memory periodically
        if (epoch + 1) % 20 == 0:
             gc.collect()
             if DEVICE.startswith('cuda'):
                 torch.cuda.empty_cache()

    print(f"\n--- Training Finished ---")
    if best_epoch != -1:
         print(f"Best Validation MAE: {best_val_mae:.3f} achieved at epoch {best_epoch}")
         print(f"Best model saved to {MODEL_SAVE_PATH}")
    else:
         print("No improvement found during training. Model not saved.")

    # --- Optional: Evaluate on Test Set ---
    # if 'test_loader' in locals() and MODEL_SAVE_PATH.exists():
    #     print(f"\n--- Evaluating on Test Set using Best Model ---")
    #     try:
    #         model.load_state_dict(torch.load(MODEL_SAVE_PATH))
    #         test_loss, test_mae = evaluate(model, test_loader, criterion, DEVICE)
    #         print(f"Test Loss: {test_loss:.4f} | Test MAE: {test_mae:.3f}")
    #     except Exception as e:
    #         print(f"Error during test set evaluation: {e}")
    # else:
    #     print("\nSkipping test set evaluation.")

    # You can plot history['train_loss'], history['val_loss'], history['val_mae'] here if needed
    # import matplotlib.pyplot as plt
    # plt.figure(...) etc.

    print("\nConcatenated MLP training script finished.")

Starting concatenated feature MLP training...
Target Slices: [120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130]
Concatenated Feature Dim: 2816
Using device: cuda:1
Feature Root: /data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_coronal_120_130
CSV Path: /data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv
Best model will be saved to: concatenated_mlp_model.pth
Hyperparameters: Epochs=500, LR=0.0001, Batch=32, Dropout=0.4, ES_Patience=50
Setting up datasets...

[Dataset Init] Split: train, Concatenating Slices: [120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130]
Loading features from base: /data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_coronal_120_130
Loading metadata from: /data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv
Loaded metadata for 2850 subjects from CSV.
Scanning 2275 potential subject directories...
Found 2274 valid subjects for split train.

[Dataset Init] Split: validation, Concatenating Slices: [120, 121, 1

Evaluating early fusion MLP

In [4]:
@torch.no_grad()
def test_model(model, test_loader, device):
    """
    Evaluate the trained model on the test dataset.

    Args:
        model (nn.Module): Trained model.
        test_loader (DataLoader): DataLoader for the test dataset.
        device (str): Device to run the evaluation on ('cpu' or 'cuda').

    Returns:
        float: Mean Absolute Error (MAE) on the test dataset.
    """
    model.eval()
    total_mae = 0.0
    num_samples = 0

    for features, ages in test_loader:
        features, ages = features.to(device), ages.to(device)
        predictions = model(features)
        mae = F.l1_loss(predictions, ages, reduction='sum')  # Sum MAE for the batch
        total_mae += mae.item()
        num_samples += features.size(0)

    avg_mae = total_mae / num_samples
    print(f"Test MAE: {avg_mae:.3f}")
    return avg_mae

# Example usage:
# Ensure the model is loaded with the best weights
model.load_state_dict(torch.load(MODEL_SAVE_PATH))
test_dataset = ConcatenatedSliceDataset(FEATURE_ROOT, CSV_PATH, 'test', TARGET_CORONAL_INDICES)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# Evaluate the model on the test set
test_mae = test_model(model, test_loader, DEVICE)


[Dataset Init] Split: test, Concatenating Slices: [120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130]
Loading features from base: /data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_coronal_120_130
Loading metadata from: /data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv
Loaded metadata for 2850 subjects from CSV.
Scanning 296 potential subject directories...
Found 296 valid subjects for split test.
Test MAE: 6.348


Feature Fusion with Multi-head Attention

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from pathlib import Path
import torch.nn.functional as F
import os
import gc
from tqdm import tqdm
import random

# --- Configuration ---
# IMPORTANT: Point to the directory containing the coronal 120-130 slice features
FEATURE_ROOT = Path("/data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_coronal_120_130")
CSV_PATH = Path("/data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv")
MODEL_SAVE_PATH = Path("./fused_mlp_model_avg_pool_coronal_120_130.pth") # Path to save the best model
DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu"
SPLITS = ["train", "validation", "test"]

# Define the target coronal slices to load and fuse
CORONAL_SLICE_CENTER = 125
CORONAL_SLICE_RANGE = 5
TARGET_CORONAL_INDICES = list(range(CORONAL_SLICE_CENTER - CORONAL_SLICE_RANGE,
                                    CORONAL_SLICE_CENTER + CORONAL_SLICE_RANGE + 1)) # 120 to 130
# Convert indices to strings to match the dataset logic expecting identifiers
TARGET_SLICE_IDENTIFIERS = [str(idx) for idx in TARGET_CORONAL_INDICES]

NUM_SLICES_TO_FUSE = len(TARGET_SLICE_IDENTIFIERS)
FEATURE_DIM_PER_SLICE = 256 # Dimension after Global Average Pooling in preprocessing

# --- Hyperparameters ---
EPOCHS = 500
LEARNING_RATE = 1e-4
BATCH_SIZE = 32
WEIGHT_DECAY = 1e-5
DROPOUT_RATE = 0.3 # Start similar to single-slice models, may need tuning
SCHEDULER_PATIENCE_PERCENT = 0.10
EARLY_STOPPING_PATIENCE = 40
NUM_WORKERS = 4
FUSION_METHOD = 'average' # 'average' or 'max'

# --- Dataset Definition (Loads and Fuses Features) ---
class FusedSliceDataset(Dataset):
    def __init__(self, feature_root, csv_path, split, target_slice_identifiers):
        """
        Args:
            feature_root (Path): Base directory containing split folders with subject subdirs.
            csv_path (Path): Path to the CSV file with metadata.
            split (str): The dataset split ('train', 'validation', or 'test').
            target_slice_identifiers (list[str]): List of slice identifiers (e.g., '120', '121').
        """
        self.feature_root = Path(feature_root)
        self.csv_path = Path(csv_path)
        self.split = split
        self.target_slice_identifiers = sorted(target_slice_identifiers) # Ensure consistent order
        self.num_slices_to_fuse = len(self.target_slice_identifiers)
        self.split_dir = self.feature_root / self.split

        print(f"\n[Dataset Init] Split: {self.split}, Fusing Coronal Slices: {self.target_slice_identifiers}")
        print(f"Loading features from base: {self.feature_root}")
        print(f"Loading metadata from: {self.csv_path}")

        if not self.split_dir.is_dir():
            raise FileNotFoundError(f"Split directory not found: {self.split_dir}")
        if not self.csv_path.is_file():
            raise FileNotFoundError(f"CSV file not found: {self.csv_path}")

        try:
            df = pd.read_csv(self.csv_path)
            self.meta_dict = df.set_index('filename')['age'].to_dict()
            self.subject_id_to_filename = {
                fname.replace(".nii.gz", "").replace(".nii", ""): fname
                for fname in self.meta_dict.keys()
            }
            print(f"Loaded metadata for {len(self.meta_dict)} subjects from CSV.")
        except Exception as e:
            raise ValueError(f"Error loading or processing CSV {self.csv_path}: {e}")

        all_subject_dirs = [d for d in self.split_dir.iterdir() if d.is_dir()]
        self.valid_subject_dirs = []
        missing_meta_count = 0
        incomplete_slice_count = 0

        print(f"Scanning {len(all_subject_dirs)} potential subject directories...")

        for subject_dir in all_subject_dirs:
            subject_id = subject_dir.name
            original_filename = self.subject_id_to_filename.get(subject_id)

            if original_filename and original_filename in self.meta_dict:
                all_slices_present = True
                for slice_id in self.target_slice_identifiers:
                    # Construct filename based on the coronal slice index
                    slice_filename = f"slice_{slice_id}.npy"
                    expected_slice_path = subject_dir / slice_filename
                    if not expected_slice_path.is_file():
                        all_slices_present = False
                        break

                if all_slices_present:
                    self.valid_subject_dirs.append(subject_dir)
                else:
                    incomplete_slice_count += 1
            else:
                missing_meta_count += 1

        if missing_meta_count > 0:
            print(f"Warning: {missing_meta_count} subject directories did not have corresponding metadata.")
        if incomplete_slice_count > 0:
            print(f"Warning: {incomplete_slice_count} subjects with metadata were missing one or more required slices.")

        if not self.valid_subject_dirs:
             raise RuntimeError(f"No subjects found in {self.split_dir} with metadata and all required slices {self.target_slice_identifiers}.")

        print(f"Found {len(self.valid_subject_dirs)} valid subjects for split {self.split}.")

    def __len__(self):
        return len(self.valid_subject_dirs)

    def __getitem__(self, idx):
        subject_dir = self.valid_subject_dirs[idx]
        subject_id = subject_dir.name
        original_filename = self.subject_id_to_filename.get(subject_id)

        if not original_filename:
             raise ValueError(f"Could not find original filename for subject ID {subject_id}")

        slice_features = []
        try:
            for slice_id in self.target_slice_identifiers:
                slice_filename = f"slice_{slice_id}.npy"
                slice_path = subject_dir / slice_filename
                embedding = np.load(slice_path)
                embedding_tensor = torch.tensor(embedding, dtype=torch.float32)

                # Apply Global Average Pooling (GAP) if not already done
                if embedding_tensor.ndim == 4 and embedding_tensor.shape[0] == 1 and embedding_tensor.shape[1] == FEATURE_DIM_PER_SLICE:
                    pooled_embedding = F.adaptive_avg_pool2d(embedding_tensor, (1, 1)).squeeze()
                elif embedding_tensor.ndim == 3 and embedding_tensor.shape[0] == FEATURE_DIM_PER_SLICE:
                     pooled_embedding = F.adaptive_avg_pool2d(embedding_tensor.unsqueeze(0), (1, 1)).squeeze()
                elif embedding_tensor.ndim == 1 and embedding_tensor.shape[0] == FEATURE_DIM_PER_SLICE:
                     pooled_embedding = embedding_tensor
                else:
                    raise ValueError(f"Unexpected embedding shape {embedding_tensor.shape} for {slice_path}.")

                if pooled_embedding.shape[0] != FEATURE_DIM_PER_SLICE:
                    raise ValueError(f"Pooled embedding shape is not {FEATURE_DIM_PER_SLICE} for {slice_path}: {pooled_embedding.shape}")

                slice_features.append(pooled_embedding)

            stacked_features = torch.stack(slice_features, dim=0)

            if FUSION_METHOD == 'average':
                fused_feature_vector = torch.mean(stacked_features, dim=0)
            # elif FUSION_METHOD == 'max':
            #     fused_feature_vector = torch.max(stacked_features, dim=0)[0]
            else:
                raise ValueError(f"Unknown fusion method: {FUSION_METHOD}")

            age = self.meta_dict[original_filename]
            age_tensor = torch.tensor(age, dtype=torch.float32)

            return fused_feature_vector, age_tensor

        except Exception as e:
            print(f"Error loading or processing slices for subject {subject_dir}: {e}")
            raise e

# --- Model Definition (Takes FUSED features as input) ---
# Input dim is FEATURE_DIM_PER_SLICE (256 after fusion)
class FusedAgeMLP(nn.Module):
    def __init__(self, input_dim=FEATURE_DIM_PER_SLICE, hidden_dim1=128, hidden_dim2=64, hidden_dim3=32, dropout_rate=DROPOUT_RATE):
        super(FusedAgeMLP, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1), nn.BatchNorm1d(hidden_dim1), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim1, hidden_dim2), nn.BatchNorm1d(hidden_dim2), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim2, hidden_dim3), nn.BatchNorm1d(hidden_dim3), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim3, 1)
        )

    def forward(self, x):
        return self.mlp(x).squeeze(-1)

# --- Training and Evaluation Functions (Identical) ---
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    num_samples = 0
    pbar = tqdm(loader, desc="Training", leave=False)
    for features, ages in pbar:
        features, ages = features.to(device), ages.to(device)
        optimizer.zero_grad()
        predictions = model(features)
        loss = criterion(predictions, ages)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * features.size(0)
        num_samples += features.size(0)
        pbar.set_postfix(loss=loss.item())
    return total_loss / num_samples

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_mae = 0.0
    num_samples = 0
    pbar = tqdm(loader, desc="Evaluating", leave=False)
    for features, ages in pbar:
        features, ages = features.to(device), ages.to(device)
        predictions = model(features)
        loss = criterion(predictions, ages)
        total_loss += loss.item() * features.size(0)
        mae = F.l1_loss(predictions, ages, reduction='sum')
        total_mae += mae.item()
        num_samples += features.size(0)
        pbar.set_postfix(loss=loss.item(), mae=(total_mae/num_samples if num_samples > 0 else 0))

    avg_loss = total_loss / num_samples if num_samples > 0 else 0
    avg_mae = total_mae / num_samples if num_samples > 0 else 0
    return avg_loss, avg_mae

# --- Main Training Script ---
if __name__ == "__main__":
    print(f"Starting feature fusion ({FUSION_METHOD} pooling) MLP training for Coronal slices {TARGET_CORONAL_INDICES}...")
    print(f"Target Slice Identifiers for Fusion: {TARGET_SLICE_IDENTIFIERS}")
    print(f"Fused Feature Dim: {FEATURE_DIM_PER_SLICE}")
    print(f"Using device: {DEVICE}")
    print(f"Feature Root: {FEATURE_ROOT}")
    print(f"CSV Path: {CSV_PATH}")
    print(f"Best model will be saved to: {MODEL_SAVE_PATH}")
    print(f"Hyperparameters: Epochs={EPOCHS}, LR={LEARNING_RATE}, Batch={BATCH_SIZE}, Dropout={DROPOUT_RATE}, ES_Patience={EARLY_STOPPING_PATIENCE}")

    # --- Setup Datasets and DataLoaders ---
    try:
        print("Setting up datasets...")
        train_dataset = FusedSliceDataset(FEATURE_ROOT, CSV_PATH, 'train', TARGET_SLICE_IDENTIFIERS)
        val_dataset = FusedSliceDataset(FEATURE_ROOT, CSV_PATH, 'validation', TARGET_SLICE_IDENTIFIERS)
        test_dataset = FusedSliceDataset(FEATURE_ROOT, CSV_PATH, 'test', TARGET_SLICE_IDENTIFIERS)

        print("Setting up dataloaders...")
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

    except (FileNotFoundError, ValueError, RuntimeError) as e:
        print(f"Error initializing datasets/loaders: {e}")
        raise SystemExit("Dataset/Loader initialization failed.")

    # --- Setup Model, Loss, Optimizer, Scheduler ---
    print("Initializing model, optimizer, scheduler...")
    model = FusedAgeMLP(input_dim=FEATURE_DIM_PER_SLICE, dropout_rate=DROPOUT_RATE).to(DEVICE)
    criterion = nn.L1Loss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler_patience_epochs = max(1, int(EPOCHS * SCHEDULER_PATIENCE_PERCENT))
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=scheduler_patience_epochs, factor=0.5, verbose=True)

    # --- Training Loop ---
    best_val_mae = float('inf')
    epochs_no_improve = 0
    best_epoch = -1
    history = {'train_loss': [], 'val_loss': [], 'val_mae': []}

    print(f"\n--- Starting Training ---")
    for epoch in range(EPOCHS):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE)
        val_loss, val_mae = evaluate(model, val_loader, criterion, DEVICE)

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_mae'].append(val_mae)

        print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val MAE: {val_mae:.3f}")

        scheduler.step(val_loss)

        if val_mae < best_val_mae:
            best_val_mae = val_mae
            best_epoch = epoch + 1
            epochs_no_improve = 0
            try:
                torch.save(model.state_dict(), MODEL_SAVE_PATH)
                print(f"  -> New best Val MAE: {best_val_mae:.3f}. Saved model to {MODEL_SAVE_PATH}")
            except Exception as e:
                print(f"  -> Error saving model: {e}")
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"\nEarly stopping triggered after {EARLY_STOPPING_PATIENCE} epochs without improvement.")
            break

        if (epoch + 1) % 20 == 0:
             gc.collect()
             if DEVICE.startswith('cuda'): torch.cuda.empty_cache()

    print(f"\n--- Training Finished ---")
    if best_epoch != -1:
         print(f"Best Validation MAE: {best_val_mae:.3f} achieved at epoch {best_epoch}")
         print(f"Best model saved to {MODEL_SAVE_PATH}")
    else:
         print("No improvement found during training or model saving failed.")

    # --- Evaluate on Test Set ---
    if MODEL_SAVE_PATH.exists() and best_epoch != -1:
        print(f"\n--- Evaluating on Test Set using Best Model ---")
        try:
            model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
            print(f"Loaded best model weights from {MODEL_SAVE_PATH}")
            test_loss, test_mae = evaluate(model, test_loader, criterion, DEVICE)
            print(f"Test Loss: {test_loss:.4f} | Test MAE: {test_mae:.3f}")
        except Exception as e:
            print(f"Error during test set evaluation: {e}")
    elif not MODEL_SAVE_PATH.exists():
         print(f"\nSkipping test set evaluation: Best model file not found at {MODEL_SAVE_PATH}")
    else:
         print("\nSkipping test set evaluation: No best model was saved during training.")

    print(f"\nFeature fusion ({FUSION_METHOD} pooling) MLP training script for Coronal {TARGET_CORONAL_INDICES} finished.")

Starting feature fusion (average pooling) MLP training for Coronal slices [120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130]...
Target Slice Identifiers for Fusion: ['120', '121', '122', '123', '124', '125', '126', '127', '128', '129', '130']
Fused Feature Dim: 256
Using device: cuda:1
Feature Root: /data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_coronal_120_130
CSV Path: /data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv
Best model will be saved to: fused_mlp_model_avg_pool_coronal_120_130.pth
Hyperparameters: Epochs=500, LR=0.0001, Batch=32, Dropout=0.3, ES_Patience=40
Setting up datasets...

[Dataset Init] Split: train, Fusing Coronal Slices: ['120', '121', '122', '123', '124', '125', '126', '127', '128', '129', '130']
Loading features from base: /data/kuang/Projects/MedSAM/data/BrainAGE_preprocessed_coronal_120_130
Loading metadata from: /data/kuang/Projects/MedSAM/data/Subject_demographics_info_brain_age.csv
Loaded metadata for 2850 subjects from 

                                                                              

Epoch 1/500 | Train Loss: 54.2875 | Val Loss: 53.6119 | Val MAE: 53.612
  -> New best Val MAE: 53.612. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 2/500 | Train Loss: 54.1740 | Val Loss: 53.4688 | Val MAE: 53.469
  -> New best Val MAE: 53.469. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 3/500 | Train Loss: 54.0290 | Val Loss: 53.4572 | Val MAE: 53.457
  -> New best Val MAE: 53.457. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 4/500 | Train Loss: 53.9281 | Val Loss: 53.3494 | Val MAE: 53.349
  -> New best Val MAE: 53.349. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 5/500 | Train Loss: 53.8107 | Val Loss: 53.2279 | Val MAE: 53.228
  -> New best Val MAE: 53.228. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 6/500 | Train Loss: 53.6928 | Val Loss: 53.1822 | Val MAE: 53.182
  -> New best Val MAE: 53.182. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 7/500 | Train Loss: 53.5555 | Val Loss: 53.0516 | Val MAE: 53.052
  -> New best Val MAE: 53.052. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 8/500 | Train Loss: 53.4488 | Val Loss: 52.9622 | Val MAE: 52.962
  -> New best Val MAE: 52.962. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 9/500 | Train Loss: 53.3214 | Val Loss: 52.8712 | Val MAE: 52.871
  -> New best Val MAE: 52.871. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 10/500 | Train Loss: 53.1788 | Val Loss: 52.7831 | Val MAE: 52.783
  -> New best Val MAE: 52.783. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 11/500 | Train Loss: 53.0538 | Val Loss: 52.6677 | Val MAE: 52.668
  -> New best Val MAE: 52.668. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 12/500 | Train Loss: 52.9132 | Val Loss: 52.5162 | Val MAE: 52.516
  -> New best Val MAE: 52.516. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 13/500 | Train Loss: 52.7215 | Val Loss: 52.4475 | Val MAE: 52.448
  -> New best Val MAE: 52.448. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 14/500 | Train Loss: 52.6025 | Val Loss: 52.3154 | Val MAE: 52.315
  -> New best Val MAE: 52.315. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 15/500 | Train Loss: 52.4676 | Val Loss: 52.1085 | Val MAE: 52.109
  -> New best Val MAE: 52.109. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 16/500 | Train Loss: 52.2945 | Val Loss: 51.9721 | Val MAE: 51.972
  -> New best Val MAE: 51.972. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 17/500 | Train Loss: 52.1428 | Val Loss: 52.0065 | Val MAE: 52.006


                                                                              

Epoch 18/500 | Train Loss: 51.9902 | Val Loss: 51.8506 | Val MAE: 51.851
  -> New best Val MAE: 51.851. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 19/500 | Train Loss: 51.8015 | Val Loss: 51.6434 | Val MAE: 51.643
  -> New best Val MAE: 51.643. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 20/500 | Train Loss: 51.6541 | Val Loss: 51.4563 | Val MAE: 51.456
  -> New best Val MAE: 51.456. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 21/500 | Train Loss: 51.4812 | Val Loss: 51.2823 | Val MAE: 51.282
  -> New best Val MAE: 51.282. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 22/500 | Train Loss: 51.2659 | Val Loss: 51.1357 | Val MAE: 51.136
  -> New best Val MAE: 51.136. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 23/500 | Train Loss: 51.1043 | Val Loss: 51.0856 | Val MAE: 51.086
  -> New best Val MAE: 51.086. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 24/500 | Train Loss: 50.9023 | Val Loss: 50.8292 | Val MAE: 50.829
  -> New best Val MAE: 50.829. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 25/500 | Train Loss: 50.6954 | Val Loss: 50.6728 | Val MAE: 50.673
  -> New best Val MAE: 50.673. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 26/500 | Train Loss: 50.5354 | Val Loss: 50.2349 | Val MAE: 50.235
  -> New best Val MAE: 50.235. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 27/500 | Train Loss: 50.3050 | Val Loss: 50.2336 | Val MAE: 50.234
  -> New best Val MAE: 50.234. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 28/500 | Train Loss: 50.0587 | Val Loss: 50.0472 | Val MAE: 50.047
  -> New best Val MAE: 50.047. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 29/500 | Train Loss: 49.8743 | Val Loss: 49.8691 | Val MAE: 49.869
  -> New best Val MAE: 49.869. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 30/500 | Train Loss: 49.6627 | Val Loss: 49.6369 | Val MAE: 49.637
  -> New best Val MAE: 49.637. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 31/500 | Train Loss: 49.4367 | Val Loss: 49.1881 | Val MAE: 49.188
  -> New best Val MAE: 49.188. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 32/500 | Train Loss: 49.2005 | Val Loss: 49.0304 | Val MAE: 49.030
  -> New best Val MAE: 49.030. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 33/500 | Train Loss: 48.9465 | Val Loss: 48.9309 | Val MAE: 48.931
  -> New best Val MAE: 48.931. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 34/500 | Train Loss: 48.6744 | Val Loss: 48.6766 | Val MAE: 48.677
  -> New best Val MAE: 48.677. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 35/500 | Train Loss: 48.4317 | Val Loss: 48.5504 | Val MAE: 48.550
  -> New best Val MAE: 48.550. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 36/500 | Train Loss: 48.1727 | Val Loss: 48.3363 | Val MAE: 48.336
  -> New best Val MAE: 48.336. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 37/500 | Train Loss: 47.8285 | Val Loss: 47.9552 | Val MAE: 47.955
  -> New best Val MAE: 47.955. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 38/500 | Train Loss: 47.6634 | Val Loss: 47.5331 | Val MAE: 47.533
  -> New best Val MAE: 47.533. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 39/500 | Train Loss: 47.3074 | Val Loss: 47.7698 | Val MAE: 47.770


                                                                              

Epoch 40/500 | Train Loss: 47.0557 | Val Loss: 46.8265 | Val MAE: 46.826
  -> New best Val MAE: 46.826. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 41/500 | Train Loss: 46.7474 | Val Loss: 46.5191 | Val MAE: 46.519
  -> New best Val MAE: 46.519. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 42/500 | Train Loss: 46.4259 | Val Loss: 46.6477 | Val MAE: 46.648


                                                                              

Epoch 43/500 | Train Loss: 46.1667 | Val Loss: 46.0953 | Val MAE: 46.095
  -> New best Val MAE: 46.095. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 44/500 | Train Loss: 45.8700 | Val Loss: 46.1053 | Val MAE: 46.105


                                                                              

Epoch 45/500 | Train Loss: 45.5613 | Val Loss: 45.5622 | Val MAE: 45.562
  -> New best Val MAE: 45.562. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 46/500 | Train Loss: 45.1427 | Val Loss: 45.0556 | Val MAE: 45.056
  -> New best Val MAE: 45.056. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 47/500 | Train Loss: 44.8288 | Val Loss: 44.7948 | Val MAE: 44.795
  -> New best Val MAE: 44.795. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 48/500 | Train Loss: 44.5475 | Val Loss: 44.4662 | Val MAE: 44.466
  -> New best Val MAE: 44.466. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 49/500 | Train Loss: 44.1818 | Val Loss: 44.3608 | Val MAE: 44.361
  -> New best Val MAE: 44.361. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 50/500 | Train Loss: 43.8799 | Val Loss: 43.9405 | Val MAE: 43.940
  -> New best Val MAE: 43.940. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 51/500 | Train Loss: 43.4296 | Val Loss: 43.7139 | Val MAE: 43.714
  -> New best Val MAE: 43.714. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 52/500 | Train Loss: 43.0883 | Val Loss: 43.2210 | Val MAE: 43.221
  -> New best Val MAE: 43.221. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 53/500 | Train Loss: 42.7773 | Val Loss: 42.5773 | Val MAE: 42.577
  -> New best Val MAE: 42.577. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 54/500 | Train Loss: 42.4658 | Val Loss: 42.3799 | Val MAE: 42.380
  -> New best Val MAE: 42.380. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 55/500 | Train Loss: 42.0371 | Val Loss: 41.9236 | Val MAE: 41.924
  -> New best Val MAE: 41.924. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 56/500 | Train Loss: 41.6840 | Val Loss: 41.7393 | Val MAE: 41.739
  -> New best Val MAE: 41.739. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 57/500 | Train Loss: 41.2728 | Val Loss: 41.2178 | Val MAE: 41.218
  -> New best Val MAE: 41.218. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 58/500 | Train Loss: 40.8780 | Val Loss: 40.9335 | Val MAE: 40.934
  -> New best Val MAE: 40.934. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 59/500 | Train Loss: 40.5315 | Val Loss: 40.3447 | Val MAE: 40.345
  -> New best Val MAE: 40.345. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 60/500 | Train Loss: 40.0427 | Val Loss: 40.2722 | Val MAE: 40.272
  -> New best Val MAE: 40.272. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 61/500 | Train Loss: 39.6840 | Val Loss: 39.7917 | Val MAE: 39.792
  -> New best Val MAE: 39.792. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 62/500 | Train Loss: 39.3937 | Val Loss: 39.1347 | Val MAE: 39.135
  -> New best Val MAE: 39.135. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 63/500 | Train Loss: 38.9201 | Val Loss: 39.0194 | Val MAE: 39.019
  -> New best Val MAE: 39.019. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 64/500 | Train Loss: 38.4393 | Val Loss: 38.2090 | Val MAE: 38.209
  -> New best Val MAE: 38.209. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 65/500 | Train Loss: 38.0817 | Val Loss: 38.0172 | Val MAE: 38.017
  -> New best Val MAE: 38.017. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 66/500 | Train Loss: 37.5893 | Val Loss: 37.5810 | Val MAE: 37.581
  -> New best Val MAE: 37.581. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 67/500 | Train Loss: 37.1697 | Val Loss: 37.3500 | Val MAE: 37.350
  -> New best Val MAE: 37.350. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 68/500 | Train Loss: 36.8667 | Val Loss: 36.4018 | Val MAE: 36.402
  -> New best Val MAE: 36.402. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 69/500 | Train Loss: 36.3615 | Val Loss: 36.0007 | Val MAE: 36.001
  -> New best Val MAE: 36.001. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 70/500 | Train Loss: 36.0545 | Val Loss: 35.2553 | Val MAE: 35.255
  -> New best Val MAE: 35.255. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 71/500 | Train Loss: 35.6027 | Val Loss: 34.8591 | Val MAE: 34.859
  -> New best Val MAE: 34.859. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 72/500 | Train Loss: 35.1319 | Val Loss: 34.6616 | Val MAE: 34.662
  -> New best Val MAE: 34.662. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 73/500 | Train Loss: 34.7332 | Val Loss: 34.4560 | Val MAE: 34.456
  -> New best Val MAE: 34.456. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 74/500 | Train Loss: 34.2823 | Val Loss: 33.9753 | Val MAE: 33.975
  -> New best Val MAE: 33.975. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 75/500 | Train Loss: 33.7968 | Val Loss: 32.7612 | Val MAE: 32.761
  -> New best Val MAE: 32.761. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 76/500 | Train Loss: 33.3087 | Val Loss: 33.0136 | Val MAE: 33.014


                                                                              

Epoch 77/500 | Train Loss: 33.0148 | Val Loss: 32.4875 | Val MAE: 32.488
  -> New best Val MAE: 32.488. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 78/500 | Train Loss: 32.3593 | Val Loss: 32.1971 | Val MAE: 32.197
  -> New best Val MAE: 32.197. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 79/500 | Train Loss: 31.8568 | Val Loss: 30.8787 | Val MAE: 30.879
  -> New best Val MAE: 30.879. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 80/500 | Train Loss: 31.5493 | Val Loss: 31.2089 | Val MAE: 31.209


                                                                              

Epoch 81/500 | Train Loss: 30.9200 | Val Loss: 31.0007 | Val MAE: 31.001


                                                                              

Epoch 82/500 | Train Loss: 30.6956 | Val Loss: 29.9315 | Val MAE: 29.932
  -> New best Val MAE: 29.932. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 83/500 | Train Loss: 30.0242 | Val Loss: 29.1914 | Val MAE: 29.191
  -> New best Val MAE: 29.191. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 84/500 | Train Loss: 29.5182 | Val Loss: 29.5569 | Val MAE: 29.557


                                                                              

Epoch 85/500 | Train Loss: 29.1720 | Val Loss: 28.9956 | Val MAE: 28.996
  -> New best Val MAE: 28.996. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 86/500 | Train Loss: 28.6142 | Val Loss: 28.0000 | Val MAE: 28.000
  -> New best Val MAE: 28.000. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 87/500 | Train Loss: 28.0696 | Val Loss: 27.6230 | Val MAE: 27.623
  -> New best Val MAE: 27.623. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 88/500 | Train Loss: 27.7484 | Val Loss: 27.7554 | Val MAE: 27.755


                                                                              

Epoch 89/500 | Train Loss: 27.3016 | Val Loss: 26.6312 | Val MAE: 26.631
  -> New best Val MAE: 26.631. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 90/500 | Train Loss: 26.7378 | Val Loss: 24.9484 | Val MAE: 24.948
  -> New best Val MAE: 24.948. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 91/500 | Train Loss: 26.3313 | Val Loss: 23.6669 | Val MAE: 23.667
  -> New best Val MAE: 23.667. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 92/500 | Train Loss: 25.7375 | Val Loss: 25.3459 | Val MAE: 25.346


                                                                              

Epoch 93/500 | Train Loss: 25.3498 | Val Loss: 24.0378 | Val MAE: 24.038


                                                                              

Epoch 94/500 | Train Loss: 24.8870 | Val Loss: 23.6777 | Val MAE: 23.678


                                                                              

Epoch 95/500 | Train Loss: 24.2999 | Val Loss: 22.8534 | Val MAE: 22.853
  -> New best Val MAE: 22.853. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 96/500 | Train Loss: 24.0714 | Val Loss: 23.3730 | Val MAE: 23.373


                                                                              

Epoch 97/500 | Train Loss: 23.3747 | Val Loss: 21.7609 | Val MAE: 21.761
  -> New best Val MAE: 21.761. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 98/500 | Train Loss: 22.8993 | Val Loss: 21.5396 | Val MAE: 21.540
  -> New best Val MAE: 21.540. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 99/500 | Train Loss: 22.6974 | Val Loss: 20.6382 | Val MAE: 20.638
  -> New best Val MAE: 20.638. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 100/500 | Train Loss: 22.0128 | Val Loss: 21.3130 | Val MAE: 21.313


                                                                              

Epoch 101/500 | Train Loss: 21.3560 | Val Loss: 22.0422 | Val MAE: 22.042


                                                                              

Epoch 102/500 | Train Loss: 20.9322 | Val Loss: 19.7850 | Val MAE: 19.785
  -> New best Val MAE: 19.785. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 103/500 | Train Loss: 20.4687 | Val Loss: 20.5792 | Val MAE: 20.579


                                                                              

Epoch 104/500 | Train Loss: 19.8896 | Val Loss: 18.9932 | Val MAE: 18.993
  -> New best Val MAE: 18.993. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 105/500 | Train Loss: 19.5454 | Val Loss: 18.9610 | Val MAE: 18.961
  -> New best Val MAE: 18.961. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 106/500 | Train Loss: 18.8247 | Val Loss: 17.3572 | Val MAE: 17.357
  -> New best Val MAE: 17.357. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 107/500 | Train Loss: 18.5210 | Val Loss: 16.6927 | Val MAE: 16.693
  -> New best Val MAE: 16.693. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 108/500 | Train Loss: 17.9580 | Val Loss: 18.1808 | Val MAE: 18.181


                                                                              

Epoch 109/500 | Train Loss: 17.4119 | Val Loss: 16.0821 | Val MAE: 16.082
  -> New best Val MAE: 16.082. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 110/500 | Train Loss: 16.9178 | Val Loss: 16.9096 | Val MAE: 16.910


                                                                              

Epoch 111/500 | Train Loss: 16.5804 | Val Loss: 16.1407 | Val MAE: 16.141


                                                                              

Epoch 112/500 | Train Loss: 16.2098 | Val Loss: 16.0093 | Val MAE: 16.009
  -> New best Val MAE: 16.009. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 113/500 | Train Loss: 15.8606 | Val Loss: 16.2306 | Val MAE: 16.231


                                                                              

Epoch 114/500 | Train Loss: 15.4427 | Val Loss: 15.2010 | Val MAE: 15.201
  -> New best Val MAE: 15.201. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 115/500 | Train Loss: 15.0767 | Val Loss: 13.7068 | Val MAE: 13.707
  -> New best Val MAE: 13.707. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 116/500 | Train Loss: 14.9888 | Val Loss: 14.0311 | Val MAE: 14.031


                                                                              

Epoch 117/500 | Train Loss: 14.2975 | Val Loss: 13.2899 | Val MAE: 13.290
  -> New best Val MAE: 13.290. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 118/500 | Train Loss: 14.1800 | Val Loss: 14.0533 | Val MAE: 14.053


                                                                              

Epoch 119/500 | Train Loss: 13.9522 | Val Loss: 11.8511 | Val MAE: 11.851
  -> New best Val MAE: 11.851. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 120/500 | Train Loss: 13.3630 | Val Loss: 11.9772 | Val MAE: 11.977


                                                                              

Epoch 121/500 | Train Loss: 13.1851 | Val Loss: 12.5152 | Val MAE: 12.515


                                                                              

Epoch 122/500 | Train Loss: 12.7852 | Val Loss: 11.9503 | Val MAE: 11.950


                                                                              

Epoch 123/500 | Train Loss: 12.8144 | Val Loss: 12.1763 | Val MAE: 12.176


                                                                              

Epoch 124/500 | Train Loss: 12.2688 | Val Loss: 11.4118 | Val MAE: 11.412
  -> New best Val MAE: 11.412. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 125/500 | Train Loss: 11.9156 | Val Loss: 10.7049 | Val MAE: 10.705
  -> New best Val MAE: 10.705. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 126/500 | Train Loss: 12.1394 | Val Loss: 10.6405 | Val MAE: 10.640
  -> New best Val MAE: 10.640. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 127/500 | Train Loss: 11.7996 | Val Loss: 9.5095 | Val MAE: 9.509
  -> New best Val MAE: 9.509. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 128/500 | Train Loss: 11.4969 | Val Loss: 10.7671 | Val MAE: 10.767


                                                                              

Epoch 129/500 | Train Loss: 11.4571 | Val Loss: 9.5920 | Val MAE: 9.592


                                                                              

Epoch 130/500 | Train Loss: 11.1453 | Val Loss: 10.6038 | Val MAE: 10.604


                                                                              

Epoch 131/500 | Train Loss: 10.9424 | Val Loss: 9.3645 | Val MAE: 9.365
  -> New best Val MAE: 9.365. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 132/500 | Train Loss: 10.7760 | Val Loss: 9.8142 | Val MAE: 9.814


                                                                              

Epoch 133/500 | Train Loss: 11.0362 | Val Loss: 9.1374 | Val MAE: 9.137
  -> New best Val MAE: 9.137. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 134/500 | Train Loss: 10.6597 | Val Loss: 8.8777 | Val MAE: 8.878
  -> New best Val MAE: 8.878. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 135/500 | Train Loss: 10.4194 | Val Loss: 10.3979 | Val MAE: 10.398


                                                                              

Epoch 136/500 | Train Loss: 10.3621 | Val Loss: 8.5180 | Val MAE: 8.518
  -> New best Val MAE: 8.518. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 137/500 | Train Loss: 10.5719 | Val Loss: 8.3304 | Val MAE: 8.330
  -> New best Val MAE: 8.330. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 138/500 | Train Loss: 10.3023 | Val Loss: 7.8517 | Val MAE: 7.852
  -> New best Val MAE: 7.852. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 139/500 | Train Loss: 10.4329 | Val Loss: 7.9605 | Val MAE: 7.961


                                                                              

Epoch 140/500 | Train Loss: 10.2939 | Val Loss: 7.6513 | Val MAE: 7.651
  -> New best Val MAE: 7.651. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 141/500 | Train Loss: 10.0490 | Val Loss: 7.9325 | Val MAE: 7.932


                                                                              

Epoch 142/500 | Train Loss: 10.0486 | Val Loss: 8.2821 | Val MAE: 8.282


                                                                              

Epoch 143/500 | Train Loss: 9.9898 | Val Loss: 8.8429 | Val MAE: 8.843


                                                                              

Epoch 144/500 | Train Loss: 9.8826 | Val Loss: 7.7119 | Val MAE: 7.712


                                                                              

Epoch 145/500 | Train Loss: 9.7971 | Val Loss: 7.5083 | Val MAE: 7.508
  -> New best Val MAE: 7.508. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 146/500 | Train Loss: 9.8522 | Val Loss: 8.6738 | Val MAE: 8.674


                                                                              

Epoch 147/500 | Train Loss: 9.8264 | Val Loss: 7.7783 | Val MAE: 7.778


                                                                              

Epoch 148/500 | Train Loss: 9.6885 | Val Loss: 7.7470 | Val MAE: 7.747


                                                                              

Epoch 149/500 | Train Loss: 9.9866 | Val Loss: 7.4814 | Val MAE: 7.481
  -> New best Val MAE: 7.481. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 150/500 | Train Loss: 9.5816 | Val Loss: 7.3299 | Val MAE: 7.330
  -> New best Val MAE: 7.330. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 151/500 | Train Loss: 9.8610 | Val Loss: 7.6957 | Val MAE: 7.696


                                                                              

Epoch 152/500 | Train Loss: 9.5754 | Val Loss: 7.3769 | Val MAE: 7.377


                                                                              

Epoch 153/500 | Train Loss: 9.3983 | Val Loss: 7.6800 | Val MAE: 7.680


                                                                              

Epoch 154/500 | Train Loss: 9.8741 | Val Loss: 7.3386 | Val MAE: 7.339


                                                                              

Epoch 155/500 | Train Loss: 9.5094 | Val Loss: 7.4059 | Val MAE: 7.406


                                                                              

Epoch 156/500 | Train Loss: 9.5052 | Val Loss: 8.0821 | Val MAE: 8.082


                                                                              

Epoch 157/500 | Train Loss: 9.4526 | Val Loss: 7.6788 | Val MAE: 7.679


                                                                              

Epoch 158/500 | Train Loss: 9.9306 | Val Loss: 7.4401 | Val MAE: 7.440


                                                                              

Epoch 159/500 | Train Loss: 9.6209 | Val Loss: 7.3518 | Val MAE: 7.352


                                                                              

Epoch 160/500 | Train Loss: 9.8077 | Val Loss: 7.2820 | Val MAE: 7.282
  -> New best Val MAE: 7.282. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 161/500 | Train Loss: 9.3995 | Val Loss: 7.4187 | Val MAE: 7.419


                                                                              

Epoch 162/500 | Train Loss: 9.5777 | Val Loss: 7.4919 | Val MAE: 7.492


                                                                              

Epoch 163/500 | Train Loss: 9.3840 | Val Loss: 7.2244 | Val MAE: 7.224
  -> New best Val MAE: 7.224. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 164/500 | Train Loss: 9.9735 | Val Loss: 7.3761 | Val MAE: 7.376


                                                                              

Epoch 165/500 | Train Loss: 9.6386 | Val Loss: 7.3061 | Val MAE: 7.306


                                                                              

Epoch 166/500 | Train Loss: 10.0165 | Val Loss: 6.9860 | Val MAE: 6.986
  -> New best Val MAE: 6.986. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 167/500 | Train Loss: 9.3355 | Val Loss: 7.5021 | Val MAE: 7.502


                                                                              

Epoch 168/500 | Train Loss: 9.3350 | Val Loss: 7.2044 | Val MAE: 7.204


                                                                              

Epoch 169/500 | Train Loss: 9.5110 | Val Loss: 7.3677 | Val MAE: 7.368


                                                                              

Epoch 170/500 | Train Loss: 9.3733 | Val Loss: 7.5166 | Val MAE: 7.517


                                                                              

Epoch 171/500 | Train Loss: 9.6272 | Val Loss: 7.1386 | Val MAE: 7.139


                                                                              

Epoch 172/500 | Train Loss: 9.5017 | Val Loss: 7.4890 | Val MAE: 7.489


                                                                              

Epoch 173/500 | Train Loss: 9.4376 | Val Loss: 6.9857 | Val MAE: 6.986
  -> New best Val MAE: 6.986. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 174/500 | Train Loss: 9.3915 | Val Loss: 7.0737 | Val MAE: 7.074


                                                                              

Epoch 175/500 | Train Loss: 9.5902 | Val Loss: 7.1989 | Val MAE: 7.199


                                                                              

Epoch 176/500 | Train Loss: 9.1375 | Val Loss: 6.9915 | Val MAE: 6.992


                                                                              

Epoch 177/500 | Train Loss: 9.3452 | Val Loss: 6.9271 | Val MAE: 6.927
  -> New best Val MAE: 6.927. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 178/500 | Train Loss: 9.4251 | Val Loss: 6.9337 | Val MAE: 6.934


                                                                              

Epoch 179/500 | Train Loss: 9.5191 | Val Loss: 7.3349 | Val MAE: 7.335


                                                                              

Epoch 180/500 | Train Loss: 9.4230 | Val Loss: 6.8787 | Val MAE: 6.879
  -> New best Val MAE: 6.879. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 181/500 | Train Loss: 9.2206 | Val Loss: 6.9510 | Val MAE: 6.951


                                                                              

Epoch 182/500 | Train Loss: 9.3509 | Val Loss: 7.1393 | Val MAE: 7.139


                                                                              

Epoch 183/500 | Train Loss: 9.4592 | Val Loss: 7.2990 | Val MAE: 7.299


                                                                              

Epoch 184/500 | Train Loss: 9.6816 | Val Loss: 7.0267 | Val MAE: 7.027


                                                                              

Epoch 185/500 | Train Loss: 8.9257 | Val Loss: 7.0624 | Val MAE: 7.062


                                                                              

Epoch 186/500 | Train Loss: 9.4929 | Val Loss: 7.2032 | Val MAE: 7.203


                                                                              

Epoch 187/500 | Train Loss: 9.1691 | Val Loss: 7.0979 | Val MAE: 7.098


                                                                              

Epoch 188/500 | Train Loss: 9.3250 | Val Loss: 7.0467 | Val MAE: 7.047


                                                                              

Epoch 189/500 | Train Loss: 9.3788 | Val Loss: 6.9255 | Val MAE: 6.926


                                                                              

Epoch 190/500 | Train Loss: 9.4041 | Val Loss: 6.7827 | Val MAE: 6.783
  -> New best Val MAE: 6.783. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 191/500 | Train Loss: 9.2505 | Val Loss: 6.9986 | Val MAE: 6.999


                                                                              

Epoch 192/500 | Train Loss: 9.3126 | Val Loss: 6.6925 | Val MAE: 6.692
  -> New best Val MAE: 6.692. Saved model to fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                              

Epoch 193/500 | Train Loss: 9.2313 | Val Loss: 6.8866 | Val MAE: 6.887


                                                                              

Epoch 194/500 | Train Loss: 9.0571 | Val Loss: 6.9718 | Val MAE: 6.972


                                                                              

Epoch 195/500 | Train Loss: 9.4948 | Val Loss: 7.1458 | Val MAE: 7.146


                                                                              

Epoch 196/500 | Train Loss: 9.2252 | Val Loss: 6.8835 | Val MAE: 6.884


                                                                              

Epoch 197/500 | Train Loss: 9.0976 | Val Loss: 6.8495 | Val MAE: 6.850


                                                                              

Epoch 198/500 | Train Loss: 9.2518 | Val Loss: 6.9405 | Val MAE: 6.941


                                                                              

Epoch 199/500 | Train Loss: 9.2620 | Val Loss: 6.8251 | Val MAE: 6.825


                                                                              

Epoch 200/500 | Train Loss: 9.1508 | Val Loss: 7.0214 | Val MAE: 7.021


                                                                              

Epoch 201/500 | Train Loss: 9.0995 | Val Loss: 6.8459 | Val MAE: 6.846


                                                                              

Epoch 202/500 | Train Loss: 9.2542 | Val Loss: 6.9494 | Val MAE: 6.949


                                                                              

Epoch 203/500 | Train Loss: 9.4390 | Val Loss: 6.8511 | Val MAE: 6.851


                                                                              

Epoch 204/500 | Train Loss: 8.9835 | Val Loss: 6.7988 | Val MAE: 6.799


                                                                              

Epoch 205/500 | Train Loss: 9.4374 | Val Loss: 6.9409 | Val MAE: 6.941


                                                                              

Epoch 206/500 | Train Loss: 9.1314 | Val Loss: 6.8150 | Val MAE: 6.815


                                                                              

Epoch 207/500 | Train Loss: 9.1699 | Val Loss: 7.0878 | Val MAE: 7.088


                                                                              

Epoch 208/500 | Train Loss: 9.3498 | Val Loss: 6.8732 | Val MAE: 6.873


                                                                              

Epoch 209/500 | Train Loss: 9.1226 | Val Loss: 6.9563 | Val MAE: 6.956


                                                                              

Epoch 210/500 | Train Loss: 9.3128 | Val Loss: 7.0457 | Val MAE: 7.046


                                                                              

Epoch 211/500 | Train Loss: 9.0309 | Val Loss: 6.8429 | Val MAE: 6.843


                                                                              

Epoch 212/500 | Train Loss: 9.1723 | Val Loss: 7.0807 | Val MAE: 7.081


                                                                              

Epoch 213/500 | Train Loss: 9.0588 | Val Loss: 6.9640 | Val MAE: 6.964


                                                                              

Epoch 214/500 | Train Loss: 9.1037 | Val Loss: 7.0310 | Val MAE: 7.031


                                                                              

Epoch 215/500 | Train Loss: 9.2086 | Val Loss: 6.8893 | Val MAE: 6.889


                                                                              

Epoch 216/500 | Train Loss: 9.3317 | Val Loss: 6.8486 | Val MAE: 6.849


                                                                              

Epoch 217/500 | Train Loss: 9.0031 | Val Loss: 6.9251 | Val MAE: 6.925


                                                                              

Epoch 218/500 | Train Loss: 8.9669 | Val Loss: 6.9444 | Val MAE: 6.944


                                                                              

Epoch 219/500 | Train Loss: 9.0458 | Val Loss: 6.8939 | Val MAE: 6.894


                                                                              

Epoch 220/500 | Train Loss: 9.1502 | Val Loss: 6.9259 | Val MAE: 6.926


                                                                              

Epoch 221/500 | Train Loss: 9.1316 | Val Loss: 7.0416 | Val MAE: 7.042


                                                                              

Epoch 222/500 | Train Loss: 8.8570 | Val Loss: 6.7551 | Val MAE: 6.755


                                                                              

Epoch 223/500 | Train Loss: 9.1416 | Val Loss: 6.9232 | Val MAE: 6.923


                                                                              

Epoch 224/500 | Train Loss: 9.1907 | Val Loss: 7.0286 | Val MAE: 7.029


                                                                              

Epoch 225/500 | Train Loss: 9.1011 | Val Loss: 6.9183 | Val MAE: 6.918


                                                                              

Epoch 226/500 | Train Loss: 9.0350 | Val Loss: 6.9535 | Val MAE: 6.953


                                                                              

Epoch 227/500 | Train Loss: 9.1556 | Val Loss: 6.9345 | Val MAE: 6.935


                                                                              

Epoch 228/500 | Train Loss: 9.1376 | Val Loss: 6.8911 | Val MAE: 6.891


                                                                              

Epoch 229/500 | Train Loss: 8.8160 | Val Loss: 7.0914 | Val MAE: 7.091


                                                                              

Epoch 230/500 | Train Loss: 9.0500 | Val Loss: 7.3193 | Val MAE: 7.319


                                                                              

Epoch 231/500 | Train Loss: 9.1384 | Val Loss: 6.9956 | Val MAE: 6.996


                                                                              

Epoch 232/500 | Train Loss: 9.0564 | Val Loss: 7.0950 | Val MAE: 7.095

Early stopping triggered after 40 epochs without improvement.

--- Training Finished ---
Best Validation MAE: 6.692 achieved at epoch 192
Best model saved to fused_mlp_model_avg_pool_coronal_120_130.pth

--- Evaluating on Test Set using Best Model ---
Loaded best model weights from fused_mlp_model_avg_pool_coronal_120_130.pth


                                                                               

Test Loss: 6.9427 | Test MAE: 6.943

Feature fusion (average pooling) MLP training script for Coronal [120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130] finished.


