In [None]:
# Re-run after kernel reset to regenerate the average Grad-CAM script output

import os
import cv2
import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt


In [None]:
# Configuration
base_run_dir = "runs/classify-bestmodel"
model_base = "sgkf05-yolo11s"
gradcam_variant = "EigenGradCAM"
output_dir = os.path.join(base_run_dir, model_base, gradcam_variant, "avg")

# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)

# Outcome classes
outcome_categories = ["TP", "FP", "FN", "TN"]
image_groups = {key: [] for key in outcome_categories}

# Find all prediction logs
csv_paths = glob(os.path.join(base_run_dir, model_base, "test-*", "image_predictions.csv"))


In [None]:
# Collect image paths per outcome class
for csv_path in tqdm(csv_paths, desc="Collecting image paths"):
    df = pd.read_csv(csv_path)
    test_dir = os.path.dirname(csv_path)
    gradcam_image_dir = os.path.join(test_dir, gradcam_variant, "images")

    for _, row in df.iterrows():
        true_label = row['true_label']
        pred_label = row['predicted_label']
        correct = row['correct']
        img_path = row['image_path']
        base_filename = os.path.splitext(os.path.basename(img_path))[0]
        gradcam_filename = f"{base_filename}_{pred_label}_gradcam.jpg"
        gradcam_path = os.path.join(gradcam_image_dir, gradcam_filename)

        if not os.path.exists(gradcam_path):
            continue

        if true_label == "PSS" and pred_label == "PSS" and correct == 1:
            image_groups["TP"].append(gradcam_path)
        elif true_label == "NRM" and pred_label == "PSS" and correct == 0:
            image_groups["FP"].append(gradcam_path)
        elif true_label == "PSS" and pred_label == "NRM" and correct == 0:
            image_groups["FN"].append(gradcam_path)
        elif true_label == "NRM" and pred_label == "NRM" and correct == 1:
            image_groups["TN"].append(gradcam_path)

In [None]:
# Averaging function
def average_images(image_paths):
    images = []
    for path in image_paths:
        img = cv2.imread(path)
        if img is not None:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            images.append(img.astype(np.float32))
    if images:
        avg_img = np.mean(images, axis=0).astype(np.uint8)
        return avg_img
    return None

In [None]:
# Process and save average images
for outcome, paths in image_groups.items():
    print(f"üñºÔ∏è Averaging {len(paths)} images for: {outcome}")
    avg_image = average_images(paths)
    if avg_image is not None:
        # Show the image
        plt.imshow(avg_image)
        plt.title(f"Average Grad-CAM ({outcome})")
        plt.axis("off")
        plt.show()

        # Save the image
        out_path = os.path.join(output_dir, f"avg_gradcam_{outcome}.jpg")
        cv2.imwrite(out_path, cv2.cvtColor(avg_image, cv2.COLOR_RGB2BGR))
        print(f"‚úÖ Saved: {out_path}")
    else:
        print(f"‚ö†Ô∏è No valid images found for {outcome}")