## Asymmetric Dense Embedding Inferencing

The examples below shows how to inference with LightRetriever's asymmetric dense.

### 1. Lightweight Query Encoder with a Dense EmbeddingBag
Please first cache and save a EmbeddingBag by following `scripts/cache_emb_bag.ipynb`.

#### Load Tokenizer & EmbeddingBag

In [2]:
import torch
from transformers import AutoTokenizer, PreTrainedTokenizerBase

model_name_or_path = "lightretriever/lightretriever-qwen2.5-1.5b"
emb_bag_path = "web_search_en.emb_bag.pt"
device = torch.device("cuda:0")

# Load Tokenizer
tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(model_name_or_path)

# Load EmbeddingBag
emb_bag_weight = torch.load(emb_bag_path, map_location="cpu")
emb_bag = torch.nn.EmbeddingBag.from_pretrained(emb_bag_weight, padding_idx=tokenizer.pad_token_id)
emb_bag = emb_bag.to(device=device, dtype=torch.bfloat16)

#### Inference Query Embedding

In [5]:
import numpy as np

def tokenize_nonctx_qry_emb_bag(
    queries: list[str], 
    tokenizer: PreTrainedTokenizerBase, 
    max_len: int = 512,
):
    """ Tokenize queries for EmbeddingBag.

        Args:
            queries (list[str]): List of string queries.
            tokenizer (PreTrainedTokenizerBase): HF Tokenizer.
            max_len (int): Max sequence length of each `query`, DO NOT include the prompt area.
    """
    encodings_ids: list[list[int]] = tokenizer(
        queries, 
        max_length=max_len,
        truncation=True,
        add_special_tokens=False,
        return_attention_mask=False,
    )["input_ids"]

    offsets = torch.from_numpy(np.cumsum([0] + [len(token_ids) for token_ids in encodings_ids[:-1]]))
    input_ids = torch.from_numpy(np.concatenate(encodings_ids)).long()
    return {"input": input_ids, "offsets": offsets}

# Tokenize
queries: list[str] = [
    "How tall is Mount Everest?",
    "Who invented the light bulb?"
]
queries_tokenized: dict[str, torch.Tensor] = tokenize_nonctx_qry_emb_bag(queries, tokenizer=tokenizer)
queries_tokenized = {k: v.to(device) for k, v in queries_tokenized.items()}

# Encode & Pooling
query_embeddings: torch.Tensor = emb_bag(**queries_tokenized)
query_embeddings = torch.nn.functional.normalize(query_embeddings, p=2, dim=-1)

### 2. Full-sized LLM Document Encoder

#### Load & Merge LLM LoRA Weights

In [6]:
from transformers import AutoModelForCausalLM
from transformers.modeling_outputs import BaseModelOutput
from peft import LoraConfig, PeftModel, LoraModel

# Load Base HF Model & Peft Adapters
config = LoraConfig.from_pretrained(model_name_or_path)
base_model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path, 
    torch_dtype=torch.bfloat16, 
    attn_implementation="flash_attention_2",
    device_map=device,
)
hf_model: LoraModel = PeftModel.from_pretrained(base_model, model_name_or_path, config=config)
hf_model = hf_model.merge_and_unload()  # Merge to single HF Model

#### Inference Corpus Embedding

In [7]:
def lasttoken_pooling(
    last_hidden: torch.Tensor,
    attention_mask: torch.Tensor,
):
    """ Last Token (EOS) Pooling

        Args:
            last_hidden (torch.Tensor): Last layer hidden states.
            attention_mask (torch.Tensor): Attention mask.
    """
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1)
        last_token_indices = sequence_lengths - 1
        return last_hidden[torch.arange(last_hidden.shape[0], device=last_hidden.device), last_token_indices]

# Tokenize
corpus: list[str] = [
    "Mount Everest is the highest mountain in the world, about 8,848 meters tall.",
    "Thomas Edison invented the electric light bulb.",
    "Mount Fuji is the tallest mountain in Japan.",
]
corpus_tokenized = tokenizer(
    corpus, 
    max_length=512,
    padding="longest",
    truncation=True,
    add_special_tokens=True,
    return_attention_mask=True,
    return_tensors="pt",
)
corpus_tokenized: dict[str, torch.Tensor] = {k: v.to(device) for k, v in corpus_tokenized.items()}

# Encode & Pooling
lm_out: BaseModelOutput = hf_model.model(
    **corpus_tokenized,
    return_dict=True,
    use_cache=False,
    output_hidden_states=False
)
corpus_embedding = lasttoken_pooling(lm_out.last_hidden_state, corpus_tokenized["attention_mask"])
corpus_embedding = torch.nn.functional.normalize(corpus_embedding, p=2, dim=-1)

### 3. Compute Similarity

In [8]:
scores = query_embeddings @ corpus_embedding.T
print(scores)

tensor([[0.3945, 0.0483, 0.3105],
        [0.0076, 0.3945, 0.0148]], device='cuda:0', dtype=torch.bfloat16)
