# Cardiomegaly Detection with Deep Learning


## 1. Environment Setup

In [None]:
# Install required packages
!pip install -q pandas numpy opencv-python matplotlib scikit-learn tensorflow Pillow seaborn albumentations \
                tqdm grad-cam jupyterlab pydicom imgaug scikit-image keras-tuner \
                mlflow optuna shap eli5 rich plotly


# TensorFlow and Deep Learning
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout, Flatten
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import VGG16, ResNet50, InceptionV3
from tensorflow.keras.preprocessing.image import ImageDataGenerator

#System Utilities
import os
import sys
import time
import random
import shutil
import logging
from tqdm import tqdm
from datetime import datetime


# Data Analysis and Plotting
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from matplotlib.pyplot import imread, imshow, subplots, show
import plotly.express as px
import plotly.graph_objects as go
from rich import print as rprint

# Image Processing
import cv2
from PIL import Image
from skimage import exposure, filters, morphology, measure
import imgaug.augmenters as iaa
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

# Utility & Performance Tracking
import os
import random
from tqdm import tqdm
import logging
from datetime import datetime

# Model Evaluation and Metrics
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve, accuracy_score
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold

# Data Augmentation (Advanced)
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Explainability
#from gradcam import GradCAM, GradCAMPlusPlus
#from gradcam.utils import visualize_cam

# Hyper Parameter Tuning
import keras_tuner as kt
import optuna
import mlflow
mlflow.set_tracking_uri("file:///content/mlruns")
mlflow.set_experiment("Cardiomegaly_CELM")

# Reproducibility
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

# GPU Configuration
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Allow dynamic memory growth
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"{len(gpus)} GPU(s) detected and configured.")
    except RuntimeError as e:
        print(e)

# Logging Configuration
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
logging.basicConfig(filename=os.path.join(log_dir, f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"),
                    level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

print("Environment successfully set up and ready.")

## 2. Load Combined Dataset

In [None]:
# -----------------------------
# Load and combine multiple CSV files
# -----------------------------
csv_files = [
    'combined_dataset.csv',
    'padchest_labels.csv',
    'nih_labels.csv',
    'vindr_labels.csv',
    'chexpert_labels.csv'
]

dataframes = []

for file in csv_files:
    if os.path.exists(file):
        df = pd.read_csv(file)
        print(f"{file} loaded with shape: {df.shape}")
        dataframes.append(df)
    else:
        print(f" Warning: File {file} not found.")

# -----------------------------
#  Concatenate all dataframes
# -----------------------------
full_df = pd.concat(dataframes, ignore_index=True)
print(f"\n Full dataset shape after concatenation: {full_df.shape}")

# -----------------------------
# Check for missing values
# -----------------------------
missing_info = full_df.isnull().sum()
missing_info = missing_info[missing_info > 0]
if not missing_info.empty:
    print("\n Missing values detected:")
    print(missing_info)
else:
    print("\n No missing values found.")

# -----------------------------
# Analyze class distribution
# -----------------------------
print("\n Class distribution:")
print(full_df['label'].value_counts())

plt.figure(figsize=(8, 5))
sns.countplot(data=full_df, x='label', order=full_df['label'].value_counts().index, palette='viridis')
plt.title('Class Distribution: Cardiomegaly vs Normal/Other Findings')
plt.xlabel('Class Label')
plt.ylabel('Frequency')
plt.grid(axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()

# Save the figure to disk
plt.savefig('class_distribution.png')
plt.show()

## 3. CLAHE and Preprocessing

In [None]:
import os
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path

# -----------------------------
# CLAHE function
# -----------------------------
def apply_clahe(image, clip_limit=7.0, tile_grid_size=(11, 11)):
    lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
    l, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
    cl = clahe.apply(l)
    limg = cv2.merge((cl, a, b))
    return cv2.cvtColor(limg, cv2.COLOR_LAB2RGB)

# -----------------------------
# Preprocessing for 1 image
# -----------------------------
def preprocess_image(image_path, output_size=(224, 224)):
    img = cv2.imread(image_path)
    if img is None:
        raise FileNotFoundError(f" Image not found: {image_path}")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = apply_clahe(img)
    img = cv2.resize(img, output_size)
    img = (img / 255.0 * 255).astype(np.uint8)  # normalize and convert back to uint8 for saving
    return img

# -----------------------------
# Batch processing from CSV
# -----------------------------
def process_images_from_csv(csv_path, image_column, input_root="", output_root="processed_images", img_size=(224, 224)):
    df = pd.read_csv(csv_path)
    image_paths = df[image_column].dropna().tolist()

    os.makedirs(output_root, exist_ok=True)
    print(f"\n Processing {len(image_paths)} images...")

    for rel_path in tqdm(image_paths):
        in_path = os.path.join(input_root, rel_path)
        out_path = os.path.join(output_root, os.path.basename(rel_path))

        try:
            processed = preprocess_image(in_path, output_size=img_size)
            cv2.imwrite(out_path, cv2.cvtColor(processed, cv2.COLOR_RGB2BGR))
        except Exception as e:
            print(f" Skipping {rel_path}: {e}")

    print(f"\n Done. Processed images saved to: {output_root}")

# -----------------------------
# Run example
# -----------------------------
if __name__ == "__main__":
    csv_file = "combined_dataset.csv"          # path to your CSV
    image_col = "image_path"                   # column name that contains relative or full image paths
    input_images_dir = "raw_images"            # optional prefix (if images are not in root)
    output_images_dir = "processed_images"     # where to save processed versions
    image_size = (224, 224)

    process_images_from_csv(csv_file, image_col, input_images_dir, output_images_dir, image_size)

## 4. Data Augmentation

In [None]:
def get_keras_augmentation_pipeline(mode="medium"):
    """
    Returns a tf.keras.Sequential data augmentation pipeline.
    Args:
        mode (str): 'light', 'medium', or 'strong'
    """
    if mode == "light":
        return tf.keras.Sequential([
            tf.keras.layers.RandomFlip("horizontal"),
        ])
    elif mode == "medium":
        return tf.keras.Sequential([
            tf.keras.layers.RandomFlip("horizontal"),
            tf.keras.layers.RandomRotation(0.05),
            tf.keras.layers.RandomZoom(height_factor=0.1, width_factor=0.1),
            tf.keras.layers.RandomBrightness(factor=0.1),
            tf.keras.layers.RandomContrast(factor=0.1),
        ])
    elif mode == "strong":
        return tf.keras.Sequential([
            tf.keras.layers.RandomFlip("horizontal_and_vertical"),
            tf.keras.layers.RandomRotation(0.15),
            tf.keras.layers.RandomZoom(height_factor=0.2, width_factor=0.2),
            tf.keras.layers.RandomBrightness(factor=0.2),
            tf.keras.layers.RandomContrast(factor=0.3),
            tf.keras.layers.RandomTranslation(0.2, 0.2),
        ])
    else:
        raise ValueError("Mode must be 'light', 'medium', or 'strong'.")

# Example usage
data_augmentation = get_keras_augmentation_pipeline("medium")

In [None]:
def visualize_augmentations(image_path, n_examples=6):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    plt.figure(figsize=(15, 5))
    for i in range(n_examples):
        augmented_img = augment_image_np(image)
        plt.subplot(1, n_examples, i + 1)
        plt.imshow(augmented_img)
        plt.title(f"Aug #{i+1}")
        plt.axis("off")
    plt.suptitle("Augmented Samples", fontsize=16)
    plt.tight_layout()
    plt.show()

In [None]:
ef get_albumentations_pipeline():
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.2),
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
        A.Rotate(limit=15, p=0.5),
        A.RandomScale(scale_limit=0.15, p=0.5),
        A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.3),
        A.GridDistortion(p=0.3),
        A.RandomGamma(p=0.3),
        A.GaussianBlur(blur_limit=(3,5), p=0.2),
        A.Resize(224, 224),  # Resize after augmentation
    ])

def augment_image_np(image_np):
    aug = get_albumentations_pipeline()
    augmented = aug(image=image_np)
    return augmented["image"]


## 5. Create TensorFlow Dataset

In [None]:
import tensorflow as tf
import numpy as np
import time

# -----------------------------
# CLAHE with OpenCV
# -----------------------------
import cv2
import tensorflow_addons as tfa

def apply_clahe_tf(image):
    """Applies CLAHE on the L-channel of LAB image."""
    image = tf.numpy_function(apply_clahe_np, [image], tf.uint8)
    return tf.cast(image, tf.float32) / 255.0

def apply_clahe_np(image_np):
    img_rgb = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
    l, a, b = cv2.split(img_rgb)
    clahe = cv2.createCLAHE(clipLimit=7.0, tileGridSize=(11,11))
    cl = clahe.apply(l)
    merged = cv2.merge((cl, a, b))
    result = cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
    return result

# -----------------------------
# Data Augmentation Layer
# -----------------------------
def get_augmentation_pipeline(intensity='medium'):
    if intensity == 'light':
        return tf.keras.Sequential([
            tf.keras.layers.RandomFlip("horizontal"),
        ])
    elif intensity == 'medium':
        return tf.keras.Sequential([
            tf.keras.layers.RandomFlip("horizontal"),
            tf.keras.layers.RandomRotation(0.05),
            tf.keras.layers.RandomZoom(0.1),
            tf.keras.layers.RandomBrightness(0.1),
            tf.keras.layers.RandomContrast(0.1)
        ])
    elif intensity == 'strong':
        return tf.keras.Sequential([
            tf.keras.layers.RandomFlip("horizontal_and_vertical"),
            tf.keras.layers.RandomRotation(0.15),
            tf.keras.layers.RandomZoom(0.2),
            tf.keras.layers.RandomTranslation(0.2, 0.2),
            tf.keras.layers.RandomContrast(0.3),
            tf.keras.layers.RandomBrightness(0.2)
        ])
    else:
        raise ValueError("Choose from: light | medium | strong")

# -----------------------------
# Preprocessing Function
# -----------------------------
def preprocess_image(img_path,
                     label,
                     img_size=(224, 224),
                     channels=3,
                     apply_clahe=False,
                     augment=False,
                     augment_level="medium",
                     normalize=True,
                     num_classes=1):

    image = tf.io.read_file(img_path)

    # Auto-detect encoding (e.g. PNG or JPEG)
    file_format = tf.strings.split(img_path, ".")[-1]
    image = tf.cond(tf.equal(file_format, "png"),
                    lambda: tf.image.decode_png(image, channels=channels),
                    lambda: tf.image.decode_jpeg(image, channels=channels))

    image = tf.image.resize(image, img_size)

    if apply_clahe:
        image = apply_clahe_tf(image)
    else:
        image = tf.cast(image, tf.float32) / 255.0 if normalize else tf.cast(image, tf.float32)

    if augment:
        aug = get_augmentation_pipeline(augment_level)
        image = aug(image, training=True)

    if num_classes > 1:
        label = tf.one_hot(label, depth=num_classes)

    return image, label

# -----------------------------
# TF Dataset Builder
# -----------------------------
def get_tf_dataset(img_paths,
                   labels,
                   batch_size=32,
                   shuffle_buffer=1024,
                   img_size=(224, 224),
                   channels=3,
                   training=True,
                   apply_clahe=False,
                   augment_level="medium",
                   normalize=True,
                   num_classes=1):

    start = time.time()

    dataset = tf.data.Dataset.from_tensor_slices((img_paths, labels))

    dataset = dataset.map(
        lambda x, y: preprocess_image(
            x, y,
            img_size=img_size,
            channels=channels,
            apply_clahe=apply_clahe,
            augment=training,
            augment_level=augment_level,
            normalize=normalize,
            num_classes=num_classes
        ),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    if training:
        dataset = dataset.shuffle(shuffle_buffer)

    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    duration = time.time() - start
    print(f" TF Dataset prepared in {duration:.2f} seconds.")

    return dataset

In [None]:
df = pd.read_csv("combined_dataset.csv")
image_paths = df['image_path'].values
labels = df['label'].values  # binary or int class

train_ds = get_tf_dataset(
    img_paths=image_paths,
    labels=labels,
    batch_size=32,
    training=True,
    apply_clahe=True,
    augment_level="strong",
    num_classes=2
)

val_ds = get_tf_dataset(
    img_paths=image_paths,
    labels=labels,
    training=False,
    num_classes=2
)

## 6. Build Base Models (VGG16, ResNet50, InceptionV3, DenseNet121, DenseNet201, AlexNet (via custom implementation), ViT-B/16 (Vision Transformer, Custom CNN (my own architecture))

In [None]:
from tensorflow.keras.applications import (
    VGG16, ResNet50, InceptionV3, DenseNet121, DenseNet201
)
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import (
    Input, Dense, Dropout, GlobalAveragePooling2D, Conv2D, MaxPooling2D, Flatten
)
from tensorflow.keras.optimizers import Adam
import tensorflow_hub as hub


def build_base_model(name, input_shape=(224, 224, 3)):
    """
    Returns a base model architecture with include_top=False.
    """
    if name == 'VGG16':
        return VGG16(weights='imagenet', include_top=False, input_shape=input_shape)
    elif name == 'ResNet50':
        return ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
    elif name == 'InceptionV3':
        return InceptionV3(weights='imagenet', include_top=False, input_shape=input_shape)
    elif name == 'DenseNet121':
        return DenseNet121(weights='imagenet', include_top=False, input_shape=input_shape)
    elif name == 'DenseNet201':
        return DenseNet201(weights='imagenet', include_top=False, input_shape=input_shape)
    elif name == 'AlexNet':
        return build_alexnet(input_shape)
    elif name == 'ViT-B-16':
        return build_vit_model(input_shape)
    elif name == 'CustomCNN':
        return build_custom_cnn(input_shape)
    else:
        raise ValueError(f"Model '{name}' is not supported.")


def build_model_wrapper(model_name, input_shape=(224, 224, 3), num_classes=2, fine_tune=False):
    base = build_base_model(model_name, input_shape)
    if hasattr(base, 'trainable') and not fine_tune:
        base.trainable = False

    if model_name in ['AlexNet', 'CustomCNN']:  # Already includes classification head
        return base

    x = base.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.4)(x)

    if num_classes == 1:
        out = Dense(1, activation='sigmoid')(x)
        loss = 'binary_crossentropy'
    else:
        out = Dense(num_classes, activation='softmax')(x)
        loss = 'sparse_categorical_crossentropy'

    model = Model(inputs=base.input, outputs=out, name=f"{model_name}_Model")
    model.compile(optimizer=Adam(1e-4), loss=loss, metrics=['accuracy'])

    return model

In [None]:
def build_vit_model(input_shape=(224, 224, 3), num_classes=2):
    vit_url = "https://tfhub.dev/google/vit/base_patch16_224/1"
    vit_layer = hub.KerasLayer(vit_url, trainable=False)
    inputs = Input(shape=input_shape)
    x = tf.image.resize(inputs, [224, 224])
    x = vit_layer(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.3)(x)
    out = Dense(num_classes, activation='softmax')(x)
    model = Model(inputs, out, name="ViT_B16")
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

In [None]:
def build_custom_cnn(input_shape=(224, 224, 3), num_classes=2):
    model = Sequential([
        Input(shape=input_shape),
        Conv2D(32, (3, 3), activation='relu'),
        MaxPooling2D(),
        Conv2D(64, (3, 3), activation='relu'),
        MaxPooling2D(),
        Flatten(),
        Dense(128, activation='relu'),
        Dropout(0.3),
        Dense(num_classes, activation='softmax')
    ])
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

In [None]:
def build_alexnet(input_shape=(224, 224, 3), num_classes=2):
    model = Sequential([
        Input(shape=input_shape),
        Conv2D(96, kernel_size=11, strides=4, activation='relu'),
        MaxPooling2D(pool_size=3, strides=2),
        Conv2D(256, kernel_size=5, padding='same', activation='relu'),
        MaxPooling2D(pool_size=3, strides=2),
        Conv2D(384, kernel_size=3, padding='same', activation='relu'),
        Conv2D(384, kernel_size=3, padding='same', activation='relu'),
        Conv2D(256, kernel_size=3, padding='same', activation='relu'),
        MaxPooling2D(pool_size=3, strides=2),
        Flatten(),
        Dense(4096, activation='relu'),
        Dropout(0.5),
        Dense(4096, activation='relu'),
        Dropout(0.5),
        Dense(num_classes, activation='softmax')
    ])
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

## 7. Training and Evaluation

In [None]:

def train_and_evaluate(model, model_name, train_ds, val_ds, test_ds, epochs=10, output_dir="results"):

    os.makedirs(output_dir, exist_ok=True)
    weights_path = os.path.join(output_dir, f"{model_name}_best_weights.h5")

    # Callback to save best weights
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
        weights_path, monitor='val_accuracy', save_best_only=True, verbose=1
    )

    early_stopping_cb = tf.keras.callbacks.EarlyStopping(
        patience=5, restore_best_weights=True, monitor='val_loss'
    )

    print(f" Training model: {model_name}")
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs,
        callbacks=[checkpoint_cb, early_stopping_cb]
    )

    # Load best weights
    model.load_weights(weights_path)

    # Evaluate on test set
    print(f"\n Evaluating {model_name} on test data...")
    test_loss, test_acc = model.evaluate(test_ds)
    print(f" Test Accuracy: {test_acc:.4f}, Loss: {test_loss:.4f}")

    # Predict
    y_true, y_pred = [], []
    for x_batch, y_batch in test_ds:
        preds = model.predict(x_batch)
        y_true.extend(y_batch.numpy())
        y_pred.extend(np.argmax(preds, axis=1) if preds.shape[1] > 1 else (preds > 0.5).astype("int").flatten())

    print("\n Classification Report:")
    print(classification_report(y_true, y_pred))

    # Confusion Matrix
    cm = confusion_matrix(y_true, y_pred)
    print(" Confusion Matrix:")
    print(cm)

    # Plot training history
    plot_training_history(history, model_name, output_dir)


def plot_training_history(history, model_name, save_dir):
    """
    Plots accuracy and loss curves from model training history.
    """
    plt.figure(figsize=(12, 5))

    # Accuracy
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Train')
    plt.plot(history.history['val_accuracy'], label='Val')
    plt.title(f'{model_name} Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    # Loss
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Train')
    plt.plot(history.history['val_loss'], label='Val')
    plt.title(f'{model_name} Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f"{model_name}_training_plot.png"))
    plt.show()

In [None]:

model_names = ["VGG16", "ResNet50", "InceptionV3", "DenseNet121", "DenseNet201", "AlexNet", "ViT-B-16", "CustomCNN"]

for name in model_names:
    print(f"🔧 Building {name}...")
    model = build_model_wrapper(name, input_shape=(224, 224, 3), num_classes=2)
    model.summary()
    # model.fit(train_ds, validation_data=val_ds, epochs=5)


In [None]:


#  Define model names
model_names = [
    "VGG16",
    "ResNet50",
    "InceptionV3",
    "DenseNet121",
    "DenseNet201",
    "AlexNet",
    "ViT-B-16",
    "CustomCNN"
]


#  Loop through and train + evaluate each model
for name in model_names:
    print(f"\n Starting model: {name}")
    try:
        model = build_model_wrapper(
            model_name=name,
            input_shape=(224, 224, 3),
            num_classes=2,
            fine_tune=False
        )

        train_and_evaluate(
            model=model,
            model_name=name,
            train_ds=train_ds,
            val_ds=val_ds,
            test_ds=test_ds,
            epochs=10,
            output_dir=f"results/{name.lower()}"
        )
    except Exception as e:
        print(f"Failed to train {name}: {e}")


## 8. Stacking Ensemble

In [None]:
def get_base_model(model_class, input_shape=(224, 224, 3), name="base"):
    base = model_class(weights='imagenet', include_top=False, input_shape=input_shape)
    base.trainable = False
    inputs = base.input
    x = base.output
    x = GlobalAveragePooling2D(name=f"{name}_gap")(x)
    return Model(inputs=inputs, outputs=x, name=name)

In [None]:
def build_celm_model(input_shape=(224, 224, 3), num_classes=2):
    # Instantiate base feature extractors
    vgg = get_base_model(VGG16, input_shape, name="vgg16")
    resnet = get_base_model(ResNet50, input_shape, name="resnet50")
    inception = get_base_model(InceptionV3, input_shape, name="inceptionv3")

    # Define 3 separate inputs (but in practice, they'll be the same image)
    input_vgg = Input(shape=input_shape, name='input_vgg')
    input_res = Input(shape=input_shape, name='input_resnet')
    input_inc = Input(shape=input_shape, name='input_inception')

    # Get feature outputs
    feat_vgg = vgg(input_vgg)
    feat_res = resnet(input_res)
    feat_inc = inception(input_inc)

    # Concatenate features
    concatenated = Concatenate(name="concat_features")([feat_vgg, feat_res, feat_inc])
    x = Dense(256, activation='relu', name='fc1')(concatenated)
    x = Dropout(0.4, name='dropout')(x)

    if num_classes == 1:
        output = Dense(1, activation='sigmoid', name='output')(x)
        loss_fn = 'binary_crossentropy'
    else:
        output = Dense(num_classes, activation='softmax', name='output')(x)
        loss_fn = 'sparse_categorical_crossentropy'

    model = Model(inputs=[input_vgg, input_res, input_inc], outputs=output, name='CELM')
    model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])

    return model

In [None]:
def get_multi_input_dataset(single_input_dataset):
    """
    Converts a (image, label) dataset into a multi-input tuple for stacking ensemble.
    Returns dataset in format: ((img, img, img), label)
    """
    return single_input_dataset.map(lambda x, y: ((x, x, x), y), num_parallel_calls=tf.data.AUTOTUNE)

In [None]:

from sklearn.linear_model import LogisticRegression

# Collect model predictions and fit meta-learner
def train_meta_model(base_preds, y_true):
    meta_model = LogisticRegression()
    meta_model.fit(base_preds, y_true)
    return meta_model


## 9. Training and Evaluation

In [None]:
train_ds_multi = get_multi_input_dataset(train_ds)
val_ds_multi = get_multi_input_dataset(val_ds)
test_ds_multi = get_multi_input_dataset(test_ds)

celm_model = build_celm_model(input_shape=(224, 224, 3), num_classes=2)

from train_and_evaluate import train_and_evaluate

train_and_evaluate(
    model=celm_model,
    model_name="CELM",
    train_ds=train_ds_multi,
    val_ds=val_ds_multi,
    test_ds=test_ds_multi,
    epochs=10,
    output_dir="results/celm"
)


## 10. Save Models and Export

In [None]:

# -----------------------------
# Export helper function
# -----------------------------
def export_model(model, model_name="model", export_dir="exports"):
    os.makedirs(export_dir, exist_ok=True)

    # HDF5 format
    h5_path = os.path.join(export_dir, f"{model_name}.h5")
    model.save(h5_path)
    print(f" {model_name} saved to HDF5: {h5_path}")

    # SavedModel format
    savedmodel_path = os.path.join(export_dir, f"{model_name}_savedmodel")
    model.save(savedmodel_path, save_format='tf')
    print(f" {model_name} saved as SavedModel: {savedmodel_path}")

    # TFLite format
    converter = tf.lite.TFLiteConverter.from_saved_model(savedmodel_path)
    tflite_model = converter.convert()
    tflite_path = os.path.join(export_dir, f"{model_name}.tflite")
    with open(tflite_path, 'wb') as f:
        f.write(tflite_model)
    print(f" {model_name} exported to TensorFlow Lite: {tflite_path}\n")


# -----------------------------
# Define model names
# -----------------------------
model_names = [
    "VGG16",
    "ResNet50",
    "InceptionV3",
    "DenseNet121",
    "DenseNet201",
    "AlexNet",
    "ViT-B-16",
    "CustomCNN"
]

input_shape = (224, 224, 3)
num_classes = 2




In [None]:
  # -----------------------------
#  Loop over each model
# -----------------------------
for name in model_names:
    print(f"\n Exporting model: {name}")
    try:
        model = build_model_wrapper(name, input_shape=input_shape, num_classes=num_classes)
        weights_path = f"results/{name.lower()}/{name}_best_weights.h5"
        model.load_weights(weights_path)
        export_model(model, model_name=name, export_dir=f"exports/{name.lower()}")
    except Exception as e:
        print(f" Could not export {name}: {e}")

# -----------------------------
#  Handle CELM (Stacking Ensemble)
# -----------------------------
print("\n Exporting CELM ensemble model...")
try:
    celm_model = build_celm_model(input_shape=input_shape, num_classes=num_classes)
    celm_model.load_weights("results/celm/CELM_best_weights.h5")
    export_model(celm_model, model_name="CELM", export_dir="exports/celm")
except Exception as e:
    print(f" Could not export CELM: {e}")