In [13]:
import torch
import torch.nn.functional as F
from torch import Tensor
from sentence_transformers import SentenceTransformer


#########################################
# 1. Helper: masked average pooling
#########################################

def masked_average_pool(last_hidden_states: Tensor,
                        pooling_mask: Tensor,
                        label: str = "") -> Tensor:
    """
    last_hidden_states: [batch, seq_len, hidden_dim]
    pooling_mask:       [batch, seq_len] with 1 for tokens to include, 0 otherwise
    label:              optional label for logging
    """
    batch_size, seq_len, hidden_dim = last_hidden_states.shape
    
    if label:
        print(f"\n[{label}] Pooling Operation:")
        print(f"  Input shape: {last_hidden_states.shape}")
        print(f"  Pooling mask shape: {pooling_mask.shape}")
    
    # Expand mask to [batch, seq_len, 1]
    mask = pooling_mask[..., None].bool()

    # Zero-out tokens we don't want to pool over
    masked_hidden = last_hidden_states.masked_fill(~mask, 0.0)

    # Sum over tokens
    summed = masked_hidden.sum(dim=1)

    # Count how many tokens are actually included per example
    counts = pooling_mask.sum(dim=1)[..., None].clamp(min=1)
    
    if label:
        for i in range(min(batch_size, 3)):  # Show first 3 examples
            num_tokens = int(counts[i, 0].item())
            total_seq_len = (pooling_mask[i] != 0).sum().item()
            print(f"  Batch {i}: pooling over {num_tokens}/{seq_len} tokens")
            if total_seq_len < seq_len:
                print(f"    (sequence padded from {total_seq_len} to {seq_len})")

    # Return average
    result = summed / counts
    
    if label:
        print(f"  Output shape: {result.shape}")
    
    return result


#########################################
# 2. Build inputs with explicit context for custom approach
#########################################

def build_inputs_with_context_custom(contexts, texts, tokenizer):
    """
    Build custom inputs where context is in attention but not in pooling.
    
    contexts: list[str]  -- extra context for each example
    texts:    list[str]  -- main text we want to embed

    Returns:
      batch_dict      -- dict with input_ids, attention_mask for the model
      pooling_mask    -- tensor [batch, seq_len], 1 only on "text" tokens
    """
    bos_id = tokenizer.bos_token_id
    eos_id = tokenizer.eos_token_id

    input_ids = []
    attention_masks = []
    pooling_masks = []

    for ctx, txt in zip(contexts, texts):
        # Tokenize WITHOUT special tokens so we can control layout
        ctx_ids = tokenizer.encode(ctx, add_special_tokens=False)
        txt_ids = tokenizer.encode(txt, add_special_tokens=False)

        # Layout: [BOS] context [EOS] text [EOS]
        ids = [bos_id] + ctx_ids + [eos_id] + txt_ids + [eos_id]

        # Model attention mask: attend to everything except padding
        attn = [1] * len(ids)

        # Pooling mask:
        #   0 for [BOS], context, and first [EOS]
        #   1 for text tokens
        #   0 for final [EOS]
        num_ctx = len(ctx_ids)
        num_txt = len(txt_ids)

        pool = (
            [0] * (1 + num_ctx + 1) +  # [BOS] + context + [EOS]
            [1] * num_txt +            # text tokens
            [0]                        # final [EOS]
        )

        input_ids.append(ids)
        attention_masks.append(attn)
        pooling_masks.append(pool)

    # Pad sequences to same length for batching
    batch = tokenizer.pad(
        {"input_ids": input_ids, "attention_mask": attention_masks},
        padding=True,
        return_tensors="pt",
    )

    # Pad the pooling masks with 0
    pooling_masks = tokenizer.pad(
        {"input_ids": pooling_masks},
        padding=True,
        return_tensors="pt",
    )["input_ids"]

    return batch, pooling_masks


def apply_dense_layers(embeddings: Tensor, dense_layers: list) -> Tensor:
    """
    Apply Dense projection layers to embeddings.
    EmbeddingGemma has 2 Dense layers after pooling.
    """
    with torch.no_grad():
        features = {"sentence_embedding": embeddings}
        for dense in dense_layers:
            features = dense(features)
        # Clone to avoid inference mode issues
        return features["sentence_embedding"].clone()


def compute_and_print_scores(query_embedding: Tensor, 
                              passage_embeddings: Tensor,
                              approach_name: str) -> Tensor:
    """
    Compute similarity scores and print top 3 results.
    Returns the scores tensor and top indices.
    """
    # Ensure query_embedding is 2D
    if query_embedding.dim() == 1:
        query_embedding = query_embedding.unsqueeze(0)
    
    scores = (query_embedding @ passage_embeddings.T) * 100
    print(f"\n{approach_name} - Similarity scores (top 3):")
    top_indices = scores[0].argsort(descending=True)[:3]
    for idx in top_indices:
        print(f"  Query vs Passage {idx.item() + 1}: {scores[0, idx].item():.2f}")
    return scores, top_indices


#########################################
# 3. Example data - CONTRIVED SCENARIO
#########################################

query = "I think she's most known for being a pharoah, but I'm not sure.  What is she known for?"

passages = [
    "Marie Antoinette is most known for being the last Queen of France before the French Revolution and her alleged quote 'Let them eat cake'.",
    "Cleopatra is most known for being the last pharaoh of Ancient Egypt and her relationships with Julius Caesar and Mark Antony.",
    "Joan of Arc is most known for leading the French army against English occupation during the Hundred Years' War.",
    "Rosa Parks is most known for refusing to give up her bus seat to a white passenger, sparking the Montgomery Bus Boycott.",
    "Marie Curie is most known for being the first woman to win a Nobel Prize and discovering the elements polonium and radium.",
    "Frida Kahlo is most known for her self-portraits and paintings that explored identity, postcolonialism, and the female experience.",
    "Virginia Woolf is most known for her modernist novels like Mrs. Dalloway and To the Lighthouse.",
    "Eleanor Roosevelt is most known for her role in drafting the Universal Declaration of Human Rights and her advocacy for civil rights.",
    "Amelia Earhart is most known for being the first female aviator to fly solo across the Atlantic Ocean.",
]

context = "My favorite woman in history is Marie Curie"

# Print the data we'll be working with
print("Context:", context)
print("Query:", query)
print("Passages:")
for i, passage in enumerate(passages):
    print(f"  {i+1}. {passage}")


#########################################
# 4. Load model
#########################################

print("\n\n### Loading EmbeddingGemma model ###")
model = SentenceTransformer("google/embeddinggemma-300m")

# Get the underlying components for custom approaches
tokenizer = None
base_transformer = None
dense_layers = []

for i, module in enumerate(model):
    print(f"  Layer {i}: {type(module).__name__}")
    if hasattr(module, 'tokenizer'):
        tokenizer = module.tokenizer
    if hasattr(module, 'auto_model'):
        base_transformer = module.auto_model
    # Collect Dense layers (there are 2 in EmbeddingGemma)
    if type(module).__name__ == 'Dense':
        dense_layers.append(module)

if tokenizer is None:
    raise ValueError("Could not find tokenizer in model modules")
if base_transformer is None:
    raise ValueError("Could not find base transformer in model modules")

print(f"\nLoaded model with tokenizer: {type(tokenizer).__name__}")
print(f"Base transformer: {type(base_transformer).__name__}")
print(f"Found {len(dense_layers)} Dense projection layers")


#########################################
# 4.5. Compute passage embeddings ONCE (used by all approaches)
#########################################
print("\n\n### Computing passage embeddings (once) ###")

# Use the standard encode_document method - already normalized
passage_embeddings = model.encode_document(passages, convert_to_tensor=True)
print(f"Computed {len(passages)} passage embeddings with shape {passage_embeddings.shape}")


#########################################
# 5. APPROACH 1: Context in attention, NOT in pooling (CORE EXPERIMENT)
#########################################
print("\n\n### APPROACH 1: Context in attention only (custom masked pooling) ###")
print("This is the CORE approach: context influences via attention, but we pool only over query tokens")

# Get device
device = base_transformer.device

# Add the standard EmbeddingGemma query prompt to the query text only
# The context remains unprompted
query_prompt = "task: search result | query: "
query_with_prompt = f"{query_prompt}{query}"

# Build custom inputs with context in attention but not in pooling
batch, pooling_mask = build_inputs_with_context_custom(
    contexts=[context],
    texts=[query_with_prompt],
    tokenizer=tokenizer,
)

# Move to device
batch = {k: v.to(device) for k, v in batch.items()}
pooling_mask = pooling_mask.to(device)

with torch.no_grad():
    # Use the base transformer directly
    outputs = base_transformer(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],
    )

last_hidden = outputs.last_hidden_state

# Pool ONLY over the "main text" tokens (exclude context)
query_embedding = masked_average_pool(last_hidden, pooling_mask, label="Approach 1 - Query")

# Apply Dense projection layers (same as standard pipeline)
query_embedding = apply_dense_layers(query_embedding, dense_layers)

# Normalize
query_embedding = F.normalize(query_embedding, p=2, dim=1)

# Calculate and print scores
scores, top_indices = compute_and_print_scores(query_embedding, passage_embeddings, "Approach 1")


#########################################
# 6. APPROACH 2: NO context at all (baseline)
#########################################
print("\n\n### APPROACH 2: No context (baseline) ###")

# Add the standard EmbeddingGemma query prompt
query_prompt = "task: search result | query: "
query_with_prompt = f"{query_prompt}{query}"

# Just embed the query without any context (but with prompt)
query_no_context_batch = tokenizer([query_with_prompt], padding=True, return_tensors="pt")
query_no_context_batch = {k: v.to(device) for k, v in query_no_context_batch.items()}

with torch.no_grad():
    query_no_context_outputs = base_transformer(**query_no_context_batch)

# Standard average pooling over all non-padding tokens
query_embedding_nc = masked_average_pool(
    query_no_context_outputs.last_hidden_state,
    query_no_context_batch["attention_mask"],
    label="Approach 2 - Query"
)

# Apply Dense projection layers
query_embedding_nc = apply_dense_layers(query_embedding_nc, dense_layers)

# Normalize
query_embedding_nc = F.normalize(query_embedding_nc, p=2, dim=1)

# Calculate and print scores
scores_nc, top_indices_nc = compute_and_print_scores(query_embedding_nc, passage_embeddings, "Approach 2")


#########################################
# 7. APPROACH 3: Context included in embedding
#########################################
print("\n\n### APPROACH 3: Context included in the embedding ###")

# Add the standard EmbeddingGemma query prompt
query_prompt = "task: search result | query: "

# Concatenate context directly into the query (with prompt)
query_with_context = f"{context} {query_prompt}{query}"
query_context_batch = tokenizer([query_with_context], padding=True, return_tensors="pt")
query_context_batch = {k: v.to(device) for k, v in query_context_batch.items()}

with torch.no_grad():
    query_context_outputs = base_transformer(**query_context_batch)

# Pool over everything (context + text both in embedding)
query_embedding_emb = masked_average_pool(
    query_context_outputs.last_hidden_state,
    query_context_batch["attention_mask"],
    label="Approach 3 - Query"
)

# Apply Dense projection layers
query_embedding_emb = apply_dense_layers(query_embedding_emb, dense_layers)

# Normalize
query_embedding_emb = F.normalize(query_embedding_emb, p=2, dim=1)

# Calculate and print scores
scores_emb, top_indices_emb = compute_and_print_scores(query_embedding_emb, passage_embeddings, "Approach 3")


#########################################
# 8. APPROACH 4: Standard EmbeddingGemma (for comparison)
#########################################
print("\n\n### APPROACH 4: Standard EmbeddingGemma encode_query method ###")

# Use the built-in encode_query method (includes prompting)
query_embedding_standard = model.encode_query(query, convert_to_tensor=True)
similarities = model.similarity(query_embedding_standard, passage_embeddings) * 100

print(f"\nApproach 4 - Similarity scores (top 3):")
top_indices_standard = similarities[0].argsort(descending=True)[:3]
for idx in top_indices_standard:
    print(f"  Query vs Passage {idx.item() + 1}: {similarities[0, idx].item():.2f}")

# Verification: Check if Approach 2 and 4 match exactly
print(f"\nVerification: Approach 4 vs Approach 2")
manual_sim = (query_embedding_nc @ passage_embeddings.T) * 100
print(f"  Max difference in scores: {(similarities - manual_sim).abs().max().item():.6f}")
print(f"  Embeddings identical: {torch.allclose(query_embedding_standard, query_embedding_nc, atol=1e-6)}")


#########################################
# 9. Summary
#########################################
print("\n\n" + "=" * 60)
print("SUMMARY - EmbeddingGemma with Custom Masked Pooling")
print("=" * 60)

print("\nApproach 1 (Context in attention only - CORE EXPERIMENT):")
print("  ✓ Query embedding represents ONLY the query text")
print("  ✓ Context influences interpretation via attention")
print("  ✓ Query embedding remains comparable to passages")
print(f"  Top match: Passage {top_indices[0].item() + 1} (score: {scores[0, top_indices[0]].item():.2f})")

print("\nApproach 2 (No context):")
print("  ✗ No way to use context to influence query interpretation")
print(f"  Top match: Passage {top_indices_nc[0].item() + 1} (score: {scores_nc[0, top_indices_nc[0]].item():.2f})")

print("\nApproach 3 (Context in embedding):")
print("  ✓ Can use context to influence query")
print("  ✗ Query embedding now represents 'context + query', not just query")
print("  ✗ Changes the semantic representation of the query itself")
print(f"  Top match: Passage {top_indices_emb[0].item() + 1} (score: {scores_emb[0, top_indices_emb[0]].item():.2f})")

print("\nApproach 4 (Standard EmbeddingGemma with prompts):")
print("  • Uses built-in encode_query with task-specific prompts")
print("  • Different from Approach 2 due to prompt engineering")
print(f"  Top match: Passage {top_indices_standard[0].item() + 1} (score: {similarities[0, top_indices_standard[0]].item():.2f})")

print("\n" + "=" * 60)
print("KEY INSIGHT: Approach 1 demonstrates masked pooling where")
print("context is used during attention (influencing hidden states)")
print("but excluded from the final embedding via selective pooling.")
print("This allows context to guide interpretation without changing")
print("the semantic space of the query embedding.")
print("=" * 60)


Context: My favorite woman in history is Marie Curie
Query: I think she's most known for being a pharoah, but I'm not sure.  What is she known for?
Passages:
  1. Marie Antoinette is most known for being the last Queen of France before the French Revolution and her alleged quote 'Let them eat cake'.
  2. Cleopatra is most known for being the last pharaoh of Ancient Egypt and her relationships with Julius Caesar and Mark Antony.
  3. Joan of Arc is most known for leading the French army against English occupation during the Hundred Years' War.
  4. Rosa Parks is most known for refusing to give up her bus seat to a white passenger, sparking the Montgomery Bus Boycott.
  5. Marie Curie is most known for being the first woman to win a Nobel Prize and discovering the elements polonium and radium.
  6. Frida Kahlo is most known for her self-portraits and paintings that explored identity, postcolonialism, and the female experience.
  7. Virginia Woolf is most known for her modernist novels li

You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


  Layer 0: Transformer
  Layer 1: Pooling
  Layer 2: Dense
  Layer 3: Dense
  Layer 4: Normalize

Loaded model with tokenizer: GemmaTokenizerFast
Base transformer: Gemma3TextModel
Found 2 Dense projection layers


### Computing passage embeddings (once) ###
Computed 9 passage embeddings with shape torch.Size([9, 768])


### APPROACH 1: Context in attention only (custom masked pooling) ###
This is the CORE approach: context influences via attention, but we pool only over query tokens

[Approach 1 - Query] Pooling Operation:
  Input shape: torch.Size([1, 46, 768])
  Pooling mask shape: torch.Size([1, 46])
  Batch 0: pooling over 35/46 tokens
    (sequence padded from 35 to 46)
  Output shape: torch.Size([1, 768])

Approach 1 - Similarity scores (top 3):
  Query vs Passage 2: 46.68
  Query vs Passage 5: 43.53
  Query vs Passage 8: 31.41


### APPROACH 2: No context (baseline) ###

[Approach 2 - Query] Pooling Operation:
  Input shape: torch.Size([1, 37, 768])
  Pooling mask shape: torch.S