In [3]:
#!/usr/bin/env python3
"""
PCA visualization of ViT softmax+CLR embeddings:
 - U (image projections): histogram
 - V (class loadings): top ±10 ImageNet classes
 - Optional thumbnails of top-scoring images
"""

import os
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from scipy.special import softmax
from skbio.stats.composition import clr
from PIL import Image
import plotly.graph_objects as go

# ---------------------------------------------------------------
# CONFIG
# ---------------------------------------------------------------
imgs_path = '/home/maria/MITNeuralComputation/vit_embeddings/images'
path = '/home/maria/Documents/HuggingMouseData/MouseViTEmbeddings/google_vit-base-patch16-224_embeddings_logits.pkl'
top_k = 10
n_components = 15

# ---------------------------------------------------------------
# LOAD DATA
# ---------------------------------------------------------------
with open(path, 'rb') as f:
    vit_dict = pickle.load(f)['natural_scenes']

# vit_dict is expected to be {image_id: logits_vector}
embeddings = np.stack(list(vit_dict))
image_ids = list(vit_dict)

print("Embeddings shape:", embeddings.shape)
assert embeddings.ndim == 2

# Softmax → CLR
X = softmax(embeddings, axis=1)
X_clr = clr(X + 1e-12)

# ---------------------------------------------------------------
# PCA
# ---------------------------------------------------------------
pca = PCA(n_components=n_components, random_state=0)
U = pca.fit_transform(X_clr)
V = pca.components_.T  # shape: (n_classes, n_components)
expl_var = pca.explained_variance_ratio_

print(f"PCA done. {n_components} components explain {100*expl_var.sum():.1f}% variance")

# ---------------------------------------------------------------
# Helper to load ImageNet class names
# (You can replace this with your own list if needed)
try:
    from torchvision.models import vit_b_16, ViT_B_16_Weights
    class_names = ViT_B_16_Weights.IMAGENET1K_V1.meta["categories"]
except Exception:
    class_names = [f"class_{i}" for i in range(V.shape[0])]

# ---------------------------------------------------------------
# Function to plot top ±10 labels for one PC
# ---------------------------------------------------------------
def plot_pc_semantics_plotly(V, class_names, pc_idx, top_k=10):
    v = V[:, pc_idx]
    pos_idx = np.argsort(v)[-top_k:][::-1]
    neg_idx = np.argsort(v)[:top_k]
    pos_names = [class_names[i] for i in pos_idx]
    neg_names = [class_names[i] for i in neg_idx]

    fig = go.Figure()
    fig.add_trace(go.Bar(
        x=v[pos_idx],
        y=pos_names,
        orientation='h',
        marker_color='green',
        name='Positive'
    ))
    fig.add_trace(go.Bar(
        x=v[neg_idx],
        y=neg_names,
        orientation='h',
        marker_color='red',
        name='Negative'
    ))
    fig.update_layout(
        title=f"PC{pc_idx+1} — top ±{top_k} ImageNet categories",
        barmode='overlay',
        xaxis_title="Loading weight",
        yaxis_title="Class",
        template="plotly_white"
    )
    return fig

# ---------------------------------------------------------------
# Function to show histogram of U[:, pc_idx]
# ---------------------------------------------------------------
def plot_pc_distribution_plotly(U, pc_idx):
    fig = go.Figure()
    fig.add_trace(go.Histogram(
        x=U[:, pc_idx],
        nbinsx=25,
        marker_color='steelblue'
    ))
    fig.update_layout(
        title=f"Distribution of image projections on PC{pc_idx+1}",
        xaxis_title="PC score",
        yaxis_title="Count",
        template="plotly_white"
    )
    return fig

# ---------------------------------------------------------------
# Example: visualize one PC
# ---------------------------------------------------------------
pc_idx = 0  # change 0..14 for others
fig_labels = plot_pc_semantics_plotly(V, class_names, pc_idx, top_k)
fig_images = plot_pc_distribution_plotly(U, pc_idx)

fig_labels.show()
fig_images.show()

# ---------------------------------------------------------------
# Optional: save top / bottom scoring images per PC
# ---------------------------------------------------------------
def show_top_images(U, pc_idx, image_ids, imgs_path, top_n=3):
    scores = U[:, pc_idx]
    top_idx = np.argsort(scores)[-top_n:][::-1]
    bot_idx = np.argsort(scores)[:top_n]

    fig, axes = plt.subplots(2, top_n, figsize=(3*top_n, 6))
    for i, idx in enumerate(top_idx):
        img_path = os.path.join(imgs_path, f"{image_ids[idx]}.png")
        if os.path.exists(img_path):
            axes[0, i].imshow(Image.open(img_path))
        axes[0, i].set_title(f"Top {i+1} ({scores[idx]:.2f})")
        axes[0, i].axis("off")

    for i, idx in enumerate(bot_idx):
        img_path = os.path.join(imgs_path, f"{image_ids[idx]}.png")
        if os.path.exists(img_path):
            axes[1, i].imshow(Image.open(img_path))
        axes[1, i].set_title(f"Bottom {i+1} ({scores[idx]:.2f})")
        axes[1, i].axis("off")

    plt.suptitle(f"Example images for PC{pc_idx+1}")
    plt.tight_layout()
    plt.show()

# Uncomment if you have image files:
# show_top_images(U, pc_idx, image_ids, imgs_path)


Embeddings shape: (118, 1000)
PCA done. 15 components explain 72.0% variance
