In [None]:
# CELL 1: Setup, Imports, and Configuration
print("--- Cell 1: Setup, Imports, and Configuration ---")

# Core Libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import random
import glob
from math import ceil
import gc
import inspect
from functools import partial
import json
import traceback
import time  # For timing experiments

# Scikit-learn
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    classification_report, confusion_matrix, roc_auc_score, roc_curve,
    accuracy_score, precision_score, recall_score, f1_score, precision_recall_curve
)
from sklearn.utils import class_weight

# TensorFlow / Keras
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, applications
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.layers import (
    Dense, Dropout, BatchNormalization, GlobalAveragePooling2D, GlobalMaxPooling2D,
    Input, Conv2D, Add, Multiply, Activation, Concatenate, Reshape, Layer, Softmax
)
from tensorflow.keras.layers.experimental import preprocessing as keras_preprocessing
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.mixed_precision import Policy, set_global_policy
import tensorflow_addons as tfa
# print(f"TensorFlow Version: {tf.__version__}")
# print(f"Keras Version: {keras.__version__}")

# --- Configuration ---
SEED = 42
IMG_SIZE = 224 # Using a single dimension for square images
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
BATCH_SIZE_PER_REPLICA = 32 # Adjust based on GPU memory
EPOCHS_HEAD = 30       # Epochs for initial training (head only)
EPOCHS_FINETUNE = 40     # Epochs for fine-tuning
LEARNING_RATE = .0001
LEARNING_RATE_FINETUNE = .00001
DROPOUT_RATE = 0.3
PATIENCE_EARLY_STOPPING = 8
PATIENCE_REDUCE_LR = 4
MIN_LR = 1e-7
TARGET_METRIC = 'f1_opt' 

# Experiment Tracking - Simple dictionary for results
results = {}
# Directory for saving model checkpoints
CHECKPOINT_DIR = "/kaggle/working/checkpoints"
GRADCAM_DIR = "/kaggle/working/gradcam_outputs"
METRICS_DIR = "/kaggle/working/saved"
PLOTS_DIR = "/kaggle/working/plots"
# Create directories if they don't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(GRADCAM_DIR, exist_ok=True)
os.makedirs(METRICS_DIR, exist_ok=True)
os.makedirs(PLOTS_DIR, exist_ok=True)


print(f"Checkpoint directory: {CHECKPOINT_DIR}")
print(f"Grad-CAM output directory: {GRADCAM_DIR}")

# --- Hardware Setup ---
# Mixed Precision (Optional but recommended for speed/memory on compatible GPUs)
try:
    policy = Policy('mixed_float16')
    set_global_policy(policy)
    print('Mixed precision enabled: Compute dtype=%s, Variable dtype=%s' % (
          policy.compute_dtype, policy.variable_dtype))
except Exception as e:
    print(f"Could not enable mixed precision: {e}")

# GPU Configuration
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"Found {len(gpus)} GPUs. Memory growth enabled.")
        # Set up distribution strategy
        if len(gpus) > 1:
            strategy = tf.distribute.MirroredStrategy()
            print(f"Using MirroredStrategy with {strategy.num_replicas_in_sync} devices.")
        else:
            strategy = tf.distribute.get_strategy() # Default strategy for single GPU
            print("Using default strategy for single GPU.")
    except RuntimeError as e:
        print(f"GPU setup error: {e}. Falling back to default strategy.")
        strategy = tf.distribute.get_strategy()
else:
    print("No GPUs found. Using default strategy (CPU).")
    strategy = tf.distribute.get_strategy()

# Calculate Global Batch Size
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
print(f"Global Batch Size: {GLOBAL_BATCH_SIZE}")

# AUTOTUNE for tf.data pipelines
AUTOTUNE = tf.data.AUTOTUNE

# --- Reproducibility ---
# os.environ['TF_DETERMINISTIC_OPS'] = '1' # Commented out: Caused UnimplementedError on GPU
# os.environ['TF_CUDNN_DETERMINISTIC'] = '1' # Also comment out or set to '0' if needed
print("Note: TF_DETERMINISTIC_OPS disabled for GPU compatibility. Minor non-determinism may occur.")
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)
# Set Python hash seed (for certain operations)
os.environ['PYTHONHASHSEED'] = str(SEED)

print("--- Setup Complete ---")




In [None]:
# CELL 2: Data Loading and Path Definition
print("\n--- Cell 2: Data Loading and Path Definition ---")

# --- Dataset Paths ---
# Adjust BASE_PATH if your dataset is located elsewhere
try:
    if os.path.exists("/kaggle/input/chest-xray-pneumonia/chest_xray/"):
        BASE_PATH = "/kaggle/input/chest-xray-pneumonia/chest_xray/"
        print("Using Kaggle dataset path.")
    else:
        # Example for local structure - MODIFY AS NEEDED
        local_path = "./chest_xray/"
        if os.path.exists(local_path):
             BASE_PATH = local_path
             print(f"Using local dataset path: {BASE_PATH}")
        else:
             raise FileNotFoundError("Dataset base path not found locally or on Kaggle.")

    TRAIN_PATH = os.path.join(BASE_PATH, "train")
    VAL_PATH = os.path.join(BASE_PATH, "val")
    TEST_PATH = os.path.join(BASE_PATH, "test")

    # Basic check for subdirectories
    for p in [TRAIN_PATH, VAL_PATH, TEST_PATH]:
        if not os.path.exists(p):
            print(f"WARNING: Dataset directory does not exist: {p}")
        elif not os.listdir(p):
             print(f"WARNING: Dataset directory is empty: {p}")

except FileNotFoundError as e:
    print(f"ERROR: {e}")
    print("Please ensure the dataset is available and BASE_PATH is set correctly.")
    # Optionally, raise the error again to stop execution
    # raise

# --- Load All Image Paths and Labels ---
def load_image_paths_and_labels(base_dir, label_map=None):
    """Loads image paths and numeric labels from subdirectories."""
    paths = []
    labels = []
    new_label_map = {}
    is_new_map = False
    if label_map is None:
        label_map = {}
        is_new_map = True

    print(f"Loading data from: {base_dir}")
    try:
        categories = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))])
        if not categories:
            print(f"ERROR: No category subdirectories found in {base_dir}")
            return [], [], {}

        for i, category in enumerate(categories):
            if is_new_map:
                label_map[category.upper()] = i # Use uppercase for consistency
            label = label_map.get(category.upper())
            if label is None:
                 print(f"Warning: Category '{category}' not found in provided label map. Skipping.")
                 continue

            category_path = os.path.join(base_dir, category)
            image_files = []
            # Common image extensions (consider adding .png if needed)
            for ext in ['*.jpeg', '*.jpg', '*.png']:
                image_files.extend(glob.glob(os.path.join(category_path, ext)))

            print(f"  Found {len(image_files)} images for '{category}' (Label: {label})")
            paths.extend(image_files)
            labels.extend([label] * len(image_files))

    except FileNotFoundError:
        print(f"ERROR: Directory not found: {base_dir}")
    except Exception as e:
        print(f"ERROR loading from {base_dir}: {e}")

    if not paths:
        print(f"WARNING: No images found in {base_dir}")

    return paths, labels, label_map if is_new_map else {}


# Load paths and create label map from the training directory initially
print("Loading initial paths from TRAIN directory to define labels...")
try:
    train_paths_orig, train_labels_orig, label_dict = load_image_paths_and_labels(TRAIN_PATH)
    if not label_dict:
        raise ValueError("Could not determine label mapping from training data.")

    print("\nLoading paths from VAL directory...")
    val_paths_orig, val_labels_orig, _ = load_image_paths_and_labels(VAL_PATH, label_dict)
    print("\nLoading paths from TEST directory...")
    test_paths_orig, test_labels_orig, _ = load_image_paths_and_labels(TEST_PATH, label_dict)

    # Create inverse mapping for display purposes
    inv_label_dict = {v: k for k, v in label_dict.items()}
    print(f"\nLabel Mapping: {label_dict}")
    print(f"Inverse Label Mapping: {inv_label_dict}")

    # Combine all paths and labels
    all_paths = train_paths_orig + val_paths_orig + test_paths_orig
    all_labels = train_labels_orig + val_labels_orig + test_labels_orig
    print(f"\nTotal images found across all sets: {len(all_paths)}")

    if len(all_paths) == 0:
        raise ValueError("No images loaded. Check dataset paths and structure.")
    if len(all_paths) != len(all_labels):
         raise ValueError("Mismatch between loaded image paths and labels count.")

    # Clean up original lists to save memory
    del train_paths_orig, train_labels_orig, val_paths_orig, val_labels_orig
    del test_paths_orig, test_labels_orig
    gc.collect()

except Exception as e:
    print(f"ERROR during data loading: {e}")
    # Consider raising the error to halt if loading is critical
    # raise



In [None]:
# CELL 3: Stratified Splitting, Class Weights, and Visualization
print("\n--- Cell 3: Stratified Splitting, Class Weights, and Visualization ---")

if 'all_paths' not in locals() or not all_paths:
    print("ERROR: 'all_paths' not available. Cannot perform splitting. Check Cell 2.")
else:
    print(f"Performing 75/5/20 stratified split on {len(all_paths)} images...")

    # --- Perform Stratified Split ---
    # Ensure labels are numpy array for stratification
    all_labels_np = np.array(all_labels)

    # Split 1: Separate Test set (20%)
    try:
        train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
            all_paths, all_labels_np,
            test_size=0.20,  # 20% for test
            random_state=SEED,
            stratify=all_labels_np,
            shuffle=True
        )

        # Split 2: Separate Validation set from the remaining Train/Val (target 5% of original total)
        # Calculate validation size relative to the train_val set
        original_total = len(all_paths)
        val_target_size = 0.05
        test_actual_size = len(test_paths) / original_total
        # Relative size: val_target / (1 - test_actual_size)
        val_relative_size = val_target_size / (1.0 - test_actual_size)

        if val_relative_size >= 1.0 or val_relative_size <= 0:
             print(f"Warning: Calculated relative validation size ({val_relative_size:.3f}) is invalid. Setting validation set to empty.")
             train_paths, val_paths, train_labels, val_labels = train_val_paths, [], train_val_labels, []
        else:
            train_paths, val_paths, train_labels, val_labels = train_test_split(
                train_val_paths, train_val_labels,
                test_size=val_relative_size,
                random_state=SEED,
                stratify=train_val_labels,
                shuffle=True
            )

        print("\nSplit complete:")
        print(f"  Train Set:      {len(train_paths):>6} images ({len(train_paths)/original_total:7.1%})")
        print(f"  Validation Set: {len(val_paths):>6} images ({len(val_paths)/original_total:7.1%})")
        print(f"  Test Set:       {len(test_paths):>6} images ({len(test_paths)/original_total:7.1%})")
        print(f"  Total Verified: {len(train_paths) + len(val_paths) + len(test_paths):>6} images")

        # Convert labels back to lists (optional, but consistent with input)
        train_labels = list(train_labels)
        val_labels = list(val_labels)
        test_labels = list(test_labels)

    except ValueError as e:
         print(f"ERROR during train/test split: {e}")
         print("This might happen if a class has only 1 sample. Check dataset balance.")
         # Consider stopping execution if split fails
         raise
    except Exception as e:
        print(f"An unexpected error occurred during splitting: {e}")
        raise

    # Clean up intermediate variables
    del all_paths, all_labels, all_labels_np, train_val_paths, train_val_labels
    gc.collect()

    # --- Calculate Class Weights (using the new train_labels) ---
    print("\nCalculating Class Weights for Training Set...")
    unique_classes, class_counts = np.unique(train_labels, return_counts=True)

    if len(unique_classes) < 2:
        print("WARNING: Only one class found in the training set. Class weights set to None.")
        class_weights_dict = None
    else:
        # Calculate balanced weights
        total_samples = len(train_labels)
        num_classes = len(unique_classes)
        weights = total_samples / (num_classes * class_counts)
        class_weights_dict = dict(zip(unique_classes, weights))
        print("Calculated Class Weights:")
        for cls, weight in class_weights_dict.items():
            print(f"  Class {cls} ({inv_label_dict[cls]}): {weight:.4f}")

    # --- Visualize Split Distributions ---
    def plot_split_distributions(split_data, label_map_inv):
        """Generates bar plots showing class counts and percentages within each split."""
        num_splits = len(split_data)
        if num_splits == 0: return

        fig, axes = plt.subplots(1, num_splits, figsize=(6 * num_splits, 5), sharey=False)
        if num_splits == 1: axes = [axes] # Ensure iterable

        fig.suptitle(f'Dataset Split Class Distribution', fontsize=16, y=1.03)
        class_names = sorted(list(label_map_inv.values()))
        palette = sns.color_palette('viridis', n_colors=len(class_names))

        for i, (name, (paths, labels)) in enumerate(split_data.items()):
            ax = axes[i]
            count = len(labels)
            ax.set_title(f"{name} Set ({count} images)")

            if count > 0:
                counts_series = pd.Series(labels).map(label_map_inv).value_counts().reindex(class_names, fill_value=0)
                percentages = (counts_series / count) * 100

                sns.barplot(x=counts_series.index, y=percentages.values, ax=ax, palette=palette, order=class_names)
                ax.set_ylabel("Percentage (%)" if i == 0 else "")
                ax.set_ylim(0, 105)
                ax.tick_params(axis='x', rotation=0)

                # Add percentage labels
                if ax.containers:
                    try:
                        ax.bar_label(ax.containers[0], fmt='%.1f%%', padding=3, fontsize=9)
                    except IndexError: pass # Handle potential issues
            else:
                ax.text(0.5, 0.5, 'No Data', ha='center', va='center', transform=ax.transAxes)
                ax.set_ylim(0, 105)
            ax.grid(axis='y', linestyle='--', alpha=0.7)

        plt.tight_layout(rect=[0, 0, 1, 0.97])
        plt.show()

    split_summary_data = {
        "Train": (train_paths, train_labels),
        "Validation": (val_paths, val_labels),
        "Test": (test_paths, test_labels)
    }
    plot_split_distributions(split_summary_data, inv_label_dict)




In [None]:
# CELL 4: Augmentation Layer and Preprocessing Functions
print("\n--- Cell 4: Augmentation Layer and Preprocessing Functions ---")

# --- Standard Augmentation Layer ---
# Using Keras preprocessing layers for GPU acceleration
standard_augmentation = tf.keras.Sequential([
    layers.Input(shape=IMG_SHAPE),
    keras_preprocessing.RandomFlip("horizontal", seed=SEED),
    keras_preprocessing.RandomRotation(0.1, seed=SEED), # Slight rotation
    keras_preprocessing.RandomZoom(height_factor=0.1, width_factor=0.1, seed=SEED), # Slight zoom
    # keras_preprocessing.RandomContrast(0.1, seed=SEED), # Optional: slight contrast change
    # keras_preprocessing.RandomBrightness(0.1, seed=SEED), # Optional: slight brightness change
], name='standard_augmentation')

# --- Image Decoding ---
@tf.function
def decode_image(image_bytes):
    """Decodes JPEG/PNG, converts to float32, ensures 1 channel."""
    # Try JPEG first, then PNG
    img = tf.io.decode_image(image_bytes, channels=1, expand_animations=False) # Ensure single channel
    img = tf.cast(img, tf.float32)
    return img

# --- CLAHE Application (using tf.py_function) ---
# --- CLAHE Application (using tf.py_function) ---
def apply_cv_clahe_np(image_np, clip_limit, grid_size): # grid_size might arrive as tensor tuple
    """Applies CLAHE using OpenCV to a NumPy array, handling potential tensor inputs."""
    # Input image is expected float32, convert to uint8 for OpenCV
    if image_np.shape[-1] != 1: # Ensure single channel
        print("Warning: Image for CLAHE is not single channel, attempting grayscale conversion.")
        if len(image_np.shape) == 3 and image_np.shape[-1] == 3:
             image_uint8 = cv2.cvtColor(image_np.astype(np.uint8), cv2.COLOR_RGB2GRAY)
        else:
             print("Error: Cannot convert image to grayscale for CLAHE.")
             return image_np.astype(np.float32) # Return original as float
    else:
        # Squeeze channel dim if present, ensure uint8
        image_uint8 = np.squeeze(image_np).astype(np.uint8)

    # --- ADD EXPLICIT CONVERSION FOR grid_size ---
    try:
        # Check if grid_size items are TF Tensors (have .numpy()) and convert, otherwise assume Python type
        tile_h = int(grid_size[0].numpy()) if hasattr(grid_size[0], 'numpy') else int(grid_size[0])
        tile_w = int(grid_size[1].numpy()) if hasattr(grid_size[1], 'numpy') else int(grid_size[1])
        cv2_grid_size = (tile_h, tile_w)
        # print(f"Debug: Converted grid_size to {cv2_grid_size}") # Optional debug print
    except Exception as e:
        # Fallback to default if conversion fails for any reason
        print(f"Warning: Could not parse grid_size ({grid_size}). Using default (8, 8). Error: {e}")
        cv2_grid_size = (8, 8)
    # --- END CONVERSION ---

    try:
        # Use the converted tuple of Python ints
        clahe = cv2.createCLAHE(clipLimit=float(clip_limit), tileGridSize=cv2_grid_size)
        clahe_img = clahe.apply(image_uint8)
        # Add channel dimension back and cast to float32 for TF
        processed_image = np.expand_dims(clahe_img, axis=-1).astype(np.float32)
    except Exception as cv_e:
         print(f"Error during cv2.createCLAHE or apply: {cv_e}")
         # Return original image (uint8 converted back to float32 with channel) if CLAHE fails
         processed_image = np.expand_dims(image_uint8, axis=-1).astype(np.float32)

    return processed_image
def tf_apply_clahe(image, clip_limit, grid_size=(8, 8)):
    """TensorFlow wrapper for applying CLAHE using py_function."""
    # Input image is expected to be float32, shape [H, W, 1]
    # Need float32 output from py_func
    processed_image = tf.py_function(
        func=apply_cv_clahe_np,
        inp=[image, clip_limit, grid_size], # Pass clip_limit and grid_size
        Tout=tf.float32
    )
    # Ensure shape is set after py_function
    processed_image.set_shape([None, None, 1]) # Keep channel dim
    return processed_image

# --- Unified Preprocessing Function ---
@tf.function
def preprocess_image(image_path, label, img_size=IMG_SIZE, apply_augment=False, augment_layer=None,
                     apply_clahe=False, clahe_clip_limit=2.0, clahe_grid_size=(8, 8)):
    """
    Loads, decodes, optionally applies CLAHE, resizes, optionally applies augmentation,
    and preprocesses the image for DenseNet121.
    """
    image_bytes = tf.io.read_file(image_path)
    image = decode_image(image_bytes) # Decodes to grayscale float32 [H, W, 1]

    # 1. Apply CLAHE (if enabled) - BEFORE resizing for better effect
    if apply_clahe:
        image = tf_apply_clahe(image, clip_limit=clahe_clip_limit, grid_size=clahe_grid_size)

    # 2. Resize
    image = tf.image.resize(image, [img_size, img_size], method=tf.image.ResizeMethod.BILINEAR) # Use bilinear for float images

    # 3. Convert Grayscale to RGB (Required by DenseNet)
    image = tf.image.grayscale_to_rgb(image) # Converts [H, W, 1] to [H, W, 3]

    # 4. Apply Augmentation (if enabled) - AFTER resizing and RGB conversion
    if apply_augment and augment_layer is not None:
        # Keras layers expect batch dimension
        image = tf.expand_dims(image, axis=0)
        image = augment_layer(image, training=True) # Apply augmentation
        image = tf.squeeze(image, axis=0) # Remove batch dimension
        image = tf.cast(image, tf.float32) # Ensure float32 after augmentation

    # 5. Preprocess for DenseNet121
    image = tf.keras.applications.densenet.preprocess_input(image)

    return image, label

# --- Dataset Building Function ---
def build_dataset(paths, labels, preprocess_fn_base, preprocess_args,
                  batch_size, dataset_name="Dataset",
                  shuffle=False, augment_in_map=False, oversample=False,
                  cache=True):
    """
    Builds a tf.data.Dataset with preprocessing, optional shuffling,
    optional oversampling, batching, and prefetching.

    Args:
        paths (list): List of image file paths.
        labels (list): List of corresponding labels.
        preprocess_fn_base (function): The base preprocessing function (e.g., preprocess_image).
        preprocess_args (dict): Dictionary of arguments for the preprocessing function.
        batch_size (int): Global batch size.
        dataset_name (str): Name for printing messages.
        shuffle (bool): Whether to shuffle the dataset (typically True for train).
        augment_in_map (bool): Whether to apply augmentation within the map function.
                               (Passed via preprocess_args['apply_augment']).
        oversample (bool): Whether to oversample the minority class (typically True for train).
        cache (bool or str): Whether to cache the dataset (True for memory, file path for disk).
    """
    if not paths:
        print(f"WARNING [{dataset_name}]: Empty paths list provided. Returning None.")
        return None
    if len(paths) != len(labels):
        print(f"ERROR [{dataset_name}]: Mismatch paths ({len(paths)}) vs labels ({len(labels)}).")
        return None

    AUTO = tf.data.AUTOTUNE
    num_classes = len(np.unique(labels))

    # Create the partial function for mapping BEFORE creating the dataset slices
    # This ensures all necessary arguments are bound.
    map_fn = partial(preprocess_fn_base, **preprocess_args)

    ds = tf.data.Dataset.from_tensor_slices((paths, labels))

    # Apply mapping function (preprocessing)
    ds = ds.map(map_fn, num_parallel_calls=AUTO)

    # Apply caching (after mapping, before repeating/shuffling/sampling)
    if cache:
        if isinstance(cache, str): # Disk caching
            safe_suffix = "".join(c if c.isalnum() else "_" for c in dataset_name)
            cache_file = os.path.join(CHECKPOINT_DIR, f"tf_cache_{safe_suffix}") # Use checkpoint dir
            ds = ds.cache(cache_file)
            print(f"[{dataset_name}] Caching to disk: {cache_file}")
        else: # Memory caching
            ds = ds.cache()
            print(f"[{dataset_name}] Caching to memory.")

    # --- Oversampling (if enabled, typically only for training set) ---
    if oversample and num_classes > 1 and shuffle: # Only makes sense for training
        print(f"[{dataset_name}] Applying oversampling...")
        unique_cls, _ = np.unique(labels, return_counts=True)
        datasets_by_class = []
        # Filter dataset for each class
        for cls_index in unique_cls:
            datasets_by_class.append(ds.filter(lambda img, lbl: lbl == cls_index))

        # Define desired distribution (equal probability for each class)
        target_dist = [1.0 / num_classes] * num_classes

        # Use sample_from_datasets for resampling
        # Note: This samples indefinitely, so take() is needed if used without repeat()
        # Since we usually repeat() for training, this should be fine.
        ds = tf.data.experimental.sample_from_datasets(
            datasets_by_class, weights=target_dist, seed=SEED
        )
        print(f"[{dataset_name}] Oversampling applied.")

    # Apply shuffling (if enabled, typically for training)
    if shuffle:
        buffer_size = min(len(paths), 5000) # Adjust buffer size based on dataset size/memory
        ds = ds.shuffle(buffer_size=buffer_size, seed=SEED, reshuffle_each_iteration=True)
        print(f"[{dataset_name}] Shuffling applied (buffer={buffer_size}).")

    # Apply batching
    ds = ds.batch(batch_size)

    # Apply prefetching
    ds = ds.prefetch(buffer_size=AUTO)

    # Apply distribution options
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    ds = ds.with_options(options)

    args_str = ", ".join(f"{k}={v}" for k,v in preprocess_args.items())
    print(f"-> [{dataset_name}] Built: items={len(paths)}, shuffle={shuffle}, oversample={oversample}, map_args=({args_str})")
    return ds




In [None]:
# CELL 5: Model Building Functions (DenseNet121, Attention Layers)
print("\n--- Cell 5: Model Building Functions ---")

# --- Attention Layer Implementations ---
# Using the provided implementations, ensuring they are Layers
class SelfAttention(Layer):
    def __init__(self, units=64, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.query_layer = Dense(units, name='query')
        self.key_layer = Dense(units, name='key')
        self.value_layer = Dense(None, name='value') # Output channels inferred in build
        self.softmax = Softmax(axis=-1)
        self.add = Add()

    def build(self, input_shape):
        channels = input_shape[-1]
        if channels is None:
            raise ValueError("Channel dimension must be known for SelfAttention.")
        # Correctly set output units for value layer
        self.value_layer.units = channels
        super().build(input_shape)

    def call(self, inputs):
        # B = Batch size, H = Height, W = Width, C = Channels
        input_shape = tf.shape(inputs)
        B, H, W = input_shape[0], input_shape[1], input_shape[2]
        C = tf.compat.dimension_value(inputs.shape[-1]) # Static preferred

        flattened = Reshape((H * W, C))(inputs) # Shape: (B, H*W, C)

        q = self.query_layer(flattened)  # Shape: (B, H*W, units)
        k = self.key_layer(flattened)  # Shape: (B, H*W, units)
        v = self.value_layer(flattened)  # Shape: (B, H*W, C)

        # Attention Scores
        scores = tf.matmul(q, k, transpose_b=True)  # Shape: (B, H*W, H*W)

        # Scaling
        dk = tf.cast(tf.shape(k)[-1], tf.float32)
        scaled_scores = scores / tf.math.sqrt(dk)

        # Weights
        weights = self.softmax(scaled_scores) # Shape: (B, H*W, H*W)

        # Weighted Values
        attention_output = tf.matmul(weights, v)  # Shape: (B, H*W, C)

        # Reshape and Residual Connection
        output_reshaped = Reshape((H, W, C))(attention_output)
        output = self.add([inputs, output_reshaped])
        return output

    def get_config(self):
        config = super().get_config()
        config.update({"units": self.units})
        return config

class ChannelAttention(Layer):
    def __init__(self, ratio=8, **kwargs):
        super().__init__(**kwargs)
        self.ratio = ratio
        self.avg_pool = GlobalAveragePooling2D(keepdims=True)
        self.max_pool = GlobalMaxPooling2D(keepdims=True)
        # Dense layers will be built in build()

    def build(self, input_shape):
        channels = input_shape[-1]
        if channels is None: raise ValueError("Channel dimension required.")
        self.shared_dense_1 = Dense(channels // self.ratio, activation='relu', kernel_initializer='he_normal', use_bias=True, name='ca_dense_1')
        self.shared_dense_2 = Dense(channels, kernel_initializer='he_normal', use_bias=True, name='ca_dense_2')
        super().build(input_shape)

    def call(self, inputs):
        avg_pooled = self.avg_pool(inputs)
        max_pooled = self.max_pool(inputs)

        avg_out = self.shared_dense_2(self.shared_dense_1(avg_pooled))
        max_out = self.shared_dense_2(self.shared_dense_1(max_pooled))

        attention = Activation('sigmoid')(Add()([avg_out, max_out]))
        return Multiply()([inputs, attention])

    def get_config(self):
        config = super().get_config()
        config.update({"ratio": self.ratio})
        return config

class SpatialAttention(Layer):
    def __init__(self, kernel_size=7, **kwargs):
        super().__init__(**kwargs)
        self.kernel_size = kernel_size
        self.concat = Concatenate(axis=-1)
        # Conv2D layer built in build()

    def build(self, input_shape):
         self.conv2d = Conv2D(1, kernel_size=self.kernel_size, padding='same', activation='sigmoid', kernel_initializer='he_normal', use_bias=False, name='sa_conv')
         super().build(input_shape)

    def call(self, inputs):
        avg_pooled = tf.reduce_mean(inputs, axis=-1, keepdims=True) # Avg across channels
        max_pooled = tf.reduce_max(inputs, axis=-1, keepdims=True)  # Max across channels

        concat = self.concat([avg_pooled, max_pooled]) # Shape: (B, H, W, 2)
        attention = self.conv2d(concat) # Shape: (B, H, W, 1)

        return Multiply()([inputs, attention])

    def get_config(self):
        config = super().get_config()
        config.update({"kernel_size": self.kernel_size})
        return config

class CBAM(Layer):
    def __init__(self, ratio=8, kernel_size=7, **kwargs):
        super().__init__(**kwargs)
        self.ratio = ratio
        self.kernel_size = kernel_size
        self.channel_attn = ChannelAttention(ratio=ratio, name='cbam_channel')
        self.spatial_attn = SpatialAttention(kernel_size=kernel_size, name='cbam_spatial')

    def call(self, inputs):
        x = self.channel_attn(inputs)
        x = self.spatial_attn(x)
        return x

    def get_config(self):
        config = super().get_config()
        config.update({
            "ratio": self.ratio,
            "kernel_size": self.kernel_size
        })
        return config

# Store custom layers for potential loading later
custom_objects_map = {
    'SelfAttention': SelfAttention,
    'ChannelAttention': ChannelAttention,
    'SpatialAttention': SpatialAttention,
    'CBAM': CBAM,
}


# --- Model Creation Functions ---
def create_base_model(input_shape, trainable=False):
    """Creates the DenseNet121 base model."""
    base = applications.DenseNet121(weights='imagenet', include_top=False, input_shape=input_shape)
    base.trainable = trainable
    return base

def add_classification_head(inputs, num_classes, pooling_type='avg', attention_type=None, dropout_rate=0.3):
    """Adds attention (optional), pooling, and classification layers."""
    x = inputs

    # 1. Attention (Optional)
    if attention_type:
        attn_name = f'attn_{attention_type}'
        if attention_type == 'self': x = SelfAttention(units=64, name=attn_name)(x)
        elif attention_type == 'channel': x = ChannelAttention(ratio=8, name=attn_name)(x)
        elif attention_type == 'spatial': x = SpatialAttention(kernel_size=7, name=attn_name)(x)
        elif attention_type == 'cbam': x = CBAM(ratio=8, kernel_size=7, name=attn_name)(x)
        else: raise ValueError(f"Unknown attention type: {attention_type}")

    # 2. Pooling
    pool_name = f'pool_{pooling_type}'
    if pooling_type == 'avg': x = GlobalAveragePooling2D(name=pool_name)(x)
    elif pooling_type == 'max': x = GlobalMaxPooling2D(name=pool_name)(x)
    elif pooling_type == 'hybrid':
        avg_pool = GlobalAveragePooling2D(name='pool_hybrid_avg')(x)
        max_pool = GlobalMaxPooling2D(name='pool_hybrid_max')(x)
        x = Concatenate(name='pool_hybrid_concat')([avg_pool, max_pool])
    else: raise ValueError(f"Unknown pooling type: {pooling_type}")

    # 3. Classification Head Layers
    x = BatchNormalization(name='head_bn_1')(x)
    x = Dropout(dropout_rate, seed=SEED, name='head_dropout_1')(x)
    x = Dense(128, activation='relu', name='head_dense_1', kernel_initializer='he_normal')(x)
    x = BatchNormalization(name='head_bn_2')(x)
    x = Dropout(dropout_rate, seed=SEED, name='head_dropout_2')(x)

    # Final Output Layer
    if num_classes == 1: # Binary classification
        activation = 'sigmoid'
        units = 1
    else: # Multi-class
        activation = 'softmax'
        units = num_classes

    outputs = Dense(units, activation=activation, name='classifier_output')(x)
    # If using mixed precision, ensure output is float32
    if tf.keras.mixed_precision.global_policy().compute_dtype == 'float16':
         outputs = Activation('linear', dtype='float32')(outputs)

    return outputs

def build_full_model(input_shape, num_classes, pooling='avg', attention=None, dropout=DROPOUT_RATE, base_trainable=False):
    """Builds the complete model with base and head."""
    inputs = Input(shape=input_shape, name='input_image')
    base_model = create_base_model(input_shape, trainable=base_trainable)
    base_output = base_model(inputs, training=base_trainable) # Control training mode

    outputs = add_classification_head(
        base_output,
        num_classes=num_classes,
        pooling_type=pooling,
        attention_type=attention,
        dropout_rate=dropout
    )

    model_name = f'DenseNet121_P-{pooling or "none"}_A-{attention or "none"}'
    model = keras.Model(inputs=inputs, outputs=outputs, name=model_name)
    return model

print("Model building functions defined.")




In [None]:
# CELL 6: Core Training, Evaluation, and Enhanced Visualization Utilities
print("\n--- Cell 6: Core Training, Evaluation, and Enhanced Visualization Utilities ---")

import time
import numpy as np
import pandas as pd
import tensorflow as tf # Ensure tensorflow is imported
from tensorflow import keras # Ensure keras is imported
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, # For find_optimal_threshold & evaluate_model
    roc_auc_score, confusion_matrix, classification_report, # For evaluate_model & plotting
    roc_curve, # For evaluate_model & plotting
    precision_recall_curve, average_precision_score # For new plotting
)
import os
import traceback # For debugging if needed
from math import ceil

# --- Matplotlib Publication Quality Settings ---
plt.rcParams.update({
    'font.size': 11,
    'font.family': 'serif',
    'axes.labelsize': 12,
    'axes.titlesize': 14,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 11,
    'figure.titlesize': 16,
    'axes.linewidth': 1.2,
    'grid.linewidth': 0.8,
    'lines.linewidth': 2.0,
    'patch.linewidth': 0.5,
    'savefig.dpi': 300,
    # 'savefig.format': 'pdf', # Default format for savefig can be specified directly in the call
    'savefig.bbox': 'tight'
})

# --- Training Function ---
def train_model(model, train_ds, val_ds, epochs, class_weights, strategy, learning_rate,
                initial_epoch=0, callbacks=None, stage_name="Training"):
    """Compiles and trains the model within the strategy scope."""
    if train_ds is None:
        print(f"ERROR [{stage_name}]: Training dataset is None. Cannot train.")
        return None, None

    with strategy.scope():
        # IMPORTANT: For F1-score to be in history, add it here.
        # Ensure 'import tensorflow_addons as tfa' is in Cell 1.
        metrics_list = [
            'accuracy',
            tf.keras.metrics.Precision(name='precision'),
            tf.keras.metrics.Recall(name='recall'),
            # tf.keras.metrics.AUC(name='auc')
            # Example to add F1-score (make sure tfa is imported in Cell 1):
            tfa.metrics.F1Score(num_classes=1, threshold=0.5, name='f1_score', average='micro'), # For binary
        ]
        model.compile(
            optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
            loss='binary_crossentropy', # Assuming binary classification from num_classes=1 elsewhere
            metrics=metrics_list
        )

    print(f"\n--- Starting {stage_name} ---")
    print(f"Epochs: {epochs}, Initial Epoch: {initial_epoch}")
    print(f"Learning Rate: {learning_rate}")
    print(f"Class Weights: {'Applied' if class_weights else 'None'}")

    start_time = time.time()
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs,
        initial_epoch=initial_epoch,
        class_weight=class_weights,
        callbacks=callbacks or [],
        verbose=1
    )
    end_time = time.time()
    training_duration = end_time - start_time
    print(f"--- {stage_name} Finished (Duration: {training_duration:.2f} seconds) ---")
    return history, training_duration

# --- Threshold Optimization ---
def find_optimal_threshold(y_true, y_pred_proba, target_metric='f1'):
    """Finds the threshold maximizing the target metric on validation predictions."""
    best_threshold = 0.5
    best_score = -1.0

    if y_true is None or y_pred_proba is None or len(y_true) == 0 or len(y_true) != len(y_pred_proba):
        print("Warning: Invalid inputs for threshold optimization. Returning default 0.5.")
        return best_threshold

    y_true = np.array(y_true).astype(int) # Ensure numpy array
    y_pred_proba = np.array(y_pred_proba) # Ensure numpy array

    thresholds = np.arange(0.01, 1.0, 0.01)
    scores = []

    for thresh in thresholds:
        y_pred_binary = (y_pred_proba >= thresh).astype(int)
        if target_metric == 'f1':
            score = f1_score(y_true, y_pred_binary, zero_division=0)
        elif target_metric == 'accuracy':
            score = accuracy_score(y_true, y_pred_binary)
        elif target_metric == 'precision':
            score = precision_score(y_true, y_pred_binary, zero_division=0)
        elif target_metric == 'recall':
            score = recall_score(y_true, y_pred_binary, zero_division=0)
        else: # Default to F1
            print(f"Warning: Unknown target_metric '{target_metric}' for threshold optimization. Defaulting to F1.")
            score = f1_score(y_true, y_pred_binary, zero_division=0)
        scores.append(score)

    if scores:
        best_idx = np.argmax(scores)
        best_score = scores[best_idx]
        best_threshold = thresholds[best_idx]
    else:
        print("Warning: Could not compute scores for threshold optimization. Using default 0.5.")
    # print(f"Best threshold for '{target_metric}': {best_threshold:.3f} with score: {best_score:.4f}") # Optional debug
    return best_threshold

# --- Enhanced Plotting Functions ---
def plot_training_history_enhanced(history, title_suffix="", save_dir=None, config_name="model"):
    if history is None or not history.history:
        print("No history data found to plot.")
        return
    history_df = pd.DataFrame(history.history)
    history_df['epoch'] = np.arange(1, len(history_df) + 1)

    potential_metrics_map = {
        'loss': 'Loss', 'accuracy': 'Accuracy', 'precision': 'Precision', 'recall': 'Recall',
        'f1_score': 'F1-score', 'f1': 'F1-score', # Common keys for F1
        'auc': 'AUC'
    }
    available_history_keys = list(history_df.columns)
    metrics_for_plotting = []

    # Add base metrics if available
    for key in ['loss', 'accuracy', 'precision', 'recall']:
        if key in available_history_keys:
            metrics_for_plotting.append({'key': key, 'name': potential_metrics_map.get(key, key.title())})

    # Determine 5th metric: F1-score (preferred) or AUC
    f1_key_to_use = None
    if 'f1_score' in available_history_keys: f1_key_to_use = 'f1_score'
    elif 'f1' in available_history_keys: f1_key_to_use = 'f1'

    if f1_key_to_use:
        if len(metrics_for_plotting) < 5:
            metrics_for_plotting.append({'key': f1_key_to_use, 'name': potential_metrics_map.get(f1_key_to_use, 'F1-score')})
    elif 'auc' in available_history_keys: # If F1 not found, try AUC
        if len(metrics_for_plotting) < 5:
            metrics_for_plotting.append({'key': 'auc', 'name': potential_metrics_map.get('auc', 'AUC')})
    
    num_actual_plots = len(metrics_for_plotting)
    if num_actual_plots == 0:
        print("No plottable metrics found in history.")
        return

    fig = plt.figure() # Initialize figure; size set below
    subplot_definitions = []

    if num_actual_plots == 5:
        fig.set_size_inches(18, 10) 
        gs_fig = fig.add_gridspec(2, 6, hspace=0.45, wspace=0.5) # hspace for title, wspace for between plots
        subplot_definitions = [
            gs_fig[0, 0:2], gs_fig[0, 2:4], gs_fig[0, 4:6], 
            gs_fig[1, 1:3], gs_fig[1, 3:5]                  
        ]
    elif num_actual_plots == 4:
        fig.set_size_inches(12, 10) 
        gs_fig = fig.add_gridspec(2, 2, hspace=0.35, wspace=0.3)
        subplot_definitions = [gs_fig[0,0], gs_fig[0,1], gs_fig[1,0], gs_fig[1,1]]
    elif num_actual_plots > 0: # 1, 2, or 3 plots
        fig.set_size_inches(6 * num_actual_plots, 5.5) 
        gs_fig = fig.add_gridspec(1, num_actual_plots, hspace=0.3, wspace=0.25 if num_actual_plots > 1 else 0)
        subplot_definitions = [gs_fig[0,i] for i in range(num_actual_plots)]
    else: # Should be caught by num_actual_plots == 0
        return

    colors = {'train': '#2E86AB', 'val': '#A23B72'}
    for idx, metric_info in enumerate(metrics_for_plotting):
        metric_key = metric_info['key']
        display_name = metric_info['name']
        
        ax = fig.add_subplot(subplot_definitions[idx])
        val_metric_key = f'val_{metric_key}'

        if metric_key in history_df.columns:
            ax.plot(history_df['epoch'], history_df[metric_key],
                    color=colors['train'], marker='o', markersize=4,
                    label='Training', linewidth=2)
        if val_metric_key in history_df.columns:
            ax.plot(history_df['epoch'], history_df[val_metric_key],
                    color=colors['val'], marker='s', markersize=4,
                    label='Validation', linewidth=2, linestyle='--')
        
        ax.set_title(display_name, fontweight='bold', pad=10)
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Value')
        ax.grid(True, alpha=0.3)
        ax.legend(frameon=True, fancybox=True, shadow=True)
        
        if metric_key not in ['loss']:
            ax.set_ylim(-0.05, 1.05) # Allow slight dip below 0 for visual
        else: 
            min_val_data = history_df[metric_key].dropna()
            max_val_data = history_df[metric_key].dropna()
            if val_metric_key in history_df and not history_df[val_metric_key].dropna().empty:
                min_val_data = pd.concat([min_val_data, history_df[val_metric_key].dropna()])
                max_val_data = pd.concat([max_val_data, history_df[val_metric_key].dropna()])
            
            min_loss = min_val_data.min() if not min_val_data.empty else 0
            max_loss = max_val_data.max() if not max_val_data.empty else 1.0 # Ensure max_loss is float
            
            padding_abs = 0.1 
            if pd.notna(min_loss) and pd.notna(max_loss) and (max_loss - min_loss) > 1e-5 :
                 padding = 0.1 * (max_loss - min_loss)
                 padding = max(padding, 0.05) # ensure some minimal padding
            else:
                 padding = padding_abs
            
            y_min_plot = float(min_loss - padding) if pd.isna(min_loss) else float(max(0, min_loss - padding))
            y_max_plot = float(max_loss + padding) if pd.notna(max_loss) else float(y_min_plot + 2*padding_abs)
            if y_max_plot <= y_min_plot: y_max_plot = y_min_plot + padding_abs # Ensure max > min

            ax.set_ylim(y_min_plot, y_max_plot)

    fig.suptitle(f'Training History {title_suffix}', fontsize=16, fontweight='bold')
    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) 

    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        filename = f"training_history_{config_name}.pdf"
        filepath = os.path.join(save_dir, filename)
        plt.savefig(filepath, dpi=300, bbox_inches='tight', format='pdf')
        print(f"Training history saved: {filepath}")
    plt.show()

def plot_dual_confusion_matrices(y_true, y_pred_proba, optimal_threshold, inv_label_map,
                                 title_suffix="", save_dir=None, config_name="model"):
    if y_true is None or y_pred_proba is None:
        print("Cannot plot confusion matrices: Missing data")
        return
    y_true = np.array(y_true).astype(int)
    y_pred_proba = np.array(y_pred_proba)

    y_pred_default = (y_pred_proba >= 0.5).astype(int)
    y_pred_optimal = (y_pred_proba >= optimal_threshold).astype(int)
    
    cm_default = confusion_matrix(y_true, y_pred_default)
    cm_optimal = confusion_matrix(y_true, y_pred_optimal)
    
    # Determine class names robustly
    unique_labels = sorted(np.unique(y_true))
    if not inv_label_map or not all(lbl in inv_label_map for lbl in unique_labels) :
        class_names = [f"Class {i}" for i in unique_labels]
    else:
        class_names = [inv_label_map.get(lbl, f"Class {lbl}") for lbl in unique_labels]
    if not class_names: class_names = ["Class 0", "Class 1"] # Fallback

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6.5)) # Slightly taller
    sns.heatmap(cm_default, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names,
                ax=ax1, cbar_kws={'shrink': 0.8}, square=True,
                annot_kws={'size': 14, 'weight': 'bold'})
    ax1.set_title('Confusion Matrix\n(Threshold = 0.5)', fontweight='bold', pad=15)
    ax1.set_xlabel('Predicted Label', fontweight='bold'); ax1.set_ylabel('True Label', fontweight='bold')

    sns.heatmap(cm_optimal, annot=True, fmt='d', cmap='Greens',
                xticklabels=class_names, yticklabels=class_names,
                ax=ax2, cbar_kws={'shrink': 0.8}, square=True,
                annot_kws={'size': 14, 'weight': 'bold'})
    ax2.set_title(f'Confusion Matrix\n(Optimal Threshold = {optimal_threshold:.3f})',
                  fontweight='bold', pad=15)
    ax2.set_xlabel('Predicted Label', fontweight='bold'); ax2.set_ylabel('True Label', fontweight='bold')
    
    plt.suptitle(f'Confusion Matrix Comparison {title_suffix}',
                 fontsize=16, fontweight='bold', y=1.0) # Adjusted y
    plt.tight_layout(rect=[0, 0, 1, 0.95]) # Make space for suptitle
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        filename = f"confusion_matrices_{config_name}.pdf"
        filepath = os.path.join(save_dir, filename)
        plt.savefig(filepath, dpi=300, bbox_inches='tight', format='pdf')
        print(f"Confusion matrices saved: {filepath}")
    plt.show()

def plot_performance_curves(y_true, y_pred_proba, optimal_threshold, inv_label_map,
                            title_suffix="", save_dir=None, config_name="model"):
    if y_true is None or y_pred_proba is None:
        print("Cannot plot performance curves: Missing data")
        return
    y_true = np.array(y_true).astype(int)
    y_pred_proba = np.array(y_pred_proba)

    if len(np.unique(y_true)) < 2:
        print("Cannot plot ROC/PR curves: Only one class present in true labels.")
        return

    fpr, tpr, roc_thresholds = roc_curve(y_true, y_pred_proba)
    roc_auc = roc_auc_score(y_true, y_pred_proba)
    precision_vals, recall_vals, pr_thresholds = precision_recall_curve(y_true, y_pred_proba)
    avg_precision = average_precision_score(y_true, y_pred_proba)
    
    positive_class_label = 1 # Assuming positive class is 1 for binary
    positive_class_name = inv_label_map.get(positive_class_label, "Positive Class") if isinstance(inv_label_map, dict) else "Positive Class"


    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6.5)) # Slightly taller
    # ROC Curve
    ax1.plot(fpr, tpr, color='#2E86AB', linewidth=3, label=f'ROC Curve (AUC = {roc_auc:.4f})')
    ax1.plot([0, 1], [0, 1], color='gray', linewidth=2, linestyle='--', alpha=0.7, label='Random Classifier')
    
    # Find point for optimal threshold on ROC curve
    # roc_thresholds also includes some high values like inf, so be careful with argmin
    if len(roc_thresholds) > 0:
        optimal_idx_roc = np.argmin(np.abs(roc_thresholds - optimal_threshold))
        # Ensure index is valid for fpr and tpr arrays
        if optimal_idx_roc < len(fpr) and optimal_idx_roc < len(tpr):
             ax1.scatter(fpr[optimal_idx_roc], tpr[optimal_idx_roc], color='red', s=100,
                        zorder=5, label=f'Optimal Threshold ({optimal_threshold:.3f})')

    ax1.set_xlim([-0.02, 1.0]); ax1.set_ylim([0.0, 1.05]) # Start x slightly before 0
    ax1.set_xlabel('False Positive Rate', fontweight='bold'); ax1.set_ylabel('True Positive Rate', fontweight='bold')
    ax1.set_title(f'ROC Curve\n({positive_class_name} Detection)', fontweight='bold', pad=15)
    ax1.legend(loc="lower right", frameon=True, fancybox=True, shadow=True); ax1.grid(True, alpha=0.3)

    # Precision-Recall Curve
    ax2.plot(recall_vals, precision_vals, color='#A23B72', linewidth=3, label=f'PR Curve (AP = {avg_precision:.4f})')
    no_skill = len(y_true[y_true==positive_class_label]) / len(y_true) if len(y_true) > 0 else 0
    ax2.axhline(y=no_skill, color='gray', linewidth=2, linestyle='--', alpha=0.7,
                label=f'No-Skill Classifier (AP = {no_skill:.4f})')

    # Find point for optimal threshold on PR curve
    # pr_thresholds is shorter by 1 than precision and recall.
    # It corresponds to decisions made *between* points on the curve.
    # So, pr_thresholds[i] is the threshold used to get recall_vals[i+1] and precision_vals[i+1]
    if len(pr_thresholds) > 0:
        # Find index in pr_thresholds closest to optimal_threshold
        optimal_idx_pr_thresh = np.argmin(np.abs(pr_thresholds - optimal_threshold))
        # The corresponding point on PR curve is at index optimal_idx_pr_thresh + 1 for recall and precision
        # Ensure this index is valid for recall_vals and precision_vals
        point_idx_pr = optimal_idx_pr_thresh +1
        if point_idx_pr < len(recall_vals) and point_idx_pr < len(precision_vals):
            ax2.scatter(recall_vals[point_idx_pr], precision_vals[point_idx_pr],
                        color='red', s=100, zorder=5,
                        label=f'Optimal Threshold ({optimal_threshold:.3f})')

    ax2.set_xlim([0.0, 1.02]); ax2.set_ylim([0.0, 1.05]) # End x slightly after 1
    ax2.set_xlabel('Recall', fontweight='bold'); ax2.set_ylabel('Precision', fontweight='bold')
    ax2.set_title(f'Precision-Recall Curve\n({positive_class_name} Detection)', fontweight='bold', pad=15)
    ax2.legend(loc="lower left", frameon=True, fancybox=True, shadow=True); ax2.grid(True, alpha=0.3)

    plt.suptitle(f'Performance Curves {title_suffix}', fontsize=16, fontweight='bold', y=1.0) # Adjusted y
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        filename = f"performance_curves_{config_name}.pdf"
        filepath = os.path.join(save_dir, filename)
        plt.savefig(filepath, dpi=300, bbox_inches='tight', format='pdf')
        print(f"Performance curves saved: {filepath}")
    plt.show()

# --- Evaluation Function (Fixed model.predict calls) ---
# In your consolidated Cell 6:
# Ensure these imports are at the top of Cell 6 if not already:
# from sklearn.metrics import average_precision_score, roc_auc_score, ... (other metrics)
# import numpy as np

def evaluate_model_optimized_with_viz(model, val_ds, test_ds, strategy, inv_label_map,
                                      target_metric='f1', dataset_name="Test",
                                      save_dir=None, config_name="model"):
    """
    Evaluates the model using logic from old 'evaluate_model_optimized'
    and then calls new enhanced visualization functions. NOW INCLUDES PR AUC.
    """
    results_eval = {}
    y_true_val_np, y_pred_proba_val_np = None, None
    y_true_eval_np, y_pred_proba_eval_np = None, None
    # cm_eval_np = None # Not returned by this version as run_experiment doesn't use it directly
    optimal_threshold = 0.5 

    print(f"\n--- Evaluating Model on {dataset_name} Set ---")
    print(f"  Using target metric '{target_metric}' for threshold optimization on validation set.")

    # 1. Get Compiled Metrics on Evaluation Set (using model.evaluate @ 0.5 threshold)
    print(f"Running model.evaluate() on {dataset_name} set...")
    if test_ds is not None:
        try:
            eval_results_tf = model.evaluate(test_ds, verbose=0, return_dict=True) # Renamed to avoid clash
            print("  Compiled metrics (@ 0.5 Threshold from model.evaluate):")
            for k, v in eval_results_tf.items():
                metric_val = float(v) if isinstance(v, (np.float32, np.float64)) else v
                print(f"    {k}: {metric_val:.4f}")
                results_eval[k] = metric_val # Stores loss, accuracy, precision, recall, auc
        except Exception as e:
            print(f"  Error during model.evaluate() on {dataset_name} set: {e}")
            print("  Skipping compiled metrics evaluation from model.evaluate.")
    else:
        print(f"  {dataset_name} dataset not provided. Skipping compiled metrics evaluation from model.evaluate.")

    # 2. Get Predictions on Validation Set for Threshold Optimization
    # ... (this part remains the same - it finds optimal_threshold) ...
    print(f"\nRunning model.predict() on validation set for threshold optimization...")
    if val_ds is not None:
        try:
            # Efficiently get all labels and predictions
            y_true_val_list_batches = []
            y_pred_proba_val_list_batches = []
            print("  Extracting true labels and making predictions on validation set...")
            for images_batch_val, labels_batch_val in val_ds.as_numpy_iterator(): # Iterate once
                y_true_val_list_batches.append(labels_batch_val)
                y_pred_proba_val_list_batches.append(model.predict(images_batch_val, verbose=0))

            if y_true_val_list_batches:
                y_true_val_np = np.concatenate([item.flatten() for item in y_true_val_list_batches])
                y_pred_proba_val_np = np.concatenate([item.flatten() for item in y_pred_proba_val_list_batches])
                print(f"  Extracted {len(y_true_val_np)} validation labels and made {len(y_pred_proba_val_np)} predictions.")

                if len(y_true_val_np) == len(y_pred_proba_val_np) and len(y_true_val_np) > 0:
                    optimal_threshold = find_optimal_threshold(y_true_val_np, y_pred_proba_val_np, target_metric)
                    print(f"  Optimal threshold determined from validation set: {optimal_threshold:.3f}")
                else:
                    print(f"  Warning: Mismatch or empty validation labels/predictions ({len(y_true_val_np)} vs {len(y_pred_proba_val_np)}). Using default threshold 0.5.")
                    optimal_threshold = 0.5
            else:
                print("  Warning: No labels/data extracted from validation dataset. Using default threshold 0.5.")
                optimal_threshold = 0.5
        except Exception as e:
            print(f"  Error during validation prediction/threshold optimization: {e}")
            # import traceback
            # traceback.print_exc()
            print("  Using default threshold 0.5.")
            optimal_threshold = 0.5
    else:
        print("  Validation dataset not provided. Using default threshold 0.5 for evaluation set metrics.")
        optimal_threshold = 0.5
    results_eval['optimal_threshold'] = optimal_threshold
    results_eval['threshold_target_metric'] = target_metric if val_ds else 'N/A'


    # 3. Get Predictions on Evaluation Set (e.g., test_ds) for Detailed Metrics
    print(f"\nRunning model.predict() on {dataset_name} set for detailed metrics...")
    if test_ds is not None:
        try:
            y_true_eval_list_batches = []
            y_pred_proba_eval_list_batches = []
            print(f"  Extracting true labels and making predictions on {dataset_name} set...")
            for images_batch_eval, labels_batch_eval in test_ds.as_numpy_iterator(): # Iterate once
                y_true_eval_list_batches.append(labels_batch_eval)
                y_pred_proba_eval_list_batches.append(model.predict(images_batch_eval, verbose=0))

            if y_true_eval_list_batches:
                y_true_eval_np = np.concatenate([item.flatten() for item in y_true_eval_list_batches])
                y_pred_proba_eval_np = np.concatenate([item.flatten() for item in y_pred_proba_eval_list_batches])
                print(f"  Extracted {len(y_true_eval_np)} {dataset_name} labels and made {len(y_pred_proba_eval_np)} predictions.")

                if len(y_true_eval_np) == len(y_pred_proba_eval_np) and len(y_true_eval_np) > 0:
                    y_pred_eval_optimized = (y_pred_proba_eval_np >= optimal_threshold).astype(int)
                    print(f"\nDetailed Metrics ({dataset_name} Set @ Optimal Threshold {optimal_threshold:.3f}):")
                    try:
                        results_eval['accuracy_opt'] = accuracy_score(y_true_eval_np, y_pred_eval_optimized)
                        results_eval['precision_opt'] = precision_score(y_true_eval_np, y_pred_eval_optimized, zero_division=0)
                        results_eval['recall_opt'] = recall_score(y_true_eval_np, y_pred_eval_optimized, zero_division=0)
                        results_eval['f1_opt'] = f1_score(y_true_eval_np, y_pred_eval_optimized, zero_division=0)

                        if len(np.unique(y_true_eval_np)) > 1: # Need at least two classes for AUCs
                            try:
                                results_eval['roc_auc_proba'] = roc_auc_score(y_true_eval_np, y_pred_proba_eval_np)
                                # --- ADDED PR AUC (Average Precision) ---
                                results_eval['pr_auc'] = average_precision_score(y_true_eval_np, y_pred_proba_eval_np)
                                # --------------------------------------
                            except ValueError as auc_e:
                                print(f"  Warning: Could not calculate ROC AUC or PR AUC scores: {auc_e}")
                                results_eval['roc_auc_proba'] = np.nan
                                results_eval['pr_auc'] = np.nan # Also set PR AUC to NaN
                        else:
                            results_eval['roc_auc_proba'] = np.nan
                            results_eval['pr_auc'] = np.nan # Also set PR AUC to NaN for single class
                            print("  Warning: Only one class present in test labels, ROC AUC and PR AUC are undefined.")

                        print(f"  Accuracy (opt):  {results_eval.get('accuracy_opt', np.nan):.4f}")
                        print(f"  Precision (opt): {results_eval.get('precision_opt', np.nan):.4f}")
                        print(f"  Recall (opt):    {results_eval.get('recall_opt', np.nan):.4f}")
                        print(f"  F1 Score (opt):  {results_eval.get('f1_opt', np.nan):.4f}")
                        print(f"  ROC AUC (proba): {results_eval.get('roc_auc_proba', np.nan):.4f}")
                        # --- ADDED PR AUC PRINT ---
                        print(f"  PR AUC (AvgPrec):{results_eval.get('pr_auc', np.nan):.4f}")
                        # -------------------------

                        cm_eval_np = confusion_matrix(y_true_eval_np, y_pred_eval_optimized)
                        print(f"\nClassification Report ({dataset_name} Set @ Optimal Threshold {optimal_threshold:.3f}):")
                        num_classes_eval = len(np.unique(y_true_eval_np))
                        target_names_report = [inv_label_map.get(i, f"Class {i}") for i in range(num_classes_eval)] if num_classes_eval == 2 else [inv_label_map.get(c, f"Class {c}") for c in sorted(np.unique(y_true_eval_np))]
                        if not target_names_report: target_names_report = ["Unknown"]
                        print(classification_report(y_true_eval_np, y_pred_eval_optimized, target_names=target_names_report, labels=np.unique(y_true_eval_np), zero_division=0))
                    except Exception as metric_e:
                        print(f"  Error calculating detailed sklearn metrics: {metric_e}")
                else:
                    print(f"  Warning: Mismatch or empty {dataset_name} labels/predictions. Cannot calculate detailed sklearn metrics.")
            else:
                print(f"  Warning: No labels/data extracted from {dataset_name} dataset. Cannot calculate detailed sklearn metrics.")
        except Exception as e:
            print(f"  Error during {dataset_name} prediction or detailed metrics calculation: {e}")
            # import traceback
            # traceback.print_exc()
    else:
        print(f"  {dataset_name} dataset not provided. Cannot calculate detailed sklearn metrics.")

    # --- Call Enhanced Visualizations ---
    if y_true_eval_np is not None and y_pred_proba_eval_np is not None:
        print("\n--- Generating Enhanced Visualizations ---")
        # These plotting functions already calculate ROC AUC and Avg Precision internally for their plots
        plot_dual_confusion_matrices( 
            y_true_eval_np, y_pred_proba_eval_np, optimal_threshold, inv_label_map,
            title_suffix=f"({dataset_name} Set, Config: {config_name})", save_dir=save_dir, config_name=config_name
        )
        plot_performance_curves(
            y_true_eval_np, y_pred_proba_eval_np, optimal_threshold, inv_label_map,
            title_suffix=f"({dataset_name} Set, Config: {config_name})", save_dir=save_dir, config_name=config_name
        )
    else:
        print("\n--- Skipping Enhanced Visualizations due to missing evaluation data ---")

    print("--- Evaluation Complete ---")
    return results_eval, y_true_eval_np, y_pred_proba_eval_np, None # cm_eval_np is not directly needed by run_experiment

# In your consolidated Cell 6:
def plot_comparison_bars_enhanced(config_keys_to_plot, metrics_dir, title, save_dir=None,
                                  metrics_to_display=['f1_opt', 'accuracy_opt', 'roc_auc_proba', 'pr_auc']):
    """
    Creates publication-quality comparison bar charts by reading metrics from saved JSON files.
    Removes top and right spines from subplots for a cleaner look.

    Args:
        config_keys_to_plot (list): List of configuration keys (strings) to load and plot.
        metrics_dir (str): Path to the directory containing the "evaluation_metrics_{key}.json" files.
        title (str): Main title for the plot.
        save_dir (str, optional): Directory to save the plot PDF. Defaults to None (no save).
        metrics_to_display (list, optional): List of metric keys (strings) from the JSON files
                                             to extract and plot.
    Returns:
        pandas.DataFrame: DataFrame containing the plotted data, or None if plotting failed.
    """
    if not config_keys_to_plot:
        print(f"No configuration keys provided for comparison plot: {title}")
        return None
    if not os.path.isdir(metrics_dir):
        print(f"Metrics directory not found: {metrics_dir}")
        return None

    data_for_plot = {}
    for config_key in config_keys_to_plot:
        json_filename = f"evaluation_metrics_{config_key}.json"
        json_path = os.path.join(metrics_dir, json_filename)
        if os.path.exists(json_path):
            try:
                with open(json_path, 'r') as f:
                    metrics_from_json = json.load(f)
                temp_metrics_for_this_config = {}
                has_at_least_one_valid_value = False
                for metric_name in metrics_to_display:
                    value = metrics_from_json.get(metric_name, np.nan)
                    temp_metrics_for_this_config[metric_name] = value
                    if not pd.isna(value):
                        has_at_least_one_valid_value = True
                if has_at_least_one_valid_value:
                    data_for_plot[config_key] = temp_metrics_for_this_config
            except Exception as e:
                print(f"Error loading or parsing JSON for config '{config_key}' from {json_path}: {e}")

    if not data_for_plot:
        print(f"No valid data could be loaded from JSON files for the specified configurations and metrics ({metrics_to_display}) for plot: {title}")
        return None

    df = pd.DataFrame(data_for_plot).T.reset_index().rename(columns={'index': 'Configuration'})
    if df.empty:
        print(f"DataFrame is empty after processing JSON files for plot: {title}")
        return None

    plottable_metric_keys = [mk for mk in metrics_to_display if mk in df.columns and not df[mk].isnull().all()]
    if not plottable_metric_keys:
        print(f"None of the specified metrics_to_display ({metrics_to_display}) have any valid data in the loaded JSONs for plotting: {title}")
        return df

    sort_metric = None
    if 'f1_opt' in plottable_metric_keys: sort_metric = 'f1_opt'
    elif plottable_metric_keys: sort_metric = plottable_metric_keys[0]
    if sort_metric: df = df.sort_values(by=sort_metric, ascending=False, na_position='last')
    else: df = df.sort_values(by='Configuration', ascending=True)

    num_metrics_to_plot = len(plottable_metric_keys)
    ncols = min(num_metrics_to_plot, 3)
    nrows = ceil(num_metrics_to_plot / ncols)
    fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, max(4, 0.6 * len(df) + 1.5)), squeeze=False)
    axes = axes.flatten()
    colors = ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D', '#593E2A', '#3A7D44', '#B565A7']
    
    plot_idx = 0
    for metric in plottable_metric_keys:
        ax = axes[plot_idx]
        
        y_pos = np.arange(len(df))
        bar_values = df[metric].fillna(0) 
        bars = ax.barh(y_pos, bar_values, color=colors[plot_idx % len(colors)],
                       alpha=0.85, edgecolor='black', linewidth=0.7)
        
        if metric == 'pr_auc': metric_name_display = "PR AUC (AvgPrec)"
        elif metric == 'roc_auc_proba': metric_name_display = "ROC AUC"
        else: metric_name_display = (metric.replace('_', ' ').replace(' opt', ' (Opt)').title())
            
        ax.set_title(metric_name_display, fontweight='bold', pad=12, fontsize=12)
        ax.set_xlabel('Score', fontweight='bold', fontsize=10)
        ax.set_yticks(y_pos)
        ax.set_yticklabels(df['Configuration'], fontsize=9)
        
        max_val_metric = df[metric].max()
        if pd.isna(max_val_metric) or max_val_metric == 0 : upper_limit = 0.1 
        elif max_val_metric <= 1.0: upper_limit = 1.05
        else: upper_limit = max_val_metric * 1.15
        ax.set_xlim(0, upper_limit)
        ax.tick_params(axis='x', labelsize=8)
        ax.grid(axis='x', linestyle=':', alpha=0.6)

        # --- MODIFICATION: Remove top and right spines ---
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        # You could also use sns.despine(ax=ax, top=True, right=True, left=False, bottom=False)
        # If you also want to remove left/bottom, set them to True in despine or use ax.spines['left'].set_visible(False) etc.
        # -------------------------------------------------

        for bar_idx, bar_obj in enumerate(bars):
            original_value = df[metric].iloc[bar_idx]
            width_for_text = bar_obj.get_width()
            text_label = f'{original_value:.3f}' if not pd.isna(original_value) else 'N/A'
            
            # Adjust text position slightly if needed, especially after removing spines
            padding_from_bar = ax.get_xlim()[1] * 0.015 # Slightly increased padding
            ax.text(width_for_text + padding_from_bar, 
                    bar_obj.get_y() + bar_obj.get_height() / 2,
                    text_label, ha='left', va='center',
                    fontweight='normal', fontsize=8.5, color='dimgray')
        plot_idx += 1

    for k_ax in range(plot_idx, len(axes)): axes[k_ax].axis('off')
    fig.suptitle(title, fontsize=16, fontweight='bold', y=0.99 if nrows > 1 else 1.02)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95 if nrows > 1 else 0.92])

    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        safe_title = "".join(c if c.isalnum() else "_" for c in title.lower())
        filename = f"comparison_bars_{safe_title}.pdf"
        filepath = os.path.join(save_dir, filename)
        try:
            plt.savefig(filepath, dpi=300, bbox_inches='tight', format='pdf')
            print(f"Comparison bar plot saved: {filepath}")
        except Exception as e:
            print(f"Error saving comparison bar plot to {filepath}: {e}")
    
    plt.show()
    return df
    
print("Cell 6: All utility functions (training, evaluation, enhanced plotting) are defined.")

In [None]:

# CELL 7: Experiment Runner Function
print("\n--- Cell 7: Experiment Runner Function ---")

# REMOVED Grad-CAM specific args: test_paths_global, test_labels_global, label_dict_global
# def run_experiment(config, test_paths_global, test_labels_global, label_dict_global):
def run_experiment(config): # Modified signature
    """
    Runs a complete experiment stage: builds datasets, builds model,
    trains, evaluates, stores results, and cleans up. (Grad-CAM part removed)
    """
    key = config['key']
    print("\n" + "="*70)
    print(f" Starting Experiment: {key} ")
    print("="*70)
    # ... (Config printing logic remains the same) ...
    print("-"*70)

    start_time_total = time.time()
    model = None
    history = None
    eval_metrics = None # This will hold the dictionary of evaluation metrics
    training_duration = 0
    model_to_eval = None
    checkpoint_filepath = os.path.join(CHECKPOINT_DIR, f"{key}_best.keras")

    try:
        # --- 1. Build Datasets ---
        print("\n[1. Building Datasets...]")
        if 'img_size' not in config['parse_args']:
            config['parse_args']['img_size'] = IMG_SIZE
        train_ds = build_dataset(
            train_paths, train_labels, preprocess_image, config['parse_args'],
            GLOBAL_BATCH_SIZE, f"Train ({key})", shuffle=True,
            augment_in_map=config['parse_args'].get('apply_augment', False),
            oversample=config.get('oversample_train', False), cache=True
        )
        val_parse_args = config['parse_args'].copy(); val_parse_args['apply_augment'] = False
        val_ds = build_dataset(
            val_paths, val_labels, preprocess_image, val_parse_args,
            GLOBAL_BATCH_SIZE, f"Validation ({key})", shuffle=False, cache=True
        )
        test_parse_args = config['parse_args'].copy(); test_parse_args['apply_augment'] = False
        test_ds = build_dataset(
            test_paths, test_labels, preprocess_image, test_parse_args,
            GLOBAL_BATCH_SIZE, f"Test ({key})", shuffle=False, cache=True
        )
        if not all([train_ds, val_ds, test_ds]): raise RuntimeError("Dataset build failed.")
        print("Datasets built successfully.")

        # --- 2. Build Model ---
        print("\n[2. Building Model...]")
        with strategy.scope():
            if 'num_classes' not in config['model_args']: config['model_args']['num_classes'] = 1
            model = build_full_model(IMG_SHAPE, **config['model_args'])
        model.summary(line_length=100)
        print(f"Model '{model.name}' built.")

        # --- 3. Setup Callbacks ---
        print("\n[3. Setting up Callbacks...]")
        callbacks_list = [
            EarlyStopping(monitor='val_loss', patience=PATIENCE_EARLY_STOPPING, verbose=1, restore_best_weights=False),
            ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=PATIENCE_REDUCE_LR, min_lr=MIN_LR, verbose=1),
            ModelCheckpoint(filepath=checkpoint_filepath, monitor='val_loss', save_best_only=True, save_weights_only=False, verbose=1)
        ]
        print(f"Model checkpoint path: {checkpoint_filepath}")

        # --- 4. Determine Class Weights ---
        train_class_weights = None
        if config['train_args'].get('class_weights_setting') == 'balanced':
            train_class_weights = class_weights_dict
            print("Using 'balanced' class weights for training.")
        else: print("No class weights applied for training.")

        # --- 5. Train Model ---
        print("\n[5. Training Model...]")
        history, training_duration = train_model(
            model, train_ds, val_ds,
            epochs=config['train_args']['epochs'], class_weights=train_class_weights,
            strategy=strategy, learning_rate=config['train_args']['learning_rate'],
            callbacks=callbacks_list, stage_name=f"Training ({key})"
        )
        if history is None: raise RuntimeError("Model training failed.")

        # --- 6. Load Best Weights ---
        print("\n[6. Loading Best Weights from Checkpoint...]")
        if os.path.exists(checkpoint_filepath):
            with strategy.scope():
                model_to_eval = keras.models.load_model(checkpoint_filepath, custom_objects=custom_objects_map)
            print(f"Successfully loaded best model weights from {checkpoint_filepath}")
        else:
            print(f"WARNING: Checkpoint file not found at {checkpoint_filepath}. Evaluating with the last epoch's weights.")
            model_to_eval = model

        # --- 7. Evaluate Model ---
        print("\n[7. Evaluating Model...]")
        eval_metrics, y_true_test, y_pred_proba_test, _ = evaluate_model_optimized_with_viz(
            model=model_to_eval, val_ds=val_ds, test_ds=test_ds, strategy=strategy,
            inv_label_map=inv_label_dict, target_metric=TARGET_METRIC,
            dataset_name=f"Test ({key})",
            save_dir=PLOTS_DIR,
            config_name=key
        )
        if not eval_metrics: print("Warning: Evaluation failed or returned no metrics."); eval_metrics = {}

        

        # --- 9. Store Results ---
        print("\n[9. Storing Results...]") # Step number remains for consistency, though Grad-CAM (step 8) is out
        total_duration = time.time() - start_time_total
        
        metrics_json_save_path = None
        if eval_metrics:
            metrics_json_save_path = os.path.join(METRICS_DIR, f"evaluation_metrics_{key}.json")
            try:
                serializable_metrics = {}
                for m_key, m_value in eval_metrics.items():
                    if isinstance(m_value, np.generic):
                        serializable_metrics[m_key] = m_value.item()
                    elif isinstance(m_value, np.ndarray):
                        serializable_metrics[m_key] = m_value.tolist()
                    else:
                        serializable_metrics[m_key] = m_value
                
                with open(metrics_json_save_path, 'w') as f:
                    json.dump(serializable_metrics, f, indent=4)
                print(f"Evaluation metrics saved to {metrics_json_save_path}")
            except Exception as e:
                print(f"Error saving evaluation metrics to {metrics_json_save_path}: {e}")
                metrics_json_save_path = None
        
        results[key] = {
            'config': config,
            'metrics': eval_metrics,
            'training_duration_sec': training_duration,
            'total_duration_sec': total_duration,
            'checkpoint_path': checkpoint_filepath if os.path.exists(checkpoint_filepath) else None,
            'metrics_json_path': metrics_json_save_path
        }
        print(f"Results for '{key}' stored (including path to JSON metrics).")

        # --- 10. Plotting ---
        print("\n[10. Plotting Results...]")
        if history:
            plot_training_history_enhanced(
                history,
                title_suffix=f" ({key})",
                save_dir=PLOTS_DIR,
                config_name=key
            )

        print(f"--- Experiment {key} Complete ---")
        return eval_metrics

    except Exception as e:
        print(f"\n\n ****** ERROR during experiment {key} ****** ")
        print(f"Error Type: {type(e).__name__}")
        print(f"Error Details: {e}")
        print("Traceback:")
        traceback.print_exc()
        results[key] = {'status': 'failed', 'error': str(e), 'config': config}
        return None

   

# --- End of run_experiment definition ---

In [None]:
# CELL 8: Stage 1 - Run Baseline (Transfer Learning + Standard Augmentation)
print("\n--- Cell 8: Stage 1 - Baseline (Transfer Learning + Standard Augmentation) ---")


config_baseline = {
    'key': "Baseline_StdAug",
    'parse_args': {
        'apply_augment': True,
        'augment_layer': standard_augmentation, # Defined in Cell 4
        'apply_clahe': False,
        'clahe_clip_limit': 2.0, # Default, not used
        'img_size': IMG_SIZE
    },
    'model_args': {
        'pooling': 'avg',
        'attention': None,
        'dropout': DROPOUT_RATE,
        'base_trainable': False, # Base frozen for initial training
        'num_classes': 1 # Binary
    },
    'train_args': {
        'epochs': EPOCHS_HEAD,
        'learning_rate': LEARNING_RATE,
        'class_weights_setting': None # Start without explicit imbalance handling
    },
    'oversample_train': False # Not oversampling baseline
}

# Run the baseline experiment
baseline_metrics = run_experiment(
    config_baseline,
)

# Initialize the variable to track the best configuration key
# It starts with the baseline, assuming it ran successfully
current_best_config_key = config_baseline['key'] if baseline_metrics else None
if current_best_config_key:
     print(f"\nBaseline run complete. Current best configuration key: '{current_best_config_key}'")
else:
     print("\nERROR: Baseline run failed. Cannot proceed with subsequent experiments.")
     # Optionally raise an error here




In [None]:
# CELL 9: Stage 2 - CLAHE Experiments

print("\n--- Cell 9: Stage 2 - CLAHE Experiments ---")

if current_best_config_key is None:
    print("Skipping CLAHE experiments because baseline failed.")
else:
    clahe_clip_limits_to_test = [1.0, 2.0, 3.0]
    clahe_stage_keys = [current_best_config_key] # Start comparison list with the baseline key

    # Get the baseline config to modify
    base_config_for_clahe = results[current_best_config_key]['config']

    for clip_limit in clahe_clip_limits_to_test:
        clahe_key = f"{current_best_config_key}_CLAHE{clip_limit:.1f}" # Build key based on baseline

        config_clahe = {
            'key': clahe_key,
            'parse_args': base_config_for_clahe['parse_args'].copy(), # Copy baseline parse args
            'model_args': base_config_for_clahe['model_args'].copy(), # Copy baseline model args
            'train_args': base_config_for_clahe['train_args'].copy(), # Copy baseline train args
            'oversample_train': base_config_for_clahe['oversample_train'] # Copy baseline oversample flag
        }

        # Modify only the CLAHE settings in parse_args
        config_clahe['parse_args']['apply_clahe'] = True
        config_clahe['parse_args']['clahe_clip_limit'] = clip_limit

        # Run the experiment for this CLAHE variation
        # Inside the loop:
        # Run the experiment for this CLAHE variation
        clahe_metrics = run_experiment(
            config_clahe,
        )
        if clahe_metrics: # Add key only if run succeeded
            clahe_stage_keys.append(clahe_key)

    print("\nCompleted CLAHE experiments.")




In [None]:
# CELL 10: Stage 2 - CLAHE Comparison & Selection
print("\n--- Cell 10: Stage 2 - CLAHE Comparison & Selection ---")

import os
import json
import numpy as np # Required if directly handling np.nan, though pd.isna is often used

# --- Configuration for this Stage ---
# These global variables should be defined in previous cells:
# METRICS_DIR (str): Path to the directory where evaluation_metrics_{key}.json files are stored.
# PLOTS_DIR (str): Path to the directory where plots will be saved.
# TARGET_METRIC (str): The primary metric for selecting the best model (e.g., 'f1_opt').
# results (dict): The global dictionary populated by run_experiment, holding all results.
# current_best_config_key (str): Key of the best model from the PREVIOUS stage.

# Explicitly list the EXACT configuration keys for this CLAHE comparison stage.
# These names MUST match the {key} part of your "evaluation_metrics_{key}.json" filenames.
# This list includes the baseline that CLAHE is being compared against/applied to.
# We assume "Baseline_StdAug" is the relevant baseline for these CLAHE variants.
keys_for_clahe_stage_evaluation = [
    "Baseline_StdAug",
    "Baseline_StdAug_CLAHE1.0",
    "Baseline_StdAug_CLAHE2.0",
    "Baseline_StdAug_CLAHE3.0"
    # Add any other specific CLAHE variant keys if you ran them based on "Baseline_StdAug"
]

# This is the specific baseline configuration key within this stage's comparison list.
# It's used for context in titles or for fetching its score for comparison messages.
# It should be one of the keys from keys_for_clahe_stage_evaluation.
# If current_best_config_key (from previous stage) is this baseline, that's good,
# otherwise, ensure this key is correct for what you consider the 'base' in this comparison.
reference_baseline_key_in_this_stage = "Baseline_StdAug"

print(f"DEBUG: Initial current_best_config_key (from previous stage): '{current_best_config_key}'")
print(f"DEBUG: Reference baseline for this CLAHE stage comparison: '{reference_baseline_key_in_this_stage}'")
print(f"DEBUG: TARGET_METRIC for selection is: '{TARGET_METRIC}'")
print(f"DEBUG: METRICS_DIR is: '{METRICS_DIR}'")

# Filter the defined keys to only include those for which a metrics JSON file actually exists
valid_keys_for_comparison = [
    key for key in keys_for_clahe_stage_evaluation
    if os.path.exists(os.path.join(METRICS_DIR, f"evaluation_metrics_{key}.json"))
]

print(f"Found metric files for and will compare: {valid_keys_for_comparison}")

if not valid_keys_for_comparison:
    print("ERROR: No metric files found for any of the specified CLAHE stage configurations. Cannot proceed with comparison.")
    # current_best_config_key remains unchanged
elif len(valid_keys_for_comparison) == 1 and reference_baseline_key_in_this_stage in valid_keys_for_comparison:
    print(f"Warning: Only the reference baseline '{reference_baseline_key_in_this_stage}' metric file was found. "
          "No other CLAHE variants to compare against in the provided list for this stage.")
    # current_best_config_key remains unchanged if this is the only valid key
    # Or it becomes this key if it wasn't already.
    if current_best_config_key != reference_baseline_key_in_this_stage :
         print(f"Setting current_best_config_key to '{reference_baseline_key_in_this_stage}' as it's the only valid one found for this stage.")
         current_best_config_key = reference_baseline_key_in_this_stage
else:
    # --- Plotting ---
    metrics_to_request_for_plot = [
        'f1_opt',
        'accuracy_opt',
        'precision_opt',
        'recall_opt',
        'roc_auc_proba',
        'pr_auc'  # Added PR AUC
    ]
    # Ensure TARGET_METRIC is in the list to be plotted, preferably first
    if TARGET_METRIC not in metrics_to_request_for_plot:
        metrics_to_request_for_plot.insert(0, TARGET_METRIC)
    metrics_to_request_for_plot = list(dict.fromkeys(metrics_to_request_for_plot)) # Remove duplicates, keep order

    df_clahe_comparison = plot_comparison_bars_enhanced(
        config_keys_to_plot=valid_keys_for_comparison,
        metrics_dir=METRICS_DIR,
        title=f"CLAHE Stage Comparison (Ref: {reference_baseline_key_in_this_stage})",
        save_dir=PLOTS_DIR,
        metrics_to_display=metrics_to_request_for_plot
    )

    # --- Selection Logic (based on JSON files) ---
    best_score_this_stage = -1.0  # Initialize with a value lower than any possible valid score
    winner_key_this_stage = None
    winner_metrics_this_stage = None # To store metrics dict of the winner

    # Get the score of the reference baseline for this stage's comparison message
    reference_baseline_score = -1.0
    if reference_baseline_key_in_this_stage in valid_keys_for_comparison:
        ref_baseline_json_path = os.path.join(METRICS_DIR, f"evaluation_metrics_{reference_baseline_key_in_this_stage}.json")
        try:
            with open(ref_baseline_json_path, 'r') as f:
                ref_baseline_metrics_data = json.load(f)
            reference_baseline_score = ref_baseline_metrics_data.get(TARGET_METRIC, -1.0)
        except Exception as e:
            print(f"Warning: Error loading metrics for reference baseline '{reference_baseline_key_in_this_stage}': {e}")

    print(f"\nSelecting best configuration from this CLAHE stage using '{TARGET_METRIC}' "
          f"(Score of '{reference_baseline_key_in_this_stage}' for reference: {reference_baseline_score:.4f}):")

    for config_key in valid_keys_for_comparison:
        json_path = os.path.join(METRICS_DIR, f"evaluation_metrics_{config_key}.json")
        current_score_from_json = -1.0 # Default for this iteration
        current_metrics_from_json = None

        try:
            with open(json_path, 'r') as f:
                current_metrics_from_json = json.load(f)
            current_score_from_json = current_metrics_from_json.get(TARGET_METRIC, -1.0)
            
            metric_display_value = f"{current_score_from_json:.4f}" if current_score_from_json != -1.0 or TARGET_METRIC in current_metrics_from_json else "Not Found"
            print(f"  - Config '{config_key}': {TARGET_METRIC} = {metric_display_value}")

            # Assumes higher is better for TARGET_METRIC
            if current_score_from_json > best_score_this_stage:
                best_score_this_stage = current_score_from_json
                winner_key_this_stage = config_key
                winner_metrics_this_stage = current_metrics_from_json
        except Exception as e:
            print(f"  - Config '{config_key}': Error loading/processing metrics from JSON - {e}. Skipping for winner selection.")
            continue # Skip to the next config_key

    # --- Announce Winner of this Stage and Update the GLOBAL current_best_config_key ---
    if winner_key_this_stage and best_score_this_stage > -1.0: # Check if a valid positive score was found
        print(f"\n🏆 Winner of CLAHE Stage: '{winner_key_this_stage}' ({TARGET_METRIC}: {best_score_this_stage:.4f})")
        print("This configuration will be used as the new 'current_best_config_key' for subsequent stages.")
        current_best_config_key = winner_key_this_stage # Update the global best key

        if winner_key_this_stage in results and 'config' in results[winner_key_this_stage]:
            print("\nWinning Configuration Details (from in-memory 'results'):")
            winner_config_dict = results[winner_key_this_stage]['config']
            # Simple print, assuming you might have a dedicated print_config_details function
            for detail_key, detail_value in winner_config_dict.items():
                if isinstance(detail_value, dict):
                    print(f"  {detail_key}:")
                    for sub_key, sub_value in detail_value.items():
                        if sub_key == 'augment_layer': print(f"    {sub_key}: <Keras Layer Object>")
                        else: print(f"    {sub_key}: {sub_value}")
                else: print(f"  {detail_key}: {detail_value}")
        else:
            print(f"Full configuration details for winner '{winner_key_this_stage}' not found in in-memory 'results'.")

        if winner_metrics_this_stage:
            print("\nWinning Metrics (from JSON):")
            for m_key, m_val in winner_metrics_this_stage.items():
                if isinstance(m_val, (float, np.floating)): print(f"  {m_key}: {m_val:.4f}")
                else: print(f"  {m_key}: {m_val}")
    else:
        print("\nCould not determine a new winner for the CLAHE stage (e.g., all scores were -1.0 or no valid positive scores found).")
        if reference_baseline_key_in_this_stage in valid_keys_for_comparison and reference_baseline_score >= best_score_this_stage:
            print(f"The configuration '{reference_baseline_key_in_this_stage}' (Score: {reference_baseline_score:.4f}) "
                  "remains the best among those evaluated in this stage, or no improvement was found.")
            # If the baseline for this stage is better than any other in this stage, it becomes the new overall best.
            if current_best_config_key != reference_baseline_key_in_this_stage and reference_baseline_score != -1.0 :
                 current_best_config_key = reference_baseline_key_in_this_stage
                 print(f"Updating overall best configuration to '{current_best_config_key}'.")
            elif current_best_config_key == reference_baseline_key_in_this_stage:
                 print(f"Overall best configuration remains '{current_best_config_key}'.")

        else: # No clear winner and baseline itself wasn't valid or wasn't best
             print(f"The overall best configuration '{current_best_config_key}' (from before this stage) remains unchanged.")

print("-" * 70)

In [None]:
# LEARNING_RATE=.0001


In [None]:
# CELL 11: Stage 3 - Imbalance Handling Experiments
print("\n--- Cell 11: Stage 3 - Imbalance Handling Experiments ---")


# The baseline for this stage is the best configuration identified after Cell 10
imbalance_baseline_key = current_best_config_key
imbalance_stage_keys = [] # Keep track of keys for comparison in this stage

if imbalance_baseline_key is None or imbalance_baseline_key not in results:
    print("ERROR: Cannot proceed with Imbalance Handling stage. Baseline configuration key is missing or invalid.")
    # Optionally raise error: raise ValueError("Baseline configuration for Imbalance Handling stage is missing.")
else:
    print(f"Using configuration '{imbalance_baseline_key}' as baseline for Imbalance Handling stage.")
    # Add baseline key to the list for comparison
    imbalance_stage_keys.append(imbalance_baseline_key)

    # Retrieve the configuration dictionary of the baseline
    baseline_config_imbalance = results[imbalance_baseline_key]['config']

    # --- Experiment 3a: Apply Class Weights ---
    print("\n--- Running Imbalance Experiment: Class Weights ---")
    config_weights_key = f"{imbalance_baseline_key}_ClassWeights"
    config_weights = {
        'key': config_weights_key,
        'parse_args': baseline_config_imbalance['parse_args'].copy(),
        'model_args': baseline_config_imbalance['model_args'].copy(),
        'train_args': baseline_config_imbalance['train_args'].copy(),
        'oversample_train': False # Ensure oversampling is off
    }
    # Modify train_args to apply balanced class weights
    config_weights['train_args']['class_weights_setting'] = 'balanced'

    # Run the experiment
    weights_metrics = run_experiment(
    config_weights,
    )
    if weights_metrics: # Add key only if run succeeded
        imbalance_stage_keys.append(config_weights_key)


    # --- Experiment 3b: Apply Oversampling ---
    print("\n--- Running Imbalance Experiment: Oversampling ---")
    config_oversample_key = f"{imbalance_baseline_key}_Oversample"
    config_oversample = {
        'key': config_oversample_key,
        'parse_args': baseline_config_imbalance['parse_args'].copy(),
        'model_args': baseline_config_imbalance['model_args'].copy(),
        'train_args': baseline_config_imbalance['train_args'].copy(),
        'oversample_train': True # Enable oversampling in build_dataset
    }
    # Ensure class weights are explicitly off when oversampling
    config_oversample['train_args']['class_weights_setting'] = None

    # Run the experiment
    oversample_metrics = run_experiment(
    config_oversample,
    )
    if oversample_metrics: # Add key only if run succeeded
        imbalance_stage_keys.append(config_oversample_key)

    print("\nCompleted Imbalance Handling experiments.")





In [None]:
# CELL 12: Stage 3 - Imbalance Handling Comparison & Selection
print("\n--- Cell 12: Stage 3 - Imbalance Handling Comparison & Selection ---")

import os
import json
import numpy as np # Though pd.isna is used in plot function, direct use might occur

# --- Configuration for this Stage ---
# These global variables should be defined in previous cells:
# METRICS_DIR (str): Path to evaluation_metrics_{key}.json files.
# PLOTS_DIR (str): Path to save plots.
# TARGET_METRIC (str): Primary metric for selection (e.g., 'f1_opt').
# results (dict): Global dictionary with all experiment results (configs, paths, etc.).
# current_best_config_key (str): Key of the best model from the PREVIOUS stage (e.g., winner of CLAHE stage).

# The baseline for THIS Imbalance Handling stage is the winner from the previous stage.
imbalance_stage_baseline_key = current_best_config_key

# IMPORTANT: Define the list of exact keys for this Imbalance Handling stage.
# This list should include the baseline for this stage and all its imbalance variants.
# These names MUST match the {key} part of your "evaluation_metrics_{key}.json" filenames.
keys_for_imbalance_stage = [
    imbalance_stage_baseline_key,  # e.g., "Baseline_StdAug_CLAHE1.0"
    f"{imbalance_stage_baseline_key}_ClassWeights",
    f"{imbalance_stage_baseline_key}_Oversample"
    # Add other specific imbalance handling experiment keys if you ran more based on this baseline
]

# --- Pre-computation & Sanity Checks ---
print(f"DEBUG: Baseline for this Imbalance Stage (current_best_config_key entering Cell 12): '{imbalance_stage_baseline_key}'")
print(f"DEBUG: TARGET_METRIC for selection is: '{TARGET_METRIC}'") # Ensure this is 'f1_opt' or similar from your JSON
print(f"DEBUG: METRICS_DIR is: '{METRICS_DIR}'")
print(f"DEBUG: Intended keys for Imbalance Stage (before file check): {keys_for_imbalance_stage}")

# Filter these keys to only include those for which a metrics JSON file actually exists
valid_keys_for_imbalance_comparison = [
    key for key in keys_for_imbalance_stage
    if os.path.exists(os.path.join(METRICS_DIR, f"evaluation_metrics_{key}.json"))
]
print(f"Found metric files for and will compare these Imbalance Stage configurations: {valid_keys_for_imbalance_comparison}")


if not imbalance_stage_baseline_key or imbalance_stage_baseline_key not in valid_keys_for_imbalance_comparison:
    print(f"ERROR: Baseline for Imbalance Stage ('{imbalance_stage_baseline_key}') metric file not found or key is None. "
          "Skipping Imbalance Handling comparison.")
    # current_best_config_key remains unchanged from before this cell
elif len(valid_keys_for_imbalance_comparison) <= 1: # Needs at least baseline + 1 variant for a meaningful comparison
    print("Not enough successful Imbalance Handling runs (including this stage's baseline) to perform a meaningful comparison.")
    print(f"Keeping previous best configuration: '{current_best_config_key}'")
    # current_best_config_key remains unchanged
else:
    # --- Plotting ---
    metrics_to_request_for_plot = [
        'f1_opt',
        'accuracy_opt',
        'precision_opt',
        'recall_opt',
        'roc_auc_proba',
        'pr_auc'  # Make sure this key exists in your JSON files
    ]
    if TARGET_METRIC not in metrics_to_request_for_plot:
        metrics_to_request_for_plot.insert(0, TARGET_METRIC)
    metrics_to_request_for_plot = list(dict.fromkeys(metrics_to_request_for_plot)) # Remove duplicates

    df_imbalance_comparison = plot_comparison_bars_enhanced(
        config_keys_to_plot=valid_keys_for_imbalance_comparison,
        metrics_dir=METRICS_DIR,
        title=f"Imbalance Handling Stage Comparison (Base: {imbalance_stage_baseline_key})",
        save_dir=PLOTS_DIR,
        metrics_to_display=metrics_to_request_for_plot
    )

    # --- Selection Logic (based on JSON files) ---
    best_score_this_stage = -1.0
    winner_key_this_stage = None 
    winner_metrics_this_stage = None

    baseline_score_for_this_stage_ref = -1.0
    baseline_ref_json_path = os.path.join(METRICS_DIR, f"evaluation_metrics_{imbalance_stage_baseline_key}.json")
    try:
        with open(baseline_ref_json_path, 'r') as f:
            baseline_ref_metrics_data = json.load(f)
        baseline_score_for_this_stage_ref = baseline_ref_metrics_data.get(TARGET_METRIC, -1.0)
    except Exception as e:
        print(f"Warning: Error loading metrics for this stage's baseline '{imbalance_stage_baseline_key}': {e}")

    print(f"\nSelecting best configuration from Imbalance Handling stage using '{TARGET_METRIC}' "
          f"(Score of '{imbalance_stage_baseline_key}' for reference: {baseline_score_for_this_stage_ref:.4f}):")

    for config_key in valid_keys_for_imbalance_comparison:
        json_path = os.path.join(METRICS_DIR, f"evaluation_metrics_{config_key}.json")
        current_score_from_json = -1.0
        current_metrics_from_json = None

        try:
            with open(json_path, 'r') as f:
                current_metrics_from_json = json.load(f)
            current_score_from_json = current_metrics_from_json.get(TARGET_METRIC, -1.0)
            
            metric_display_value = f"{current_score_from_json:.4f}" if current_score_from_json != -1.0 or TARGET_METRIC in current_metrics_from_json else "Not Found"
            print(f"  - Config '{config_key}': {TARGET_METRIC} = {metric_display_value}")

            if current_score_from_json > best_score_this_stage: # Assumes higher is better
                best_score_this_stage = current_score_from_json
                winner_key_this_stage = config_key
                winner_metrics_this_stage = current_metrics_from_json
        except Exception as e:
            print(f"  - Config '{config_key}': Error loading/processing metrics from JSON - {e}. Skipping for winner selection.")
            continue

    # --- Announce Winner of this Stage and Update the GLOBAL current_best_config_key ---
    if winner_key_this_stage and best_score_this_stage > -1.0: 
        print(f"\n🏆 Winner of Imbalance Handling Stage: '{winner_key_this_stage}' ({TARGET_METRIC}: {best_score_this_stage:.4f})")
        print("This configuration will be updated as the new 'current_best_config_key'.")
        current_best_config_key = winner_key_this_stage # Update the global best key

        if winner_key_this_stage in results and 'config' in results[winner_key_this_stage]:
            print("\nWinning Configuration Details (from in-memory 'results'):")
            winner_config_dict = results[winner_key_this_stage]['config']
            for detail_key, detail_value in winner_config_dict.items():
                if isinstance(detail_value, dict):
                    print(f"  {detail_key}:")
                    for sub_key, sub_value in detail_value.items():
                        if sub_key == 'augment_layer': print(f"    {sub_key}: <Keras Layer Object>")
                        else: print(f"    {sub_key}: {sub_value}")
                else: print(f"  {detail_key}: {detail_value}")
        else:
            print(f"Full configuration details for winner '{winner_key_this_stage}' not found in in-memory 'results'.")

        if winner_metrics_this_stage:
            print("\nWinning Metrics (from JSON):")
            for m_key, m_val in winner_metrics_this_stage.items():
                if isinstance(m_val, (float, np.floating)): print(f"  {m_key}: {m_val:.4f}")
                else: print(f"  {m_key}: {m_val}")
    else:
        print("\nCould not determine a new winner for the Imbalance Handling stage (e.g., all scores were -1.0 or no improvement).")
        if imbalance_stage_baseline_key in valid_keys_for_imbalance_comparison and \
           baseline_score_for_this_stage_ref >= best_score_this_stage and \
           baseline_score_for_this_stage_ref > -1.0 : # Ensure baseline had a valid positive score
            print(f"The configuration '{imbalance_stage_baseline_key}' (Score: {baseline_score_for_this_stage_ref:.4f}) "
                  "remains the best among those evaluated in this stage.")
            if current_best_config_key != imbalance_stage_baseline_key: # If current best was somehow different
                 current_best_config_key = imbalance_stage_baseline_key 
                 print(f"Updating overall best configuration to '{current_best_config_key}'.")
            else:
                 print(f"Overall best configuration remains '{current_best_config_key}'.")
        else:
             print(f"The overall best configuration '{current_best_config_key}' (from before this stage started, or due to errors/no improvement) remains unchanged.")

print("-" * 70)

In [None]:
# print(LEARNING_RATE)
# DROPOUT_RATE = 0.25
# print(DROPOUT_RATE)

In [None]:
# CELL 13: Stage 4 - Pooling Experiments
print("\n--- Cell 13: Stage 4 - Pooling Experiments ---")

# The baseline for this stage is the best configuration identified after Cell 12
pooling_baseline_key = current_best_config_key
pooling_stage_keys = [] # Keep track of keys for comparison in this stage

if pooling_baseline_key is None or pooling_baseline_key not in results:
    print("ERROR: Cannot proceed with Pooling stage. Baseline configuration key is missing or invalid.")
    # Optionally raise error: raise ValueError("Baseline configuration for Pooling stage is missing.")
else:
    print(f"Using configuration '{pooling_baseline_key}' as baseline for Pooling stage.")
    # Add baseline key to the list for comparison
    pooling_stage_keys.append(pooling_baseline_key)

    # Retrieve the configuration dictionary of the baseline
    baseline_config_pooling = results[pooling_baseline_key]['config']
    baseline_pooling_type = baseline_config_pooling['model_args'].get('pooling', 'avg') # Get current pooling type
    print(f"Baseline pooling type for this stage: '{baseline_pooling_type}'")

    pooling_types_to_test = ['max', 'hybrid']

    for test_pooling_type in pooling_types_to_test:
        # Skip if the type to test is the same as the baseline's pooling type
        if test_pooling_type == baseline_pooling_type:
            print(f"Skipping pooling type '{test_pooling_type}' as it matches the baseline.")
            continue

        print(f"\n--- Running Pooling Experiment: {test_pooling_type.upper()} ---")
        # Construct key by appending pooling type to the baseline key for this stage
        config_pool_key = f"{pooling_baseline_key}_Pool{test_pooling_type.upper()}"

        config_pool = {
            'key': config_pool_key,
            'parse_args': baseline_config_pooling['parse_args'].copy(),
            'model_args': baseline_config_pooling['model_args'].copy(),
            'train_args': baseline_config_pooling['train_args'].copy(),
            'oversample_train': baseline_config_pooling['oversample_train']
        }

        # Modify only the pooling setting in model_args
        config_pool['model_args']['pooling'] = test_pooling_type
        # Ensure attention is still off for this stage
        config_pool['model_args']['attention'] = None

        # Run the experiment
        pool_metrics = run_experiment(
        config_pool,
        )
        if pool_metrics: # Add key only if run succeeded
            pooling_stage_keys.append(config_pool_key)

    print("\nCompleted Pooling experiments.")





In [None]:
# CELL 14: Stage 4 - Pooling Comparison & Selection
print("\n--- Cell 14: Stage 4 - Pooling Comparison & Selection ---")

import os
import json
import numpy as np # For np.nan if needed, pd.isna used in plot function

# --- Configuration for this Stage ---
# These global variables should be defined from previous cells:
# METRICS_DIR (str): Path to evaluation_metrics_{key}.json files.
# PLOTS_DIR (str): Path to save plots.
# TARGET_METRIC (str): Primary metric for selection (e.g., 'f1_opt').
# results (dict): Global dictionary with all experiment results (configs, paths, etc.).
# current_best_config_key (str): Key of the best model from the PREVIOUS stage (e.g., winner of Imbalance Handling).

# The baseline for THIS Pooling stage is the winner from the previous stage.
pooling_stage_baseline_key = current_best_config_key

# IMPORTANT: Define the list of exact keys for THIS Pooling stage.
# This list should include the pooling_stage_baseline_key AND all its pooling variants.
# These names MUST match the {key} part of your "evaluation_metrics_{key}.json" filenames.
# Example: If pooling_stage_baseline_key = "Baseline_StdAug_CLAHE1.0_OverSample"
keys_for_pooling_stage = [
    pooling_stage_baseline_key,  # e.g., "Baseline_StdAug_CLAHE1.0_OverSample" (represents default/current pooling)
    f"{pooling_stage_baseline_key}_PoolMAX",
    f"{pooling_stage_baseline_key}_PoolAVG",    # If you ran a specific 'PoolAVG' variant
    f"{pooling_stage_baseline_key}_PoolHYBRID"
    # Add other specific pooling strategy experiment keys you ran based on this baseline.
    # If your pooling experiments ALSO include attention, add those keys here,
    # e.g., f"{pooling_stage_baseline_key}_PoolHYBRID_AttnCBAM"
    # For now, this example focuses purely on comparing pooling types.
]
# Remove duplicates if the baseline itself represents one of the pooling types implicitly
keys_for_pooling_stage = list(dict.fromkeys(keys_for_pooling_stage))


# --- Pre-computation & Sanity Checks ---
print(f"DEBUG: Baseline for this Pooling Stage (current_best_config_key entering Cell 14): '{pooling_stage_baseline_key}'")
print(f"DEBUG: TARGET_METRIC for selection is: '{TARGET_METRIC}'")
print(f"DEBUG: METRICS_DIR is: '{METRICS_DIR}'")
print(f"DEBUG: Intended keys for Pooling Stage (before file check): {keys_for_pooling_stage}")

# Filter these keys to only include those for which a metrics JSON file actually exists
valid_keys_for_pooling_comparison = [
    key for key in keys_for_pooling_stage
    if os.path.exists(os.path.join(METRICS_DIR, f"evaluation_metrics_{key}.json"))
]
print(f"Found metric files for and will compare these Pooling Stage configurations: {valid_keys_for_pooling_comparison}")


if not pooling_stage_baseline_key or pooling_stage_baseline_key not in valid_keys_for_pooling_comparison:
    print(f"ERROR: Baseline for Pooling Stage ('{pooling_stage_baseline_key}') metric file not found or key is None. "
          "Skipping Pooling comparison.")
elif len(valid_keys_for_pooling_comparison) <= 1:
    print("Not enough successful Pooling runs (including this stage's baseline) to perform a meaningful comparison.")
    print(f"Keeping previous best configuration: '{current_best_config_key}'")
else:
    # --- Plotting ---
    metrics_to_request_for_plot = [
        'f1_opt', 'accuracy_opt', 'precision_opt', 'recall_opt', 'roc_auc_proba', 'pr_auc'
    ]
    if TARGET_METRIC not in metrics_to_request_for_plot:
        metrics_to_request_for_plot.insert(0, TARGET_METRIC)
    metrics_to_request_for_plot = list(dict.fromkeys(metrics_to_request_for_plot))

    df_pooling_comparison = plot_comparison_bars_enhanced(
        config_keys_to_plot=valid_keys_for_pooling_comparison,
        metrics_dir=METRICS_DIR,
        title=f"Pooling Strategy Stage Comparison (Base: {pooling_stage_baseline_key})",
        save_dir=PLOTS_DIR,
        metrics_to_display=metrics_to_request_for_plot
    )

    # --- Selection Logic (based on JSON files) ---
    best_score_this_stage = -1.0
    winner_key_this_stage = None 
    winner_metrics_this_stage = None

    baseline_score_for_this_stage_ref = -1.0
    baseline_ref_json_path = os.path.join(METRICS_DIR, f"evaluation_metrics_{pooling_stage_baseline_key}.json")
    try:
        with open(baseline_ref_json_path, 'r') as f:
            baseline_ref_metrics_data = json.load(f)
        baseline_score_for_this_stage_ref = baseline_ref_metrics_data.get(TARGET_METRIC, -1.0)
    except Exception as e:
        print(f"Warning: Error loading metrics for this stage's baseline '{pooling_stage_baseline_key}': {e}")

    print(f"\nSelecting best configuration from Pooling stage using '{TARGET_METRIC}' "
          f"(Score of '{pooling_stage_baseline_key}' for reference: {baseline_score_for_this_stage_ref:.4f}):")

    for config_key in valid_keys_for_pooling_comparison:
        json_path = os.path.join(METRICS_DIR, f"evaluation_metrics_{config_key}.json")
        current_score_from_json = -1.0
        current_metrics_from_json = None
        try:
            with open(json_path, 'r') as f:
                current_metrics_from_json = json.load(f)
            current_score_from_json = current_metrics_from_json.get(TARGET_METRIC, -1.0)
            metric_display_value = f"{current_score_from_json:.4f}" if current_score_from_json != -1.0 or TARGET_METRIC in current_metrics_from_json else "Not Found"
            print(f"  - Config '{config_key}': {TARGET_METRIC} = {metric_display_value}")
            if current_score_from_json > best_score_this_stage: # Assumes higher is better
                best_score_this_stage = current_score_from_json
                winner_key_this_stage = config_key
                winner_metrics_this_stage = current_metrics_from_json
        except Exception as e:
            print(f"  - Config '{config_key}': Error loading/processing metrics from JSON - {e}. Skipping.")
            continue
            
    # --- Announce Winner of this Stage and Update the GLOBAL current_best_config_key ---
    if winner_key_this_stage and best_score_this_stage > -1.0: 
        print(f"\n🏆 Winner of Pooling Stage: '{winner_key_this_stage}' ({TARGET_METRIC}: {best_score_this_stage:.4f})")
        # The old code said: "This configuration will be used as the baseline for the next stage (Attention)."
        # You might have a separate Attention stage, or combine Pooling & Attention. Adjust message as needed.
        print("This configuration will be updated as the new 'current_best_config_key'.")
        current_best_config_key = winner_key_this_stage 

        if winner_key_this_stage in results and 'config' in results[winner_key_this_stage]:
            print("\nWinning Configuration Details (from in-memory 'results'):")
            winner_config_dict = results[winner_key_this_stage]['config']
            winning_pooling_type = winner_config_dict.get('model_args', {}).get('pooling', 'N/A')
            print(f"  (Winning Pooling Type from config: {winning_pooling_type})")
            for detail_key, detail_value in winner_config_dict.items():
                if isinstance(detail_value, dict):
                    print(f"  {detail_key}:")
                    for sub_key, sub_value in detail_value.items():
                        if sub_key == 'augment_layer': print(f"    {sub_key}: <Keras Layer Object>")
                        else: print(f"    {sub_key}: {sub_value}")
                else: print(f"  {detail_key}: {detail_value}")
        else:
            print(f"Full configuration details for winner '{winner_key_this_stage}' not found in 'results'.")

        if winner_metrics_this_stage:
            print("\nWinning Metrics (from JSON):")
            for m_key, m_val in winner_metrics_this_stage.items():
                if isinstance(m_val, (float, np.floating)): print(f"  {m_key}: {m_val:.4f}")
                else: print(f"  {m_key}: {m_val}")
    else:
        print("\nCould not determine a new winner for the Pooling stage.")
        if pooling_stage_baseline_key in valid_keys_for_pooling_comparison and \
           baseline_score_for_this_stage_ref >= best_score_this_stage and \
           baseline_score_for_this_stage_ref > -1.0:
            print(f"The configuration '{pooling_stage_baseline_key}' (Score: {baseline_score_for_this_stage_ref:.4f}) remains the best.")
            current_best_config_key = pooling_stage_baseline_key
        else:
            print(f"The overall best configuration '{current_best_config_key}' remains unchanged.")
print("-" * 70)

In [None]:
# LEARNING_RATE=0.0002
# print(LEARNING_RATE)

In [None]:
# CELL 15: Stage 5 - Attention Mechanism Experiments
print("\n--- Cell 15: Stage 5 - Attention Mechanism Experiments ---")

# The baseline for this stage is the best configuration identified after Cell 14
attention_baseline_key = current_best_config_key
attention_stage_keys = [] # Keep track of keys for comparison in this stage

if attention_baseline_key is None or attention_baseline_key not in results:
    print("ERROR: Cannot proceed with Attention stage. Baseline configuration key is missing or invalid.")
    # Optionally raise error: raise ValueError("Baseline configuration for Attention stage is missing.")
else:
    print(f"Using configuration '{attention_baseline_key}' as baseline for Attention stage (No Attention).")
    # Add baseline key (representing No Attention) to the list for comparison
    attention_stage_keys.append(attention_baseline_key)

    # Retrieve the configuration dictionary of the baseline
    baseline_config_attention = results[attention_baseline_key]['config']
    # Verify baseline has no attention
    baseline_attention_type = baseline_config_attention['model_args'].get('attention', None)
    if baseline_attention_type is not None:
         print(f"WARNING: Baseline config '{attention_baseline_key}' for Attention stage unexpectedly has attention type '{baseline_attention_type}'.")

    attention_types_to_test = ['self', 'channel', 'spatial', 'cbam']

    for test_attention_type in attention_types_to_test:
        print(f"\n--- Running Attention Experiment: {test_attention_type.upper()} ---")
        # Construct key by appending attention type to the baseline key for this stage
        config_attn_key = f"{attention_baseline_key}_Attn{test_attention_type.upper()}"

        config_attn = {
            'key': config_attn_key,
            'parse_args': baseline_config_attention['parse_args'].copy(),
            'model_args': baseline_config_attention['model_args'].copy(),
            'train_args': baseline_config_attention['train_args'].copy(),
            'oversample_train': baseline_config_attention['oversample_train']
        }

        # Modify only the attention setting in model_args
        config_attn['model_args']['attention'] = test_attention_type

        # Run the experiment
        attn_metrics = run_experiment(
        config_attn,
        )
        if attn_metrics: # Add key only if run succeeded
            attention_stage_keys.append(config_attn_key)

    print("\nCompleted Attention experiments.")




In [None]:
# CELL 16: Stage 5 - Attention Comparison & Selection (Best Overall Pre-Finetuning)
print("\n--- Cell 16: Stage 5 - Attention Comparison & Selection (Best Overall Pre-Finetuning) ---")

import os
import json
import numpy as np # For np.nan if needed, pd.isna used in plot function

# --- Configuration for this Stage ---
# These global variables should be defined from previous cells:
# METRICS_DIR, PLOTS_DIR, TARGET_METRIC, results, current_best_config_key

# The baseline for THIS Attention stage is the winner from the previous (Pooling) stage.
attention_stage_baseline_key = current_best_config_key

# IMPORTANT: Define the list of exact keys for THIS Attention stage.
# This list should include the attention_stage_baseline_key (representing no new/specific attention or a default)
# AND all its attention variants. These names MUST match your "evaluation_metrics_{key}.json" filenames.
# Example: If attention_stage_baseline_key = "Baseline_StdAug_CLAHE1.0_PoolHYBRID"
keys_for_attention_stage = [
    attention_stage_baseline_key,  # This is the model configuration before adding specific attention mechanisms for this stage
    f"{attention_stage_baseline_key}_AttnCBAM",
    f"{attention_stage_baseline_key}_AttnCHANNEL",
    f"{attention_stage_baseline_key}_AttnSPATIAL",
    # Add other specific attention mechanism experiment keys you ran based on this baseline.
    # Adjust if your naming convention is different (e.g., if attention is part of the pooling key like _PoolHYBRID_AttnCBAM)
    # In that case, your list might be more like the pooling stage list but varying the Attn part.
    # Given your file list, if pooling_stage_baseline_key was "Baseline_StdAug_CLAHE1.0_PoolHYBRID", then:
    # keys_for_attention_stage = [
    #    "Baseline_StdAug_CLAHE1.0_PoolHYBRID", # No specific additional attention
    #    "Baseline_StdAug_CLAHE1.0_PoolHYBRID_AttnCBAM",
    #    "Baseline_StdAug_CLAHE1.0_PoolHYBRID_AttnCHANNEL",
    #    "Baseline_StdAug_CLAHE1.0_PoolHYBRID_AttnSPATIAL",
    # ]
]
# Remove duplicates if the baseline itself represents one of the attention types implicitly
keys_for_attention_stage = list(dict.fromkeys(keys_for_attention_stage))

# This variable will store the ultimate winner before any fine-tuning.
best_overall_pre_finetune_key = None # Initialize

# --- Pre-computation & Sanity Checks ---
print(f"DEBUG: Baseline for this Attention Stage (current_best_config_key entering Cell 16): '{attention_stage_baseline_key}'")
print(f"DEBUG: TARGET_METRIC for selection is: '{TARGET_METRIC}'")
print(f"DEBUG: METRICS_DIR is: '{METRICS_DIR}'")
print(f"DEBUG: Intended keys for Attention Stage (before file check): {keys_for_attention_stage}")

# Filter these keys to only include those for which a metrics JSON file actually exists
valid_keys_for_attention_comparison = [
    key for key in keys_for_attention_stage
    if os.path.exists(os.path.join(METRICS_DIR, f"evaluation_metrics_{key}.json"))
]
print(f"Found metric files for and will compare these Attention Stage configurations: {valid_keys_for_attention_comparison}")


if not attention_stage_baseline_key or attention_stage_baseline_key not in valid_keys_for_attention_comparison:
    print(f"ERROR: Baseline for Attention Stage ('{attention_stage_baseline_key}') metric file not found or key is None. "
          "Skipping Attention comparison.")
    best_overall_pre_finetune_key = current_best_config_key # Fallback to previous stage's winner
    print(f"Best overall pre-finetune key defaults to: '{best_overall_pre_finetune_key}'")
elif len(valid_keys_for_attention_comparison) <= 1:
    print("Not enough successful Attention runs (including this stage's baseline/No Attention) to perform a meaningful comparison.")
    best_overall_pre_finetune_key = current_best_config_key # The baseline for this stage is effectively the winner
    print(f"Best overall pre-finetune key set to: '{best_overall_pre_finetune_key}' (baseline of this stage).")
else:
    # --- Plotting ---
    metrics_to_request_for_plot = [
        'f1_opt', 'accuracy_opt', 'precision_opt', 'recall_opt', 'roc_auc_proba', 'pr_auc'
    ]
    if TARGET_METRIC not in metrics_to_request_for_plot:
        metrics_to_request_for_plot.insert(0, TARGET_METRIC)
    metrics_to_request_for_plot = list(dict.fromkeys(metrics_to_request_for_plot))

    df_attention_comparison = plot_comparison_bars_enhanced(
        config_keys_to_plot=valid_keys_for_attention_comparison,
        metrics_dir=METRICS_DIR,
        title=f"Attention Mechanism Stage Comparison (Base: {attention_stage_baseline_key})",
        save_dir=PLOTS_DIR,
        metrics_to_display=metrics_to_request_for_plot
    )

    # --- Selection Logic (based on JSON files) ---
    best_score_this_stage = -1.0
    winner_key_this_stage = None 
    winner_metrics_this_stage = None

    baseline_score_for_this_stage_ref = -1.0
    baseline_ref_json_path = os.path.join(METRICS_DIR, f"evaluation_metrics_{attention_stage_baseline_key}.json")
    try:
        with open(baseline_ref_json_path, 'r') as f:
            baseline_ref_metrics_data = json.load(f)
        baseline_score_for_this_stage_ref = baseline_ref_metrics_data.get(TARGET_METRIC, -1.0)
    except Exception as e:
        print(f"Warning: Error loading metrics for this stage's baseline '{attention_stage_baseline_key}': {e}")

    print(f"\nSelecting best configuration from Attention stage using '{TARGET_METRIC}' "
          f"(Score of '{attention_stage_baseline_key}' (No new Attention) for reference: {baseline_score_for_this_stage_ref:.4f}):")

    for config_key in valid_keys_for_attention_comparison:
        json_path = os.path.join(METRICS_DIR, f"evaluation_metrics_{config_key}.json")
        current_score_from_json = -1.0
        current_metrics_from_json = None
        try:
            with open(json_path, 'r') as f:
                current_metrics_from_json = json.load(f)
            current_score_from_json = current_metrics_from_json.get(TARGET_METRIC, -1.0)
            metric_display_value = f"{current_score_from_json:.4f}" if current_score_from_json != -1.0 or TARGET_METRIC in current_metrics_from_json else "Not Found"
            print(f"  - Config '{config_key}': {TARGET_METRIC} = {metric_display_value}")
            if current_score_from_json > best_score_this_stage: # Assumes higher is better
                best_score_this_stage = current_score_from_json
                winner_key_this_stage = config_key
                winner_metrics_this_stage = current_metrics_from_json
        except Exception as e:
            print(f"  - Config '{config_key}': Error loading/processing metrics from JSON - {e}. Skipping.")
            continue
            
    # --- Announce Winner of this Stage and Update GLOBAL current_best_config_key ---
    # This winner is also the best_overall_pre_finetune_key
    if winner_key_this_stage and best_score_this_stage > -1.0: 
        print(f"\n🏆 Winner of Attention Stage (Best Overall Pre-Finetune): '{winner_key_this_stage}' ({TARGET_METRIC}: {best_score_this_stage:.4f})")
        current_best_config_key = winner_key_this_stage 
        best_overall_pre_finetune_key = current_best_config_key # Store specifically
        print("This configuration will be used for the final Fine-tuning stage.")


        if winner_key_this_stage in results and 'config' in results[winner_key_this_stage]:
            print("\nBest Overall (Pre-Finetuning) Configuration Details (from in-memory 'results'):")
            winner_config_dict = results[winner_key_this_stage]['config']
            winning_attention_type = winner_config_dict.get('model_args', {}).get('attention', 'N/A (or baseline)')
            print(f"  (Winning Attention Type from config: {winning_attention_type})")
            for detail_key, detail_value in winner_config_dict.items():
                if isinstance(detail_value, dict): # More concise print for nested dicts
                    # Check for 'model_args' specifically to print its content if desired
                    if detail_key == 'model_args':
                        print(f"  {detail_key}:")
                        for sub_k, sub_v in detail_value.items(): print(f"    {sub_k}: {sub_v}")
                    elif detail_key == 'parse_args' and 'augment_layer' in detail_value:
                        print(f"  {detail_key}:")
                        for pa_key, pa_value in detail_value.items():
                             if pa_key == 'augment_layer': print(f"    {pa_key}: <Keras Layer Object>")
                             else: print(f"    {pa_key}: {pa_value}")
                    else:
                        print(f"  {detail_key}: {{...}}") # Default concise print for other dicts
                else: print(f"  {detail_key}: {detail_value}")
        else:
            print(f"Full configuration details for winner '{winner_key_this_stage}' not found in 'results'.")

        if winner_metrics_this_stage:
            print("\nBest Overall (Pre-Finetuning) Metrics (from JSON):")
            for m_key, m_val in winner_metrics_this_stage.items():
                if isinstance(m_val, (float, np.floating)): print(f"  {m_key}: {m_val:.4f}")
                else: print(f"  {m_key}: {m_val}")
    else:
        print("\nCould not determine a new winner for the Attention stage.")
        # If no new winner, the baseline for this stage (winner of previous stage) remains the best pre-finetune
        best_overall_pre_finetune_key = attention_stage_baseline_key 
        # current_best_config_key also remains attention_stage_baseline_key
        print(f"The configuration '{attention_stage_baseline_key}' (Score: {baseline_score_for_this_stage_ref:.4f}) remains the best overall pre-finetune.")

# --- Final Sanity Check before potentially moving to a fine-tuning cell ---
if best_overall_pre_finetune_key is None or \
   not os.path.exists(os.path.join(METRICS_DIR, f"evaluation_metrics_{best_overall_pre_finetune_key}.json")):
    print("\nCRITICAL WARNING: Could not determine a valid best overall configuration before fine-tuning, or its metrics file is missing. "
          "The next Fine-tuning stage might not proceed correctly.")
    # Consider raising an error: raise RuntimeError("Failed to determine best configuration for fine-tuning.")
else:
    print(f"\nReady for Fine-Tuning using configuration: '{best_overall_pre_finetune_key}'")

print("-" * 70)

In [None]:
# CELL 17: Stage 6 - Fine-Tuning Experiment
print("\n--- Cell 17: Stage 6 - Fine-Tuning Experiment ---")

import os
import time # Make sure time is imported if not already at the top
import numpy as np # Make sure numpy is imported
# import tensorflow as tf # Already imported in Cell 6 usually
# from tensorflow import keras # Already imported in Cell 6 usually
import gc # Make sure gc is imported

# Assumed global variables:
# best_overall_pre_finetune_key (from Cell 16)
# results (global dictionary)
# CHECKPOINT_DIR, PLOTS_DIR, METRICS_DIR
# train_paths, train_labels, val_paths, val_labels, test_paths, test_labels
# preprocess_image, IMG_SIZE, GLOBAL_BATCH_SIZE
# strategy, custom_objects_map, inv_label_dict, label_dict
# LEARNING_RATE_FINETUNE, EPOCHS_FINETUNE
# PATIENCE_EARLY_STOPPING, PATIENCE_REDUCE_LR, MIN_LR
# TARGET_METRIC (e.g., 'f1_opt')
# train_model (function from Cell 6)
# evaluate_model_optimized_with_viz (function from Cell 6)
# plot_training_history_enhanced (function from Cell 6)
# class_weights_dict (if class weights are used)

fine_tuned_run_key = None # Use a different variable name to avoid confusion if cell is re-run

if best_overall_pre_finetune_key is None or best_overall_pre_finetune_key not in results:
    print("CRITICAL ERROR: Cannot proceed with Fine-tuning. Best pre-finetune configuration key is missing or invalid.")
else:
    print(f"Starting Fine-tuning based on configuration: '{best_overall_pre_finetune_key}'")

    pre_ft_config_data = results[best_overall_pre_finetune_key].get('config')
    pre_ft_checkpoint_path = results[best_overall_pre_finetune_key].get('checkpoint_path')

    if not pre_ft_config_data or not pre_ft_checkpoint_path or not os.path.exists(pre_ft_checkpoint_path):
        print(f"ERROR: Cannot fine-tune. Missing config or checkpoint file ({pre_ft_checkpoint_path}) for the best pre-finetune model '{best_overall_pre_finetune_key}'.")
    else:
        fine_tuned_run_key = f"{best_overall_pre_finetune_key}_FineTuned" # Construct a unique key
        
        # Define Fine-tuning specific parameters
        unfreeze_from_block = 'conv5_block' # Example: Start unfreezing from conv block 5 for DenseNet
        fine_tuning_lr = LEARNING_RATE_FINETUNE # Should be a very small LR
        fine_tuning_epochs = EPOCHS_FINETUNE
        fine_tune_checkpoint_filepath = os.path.join(CHECKPOINT_DIR, f"{fine_tuned_run_key}_best.keras")

        print(f"Fine-tuning Key: {fine_tuned_run_key}")
        print(f"Loading model from: {pre_ft_checkpoint_path}")
        print(f"Unfreezing from: '{unfreeze_from_block}' (or all if not found)")
        print(f"Fine-tuning LR: {fine_tuning_lr}, Max Epochs: {fine_tuning_epochs}")
        print(f"Best fine-tuned model will be saved to: {fine_tune_checkpoint_filepath}")

        model_ft = None
        history_ft = None
        eval_metrics_ft = None
        training_duration_ft = 0
        start_time_total_ft = time.time()

        try:
            # --- 1. Build Datasets (using config from best pre-ft run) ---
            print("\n[FT-1. Building Datasets for Fine-Tuning...]")
            ft_parse_args = pre_ft_config_data['parse_args'].copy()
            ft_oversample = pre_ft_config_data.get('oversample_train', False)

            train_ds_ft = build_dataset(
                train_paths, train_labels, preprocess_image, ft_parse_args,
                GLOBAL_BATCH_SIZE, f"Train ({fine_tuned_run_key})", shuffle=True,
                augment_in_map=ft_parse_args.get('apply_augment', False),
                oversample=ft_oversample, cache=True
            )
            val_ft_parse_args = ft_parse_args.copy(); val_ft_parse_args['apply_augment'] = False
            test_ft_parse_args = ft_parse_args.copy(); test_ft_parse_args['apply_augment'] = False
            val_ds_ft = build_dataset(val_paths, val_labels, preprocess_image, val_ft_parse_args, GLOBAL_BATCH_SIZE, f"Val ({fine_tuned_run_key})", shuffle=False, cache=True)
            test_ds_ft = build_dataset(test_paths, test_labels, preprocess_image, test_ft_parse_args, GLOBAL_BATCH_SIZE, f"Test ({fine_tuned_run_key})", shuffle=False, cache=True)

            if not all([train_ds_ft, val_ds_ft, test_ds_ft]):
                raise RuntimeError("Failed to build fine-tuning datasets.")
            print("Fine-tuning datasets built successfully.")

            # --- 2. Load Best Pre-FT Model ---
            print("\n[FT-2. Loading Best Pre-Finetune Model...]")
            with strategy.scope():
                model_ft = keras.models.load_model(pre_ft_checkpoint_path, custom_objects=custom_objects_map)
            print(f"Model '{model_ft.name}' loaded successfully.")

            # --- 3. Unfreeze Layers ---
            print(f"\n[FT-3. Unfreezing Base Model Layers from '{unfreeze_from_block}'...]")
            base_model_to_unfreeze = None
            for layer in model_ft.layers:
                if isinstance(layer, keras.Model) and ('densenet' in layer.name or 'efficientnet' in layer.name or 'resnet' in layer.name): # More generic base model check
                    base_model_to_unfreeze = layer
                    break
            if not base_model_to_unfreeze and len(model_ft.layers) > 1 and isinstance(model_ft.layers[1], keras.Model): # Fallback
                base_model_to_unfreeze = model_ft.layers[1]
                print(f"Warning: Assuming base model is layer '{base_model_to_unfreeze.name}' based on position.")
            
            if not base_model_to_unfreeze:
                raise ValueError("Could not identify base model layer for unfreezing.")
            print(f"Identified base model for unfreezing: '{base_model_to_unfreeze.name}'")

            base_model_to_unfreeze.trainable = True
            unfreeze_from_index = -1
            if unfreeze_from_block: # Only try to find specific block if name is given
                for i, layer in enumerate(base_model_to_unfreeze.layers):
                    if layer.name.startswith(unfreeze_from_block):
                        unfreeze_from_index = i
                        break
                if unfreeze_from_index == -1:
                    print(f"Warning: Layer prefix '{unfreeze_from_block}' not found in base model '{base_model_to_unfreeze.name}'. Unfreezing ALL base layers.")
                    unfreeze_from_index = 0 
            else: # If unfreeze_from_block is None or empty, unfreeze all
                print("Unfreezing ALL layers in the base model.")
                unfreeze_from_index = 0
            
            num_frozen_in_base = 0
            if unfreeze_from_index > 0:
                print(f"Freezing layers in '{base_model_to_unfreeze.name}' before index {unfreeze_from_index} ('{base_model_to_unfreeze.layers[unfreeze_from_index].name}')")
                for layer in base_model_to_unfreeze.layers[:unfreeze_from_index]:
                    if layer.trainable: layer.trainable = False; num_frozen_in_base += 1
                print(f"Froze {num_frozen_in_base} layers in the base model.")
            else:
                print(f"All {len(base_model_to_unfreeze.layers)} layers in '{base_model_to_unfreeze.name}' will be trainable (or maintain their current trainable status if set individually).")

            trainable_count = np.sum([np.prod(w.shape) for w in model_ft.trainable_weights])
            non_trainable_count = np.sum([np.prod(w.shape) for w in model_ft.non_trainable_weights])
            print(f"Total Trainable weights in full model: {trainable_count:,}")
            print(f"Total Non-trainable weights in full model: {non_trainable_count:,}")

            # --- 4. Re-compile Model ---
            print("\n[FT-4. Re-compiling Model for Fine-tuning...]")
            with strategy.scope():
                model_ft.compile(
                    optimizer=keras.optimizers.Adam(learning_rate=fine_tuning_lr),
                    loss='binary_crossentropy',
                    metrics=[ 'accuracy', tf.keras.metrics.Precision(name='precision'),
                              tf.keras.metrics.Recall(name='recall'), tf.keras.metrics.AUC(name='auc')]
                )
            print(f"Model re-compiled with LR={fine_tuning_lr}.")

            # --- 5. Define Fine-Tuning Callbacks ---
            print("\n[FT-5. Setting up Fine-Tuning Callbacks...]")
            ft_callbacks_list = [
                EarlyStopping(monitor='val_loss', patience=PATIENCE_EARLY_STOPPING + 2, verbose=1, restore_best_weights=False),
                ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=PATIENCE_REDUCE_LR, min_lr=MIN_LR / 5, verbose=1), # Potentially even lower min_lr
                ModelCheckpoint(filepath=fine_tune_checkpoint_filepath, monitor='val_loss', save_best_only=True, save_weights_only=False, verbose=1)
            ]
            print(f"Best fine-tuned model checkpoint path: {fine_tune_checkpoint_filepath}")

            # --- 6. Train (Fine-tune) ---
            print("\n[FT-6. Starting Fine-tuning Training Phase...]")
            ft_class_weights_setting = pre_ft_config_data['train_args'].get('class_weights_setting')
            ft_train_class_weights = class_weights_dict if ft_class_weights_setting == 'balanced' else None

            history_ft, training_duration_ft = train_model(
                model_ft, train_ds_ft, val_ds_ft,
                epochs=fine_tuning_epochs, initial_epoch=0, 
                class_weights=ft_train_class_weights, strategy=strategy,
                learning_rate=fine_tuning_lr, # LR is set during compile, but train_model might use it for logging
                callbacks=ft_callbacks_list,
                stage_name=f"Fine-Tuning ({fine_tuned_run_key})"
            )
            if history_ft is None: raise RuntimeError("Model fine-tuning training failed.")

            # --- 7. Load Best Fine-tuned Weights ---
            print("\n[FT-7. Loading Best Weights from Fine-Tuning Checkpoint...]")
            model_to_eval_ft = model_ft # Default to last epoch if checkpoint missing
            if os.path.exists(fine_tune_checkpoint_filepath):
                with strategy.scope():
                    model_to_eval_ft = keras.models.load_model(fine_tune_checkpoint_filepath, custom_objects=custom_objects_map)
                print(f"Successfully loaded best fine-tuned model from {fine_tune_checkpoint_filepath}")
            else:
                print(f"WARNING: Fine-tuning checkpoint not found at {fine_tune_checkpoint_filepath}. Evaluating with last FT epoch's weights.")

            # --- 8. Evaluate Final Fine-tuned Model ---
            print("\n[FT-8. Evaluating Final Fine-tuned Model...]")
            eval_metrics_ft, y_true_test_ft, y_pred_proba_test_ft, _ = evaluate_model_optimized_with_viz( # CORRECTED
                model=model_to_eval_ft, # Use the reloaded best model
                val_ds=val_ds_ft, 
                test_ds=test_ds_ft,
                strategy=strategy,
                inv_label_map=inv_label_dict,
                target_metric=TARGET_METRIC, # Ensure TARGET_METRIC is globally defined
                dataset_name=f"Test ({fine_tuned_run_key})",
                save_dir=PLOTS_DIR,
                config_name=fine_tuned_run_key
            )
            if not eval_metrics_ft:
                print("Warning: Fine-tuned evaluation failed or returned no metrics.")
                eval_metrics_ft = {}

            # --- 9. Grad-CAM on Fine-tuned Model (REMOVED/COMMENTED as per previous request) ---
            # print("\n[FT-9. Generating Grad-CAM for Fine-tuned Model...]")
            # try:
            #     # ... (Grad-CAM logic if you re-enable it, ensure it uses model_to_eval_ft) ...
            #     print("Grad-CAM for fine-tuned model is currently disabled.")
            # except Exception as grad_e_ft:
            #     print(f"Error during fine-tuned Grad-CAM generation: {grad_e_ft}")


            # --- 10. Store Fine-Tuned Results ---
            print("\n[FT-10. Storing Fine-tuned Results...]")
            total_duration_ft = time.time() - start_time_total_ft
            
            # Save metrics to JSON for the fine-tuned model
            metrics_json_save_path_ft = None
            if eval_metrics_ft:
                metrics_json_save_path_ft = os.path.join(METRICS_DIR, f"evaluation_metrics_{fine_tuned_run_key}.json")
                try:
                    serializable_metrics_ft = {}
                    for m_key, m_value in eval_metrics_ft.items():
                        if isinstance(m_value, np.generic): serializable_metrics_ft[m_key] = m_value.item()
                        elif isinstance(m_value, np.ndarray): serializable_metrics_ft[m_key] = m_value.tolist()
                        else: serializable_metrics_ft[m_key] = m_value
                    with open(metrics_json_save_path_ft, 'w') as f:
                        json.dump(serializable_metrics_ft, f, indent=4)
                    print(f"Fine-tuned evaluation metrics saved to {metrics_json_save_path_ft}")
                except Exception as e:
                    print(f"Error saving fine-tuned evaluation metrics to JSON: {e}")
                    metrics_json_save_path_ft = None
            
            results[fine_tuned_run_key] = {
                'config': pre_ft_config_data, 
                'fine_tune_params': {
                    'unfreeze_from_block': unfreeze_from_block, 'lr': fine_tuning_lr,
                    'epochs_run': len(history_ft.epoch) if history_ft and hasattr(history_ft, 'epoch') else 0,
                },
                'metrics': eval_metrics_ft,
                'training_duration_sec': training_duration_ft,
                'total_duration_sec': total_duration_ft,
                'checkpoint_path': fine_tune_checkpoint_filepath if os.path.exists(fine_tune_checkpoint_filepath) else None,
                'metrics_json_path': metrics_json_save_path_ft
            }
            print(f"Results for fine-tuned model '{fine_tuned_run_key}' stored.")

            # --- 11. Plotting Fine-Tuned Results ---
            print("\n[FT-11. Plotting Fine-tuned Results...]")
            if history_ft:
                plot_training_history_enhanced( # CORRECTED
                    history_ft,
                    title_suffix=f"({fine_tuned_run_key} - FT Phase)",
                    save_dir=PLOTS_DIR,
                    config_name=fine_tuned_run_key
                )
            # Other plots (CM, ROC/PR) are now handled by evaluate_model_optimized_with_viz in FT-8.

            print(f"--- Fine-Tuning Experiment {fine_tuned_run_key} Complete ---")
            # Update current_best_config_key if fine-tuned model is better
            # This comparison should ideally happen in the next cell (Final Summary Comparison)
            # For now, we just record the fine_tuned_run_key.
            # The next cell (Cell 18) will compare this fine_tuned_run_key with best_overall_pre_finetune_key.

        except Exception as e:
            print(f"\n\n ****** ERROR during FINE-TUNING experiment {fine_tuned_run_key} ****** ")
            print(f"Error Type: {type(e).__name__}")
            print(f"Error Details: {e}")
            import traceback # Moved import here for when it's actually needed
            traceback.print_exc()
            if fine_tuned_run_key: # Only store if key was generated
                results[fine_tuned_run_key] = {'status': 'failed', 'error': str(e), 'config': pre_ft_config_data if 'pre_ft_config_data' in locals() else {}}
            fine_tuned_run_key = None # Mark as failed by nullifying the key for later checks

        # finally:
        #     # --- 12. Fine-Tuning Cleanup ---
        #     print("\n[FT-12. Cleaning up fine-tuning resources...]")
        #     del model_ft, train_ds_ft, val_ds_ft, test_ds_ft, history_ft 
        #     if 'keras' in globals() or 'tensorflow.keras' in globals():
        #          keras.backend.clear_session()
        #     gc.collect()
        #     print("Fine-tuning cleanup complete.")
            
# --- Update current_best_config_key AFTER fine-tuning IF it was successful AND better ---
# This comparison and update is typically done in the *next* cell (Cell 18 - Final Summary)
# For now, Cell 17 just runs the fine-tuning experiment and records its result.
# Cell 18 will compare fine_tuned_run_key with best_overall_pre_finetune_key.
# So, current_best_config_key is NOT updated here. It's updated by the comparison cells.
# best_overall_pre_finetune_key remains the best pre-finetuning key.
# fine_tuned_run_key (if successful) is the key for the fine-tuned version.

print("-" * 70)
if fine_tuned_run_key and fine_tuned_run_key in results and results[fine_tuned_run_key].get('status') != 'failed':
    print(f"Fine-tuning experiment '{fine_tuned_run_key}' completed and results stored.")
    print(f"The best model before this fine-tuning was: '{best_overall_pre_finetune_key}'")
    print(f"Compare metrics of '{fine_tuned_run_key}' with '{best_overall_pre_finetune_key}' in the next cell (Cell 18) to determine the ultimate winner.")
elif fine_tuned_run_key and fine_tuned_run_key in results and results[fine_tuned_run_key].get('status') == 'failed':
    print(f"Fine-tuning experiment '{fine_tuned_run_key}' FAILED.")
else:
    print("Fine-tuning was not performed or key was not generated due to earlier errors.")

In [None]:
# CELL 18: Stage 7 - Final Summary Comparison
print("\n--- Cell 18: Stage 7 - Final Summary Comparison ---")

import os
import json
import numpy as np # For isinstance checks in JSON serialization, and pd.isna in plot func
import pandas as pd # For pd.isna in plot func

# These global variables should be defined from previous cells:
# METRICS_DIR (str): Path to evaluation_metrics_{key}.json files.
# PLOTS_DIR (str): Path to save plots.
# results (dict): Global dictionary holding all experiment results.
# plot_comparison_bars_enhanced (function from Cell 6 that reads from JSON)

# Variables holding keys from previous stages (ensure these are correctly set by Cells 8, 10, 12, 14, 16, 17)
# Example:
# key1 = "Baseline_StdAug" # From Cell 8 (or your actual baseline key)
# imbalance_baseline_key # Winner of CLAHE, input to Imbalance stage (Cell 12)
# pooling_baseline_key # Winner of Imbalance, input to Pooling stage (Cell 14)
# attention_baseline_key # Winner of Pooling, input to Attention stage (Cell 16)
# best_overall_pre_finetune_key # Winner of Attention stage (Cell 16)
# fine_tuned_run_key # Key of the fine-tuned model from Cell 17 (use the variable holding the actual key)


print("Gathering results for final milestone comparison...")
# Store original keys that correspond to milestones
milestone_original_keys = []
# Map descriptive labels to original keys for sorting or direct access if needed
milestone_label_to_original_key = {}

# 1. Baseline (StdAug only)
# Ensure 'key1' is the actual key for your initial baseline experiment.
# If you have a variable holding this, use it. For example:
# initial_baseline_key = "Baseline_StdAug" # Or whatever it was named
key1 = "Baseline_StdAug" # Assuming this is your absolute first baseline key
milestone_label = "1. Baseline (Aug)"
if key1 and os.path.exists(os.path.join(METRICS_DIR, f"evaluation_metrics_{key1}.json")):
    milestone_original_keys.append(key1)
    milestone_label_to_original_key[milestone_label] = key1
    print(f"- Found Milestone 1: '{key1}' -> '{milestone_label}'")
else:
    print(f"- WARNING: Milestone 1 key '{key1}' or its metrics JSON not found.")

# 2. Best CLAHE
# This should be the key that won the CLAHE stage (was 'current_best_config_key' after CLAHE cell,
# and became 'imbalance_stage_baseline_key' at the start of Imbalance cell)
key2 = imbalance_baseline_key # Key after CLAHE stage
milestone_label = "2. +CLAHE" if key2 and "_CLAHE" in key2 else "2. Baseline (No CLAHE)"
if key2 and os.path.exists(os.path.join(METRICS_DIR, f"evaluation_metrics_{key2}.json")):
    if key2 not in milestone_original_keys : milestone_original_keys.append(key2) # Avoid duplicates if same as key1
    milestone_label_to_original_key[milestone_label] = key2
    print(f"- Found Milestone 2: '{key2}' -> '{milestone_label}'")
else:
    print(f"- WARNING: Milestone 2 key '{key2}' or its metrics JSON not found.")

# 3. Best Imbalance Handling
key3 = pooling_baseline_key # Key after Imbalance stage
milestone_label = "3. +Imbalance"
if key3:
    if "_ClassWeights" in key3: milestone_label += " (Weights)"
    elif "_Oversample" in key3: milestone_label += " (Oversample)"
    # Add else if neither, assume baseline for imbalance stage didn't change or no specific imbalance won
if key3 and os.path.exists(os.path.join(METRICS_DIR, f"evaluation_metrics_{key3}.json")):
    if key3 not in milestone_original_keys : milestone_original_keys.append(key3)
    milestone_label_to_original_key[milestone_label] = key3
    print(f"- Found Milestone 3: '{key3}' -> '{milestone_label}'")
else:
    print(f"- WARNING: Milestone 3 key '{key3}' or its metrics JSON not found.")

# 4. Best Pooling
key4 = attention_baseline_key # Key after Pooling stage
milestone_label = "4. +Pooling"
if key4:
    if "_PoolMAX" in key4: milestone_label += " (Max)"
    elif "_PoolHYBRID" in key4: milestone_label += " (Hybrid)"
    elif "_PoolAVG" in key4: milestone_label += " (Avg)"
if key4 and os.path.exists(os.path.join(METRICS_DIR, f"evaluation_metrics_{key4}.json")):
    if key4 not in milestone_original_keys : milestone_original_keys.append(key4)
    milestone_label_to_original_key[milestone_label] = key4
    print(f"- Found Milestone 4: '{key4}' -> '{milestone_label}'")
else:
    print(f"- WARNING: Milestone 4 key '{key4}' or its metrics JSON not found.")

# 5. Best Attention (Best Overall Pre-FT)
key5 = best_overall_pre_finetune_key # Key after Attention stage
milestone_label = "5. +Attention"
if key5:
    if "_AttnCBAM" in key5: milestone_label += " (CBAM)"
    elif "_AttnCHANNEL" in key5: milestone_label += " (Channel)"
    elif "_AttnSPATIAL" in key5: milestone_label += " (Spatial)"
    # Add other attention types if used
if key5 and os.path.exists(os.path.join(METRICS_DIR, f"evaluation_metrics_{key5}.json")):
    if key5 not in milestone_original_keys : milestone_original_keys.append(key5)
    milestone_label_to_original_key[milestone_label] = key5
    print(f"- Found Milestone 5: '{key5}' -> '{milestone_label}'")
else:
    print(f"- WARNING: Milestone 5 key '{key5}' or its metrics JSON not found.")

# 6. Fine-Tuned (Result from Stage 6 - Fine-tuning)
key6 = fine_tuned_run_key # Use the specific variable from Cell 17 that holds the fine-tuned key
milestone_label = "6. Fine-Tuned"
if key6 and os.path.exists(os.path.join(METRICS_DIR, f"evaluation_metrics_{key6}.json")):
    if key6 not in milestone_original_keys : milestone_original_keys.append(key6)
    milestone_label_to_original_key[milestone_label] = key6
    print(f"- Found Milestone 6: '{key6}' -> '{milestone_label}'")
else:
    print(f"- WARNING: Milestone 6 key '{key6}' (Fine-Tuned) or its metrics JSON not found.")

# Ensure milestone_original_keys only contains unique keys that were actually found
milestone_original_keys = list(dict.fromkeys(milestone_original_keys)) # Preserves order, removes duplicates

# Plot Final Comparison using the original keys if data available
if milestone_original_keys:
    print("\nPlotting Final Milestone Comparison...")
    
    # Define the metrics you want to plot from the JSON files
    metrics_to_plot_final = ['f1_opt', 'accuracy_opt', 'precision_opt', 'recall_opt', 'roc_auc_proba', 'pr_auc']
    if TARGET_METRIC not in metrics_to_plot_final: # Ensure target metric is plotted
        metrics_to_plot_final.insert(0, TARGET_METRIC)
    metrics_to_plot_final = list(dict.fromkeys(metrics_to_plot_final))


    # The plot function now expects original keys.
    # If you want the bars to be labeled with descriptive names (like "1. Baseline (Aug)"),
    # you would need to modify plot_comparison_bars_enhanced to accept a mapping for y-tick labels,
    # or ensure the 'Configuration' column in its internal DataFrame is set to these descriptive labels.
    # For now, it will use the original keys as labels. We can sort milestone_original_keys
    # based on the milestone order for plotting if the map `milestone_label_to_original_key` is correctly populated.

    # To sort the keys for plotting in milestone order:
    sorted_descriptive_labels = sorted(milestone_label_to_original_key.keys())
    ordered_original_keys_for_plot = [milestone_label_to_original_key[label] for label in sorted_descriptive_labels if milestone_label_to_original_key[label] in milestone_original_keys]
    
    # Make sure all keys in ordered_original_keys_for_plot are valid and unique
    ordered_original_keys_for_plot = [k for k in ordered_original_keys_for_plot if k in milestone_original_keys]
    ordered_original_keys_for_plot = list(dict.fromkeys(ordered_original_keys_for_plot))


    if ordered_original_keys_for_plot:
        plot_comparison_bars_enhanced(
            config_keys_to_plot=ordered_original_keys_for_plot, # Pass the original keys in desired order
            metrics_dir=METRICS_DIR,
            title="Final Model Performance Milestones",
            save_dir=PLOTS_DIR, # Make sure PLOTS_DIR is defined
            metrics_to_display=metrics_to_plot_final
        )
    else:
        print("\nNo valid, ordered keys found to generate the final comparison plot from JSONs.")

else:
    print("\nNot enough valid milestone results found to generate the final comparison plot.")

print("\n" + "="*70)
print("                      End of Experiment Pipeline                      ")
print("="*70)

# --- Optional: Save final results dictionary (the big 'results' dict) to JSON ---
# This part needs a robust serializer if 'results' contains complex objects like Keras layers.
# The existing serialization logic you had was a good start.
# For simplicity, I'm focusing on the plotting part from JSONs.
# The 'final_experiment_results.json' in your provided file list suggests you do save this.

results_summary_json_path = os.path.join(PLOTS_DIR, "pipeline_summary_results.json") # Save in PLOTS_DIR for outputs
try:
    serializable_summary = {}
    print(f"\nAttempting to create a serializable summary for JSON export to {results_summary_json_path}...")
    for milestone_label, original_key in milestone_label_to_original_key.items():
        if original_key and os.path.exists(os.path.join(METRICS_DIR, f"evaluation_metrics_{original_key}.json")):
            with open(os.path.join(METRICS_DIR, f"evaluation_metrics_{original_key}.json"), 'r') as f:
                metrics = json.load(f)
            serializable_summary[milestone_label] = {
                'original_key': original_key,
                'metrics': metrics
            }
            if original_key in results and 'config' in results[original_key]:
                 # Attempt to serialize basic config parts
                serializable_summary[milestone_label]['config_summary'] = {
                    k: str(v) for k, v in results[original_key]['config'].items() if not callable(v) and not isinstance(v, keras.Model) and not isinstance(v, keras.layers.Layer)
                }


    if serializable_summary:
        with open(results_summary_json_path, 'w') as f:
            json.dump(serializable_summary, f, indent=4)
        print(f"Final milestone metrics summary saved to: {results_summary_json_path}")
    else:
        print("No milestone data to save in summary JSON.")

except Exception as json_e:
    print(f"\nWarning: Could not save final milestone summary to JSON. Error: {json_e}")