# 🗣️ Vision-Language Models

**Topics:** CLIP, Image-Text Similarity, Zero-Shot Classification

In [None]:
# Setup
!pip install torch torchvision transformers pillow -q
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
print('✅ Setup complete!')

In [None]:
# CLIP-style Contrastive Loss
def clip_loss(image_embeds, text_embeds, temperature=0.07):
    """Compute CLIP contrastive loss"""
    # Normalize embeddings
    image_embeds = F.normalize(image_embeds, dim=-1)
    text_embeds = F.normalize(text_embeds, dim=-1)
    
    # Compute similarity matrix
    logits = image_embeds @ text_embeds.T / temperature
    
    # Labels are diagonal (image_i matches text_i)
    labels = torch.arange(len(image_embeds))
    
    # Cross-entropy both ways
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.T, labels)
    
    return (loss_i2t + loss_t2i) / 2

# Demo
batch_size, embed_dim = 4, 512
img_emb = torch.randn(batch_size, embed_dim)
txt_emb = torch.randn(batch_size, embed_dim)

loss = clip_loss(img_emb, txt_emb)
print(f'CLIP Loss: {loss.item():.4f}')

In [None]:
# Cosine Similarity Visualization
# Simulate image and text embeddings
np.random.seed(42)
images = ['cat', 'dog', 'car', 'tree']
texts = ['a photo of a cat', 'a photo of a dog', 'a photo of a car', 'a photo of a tree']

# Create mock embeddings (in real CLIP, these come from encoders)
img_emb = np.random.randn(4, 128)
txt_emb = img_emb + np.random.randn(4, 128) * 0.3  # Similar but noisy

# Normalize
img_emb = img_emb / np.linalg.norm(img_emb, axis=1, keepdims=True)
txt_emb = txt_emb / np.linalg.norm(txt_emb, axis=1, keepdims=True)

# Compute similarity
similarity = img_emb @ txt_emb.T

plt.figure(figsize=(6, 5))
plt.imshow(similarity, cmap='Blues')
plt.xticks(range(4), texts, rotation=45, ha='right')
plt.yticks(range(4), images)
plt.colorbar(label='Cosine Similarity')
plt.title('Image-Text Similarity Matrix')
for i in range(4):
    for j in range(4):
        plt.text(j, i, f'{similarity[i,j]:.2f}', ha='center', va='center')
plt.tight_layout()
plt.show()