In [None]:
import model_fit_evaluate  # Assuming this contains create_data_generators, etc.
from tensorflow.keras.applications import (
    ResNet50V2,
    InceptionV3,
    EfficientNetB2,
    VGG16,
    MobileNetV2,
    Xception,
)
from tensorflow.keras.applications.resnet_v2 import preprocess_input as preprocess_input_resnet50v2
from tensorflow.keras.applications.inception_v3 import preprocess_input as preprocess_input_inceptionv3
from tensorflow.keras.applications.efficientnet import preprocess_input as preprocess_input_efficientnetb2
from tensorflow.keras.applications.vgg16 import preprocess_input as preprocess_input_vgg16
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as preprocess_input_mobilenetv2
from tensorflow.keras.applications.xception import preprocess_input as preprocess_input_xception
import tensorflow as tf
import os  # Import os to check if the test directory exists.

# Define model configurations
model_configs = {
    "ResNet50V2": {
        "base_model": ResNet50V2,
        "preprocess_input": preprocess_input_resnet50v2,
        "image_size": (224, 224),
    },
    "InceptionV3": {
        "base_model": InceptionV3,
        "preprocess_input": preprocess_input_inceptionv3,
        "image_size": (224, 224),
    },
    "EfficientNetB2": {
        "base_model": EfficientNetB2,
        "preprocess_input": preprocess_input_efficientnetb2,
        "image_size": (224, 224),
    },
    "VGG16": {
        "base_model": VGG16,
        "preprocess_input": preprocess_input_vgg16,
        "image_size": (224, 224),
    },
    "MobileNetV2": {
        "base_model": MobileNetV2,
        "preprocess_input": preprocess_input_mobilenetv2,
        "image_size": (224, 224),
    },
    "Xception": {
        "base_model": Xception,
        "preprocess_input": preprocess_input_xception,
        "image_size": (224, 224),
    },
}

if __name__ == "__main__":
    # 0. Config (Shared across models)
    data_dir = "train"
    batch_size = 4  # Adjust as needed
    epochs = 1  # Reduced epochs for faster testing. Increase for actual training.
    test_data_dir = "test"  # Replace with your test data directory

    for model_name, config in model_configs.items():
        print(f"--- Training and Evaluating {model_name} ---")

        # Create Data Generators
        train_generator, validation_generator = model_fit_evaluate.create_data_generators(
            data_dir,
            config["image_size"],
            batch_size,
            preprocessing_function=config["preprocess_input"],
        )

        # Create Model
        model = model_fit_evaluate.create_custom_model(
            base_model=config["base_model"],
            weights="imagenet",
            input_shape=(
                config["image_size"][0],
                config["image_size"][1],
                3,
            ),  # Dynamically get input shape
            num_classes=1,
            trainable_base=True,
            preprocess_input=config["preprocess_input"],  # Pass the preprocess function
            dropout_rate=0.5,
        )

        # Compile model
        model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy", "auc"])

        # Train Model
        try:
            history = model_fit_evaluate.train_model(
                model,
                train_generator,
                validation_generator,
                epochs,
                batch_size,
                early_stopping_patience=3,
            )
        except Exception as e:
            print(f"Error training {model_name}: {e}")
            continue # Continue to the next model

        # Evaluate Model on Validation Data
        model_fit_evaluate.evaluate_model(model, validation_generator)

        # Plot Training History
        model_fit_evaluate.plot_training_history(history)

        # Evaluate Model on Test Data
        if os.path.exists(test_data_dir):
            try:
                metrics = model_fit_evaluate.evaluate_model_on_test_data(
                    model,
                    test_data_dir,
                    config["preprocess_input"],  # Correct preprocess function
                    image_size=config["image_size"],  # Pass image_size
                    batch_size=16,
                )

                print(f"{model_name} Evaluation Metrics:")
                for metric, value in metrics.items():
                    print(f"{metric}: {value}")
            except Exception as e:
                print(
                    f"Error evaluating {model_name} on test data: {e}"
                )
        else:
            print(
                f"Warning: Test directory '{test_data_dir}' not found. Skipping test evaluation for {model_name}."
            )