In [1]:
#!/usr/bin/env python3
"""
Cross-domain cvPCA visualization atlas (ViT ‚Üî Mouse Brain).

For each shared component:
 - Plotly bar chart of top ¬±10 images
 - Matplotlib mosaic of top/bottom 10 image thumbnails
 - Builds an index.html with all components

Author: Maria + Pl√§ku üêæ
"""

import os, numpy as np, matplotlib.pyplot as plt, plotly.graph_objects as go
from PIL import Image

# ---------------------------------------------------------------
# CONFIG
# ---------------------------------------------------------------
AREA_NAME   = "VISam"
CVPCA_PATH  = f"vit_{AREA_NAME}_cvpca_results.npz"
IMGS_PATH   = "/home/maria/MITNeuralComputation/vit_embeddings/images"
OUT_DIR     = f"cvpca_atlas_{AREA_NAME}"
TOP_K       = 10
os.makedirs(OUT_DIR, exist_ok=True)

# ---------------------------------------------------------------
# LOAD RESULTS
# ---------------------------------------------------------------
print(f"üîπ Loading cvPCA results for {AREA_NAME} ...")
res = np.load(CVPCA_PATH, allow_pickle=True)

S           = res["singular_values"]
shared_frac = res["shared_fraction"]
vit_scores  = res["vit_scores"]
brain_scores= res["brain_scores"]

n_images, n_comps = brain_scores.shape
print(f"{n_comps} shared components, {n_images} images")

# ---------------------------------------------------------------
# IMAGE IDS AND LABELS (use scene_000.png ...)
# ---------------------------------------------------------------
image_ids = [f"scene_{i:03d}" for i in range(n_images)]
scene_labels = image_ids  # fallback if no semantic labels available

# ---------------------------------------------------------------
# FUNCTIONS
# ---------------------------------------------------------------
def plot_shared_bar(scores, pc_idx, labels, top_k=10):
    """Plotly bar chart for top/bottom image loadings along shared component."""
    v = scores[:, pc_idx]
    v = v - np.mean(v)
    pos_idx = np.argsort(v)[-top_k:][::-1]
    neg_idx = np.argsort(v)[:top_k]

    fig = go.Figure()
    fig.add_trace(go.Bar(x=v[pos_idx],
                         y=[labels[i] for i in pos_idx],
                         orientation="h",
                         marker_color="green",
                         name="Positive"))
    fig.add_trace(go.Bar(x=v[neg_idx],
                         y=[labels[i] for i in neg_idx],
                         orientation="h",
                         marker_color="red",
                         name="Negative"))

    xmax = float(np.max(np.abs(v))) if v.size else 1.0
    fig.update_xaxes(range=[-xmax, xmax])
    fig.update_layout(
        title=(f"Shared cvPC{pc_idx+1} "
               f"({shared_frac[pc_idx]*100:.2f}% shared variance)"),
        barmode="overlay",
        xaxis_title="Image loading (centered)",
        yaxis_title="Image",
        template="plotly_white",
        height=600
    )
    return fig

def save_image_mosaic(scores, pc_idx, image_ids, imgs_path, out_dir, top_n=10):
    """Save mosaic of top/bottom images for a given shared component."""
    v = scores[:, pc_idx]
    top_idx = np.argsort(v)[-top_n:][::-1]
    bot_idx = np.argsort(v)[:top_n]

    fig, axes = plt.subplots(2, top_n, figsize=(2.5 * top_n, 5))
    for i, idx in enumerate(top_idx):
        path = os.path.join(imgs_path, f"{image_ids[idx]}.png")
        if os.path.exists(path):
            axes[0, i].imshow(Image.open(path))
        axes[0, i].set_title(f"Top {i+1}\n({v[idx]:.2f})", fontsize=8)
        axes[0, i].axis("off")

    for i, idx in enumerate(bot_idx):
        path = os.path.join(imgs_path, f"{image_ids[idx]}.png")
        if os.path.exists(path):
            axes[1, i].imshow(Image.open(path))
        axes[1, i].set_title(f"Bottom {i+1}\n({v[idx]:.2f})", fontsize=8)
        axes[1, i].axis("off")

    plt.suptitle(f"Shared cvPC{pc_idx+1} ({shared_frac[pc_idx]*100:.2f}% var)",
                 fontsize=12)
    plt.tight_layout()
    out_path = os.path.join(out_dir, f"cvPC{pc_idx+1:02d}_images.png")
    plt.savefig(out_path, dpi=150)
    plt.close(fig)

# ---------------------------------------------------------------
# MAIN LOOP
# ---------------------------------------------------------------
for pc_idx in range(min(10, n_comps)):  # show first 10 for brevity
    print(f"Saving component {pc_idx+1}/{n_comps} ...")
    # bar chart
    fig = plot_shared_bar(brain_scores, pc_idx, scene_labels, top_k=TOP_K)
    fig.write_html(os.path.join(OUT_DIR, f"cvPC{pc_idx+1:02d}_bars.html"))
    # image mosaic
    save_image_mosaic(brain_scores, pc_idx, image_ids, IMGS_PATH, OUT_DIR, top_n=TOP_K)

# ---------------------------------------------------------------
# CUMULATIVE VARIANCE PLOT
# ---------------------------------------------------------------
cum_shared = np.cumsum(shared_frac)
plt.figure(figsize=(6,4))
plt.plot(np.arange(1,len(cum_shared)+1), cum_shared*100, "o-")
plt.xlabel("Shared component")
plt.ylabel("Cumulative shared variance (%)")
plt.title(f"ViT ‚Üî {AREA_NAME} cvPCA cumulative variance")
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "cumulative_shared_variance.png"), dpi=150)
plt.close()

# ---------------------------------------------------------------
# BUILD HTML INDEX
# ---------------------------------------------------------------
print("üîπ Building index.html ...")
index_path = os.path.join(OUT_DIR, "index.html")

bar_files = sorted([f for f in os.listdir(OUT_DIR) if f.endswith("_bars.html")])
pcs = [f.split("_")[0] for f in bar_files]

html_parts = [
    "<html><head><meta charset='utf-8'/>",
    f"<title>ViT ‚Üî {AREA_NAME} cvPCA Atlas</title>",
    "<style>body{font-family:sans-serif;background:#f9f9f9;margin:40px;}"
    "iframe{border:none;width:100%;height:500px;}"
    "img{width:100%;border-radius:8px;box-shadow:0 2px 4px rgba(0,0,0,0.2);}"
    ".panel{background:white;padding:20px;border-radius:10px;margin-bottom:40px;"
    "box-shadow:0 0 10px rgba(0,0,0,0.05);}</style></head><body>",
    f"<h1>Shared Representational Atlas ‚Äî ViT ‚Üî {AREA_NAME}</h1>",
    f"<p>First {len(pcs)} shared components visualized.</p>",
    f"<img src='cumulative_shared_variance.png' style='max-width:600px;'/>"
]

for pc_prefix in pcs:
    num = pc_prefix.replace("cvPC", "")
    bar_html = f"{pc_prefix}_bars.html"
    img_png = f"{pc_prefix}_images.png"
    html_parts.append(f"<div class='panel' id='{pc_prefix}'>"
                      f"<h2>Shared cvPC {int(num)}</h2>"
                      f"<iframe src='{bar_html}'></iframe>"
                      f"<img src='{img_png}' alt='cvPC{num} images'/>"
                      "</div>")

html_parts.append("</body></html>")
with open(index_path, "w") as f:
    f.write("\n".join(html_parts))

print(f"‚úÖ HTML index created: {index_path}")


üîπ Loading cvPCA results for VISam ...
81 shared components, 118 images
Saving component 1/81 ...
Saving component 2/81 ...
Saving component 3/81 ...
Saving component 4/81 ...
Saving component 5/81 ...
Saving component 6/81 ...
Saving component 7/81 ...
Saving component 8/81 ...
Saving component 9/81 ...
Saving component 10/81 ...
üîπ Building index.html ...
‚úÖ HTML index created: cvpca_atlas_VISam/index.html


In [2]:
#!/usr/bin/env python3
"""
Cross-domain cvPCA visualization atlas (ViT ‚Üî Mouse Brain) with scene labels.

For each shared component:
 - Plotly bar chart (top ¬±10 labeled scenes)
 - Matplotlib mosaic (top/bottom 10 images with labels)
 - HTML index with all components

Author: Maria + Pl√§ku üêæ
"""

import os, pickle, numpy as np, matplotlib.pyplot as plt
from PIL import Image
import plotly.graph_objects as go

# ---------------------------------------------------------------
# CONFIG
# ---------------------------------------------------------------
AREA_NAME   = "VISam"
CVPCA_PATH  = f"vit_{AREA_NAME}_cvpca_results.npz"
VIT_PATH    = "/home/maria/Documents/HuggingMouseData/MouseViTEmbeddings/google_vit-base-patch16-224_embeddings_logits.pkl"
IMGS_PATH   = "/home/maria/MITNeuralComputation/vit_embeddings/images"
OUT_DIR     = f"cvpca_atlas_{AREA_NAME}"
TOP_K       = 10
os.makedirs(OUT_DIR, exist_ok=True)

# ---------------------------------------------------------------
# LOAD cvPCA RESULTS
# ---------------------------------------------------------------
print(f"üîπ Loading cvPCA results for {AREA_NAME} ...")
res = np.load(CVPCA_PATH, allow_pickle=True)
S           = res["singular_values"]
shared_frac = res["shared_fraction"]
brain_scores= res["brain_scores"]
n_images, n_comps = brain_scores.shape
print(f"{n_comps} shared components, {n_images} images")

# ---------------------------------------------------------------
# LOAD ViT LABELS (scene names)
# ---------------------------------------------------------------
print("üîπ Loading ViT class labels ...")
try:
    from torchvision.models import vit_b_16, ViT_B_16_Weights
    class_names = ViT_B_16_Weights.IMAGENET1K_V1.meta["categories"]
except Exception:
    class_names = [f"class_{i}" for i in range(1000)]

with open(VIT_PATH, 'rb') as f:
    vit_logits = pickle.load(f)['natural_scenes']

top1_idx = np.argmax(vit_logits, axis=1)
scene_labels = [class_names[i] for i in top1_idx]
image_ids = [f"scene_{i:03d}" for i in range(n_images)]
print(f"‚úÖ Loaded {len(scene_labels)} scene labels.")

# ---------------------------------------------------------------
# BAR CHART (Plotly)
# ---------------------------------------------------------------
def plot_shared_bar(scores, pc_idx, labels, top_k=10):
    """Plotly bar chart for top/bottom image loadings along shared component."""
    v = scores[:, pc_idx] - np.mean(scores[:, pc_idx])
    pos_idx = np.argsort(v)[-top_k:][::-1]
    neg_idx = np.argsort(v)[:top_k]

    pos_names = [labels[i] for i in pos_idx]
    neg_names = [labels[i] for i in neg_idx]
    pos_y = [name + " " for name in pos_names]  # prevent label overlap
    neg_y = [name for name in neg_names]

    fig = go.Figure()
    fig.add_trace(go.Bar(x=v[pos_idx], y=pos_y, orientation="h",
                         marker_color="green", name="Positive"))
    fig.add_trace(go.Bar(x=v[neg_idx], y=neg_y, orientation="h",
                         marker_color="red", name="Negative"))

    xmax = float(np.max(np.abs(v))) if v.size else 1.0
    fig.update_xaxes(range=[-xmax, xmax])
    fig.update_layout(
        title=f"Shared cvPC{pc_idx+1} ‚Äî top ¬±{top_k} labeled scenes ({AREA_NAME})",
        barmode="overlay",
        xaxis_title="Shared loading weight (centered)",
        yaxis_title="Scene label",
        template="plotly_white"
    )
    return fig

# ---------------------------------------------------------------
# IMAGE MOSAIC (Matplotlib)
# ---------------------------------------------------------------
def save_image_mosaic(scores, pc_idx, image_ids, imgs_path, labels, out_dir, top_n=10):
    v = scores[:, pc_idx]
    top_idx = np.argsort(v)[-top_n:][::-1]
    bot_idx = np.argsort(v)[:top_n]

    fig, axes = plt.subplots(2, top_n, figsize=(2.5 * top_n, 5))
    for i, idx in enumerate(top_idx):
        path = os.path.join(imgs_path, f"{image_ids[idx]}.png")
        if os.path.exists(path):
            axes[0, i].imshow(Image.open(path))
        label = labels[idx]
        axes[0, i].set_title(f"{label}\n({v[idx]:.2f})", fontsize=7)
        axes[0, i].axis("off")

    for i, idx in enumerate(bot_idx):
        path = os.path.join(imgs_path, f"{image_ids[idx]}.png")
        if os.path.exists(path):
            axes[1, i].imshow(Image.open(path))
        label = labels[idx]
        axes[1, i].set_title(f"{label}\n({v[idx]:.2f})", fontsize=7)
        axes[1, i].axis("off")

    plt.suptitle(f"Example images for shared cvPC{pc_idx+1} "
                 f"({AREA_NAME}, {shared_frac[pc_idx]*100:.2f}% shared var)",
                 fontsize=11)
    plt.tight_layout()
    out_path = os.path.join(out_dir, f"cvPC{pc_idx+1:02d}_images.png")
    plt.savefig(out_path, dpi=150)
    plt.close(fig)

# ---------------------------------------------------------------
# MAIN LOOP
# ---------------------------------------------------------------
for pc_idx in range(min(10, n_comps)):  # first 10 for visualization
    print(f"Saving cvPC{pc_idx+1}/{n_comps} ...")
    fig = plot_shared_bar(brain_scores, pc_idx, scene_labels, top_k=TOP_K)
    fig.write_html(os.path.join(OUT_DIR, f"cvPC{pc_idx+1:02d}_bars.html"))
    save_image_mosaic(brain_scores, pc_idx, image_ids, IMGS_PATH,
                      scene_labels, OUT_DIR, top_n=TOP_K)

# ---------------------------------------------------------------
# CUMULATIVE SHARED VARIANCE
# ---------------------------------------------------------------
cum_shared = np.cumsum(shared_frac)
plt.figure(figsize=(6,4))
plt.plot(np.arange(1,len(cum_shared)+1), cum_shared*100, "o-")
plt.xlabel("Shared component")
plt.ylabel("Cumulative shared variance (%)")
plt.title(f"ViT ‚Üî {AREA_NAME} cvPCA cumulative variance")
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "cumulative_shared_variance.png"), dpi=150)
plt.close()

# ---------------------------------------------------------------
# HTML INDEX
# ---------------------------------------------------------------
print("üîπ Building index.html ...")
index_path = os.path.join(OUT_DIR, "index.html")

bar_files = sorted([f for f in os.listdir(OUT_DIR) if f.endswith("_bars.html")])
pcs = [f.split("_")[0] for f in bar_files]

html_parts = [
    "<html><head><meta charset='utf-8'/>",
    f"<title>ViT ‚Üî {AREA_NAME} cvPCA Atlas</title>",
    "<style>body{font-family:sans-serif;background:#f9f9f9;margin:40px;}"
    "iframe{border:none;width:100%;height:500px;}"
    "img{width:100%;border-radius:8px;box-shadow:0 2px 4px rgba(0,0,0,0.2);}"
    ".panel{background:white;padding:20px;border-radius:10px;margin-bottom:40px;"
    "box-shadow:0 0 10px rgba(0,0,0,0.05);}</style></head><body>",
    f"<h1>Shared Representational Atlas ‚Äî ViT ‚Üî {AREA_NAME}</h1>",
    f"<p>First {len(pcs)} shared components visualized.</p>",
    f"<img src='cumulative_shared_variance.png' style='max-width:600px;'/>"
]

for pc_prefix in pcs:
    num = pc_prefix.replace("cvPC", "")
    bar_html = f"{pc_prefix}_bars.html"
    img_png = f"{pc_prefix}_images.png"
    html_parts.append(f"<div class='panel' id='{pc_prefix}'>"
                      f"<h2>Shared cvPC {int(num)}</h2>"
                      f"<iframe src='{bar_html}'></iframe>"
                      f"<img src='{img_png}' alt='cvPC{num} images'/>"
                      "</div>")

html_parts.append("</body></html>")
with open(index_path, "w") as f:
    f.write("\n".join(html_parts))

print(f"‚úÖ HTML index created: {index_path}")


üîπ Loading cvPCA results for VISam ...
81 shared components, 118 images
üîπ Loading ViT class labels ...
‚úÖ Loaded 118 scene labels.
Saving cvPC1/81 ...
Saving cvPC2/81 ...
Saving cvPC3/81 ...
Saving cvPC4/81 ...
Saving cvPC5/81 ...
Saving cvPC6/81 ...
Saving cvPC7/81 ...
Saving cvPC8/81 ...
Saving cvPC9/81 ...
Saving cvPC10/81 ...
üîπ Building index.html ...
‚úÖ HTML index created: cvpca_atlas_VISam/index.html
