In [None]:
!pip install --quiet flax optax einops pillow transformers matplotlib tqdm scikit-learn seaborn

In [None]:

# Copyright 2025 The Bonsai AI Authors.
# Licensed under the Apache License, Version 2.0 (the "License");

import os, pickle, random
import numpy as np
import jax, jax.numpy as jnp
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from sklearn.manifold import TSNE
import seaborn as sns
from PIL import Image

from bonsai.models.clip_jax.modeling import CLIPModel
from bonsai.models.clip_jax.config import CLIPConfig
from bonsai.models.clip_jax.utils.preprocess import preprocess_image, tokenize_text

# Paths ‚Äî use local if available, else Kaggle path

ADE_PATH = "datasets/ADEChallengeData2016"
if not os.path.exists(ADE_PATH):
    ADE_PATH = "/kaggle/input/ade20k-dataset/ADEChallengeData2016"

CKPT_PATH = "clip_jax/ckpts/clip_ade20k_epoch1.pkl"

cfg = CLIPConfig()
model = CLIPModel(cfg)

# Load model parameters
if not os.path.exists(CKPT_PATH):
    raise FileNotFoundError(f"Checkpoint not found at {CKPT_PATH}")

with open(CKPT_PATH, "rb") as f:
    params = pickle.load(f)

print("‚úÖ Loaded CLIP-JAX checkpoint successfully!")

In [None]:
# Load ADE20K scene label mapping
scene_labels = {}
scene_txt = os.path.join(ADE_PATH, "sceneCategories.txt")

if os.path.exists(scene_txt):
    with open(scene_txt, "r") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 2:
                scene_labels[parts[0]] = parts[1]
else:
    raise FileNotFoundError("sceneCategories.txt not found in ADE20K dataset.")

all_classes = sorted(list(set(scene_labels.values())))
print(f"üéØ Found {len(all_classes)} unique ADE20K scene classes")

In [None]:

# Randomly choose 5 scene classes for visualization

selected_classes = random.sample(all_classes, 5)
print(f"üé® Selected classes: {selected_classes}")

val_dir = os.path.join(ADE_PATH, "images", "validation")
img_paths, txt_labels = [], []

# Collect up to 10 validation images per selected class

for cls in selected_classes:
    matched = [
        os.path.join(val_dir, f"{k}.jpg")
        for k, v in scene_labels.items()
        if v == cls and os.path.exists(os.path.join(val_dir, f"{k}.jpg"))
    ]
    img_paths.extend(matched[:10])
    txt_labels.extend([cls] * min(10, len(matched)))

print(f"üì∏ Collected {len(img_paths)} images from {len(selected_classes)} classes")

In [None]:

# Compute image embeddings
img_thumbs, img_embs = [], []

for path, lbl in zip(img_paths, txt_labels):
    try:
        img = Image.open(path).convert("RGB").resize((cfg.image_size, cfg.image_size))
    except Exception as e:
        print(f"‚ö†Ô∏è Skipping image {path}: {e}")
        continue

    img_thumbs.append(np.array(img))
    img_jax = preprocess_image(img, cfg.image_size)[None, ...]
    tok = tokenize_text([f"a photo of a {lbl}"], cfg.text_max_len)

    _, img_e, _ = model.apply(params, img_jax, tok, train=False)
    img_embs.append(np.array(img_e[0]))

img_embs = np.stack(img_embs)
print(f"‚úÖ Computed embeddings for {len(img_embs)} images")

In [None]:
# Compute text embeddings
txt_prompts = [f"a photo of a {c}" for c in selected_classes]
tok = tokenize_text(txt_prompts, cfg.text_max_len)

_, txt_embs, _ = model.apply(
    params,
    jnp.zeros((len(txt_prompts), cfg.image_size, cfg.image_size, 3)),
    tok,
    train=False,
)
txt_embs = np.array(txt_embs)

print(f"‚úÖ Computed text embeddings for {len(selected_classes)} prompts")


In [None]:
# Combine and run t-SNE on embeddings
emb = np.concatenate([img_embs, txt_embs], axis=0)
n_samples = emb.shape[0]
perplexity = max(5, min(30, n_samples // 3))

print(f"üß© Running t-SNE on {n_samples} samples (perplexity={perplexity})...")

tsne = TSNE(
    n_components=2,
    perplexity=perplexity,
    init="pca",
    learning_rate="auto",
    random_state=42,
    n_iter=1500,
)
pts = tsne.fit_transform(emb)

num_img = len(img_embs)
img_pts, txt_pts = pts[:num_img], pts[num_img:]
print("‚úÖ t-SNE embedding complete!")

In [None]:
# Visualization setup
palette = sns.color_palette("tab10", n_colors=len(selected_classes))
cls_to_color = {cls: palette[i] for i, cls in enumerate(selected_classes)}

plt.figure(figsize=(22, 18))
ax = plt.gca()
ax.set_facecolor("white")

# Plot text anchors
for i, cls in enumerate(selected_classes):
    x, y = txt_pts[i]
    ax.scatter(x, y, s=700, facecolor=cls_to_color[cls],
               edgecolors="black", linewidth=1.5, zorder=3)
    ax.text(
        x, y + 12,
        cls.replace("_", " "),
        fontsize=20, fontweight="bold",
        color="black", ha="center", va="bottom",
        bbox=dict(facecolor="white", alpha=0.9, edgecolor="none", boxstyle="round,pad=0.5"),
        zorder=4
    )

# Plot image thumbnails
for (x, y), thumb, cls in zip(img_pts, img_thumbs, txt_labels):
    im = OffsetImage(thumb, zoom=1.6, resample=True)
    ab = AnnotationBbox(im, (x, y), frameon=False, pad=0.15)
    ax.add_artist(ab)
    ax.plot(x, y, "o", color=cls_to_color[cls], markersize=10, alpha=0.4, zorder=2)

plt.title(
    "üß© CLIP-JAX ADE20K t-SNE Visualization ‚Äî Images and Text Prompts",
    fontsize=26, pad=40, weight="bold"
)
plt.axis("off")
plt.tight_layout(pad=3)
plt.show()

In [None]:
# Save cache for reuse

os.makedirs("clip_jax/ckpts", exist_ok=True)
cache_path = "clip_jax/ckpts/tsne_cache.pkl"

with open(cache_path, "wb") as f:
    pickle.dump((img_pts, txt_pts, selected_classes, txt_labels, cls_to_color), f)

print(f"üíæ Saved cached t-SNE embeddings ‚Üí {cache_path}")


In [None]:
# Summary
print(f"""
‚úÖ Zero-shot t-SNE visualization complete!
üì¶ Model: CLIP-JAX (trained on ADE20K)
üéØ Classes visualized: {', '.join(selected_classes)}
üíæ Cache saved: clip_jax/ckpts/tsne_cache.pkl
""")