In [None]:
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
sys.path.insert(0, str(PROJECT_ROOT))

print(f"‚úÖ Project root: {PROJECT_ROOT}")
print(f"‚úÖ Python version: {sys.version}")

import torch
print(f"‚úÖ PyTorch version: {torch.__version__}")
print(f"‚úÖ CUDA available: {torch.cuda.is_available()}")


In [None]:
# Now import torchvision and model
import torchvision
from src.models.hybrid import HybridDetector
from src.data.dataset import create_dataloaders

print(f"‚úÖ Torchvision version: {torchvision.__version__}")
print(f"‚úÖ Imports successful!")


In [None]:
# Load best checkpoint
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT = PROJECT_ROOT / "models" / "checkpoints" / "hybrid_imaginet_best.pth"

print(f"üìÇ Loading checkpoint: {CHECKPOINT.name}")
checkpoint = torch.load(CHECKPOINT, map_location=DEVICE)

print(f"   Epoch: {checkpoint['epoch'] + 1}")
print(f"   Val Acc: {checkpoint['val_acc']:.2f}%")
print(f"   Val Loss: {checkpoint['val_loss']:.4f}")


In [None]:
# Load model
model = HybridDetector(num_classes=2, pretrained=False, model_name='mobilenet_v3_small').to(DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print("‚úÖ Model loaded successfully!")


In [None]:
# Load validation data
DATA_ROOT = PROJECT_ROOT / "data" / "raw" / "imaginet" / "subset"
DCT_DIR = PROJECT_ROOT / "data" / "processed" / "imaginet" / "dct_features"

_, val_loader = create_dataloaders(
    root_dir=DATA_ROOT,
    dct_dir=DCT_DIR if DCT_DIR.exists() else None,
    batch_size=128,
    num_workers=0,
    train_ratio=0.8,
    seed=42
)

print(f"‚úÖ Validation loader ready: {len(val_loader)} batches")


In [None]:
# Run evaluation
import numpy as np
from tqdm.notebook import tqdm

all_preds = []
all_labels = []
all_probs = []

print("üîç Running inference...")

with torch.no_grad():
    for img_masked, dct_feat, labels in tqdm(val_loader, desc="Evaluating"):
        img_masked = img_masked.to(DEVICE)
        dct_feat = dct_feat.to(DEVICE)
        
        outputs = model(img_masked, dct_feat)
        probs = torch.softmax(outputs, dim=1)
        preds = outputs.argmax(dim=1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())
        all_probs.extend(probs[:, 1].cpu().numpy())

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
all_probs = np.array(all_probs)

accuracy = (all_preds == all_labels).mean() * 100
print(f"\n‚úÖ Evaluation complete!")
print(f"   Accuracy: {accuracy:.2f}%")
print(f"   Total samples: {len(all_labels)}")


In [None]:
# Confusion Matrix & Metrics
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns

# Classification Report
print("="*60)
print("üìä CLASSIFICATION REPORT")
print("="*60)
report = classification_report(
    all_labels, 
    all_preds, 
    target_names=['Real', 'Fake'],
    digits=4
)
print(report)

# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Real', 'Fake'],
            yticklabels=['Real', 'Fake'])
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.show()

print(f"\nTrue Negatives (Real ‚Üí Real):  {cm[0, 0]}")
print(f"False Positives (Real ‚Üí Fake): {cm[0, 1]}")
print(f"False Negatives (Fake ‚Üí Real): {cm[1, 0]}")
print(f"True Positives (Fake ‚Üí Fake):  {cm[1, 1]}")


In [None]:
# ROC Curve
from sklearn.metrics import roc_auc_score, roc_curve

auc_score = roc_auc_score(all_labels, all_probs)
fpr, tpr, thresholds = roc_curve(all_labels, all_probs)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, linewidth=2, label=f'ROC Curve (AUC = {auc_score:.4f})')
plt.plot([0, 1], [0, 1], 'k--', label='Random Classifier')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve - AI-Generated Image Detection')
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"üìä AUC Score: {auc_score:.4f}")
