# Bayesian Optimization for ResNet50 Architecture

This notebook uses Optuna to find the optimal hyperparameters for three custom convolutional blocks appended to a pre-trained ResNet50 base model. The objective is to maximize the F1-score on the validation set for tuberculosis X-ray classification.

## 1. Setup and Imports

In [None]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications import ResNet50, InceptionV3, DenseNet121
from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess_input
from tensorflow.keras.applications.inception_v3 import preprocess_input as inception_preprocess_input
from tensorflow.keras.applications.densenet import preprocess_input as densenet_preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
import optuna
from optuna.integration import KerasPruningCallback
from glob import glob
import json
import shutil
from sklearn.utils import class_weight

print("TensorFlow Version:", tf.__version__)
print("Optuna Version:", optuna.__version__)

In [None]:
# Model Configuration
MODEL_NAME = 'ResNet50'  # Options: 'ResNet50', 'InceptionV3', 'DenseNet121'

# Pretrained Weights Configuration
PRETRAINED_WEIGHTS_TYPE = 'radimagenet'


RADIMAGENET_WEIGHTS_PATH_RESNET50 = '../../weights/RadImageNet-ResNet50_notop.h5'
RADIMAGENET_WEIGHTS_PATH_INCEPTIONV3 = '../../weights/RadImageNet-InceptionV3_notop.h5'
RADIMAGENET_WEIGHTS_PATH_DENSENET121 = '../../weights/RadImageNet-DenseNet121_notop.h5'


# Fine-tuning Configuration
UNFREEZE_AT_BLOCK = 'conv4_block1_out' # ResNet50: 'conv4_block1_out', DenseNet121: 'conv4_block1_concat', InceptionV3: 'mixed9'

# Optuna Configuration
OPTUNA_STUDY_NAME = f"tb_bo_{MODEL_NAME}_{PRETRAINED_WEIGHTS_TYPE if PRETRAINED_WEIGHTS_TYPE else 'scratch'}"
BEST_PARAMS_FILE = f"../../results/best_params_tb_{MODEL_NAME}_{PRETRAINED_WEIGHTS_TYPE if PRETRAINED_WEIGHTS_TYPE else 'scratch'}.json"
OPTUNA_DB_PATH = f"sqlite:///../../results/optuna_study_tb_{MODEL_NAME}_{PRETRAINED_WEIGHTS_TYPE if PRETRAINED_WEIGHTS_TYPE else 'scratch'}.db"
OPTUNA_TRIALS = 1

# Training Configuration
IMG_SIZE = (224, 224)
if MODEL_NAME == 'InceptionV3':
    IMG_SIZE = (299, 299)
BATCH_SIZE = 32
LEARNING_RATE_BO = 1e-5
MAX_EPOCHS_BO = 1
EARLY_STOPPING_PATIENCE_BO = 5

# Data Configuration
TB_DATASET_DIR_NAME = 'tuberculosis-tb-chest-xray-dataset'
TB_SUBDIR = 'TB_Chest_Radiography_Database'
TB_BASE_PATH = os.path.join('../../data/', TB_DATASET_DIR_NAME, TB_SUBDIR)
TB_NORMAL_DIR = os.path.join(TB_BASE_PATH, 'Normal')
TB_TUBERCULOSIS_DIR = os.path.join(TB_BASE_PATH, 'Tuberculosis')

TEST_SPLIT_RATIO = 0.15
VALID_SPLIT_RATIO = 0.15

# Output
print("--- Configuration ---")
print(f"Model Name: {MODEL_NAME}")
print(f"Pretrained Weights Type: {PRETRAINED_WEIGHTS_TYPE}")
if PRETRAINED_WEIGHTS_TYPE == 'radimagenet':
    if MODEL_NAME == 'ResNet50': print(f"RadImageNet ResNet50 Path: {RADIMAGENET_WEIGHTS_PATH_RESNET50}")
    elif MODEL_NAME == 'InceptionV3': print(f"RadImageNet InceptionV3 Path: {RADIMAGENET_WEIGHTS_PATH_INCEPTIONV3}")
    elif MODEL_NAME == 'DenseNet121': print(f"RadImageNet DenseNet121 Path: {RADIMAGENET_WEIGHTS_PATH_DENSENET121}")
print(f"Unfreeze at block: {UNFREEZE_AT_BLOCK if UNFREEZE_AT_BLOCK else ('Base Frozen' if PRETRAINED_WEIGHTS_TYPE else 'Fully Trainable (Scratch)')}")
print(f"Image Size: {IMG_SIZE}")
print(f"Optuna Study Name: {OPTUNA_STUDY_NAME}")
print(f"Optuna DB Path: {OPTUNA_DB_PATH}")
print(f"Best Params File: {BEST_PARAMS_FILE}")
print(f"TB Dataset Path: {TB_BASE_PATH}")
print("--------------------")

## 2. Data Loading and Preparation

In [None]:
def load_tb_data():
    tb_data_list = []
    print(f"Loading Tuberculosis data from: {TB_BASE_PATH}")
    try:
        normal_files = glob(os.path.join(TB_NORMAL_DIR, '*.png'))
        for f in normal_files: tb_data_list.append({'filepath': f, 'label': 0}) # 0 for Normal
        print(f"  Found {len(normal_files)} Normal images.")

        tuberculosis_files = glob(os.path.join(TB_TUBERCULOSIS_DIR, '*.png'))
        for f in tuberculosis_files: tb_data_list.append({'filepath': f, 'label': 1}) # 1 for Tuberculosis
        print(f"  Found {len(tuberculosis_files)} Tuberculosis images.")

        if not tb_data_list:
            print("WARNING: No TB images found. Check dataset paths.")
            return pd.DataFrame()

        df_tb = pd.DataFrame(tb_data_list)
        print(f"  Successfully processed {len(df_tb)} TB images in total.")

        original_count = len(df_tb)
        df_tb = df_tb[df_tb['filepath'].apply(os.path.exists)]
        files_removed = original_count - len(df_tb)
        if files_removed > 0:
            print(f"WARNING: Removed {files_removed} non-existent file entries from TB data.")

        if df_tb.empty:
            print("ERROR: No valid TB image files found after checking paths.")
        return df_tb

    except Exception as e:
        print(f"  ERROR loading TB data: {e}. Check paths and file structure.")
        return pd.DataFrame()

df_tb_all = load_tb_data()

if not df_tb_all.empty:
    print(f"\nTotal TB images loaded: {len(df_tb_all)}")
    print("TB Label distribution:\n", df_tb_all['label'].value_counts())

    # Splitting data into training, validation, and test sets
    if len(df_tb_all) > 1 and df_tb_all['label'].nunique() > 1:
        # Split off test set
        train_val_df, test_df = train_test_split(
            df_tb_all,
            test_size=TEST_SPLIT_RATIO,
            stratify=df_tb_all['label'],
            random_state=42
        )

        adjusted_valid_split_ratio = VALID_SPLIT_RATIO / (1 - TEST_SPLIT_RATIO) if (1 - TEST_SPLIT_RATIO) > 0 else 0
        if len(train_val_df) > 1 and train_val_df['label'].nunique() > 1 and adjusted_valid_split_ratio > 0:
            train_df, val_df = train_test_split(
                train_val_df,
                test_size=adjusted_valid_split_ratio,
                stratify=train_val_df['label'],
                random_state=42
            )
        else:
            train_df = train_val_df
            val_df = pd.DataFrame(columns=train_val_df.columns)
            if adjusted_valid_split_ratio > 0:
                 print("Warning: Could not perform validation split properly due to insufficient data or classes after initial split. Validation set might be small or empty.")

        print(f"\nData Split:")
        print(f"  Training samples: {len(train_df)}")
        print(f"  Validation samples: {len(val_df)}")
        print(f"  Test samples: {len(test_df)}")

        if not train_df.empty: print("Train label distribution:\n", train_df['label'].value_counts(normalize=True))
        if not val_df.empty: print("Validation label distribution:\n", val_df['label'].value_counts(normalize=True))
        if not test_df.empty: print("Test label distribution:\n", test_df['label'].value_counts(normalize=True))

        if train_df.empty or val_df.empty:
            print("ERROR: Training or Validation DataFrame is empty after splitting. BO cannot proceed.")
    else:
        print("ERROR: Not enough data or classes to perform stratified split for TB dataset.")
        train_df, val_df, test_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
else:
    print("ERROR: TB DataFrame is empty. Cannot proceed with data splitting.")
    train_df, val_df, test_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

# Calculate Class Weights
class_weights_dict = None
if not train_df.empty:
    try:
        y_train_labels = train_df['label'].astype(int).values
        
        unique_classes = np.unique(y_train_labels)
        if len(unique_classes) > 1:
            weights = class_weight.compute_class_weight(
                class_weight='balanced',
                classes=unique_classes,
                y=y_train_labels
            )
            class_weights_dict = dict(zip(unique_classes, weights))
            print(f"\n--- Class Weights Calculated ---")
            print(f"Class weights to be used: {class_weights_dict}")
            print(f"Applied to classes: {unique_classes}")
            print("------------------------------")
        else:
            print("\nWarning: Only one class found in training data. Cannot calculate class weights.")
            class_weights_dict = None
            
    except Exception as e:
        print(f"\nError calculating class weights: {e}")
        class_weights_dict = None
else:
    print("\nWarning: train_df is empty. Skipping class weight calculation.")
    class_weights_dict = None

## 3. Data Augmentation and Generators

In [None]:
train_generator = None
validation_generator = None
test_generator = None

# Select the correct preprocessing function based on MODEL_NAME
if MODEL_NAME == 'ResNet50':
    preprocess_input_func = resnet_preprocess_input
elif MODEL_NAME == 'InceptionV3':
    preprocess_input_func = inception_preprocess_input
elif MODEL_NAME == 'DenseNet121':
    preprocess_input_func = densenet_preprocess_input
else:
    print(f"WARNING: Preprocessing function not explicitly set for {MODEL_NAME}. Using generic rescaling.")
    preprocess_input_func = lambda x: x / 255.0

if not train_df.empty and not val_df.empty:
    # Training data generator with augmentation
    train_datagen = ImageDataGenerator(
        preprocessing_function=preprocess_input_func, # Apply model-specific preprocessing
        rotation_range=15,
        width_shift_range=0.1,
        height_shift_range=0.1,
        shear_range=0.1,
        zoom_range=0.1,
        horizontal_flip=True,
        fill_mode='nearest'
    )

    train_generator = train_datagen.flow_from_dataframe(
        dataframe=train_df,
        x_col='filepath',
        y_col='label',
        target_size=IMG_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='raw',
        shuffle=True
    )
    
    if hasattr(train_generator, 'samples'):
        num_train_classes = train_df['label'].nunique()
        print(f"Train generator created. Found {train_generator.samples} images belonging to {num_train_classes} classes.")
    else:
        print("Train generator creation failed or train_df is empty.")


    # Validation data generator
    val_datagen = ImageDataGenerator(
        preprocessing_function=preprocess_input_func
    )
    
    validation_generator = val_datagen.flow_from_dataframe(
        dataframe=val_df,
        x_col='filepath',
        y_col='label',
        target_size=IMG_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='raw',
        shuffle=False
    )

    if hasattr(validation_generator, 'samples'): 
        num_val_classes = val_df['label'].nunique() 
        print(f"Validation generator created. Found {validation_generator.samples} images belonging to {num_val_classes} classes.")
    else:
        print("Validation generator creation failed or val_df is empty.")


    if not test_df.empty:
        test_datagen = ImageDataGenerator(
            preprocessing_function=preprocess_input_func
        )
        test_generator = test_datagen.flow_from_dataframe(
            dataframe=test_df,
            x_col='filepath',
            y_col='label',
            target_size=IMG_SIZE,
            batch_size=BATCH_SIZE,
            class_mode='raw',
            shuffle=False
        )
        if hasattr(test_generator, 'samples'):
            num_test_classes = test_df['label'].nunique()
            print(f"Test generator created. Found {test_generator.samples} images belonging to {num_test_classes} classes.")
        else:
            print("Test generator creation failed or test_df is empty.")
else:
    print("Skipping generator creation as train_df or val_df is empty.")

## 4. Model Architecture Definition

In [None]:
def build_model_for_bo(hp_filters1, hp_kernel1_str, hp_pool1,
                       hp_filters2, hp_kernel2_str, hp_pool2,
                       hp_filters3, hp_kernel3_str, hp_pool3,
                       model_name_local, pretrained_weights_type_local, unfreeze_at_local):
    def parse_kernel_size(k_str): return tuple(map(int, k_str.split('x')))
    kernel1_tuple = parse_kernel_size(hp_kernel1_str)
    kernel2_tuple = parse_kernel_size(hp_kernel2_str)
    kernel3_tuple = parse_kernel_size(hp_kernel3_str)

    inputs = keras.Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3), name="input_image")
    base_model_instance = None
    keras_weights_arg = None
    custom_weights_path = None

    if pretrained_weights_type_local == 'imagenet':
        keras_weights_arg = 'imagenet'
    elif pretrained_weights_type_local == 'radimagenet':
        keras_weights_arg = None
        if model_name_local == 'ResNet50': custom_weights_path = RADIMAGENET_WEIGHTS_PATH_RESNET50
        elif model_name_local == 'InceptionV3': custom_weights_path = RADIMAGENET_WEIGHTS_PATH_INCEPTIONV3
        elif model_name_local == 'DenseNet121': custom_weights_path = RADIMAGENET_WEIGHTS_PATH_DENSENET121
        else: print(f"Warning: RadImageNet path not specified for {model_name_local}")
    elif pretrained_weights_type_local is None:
        keras_weights_arg = None
    else:
        raise ValueError(f"Unsupported pretrained_weights_type: {pretrained_weights_type_local}")


    if model_name_local == 'ResNet50': base_model_func = ResNet50
    elif model_name_local == 'InceptionV3': base_model_func = InceptionV3
    elif model_name_local == 'DenseNet121': base_model_func = DenseNet121
    else: raise ValueError(f"Unsupported model_name: {model_name_local}")

    print(f"Loading base model: {model_name_local} with Keras weights arg: {keras_weights_arg}")
    base_model_instance = base_model_func(
        include_top=False, weights=keras_weights_arg,
        input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3)
    )

    if pretrained_weights_type_local == 'radimagenet' and custom_weights_path:
        if os.path.exists(custom_weights_path):
            print(f"Loading custom RadImageNet weights for {model_name_local} from: {custom_weights_path}")
            try:
                base_model_instance.load_weights(custom_weights_path, by_name=True, skip_mismatch=True)
                print("Custom RadImageNet weights loaded successfully.")
            except Exception as e:
                print(f"ERROR loading custom RadImageNet weights: {e}. Model will use initial weights (random if Keras weights_arg was None).")
        else:
            print(f"WARNING: RadImageNet weight file not found at {custom_weights_path}. Model will use initial weights.")
    elif pretrained_weights_type_local == 'radimagenet' and not custom_weights_path:
        print(f"WARNING: PRETRAINED_WEIGHTS_TYPE is 'radimagenet' but no path specified for {model_name_local}. Model will use initial weights (random if Keras weights_arg was None).")


    # Fine-tuning / Freezing
    if pretrained_weights_type_local is None:
        print(f"Base model {model_name_local} will be trained from scratch.")
        base_model_instance.trainable = True
    elif unfreeze_at_local and unfreeze_at_local.lower() == 'full_unfreeze':
        print(f"Fine-tuning enabled: Unfreezing all layers of the base model {model_name_local}.")
        base_model_instance.trainable = True
    elif unfreeze_at_local:
        print(f"Fine-tuning enabled: Attempting to unfreeze layers from '{unfreeze_at_local}' onwards for {model_name_local}.")
        base_model_instance.trainable = True
        set_trainable = False
        for layer in base_model_instance.layers:
            if unfreeze_at_local in layer.name: set_trainable = True
            layer.trainable = set_trainable
        if not any(layer.trainable for layer in base_model_instance.layers if unfreeze_at_local in layer.name and set_trainable):
             print(f"WARNING: Unfreeze point '{unfreeze_at_local}' not found or did not result in trainable layers as expected. Check layer names. Base model might remain frozen or fully trainable depending on initial state.")
    else: # Default: freeze the base model if pretrained weights are used and no unfreezing specified
        print(f"Keeping the base model {model_name_local} frozen.")
        base_model_instance.trainable = False

    x = base_model_instance(inputs, training=base_model_instance.trainable)

    def add_conv_pool_block(x_input, filters, kernel_size, pool_type, block_name):
        pool_func = layers.MaxPooling2D if pool_type == 'max' else layers.AveragePooling2D
        x_out = layers.Conv2D(filters=filters, kernel_size=kernel_size, padding='same', activation='relu', name=f'custom_{block_name}_conv')(x_input)
        if x_out.shape[1] > 1 and x_out.shape[2] > 1:
            x_out = pool_func(pool_size=(2, 2), strides=2, name=f'custom_{block_name}_pool')(x_out)
        else: print(f"Skipping Pooling for {block_name} (Input shape to pool: {x_out.shape})")
        return x_out

    x = add_conv_pool_block(x, hp_filters1, kernel1_tuple, hp_pool1, "block1")
    x = add_conv_pool_block(x, hp_filters2, kernel2_tuple, hp_pool2, "block2")
    x = add_conv_pool_block(x, hp_filters3, kernel3_tuple, hp_pool3, "block3")

    x = layers.GlobalAveragePooling2D(name='custom_gap')(x)
    x = layers.Dropout(0.5, name='custom_dropout')(x)
    x = layers.Dense(256, activation='relu', name='custom_dense_256')(x)
    outputs = layers.Dense(1, activation='sigmoid', name='output_sigmoid')(x)
    model = Model(inputs=inputs, outputs=outputs)

    print(f"Compiling model with learning rate: {LEARNING_RATE_BO}")
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE_BO),
        loss='binary_crossentropy',
        metrics=['accuracy', keras.metrics.Precision(name='precision'), keras.metrics.Recall(name='recall')]
    )
    return model

In [None]:
# Model Summary for Verification

print(f"--- Building a temporary model for {MODEL_NAME} to display summary ---")
print(f"Using PRETRAINED_WEIGHTS_TYPE: '{PRETRAINED_WEIGHTS_TYPE}', UNFREEZE_AT_BLOCK: '{UNFREEZE_AT_BLOCK}'")

# Dummy hyperparameters for the custom top layers
dummy_hp = {
    'hp_filters1': 32, 'hp_kernel1_str': '3x3', 'hp_pool1': 'max',
    'hp_filters2': 64, 'hp_kernel2_str': '3x3', 'hp_pool2': 'max',
    'hp_filters3': 128, 'hp_kernel3_str': '3x3', 'hp_pool3': 'max'
}

# Build a temporary model instance
temp_model = build_model_for_bo(
    **dummy_hp, # Unpack dummy hyperparameters
    model_name_local=MODEL_NAME,
    pretrained_weights_type_local=PRETRAINED_WEIGHTS_TYPE,
    unfreeze_at_local=UNFREEZE_AT_BLOCK
)

print("\nMODEL SUMMARY")
temp_model.summary()

# Clean up the temporary model and clear session
del temp_model
tf.keras.backend.clear_session()

print("\n--- End of Summary Cell ---")

## 5. Bayesian Optimization Objective Function

In [None]:
def objective(trial):
    global MODEL_NAME, PRETRAINED_WEIGHTS_TYPE, UNFREEZE_AT_BLOCK
    global class_weights_dict
    
    if not train_generator or not validation_generator or len(train_generator) == 0 or len(validation_generator) == 0:
        print(f"TRIAL {trial.number} SKIPPED: Invalid data generator(s).")
        raise optuna.TrialPruned("Data generators not available or empty.")

    # Hyperparameter suggestions
    filters1 = trial.suggest_int('filters1', 32, 128, step=16)
    kernel1_str = trial.suggest_categorical('kernel1', ['3x3', '5x5', '7x7'])
    pool1 = trial.suggest_categorical('pool1', ['max', 'average'])
    filters2 = trial.suggest_int('filters2', 64, 256, step=32)
    kernel2_str = trial.suggest_categorical('kernel2', ['3x3', '5x5', '7x7'])
    pool2 = trial.suggest_categorical('pool2', ['max', 'average'])
    filters3 = trial.suggest_int('filters3', 128, 512, step=64)
    kernel3_str = trial.suggest_categorical('kernel3', ['3x3', '5x5', '7x7'])
    pool3 = trial.suggest_categorical('pool3', ['max', 'average'])

    tf.keras.backend.clear_session()

    # Build the model
    try:
        model = build_model_for_bo(
            filters1, kernel1_str, pool1,
            filters2, kernel2_str, pool2,
            filters3, kernel3_str, pool3,
            MODEL_NAME, PRETRAINED_WEIGHTS_TYPE, UNFREEZE_AT_BLOCK
        )
    except Exception as e:
        print(f"ERROR building model in Trial {trial.number} for {MODEL_NAME} ({PRETRAINED_WEIGHTS_TYPE}): {e}")
        return 0.0

    # Define callbacks
    early_stopping = keras.callbacks.EarlyStopping(
        monitor='val_loss', 
        patience=EARLY_STOPPING_PATIENCE_BO, 
        restore_best_weights=True, 
        verbose=1
    )
    pruning_callback = KerasPruningCallback(trial=trial, monitor='val_loss', interval=1)

    print(f"\n--- Optuna Trial {trial.number} for {MODEL_NAME} ({PRETRAINED_WEIGHTS_TYPE}) ---")
    print(f"Params: F1={filters1}, K1={kernel1_str}, P1={pool1} | F2={filters2}, K2={kernel2_str}, P2={pool2} | F3={filters3}, K3={kernel3_str}, P3={pool3}")

    current_class_weight = None
    if class_weights_dict is not None:
        current_class_weight = class_weights_dict
        print(f"Using class weights: {current_class_weight}")
    else:
        print("Warning: class_weights_dict is None. Training without class weights.")
    
    # Train the model
    try:
        history = model.fit(
            train_generator, 
            epochs=MAX_EPOCHS_BO, 
            validation_data=validation_generator,
            callbacks=[early_stopping, pruning_callback],
            class_weight=current_class_weight,  # Apply class weighting
            verbose=1
        )
    except optuna.TrialPruned as e:
        print(f"Trial {trial.number} pruned by Optuna: {e}")
        raise e 
    except tf.errors.ResourceExhaustedError as e:
        print(f"ERROR in Trial {trial.number} - ResourceExhaustedError: {e}. Pruning.")
        raise optuna.TrialPruned("ResourceExhaustedError during training.")
    except Exception as e:
        print(f"ERROR during model.fit in Trial {trial.number}: {e}")
        return 0.0

    # Evaluate the model
    print(f"Evaluating Trial {trial.number} on validation set...")
    try:
        val_loss, val_acc, val_prec, val_rec = model.evaluate(validation_generator, verbose=0)
    except Exception as e:
        print(f"ERROR during model.evaluate in Trial {trial.number}: {e}")
        return 0.0 # Return low F1-score for evaluation errors

    # Calculate F1-score
    val_f1 = 0.0
    if (val_prec + val_rec) > 0:
        val_f1 = 2 * (val_prec * val_rec) / (val_prec + val_rec)
    else:
        print(f"Warning: Precision + Recall = 0 for Trial {trial.number}. F1-score will be 0.")
    
    print(f"Trial {trial.number} ({MODEL_NAME}, {PRETRAINED_WEIGHTS_TYPE}) Results: Val Loss={val_loss:.4f}, Acc={val_acc:.4f}, Prec={val_prec:.4f}, Rec={val_rec:.4f}, F1={val_f1:.4f}")
    
    return val_f1

## 6. Execute Optuna Study

In [None]:
study = None
best_trial_info = None

if train_generator and validation_generator:
    print(f"\n--- Starting Optuna Study for {MODEL_NAME} with {PRETRAINED_WEIGHTS_TYPE if PRETRAINED_WEIGHTS_TYPE else 'scratch'} weights ---")
    try:
        study = optuna.create_study(
            study_name=OPTUNA_STUDY_NAME,
            direction='maximize',
            storage=OPTUNA_DB_PATH,
            load_if_exists=True,
            pruner=optuna.pruners.MedianPruner(
                n_startup_trials=5, # Allow first few trials to complete regardless of intermediate values
                n_warmup_steps=EARLY_STOPPING_PATIENCE_BO + 2, # Don't prune before this many epochs
                interval_steps=1 # Check for pruning every epoch after warmup
            )
        )

        n_completed_trials_total = len(study.trials)
        n_finished_trials = len([t for t in study.trials if t.state in [
            optuna.trial.TrialState.COMPLETE,
            optuna.trial.TrialState.PRUNED,
            optuna.trial.TrialState.FAIL
        ]])

        n_trials_to_run = OPTUNA_TRIALS - n_finished_trials

        if n_trials_to_run > 0:
            print(f"Study '{study.study_name}' has {n_finished_trials} finished (Complete/Pruned/Fail) trials.")
            print(f"Running {n_trials_to_run} more trials (target total: {OPTUNA_TRIALS}).")

            timeout_seconds = 11 * 3600 # Time limit 11 hours because Kaggle session limit is 12 hours
            print(f"Optuna study timeout set to: {timeout_seconds} seconds ({timeout_seconds/3600:.1f} hours)")

            study.optimize(
                objective,
                n_trials=n_trials_to_run,
                timeout=timeout_seconds,
                callbacks=[lambda current_study, current_trial: tf.keras.backend.clear_session()]
            )
        else:
            print(f"Study '{study.study_name}' already has {n_finished_trials} finished trials (target was {OPTUNA_TRIALS}). No new trials will be run.")


        print("\n--- Optuna Study Finished ---")
        print(f"Study Name: {study.study_name}")
        print(f"Total trials in study database: {len(study.trials)}")
        completed_trials_list = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
        print(f"  Completed: {len(completed_trials_list)}")
        print(f"  Pruned: {len([t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED])}")
        print(f"  Failed: {len([t for t in study.trials if t.state == optuna.trial.TrialState.FAIL])}")
        print(f"  Running: {len([t for t in study.trials if t.state == optuna.trial.TrialState.RUNNING])}")

        if completed_trials_list:
            best_trial_info = study.best_trial
            print(f"\nBest completed trial for {MODEL_NAME} ({PRETRAINED_WEIGHTS_TYPE if PRETRAINED_WEIGHTS_TYPE else 'scratch'}):")
            print(f"  Number: {best_trial_info.number}")
            print(f"  Best F1-score (validation): {best_trial_info.value:.4f}")
            print(f"  Best hyperparameters: {best_trial_info.params}")

            # Save Best Hyperparameters
            print(f"\nSaving best hyperparameters for {MODEL_NAME} to {BEST_PARAMS_FILE}")
            try:
                params_to_save = {k: (list(v) if isinstance(v, tuple) else v) for k, v in best_trial_info.params.items()}
                with open(BEST_PARAMS_FILE, 'w') as f:
                    json.dump(params_to_save, f, indent=4)
                print("Best hyperparameters saved successfully.")
            except Exception as e:
                print(f"ERROR saving hyperparameters: {e}")
        else:
            print(f"No trials completed successfully for {MODEL_NAME} ({PRETRAINED_WEIGHTS_TYPE if PRETRAINED_WEIGHTS_TYPE else 'scratch'}). Cannot determine best trial.")

    except Exception as e:
        print(f"\nAn error occurred during the Optuna study for {MODEL_NAME}: {e}")
else:
    print("\n--- Optuna Study Skipped ---")
    if not train_generator: print("Reason: Training data generator was not created or is empty.")
    if not validation_generator: print("Reason: Validation data generator was not created or is empty.")
    print(f"Cannot run Optuna study for {MODEL_NAME}.")