In [3]:
#!/usr/bin/env python3
"""
Display all images from the positive and negative sides of the
first shared component (cvPC1) between ViT and mouse brain.

Automatically copies required images into a local 'images/' folder
so the HTML works offline.
"""

import os, shutil, pickle, numpy as np
from PIL import Image

# ---------------------------------------------------------------
# CONFIG
# ---------------------------------------------------------------
AREA_NAME   = "VISam"
CVPCA_PATH  = f"vit_{AREA_NAME}_cvpca_results.npz"
VIT_PATH    = "/home/maria/Documents/HuggingMouseData/MouseViTEmbeddings/google_vit-base-patch16-224_embeddings_logits.pkl"
SRC_IMGS    = "/home/maria/MITNeuralComputation/vit_embeddings/images"
OUT_DIR     = f"cvPC1_gallery_{AREA_NAME}"
OUT_HTML    = os.path.join(OUT_DIR, "index.html")
LOCAL_IMGS  = os.path.join(OUT_DIR, "images")

os.makedirs(LOCAL_IMGS, exist_ok=True)

# ---------------------------------------------------------------
# LOAD DATA
# ---------------------------------------------------------------
print(f"üîπ Loading cvPCA results for {AREA_NAME} ...")
res = np.load(CVPCA_PATH, allow_pickle=True)
brain_scores = res["brain_scores"]
shared_frac  = res["shared_fraction"]

n_images = brain_scores.shape[0]
scores = brain_scores[:, 0]  # first shared component
print(f"Loaded {n_images} image scores for shared cvPC1.")

# ---------------------------------------------------------------
# LOAD SCENE LABELS
# ---------------------------------------------------------------
print("üîπ Loading ViT labels ...")
try:
    from torchvision.models import vit_b_16, ViT_B_16_Weights
    class_names = ViT_B_16_Weights.IMAGENET1K_V1.meta["categories"]
except Exception:
    class_names = [f"class_{i}" for i in range(1000)]

with open(VIT_PATH, 'rb') as f:
    vit_logits = pickle.load(f)['natural_scenes']

top1_idx = np.argmax(vit_logits, axis=1)
scene_labels = [class_names[i] for i in top1_idx]
image_ids = [f"scene_{i:03d}" for i in range(n_images)]

# ---------------------------------------------------------------
# SORT IMAGES BY SCORE
# ---------------------------------------------------------------
sorted_idx = np.argsort(scores)
neg_idx = sorted_idx[:len(sorted_idx)//2][::-1]   # most negative
pos_idx = sorted_idx[len(sorted_idx)//2:][::-1]   # most positive

# ---------------------------------------------------------------
# COPY USED IMAGES LOCALLY
# ---------------------------------------------------------------
print("üîπ Copying required images locally ...")
used_idx = np.concatenate([pos_idx, neg_idx])
for idx in used_idx:
    src_path = os.path.join(SRC_IMGS, f"{image_ids[idx]}.png")
    dst_path = os.path.join(LOCAL_IMGS, f"{image_ids[idx]}.png")
    if os.path.exists(src_path):
        shutil.copy(src_path, dst_path)
print(f"‚úÖ Copied {len(used_idx)} images into '{LOCAL_IMGS}/'")

# ---------------------------------------------------------------
# BUILD HTML
# ---------------------------------------------------------------
print("üîπ Building HTML gallery ...")

html = [
    "<html><head><meta charset='utf-8'/>",
    f"<title>cvPC1 Image Gallery ‚Äî ViT ‚Üî {AREA_NAME}</title>",
    "<style>",
    "body { font-family: sans-serif; background: #fafafa; margin: 40px; }",
    "h1, h2 { text-align: center; }",
    ".grid { display: flex; flex-wrap: wrap; justify-content: center; }",
    ".imgbox { margin: 8px; text-align: center; width: 180px; }",
    "img { width: 160px; height: 160px; object-fit: cover; border-radius: 8px;",
    " box-shadow: 0 2px 6px rgba(0,0,0,0.2); }",
    ".label { font-size: 10px; color: #333; }",
    "</style></head><body>",
    f"<h1>Shared cvPC1 ‚Äî ViT ‚Üî {AREA_NAME}</h1>",
    f"<h3>({shared_frac[0]*100:.2f}% shared variance)</h3>",
    "<h2 style='color:green;'>Positive direction</h2><div class='grid'>"
]

for idx in pos_idx:
    img_rel = f"images/{image_ids[idx]}.png"
    if not os.path.exists(os.path.join(LOCAL_IMGS, f"{image_ids[idx]}.png")):
        continue
    html.append(f"<div class='imgbox'>"
                f"<img src='{img_rel}'/>"
                f"<div class='label'>{scene_labels[idx]}<br>"
                f"({scores[idx]:.2f})</div></div>")

html.append("</div><h2 style='color:red;'>Negative direction</h2><div class='grid'>")

for idx in neg_idx:
    img_rel = f"images/{image_ids[idx]}.png"
    if not os.path.exists(os.path.join(LOCAL_IMGS, f"{image_ids[idx]}.png")):
        continue
    html.append(f"<div class='imgbox'>"
                f"<img src='{img_rel}'/>"
                f"<div class='label'>{scene_labels[idx]}<br>"
                f"({scores[idx]:.2f})</div></div>")

html.append("</div></body></html>")

with open(OUT_HTML, "w") as f:
    f.write("\n".join(html))

print(f"‚úÖ HTML gallery created: {OUT_HTML}")
print(f"üñºÔ∏è  Open this file in your browser ‚Äî all images should now display correctly.")


üîπ Loading cvPCA results for VISam ...
Loaded 118 image scores for shared cvPC1.
üîπ Loading ViT labels ...
üîπ Copying required images locally ...
‚úÖ Copied 118 images into 'cvPC1_gallery_VISam/images/'
üîπ Building HTML gallery ...
‚úÖ HTML gallery created: cvPC1_gallery_VISam/index.html
üñºÔ∏è  Open this file in your browser ‚Äî all images should now display correctly.
