In [None]:
!pip install faiss-cpu



In [None]:
import math
import random
import numpy as np
#import faiss-cpu
import faiss
import torch
import torch.nn.functional as F
from torch import nn
from datasets import load_dataset

# Repro
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class BERTEmbedding(nn.Module):
  def __init__(self, vocab_size, hidden_size, dropout = 0.1,
               max_position_embeddings = 512, type_vocab_size = 2 ):
    #“type vocab” is the vocabulary of segment types. there are only 2. sentence 0 and sentence 1
    super().__init__()
    # this is a constructor. when we create this class, basically we are creating a 30000 x 768 table and
    self.token_embed = nn.Embedding(num_embeddings = vocab_size, embedding_dim = hidden_size)

    self.position_embed = nn.Embedding(max_position_embeddings, hidden_size)

    self.segment_embed = nn.Embedding(type_vocab_size,hidden_size)

    self.dropout = nn.Dropout(dropout)
    self.layernorm = nn.LayerNorm(hidden_size) # = 768
    # normalizes the inputs to a layer across the feature dimensions for each individual sample in a batch

  def forward(self, input_ids, token_type_ids):
    batch_size, sequence_length = input_ids.shape # B x L

    #i need to get 1 embedding vector of size 768 = hidden size per token for all tokens in every batch shape = B x L x H
    token_embedding = self.token_embed(input_ids)

    position_ids = torch.arange(sequence_length, device = input_ids.device).unsqueeze(0).expand(batch_size, sequence_length)
    #above is b x l x seq_len or number of batches x sentences per batch x 512 tokens per sequence
    position_embedding = self.position_embed(position_ids)

    segment_embedding = self.segment_embed(token_type_ids)

    embedding = token_embedding + position_embedding + segment_embedding

    embedding = self.layernorm(embedding)
    embedding = self.dropout(embedding)

    return embedding # this outputs B, L, H -> B sentences, L is the max sequence length, H is the embedding dimension


In [None]:
class MultiHeadSelfAttention(nn.Module):
  def __init__(self, num_heads, hidden_size, dropout=0.1):
    super().__init__()
    #we will split the hidden size EQUALLY among different heads to learn different aspects of context.
    assert hidden_size % num_heads == 0

    self.num_heads = num_heads

    self.hidden_size = hidden_size #768

    self.dropout = nn.Dropout(dropout)

    self.query = nn.Linear(hidden_size, hidden_size)
    self.key   = nn.Linear(hidden_size, hidden_size)
    self.value = nn.Linear(hidden_size, hidden_size)

    self.out_proj = nn.Linear(hidden_size, hidden_size)

    self.head_dim = hidden_size // num_heads
    self.scale = math.sqrt(self.head_dim)

  def forward(self, x, mask = None):
    #input is 1 batch. so x.shape = bath_Size x seq_len x hidden_dim
    B, L, H = x.shape
    Q = self.query(x)
    K = self.key(x)
    V = self.value(x)

    #lets say we have heads [Heads]. Currently for our batch we have B x L x H
    # so now we need For every sequence in batch, split the hidden dim accross
    # [heads] equally. so b x l x h -> b x [heads] x l x head_dim such that
    # [heads] x head_dim = hidden_dim
    # so current we have B x L x H. First .view(B,L,[HEADS],[HEAD_DIM]) converts
    # this into B x L x Heads x Head_dim [dimensions are numbered from l t r]
    # then transpose(1,2) converts it to B x Heads x L x Head_dim. why? so that
    # attention can be applied accoss heads in parallel
    #so Q,k,v -> B x Heads x L x Head_dim

    Q = Q.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
    K = K.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
    V = V.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)

    #what is k.transpose(-2,-1)
    #k = B x Heads x L x Head_dim
    #k.transpose(-2,-1) = B x Heads x Head_dim x L
    #so Q @ k.transpose(-2,-1) = B x Heads x L x L. So it gives RAW [!!] attention of every token with
    #every other token including itself.

    raw_attention_logits = torch.matmul(Q, K.transpose(-2,-1)) / self.scale # B x Heads x L x L

    # this below logic avoids attention to [PAD] tokens if needed.
    if mask is not None:
      #what happened here:
      '''Why the query-side mask matters
      Inside a self-attention head we compute a score matrix
      S ∈ ℝ^(L×L) (one such matrix per head and per batch element):

      markdown
      Copy
      Edit
      S[i , j]  =   (q_i · k_j) / √d_k
                ↑           ↑
              **query**   **key**
      Columns (j) – keys/values
      Making S[i , j] = −∞ for padding keys ensures no real token attends to <pad> tokens.
      Your code already does this.

      Rows (i) – queries
      If i itself indexes a padding position, that row should not be used at all—the model
      shouldn’t waste time computing how a padded “token” attends to the real sentence.
      In the current implementation the row remains, gets soft-maxed into a probability
      distribution, produces a value vector, and is then added back (residual) into the hidden
      state. Those useless vectors flow upward through all remaining layers and add noise
      (they also cost FLOPs).
      Preventing this is what “mask the queries too” means.
      Imagine a 5-token sequence (last two are padding):
      [CLS] the dog <pad> <pad>
        0    1   2    3     4   ← positions i
      Without query masking
      Rows 3 and 4 run the same math as real tokens. They look at keys 0-2,
      produce some vector, and that junk vector is added to hidden_state[3] and
      hidden_state[4]. In the next encoder layer position 0 (real) can now attend to
      positions 3-4 because they contain non-zero content!

      With query masking
      Rows 3 and 4 are set to -1e4, soft-max→0, value-weighted-sum→0, so
      hidden_state[3] and hidden_state[4] stay zero and never pollute later layers.
      og code was just
      if mask is not None:
      mask = mask.unsqueeze(1).unsqueeze(2)
      raw_attention_logits = raw_attention_logits.masked_fill(mask == 0, float('-inf'))
      '''
      key_mask   = mask.unsqueeze(1).unsqueeze(2)      # B × 1 × 1 × L
      raw_attention_logits = raw_attention_logits.masked_fill(
                                        key_mask == 0, -1e4) # or -inf

            # --- mask rows (queries) ---
      query_mask = mask.unsqueeze(1).unsqueeze(3)      # B × 1 × L × 1
      raw_attention_logits = raw_attention_logits.masked_fill(query_mask == 0, -1e4)
    # So for each head, and for each token in the input (position i from 0 to L-1),
    # you're computing a score for how much attention that token i should give to every token j in the sequence.
    # scores[batch, head, i, :] → attention scores for token i over all tokens j=0 to L-1
    # why softmax(dim = - 1) -> ? we normalise in last dim across the scores of all j tokens for a single query token i.
    # so that scores for each token i sum to 1 — making them interpretable as probabilities.
    # and then w model does weighted average of the values (V) using these softmax scores. if more attention to one token,
    # more value we get from it.
    attention_weights = F.softmax(raw_attention_logits, dim=-1)
    # attn_weights[batch, head, i, j] = probability that token i attends to token j
    #This is why we apply softmax across dim=-1, i.e., across the columns of the score matrix, not the rows.
    # after softmax it’s a **good idea** (not strictly required) to zero rows again
        # so the resulting vectors are exactly 0 rather than an average over keys:
    if mask is not None:
        attention_weights = attention_weights * key_mask.float()
    attention_weights = self.dropout(attention_weights)
    #shape is still B x Heads x L x L
    #now we need to multiple with V [B x Heads x L x Head_dim] like we said before
    attention_output = torch.matmul(attention_weights, V) #B x Heads x L x Head_dim

    #now attention is still split accross heads. we combine it into final score.
    #B x Heads x L x Head_dim. transpose(1,2) -> B x L x Heads x Head_dim
    #this is converted to B x L x H as H = Heads x Head_dim
    #this is actually the concatenation step
    attention_output = attention_output.transpose(1,2).contiguous().view(B,L,H)

    #this is the final output that the multihead attention gives. it is passed thorugh a final linear layer.why?
    #see notion for full breakdown but TLDR : Although we concatenated multiple heads,
    #each head only processed a slice of the hidden dimension independently. So the vector looks like:
    #[head1_out | head2_out | ... | head12_out] → still somewhat siloed
    # To allow interaction between all these independently learned heads, you apply one final layer.
    final_output = self.out_proj(attention_output)  # [B, L, H] only the last dimension must match the input size
    #expected by nn.Linear. All the other (leading) dimensions — like batch size, sequence length, etc. —
    #are handled automatically by PyTorch.
    '''In simple terms:

    nn.Linear(in_features=768, out_features=768)
    expects an input tensor of shape:

    [..., 768]
    That ... can be any number of leading dimensions — PyTorch will apply the same linear transformation
    independently to each vector of size 768.'''
    return final_output, attention_weights #,attention_weights for visualisation.


In [None]:
class EncoderLayer(nn.Module):
  def __init__(self, num_heads, hidden_size, intermediate_size, dropout):

    super().__init__()

    self.self_attention = MultiHeadSelfAttention(num_heads,hidden_size,dropout)

    self.attention_norm = nn.LayerNorm(hidden_size)

    self.feed_forward = nn.Sequential(
        nn.Linear(hidden_size, intermediate_size),
        nn.GELU(), # this just adds non linear patterns thats it.
        nn.Linear(intermediate_size,hidden_size),
        nn.Dropout(dropout),
    )

    self.feed_forward_norm = nn.LayerNorm(hidden_size)

    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask = None): #x = 1 batch of dimensions = B x L x H
    #this layer doesnt change shape output is also B x L x H

    attention_output, attention_weights = self.self_attention(x,mask)

    x = self.attention_norm(x + self.dropout(attention_output))

    feed_forward_output = self.feed_forward(x)

    x = self.feed_forward_norm(x + self.dropout(feed_forward_output))

    return x, attention_weights

In [None]:
class BERTEncoder(nn.Module):
  def __init__(self,num_layers, num_heads, hidden_size, intermediate_size, dropout = 0.1):
    super().__init__()
    #we will stack multiple encoder layers.
    self.encoder_layers = nn.ModuleList([
        EncoderLayer(num_heads = num_heads,
                     hidden_size = hidden_size,
                     intermediate_size = intermediate_size,
                     dropout = dropout
        )
        for _ in range(num_layers)
    ])
  def forward(self, x, mask = None):
    all_attention_weights = []
    #x is again one batch of B x L x H. it is passed through all 12 layers
    #input of layer 2 is output of layer 1
    for layer in self.encoder_layers:
      x, attention_weights = layer(x,mask)
      all_attention_weights.append(attention_weights)
    return x, all_attention_weights # Final shape is also [B, L, H] as Each layer preserves the shape.


In [None]:
class TextTower(nn.Module):
    """
    Wraps your BERTEmbedding + BERTEncoder.
    Returns one vector per example (CLS or mask-aware mean).
    """
    def __init__(self,
                 vocab_size,
                 hidden_size=512,
                 num_heads=8,
                 num_layers=6,
                 intermediate_size=2048,
                 max_position_embeddings=77,   # CLIP uses short texts so max token length is smaller than typical bert.
                 type_vocab_size=2,
                 dropout=0.1,
                 pool="cls"):                   # "cls" or "mean" mean is you just average the embeddings of all non pad tokens. cls is just take the first
                 #special token from bert which captures the semantic meaning.
        super().__init__()
        self.embedding = BERTEmbedding(vocab_size, hidden_size, dropout,
                                       max_position_embeddings, type_vocab_size)
        #this returns the embedding -> B,L,H
        self.encoder   = BERTEncoder(num_layers, num_heads, hidden_size,
                                     intermediate_size, dropout)
        #Self attention, ff, dropout, layernorm -> Returns B,l,h too
        self.pool = pool

    def forward(self, input_ids, token_type_ids, attention_mask):
        """
        input_ids:        [B, L]
        token_type_ids:   [B, L] (can be zeros for single-sentence)
        attention_mask:   [B, L] (1=real, 0=pad)
        returns:          [B, H]
        """
        x = self.embedding(input_ids, token_type_ids)      # [B,L,H]
        x, _ = self.encoder(x, attention_mask)             # [B,L,H]
        if self.pool == "cls":
            return x[:, 0] # first token of every sentence. of size 768.
        # mask-aware mean pool
        mask = attention_mask.unsqueeze(-1)                # [B,L,1]
        return (x * mask).sum(1) / mask.sum(1).clamp_min(1e-6)
'''
The pooling is just our way of turning a sequence of token embeddings [B, L, H] into a single fixed-length vector [B, H] that represents the whole sentence (or prompt) so we can compare it to the image vector.

Why we need it in CLIP

Transformers keep one embedding per token.

CLIP’s contrastive loss needs one embedding per modality per example so it can compute a single similarity score between an image and a text.

Pooling condenses the whole sequence into one representative vector.
Why mask-aware mean

If we simply averaged over all positions, the padding tokens (which are meaningless) would dilute the representation.

Mask-aware mean:

Zeros out the padded positions.

Averages only over the “real” tokens.

The goal

Find a single point in the shared projection space that best captures the meaning of the entire text so that matched images and texts end up close together and mismatches far apart.

In practice:

CLS pooling: trust the model’s [CLS] token to gather global info via self-attention (what BERT is trained to do).

Mean pooling: average all non-pad token embeddings for a “global” semantic representation.

Which you choose can be a matter of taste; CLIP-style setups work with either.

'''

'\nThe pooling is just our way of turning a sequence of token embeddings [B, L, H] into a single fixed-length vector [B, H] that represents the whole sentence (or prompt) so we can compare it to the image vector.\n\nWhy we need it in CLIP\n\nTransformers keep one embedding per token.\n\nCLIP’s contrastive loss needs one embedding per modality per example so it can compute a single similarity score between an image and a text.\n\nPooling condenses the whole sequence into one representative vector.\nWhy mask-aware mean\n\nIf we simply averaged over all positions, the padding tokens (which are meaningless) would dilute the representation.\n\nMask-aware mean:\n\nZeros out the padded positions.\n\nAverages only over the “real” tokens.\n\nThe goal\n\nFind a single point in the shared projection space that best captures the meaning of the entire text so that matched images and texts end up close together and mismatches far apart.\n\nIn practice:\n\nCLS pooling: trust the model’s [CLS] token t

In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.proj = nn.Linear(in_dim, out_dim, bias=False)
    def forward(self, x):
        return self.proj(x)
#what is the projection head?
#See in dpr we have 2 text encoders. Both produce a embedding vector. in order to find similarity between them, we just project them onto a shared dimensional
#space so that we can compute the dot product or cosine similary

In [None]:
class DPRBiEncoder(nn.Module):
    """
    Two independent text towers (question and passage),
    each pooled to a single vector, projected to a shared space,
    L2-normalized, and compared with a scaled dot product.
    """
    def __init__(self, q_tower, p_tower, d_text=512, proj_dim=512, init_tau=0.07):
        super().__init__()
        self.q_tower = q_tower #entire encoder for the queries
        self.p_tower = p_tower #entire encoder for the documents.
        self.q_proj  = ProjectionHead(d_text, proj_dim) #768 -> Projection dim.
        self.p_proj  = ProjectionHead(d_text, proj_dim) #same
        # Learn log(1/tau) learn the confidence of predictions.
        self.logit_scale = nn.Parameter(torch.tensor(math.log(1.0 / init_tau), dtype=torch.float32))

    def _norm(self, x):
        return F.normalize(x, dim=-1) #we normalise the hidden_size 768 so that the parameters dont explode or become too small.

    @torch.no_grad() # this is during evaluation. so no gradients
    def encode_questions_eval(self, **q_inputs):
        self.eval()
        zq = self.encode_questions(**q_inputs)
        return zq

    @torch.no_grad() #same.
    def encode_passages_eval(self, **p_inputs):
        self.eval()
        zp = self.encode_passages(**p_inputs)
        return zp

    def encode_questions(self, input_ids, token_type_ids=None, attention_mask=None):
        h = self.q_tower(input_ids, token_type_ids if token_type_ids is not None else torch.zeros_like(input_ids), attention_mask)
        #h is done with embedding + self attention + ff. So now every token is aware of every other token
        z = self.q_proj(h) #we project the embeddings of every token to a shared space with the passages
        return self._norm(z) #we normalise these new embeddings.

    def encode_passages(self, input_ids, token_type_ids=None, attention_mask=None):
        h = self.p_tower(input_ids, token_type_ids if token_type_ids is not None else torch.zeros_like(input_ids), attention_mask)
        z = self.p_proj(h)
        return self._norm(z) #same explanation as above.

    def forward(self, q_batch, p_batch):
        zq = self.encode_questions(**q_batch)  # [B,D]
        zp = self.encode_passages(**p_batch)   # [B,D]
        scale = self.logit_scale.exp() #this is the tau thing
        logits_qp = scale * (zq @ zp.t())      # [B,B] these are the scores of every q to passage. So each row has scores of question and passage pair
        logits_pq = logits_qp.t()              # [B,B] #opposite since we did transpose.
        return logits_qp, logits_pq


def dpr_contrastive_loss(logits_qp, logits_pq):
    """Symmetric InfoNCE (like CLIP)."""
    B = logits_qp.size(0) #batch size.
    y = torch.arange(B, device=logits_qp.device)
    loss_q = F.cross_entropy(logits_qp, y)
    loss_p = F.cross_entropy(logits_pq, y)
    return 0.5 * (loss_q + loss_p)

'''Suppose your model outputs:
y just tells us which is the correct index for each q,p or p,q pair
logits_qp =
[[10,  2,  1],   # q0 vs [p0, p1, p2]
 [ 0,  8, -1],   # q1 vs [p0, p1, p2]
 [ 1, -2,  6]]   # q2 vs [p0, p1, p2]


Targets: y = [0, 1, 2].

Compute softmax per row (intuitively):

Row 0: softmax(10,2,1) ≈ [~0.999, ~0.0003, ~0.0001] → CE₀ ≈ −log(0.999) ≈ 0.001

Row 1: softmax(0,8,−1) ≈ [~0.0003, ~0.999, ~0.0001] → CE₁ ≈ −log(0.999) ≈ 0.001

Row 2: softmax(1,−2,6) ≈ [~0.007, ~0.000, ~0.993] → CE₂ ≈ −log(0.993) ≈ 0.007

loss_q ≈ mean([0.001, 0.001, 0.007]) ≈ 0.003.

Do the same on logits_pq = logits_qp.T to get loss_p, then average:

loss = 0.5 * (loss_q + loss_p)

Why this works (and the role of y)

Every row poses a B-way classification: “Which of these B candidates is the correct match?”

Because pairs are aligned by index, the correct class for row i is i → y = arange(B).

In-batch negatives are all the off-diagonals in that row/column, no extra mining needed.'''

'Suppose your model outputs:\ny just tells us which is the correct index for each q,p or p,q pair\nlogits_qp =\n[[10,  2,  1],   # q0 vs [p0, p1, p2]\n [ 0,  8, -1],   # q1 vs [p0, p1, p2]\n [ 1, -2,  6]]   # q2 vs [p0, p1, p2]\n\n\nTargets: y = [0, 1, 2].\n\nCompute softmax per row (intuitively):\n\nRow 0: softmax(10,2,1) ≈ [~0.999, ~0.0003, ~0.0001] → CE₀ ≈ −log(0.999) ≈ 0.001\n\nRow 1: softmax(0,8,−1) ≈ [~0.0003, ~0.999, ~0.0001] → CE₁ ≈ −log(0.999) ≈ 0.001\n\nRow 2: softmax(1,−2,6) ≈ [~0.007, ~0.000, ~0.993] → CE₂ ≈ −log(0.993) ≈ 0.007\n\nloss_q ≈ mean([0.001, 0.001, 0.007]) ≈ 0.003.\n\nDo the same on logits_pq = logits_qp.T to get loss_p, then average:\n\nloss = 0.5 * (loss_q + loss_p)\n\nWhy this works (and the role of y)\n\nEvery row poses a B-way classification: “Which of these B candidates is the correct match?”\n\nBecause pairs are aligned by index, the correct class for row i is i → y = arange(B).\n\nIn-batch negatives are all the off-diagonals in that row/column, no extra m

In [None]:
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
q_tok = DPRQuestionEncoderTokenizerFast.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
p_tok = DPRContextEncoderTokenizerFast .from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

def batch_q(questions, max_len=64):
    enc = q_tok(questions, return_tensors="pt", padding=True, truncation=True, max_length=max_len)
    return {k: v.to(device) for k, v in enc.items()}

def batch_p(titles, texts, max_len=256):
    enc = p_tok(text=titles, text_pair=texts, return_tensors="pt",
                padding=True, truncation=True, max_length=max_len)
    return {k: v.to(device) for k, v in enc.items()}

'''**dict turns {"k": v} into k=v in the call.

Unexpected key → error (unless the function has **kwargs).

Duplicates → error.

Filter before unpacking if you’re unsure what the function accepts.
'''

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/493 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/492 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizerFast'.


'**dict turns {"k": v} into k=v in the call.\n\nUnexpected key → error (unless the function has **kwargs).\n\nDuplicates → error.\n\nFilter before unpacking if you’re unsure what the function accepts.\n'

In [None]:
def make_squad_pairs(split="train[:10%]"):
    ds = load_dataset("squad", split=split)
    questions = ds["question"]
    contexts  = ds["context"]
    titles    = [""] * len(contexts)
    return list(zip(questions, titles, contexts))


# ------------------------------------------------------------
# 6) Corpus for indexing: unique contexts → (id, title, text)
# ------------------------------------------------------------
def build_corpus_from_squad(split="train[:10%]"):
    ds = load_dataset("squad", split=split)
    uniq = list(dict.fromkeys(ds["context"]))  # stable-unique
    corpus = [{"id": i, "title": "", "text": ctx} for i, ctx in enumerate(uniq)]
    text2id = {c["text"]: c["id"] for c in corpus}
    return corpus, text2id


# ------------------------------------------------------------
# 7) Training utility
# ------------------------------------------------------------
def train_dpr(model, pairs, steps=400, batch_size=32, lr=5e-5, q_max=64, p_max=256, log_every=20):
    model.train()
    opt = torch.optim.AdamW(model.parameters(), lr=lr)

    n = len(pairs)
    for step in range(steps):
        # simple cyclic batching
        start = (step * batch_size) % n
        end   = start + batch_size
        if end <= n:
            batch = pairs[start:end]
        else:
            batch = pairs[start:] + pairs[:(end % n)]
        qs, titles, ctxs = zip(*batch)

        q_batch = batch_q(qs, q_max)
        p_batch = batch_p(titles, ctxs, p_max)

        logits_qp, logits_pq = model(q_batch, p_batch)
        loss = dpr_contrastive_loss(logits_qp, logits_pq)

        opt.zero_grad()
        loss.backward()
        opt.step()

        if (step + 1) % log_every == 0:
            print(f"[train] step {step+1:04d} | loss {loss.item():.4f} | tau ~ {float(model.logit_scale.exp().reciprocal()):.4f}")


# ------------------------------------------------------------
# 8) Index building (FAISS) & retrieval
# ------------------------------------------------------------
def build_faiss_index(model, corpus, batch_size=64, p_max=256):
    model.eval()
    # Encode passages
    all_vecs = []
    for i in range(0, len(corpus), batch_size):
        titles = [c["title"] for c in corpus[i:i+batch_size]]
        texts  = [c["text"]  for c in corpus[i:i+batch_size]]
        with torch.no_grad():
            zp = model.encode_passages(**batch_p(titles, texts, p_max))  # [b,D], L2-normalized
        all_vecs.append(zp.cpu().numpy().astype("float32"))
    doc_embs = np.vstack(all_vecs)  # [N,D]

    dim = doc_embs.shape[1]
    index = faiss.IndexFlatIP(dim)     # cosine/IP on unit vectors
    index.add(doc_embs)                # add all
    id_map = np.array([c["id"] for c in corpus], dtype=np.int64)
    return index, id_map, doc_embs


@torch.no_grad()
def encode_question(model, question, q_max=64):
    zq = model.encode_questions(**batch_q([question], q_max))  # [1,D]
    return zq.cpu().numpy().astype("float32")


def retrieve(index, id_map, corpus, question, model, k=5, q_max=64):
    q = encode_question(model, question, q_max)   # (1,D)
    scores, idxs = index.search(q, k)             # (1,k)
    idxs = idxs[0]; scores = scores[0]
    hits = []
    for rank, (row, score) in enumerate(zip(idxs, scores), 1):
        doc = corpus[int(row)]
        hits.append({
            "rank": rank,
            "faiss_row": int(row),
            "score": float(score),
            "id": int(doc["id"]),
            "title": doc["title"],
            "text": doc["text"][:200] + ("..." if len(doc["text"]) > 200 else "")
        })
    return hits


# ------------------------------------------------------------
# 9) Simple Recall@k evaluation (does gold context appear in top-k?)
#    For SQuAD, we map each question’s positive context to its doc id.
# ------------------------------------------------------------
def eval_recall_at_k(model, index, id_map, corpus_text2id, eval_pairs, ks=(5, 20, 100), q_max=64):
    model.eval()
    counts = {k: 0 for k in ks}
    total  = 0
    for q, _title, pos_ctx in eval_pairs:
        total += 1
        pos_id = corpus_text2id.get(pos_ctx, None)
        if pos_id is None:
            # context not in corpus (shouldn't happen with our build); skip
            continue
        zq = model.encode_questions(**batch_q([q], q_max)).cpu().numpy().astype("float32")
        scores, idxs = index.search(zq, max(ks))
        pred_ids = idxs[0]  # FAISS row indices
        # row index == corpus order because we added in that order
        # We need to check whether the row whose "id" equals pos_id is in the top-k rows.
        # Build mapping row->doc_id
        for k in ks:
            top_rows = pred_ids[:k]
            # find if any of these rows correspond to the gold doc id
            hit = False
            for row in top_rows:
                if row == pos_id:  # because we assigned doc["id"] = its insertion index
                    hit = True
                    break
            if hit:
                counts[k] += 1

    results = {f"R@{k}": counts[k] / max(total, 1) for k in ks}
    return results


# ------------------------------------------------------------
# 10) Main
# ------------------------------------------------------------
def main():
    # ---- Build towers from your BERT stack ----
    # IMPORTANT: set vocab_size to the DPR tokenizers' vocab size (they're BERT-based and equal).
    vocab_size = len(q_tok)  # same as len(p_tok)

    q_tower = TextTower(
        vocab_size=vocab_size,
        hidden_size=512, num_heads=8, num_layers=6, intermediate_size=2048,
        max_position_embeddings=64, type_vocab_size=2, dropout=0.1, pool="cls"
    )

    p_tower = TextTower(
        vocab_size=vocab_size,
        hidden_size=512, num_heads=8, num_layers=6, intermediate_size=2048,
        max_position_embeddings=256, type_vocab_size=2, dropout=0.1, pool="cls"
    )

    model = DPRBiEncoder(q_tower, p_tower, d_text=512, proj_dim=512, init_tau=0.07).to(device)

    # ---- Data ----
    train_pairs = make_squad_pairs("train[:10%]")   # (q, "", ctx)
    eval_pairs  = make_squad_pairs("validation[:10%]")

    # ---- Train ----
    print("Training DPR (contrastive, in-batch negatives)...")
    train_dpr(model, train_pairs, steps=400, batch_size=32, lr=5e-5, q_max=64, p_max=256, log_every=25)

    # ---- Build corpus, index, and evaluate ----
    print("\nBuilding corpus and FAISS index...")
    corpus, text2id = build_corpus_from_squad("train[:10%]")
    index, id_map, doc_embs = build_faiss_index(model, corpus, batch_size=64, p_max=256)

    print("\nQuick retrieval sanity check:")
    for q in [
        "Who wrote the book 'Pride and Prejudice'?",
        "What is the capital of France?",
        "Who discovered penicillin?"
    ]:
        hits = retrieve(index, id_map, corpus, q, model, k=3, q_max=64)
        print(f"\nQ: {q}")
        for h in hits:
            print(f"  {h['rank']}. score={h['score']:.3f}  id={h['id']}  {h['text']}")

    print("\nEvaluating Recall@k on a small split...")
    metrics = eval_recall_at_k(model, index, id_map, text2id, eval_pairs, ks=(5, 20, 100), q_max=64)
    print({k: f"{v:.3f}" for k, v in metrics.items()})

    # ---- Save (optional) ----
    torch.save(model.state_dict(), "dpr_biencoder.pt")
    print("\nSaved DPR weights to dpr_biencoder.pt")


In [None]:
if __name__ == "__main__":
    main()

README.md: 0.00B [00:00, ?B/s]

plain_text/train-00000-of-00001.parquet:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

plain_text/validation-00000-of-00001.par(…):   0%|          | 0.00/1.82M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

Training DPR (contrastive, in-batch negatives)...


Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  print(f"[train] step {step+1:04d} | loss {loss.item():.4f} | tau ~ {float(model.logit_scale.exp().reciprocal()):.4f}")


[train] step 0025 | loss 3.4779 | tau ~ 0.0701
[train] step 0050 | loss 3.4701 | tau ~ 0.0701
[train] step 0075 | loss 3.4729 | tau ~ 0.0702
[train] step 0100 | loss 3.4587 | tau ~ 0.0702
[train] step 0125 | loss 3.4661 | tau ~ 0.0702
[train] step 0150 | loss 3.4641 | tau ~ 0.0702
[train] step 0175 | loss 3.4744 | tau ~ 0.0702
[train] step 0200 | loss 3.4649 | tau ~ 0.0703
[train] step 0225 | loss 3.4734 | tau ~ 0.0703
[train] step 0250 | loss 3.4611 | tau ~ 0.0703
[train] step 0275 | loss 3.4611 | tau ~ 0.0703
[train] step 0300 | loss 3.4730 | tau ~ 0.0703
[train] step 0325 | loss 3.4660 | tau ~ 0.0704
[train] step 0350 | loss 3.4754 | tau ~ 0.0704
[train] step 0375 | loss 3.4688 | tau ~ 0.0704
[train] step 0400 | loss 3.4687 | tau ~ 0.0705

Building corpus and FAISS index...

Quick retrieval sanity check:

Q: Who wrote the book 'Pride and Prejudice'?
  1. score=0.109  id=492  To promote the film, production continued the trend established during Skyfall's production of releasing stil