# Final Model Training for DenseNet121

This notebook takes the optimal hyperparameters found via the Bayesian Optimization study in `../1_bayesian_optimization/bo-densenet121.ipynb`, builds the final model, trains it, and evaluates its performance on the hold-out test set.

## 1. Setup and Imports

In [None]:
# Imports and Setup
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.applications.densenet import preprocess_input as densenet_preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.utils import Sequence
from sklearn.model_selection import train_test_split
from glob import glob
import json
import math
import sys
import random

print("TensorFlow Version:", tf.__version__)

In [None]:
# determinism
SEED = 42

os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [None]:
# Parameters
TARGET_MODALITY = 'tb'

PRETRAINED_WEIGHTS_NAME = 'radimagenet'

RADIMAGENET_WEIGHT_PATH = '../../weights/RadImageNet-DenseNet121_notop.h5'

UNFREEZE_AT_BLOCK = 'conv4_block1_concat'
INITIAL_LEARNING_RATE = 1e-4 # LR used if base is frozen OR for top layers
FINE_TUNE_LR = 1e-5 # Use a smaller learning rate for fine-tuning if base is unfrozen

BATCH_SIZE = 32
MAX_EPOCHS = 2
EARLY_STOPPING_PATIENCE = 15
REDUCE_LR_PATIENCE = 5
TEST_SPLIT = 0.15
VALID_SPLIT = 0.15
IMG_SIZE = (224, 224)

# Tuberculosis Dataset Paths
TB_DATASET_DIR_NAME = 'tuberculosis-tb-chest-xray-dataset'
TB_SUBDIR = 'TB_Chest_Radiography_Database'
TB_BASE_PATH = os.path.join('../../data/', TB_DATASET_DIR_NAME, TB_SUBDIR)
TB_NORMAL_DIR = os.path.join(TB_BASE_PATH, 'Normal/')
TB_TUBERCULOSIS_DIR = os.path.join(TB_BASE_PATH, 'Tuberculosis/')

BEST_PARAMS_FILE = '../../results/best_params_tb_DenseNet121_radimagenet.json'
MODEL_SAVE_PATH = '../../models/densenet121_final_model.weights.h5'

## 2. Data Loading and Preparation

In [None]:
def load_and_prepare_data(target_modality):
    df_target_modality = pd.DataFrame()

    # Load Tuberculosis Data if targeted
    if target_modality == 'tb':
        print(f"Loading Tuberculosis data...")
        tb_data_list = []
        try:
            normal_files = glob(os.path.join(TB_NORMAL_DIR, '*.png'))
            for f in normal_files: tb_data_list.append({'filepath': f, 'label': 0})
            tuberculosis_files = glob(os.path.join(TB_TUBERCULOSIS_DIR, '*.png'))
            for f in tuberculosis_files: tb_data_list.append({'filepath': f, 'label': 1})

            if tb_data_list:
                df_target_modality = pd.DataFrame(tb_data_list)
                print(f"  Processed {len(df_target_modality)} TB images.")
            else:
                print("  WARNING: No TB images found.")
        except Exception as e:
            print(f"  ERROR loading TB data: {e}. Check paths.")


    # Validate Paths and Split
    if df_target_modality.empty:
        raise ValueError(f"CRITICAL: No data loaded for target modality '{target_modality}'.")

    # Check file existence
    original_count = len(df_target_modality)
    df_target_modality = df_target_modality[df_target_modality['filepath'].apply(os.path.exists)]
    files_removed = original_count - len(df_target_modality)
    if files_removed > 0: print(f"WARNING: Removed {files_removed} non-existent file entries.")
    print(f"\nTotal valid entries for '{target_modality}': {len(df_target_modality)}")
    if len(df_target_modality) == 0: raise ValueError("CRITICAL: No valid image files found.")

    df_target_modality['label'] = df_target_modality['label'].astype(int)
    print(f"\nLabel distribution for '{target_modality}':\n", df_target_modality['label'].value_counts())
    if df_target_modality['label'].nunique() < 2: raise ValueError("CRITICAL: Only one class found in the target modality dataset.")

    # Stratified Split based on Label
    train_df, val_df, test_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
    try:
        # Split off test set
        train_val_df, test_df = train_test_split(
            df_target_modality,
            test_size=TEST_SPLIT,
            stratify=df_target_modality['label'],
            random_state=SEED
        )
        # Adjust validation split percentage
        val_split_adjusted = VALID_SPLIT / (1 - TEST_SPLIT) if (1 - TEST_SPLIT) > 0 else 0

        # Split train_val into train and validation
        if val_split_adjusted > 0 and len(train_val_df) > 1 and train_val_df['label'].nunique() > 1:
             train_df, val_df = train_test_split(
                 train_val_df,
                 test_size=val_split_adjusted,
                 stratify=train_val_df['label'],
                 random_state=SEED
             )
        elif len(train_val_df) > 0:
             train_df = train_val_df
             val_df = pd.DataFrame(columns=train_val_df.columns) # Empty val set if split fails
             print("Warning: Could not perform validation split properly, validation set might be empty or small.")
        else: # train_val_df is empty
             train_df = pd.DataFrame(columns=df_target_modality.columns)
             val_df = pd.DataFrame(columns=df_target_modality.columns)

    except ValueError as e:
        print(f"\nCRITICAL ERROR during split: {e}. Check class distribution and split sizes.")
        raise e # Reraise the exception to stop execution
    except Exception as e_gen:
         print(f"\nUNEXPECTED ERROR during split: {e_gen}")
         raise e_gen

    print("\n--- Data Split Summary ---")
    print(f"Modality: {target_modality}")
    print(f"Train:      {len(train_df)}")
    print(f"Validation: {len(val_df)}")
    print(f"Test:       {len(test_df)}")
    print("--------------------------")

    if len(train_df) == 0 or len(val_df) == 0:
         print("\nWARNING: Training or validation set is empty after splitting.")

    # Convert labels to string for flow_from_dataframe binary mode
    train_df['label'] = train_df['label'].astype(str)
    val_df['label'] = val_df['label'].astype(str)
    test_df['label'] = test_df['label'].astype(str)


    return train_df, val_df, test_df

In [None]:
# Load Data Trigger
train_df, val_df, test_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
try:
    train_df, val_df, test_df = load_and_prepare_data(TARGET_MODALITY)
    
    print(f"\nData loading and splitting for {TARGET_MODALITY} completed.")

    if train_df.empty and val_df.empty and test_df.empty:
         print("\nWARNING: All data splits (train, val, test) are empty. This might be due to filtering or data issues.")
    elif train_df.empty or val_df.empty:
        print("\nWARNING: Training or validation set is empty after splitting. Training cannot proceed.")

except Exception as e:
    print(f"\nCRITICAL ERROR during data loading/processing: {e}. Stopping.")

## 3. Data Augmentation and Generators

In [None]:
# Data Generators

train_generator = None
validation_generator = None
test_generator = None

if not train_df.empty and not val_df.empty:

    # Augmentation settings for training
    train_datagen = ImageDataGenerator(
        preprocessing_function=densenet_preprocess_input,
        rotation_range=15,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.1,
        horizontal_flip=True,
        fill_mode='nearest'
    )

    # No augmentation for validation/testing
    val_test_datagen = ImageDataGenerator(
        preprocessing_function=densenet_preprocess_input
    )

    print(f"\nCreating training generator for '{TARGET_MODALITY}'...")
    try:
        train_generator = train_datagen.flow_from_dataframe(
            dataframe=train_df,
            x_col='filepath',
            y_col='label',
            target_size=IMG_SIZE,
            batch_size=BATCH_SIZE,
            class_mode='binary',
            shuffle=True
        )
        print(f"Training generator created: {len(train_generator)} batches.")
    except Exception as e:
        print(f"ERROR creating training generator: {e}")

    print(f"\nCreating validation generator for '{TARGET_MODALITY}'...")
    try:
        validation_generator = val_test_datagen.flow_from_dataframe(
            dataframe=val_df,
            x_col='filepath',
            y_col='label',
            target_size=IMG_SIZE,
            batch_size=BATCH_SIZE,
            class_mode='binary',
            shuffle=False
        )
        print(f"Validation generator created: {len(validation_generator)} batches.")
    except Exception as e:
        print(f"ERROR creating validation generator: {e}")

    if not test_df.empty:
        print(f"\nCreating test generator for '{TARGET_MODALITY}'...")
        try:
            test_generator = val_test_datagen.flow_from_dataframe(
                dataframe=test_df,
                x_col='filepath',
                y_col='label',
                target_size=IMG_SIZE,
                batch_size=BATCH_SIZE,
                class_mode='binary',
                shuffle=False
            )
            print(f"Test generator created: {len(test_generator)} batches.")
        except Exception as e:
            print(f"ERROR creating test generator: {e}")
    else:
        print("\nWARNING: Test data DataFrame is empty. Cannot create test generator.")

else:
    print("\nWARNING: Training or validation DataFrame is empty. Cannot create generators.")


print("\n--- Generator Status ---")
print(f"Train Generator:      {'Created' if train_generator else 'Failed/Skipped'}")
print(f"Validation Generator: {'Created' if validation_generator else 'Failed/Skipped'}")
print(f"Test Generator:       {'Created' if test_generator else 'Failed/Skipped'}")
print("------------------------\n")

## 4. Load Optimal Hyperparameters

In [None]:
# Load Best Hyperparameters
try:
    with open(BEST_PARAMS_FILE, 'r') as f:
        best_params = json.load(f)
    print("Successfully loaded best hyperparameters:")
    print(best_params)

    # Extract hyperparameters
    hp_filters1 = best_params.get('filters1')
    kernel1_str = best_params.get('kernel1')
    hp_kernel1 = (int(kernel1_str.split('x')[0]), int(kernel1_str.split('x')[1])) #parse string properly into (x, x)
    hp_pool1 = best_params.get('pool1')
    hp_filters2 = best_params.get('filters2')
    kernel2_str = best_params.get('kernel2')
    hp_kernel2 = (int(kernel2_str.split('x')[0]), int(kernel2_str.split('x')[1])) #parse string properly into (x, x)
    hp_pool2 = best_params.get('pool2')
    hp_filters3 = best_params.get('filters3')
    kernel3_str = best_params.get('kernel3')
    hp_kernel3 = (int(kernel3_str.split('x')[0]), int(kernel3_str.split('x')[1])) #parse string properly into (x, x)
    hp_pool3 = best_params.get('pool3')

    # Add checks to ensure keys exist after loading
    required_keys = ['filters1', 'kernel1', 'pool1', 'filters2', 'kernel2', 'pool2', 'filters3', 'kernel3', 'pool3']
    if not all(key in best_params for key in required_keys):
         missing_keys = [key for key in required_keys if key not in best_params]
         error_message = f"ERROR: Hyperparameter file is missing required keys: {missing_keys}"
         print(error_message)
         raise ValueError(error_message) # Raise ValueError for missing keys

except FileNotFoundError:
    error_message = f"CRITICAL ERROR: Best hyperparameters file not found at {BEST_PARAMS_FILE}. Please ensure the file exists and the path is correct. Stopping execution."
    print(error_message)
    raise SystemExit(error_message)

except json.JSONDecodeError as e:
     error_message = f"CRITICAL ERROR: Could not decode JSON from {BEST_PARAMS_FILE}. Error: {e}. Stopping execution."
     print(error_message)
     raise SystemExit(error_message)

except Exception as e:
    error_message = f"CRITICAL ERROR: An unexpected error occurred while loading or processing the hyperparameters file: {e}. Stopping execution."
    print(error_message)
    raise SystemExit(error_message)

print("\n--- Hyperparameters loaded and Training Configuration set ---")
print(f"Filters: {hp_filters1}, {hp_filters2}, {hp_filters3}")
print(f"Kernels: {hp_kernel1}, {hp_kernel2}, {hp_kernel3}")
print(f"Pooling: {hp_pool1}, {hp_pool2}, {hp_pool3}")

## 5. Build Final Model

In [None]:
# Model Building Function (DenseNet121 Base)
def build_model(hp_filters1, hp_kernel1, hp_pool1,
                hp_filters2, hp_kernel2, hp_pool2,
                hp_filters3, hp_kernel3, hp_pool3,
                initial_learning_rate=1e-4,
                pretrained_weights_name=None,
                radimagenet_path=None,
                unfreeze_at_block=None,
                fine_tune_lr=1e-5):

    # Ensure kernel sizes are tuples
    kernel1_tuple = tuple(hp_kernel1) if isinstance(hp_kernel1, list) else hp_kernel1
    kernel2_tuple = tuple(hp_kernel2) if isinstance(hp_kernel2, list) else hp_kernel2
    kernel3_tuple = tuple(hp_kernel3) if isinstance(hp_kernel3, list) else hp_kernel3

    # Base Model Instantiation & Weight Loading
    weights_to_load = None
    custom_path = None
    train_base_initially = False

    if pretrained_weights_name == 'imagenet':
        print("Loading base DenseNet121 model with ImageNet weights...")
        weights_to_load = 'imagenet'
    elif pretrained_weights_name == 'radimagenet':
        print("Preparing to load RadImageNet weights for DenseNet121...")
        weights_to_load = None
        custom_path = radimagenet_path
    else:
        print("WARNING: No valid pre-trained weights specified. Initializing DenseNet121 randomly.")
        weights_to_load = None
        train_base_initially = True

    # Load base DenseNet121 structure
    base_model = tf.keras.applications.DenseNet121(
        weights=weights_to_load,
        include_top=False,
        input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3)
    )

    # Load custom weights if applicable
    if custom_path is not None:
        if os.path.exists(custom_path):
            print(f"Loading custom weights from: {custom_path}")
            try:
                base_model.load_weights(custom_path, by_name=True, skip_mismatch=True)
                print("Custom weights loaded successfully.")
            except Exception as e:
                print(f"ERROR loading custom weights from {custom_path}: {e}")
                print("Proceeding with randomly initialized base model weights.")
                train_base_initially = True
        else:
            print(f"WARNING: Custom weight file not found at {custom_path}. Proceeding with randomly initialized base model weights.")
            train_base_initially = True

    # Layer Freezing/Unfreezing for Fine-Tuning
    if unfreeze_at_block and not train_base_initially:
        print(f"Fine-tuning enabled: Unfreezing layers from layer name starting with '{unfreeze_at_block}' onwards.")
        base_model.trainable = True
        unfreeze_layer_index = -1
        for i, layer in enumerate(base_model.layers):
            if layer.name.startswith(unfreeze_at_block):
                unfreeze_layer_index = i
                break

        if unfreeze_layer_index == -1:
            print(f"WARNING: Layer name prefix '{unfreeze_at_block}' not found. Keeping base frozen.")
            base_model.trainable = False
        else:
            print(f"Freezing layers up to index {unfreeze_layer_index}.")
            for layer in base_model.layers[:unfreeze_layer_index]:
                layer.trainable = False
            print(f"Layers from index {unfreeze_layer_index} onwards are trainable.")
    else:
        print("Keeping the base model frozen (or training from scratch if randomly initialized).")
        base_model.trainable = train_base_initially

    # Build the full model
    inputs = tf.keras.Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3), name='input_image')

    x = base_model(inputs, training=base_model.trainable)

    # Custom top layers
    pool_func1 = tf.keras.layers.MaxPooling2D if hp_pool1 == 'max' else tf.keras.layers.AveragePooling2D
    x = tf.keras.layers.Conv2D(filters=hp_filters1, kernel_size=kernel1_tuple, padding='same', activation='relu', name='custom_block1_conv')(x)
    if x.shape[1] is not None and x.shape[1] > 1 and x.shape[2] is not None and x.shape[2] > 1:
         x = pool_func1(pool_size=(2, 2), strides=2, name='custom_block1_pool')(x)
    else: print(f"Skipping Pooling Layer 1 (Input shape: {x.shape})")

    pool_func2 = tf.keras.layers.MaxPooling2D if hp_pool2 == 'max' else tf.keras.layers.AveragePooling2D
    x = tf.keras.layers.Conv2D(filters=hp_filters2, kernel_size=kernel2_tuple, padding='same', activation='relu', name='custom_block2_conv')(x)
    if x.shape[1] is not None and x.shape[1] > 1 and x.shape[2] is not None and x.shape[2] > 1:
         x = pool_func2(pool_size=(2, 2), strides=2, name='custom_block2_pool')(x)
    else: print(f"Skipping Pooling Layer 2 (Input shape: {x.shape})")

    pool_func3 = tf.keras.layers.MaxPooling2D if hp_pool3 == 'max' else tf.keras.layers.AveragePooling2D
    x = tf.keras.layers.Conv2D(filters=hp_filters3, kernel_size=kernel3_tuple, padding='same', activation='relu', name='custom_block3_conv')(x)
    if x.shape[1] is not None and x.shape[1] > 1 and x.shape[2] is not None and x.shape[2] > 1:
        x = pool_func3(pool_size=(2, 2), strides=2, name='custom_block3_pool')(x)
    else: print(f"Skipping Pooling Layer 3 (Input shape: {x.shape})")

    # Classification Head
    x = tf.keras.layers.GlobalAveragePooling2D(name='custom_gap')(x)
    x = layers.Dropout(0.5)(x)
    x = tf.keras.layers.Dense(256, activation='relu', name='custom_dense')(x)
    outputs = tf.keras.layers.Dense(1, activation='sigmoid', name='custom_output')(x)
    # --------------------------

    model = tf.keras.Model(inputs=inputs, outputs=outputs)

    # Compile the model
    lr_to_use = fine_tune_lr if unfreeze_at_block and not train_base_initially else initial_learning_rate
    print(f"Compiling model with learning rate: {lr_to_use}")
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr_to_use),
        loss='binary_crossentropy',
        metrics=['accuracy',
                 tf.keras.metrics.Precision(name='precision'),
                 tf.keras.metrics.Recall(name='recall')]
    )
    return model

In [None]:
tf.keras.backend.clear_session()

# Build the model using loaded hyperparameters and specified settings
model = build_model(
    hp_filters1, hp_kernel1, hp_pool1,
    hp_filters2, hp_kernel2, hp_pool2,
    hp_filters3, hp_kernel3, hp_pool3,
    initial_learning_rate=INITIAL_LEARNING_RATE,
    pretrained_weights_name=PRETRAINED_WEIGHTS_NAME,
    radimagenet_path=RADIMAGENET_WEIGHT_PATH,
    unfreeze_at_block=UNFREEZE_AT_BLOCK,
    fine_tune_lr=FINE_TUNE_LR
)

# Print model summary
model.summary()

# Print trainable status of layers for verification
print("\nTrainable status of base model layers (showing first 10 and last 10):")
if hasattr(model, 'layers') and len(model.layers) > 2 and hasattr(model.layers[2], 'layers'): # Check structure (Input -> Lambda -> Base)
     base_layers = model.layers[2].layers # Assumes base model is layer 2
     for layer in base_layers[:10]:
         print(f"{layer.name}: {layer.trainable}")
     print("...")
     for layer in base_layers[-10:]:
         print(f"{layer.name}: {layer.trainable}")
else:
     print("Could not display base layer trainable status (model structure might differ).")

## 6. Define Training Callbacks

In [None]:
# Callbacks for training
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=EARLY_STOPPING_PATIENCE,
    restore_best_weights=True,
    verbose=1
)

model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=MODEL_SAVE_PATH,
    save_best_only=True,
    monitor='val_loss',
    save_weights_only=True,
    verbose=1
)

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.2,
    patience=REDUCE_LR_PATIENCE,
    min_lr=1e-6,
    verbose=1
)

callbacks_list = [early_stopping, model_checkpoint, reduce_lr]

## 7. Train the Model

In [None]:
# Check if generators are valid before training
if train_generator and validation_generator and len(train_generator) > 0 and len(validation_generator) > 0:
    print("\n--- Starting Model Training ---")

    history = model.fit(
        train_generator,
        epochs=MAX_EPOCHS,
        validation_data=validation_generator,
        callbacks=callbacks_list,
        verbose=1
    )
    print("\n--- Model Training Finished ---")

    import matplotlib.pyplot as plt

    def plot_history(history):
        # Plot training & validation accuracy values
        plt.figure(figsize=(12, 5))
        plt.subplot(1, 2, 1)
        plt.plot(history.history['accuracy'])
        plt.plot(history.history['val_accuracy'])
        plt.title('Model Accuracy')
        plt.ylabel('Accuracy')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Validation'], loc='upper left')

        # Plot training & validation loss values
        plt.subplot(1, 2, 2)
        plt.plot(history.history['loss'])
        plt.plot(history.history['val_loss'])
        plt.title('Model Loss')
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Validation'], loc='upper left')

        plt.tight_layout()
        plt.show()

    plot_history(history)

else:
    print("ERROR: Training or Validation generator is invalid. Cannot train model.")
    if not train_generator or len(train_generator) == 0:
         print("Reason: Training generator issue.")
    if not validation_generator or len(validation_generator) == 0:
         print("Reason: Validation generator issue.")

## 8. Final Evaluation on Test Set

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import math

model_to_load_weights_from = MODEL_SAVE_PATH

# Check if the weight file exists
if os.path.exists(model_to_load_weights_from):
    print(f"--- Found saved weights at: {model_to_load_weights_from} ---")

    print("Re-building model architecture...")
    model_for_evaluation = build_model(
        hp_filters1, hp_kernel1, hp_pool1,
        hp_filters2, hp_kernel2, hp_pool2,
        hp_filters3, hp_kernel3, hp_pool3,
        initial_learning_rate=INITIAL_LEARNING_RATE,
        pretrained_weights_name=PRETRAINED_WEIGHTS_NAME,
        radimagenet_path=RADIMAGENET_WEIGHT_PATH,
        unfreeze_at_block=UNFREEZE_AT_BLOCK,
        fine_tune_lr=FINE_TUNE_LR
    )
    print("Model architecture re-built.")

    print(f"\n--- Loading Weights into model from {model_to_load_weights_from} ---")
    try:
        model_for_evaluation.load_weights(model_to_load_weights_from)
        print("Weights loaded successfully into the model.")

        # Check if test_generator exists and is valid
        if 'test_generator' not in locals() or test_generator is None or len(test_generator) == 0:
             print("\nERROR: 'test_generator' object not found or is empty.")
             print("Please ensure data loading/preparation cells have been run and test data exists.")
             raise NameError("test_generator not defined or empty")

        print("\n--- Evaluating on Test Set (using model.evaluate) ---")
        test_loss, test_acc, test_prec, test_rec = model_for_evaluation.evaluate(
            test_generator,
            verbose=1
        )
        test_f1 = 2 * (test_prec * test_rec) / (test_prec + test_rec) if (test_prec + test_rec) > 0 else 0

        print(f"\nTest Metrics (from evaluate):")
        print(f"  Loss:      {test_loss:.4f}")
        print(f"  Accuracy:  {test_acc:.4f}")
        print(f"  Precision: {test_prec:.4f}")
        print(f"  Recall:    {test_rec:.4f}")
        print(f"  F1-Score:  {test_f1:.4f}")

        # Sklearn Metrics evaluation
        print("\nCalculating predictions and true labels batch-by-batch for sklearn metrics...")
        y_true_list = []
        y_pred_proba_list = []
        num_batches = len(test_generator)
        test_generator.reset()
        for i in range(num_batches):
            try:
                x_batch, y_batch = next(test_generator)
                if x_batch.shape[0] == 0: continue
                y_batch_pred_proba = model_for_evaluation.predict_on_batch(x_batch)
                y_true_list.append(y_batch)
                y_pred_proba_list.append(y_batch_pred_proba)
            except StopIteration: break
            except Exception as batch_err:
                print(f"Error processing batch {i+1}: {batch_err}")
                continue
        
        if not y_true_list:
            print("ERROR: No data collected from test generator for sklearn metrics.")
        else:
            y_true = np.concatenate(y_true_list)
            y_pred_proba = np.concatenate(y_pred_proba_list)
            min_len = min(len(y_true), len(y_pred_proba))
            y_true = y_true[:min_len].astype(int)
            y_pred_proba = y_pred_proba[:min_len]
            y_pred_classes = (y_pred_proba > 0.5).astype(int).flatten()

            sk_acc = accuracy_score(y_true, y_pred_classes)
            sk_prec = precision_score(y_true, y_pred_classes, zero_division=0)
            sk_rec = recall_score(y_true, y_pred_classes, zero_division=0)
            sk_f1 = f1_score(y_true, y_pred_classes, zero_division=0)
            print("\nSklearn Metrics:")
            print(f"  Accuracy:  {sk_acc:.4f}")
            print(f"  Precision: {sk_prec:.4f}")
            print(f"  Recall:    {sk_rec:.4f}")
            print(f"  F1-Score:  {sk_f1:.4f}")

            print("\nCalculating Confusion Matrix...")
            cm = confusion_matrix(y_true, y_pred_classes)
            plt.figure(figsize=(6, 5))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                        xticklabels=['Normal (0)', 'Abnormal (1)'],
                        yticklabels=['Normal (0)', 'Abnormal (1)'])
            plt.title('Confusion Matrix on Test Set')
            plt.xlabel('Predicted Label')
            plt.ylabel('True Label')
            plt.show()

    except NameError as ne:
        print(f"\nError during evaluation setup (NameError): {ne}")
    except Exception as e:
        print(f"Error re-building model, loading weights, or evaluating: {e}")
        print("Ensure all necessary variables (hyperparameters, paths) are defined and weights file is compatible.")

else:
    print(f"--- ERROR: Saved weights not found at {model_to_load_weights_from} ---")
    print("--- Please train the model first by running the preceding cells, or check MODEL_SAVE_PATH. ---")