The only new thing for this paper is the Vision Transformer and how to convert an image in B,3,224,224 -> B,197,H.

H = Embedding Dimension.

3 = Number of channels [RGB]

224 x 224 = pixels.

B = Number of Images.

226X226 -> If patch size is 16.  so there will be 16 "square blocks" of 14x14 size. on patch represents 14x14 pixels in the og image.

so now this new Square grid is converted to a flat list. so 14x14 -> 196.

1 CLS token is prepended. ->197.

Before patching:  [B,   3, 224, 224]   ← 3 = RGB channels

After conv:       [B, 512,  14,  14]   ← 512 = embedding channels, not colors

The channels get mixed in with embedding dimension.  Each "token" in 197 tokens gets an embedding of Hidden_size. so 197,H and finally, we get B,197,E.

To this we also add positional embeddings to give the model some sense of where each "pixel patch" is located in the image

In [None]:
'''
Example setup

Batch size: B

Image: [B, 3, 224, 224]

Patch size: 16

Embed dim (hidden size): E = 512

Formulas:

patches per side = 224 ÷ 16 = 14

total patches = 14 × 14 = 196

sequence length after adding CLS = 196 + 1 = 197

Shape flow

Input image

[B, 3, 224, 224]

Conv2d patchifier + projection
Conv2d(in_ch=3, out_ch=E, kernel_size=16, stride=16)
This both splits into 16×16 tiles and projects each tile to E dims.

Output: [B, E, 14, 14]

Flatten spatial grid to a sequence
Flatten last two dims (14×14) and move channels to last:

After flatten(2) → [B, E, 196]

After transpose(1, 2) → [B, 196, E]

Prepend [CLS] token
A learnable vector expanded for the batch: [1, 1, E] → [B, 1, E] and concatenated in front.

After concat: [B, 197, E]

Add positional embeddings
A learnable table for every position (including CLS): [1, 197, E], broadcast and added.

After adding: [B, 197, E] (same shape, now position‑aware)

Feed to the Transformer encoder
The VisionTower (ViT) takes [B, 197, E] and returns [B, 197, E].
You then take the CLS vector as the image representation:

CLS pooled output: [B, E]

Quick alternate example (to cement it)

Image [B, 3, 256, 256], patch 32, E = 768

patches per side: 256 ÷ 32 = 8

total patches: 8 × 8 = 64

after conv: [B, 768, 8, 8]

flatten+transpose: [B, 64, 768]

add CLS → [B, 65, 768]

add pos → [B, 65, 768]

take CLS → [B, 768]

One‑screen mental model
Images:     [B, 3, H, W]
↓ Conv(stride=patch) splits & projects
Patch grid: [B, E, H/patch, W/patch]
↓ Flatten & transpose
Tokens:     [B, P, E]        (P = (H/patch)*(W/patch))
↓ Prepend CLS, add pos
Sequence:   [B, P+1, E]
↓ Transformer encoder
CLS vec:    [B, E]           (global image embedding)
'''

Rest all of the code is easy and can be taken directly from the Encoder of BERT. [Important, Channels get mixed in with Embeddings so, 3->E]


In [None]:
import math
import torch
import torch.nn.functional as F
from torch import nn

In [None]:
class BERTEmbedding(nn.Module):
  def __init__(self, vocab_size, hidden_size, dropout = 0.1,
               max_position_embeddings = 512, type_vocab_size = 2 ):
    #“type vocab” is the vocabulary of segment types. there are only 2. sentence 0 and sentence 1
    super().__init__()
    # this is a constructor. when we create this class, basically we are creating a 30000 x 768 table and
    self.token_embed = nn.Embedding(num_embeddings = vocab_size, embedding_dim = hidden_size)

    self.position_embed = nn.Embedding(max_position_embeddings, hidden_size)

    self.segment_embed = nn.Embedding(type_vocab_size,hidden_size)

    self.dropout = nn.Dropout(dropout)
    self.layernorm = nn.LayerNorm(hidden_size) # = 768
    # normalizes the inputs to a layer across the feature dimensions for each individual sample in a batch

  def forward(self, input_ids, token_type_ids):
    batch_size, sequence_length = input_ids.shape # B x L

    #i need to get 1 embedding vector of size 768 = hidden size per token for all tokens in every batch shape = B x L x H
    token_embedding = self.token_embed(input_ids)

    position_ids = torch.arange(sequence_length, device = input_ids.device).unsqueeze(0).expand(batch_size, sequence_length)
    #above is b x l x seq_len or number of batches x sentences per batch x 512 tokens per sequence
    position_embedding = self.position_embed(position_ids)

    segment_embedding = self.segment_embed(token_type_ids)

    embedding = token_embedding + position_embedding + segment_embedding

    embedding = self.layernorm(embedding)
    embedding = self.dropout(embedding)

    return embedding #[B, L, H] [32,77,768]


In [None]:
class MultiHeadSelfAttention(nn.Module):
  def __init__(self, num_heads, hidden_size, dropout=0.1):
    super().__init__()
    assert hidden_size % num_heads == 0

    self.num_heads = num_heads

    self.hidden_size = hidden_size #768

    self.dropout = nn.Dropout(dropout)

    self.query = nn.Linear(hidden_size, hidden_size)
    self.key   = nn.Linear(hidden_size, hidden_size)
    self.value = nn.Linear(hidden_size, hidden_size)

    self.out_proj = nn.Linear(hidden_size, hidden_size)

    self.head_dim = hidden_size // num_heads
    self.scale = math.sqrt(self.head_dim)

  def forward(self, x, mask = None):
    B, L, H = x.shape
    Q = self.query(x)
    K = self.key(x)
    V = self.value(x)

    Q = Q.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
    K = K.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
    V = V.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)

    raw_attention_logits = torch.matmul(Q, K.transpose(-2,-1)) / self.scale # B x Heads x L x L


    if mask is not None:

      key_mask   = mask.unsqueeze(1).unsqueeze(2)      # B × 1 × 1 × L
      raw_attention_logits = raw_attention_logits.masked_fill(
                                        key_mask == 0, -1e4) # or -inf

            # --- mask rows (queries) ---
      query_mask = mask.unsqueeze(1).unsqueeze(3)      # B × 1 × L × 1
      raw_attention_logits = raw_attention_logits.masked_fill(query_mask == 0, -1e4)

    attention_weights = F.softmax(raw_attention_logits, dim=-1)

    if mask is not None:
        attention_weights = attention_weights * key_mask.float()
    attention_weights = self.dropout(attention_weights)

    attention_output = torch.matmul(attention_weights, V) #B x Heads x L x Head_dim


    attention_output = attention_output.transpose(1,2).contiguous().view(B,L,H)


    final_output = self.out_proj(attention_output)  # [B, L, H] only the last dimension must match the input size

    return final_output, attention_weights #,attention_weights for visualisation.


In [None]:
class EncoderLayer(nn.Module):
  def __init__(self, num_heads, hidden_size, intermediate_size, dropout):

    super().__init__()

    self.self_attention = MultiHeadSelfAttention(num_heads,hidden_size,dropout)

    self.attention_norm = nn.LayerNorm(hidden_size)

    self.feed_forward = nn.Sequential(
        nn.Linear(hidden_size, intermediate_size),
        nn.GELU(), # this just adds non linear patterns thats it.
        nn.Linear(intermediate_size,hidden_size),
        nn.Dropout(dropout),
    )

    self.feed_forward_norm = nn.LayerNorm(hidden_size)

    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask = None): #x = 1 batch of dimensions = B x L x H
    #this layer doesnt change shape output is also B x L x H

    attention_output, attention_weights = self.self_attention(x,mask)

    x = self.attention_norm(x + self.dropout(attention_output))

    feed_forward_output = self.feed_forward(x)

    x = self.feed_forward_norm(x + self.dropout(feed_forward_output))

    return x, attention_weights

In [None]:
class BERTEncoder(nn.Module):
  def __init__(self,num_layers, num_heads, hidden_size, intermediate_size, dropout = 0.1):
    super().__init__()
    #we will stack multiple encoder layers.
    self.encoder_layers = nn.ModuleList([
        EncoderLayer(num_heads = num_heads,
                     hidden_size = hidden_size,
                     intermediate_size = intermediate_size,
                     dropout = dropout
        )
        for _ in range(num_layers)
    ])
  def forward(self, x, mask = None):
    all_attention_weights = []
    #x is again one batch of B x L x H. it is passed through all 12 layers
    #input of layer 2 is output of layer 1
    for layer in self.encoder_layers:
      x, attention_weights = layer(x,mask)
      all_attention_weights.append(attention_weights)
    return x, all_attention_weights # Final shape is also [B, L, H] as Each layer preserves the shape.


In [None]:
class TextTower(nn.Module):

    def __init__(self,
                 vocab_size,
                 hidden_size=512,
                 num_heads=8,
                 num_layers=6,
                 intermediate_size=2048,
                 max_position_embeddings=77,   # CLIP uses short texts
                 type_vocab_size=2,
                 dropout=0.1,
                 pool="cls"):                   # "cls" or "mean"
        super().__init__()
        self.embedding = BERTEmbedding(vocab_size, hidden_size, dropout,
                                       max_position_embeddings, type_vocab_size)
        self.encoder   = BERTEncoder(num_layers, num_heads, hidden_size,
                                     intermediate_size, dropout)
        self.pool = pool

    def forward(self, input_ids, token_type_ids, attention_mask):
        """
        input_ids:        [B, L]
        token_type_ids:   [B, L] (can be zeros for single-sentence)
        attention_mask:   [B, L] (1=real, 0=pad)
        returns:          [B, H]
        """
        x = self.embedding(input_ids, token_type_ids)      # [B,L,H]
        x, _ = self.encoder(x, attention_mask)             # [B,L,H]
        if self.pool == "cls":
            return x[:, 0]
        # mask-aware mean pool
        mask = attention_mask.unsqueeze(-1)                # [B,L,1]
        return (x * mask).sum(1) / mask.sum(1).clamp_min(1e-6)
'''
The pooling is just our way of turning a sequence of token embeddings [B, L, H] into a single fixed-length vector [B, H] that represents the whole sentence (or prompt) so we can compare it to the image vector.

Why we need it in CLIP

Transformers keep one embedding per token.

CLIP’s contrastive loss needs one embedding per modality per example so it can compute a single similarity score between an image and a text.

Pooling condenses the whole sequence into one representative vector.
Why mask-aware mean

If we simply averaged over all positions, the padding tokens (which are meaningless) would dilute the representation.

Mask-aware mean:

Zeros out the padded positions.

Averages only over the “real” tokens.

The goal

Find a single point in the shared projection space that best captures the meaning of the entire text so that matched images and texts end up close together and mismatches far apart.

In practice:

CLS pooling: trust the model’s [CLS] token to gather global info via self-attention (what BERT is trained to do).

Mean pooling: average all non-pad token embeddings for a “global” semantic representation.

Which you choose can be a matter of taste; CLIP-style setups work with either.

'''

'\nThe pooling is just our way of turning a sequence of token embeddings [B, L, H] into a single fixed-length vector [B, H] that represents the whole sentence (or prompt) so we can compare it to the image vector.\n\nWhy we need it in CLIP\n\nTransformers keep one embedding per token.\n\nCLIP’s contrastive loss needs one embedding per modality per example so it can compute a single similarity score between an image and a text.\n\nPooling condenses the whole sequence into one representative vector.\nWhy mask-aware mean\n\nIf we simply averaged over all positions, the padding tokens (which are meaningless) would dilute the representation.\n\nMask-aware mean:\n\nZeros out the padded positions.\n\nAverages only over the “real” tokens.\n\nThe goal\n\nFind a single point in the shared projection space that best captures the meaning of the entire text so that matched images and texts end up close together and mismatches far apart.\n\nIn practice:\n\nCLS pooling: trust the model’s [CLS] token t

In [None]:
class PatchEmbed(nn.Module):
    """
    Splits image into non-overlapping patches and projects to embed_dim.
    Uses Conv2d(kernel=stride=patch) → [B, E, H/patch, W/patch] → flatten → [B, P, E]
    Also prepends a learnable CLS token and adds learnable positional embeddings.
    """
    def __init__(self, img_size=224, patch=16, in_ch=3, embed_dim=512):
        super().__init__()
        assert img_size % patch == 0, "img_size must be divisible by patch size"
        self.patch = patch
        self.grid = img_size // patch          # patches per side
        self.num_patches = self.grid * self.grid

        self.proj = nn.Conv2d(in_ch, embed_dim, kernel_size=patch, stride=patch)
        self.cls  = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos  = nn.Parameter(torch.zeros(1, 1 + self.num_patches, embed_dim))  # includes CLS

        # (optional) tiny init helps a bit
        nn.init.normal_(self.cls, std=1e-6)
        nn.init.normal_(self.pos, std=1e-6)

    def forward(self, x):
        """
        x: [B, 3, H, W] with H=W=img_size
        returns: [B, 1+P, E]
        """
        B, C, H, W = x.shape
        assert H % self.patch == 0 and W % self.patch == 0
        x = self.proj(x)                       # [B,E,H/patch,W/patch] = [B,E,g,g]
        x = x.flatten(2).transpose(1, 2)       # [B,P,E], P=g*g
        cls = self.cls.expand(B, -1, -1)       # [B,1,E]
        x = torch.cat([cls, x], dim=1)         # [B,1+P,E]
        x = x + self.pos[:, :x.size(1)]
        return x


In [None]:
class VisionTowerViT(nn.Module):
    """
    ViT-style encoder built from your EncoderLayer/BERTEncoder.
    """
    def __init__(self,
                 img_size=224,
                 patch=16,
                 embed_dim=512,
                 num_layers=6,
                 num_heads=8,
                 mlp_dim=2048,
                 dropout=0.1):
        super().__init__()
        self.patchify = PatchEmbed(img_size, patch, 3, embed_dim)
        self.encoder  = BERTEncoder(num_layers, num_heads, embed_dim, mlp_dim, dropout)

    def forward(self, images):
        """
        images: [B,3,H,W]
        returns: [B, E]  (CLS pooled)
        """
        x = self.patchify(images)                         # [B,1+P,E]
        B, L, _ = x.shape
        mask = torch.ones(B, L, dtype=torch.long, device=x.device)  # no padding in vision
        x, _ = self.encoder(x, mask)                      # [B,1+P,E]
        return x[:, 0]


In [None]:
class ProjectionHead(nn.Module):
    """
    Simple linear (no bias). Norm happens outside.
    """
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.proj = nn.Linear(in_dim, out_dim, bias=False)

    def forward(self, x):
        return self.proj(x)
'''
What is the shared projection space? Is it cross‑attention?

Not cross‑attention. It’s metric learning.

Each tower produces a vector:
h_text ∈ ℝ^{H_t}, h_img ∈ ℝ^{H_v}.

We map them to a common dimension D using separate linear heads:

z_t = text_proj(h_text)   # [B,D]
z_i = image_proj(h_img)   # [B,D]
z_t = normalize(z_t); z_i = normalize(z_i)  # L2 to unit sphere


In this shared space, we compute cosine similarities and train with the contrastive (InfoNCE) loss so matched pairs have high similarity, mismatched pairs low.

Why do this?

Lets each tower keep its own internal width (H_t, H_v) and still compare in a single space D (e.g., 512).

The projection heads act like adapters that learn the right metric geometry for alignment.

Linear, no bias is typical; you can use a small MLP, but CLIP’s linear head works great.

How it differs from cross‑attention

Cross‑attention fuses modalities token‑by‑token (image attends to text tokens and vice versa).

CLIP’s bi‑encoder doesn’t fuse; each side encodes independently → project → compare with a dot product.
This makes CLIP fast for retrieval (precompute embeddings; no cross‑att at query time).

Timeline of ops

Text tokens → self‑attention stack → pooled h_text.

Image patches → self‑attention stack (ViT) → pooled h_img.

Each goes through its own projection head to D, then L2 norm.

Compute similarity matrix S = (z_t)(z_i)^T, scale by 1/τ, apply contrastive loss.

'''

'\nWhat is the shared projection space? Is it cross‑attention?\n\nNot cross‑attention. It’s metric learning.\n\nEach tower produces a vector:\nh_text ∈ ℝ^{H_t}, h_img ∈ ℝ^{H_v}.\n\nWe map them to a common dimension D using separate linear heads:\n\nz_t = text_proj(h_text)   # [B,D]\nz_i = image_proj(h_img)   # [B,D]\nz_t = normalize(z_t); z_i = normalize(z_i)  # L2 to unit sphere\n\n\nIn this shared space, we compute cosine similarities and train with the contrastive (InfoNCE) loss so matched pairs have high similarity, mismatched pairs low.\n\nWhy do this?\n\nLets each tower keep its own internal width (H_t, H_v) and still compare in a single space D (e.g., 512).\n\nThe projection heads act like adapters that learn the right metric geometry for alignment.\n\nLinear, no bias is typical; you can use a small MLP, but CLIP’s linear head works great.\n\nHow it differs from cross‑attention\n\nCross‑attention fuses modalities token‑by‑token (image attends to text tokens and vice versa).\n\nC

In [None]:
class CLIPModel(nn.Module):
    """
    Holds text & vision towers, projects to shared dim, normalizes, and
    returns the scaled similarity matrices for contrastive loss.
    """
    def __init__(self,
                 text_tower: TextTower,
                 vision_tower: VisionTowerViT,
                 d_text: int,
                 d_vision: int,
                 proj_dim: int = 512,
                 init_tau: float = 0.07):
        super().__init__()
        self.text_tower   = text_tower
        self.vision_tower = vision_tower
        self.text_proj    = ProjectionHead(d_text, proj_dim)
        self.vision_proj  = ProjectionHead(d_vision, proj_dim)
        # Learn log(1/τ) → exponentiate at forward time
        self.logit_scale  = nn.Parameter(torch.tensor(math.log(1.0 / init_tau), dtype=torch.float32))

    def _l2norm(self, x):
        return F.normalize(x, dim=-1)

    def encode_text(self, input_ids, token_type_ids, attention_mask):
        h = self.text_tower(input_ids, token_type_ids, attention_mask)  # [B,H_t]
        z = self.text_proj(h)                                           # [B,D]
        return self._l2norm(z)

    def encode_image(self, images):
        h = self.vision_tower(images)                                   # [B,H_v]
        z = self.vision_proj(h)                                         # [B,D]
        return self._l2norm(z)

    def forward(self, input_ids, token_type_ids, attention_mask, images):
        z_t = self.encode_text(input_ids, token_type_ids, attention_mask)   # [B,D]
        z_i = self.encode_image(images)                                      # [B,D]
        scale = self.logit_scale.exp()                                       # scalar 1/τ
        logits_per_text  = scale * (z_t @ z_i.t())                           # [B,B]
        logits_per_image = logits_per_text.t()                               # [B,B]
        return logits_per_text, logits_per_image


In [None]:
def clip_contrastive_loss(logits_per_text, logits_per_image):
    """
    logits_per_text:  [B,B]  (rows = each text against all images)
    logits_per_image: [B,B]  (rows = each image against all texts)
    """
    B = logits_per_text.size(0)
    target = torch.arange(B, device=logits_per_text.device)
    loss_t = F.cross_entropy(logits_per_text, target)
    loss_i = F.cross_entropy(logits_per_image, target)
    return 0.5 * (loss_t + loss_i)


In [None]:
from torch.utils.data import Dataset
from torchvision import transforms as T

def get_train_transforms(img_size=224):
    return T.Compose([
        T.RandomResizedCrop(img_size, scale=(0.8, 1.0)),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std =[0.229, 0.224, 0.225]),
    ])

def get_eval_transforms(img_size=224):
    return T.Compose([
        T.Resize(int(img_size * 1.14)),
        T.CenterCrop(img_size),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std =[0.229, 0.224, 0.225]),
    ])

class CLIPPairsDataset(Dataset):
    """
    Expects a list of (PIL.Image, caption_str) pairs.
    Tokenizes text to fixed max_len (e.g., 77).
    """
    def __init__(self, pairs, tokenizer, max_len=77, tfms=None):
        self.pairs = pairs
        self.tk = tokenizer
        self.max_len = max_len
        self.tfms = tfms if tfms is not None else get_train_transforms()

    def __len__(self): return len(self.pairs)

    def __getitem__(self, i):
        img, txt = self.pairs[i]
        img = self.tfms(img)                                          # [3,H,W]
        enc = self.tk(txt,
                      padding="max_length",
                      truncation=True,
                      max_length=self.max_len,
                      return_tensors="pt",
                      return_token_type_ids=True,
                      return_attention_mask=True)
        return {
            "image": img,
            "input_ids": enc["input_ids"].squeeze(0),                 # [L]
            "token_type_ids": enc["token_type_ids"].squeeze(0),       # [L]
            "attention_mask": enc["attention_mask"].squeeze(0),       # [L]
        }


In [None]:
import transformers
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [None]:
#Sanity Check.
# Build towers
text_tower = TextTower(vocab_size=len(tokenizer), hidden_size=512, num_heads=8,
                       num_layers=6, intermediate_size=2048, max_position_embeddings=77)
vision_tower = VisionTowerViT(img_size=224, patch=16, embed_dim=512,
                              num_layers=6, num_heads=8, mlp_dim=2048)

model = CLIPModel(text_tower, vision_tower, d_text=512, d_vision=512, proj_dim=512)

# Fake batch
B, L = 4, 77
images = torch.randn(B, 3, 224, 224)
input_ids = torch.randint(0, len(tokenizer), (B, L))
token_type_ids = torch.zeros(B, L, dtype=torch.long)
attention_mask = torch.ones(B, L, dtype=torch.long)

with torch.no_grad():
    lt, li = model(input_ids, token_type_ids, attention_mask, images)
print(lt.shape, li.shape)  # both [B,B]
loss = clip_contrastive_loss(lt, li)
print("loss:", float(loss))


torch.Size([4, 4]) torch.Size([4, 4])
loss: 1.2473889589309692


In [None]:
def get_param_groups(model, weight_decay, no_decay=('bias', 'LayerNorm.weight')):
    decay_params = []
    no_decay_params = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if any(nd in name for nd in no_decay):
            no_decay_params.append(param)
        else:
            decay_params.append(param)
    return [
        {'params': decay_params, 'weight_decay': weight_decay},
        {'params': no_decay_params, 'weight_decay': 0.0},
    ]


In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Assume model has: encode_image(imgs), encode_text(tokens), logit_scale
model = model.to(device)

# Optimizer
optimizer = AdamW(
    get_param_groups(model, weight_decay=0.2),
    lr=5e-4, betas=(0.9, 0.98), eps=1e-6
)

# Scheduler: warmup steps then cosine
warmup_steps = 500
total_steps = 10000
scheduler = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps)

def warmup_lr_lambda(step):
    return min((step + 1) / warmup_steps, 1.0)

warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, warmup_lr_lambda)

# Training loop
for step, (images, texts) in enumerate(train_loader):
    images, texts = images.to(device), texts.to(device)

    # Forward
    z_i = model.encode_image(images)   # [B, D], L2-normalized inside
    z_t = model.encode_text(texts)     # [B, D]

    logits = z_t @ z_i.t() * model.logit_scale.exp()  # [B, B]
    targets = torch.arange(logits.size(0), device=device)

    loss_t = F.cross_entropy(logits, targets)
    loss_i = F.cross_entropy(logits.t(), targets)
    loss = 0.5 * (loss_t + loss_i)

    # Backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Scheduler update
    if step < warmup_steps:
        warmup_scheduler.step()
    else:
        scheduler.step()

    if step % 100 == 0:
        print(f"Step {step} | Loss: {loss.item():.4f}")


In [None]:
#This is the ZERO SHOT HElper for testing inference
#After training, we can retrieve images from text queries or vice-versa without further tuning:
@torch.no_grad()
def zero_shot_retrieval(model, query_texts, candidate_images, tokenizer):
    model.eval()
    device = next(model.parameters()).device

    # Encode
    text_tokens = tokenizer(query_texts, padding=True, truncation=True, return_tensors='pt').to(device)
    z_t = model.encode_text(**text_tokens)

    z_i = model.encode_image(candidate_images.to(device))

    # Similarity
    sims = z_t @ z_i.t() * model.logit_scale.exp()  # [num_queries, num_images]

    # Rankings
    rankings = sims.argsort(dim=1, descending=True)
    return rankings


In [None]:
#Example
rankings = zero_shot_retrieval(model, ["a photo of a dog"], image_batch, tokenizer)
print("Most likely image index for query:", rankings[0, 0].item())



---

Gradio trained colab version 2


In [None]:
!pip -q install -U transformers datasets faiss-cpu accelerate

import math, random, os
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from datasets import load_dataset
from transformers import (
    DPRQuestionEncoder, DPRContextEncoder,
    DPRQuestionEncoderTokenizerFast, DPRContextEncoderTokenizerFast,
    get_linear_schedule_with_warmup
)

# Repro
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.1/40.1 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m117.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m503.6/503.6 kB[0m [31m44.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m24.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.8/42.8 MB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf-cu12 25.6.0 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", but you have pyarrow 21.0.0 which is incompatible.
pylibcudf-cu12 25.6.0 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x8

In [None]:
q_tok = DPRQuestionEncoderTokenizerFast.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
p_tok = DPRContextEncoderTokenizerFast .from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
vocab_size_match = (len(q_tok), len(p_tok))
vocab_size_match


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/493 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/492 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizerFast'.


(30522, 30522)

In [None]:
def make_squad_pairs(split="train[:10%]"):
    """
    Returns list of (question, title, context).
    We keep title empty here (SQuAD doesn't have titles per item).
    """
    ds = load_dataset("squad", split=split)
    questions = ds["question"]
    contexts  = ds["context"]
    titles    = [""] * len(contexts)
    return list(zip(questions, titles, contexts))

def build_corpus_from_squad(split="train[:10%]"):
    """
    Stable-unique the contexts and return:
      corpus: [{"id": i, "title": "", "text": ctx}, ...]
      text2id: {ctx_string: id}
    """
    ds = load_dataset("squad", split=split)
    uniq = list(dict.fromkeys(ds["context"]))
    corpus = [{"id": i, "title": "", "text": ctx} for i, ctx in enumerate(uniq)]
    text2id = {c["text"]: c["id"] for c in corpus}
    return corpus, text2id


In [None]:
def batch_q(questions, max_len=64):
    enc = q_tok(questions, return_tensors="pt", padding=True, truncation=True, max_length=max_len)
    return {k: v.to(device) for k, v in enc.items()}

def batch_p(titles, texts, max_len=256):
    # DPR context encoder expects pair (title, text)
    enc = p_tok(text=titles, text_pair=texts, return_tensors="pt",
                padding=True, truncation=True, max_length=max_len)
    return {k: v.to(device) for k, v in enc.items()}


In [None]:
class DPRBiEncoderHF(nn.Module):
    """
    Wraps Hugging Face DPR encoders, projects to a shared space (optional),
    L2-normalizes outputs, and maintains a learnable logit_scale (temperature).
    """
    def __init__(self, proj_dim=768, init_tau=0.07, freeze_backbones=False, use_projection=True):
        super().__init__()
        # Pretrained towers
        self.q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        self.p_encoder = DPRContextEncoder  .from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

        if freeze_backbones:
            for p in self.q_encoder.parameters(): p.requires_grad = False
            for p in self.p_encoder.parameters(): p.requires_grad = False

        self.use_projection = use_projection
        hidden = self.q_encoder.config.hidden_size  # 768
        if use_projection:
            self.q_proj = nn.Linear(hidden, proj_dim, bias=False)
            self.p_proj = nn.Linear(hidden, proj_dim, bias=False)
            out_dim = proj_dim
        else:
            self.q_proj = nn.Identity()
            self.p_proj = nn.Identity()
            out_dim = hidden

        # Learn log(1/tau)
        self.logit_scale = nn.Parameter(torch.tensor(math.log(1.0 / init_tau), dtype=torch.float32))
        self.out_dim = out_dim

    def _norm(self, x):  # L2-normalize
        return F.normalize(x, dim=-1)

    @torch.no_grad()
    def encode_questions_eval(self, **q_inputs):
        self.eval()
        zq = self.encode_questions(**q_inputs)
        return zq

    @torch.no_grad()
    def encode_passages_eval(self, **p_inputs):
        self.eval()
        zp = self.encode_passages(**p_inputs)
        return zp

    def encode_questions(self, **q_inputs):
        out = self.q_encoder(**q_inputs, return_dict=True)
        h   = out.pooler_output  # [B, 768]
        z   = self.q_proj(h)
        return self._norm(z)

    def encode_passages(self, **p_inputs):
        out = self.p_encoder(**p_inputs, return_dict=True)
        h   = out.pooler_output  # [B, 768]
        z   = self.p_proj(h)
        return self._norm(z)

    def forward(self, q_batch, p_batch):
        zq = self.encode_questions(**q_batch)  # [B, D]
        zp = self.encode_passages(**p_batch)   # [B, D]
        scale = self.logit_scale.exp()
        logits_qp = scale * (zq @ zp.t())      # [B, B]
        logits_pq = logits_qp.t()              # [B, B]
        return logits_qp, logits_pq


In [None]:
def dpr_contrastive_loss(logits_qp, logits_pq):
    B = logits_qp.size(0)
    y = torch.arange(B, device=logits_qp.device)
    loss_q = F.cross_entropy(logits_qp, y)
    loss_p = F.cross_entropy(logits_pq, y)
    return 0.5 * (loss_q + loss_p)


In [None]:
import faiss
import numpy as np

def build_faiss_index(model, corpus, batch_size=64, p_max=256):
    model.eval()
    all_vecs = []
    for i in range(0, len(corpus), batch_size):
        titles = [c["title"] for c in corpus[i:i+batch_size]]
        texts  = [c["text"]  for c in corpus[i:i+batch_size]]
        with torch.no_grad():
            zp = model.encode_passages(**batch_p(titles, texts, p_max))  # [b, D], normalized
        all_vecs.append(zp.cpu().numpy().astype("float32"))
    doc_embs = np.vstack(all_vecs)  # [N, D]

    dim = doc_embs.shape[1]
    index = faiss.IndexFlatIP(dim)       # inner product on unit vectors = cosine
    index.add(doc_embs)                  # order == corpus order
    id_map = np.array([c["id"] for c in corpus], dtype=np.int64)
    return index, id_map, doc_embs

@torch.no_grad()
def encode_question(model, question, q_max=64):
    zq = model.encode_questions(**batch_q([question], q_max))  # [1, D]
    return zq.cpu().numpy().astype("float32")

def retrieve(index, id_map, corpus, question, model, k=5, q_max=64):
    q = encode_question(model, question, q_max)   # (1, D)
    scores, idxs = index.search(q, k)             # (1, k)
    idxs = idxs[0]; scores = scores[0]
    hits = []
    for rank, (row, score) in enumerate(zip(idxs, scores), 1):
        doc = corpus[int(row)]
        hits.append({
            "rank": rank,
            "faiss_row": int(row),
            "score": float(score),
            "id": int(doc["id"]),
            "title": doc["title"],
            "text": doc["text"][:200] + ("..." if len(doc["text"]) > 200 else "")
        })
    return hits


In [None]:
def eval_recall_at_k(model, index, corpus_text2id, eval_pairs, ks=(5, 20, 100), q_max=64):
    model.eval()
    counts = {k: 0 for k in ks}
    total  = 0

    for q, _title, pos_ctx in eval_pairs:
        pos_id = corpus_text2id.get(pos_ctx, None)
        if pos_id is None:
            continue
        total += 1

        with torch.no_grad():
            zq = model.encode_questions(**batch_q([q], q_max)).cpu().numpy().astype("float32")

        scores, idxs = index.search(zq, max(ks))
        rows = idxs[0]

        for k in ks:
            if pos_id in rows[:k]:
                counts[k] += 1

    return {f"R@{k}": counts[k] / max(total, 1) for k in ks}



In [None]:
def train_dpr(model, pairs,
              steps=3000, batch_size=32, lr=5e-5, wd=0.01,
              warmup_ratio=0.1, q_max=64, p_max=256, log_every=50):
    model.train()
    model.to(device)

    # Train everything that requires grad
    params = [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(params, lr=lr, weight_decay=wd)

    # Scheduler (linear w/ warmup by fraction of total steps)
    warmup = int(warmup_ratio * steps)
    sched = get_linear_schedule_with_warmup(opt, num_warmup_steps=warmup, num_training_steps=steps)

    n = len(pairs)
    for step in range(steps):
        # cyclic batch sampling
        start = (step * batch_size) % n
        end   = start + batch_size
        if end <= n:
            batch = pairs[start:end]
        else:
            batch = pairs[start:] + pairs[:(end % n)]
        qs, titles, ctxs = zip(*batch)

        q_batch = batch_q(qs, q_max)
        p_batch = batch_p(titles, ctxs, p_max)

        logits_qp, logits_pq = model(q_batch, p_batch)
        loss = dpr_contrastive_loss(logits_qp, logits_pq)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(params, 1.0)
        opt.step()
        sched.step()

        if (step + 1) % log_every == 0:
            with torch.no_grad():
                tau = (model.logit_scale.detach().exp().reciprocal()).item()
            print(f"[train] step {step+1:04d} | loss {loss.item():.4f} | tau ~ {tau:.4f}")


In [None]:
def main():
    # ------- Config -------
    train_split = "train[:10%]"        # tweak: 5% / 10% / 20% for visible gains
    eval_split  = "validation[:10%]"
    steps       = 3000                 # ~45–60 min on T4 for 10%
    batch_size  = 32
    lr          = 5e-5
    wd          = 0.01
    q_max       = 64
    p_max       = 256
    freeze_backbones = False           # True = faster/cheaper; False = better quality
    use_projection   = True            # keeps heads small but powerful
    proj_dim    = 768                  # keep same as hidden for simplicity

    # ------- Model -------
    model = DPRBiEncoderHF(
        proj_dim=proj_dim,
        init_tau=0.07,
        freeze_backbones=freeze_backbones,
        use_projection=use_projection
    ).to(device)

    # ------- Data -------
    print("Preparing data...")
    train_pairs = make_squad_pairs(train_split)      # (q, "", ctx)
    eval_pairs  = make_squad_pairs(eval_split)

    # ------- Train -------
    print(f"Training on {len(train_pairs)} pairs for {steps} steps...")
    train_dpr(model, train_pairs, steps=steps, batch_size=batch_size,
              lr=lr, wd=wd, warmup_ratio=0.1, q_max=q_max, p_max=p_max, log_every=100)

    # ------- Build corpus & FAISS -------
    print("\nBuilding corpus and FAISS index...")
    corpus, text2id = build_corpus_from_squad(train_split)  # index on train contexts
    index, id_map, doc_embs = build_faiss_index(model, corpus, batch_size=64, p_max=p_max)

    # ------- Sanity retrieval -------
    print("\nQuick retrieval sanity check:")
    for q in [
        "Who wrote the book 'Pride and Prejudice'?",
        "What is the capital of France?",
        "Who discovered penicillin?"
    ]:
        hits = retrieve(index, id_map, corpus, q, model, k=3, q_max=q_max)
        print(f"\nQ: {q}")
        for h in hits:
            print(f"  {h['rank']}. score={h['score']:.3f}  id={h['id']}  {h['text']}")

    # ------- Eval Recall@k (on eval pairs, searching the train index) -------
    print("\nEvaluating Recall@k on validation subset (searching train index)...")
    metrics = eval_recall_at_k(model, index, text2id, eval_pairs, ks=(5, 20, 100), q_max=q_max)
    nice = {k: f"{v:.3f}" for k, v in metrics.items()}
    print(nice)

    # ------- Save -------
    torch.save(model.state_dict(), "dpr_biencoder_hf.pt")
    print("\nSaved DPR weights to dpr_biencoder_hf.pt")

if __name__ == "__main__":
    main()


pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Preparing data...


README.md: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

plain_text/train-00000-of-00001.parquet:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

plain_text/validation-00000-of-00001.par(…):   0%|          | 0.00/1.82M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

Training on 8760 pairs for 3000 steps...
[train] step 0100 | loss 3.3737 | tau ~ 0.0700
[train] step 0200 | loss 3.2287 | tau ~ 0.0698
[train] step 0300 | loss 3.1540 | tau ~ 0.0697
[train] step 0400 | loss 2.4535 | tau ~ 0.0696
[train] step 0500 | loss 2.4723 | tau ~ 0.0693
[train] step 0600 | loss 2.9703 | tau ~ 0.0692
[train] step 0700 | loss 2.2328 | tau ~ 0.0691
[train] step 0800 | loss 2.3042 | tau ~ 0.0689
[train] step 0900 | loss 2.2487 | tau ~ 0.0689
[train] step 1000 | loss 1.7898 | tau ~ 0.0688
[train] step 1100 | loss 2.0015 | tau ~ 0.0687
[train] step 1200 | loss 2.3125 | tau ~ 0.0686
[train] step 1300 | loss 2.4241 | tau ~ 0.0685
[train] step 1400 | loss 2.3568 | tau ~ 0.0685
[train] step 1500 | loss 2.0824 | tau ~ 0.0685
[train] step 1600 | loss 2.0093 | tau ~ 0.0685
[train] step 1700 | loss 2.7099 | tau ~ 0.0685
[train] step 1800 | loss 1.6970 | tau ~ 0.0684
[train] step 1900 | loss 2.1183 | tau ~ 0.0684
[train] step 2000 | loss 1.7285 | tau ~ 0.0684
[train] step 2100 |

NameError: name 'model' is not defined

In [None]:
model = DPRBiEncoderHF(
    proj_dim=768,
    init_tau=0.07,
    freeze_backbones=False,
    use_projection=True
).to(device)

# Load your trained weights
model.load_state_dict(torch.load("dpr_biencoder_hf.pt", map_location=device))
model.eval()

Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the

DPRBiEncoderHF(
  (q_encoder): DPRQuestionEncoder(
    (question_encoder): DPREncoder(
      (bert_model): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0-11): 12 x BertLayer(
              (attention): BertAttention(
                (self): BertSdpaSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (output): BertS

In [None]:
# Build index on validation contexts
corpus, text2id = build_corpus_from_squad("validation[:10%]")
index, id_map, doc_embs = build_faiss_index(model, corpus)

# Make eval pairs from same split
eval_pairs = make_squad_pairs("validation[:10%]")

# Compute Recall@k
metrics = eval_recall_at_k(model, index, text2id, eval_pairs, ks=(5,20,100))
print(metrics)


{'R@5': 0.7909176915799432, 'R@20': 0.9649952696310312, 'R@100': 1.0}
