In [2]:
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel


#########################################
# 1. Helper: masked average pooling
#########################################

def masked_average_pool(last_hidden_states: Tensor,
                        pooling_mask: Tensor,
                        label: str = "") -> Tensor:
    """
    last_hidden_states: [batch, seq_len, hidden_dim]
    pooling_mask:       [batch, seq_len] with 1 for tokens to include, 0 otherwise
    label:              optional label for logging
    """
    batch_size, seq_len, hidden_dim = last_hidden_states.shape
    
    if label:
        print(f"\n[{label}] Pooling Operation:")
        print(f"  Input shape: {last_hidden_states.shape}")
        print(f"  Pooling mask shape: {pooling_mask.shape}")
    
    # Expand mask to [batch, seq_len, 1]
    mask = pooling_mask[..., None].bool()

    # Zero-out tokens we don't want to pool over
    masked_hidden = last_hidden_states.masked_fill(~mask, 0.0)

    # Sum over tokens
    summed = masked_hidden.sum(dim=1)

    # Count how many tokens are actually included per example
    counts = pooling_mask.sum(dim=1)[..., None].clamp(min=1)
    
    if label:
        for i in range(min(batch_size, 3)):  # Show first 3 examples
            num_tokens = int(counts[i, 0].item())
            total_seq_len = (pooling_mask[i] != 0).sum().item()
            print(f"  Batch {i}: pooling over {num_tokens}/{seq_len} tokens")
            if total_seq_len < seq_len:
                print(f"    (sequence padded from {total_seq_len} to {seq_len})")

    # Return average
    result = summed / counts
    
    if label:
        print(f"  Output shape: {result.shape}")
    
    return result


#########################################
# 2. Build inputs with explicit context
#########################################

def build_inputs_with_context(contexts, texts, tokenizer):
    """
    contexts: list[str]  -- extra context for each example
    texts:    list[str]  -- main text we want to embed ("query: ...", "passage: ...")

    Returns:
      batch_dict      -- dict with input_ids, attention_mask for the model
      pooling_mask    -- tensor [batch, seq_len], 1 only on "text" tokens
    """
    cls_id = tokenizer.cls_token_id
    sep_id = tokenizer.sep_token_id

    input_ids = []
    attention_masks = []
    pooling_masks = []

    for ctx, txt in zip(contexts, texts):
        # Tokenize WITHOUT special tokens so we can control layout
        ctx_ids = tokenizer.encode(ctx, add_special_tokens=False)
        txt_ids = tokenizer.encode(txt, add_special_tokens=False)

        # Layout: [CLS] context [SEP] text [SEP]
        ids = [cls_id] + ctx_ids + [sep_id] + txt_ids + [sep_id]

        # Model attention mask: attend to everything except padding
        attn = [1] * len(ids)

        # Pooling mask:
        #   0 for [CLS], context, and first [SEP]
        #   1 for text tokens
        #   0 for final [SEP]
        num_ctx = len(ctx_ids)
        num_txt = len(txt_ids)

        pool = (
            [0] * (1 + num_ctx + 1) +  # [CLS] + context + [SEP]
            [1] * num_txt +            # text tokens
            [0]                        # final [SEP]
        )

        input_ids.append(ids)
        attention_masks.append(attn)
        pooling_masks.append(pool)

    # Pad sequences to same length for batching
    batch = tokenizer.pad(
        {"input_ids": input_ids, "attention_mask": attention_masks},
        padding=True,
        return_tensors="pt",
    )

    # We can reuse tokenizer.pad just to pad the pooling masks with 0
    pooling_masks = tokenizer.pad(
        {"input_ids": pooling_masks},
        padding=True,
        return_tensors="pt",
    )["input_ids"]

    return batch, pooling_masks


def compute_and_print_scores(query_embedding: Tensor, 
                              passage_embeddings: Tensor,
                              approach_name: str) -> Tensor:
    """
    Compute similarity scores and print top 3 results.
    Returns the scores tensor and top indices.
    """
    scores = (query_embedding @ passage_embeddings.T) * 100
    print(f"\n{approach_name} - Similarity scores (top 3):")
    top_indices = scores[0].argsort(descending=True)[:3]
    for idx in top_indices:
        print(f"  Query vs Passage {idx.item() + 1}: {scores[0, idx].item():.2f}")
    return scores, top_indices


#########################################
# 3. Example data - CONTRIVED SCENARIO
#########################################

query = "query: I think she's most known for being a pharoah, but I'm not sure.  What is she known for?"

passages = [
    "passage: Marie Antoinette is most known for being the last Queen of France before the French Revolution and her alleged quote 'Let them eat cake'.",
    "passage: Cleopatra is most known for being the last pharaoh of Ancient Egypt and her relationships with Julius Caesar and Mark Antony.",
    "passage: Joan of Arc is most known for leading the French army against English occupation during the Hundred Years' War.",
    "passage: Rosa Parks is most known for refusing to give up her bus seat to a white passenger, sparking the Montgomery Bus Boycott.",
    "passage: Marie Curie is most known for being the first woman to win a Nobel Prize and discovering the elements polonium and radium.",
    "passage: Frida Kahlo is most known for her self-portraits and paintings that explored identity, postcolonialism, and the female experience.",
    "passage: Virginia Woolf is most known for her modernist novels like Mrs. Dalloway and To the Lighthouse.",
    "passage: Eleanor Roosevelt is most known for her role in drafting the Universal Declaration of Human Rights and her advocacy for civil rights.",
    "passage: Amelia Earhart is most known for being the first female aviator to fly solo across the Atlantic Ocean.",
]

context = "My favorite woman in history is Marie Curie"

# Print the data we'll be working with
print("Context:", context)
print("Query:", query)
print("Passages:")
for i, passage in enumerate(passages):
    print(f"  {i+1}. {passage}")


#########################################
# 4. Load model & tokenizer
#########################################

tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-large-v2")
model = AutoModel.from_pretrained("intfloat/e5-large-v2")


#########################################
# 4.5. Compute passage embeddings ONCE (used by all approaches)
#########################################
print("\n\n### Computing passage embeddings (once) ###")

passage_batch = tokenizer(passages, padding=True, return_tensors="pt")

with torch.no_grad():
    passage_outputs = model(**passage_batch)

# Standard average pooling for passages
passage_embeddings = masked_average_pool(
    passage_outputs.last_hidden_state,
    passage_batch["attention_mask"],
    label="Passage embeddings"
)
passage_embeddings = F.normalize(passage_embeddings, p=2, dim=1)
print(f"Computed {len(passages)} passage embeddings with shape {passage_embeddings.shape}")


#########################################
# 5. APPROACH 1: Context in attention, NOT in pooling
#########################################
print("\n\n### APPROACH 1: Context influences via attention only ###")

# Only compute query embedding with context
batch, pooling_mask = build_inputs_with_context(
    contexts=[context],
    texts=[query],
    tokenizer=tokenizer,
)

with torch.no_grad():
    outputs = model(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],   # <- attention over context + text
    )

last_hidden = outputs.last_hidden_state

# Pool ONLY over the "main text" tokens (exclude context)
query_embedding = masked_average_pool(last_hidden, pooling_mask, label="Approach 1 - Query")
query_embedding = F.normalize(query_embedding, p=2, dim=1)

# Calculate and print scores
scores, top_indices = compute_and_print_scores(query_embedding, passage_embeddings, "Approach 1")


#########################################
# 6. APPROACH 2: NO context at all
#########################################
print("\n\n### APPROACH 2: No context (baseline) ###")

# Just embed the query without any context
no_context_batch = tokenizer([query], padding=True, return_tensors="pt")

with torch.no_grad():
    no_context_outputs = model(**no_context_batch)

# Standard average pooling over all non-padding tokens
query_embedding_nc = masked_average_pool(
    no_context_outputs.last_hidden_state,
    no_context_batch["attention_mask"],
    label="Approach 2 - Query"
)
query_embedding_nc = F.normalize(query_embedding_nc, p=2, dim=1)

# Calculate and print scores
scores_nc, top_indices_nc = compute_and_print_scores(query_embedding_nc, passage_embeddings, "Approach 2")


#########################################
# 7. APPROACH 3: Context included in embedding
#########################################
print("\n\n### APPROACH 3: Context included in the embedding ###")

# Concatenate context directly into the query
query_with_context = f"{context} {query}"
embedded_batch = tokenizer([query_with_context], padding=True, return_tensors="pt")

with torch.no_grad():
    embedded_outputs = model(**embedded_batch)

# Pool over everything (context + text both in embedding for query)
query_embedding_emb = masked_average_pool(
    embedded_outputs.last_hidden_state,
    embedded_batch["attention_mask"],
    label="Approach 3 - Query"
)
query_embedding_emb = F.normalize(query_embedding_emb, p=2, dim=1)

# Calculate and print scores
scores_emb, top_indices_emb = compute_and_print_scores(query_embedding_emb, passage_embeddings, "Approach 3")


#########################################
# 8. APPROACH 4: Standard E5 embedding (reference implementation)
#########################################
print("\n\n### APPROACH 4: Standard E5 embedding calculation ###")

# Use the standard E5 approach: average pool over all tokens (excluding padding)
# This should match Approach 2 exactly
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    """Standard average pooling used by E5 and similar models."""
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

# Embed query without context
batch_standard = tokenizer([query], padding=True, return_tensors="pt")

with torch.no_grad():
    outputs_standard = model(**batch_standard)

query_embedding_standard = average_pool(outputs_standard.last_hidden_state, batch_standard["attention_mask"])
query_embedding_standard = F.normalize(query_embedding_standard, p=2, dim=1)

# Calculate and print scores
scores_standard, top_indices_standard = compute_and_print_scores(query_embedding_standard, passage_embeddings, "Approach 4")

# Verify this matches Approach 2
print(f"\nVerification: Approach 4 vs Approach 2")
print(f"  Max difference in scores: {(scores_standard - scores_nc).abs().max().item():.6f}")
print(f"  Embeddings identical: {torch.allclose(query_embedding_standard, query_embedding_nc, atol=1e-6)}")


#########################################
# 9. Summary
#########################################
print("\n\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print("\nApproach 1 (Context in attention only):")
print("  ✓ Query embedding represents ONLY the query text")
print("  ✓ Context influences interpretation via attention")
print("  ✓ Query embedding remains comparable to passages")
print(f"  Top match: Passage {top_indices[0].item() + 1} (score: {scores[0, top_indices[0]].item():.2f})")

print("\nApproach 2 (No context):")
print("  ✗ No way to use context to influence query interpretation")
print(f"  Top match: Passage {top_indices_nc[0].item() + 1} (score: {scores_nc[0, top_indices_nc[0]].item():.2f})")

print("\nApproach 3 (Context in embedding):")
print("  ✓ Can use context to influence query")
print("  ✗ Query embedding now represents 'context + query', not just query")
print("  ✗ Changes the semantic representation of the query itself")
print(f"  Top match: Passage {top_indices_emb[0].item() + 1} (score: {scores_emb[0, top_indices_emb[0]].item():.2f})")

print("\nApproach 4 (Standard E5):")
print("  • Reference implementation using standard average pooling")
print("  • Should match Approach 2 exactly")
print(f"  Top match: Passage {top_indices_standard[0].item() + 1} (score: {scores_standard[0, top_indices_standard[0]].item():.2f})")
print("=" * 60)


Context: My favorite woman in history is Marie Curie
Query: query: I think she's most known for being a pharoah, but I'm not sure.  What is she known for?
Passages:
  1. passage: Marie Antoinette is most known for being the last Queen of France before the French Revolution and her alleged quote 'Let them eat cake'.
  2. passage: Cleopatra is most known for being the last pharaoh of Ancient Egypt and her relationships with Julius Caesar and Mark Antony.
  3. passage: Joan of Arc is most known for leading the French army against English occupation during the Hundred Years' War.
  4. passage: Rosa Parks is most known for refusing to give up her bus seat to a white passenger, sparking the Montgomery Bus Boycott.
  5. passage: Marie Curie is most known for being the first woman to win a Nobel Prize and discovering the elements polonium and radium.
  6. passage: Frida Kahlo is most known for her self-portraits and paintings that explored identity, postcolonialism, and the female experience.


You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.



[Passage embeddings] Pooling Operation:
  Input shape: torch.Size([9, 33, 1024])
  Pooling mask shape: torch.Size([9, 33])
  Batch 0: pooling over 31/33 tokens
    (sequence padded from 31 to 33)
  Batch 1: pooling over 26/33 tokens
    (sequence padded from 26 to 33)
  Batch 2: pooling over 25/33 tokens
    (sequence padded from 25 to 33)
  Output shape: torch.Size([9, 1024])
Computed 9 passage embeddings with shape torch.Size([9, 1024])


### APPROACH 1: Context influences via attention only ###

[Approach 1 - Query] Pooling Operation:
  Input shape: torch.Size([1, 41, 1024])
  Pooling mask shape: torch.Size([1, 41])
  Batch 0: pooling over 29/41 tokens
    (sequence padded from 29 to 41)
  Output shape: torch.Size([1, 1024])

Approach 1 - Similarity scores (top 3):
  Query vs Passage 5: 85.40
  Query vs Passage 2: 81.55
  Query vs Passage 1: 79.33


### APPROACH 2: No context (baseline) ###

[Approach 2 - Query] Pooling Operation:
  Input shape: torch.Size([1, 31, 1024])
  Pooling 