In [None]:
# --- 1. CONFIGURATION ET IMPORTS ---
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from sklearn.metrics import roc_curve, auc
import sys
import os

# Ajouter le dossier parent pour accéder aux modules models/ et data/
sys.path.append('..')

from models.had_net import HybridAnomalyNet
from models.baseline_resnet import DeepOnlyAnomalyNet
from data.dataset import HybridAnomalyDataset
from utils.trainer import init_center

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- 2. FONCTION D'ÉVALUATION ---
def get_anomaly_scores(model, loader, center, device, is_hybrid=True):
    model.eval()
    scores = []
    labels = []
    
    with torch.no_grad():
        for images, features, lbl in loader:
            images = images.to(device)
            
            # Forward pass selon le type de modèle
            if is_hybrid:
                outputs = model(images, features.to(device))
            else:
                outputs = model(images)
            
            # Calcul de la distance au centre (Score d'anomalie)
            dists = torch.sum((outputs - center) ** 2, dim=1)
            
            scores.extend(dists.cpu().numpy())
            labels.extend(lbl.numpy())
            
    return np.array(scores), np.array(labels)

# --- 3. CHARGEMENT DES DONNÉES DE TEST ---
# (On suppose que anomaly_test_df et test_feats sont déjà chargés en mémoire)
# test_dataset = HybridAnomalyDataset(anomaly_test_df, test_feats, transform=data_transforms)
# test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# --- 4. CHARGEMENT DES MODÈLES ---

# A. Modèle Hybride (HAD-Net)
hybrid_model = HybridAnomalyNet().to(device)
# hybrid_model.load_state_dict(torch.load("../weights/hybrid_model.pth"))
# center_h = torch.load("../weights/center_hybrid.pth")

# B. Modèle Baseline (Deep-Only)
baseline_model = DeepOnlyAnomalyNet().to(device)
# baseline_model.load_state_dict(torch.load("../weights/baseline_model.pth"))
# center_b = torch.load("../weights/center_baseline.pth")

# --- 5. CALCUL DES SCORES ET COURBES ROC ---

# Note : Pour la démo, remplacez par vos vrais scores calculés
print("Calcul des scores pour le modèle Hybride...")
scores_h, labels_h = get_anomaly_scores(hybrid_model, test_loader, center_h, device, is_hybrid=True)

print("Calcul des scores pour le modèle Baseline...")
scores_b, labels_b = get_anomaly_scores(baseline_model, test_loader, center_b, device, is_hybrid=False)

# Calcul AUC
fpr_h, tpr_h, _ = roc_curve(labels_h, scores_h)
auc_h = auc(fpr_h, tpr_h)

fpr_b, tpr_b, _ = roc_curve(labels_b, scores_b)
auc_b = auc(fpr_b, tpr_b)

# --- 6. VISUALISATION (THE "MONEY SHOT") ---

plt.figure(figsize=(10, 7))
plt.plot(fpr_h, tpr_h, color='green', lw=3, label=f'Proposed HAD-Net (AUC = {auc_h:.4f})')
plt.plot(fpr_b, tpr_b, color='red', lw=2, linestyle='--', label=f'ResNet18 Baseline (AUC = {auc_b:.4f})')
plt.plot([0, 1], [0, 1], color='gray', linestyle=':')

plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (1 - Specificity)')
plt.ylabel('True Positive Rate (Sensitivity)')
plt.title('Ablation Study: Impact of Handcrafted Texture Features')
plt.legend(loc="lower right")
plt.grid(alpha=0.3)
plt.show()

print(f"Improvement: +{((auc_h - auc_b)/auc_b)*100:.2f}%")