# CLIP zero-shot visualization (Colab)

This notebook runs the CLIP zero-shot example, displays the image with an overlaid confidence bar chart, and shows a 2D projection of the image and class embeddings. It also includes a small, visualization-only regularization that exaggerates close/far relationships in the latent space.

Notes:
- If you have a local `duck.jpg` file, upload it to the Colab runtime (click the Files tab -> Upload) or place it in the same directory.
- The notebook will download a sample duck image automatically if no `duck.jpg` is found.

In [None]:
# Install required packages. On Colab many of these are already available, but we include these commands to be explicit.
# If you need GPU-enabled PyTorch, consider installing with the proper CUDA wheel for your runtime.
!pip install -q transformers matplotlib pillow numpy
# Optional: CPU-only PyTorch install (uncomment if needed)
# !pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

In [None]:
# Imports and helpers
import os
import warnings
warnings.filterwarnings("ignore", message="TypedStorage is deprecated")
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests
from io import BytesIO
import numpy as np
import matplotlib.pyplot as plt

print('torch version:', torch.__version__, 'cuda available:', torch.cuda.is_available())

In [None]:
# Load CLIP model + processor (this downloads weights on first run)
device = "cuda" if torch.cuda.is_available() else "cpu"
print('Using device:', device)
model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32').to(device)
processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
print('Model and processor loaded')

In [None]:
# Image path and fallback download if not present
img_path = './duck.jpg'
if not os.path.exists(img_path):
    print('duck.jpg not found, downloading a sample image...')
    url = 'https://upload.wikimedia.org/wikipedia/commons/4/49/Male_mallard_duck_head.jpg'
    r = requests.get(url)
    open(img_path, 'wb').write(r.content)
    print('Downloaded sample image to', img_path)

# Candidate labels (edit as desired)
labels = [
    'a picture of a bull','a picture of a racecar',
    'a picture of a woman', 'a picture of a man',
    'a picture of a queen', 'a picture of a king', 'a photo of donald duck', 
    'a photo of a duck', 'a photo of a bird'
]
print('Labels:', labels)

In [None]:
# Prepare inputs and run CLIP
image = Image.open(img_path)
inputs = processor(text=labels, images=image, return_tensors='pt', padding=True).to(device)
with torch.no_grad():
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)

best_idx = probs[0].argmax().item()
print(f'Predicted label: {labels[best_idx]} (confidence {probs[0][best_idx]:.2f})')
for i, label in enumerate(labels):
    if i != best_idx:
        print(f'Other label: {label} (confidence {probs[0][i]:.2f})')

In [None]:
# Visualization: show image and overlay confidence bar chart
# Ensure image is RGB and convert for plotting
image_rgb = Image.open(img_path).convert('RGB')
image_np = np.asarray(image_rgb)
confidences = probs[0].cpu().numpy()
order = np.argsort(confidences)[::-1]
ordered_conf = confidences[order]
ordered_labels = [labels[i] for i in order]

fig = plt.figure(figsize=(10,10))
fig.subplots_adjust(left=0.18, right=0.98, top=0.95, bottom=0.04)
ax_img = fig.add_axes([0.05, 0.28, 0.9, 0.67])
ax_img.imshow(image_np)
ax_img.axis('off')
ax_bar = fig.add_axes([0.18, 0.03, 0.78, 0.17])
ax_bar.patch.set_alpha(0.7)
y_pos = np.arange(len(ordered_labels))
ax_bar.barh(y_pos, ordered_conf, color='C0', height=0.6)
ax_bar.set_yticks(y_pos)
ax_bar.set_yticklabels(ordered_labels, fontsize=10)
ax_bar.invert_yaxis()
ax_bar.set_xlim(0,1)
ax_bar.set_xlabel('Confidence')
for i, v in enumerate(ordered_conf):
    x_pos = min(v + 0.02, 0.99)
    ax_bar.text(x_pos, i, f"{v:.2f}", va='center', fontsize=9)
plt.show()

In [None]:
# Embedding visualization with optional regularization
# Extract embeddings (prefer outputs if provided)
try:
    image_embed = outputs.image_embeds[0]
    text_embeds = outputs.text_embeds
except Exception:
    with torch.no_grad():
        if 'pixel_values' in inputs:
            image_embed = model.get_image_features(pixel_values=inputs['pixel_values'].to(device))[0]
        else:
            raise RuntimeError('No image tensor found for embedding computation')
        if 'input_ids' in inputs:
            text_embeds = model.get_text_features(input_ids=inputs['input_ids'].to(device), attention_mask=inputs.get('attention_mask', None).to(device))
        else:
            raise RuntimeError('No text tensor found for embedding computation')

img_e = image_embed.detach().cpu().numpy()
txt_e = text_embeds.detach().cpu().numpy()

from numpy.linalg import norm
distances = np.array([norm(img_e - t) for t in txt_e])

# Visualization-only regularization function
def regularize_text_embeddings(txt_embeddings, img_embedding, power=1.8, eps=1e-8):
    dists = np.linalg.norm(txt_embeddings - img_embedding, axis=1)
    mean_dist = dists.mean() + eps
    ratios = (dists / mean_dist)
    multipliers = (ratios ** power)
    multipliers = np.nan_to_num(multipliers, nan=1.0, posinf=ratios.max(), neginf=1.0)
    dirs = txt_embeddings - img_embedding[np.newaxis, :]
    transformed = img_embedding[np.newaxis, :] + dirs * multipliers[:, np.newaxis]
    return transformed

power = 1.8  # change this to exaggerate or reduce separation
txt_e_reg = regularize_text_embeddings(txt_e, img_e, power=power)
reg_distances = np.array([norm(img_e - t) for t in txt_e_reg])

# PCA via SVD to 2D for plotting
all_emb = np.vstack([img_e[np.newaxis, :], txt_e_reg])
all_emb_mean = all_emb.mean(axis=0)
X = all_emb - all_emb_mean
U, S, Vt = np.linalg.svd(X, full_matrices=False)
coords = X.dot(Vt.T[:, :2])
img_xy = coords[0]
txt_xy = coords[1:]

fig2 = plt.figure(figsize=(10,5))
ax_emb = fig2.add_axes([0.06, 0.10, 0.88, 0.85])
normed = (reg_distances - reg_distances.min()) / (np.ptp(reg_distances) + 1e-8)
colors = plt.cm.viridis(1 - normed)
ax_emb.scatter(txt_xy[:, 0], txt_xy[:, 1], c=colors, s=90, edgecolor='k')
ax_emb.scatter(img_xy[0], img_xy[1], marker='*', s=220, c='red', edgecolor='k', label='image')
for i, label in enumerate(labels):
    x1, y1 = img_xy
    x2, y2 = txt_xy[i]
    ax_emb.plot([x1, x2], [y1, y2], color='gray', linewidth=0.8, linestyle='--')
    mx, my = (x1 + x2) / 2, (y1 + y2) / 2
    ax_emb.text(mx, my, f"{reg_distances[i]:.2f}", fontsize=9, color='black', bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=1))
    ax_emb.text(txt_xy[i, 0] + 0.01, txt_xy[i, 1] + 0.01, label, fontsize=9)

ax_emb.set_title(f'2D projection of CLIP embeddings (power={power:.2f} regularization)')
ax_emb.axis('equal')
ax_emb.grid(False)
plt.show()

## Notes & next steps
- Adjust `power` in the embedding cell to control how much nearby embeddings are pulled closer and far ones pushed farther for visualization.
- If you want to compare original vs regularized spaces, I can add a side-by-side plot.
- To save figures automatically, you can add `fig.savefig('foo.png', dpi=200)` after the plotting cells.