In [None]:
import os
import torch
import numpy as np
import torchvision.transforms as T
from torchvision.io import read_image
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from skimage.transform import resize
from PIL import Image
import cv2
import random
from tqdm import tqdm

from ppnet import model

In [None]:
MODEL_PATH = "./checkpoints/300push74.87.pth"
DATASET_PATH = "./datasets/cub200/test/"
IMG_PATH = os.path.join(DATASET_PATH, "001.Black_footed_Albatross/1.jpg")
DEVICE = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

print("Using device:", DEVICE)

## Consistency evaluation

In [None]:
def load_image(path, img_size=224):
    img = Image.open(path).convert('RGB')
    transform = T.Compose([
        T.Resize((img_size, img_size)),
        T.ToTensor(),
    ])
    return transform(img).unsqueeze(0)

def perturb_image(image, mode='rotate', degree=5, noise_std=0.05):
    if mode == 'rotate':
        transform = T.Compose([
            T.ToPILImage(),
            T.RandomRotation(degrees=(degree, degree)),
            T.ToTensor(),
        ])
        return transform(image.squeeze()).unsqueeze(0)
    elif mode == 'noise':
        noise = torch.randn_like(image) * noise_std
        return torch.clamp(image + noise, 0, 1)
    elif mode == 'shift':
        transform = T.Compose([
            T.ToPILImage(),
            T.RandomAffine(degrees=0, translate=(0.05, 0.05)),
            T.ToTensor(),
        ])
        return transform(image.squeeze()).unsqueeze(0)
    else:
        raise ValueError("Unsupported perturbation type")


In [None]:
ppnet = torch.load(MODEL_PATH, map_location=DEVICE)
ppnet.eval()
ppnet = ppnet.to(DEVICE)

In [None]:
def get_activation_heatmap(ppnet, image_tensor, prototype_idx=0):
    image_tensor = image_tensor.to(DEVICE)
    with torch.no_grad():
        features, distances = ppnet.push_forward(image_tensor)
    act = -distances[0, prototype_idx].detach().cpu().numpy()
    act = act - act.min()
    act = act / act.max()
    act_resized = cv2.resize(act, (image_tensor.shape[3], image_tensor.shape[2]))
    return act_resized

In [None]:
def compute_iou(map1, map2, threshold=0.5):
    mask1 = map1 > threshold
    mask2 = map2 > threshold
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    return intersection / union if union > 0 else 0


In [None]:
# Load image
orig_img = load_image(IMG_PATH)
pert_img = perturb_image(orig_img, mode='rotate', degree=5)

# Choose a prototype index (e.g., the top one activated for this image)
heatmap_orig = get_activation_heatmap(ppnet, orig_img, prototype_idx=0)
heatmap_pert = get_activation_heatmap(ppnet, pert_img, prototype_idx=0)

iou_score = compute_iou(heatmap_orig, heatmap_pert)
print("IoU between original and rotated heatmaps:", iou_score)

In [None]:
def evaluate_consistency(ppnet, dataset_dir, n_images=100, perturbations=None, prototype_idx=0, img_size=224):
    """
    Compute average IoU consistency score for ProtoPNet explanations.
    """
    if perturbations is None:
        perturbations = [
            ("rotate", {"degree": 5}),
            ("noise", {"noise_std": 0.05}),
            ("shift", {}),
        ]

    # Collect image paths
    all_images = []
    for root, _, files in os.walk(dataset_dir):
        for f in files:
            if f.lower().endswith(('.jpg', '.jpeg', '.png')):
                all_images.append(os.path.join(root, f))

    random.shuffle(all_images)
    image_subset = all_images[:n_images]

    results = {mode: [] for mode, _ in perturbations}

    for img_path in tqdm(image_subset, desc="Evaluating consistency"):
        try:
            img = load_image(img_path, img_size=img_size)
            heatmap_orig = get_activation_heatmap(ppnet, img, prototype_idx=prototype_idx)

            for mode, kwargs in perturbations:
                pert_img = perturb_image(img, mode=mode, **kwargs)
                heatmap_pert = get_activation_heatmap(ppnet, pert_img, prototype_idx=prototype_idx)
                iou = compute_iou(heatmap_orig, heatmap_pert)
                results[mode].append(iou)
        except Exception as e:
            print(f"⚠️ Skipping {img_path} due to error: {e}")
            continue

    summary = {mode: {"mean": np.mean(vals), "std": np.std(vals)} for mode, vals in results.items()}
    return summary


In [None]:
consistency_scores = evaluate_consistency(ppnet, DATASET_PATH, n_images=100, prototype_idx=0)
for perturb, stats in consistency_scores.items():
    print(f"{perturb.capitalize():<10} — Mean IoU: {stats['mean']:.4f}, Std: {stats['std']:.4f}")

## Plausibility Evaluation

In [None]:
import os
import random
import matplotlib.pyplot as plt
from PIL import Image


def show_random_prototypes(base_path, num_prototypes=10, top_k=1, seed=42, save_path=None):
    # Collect available prototypes
    prototype_ids = [d for d in os.listdir(base_path) if d.isdigit()]
    prototype_ids = sorted(prototype_ids)

    if len(prototype_ids) == 0:
        raise ValueError(f"No prototype folders found in {base_path}")

    # Sample prototypes randomly
    random.seed(seed)
    sampled = random.sample(prototype_ids, min(num_prototypes, len(prototype_ids)))

    # Calculate grid size
    cols = top_k
    rows = len(sampled)

    plt.figure(figsize=(4 * cols, 3 * rows))

    for row_idx, proto_id in enumerate(sampled):
        for k in range(1, top_k + 1):
            img_path = os.path.join(base_path, proto_id, f"nearest-{k}_original_with_heatmap.png")

            # Fallback if heatmap version doesn't exist
            if not os.path.exists(img_path):
                img_path = os.path.join(base_path, proto_id, f"nearest-{k}_high_act_patch_in_original_img.png")
            if not os.path.exists(img_path):
                continue

            img = Image.open(img_path)
            plt.subplot(rows, cols, row_idx * cols + k)
            plt.imshow(img)
            plt.title(f"Proto {proto_id} - Top {k}", fontsize=10)
            plt.axis("off")

    plt.suptitle(f"{len(sampled)} Random Prototypes (Top-{top_k} patches each)", fontsize=14)
    plt.tight_layout(rect=[0, 0, 1, 0.96])

    if save_path:
        plt.savefig(save_path, dpi=200)
        print(f"Saved plausibility figure to {save_path}")

    plt.show()

    return sampled


In [None]:
base_path = "analysis/Assignment3/understandable-ProtoPNet/300push74.87.pth/global/nearest_prototypes/test"
sampled_protos = show_random_prototypes(base_path, num_prototypes=10, top_k=1, save_path="plausibility_examples.png")

## Usefulness Evaluation