In [None]:
import torch
import torch.nn.functional as F
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
    

In [None]:
join = lambda path: os.path.join(os.getcwd(), path)
def evaluate_attack(weights_path=join("outputs/model_best.pth"), test_dir=join("dataset/test")):
    model = models.resnet18(pretrained=False)

    model.fc = torch.nn.Linear(model.fc.in_features, 2)
    if not os.path.exists(weights_path):
        print("ERROR: Model weights not found. Train first.")
        return
    
    model.load_state_dict(torch.load(weights_path))
    model.to(DEVICE)
    
    model.eval()

    if not os.path.exists(test_dir):
        print(f"ERROR: Test directory not found: {test_dir}")
        return
    
    dataset = datasets.ImageFolder(test_dir, transform=transform)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

    print(f"Attack images found: {len(dataset)}")

    confidences = []
    print("Performing atack with diffusion images...")
    with torch.no_grad():
        for inputs, _ in dataloader:
            inputs = inputs.to(DEVICE)
            logits = model(inputs)
            probs = F.softmax(logits, dim=1)

            max_conf, _ = torch.max(probs, dim=1)
            confidences.extend(max_conf.cpu().numpy())

    plot_histogram(confidences)

def plot_histogram(confidences):
    plt.figure(figsize=(10, 6))
    plt.hist(confidences, bins=20, range=(0, 1), color='red', alpha=0.7, edgecolor='black')
    plt.title("Generalization Failure: Confiance in Unknown Samples (Diffusion)", fontsize=14)

    plt.xlabel("Network Confidence (Softmax Score)", fontsize=12)
    plt.ylabel("Sample Count", fontsize=12)
    plt.axvline(x=0.9, color='blue', linestyle='--', label='High Confidence Threshold')
    plt.legend()
    plt.grid(axis='y', alpha=0.5)

    plt.text(0.1, 5, "Ideally, the samples\nshould be here\n(Low Confidence)", 
             bbox=dict(facecolor='green', alpha=0.2))
    plt.text(0.75, 5, "Reality:\nOverconfidence", 
             bbox=dict(facecolor='red', alpha=0.2))
    output_path = './outputs/histogram_failure.png'
    plt.savefig(output_path)
    print(f"Generated Graph: {output_path}")


In [None]:
evaluate_attack()