In [None]:
# --- 1. CONFIGURATION AND IMPORTS ---
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from sklearn.metrics import roc_curve, auc
import sys
import os

# Add parent directory to access models/ and data/ modules
sys.path.append('..')

from models.had_net import HybridAnomalyNet
from models.baseline_resnet import DeepOnlyAnomalyNet
from data.dataset import HybridAnomalyDataset
from utils.training import init_center # Corrected filename import

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- 2. EVALUATION FUNCTION ---
def get_anomaly_scores(model, loader, center, device, is_hybrid=True):
    """
    Computes anomaly scores (squared Euclidean distance to center) 
    for the entire test set.
    """
    model.eval()
    scores = []
    labels = []
    
    with torch.no_grad():
        for images, features, lbl in loader:
            images = images.to(device)
            
            # Forward pass based on model type
            if is_hybrid:
                outputs = model(images, features.to(device))
            else:
                outputs = model(images)
            
            # Calculate distance to center (Anomaly Score)
            dists = torch.sum((outputs - center) ** 2, dim=1)
            
            scores.extend(dists.cpu().numpy())
            labels.extend(lbl.numpy())
            
    return np.array(scores), np.array(labels)

# --- 3. TEST DATA LOADING ---
# Ensure anomaly_test_df and test_feats are available in memory
# test_dataset = HybridAnomalyDataset(anomaly_test_df, test_feats, transform=data_transforms)
# test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# --- 4. MODEL LOADING ---
# Load weights for both models (assuming they are stored in a 'weights' folder)
# A. Proposed Hybrid Model (HAD-Net)
hybrid_model = HybridAnomalyNet().to(device)
# hybrid_model.load_state_dict(torch.load("../weights/hybrid_hadnet.pth"))
# center_h = torch.load("../weights/center_hybrid.pth")

# B. Baseline Model (Deep-Only ResNet)
baseline_model = DeepOnlyAnomalyNet().to(device)
# baseline_model.load_state_dict(torch.load("../weights/baseline_resnet.pth"))
# center_b = torch.load("../weights/center_baseline.pth")

# --- 5. PERFORMANCE METRICS CALCULATION ---
print("Computing scores for Hybrid HAD-Net...")
scores_h, labels_h = get_anomaly_scores(hybrid_model, test_loader, center_h, device, is_hybrid=True)

print("Computing scores for Baseline ResNet...")
scores_b, labels_b = get_anomaly_scores(baseline_model, test_loader, center_b, device, is_hybrid=False)

# ROC Curve and AUC calculation
fpr_h, tpr_h, _ = roc_curve(labels_h, scores_h)
auc_h = auc(fpr_h, tpr_h)

fpr_b, tpr_b, _ = roc_curve(labels_b, scores_b)
auc_b = auc(fpr_b, tpr_b)

# --- 6. VISUALIZATION (THE "MONEY SHOT") ---
plt.figure(figsize=(10, 7))
plt.plot(fpr_h, tpr_h, color='green', lw=3, label=f'Proposed HAD-Net (AUC = {auc_h:.4f})')
plt.plot(fpr_b, tpr_b, color='red', lw=2, linestyle='--', label=f'Baseline ResNet18 (AUC = {auc_b:.4f})')
plt.plot([0, 1], [0, 1], color='gray', linestyle=':')

plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (1 - Specificity)')
plt.ylabel('True Positive Rate (Sensitivity)')
plt.title('Ablation Study: Impact of Handcrafted Texture Features')
plt.legend(loc="lower right")
plt.grid(alpha=0.3)
plt.show()

improvement = ((auc_h - auc_b)/auc_b)*100
print(f"Final Improvement: +{improvement:.2f}%")