In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
imtkaggleteam_dental_radiography_path = kagglehub.dataset_download('imtkaggleteam/dental-radiography')

print('Data source import complete.')


In [None]:
# -*- coding: utf-8 -*-
"""
Multi-Model Dental X-Ray Classification with Attention Mechanisms
EfficientNetB0 + SE/CBAM, MobileNetV2
Focal Loss + TTA + Comprehensive Evaluation
FINAL CORRECTED VERSION - IEEE ICME 2026
"""

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import cv2
import numpy as np
import pandas as pd
import warnings
import seaborn as sns
import matplotlib.pyplot as plt
import time
import json
warnings.filterwarnings('ignore')

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (confusion_matrix, classification_report, f1_score,
                            matthews_corrcoef, roc_auc_score, roc_curve, auc,
                            precision_recall_fscore_support)
from sklearn.utils.class_weight import compute_class_weight
from statsmodels.stats.contingency_tables import mcnemar
from scipy.stats import wilcoxon, norm
from collections import Counter

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.applications import (EfficientNetB0, MobileNetV2)
from tensorflow.keras.applications.efficientnet import preprocess_input as eff_preprocess
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mob_preprocess

print(f"TensorFlow Version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

# ---------- CONFIG ----------
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 50
SEED = 42
TTA_STEPS = 6

np.random.seed(SEED)
tf.random.set_seed(SEED)

# ---------- ATTENTION MODULES ----------
def se_block(input_tensor, reduction=16):
    """Squeeze-and-Excitation Block"""
    channels = input_tensor.shape[-1]
    se = layers.GlobalAveragePooling2D()(input_tensor)
    se = layers.Dense(channels // reduction, activation='relu')(se)
    se = layers.Dense(channels, activation='sigmoid')(se)
    se = layers.Reshape((1, 1, channels))(se)
    return layers.Multiply()([input_tensor, se])

def cbam_block(input_tensor, reduction=16):
    """Convolutional Block Attention Module"""
    channels = input_tensor.shape[-1]

    # Channel Attention
    avg_pool = layers.GlobalAveragePooling2D()(input_tensor)
    max_pool = layers.GlobalMaxPooling2D()(input_tensor)
    avg_pool = layers.Reshape((1, 1, channels))(avg_pool)
    max_pool = layers.Reshape((1, 1, channels))(max_pool)

    shared_dense1 = layers.Dense(channels // reduction, activation='relu')
    shared_dense2 = layers.Dense(channels, activation='sigmoid')

    avg_out = shared_dense2(shared_dense1(avg_pool))
    max_out = shared_dense2(shared_dense1(max_pool))

    channel_attn = layers.Add()([avg_out, max_out])
    channel_attn = layers.Reshape((1, 1, channels))(channel_attn)
    channel_attn = layers.Multiply()([input_tensor, channel_attn])

    # Spatial Attention
    avg_pool_spatial = layers.Lambda(lambda x: tf.reduce_mean(x, axis=-1, keepdims=True))(channel_attn)
    max_pool_spatial = layers.Lambda(lambda x: tf.reduce_max(x, axis=-1, keepdims=True))(channel_attn)
    spatial_attn = layers.Concatenate(axis=-1)([avg_pool_spatial, max_pool_spatial])
    spatial_attn = layers.Conv2D(1, kernel_size=7, padding='same', activation='sigmoid')(spatial_attn)

    return layers.Multiply()([channel_attn, spatial_attn])

# ---------- MODEL BUILDERS ----------
def build_efficientnet_vanilla(n_classes, img_size=IMG_SIZE):
    """EfficientNetB0 Baseline (No Attention)"""
    base = EfficientNetB0(include_top=False, weights='imagenet', input_shape=(img_size, img_size, 3))
    base.trainable = True
    x = layers.GlobalAveragePooling2D()(base.output)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.4)(x)
    x = layers.Dense(256, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    out = layers.Dense(n_classes, activation='softmax')(x)
    return models.Model(base.input, out, name='EfficientNetB0')

def build_efficientnet_se(n_classes, img_size=IMG_SIZE):
    """EfficientNetB0 + SE Attention"""
    base = EfficientNetB0(include_top=False, weights='imagenet', input_shape=(img_size, img_size, 3))
    base.trainable = True
    x = base.output
    x = se_block(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.4)(x)
    x = layers.Dense(256, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    out = layers.Dense(n_classes, activation='softmax')(x)
    return models.Model(base.input, out, name='EfficientNetB0_SE')

def build_efficientnet_cbam(n_classes, img_size=IMG_SIZE):
    """EfficientNetB0 + CBAM Attention"""
    base = EfficientNetB0(include_top=False, weights='imagenet', input_shape=(img_size, img_size, 3))
    base.trainable = True
    x = base.output
    x = cbam_block(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.4)(x)
    x = layers.Dense(256, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    out = layers.Dense(n_classes, activation='softmax')(x)
    return models.Model(base.input, out, name='EfficientNetB0_CBAM')

def build_mobilenet(n_classes, img_size=IMG_SIZE):
    """MobileNetV2 (Lightweight Baseline)"""
    base = MobileNetV2(include_top=False, weights='imagenet', input_shape=(img_size, img_size, 3))
    base.trainable = True
    x = layers.GlobalAveragePooling2D()(base.output)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.4)(x)
    x = layers.Dense(256, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    out = layers.Dense(n_classes, activation='softmax')(x)
    return models.Model(base.input, out, name='MobileNetV2')

# Model registry
MODELS = {
    'EfficientNetB0': (build_efficientnet_vanilla, eff_preprocess),
    'EfficientNetB0_SE': (build_efficientnet_se, eff_preprocess),
    'EfficientNetB0_CBAM': (build_efficientnet_cbam, eff_preprocess),
    'MobileNetV2': (build_mobilenet, mob_preprocess),
}

# ---------- DATASET LOADER ----------
print("\nSetting up dataset...")
possible_paths = [
    '/kaggle/input/dental-radiography',
    '/kaggle/input/dental-radiography/dataset',
    './dental-radiography'
]
dataset_path = None
for p in possible_paths:
    if os.path.exists(os.path.join(p, 'train')):
        dataset_path = p
        print(f"Found dataset at: {p}")
        break

if dataset_path is None:
    try:
        import kagglehub
        dataset_path = kagglehub.dataset_download('imtkaggleteam/dental-radiography')
        print(f"Downloaded dataset to: {dataset_path}")
    except Exception as e:
        print(f"Error downloading dataset: {e}")
        raise

# ---------- FOCAL LOSS ----------
class FocalLoss(tf.keras.losses.Loss):
    def __init__(self, alpha=0.5, gamma=1.0, name='focal_loss'):
        super().__init__(name=name)
        self.alpha, self.gamma = alpha, gamma

    def call(self, y_true, y_pred):
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1-1e-7)
        ce = -y_true * tf.math.log(y_pred)
        weight = self.alpha * y_true * tf.pow(1 - y_pred, self.gamma)
        return tf.reduce_sum(weight * ce, axis=-1)

# ---------- DATA PIPELINE ----------
def load_ann(csv):
    """Load annotations with outlier filtering"""
    df = pd.read_csv(csv)
    df['area'] = (df['xmax']-df['xmin']) * (df['ymax']-df['ymin'])
    q25, q75 = df['area'].quantile([0.25, 0.75])
    return df[(df['area']>=q25) & (df['area']<=q75)]

def load_imgs(df, folder):
    """Load and preprocess images"""
    imgs, labs = [], []
    for _, r in df.iterrows():
        path = os.path.join(folder, r['filename'])
        if not os.path.exists(path):
            continue
        img = cv2.imread(path)
        if img is None:
            continue
        g = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        crop = g[int(r['ymin']):int(r['ymax']), int(r['xmin']):int(r['xmax'])]
        if crop.size == 0:
            continue
        rgb = cv2.cvtColor(cv2.resize(crop, (IMG_SIZE, IMG_SIZE)), cv2.COLOR_GRAY2RGB)
        imgs.append(rgb)
        labs.append(r['class'])
    return np.array(imgs), np.array(labs)

print("Loading data...")
train_df = load_ann(os.path.join(dataset_path, 'train/_annotations.csv'))
valid_df = load_ann(os.path.join(dataset_path, 'valid/_annotations.csv'))
test_df  = load_ann(os.path.join(dataset_path, 'test/_annotations.csv'))

X_train, y_train = load_imgs(train_df, os.path.join(dataset_path, 'train'))
X_valid, y_valid = load_imgs(valid_df, os.path.join(dataset_path, 'valid'))
X_test, y_test = load_imgs(test_df, os.path.join(dataset_path, 'test'))

print(f"Train: {len(X_train)} samples")
print(f"Valid: {len(X_valid)} samples")
print(f"Test:  {len(X_test)} samples")

# Encode labels
le = LabelEncoder()
y_train_enc = le.fit_transform(y_train)
y_valid_enc = le.transform(y_valid)
y_test_enc = le.transform(y_test)

y_train_cat = tf.keras.utils.to_categorical(y_train_enc)
y_valid_cat = tf.keras.utils.to_categorical(y_valid_enc)
y_test_cat = tf.keras.utils.to_categorical(y_test_enc)

print(f"Classes: {le.classes_}")

# CRITICAL: Compute class weights for imbalanced data
class_weights = dict(enumerate(
    compute_class_weight('balanced', classes=np.unique(y_train_enc), y=y_train_enc)
))
print(f"Class weights: {class_weights}")

# ---------- TTA FUNCTION ----------
def tta_predict(model, X, steps=TTA_STEPS):
    """Test-Time Augmentation with mild augmentation (anatomy-aware)"""
    preds = [model.predict(X, verbose=0)]
    datagen = ImageDataGenerator(
        rotation_range=5,        # Mild rotation
        width_shift_range=0.05,  # Small shift
        height_shift_range=0.05,
        horizontal_flip=False    # No flip for dental X-rays (anatomy matters)
    )
    for _ in range(steps-1):
        aug = next(datagen.flow(X, batch_size=len(X), shuffle=False))
        preds.append(model.predict(aug, verbose=0))
    return np.mean(preds, axis=0), np.std(preds, axis=0)

# ---------- CREATE OUTPUT DIRECTORIES ----------
os.makedirs('figures', exist_ok=True)
os.makedirs('models', exist_ok=True)

# ---------- TRAIN & EVAL LOOP ----------
results = {}
histories = {}

for model_name, (build_fn, preprocess_fn) in MODELS.items():
    print(f"\n{'='*70}")
    print(f"Training {model_name}")
    print(f"{'='*70}")

    # Clear session to free memory
    tf.keras.backend.clear_session()

    # Build model
    model = build_fn(len(le.classes_))
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-4),
        loss=FocalLoss(alpha=0.5, gamma=1.0),
        metrics=['accuracy']
    )

    print(f"Model parameters: {model.count_params():,}")

    # Data generators (anatomy-aware augmentation - NO horizontal flip)
    train_datagen = ImageDataGenerator(
        rotation_range=15,       # Moderate rotation
        width_shift_range=0.1,   # Small shift
        height_shift_range=0.1,
        zoom_range=0.1,         # Small zoom
        horizontal_flip=False,   # CRITICAL: No flip for dental anatomy
        preprocessing_function=preprocess_fn
    )
    valid_datagen = ImageDataGenerator(preprocessing_function=preprocess_fn)

    train_gen = train_datagen.flow(
        X_train.astype('float32'), y_train_cat,
        batch_size=BATCH_SIZE, shuffle=True, seed=SEED
    )
    valid_gen = valid_datagen.flow(
        X_valid.astype('float32'), y_valid_cat,
        batch_size=BATCH_SIZE, shuffle=False
    )

    # Callbacks
    cb = [
        EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1),
        ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=4, min_lr=1e-7, verbose=1),
        ModelCheckpoint(f'models/best_{model_name}.keras', monitor='val_accuracy',
                       save_best_only=True, verbose=0)
    ]

    # Train with class weights (CRITICAL for imbalanced data!)
    start_time = time.time()
    history = model.fit(
        train_gen,
        validation_data=valid_gen,
        epochs=EPOCHS,
        class_weight=class_weights,  # CRITICAL: Handle class imbalance
        callbacks=cb,
        verbose=1,
        steps_per_epoch=len(X_train)//BATCH_SIZE,
        validation_steps=len(X_valid)//BATCH_SIZE
    )

    train_time = time.time() - start_time
    histories[model_name] = history.history

    # Inference time measurement
    X_test_prep = preprocess_fn(X_test.astype('float32'))
    start_inf = time.time()
    _ = model.predict(X_test_prep[:10], verbose=0)
    inf_time = (time.time() - start_inf) / 10 * 1000  # ms per image

    params = model.count_params()

    # Standard predictions
    print("\nRunning standard inference...")
    y_pred_std = model.predict(X_test_prep, verbose=0)
    y_pred_std_cls = np.argmax(y_pred_std, axis=1)

    # TTA predictions
    print("Running TTA inference...")
    y_pred_tta, y_pred_tta_var = tta_predict(model, X_test_prep)
    y_pred_tta_cls = np.argmax(y_pred_tta, axis=1)

    # Compute metrics
    acc_std = np.mean(y_test_enc == y_pred_std_cls)
    f1_macro_std = f1_score(y_test_enc, y_pred_std_cls, average='macro')
    f1_weighted_std = f1_score(y_test_enc, y_pred_std_cls, average='weighted')
    mcc_std = matthews_corrcoef(y_test_enc, y_pred_std_cls)

    acc_tta = np.mean(y_test_enc == y_pred_tta_cls)
    f1_macro_tta = f1_score(y_test_enc, y_pred_tta_cls, average='macro')
    f1_weighted_tta = f1_score(y_test_enc, y_pred_tta_cls, average='weighted')
    mcc_tta = matthews_corrcoef(y_test_enc, y_pred_tta_cls)

    # AUC-ROC
    try:
        auc_std = roc_auc_score(y_test_cat, y_pred_std, multi_class='ovr', average='macro')
        auc_tta = roc_auc_score(y_test_cat, y_pred_tta, multi_class='ovr', average='macro')
    except Exception as e:
        print(f"AUC calculation failed: {e}")
        auc_std = 0.0
        auc_tta = 0.0

    # Confusion matrices
    cm_std = confusion_matrix(y_test_enc, y_pred_std_cls)
    cm_tta = confusion_matrix(y_test_enc, y_pred_tta_cls)

    # Store results
    results[model_name] = {
        'params': params,
        'train_time': train_time,
        'inf_time': inf_time,
        'std': {
            'acc': acc_std,
            'f1_macro': f1_macro_std,
            'f1_weighted': f1_weighted_std,
            'mcc': mcc_std,
            'auc': auc_std,
            'cm': cm_std
        },
        'tta': {
            'acc': acc_tta,
            'f1_macro': f1_macro_tta,
            'f1_weighted': f1_weighted_tta,
            'mcc': mcc_tta,
            'auc': auc_tta,
            'cm': cm_tta,
            'proba': y_pred_tta,
            'var': y_pred_tta_var
        },
        'predictions': {
            'std_cls': y_pred_std_cls,
            'tta_cls': y_pred_tta_cls
        }
    }

    # Print results
    print(f"\n{model_name} Results:")
    print(f"  Parameters:    {params:,}")
    print(f"  Train Time:    {train_time:.1f}s ({train_time/60:.1f} min)")
    print(f"  Inference:     {inf_time:.2f}ms/image")
    print(f"  Standard:      Acc={acc_std:.4f}, F1={f1_macro_std:.4f}, AUC={auc_std:.4f}")
    print(f"  TTA:           Acc={acc_tta:.4f}, F1={f1_macro_tta:.4f}, AUC={auc_tta:.4f}")
    print(f"  Improvement:   +{(acc_tta-acc_std)*100:.2f}% accuracy")

# ---------- ANALYSIS & VISUALIZATION ----------
model_names = list(results.keys())
print(f"\n{'='*70}")
print("GENERATING VISUALIZATIONS AND ANALYSIS")
print(f"{'='*70}")

# 1. Confusion Matrices (2x2 grid)
print("\n1. Plotting confusion matrices...")
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
axes = axes.ravel()
for i, model_name in enumerate(model_names):
    cm = results[model_name]['tta']['cm']
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=le.classes_, yticklabels=le.classes_, ax=axes[i])
    axes[i].set_title(f'{model_name} (TTA)', fontsize=12, fontweight='bold')
    axes[i].set_xlabel('Predicted', fontsize=10)
    axes[i].set_ylabel('True', fontsize=10)
plt.tight_layout()
plt.savefig('figures/confusion_matrices.png', dpi=300, bbox_inches='tight')
plt.savefig('figures/confusion_matrices.pdf', bbox_inches='tight')
plt.close()
print("   ✓ Saved: confusion_matrices.png/pdf")

# 2. Training Curves
print("2. Plotting training curves...")
fig = plt.figure(figsize=(14, 10))
for i, model_name in enumerate(model_names):
    plt.subplot(2, 2, i+1)
    h = histories[model_name]
    epochs_range = range(len(h['accuracy']))
    plt.plot(epochs_range, h['accuracy'], 'b-', label='Train Acc', linewidth=2)
    plt.plot(epochs_range, h['val_accuracy'], 'r-', label='Val Acc', linewidth=2)
    plt.title(f'{model_name}', fontsize=12, fontweight='bold')
    plt.xlabel('Epoch', fontsize=10)
    plt.ylabel('Accuracy', fontsize=10)
    plt.legend(fontsize=9)
    plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('figures/training_curves.png', dpi=300, bbox_inches='tight')
plt.savefig('figures/training_curves.pdf', bbox_inches='tight')
plt.close()
print("   ✓ Saved: training_curves.png/pdf")

# 3. Model Comparison Bar Charts
print("3. Plotting model comparisons...")
metrics_to_plot = ['acc', 'f1_macro', 'auc']
for metric in metrics_to_plot:
    std_vals = [results[m]['std'][metric] for m in model_names]
    tta_vals = [results[m]['tta'][metric] for m in model_names]
    x = np.arange(len(model_names))
    width = 0.35

    fig, ax = plt.subplots(figsize=(10, 6))
    bars1 = ax.bar(x - width/2, std_vals, width, label='Standard', alpha=0.8)
    bars2 = ax.bar(x + width/2, tta_vals, width, label='TTA', alpha=0.8)

    ax.set_xlabel('Models', fontsize=12, fontweight='bold')
    ax.set_ylabel(metric.upper(), fontsize=12, fontweight='bold')
    ax.set_title(f'{metric.upper()} Comparison', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(model_names, rotation=45, ha='right')
    ax.legend(fontsize=11)
    ax.grid(axis='y', alpha=0.3)

    # Add value labels
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{height:.3f}', ha='center', va='bottom', fontsize=8)

    plt.tight_layout()
    plt.savefig(f'figures/comparison_{metric}.png', dpi=300, bbox_inches='tight')
    plt.savefig(f'figures/comparison_{metric}.pdf', bbox_inches='tight')
    plt.close()
print("   ✓ Saved: comparison_*.png/pdf")

# 4. Select Best Model
best_model_name = max(results, key=lambda x: results[x]['tta']['acc'])
print(f"\n4. Best Model: {best_model_name} (TTA Acc={results[best_model_name]['tta']['acc']:.4f})")

# 5. Per-Class Performance (Best Model)
print("5. Plotting per-class performance...")
precision, recall, f1_per_class, _ = precision_recall_fscore_support(
    y_test_enc, results[best_model_name]['predictions']['tta_cls'], average=None
)

fig, ax = plt.subplots(figsize=(10, 6))
x = np.arange(len(le.classes_))
width = 0.25

bars1 = ax.bar(x - width, precision, width, label='Precision', alpha=0.8)
bars2 = ax.bar(x, recall, width, label='Recall', alpha=0.8)
bars3 = ax.bar(x + width, f1_per_class, width, label='F1-Score', alpha=0.8)

ax.set_xlabel('Classes', fontsize=12, fontweight='bold')
ax.set_ylabel('Score', fontsize=12, fontweight='bold')
ax.set_title(f'Per-Class Performance - {best_model_name} (TTA)', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(le.classes_, rotation=45, ha='right')
ax.legend(fontsize=11)
ax.set_ylim([0, 1.1])
ax.grid(axis='y', alpha=0.3)

for bars in [bars1, bars2, bars3]:
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
               f'{height:.2f}', ha='center', va='bottom', fontsize=8)

plt.tight_layout()
plt.savefig('figures/per_class_performance.png', dpi=300, bbox_inches='tight')
plt.savefig('figures/per_class_performance.pdf', bbox_inches='tight')
plt.close()
print("   ✓ Saved: per_class_performance.png/pdf")

# 6. ROC Curves (Best Model)
print("6. Plotting ROC curves...")
y_proba = results[best_model_name]['tta']['proba']

fig, ax = plt.subplots(figsize=(10, 8))
colors = ['blue', 'red', 'green', 'orange']
for i in range(len(le.classes_)):
    fpr, tpr, _ = roc_curve(y_test_cat[:, i], y_proba[:, i])
    roc_auc = auc(fpr, tpr)
    ax.plot(fpr, tpr, color=colors[i], lw=2,
            label=f'{le.classes_[i]} (AUC = {roc_auc:.3f})')

ax.plot([0, 1], [0, 1], 'k--', lw=2, label='Random (AUC = 0.5)')
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='bold')
ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='bold')
ax.set_title(f'ROC Curves - {best_model_name} (TTA)', fontsize=14, fontweight='bold')
ax.legend(loc='lower right', fontsize=10)
ax.grid(alpha=0.3)

plt.tight_layout()
plt.savefig('figures/roc_curves.png', dpi=300, bbox_inches='tight')
plt.savefig('figures/roc_curves.pdf', bbox_inches='tight')
plt.close()
print("   ✓ Saved: roc_curves.png/pdf")

# ---------- STATISTICAL TESTING ----------
print(f"\n{'='*70}")
print("STATISTICAL TESTING")
print(f"{'='*70}")

# McNemar's Test
print("\nMcNemar's Test (Pairwise Model Comparison):")
mcnemar_results = {}
for i in range(len(model_names)):
    for j in range(i+1, len(model_names)):
        m1, m2 = model_names[i], model_names[j]
        y1 = results[m1]['predictions']['tta_cls']
        y2 = results[m2]['predictions']['tta_cls']
        y1_correct = (y1 == y_test_enc)
        y2_correct = (y2 == y_test_enc)

        n_01 = np.sum(~y1_correct & y2_correct)
        n_10 = np.sum(y1_correct & ~y2_correct)
        table = [[0, n_01], [n_10, 0]]

        mc_test = mcnemar(table, exact=False, correction=True)
        mcnemar_results[f"{m1} vs {m2}"] = {
            'statistic': mc_test.statistic,
            'p_value': mc_test.pvalue
        }
        sig = "✓ Significant" if mc_test.pvalue < 0.05 else "✗ Not significant"
        print(f"  {m1} vs {m2}: p={mc_test.pvalue:.4f} ({sig})")

# Wilcoxon Test (Standard vs TTA)
print("\nWilcoxon Test (Standard vs TTA):")
wilcoxon_results = {}
for model_name in model_names:
    std_correct = (results[model_name]['predictions']['std_cls'] == y_test_enc).astype(int)
    tta_correct = (results[model_name]['predictions']['tta_cls'] == y_test_enc).astype(int)

    try:
        wilcox_test = wilcoxon(std_correct, tta_correct)
        wilcoxon_results[model_name] = {
            'statistic': wilcox_test.statistic,
            'p_value': wilcox_test.pvalue
        }
        sig = "✓ Significant" if wilcox_test.pvalue < 0.05 else "✗ Not significant"
        print(f"  {model_name}: p={wilcox_test.pvalue:.4f} ({sig})")
    except Exception as e:
        print(f"  {model_name}: Test failed ({e})")

# 95% Confidence Intervals
print("\n95% Confidence Intervals:")
def confidence_interval(p, n, alpha=0.05):
    z = norm.ppf(1 - alpha/2)
    se = np.sqrt(p * (1 - p) / n)
    return p - z * se, p + z * se

n = len(y_test_enc)
ci_results = {}
for model_name in model_names:
    acc_std = results[model_name]['std']['acc']
    acc_tta = results[model_name]['tta']['acc']
    ci_std = confidence_interval(acc_std, n)
    ci_tta = confidence_interval(acc_tta, n)
    ci_results[model_name] = {'std_ci': ci_std, 'tta_ci': ci_tta}
    print(f"  {model_name}:")
    print(f"    Standard: [{ci_std[0]:.4f}, {ci_std[1]:.4f}]")
    print(f"    TTA:      [{ci_tta[0]:.4f}, {ci_tta[1]:.4f}]")

# ---------- GRAD-CAM VISUALIZATION ----------
print(f"\n{'='*70}")
print("GRAD-CAM VISUALIZATION")
print(f"{'='*70}")

def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
    """Generate Grad-CAM heatmap"""
    grad_model = tf.keras.models.Model(
        [model.inputs], [model.get_layer(last_conv_layer_name).output, model.output]
    )

    with tf.GradientTape() as tape:
        last_conv_layer_output, preds = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(preds[0])
        class_channel = preds[:, pred_index]

    grads = tape.gradient(class_channel, last_conv_layer_output)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

    last_conv_layer_output = last_conv_layer_output[0]
    heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    heatmap = tf.maximum(heatmap, 0) / tf.maximum(tf.reduce_max(heatmap), 1e-10)
    return heatmap.numpy()

def superimpose_heatmap(img, heatmap, alpha=0.4):
    """Overlay heatmap on image"""
    heatmap = np.uint8(255 * heatmap)
    jet = plt.cm.get_cmap("jet")
    jet_colors = jet(np.arange(256))[:, :3]
    jet_heatmap = jet_colors[heatmap]
    jet_heatmap = tf.keras.preprocessing.image.array_to_img(jet_heatmap)
    jet_heatmap = jet_heatmap.resize((img.shape[1], img.shape[0]))
    jet_heatmap = tf.keras.preprocessing.image.img_to_array(jet_heatmap)
    superimposed_img = jet_heatmap * alpha + img
    superimposed_img = tf.keras.preprocessing.image.array_to_img(superimposed_img)
    return superimposed_img

print(f"Generating Grad-CAM for {best_model_name}...")
try:
    # Load best model
    best_model_fn = MODELS[best_model_name][0]
    best_preprocess = MODELS[best_model_name][1]

    viz_model = best_model_fn(len(le.classes_))
    viz_model.load_weights(f'models/best_{best_model_name}.keras')

    # Auto-find last conv layer
    target_layer = None
    for layer in reversed(viz_model.layers):
        if 'conv' in layer.name.lower() or 'relu' in layer.name.lower():
            target_layer = layer.name
            break

    if target_layer is None:
        print("  Warning: Could not find convolutional layer")
        target_layer = viz_model.layers[-10].name  # Fallback

    print(f"  Using layer: {target_layer}")

    # Generate visualizations (one per class)
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    for i, cls in enumerate(le.classes_):
        class_indices = np.where(y_test_enc == i)[0]
        if len(class_indices) > 0:
            idx = class_indices[0]
            img = X_test[idx]
            img_prep = best_preprocess(np.expand_dims(img.astype('float32'), axis=0))

            heatmap = make_gradcam_heatmap(img_prep, viz_model, target_layer)
            viz = superimpose_heatmap(img, heatmap)

            axes[i].imshow(viz)
            pred_cls = results[best_model_name]['predictions']['tta_cls'][idx]
            axes[i].set_title(f'True: {cls}\nPred: {le.classes_[pred_cls]}', fontsize=10)
            axes[i].axis('off')
        else:
            axes[i].axis('off')
            axes[i].set_title(f'{cls}\n(No samples)', fontsize=10)

    plt.tight_layout()
    plt.savefig('figures/gradcam.png', dpi=300, bbox_inches='tight')
    plt.savefig('figures/gradcam.pdf', bbox_inches='tight')
    plt.close()
    print("   ✓ Saved: gradcam.png/pdf")

except Exception as e:
    print(f"  Grad-CAM generation failed: {e}")

# ---------- TTA ABLATION STUDY ----------
print(f"\n{'='*70}")
print("TTA ABLATION STUDY")
print(f"{'='*70}")

print(f"Running ablation on {best_model_name}...")
try:
    # Load best model fresh
    ablation_model = MODELS[best_model_name][0](len(le.classes_))
    ablation_model.load_weights(f'models/best_{best_model_name}.keras')
    X_test_prep = MODELS[best_model_name][1](X_test.astype('float32'))

    steps_list = [3, 6, 10, 15]
    accs = []
    f1s = []

    for steps in steps_list:
        print(f"  Testing {steps} TTA steps...")
        y_p, _ = tta_predict(ablation_model, X_test_prep, steps=steps)
        y_p_cls = np.argmax(y_p, axis=1)
        acc = np.mean(y_test_enc == y_p_cls)
        f1 = f1_score(y_test_enc, y_p_cls, average='macro')
        accs.append(acc)
        f1s.append(f1)
        print(f"    Acc={acc:.4f}, F1={f1:.4f}")

    # Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    ax1.plot(steps_list, accs, marker='o', linewidth=2, markersize=8)
    ax1.set_xlabel('TTA Steps', fontsize=12, fontweight='bold')
    ax1.set_ylabel('Accuracy', fontsize=12, fontweight='bold')
    ax1.set_title(f'TTA Steps vs Accuracy\n({best_model_name})', fontsize=14, fontweight='bold')
    ax1.grid(alpha=0.3)

    ax2.plot(steps_list, f1s, marker='s', linewidth=2, markersize=8, color='red')
    ax2.set_xlabel('TTA Steps', fontsize=12, fontweight='bold')
    ax2.set_ylabel('Macro-F1', fontsize=12, fontweight='bold')
    ax2.set_title(f'TTA Steps vs Macro-F1\n({best_model_name})', fontsize=14, fontweight='bold')
    ax2.grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig('figures/tta_ablation.png', dpi=300, bbox_inches='tight')
    plt.savefig('figures/tta_ablation.pdf', bbox_inches='tight')
    plt.close()
    print("   ✓ Saved: tta_ablation.png/pdf")

except Exception as e:
    print(f"  TTA ablation failed: {e}")

# ---------- FAILURE ANALYSIS ----------
print(f"\n{'='*70}")
print("FAILURE ANALYSIS")
print(f"{'='*70}")

print("\nMisclassification counts:")
for model_name in model_names:
    pred = results[model_name]['predictions']['tta_cls']
    mis_count = np.sum(pred != y_test_enc)
    print(f"  {model_name}: {mis_count} errors ({mis_count/len(y_test_enc)*100:.1f}%)")

print("\nMost common errors (True → Pred):")
all_errors = []
for model_name in model_names:
    pred = results[model_name]['predictions']['tta_cls']
    for i in range(len(y_test_enc)):
        if pred[i] != y_test_enc[i]:
            all_errors.append((y_test_enc[i], pred[i]))

error_counts = Counter(all_errors)
for (true, pred), count in error_counts.most_common(10):
    print(f"  {le.classes_[true]} → {le.classes_[pred]}: {count} times")

# ---------- LATEX TABLES ----------
print(f"\n{'='*70}")
print("LATEX TABLES FOR PAPER")
print(f"{'='*70}")

# Dataset Statistics
train_counts = pd.Series(y_train).value_counts().sort_index()
val_counts = pd.Series(y_valid).value_counts().sort_index()
test_counts = pd.Series(y_test).value_counts().sort_index()

print("\n--- Dataset Statistics Table ---")
print("\\begin{table}[htbp]")
print("\\centering")
print("\\caption{Dataset Statistics}")
print("\\begin{tabular}{|l|r|r|r|r|}")
print("\\hline")
print("\\textbf{Class} & \\textbf{Train} & \\textbf{Val} & \\textbf{Test} & \\textbf{Total} \\\\ \\hline")
for i, cls in enumerate(le.classes_):
    cls_total = train_counts.get(i, 0) + val_counts.get(i, 0) + test_counts.get(i, 0)
    print(f"{cls} & {train_counts.get(i, 0)} & {val_counts.get(i, 0)} & {test_counts.get(i, 0)} & {cls_total} \\\\ \\hline")
print(f"\\textbf{{Total}} & {len(X_train)} & {len(X_valid)} & {len(X_test)} & {len(X_train)+len(X_valid)+len(X_test)} \\\\ \\hline")
print("\\end{tabular}")
print("\\end{table}")

# Model Comparison
print("\n--- Model Comparison Table ---")
print("\\begin{table*}[htbp]")
print("\\centering")
print("\\caption{Model Performance Comparison}")
print("\\begin{tabular}{|l|r|r|c|c|c|c|}")
print("\\hline")
print("\\textbf{Model} & \\textbf{Params} & \\textbf{Inf (ms)} & \\textbf{Std Acc} & \\textbf{TTA Acc} & \\textbf{TTA F1} & \\textbf{AUC} \\\\ \\hline")
for model_name in model_names:
    res = results[model_name]
    params_str = f"{res['params']:,}"
    print(f"{model_name.replace('_', ' ')} & {params_str} & {res['inf_time']:.1f} & {res['std']['acc']:.4f} & {res['tta']['acc']:.4f} & {res['tta']['f1_macro']:.4f} & {res['tta']['auc']:.4f} \\\\ \\hline")
print("\\end{tabular}")
print("\\end{table*}")

# Per-Class Metrics
print("\n--- Per-Class Metrics Table ---")
print("\\begin{table}[htbp]")
print("\\centering")
print("\\caption{Per-Class Performance (" + best_model_name.replace('_', ' ') + ")}")
print("\\begin{tabular}{|l|c|c|c|}")
print("\\hline")
print("\\textbf{Class} & \\textbf{Precision} & \\textbf{Recall} & \\textbf{F1-Score} \\\\ \\hline")
for i, cls in enumerate(le.classes_):
    print(f"{cls} & {precision[i]:.4f} & {recall[i]:.4f} & {f1_per_class[i]:.4f} \\\\ \\hline")
print("\\end{tabular}")
print("\\end{table}")

# Attention Ablation
attention_models = ['EfficientNetB0', 'EfficientNetB0_SE', 'EfficientNetB0_CBAM']
print("\n--- Attention Ablation Table ---")
print("\\begin{table}[htbp]")
print("\\centering")
print("\\caption{Attention Mechanism Ablation}")
print("\\begin{tabular}{|l|c|c|c|}")
print("\\hline")
print("\\textbf{Model} & \\textbf{TTA Acc} & \\textbf{Macro-F1} & \\textbf{AUC} \\\\ \\hline")
for model in attention_models:
    if model in results:
        res = results[model]['tta']
        print(f"{model.replace('_', ' ')} & {res['acc']:.4f} & {res['f1_macro']:.4f} & {res['auc']:.4f} \\\\ \\hline")
print("\\end{tabular}")
print("\\end{table}")

# ---------- SAVE RESULTS ----------
print(f"\n{'='*70}")
print("SAVING RESULTS")
print(f"{'='*70}")

# Save JSON
results_for_json = {}
for model_name, res in results.items():
    results_for_json[model_name] = {
        'params': int(res['params']),
        'train_time': float(res['train_time']),
        'inf_time': float(res['inf_time']),
        'std': {k: float(v) if not isinstance(v, np.ndarray) else v.tolist()
                for k, v in res['std'].items()},
        'tta': {k: float(v) if not isinstance(v, np.ndarray) else v.tolist()
                for k, v in res['tta'].items()}
    }

with open('results.json', 'w') as f:
    json.dump(results_for_json, f, indent=2)
print("   ✓ Saved: results.json")

# Summary
print(f"\n{'='*70}")
print("EXPERIMENT COMPLETE!")
print(f"{'='*70}")
print(f"\nBest Model: {best_model_name}")
print(f"  TTA Accuracy: {results[best_model_name]['tta']['acc']:.4f}")
print(f"  TTA Macro-F1: {results[best_model_name]['tta']['f1_macro']:.4f}")
print(f"  TTA AUC:      {results[best_model_name]['tta']['auc']:.4f}")

print("\nGenerated Files:")
print("  Models:  models/best_*.keras (4 files)")
print("  Figures: figures/*.png & *.pdf (10 files)")
print("  Data:    results.json")

print("\n✓ All analyses completed successfully!")

E0000 00:00:1766950385.675236      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766950385.730941      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766950386.173260      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1766950386.173309      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1766950386.173312      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1766950386.173314      55 computation_placer.cc:177] computation placer already registered. Please check linka

TensorFlow Version: 2.19.0
GPU Available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

Setting up dataset...
Found dataset at: /kaggle/input/dental-radiography
Loading data...
Train: 4023 samples
Valid: 392 samples
Test:  237 samples
Classes: ['Cavity' 'Fillings' 'Impacted Tooth' 'Implant']
Class weights: {0: np.float64(5.184278350515464), 1: np.float64(0.3822690992018244), 2: np.float64(3.444349315068493), 3: np.float64(1.1100993377483444)}

Training EfficientNetB0


I0000 00:00:1766950417.455670      55 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15513 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0


Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb0_notop.h5
[1m16705208/16705208[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Model parameters: 4,384,679
Epoch 1/50


I0000 00:00:1766950454.080289     125 service.cc:152] XLA service 0x7c6068004230 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1766950454.080337     125 service.cc:160]   StreamExecutor device (0): Tesla P100-PCIE-16GB, Compute Capability 6.0
I0000 00:00:1766950459.956711     125 cuda_dnn.cc:529] Loaded cuDNN version 91002
I0000 00:00:1766950495.635020     125 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m159s[0m 671ms/step - accuracy: 0.3321 - loss: 0.7050 - val_accuracy: 0.1953 - val_loss: 0.9066 - learning_rate: 1.0000e-04
Epoch 2/50
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 11ms/step - accuracy: 0.4688 - loss: 0.4560 - val_accuracy: 0.1979 - val_loss: 0.9086 - learning_rate: 1.0000e-04
Epoch 3/50
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 298ms/step - accuracy: 0.5537 - loss: 0.3525 - val_accuracy: 0.2943 - val_loss: 0.6367 - learning_rate: 1.0000e-04
Epoch 4/50
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.7188 - loss: 0.2977 - val_accuracy: 0.2917 - val_loss: 0.6342 - learning_rate: 1.0000e-04
Epoch 5/50
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 301ms/step - accuracy: 0.6634 - loss: 0.2804 - val_accuracy: 0.8255 - val_loss: 0.2054 - learning_rate: 1.0000e-04
Epoch 6/50
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━