In [1]:

import os
import numpy as np
import matplotlib.pyplot as plt
import cv2

# ==== Paths ====
image_path = "/Users/muhajav/Documents/TPSC/Gender Classification/data/celeba/img_align_celeba/000002.jpg"
saliency_path = "/Users/muhajav/Documents/TPSC/Gender Classification/Model Applied File/saliency_image01.npy"
lime_path = "/Users/muhajav/Documents/TPSC/Gender Classification/Model Applied File/lime_image01.npy"
output_dir = "results"
os.makedirs(output_dir, exist_ok=True)

# ==== Load Data ====
# Load original image
img = cv2.imread(image_path)
if img is None:
    raise FileNotFoundError(f"Could not load image at {image_path}")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# Load heatmaps
saliency_map = np.load(saliency_path)

# Load LIME (stored as dict with key 0)
lime_raw = np.load(lime_path, allow_pickle=True).item()
if isinstance(lime_raw, dict) and 0 in lime_raw:
    lime_map = lime_raw[0]
else:
    raise ValueError(f"Unexpected structure in {lime_path}: {type(lime_raw)}")

# ==== Normalize heatmaps ====
def normalize_heatmap(hmap):
    hmap = np.nan_to_num(hmap)  # remove NaN if any
    hmap -= hmap.min()
    if hmap.max() > 0:
        hmap /= hmap.max()
    return hmap

saliency_map = normalize_heatmap(saliency_map)
lime_map = normalize_heatmap(lime_map)

# ==== Visualization ====
# 1. Side-by-side comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(img)
axes[0].set_title("Original Image")
axes[0].axis("off")

axes[1].imshow(saliency_map, cmap="jet")
axes[1].set_title("Saliency Heatmap (Yudi)")
axes[1].axis("off")

axes[2].imshow(lime_map, cmap="jet")
axes[2].set_title("LIME Explanation (Zidni)")
axes[2].axis("off")

plt.tight_layout()
plt.savefig(os.path.join(output_dir, "comparison_side_by_side2.png"), dpi=300)
plt.close()

# 2. Overlay heatmaps on original
def overlay_heatmap(image, heatmap, alpha=0.5, cmap="jet"):
    heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
    colored = plt.get_cmap(cmap)(heatmap)[:, :, :3]  # RGBA -> RGB
    overlay = (1 - alpha) * image / 255.0 + alpha * colored
    return np.clip(overlay, 0, 1)

saliency_overlay = overlay_heatmap(img, saliency_map)
lime_overlay = overlay_heatmap(img, lime_map)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(img)
axes[0].set_title("Original Image")
axes[0].axis("off")

axes[1].imshow(saliency_overlay)
axes[1].set_title("Overlay: Saliency (Yudi)")
axes[1].axis("off")

axes[2].imshow(lime_overlay)
axes[2].set_title("Overlay: LIME (Zidni)")
axes[2].axis("off")

plt.tight_layout()
plt.savefig(os.path.join(output_dir, "comparison_overlay2.png"), dpi=300)
plt.close()

print(f"Visualizations saved in '{output_dir}/'")



Visualizations saved in 'results/'
