## Asymmetric Sparse Embedding Inferencing

The examples below shows how to inference with LightRetriever's asymmetric sparse.

### 1. Lightweight Query Encoder without using any LM parameters

In [None]:
from collections import Counter
from transformers import AutoTokenizer, PreTrainedTokenizerBase

model_name_or_path = "lightretriever/lightretriever-qwen2.5-1.5b"

# Load Tokenizer
tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(model_name_or_path)

# Tokenize
queries: list[str] = [
    "How tall is Mount Everest?",
    "Who invented the light bulb?"
]
query_input_ids: list[list[int]] = tokenizer(
    queries, 
    max_length=512,
    truncation=True,
    add_special_tokens=False,
    return_attention_mask=False,
)["input_ids"]

# Get Query Embedding via counting
query_embeddings: list[dict[int, int]] = [dict(Counter(input_ids)) for input_ids in query_input_ids]

In [9]:
print(query_embeddings)

[{5158: 1, 16217: 1, 374: 1, 6470: 1, 3512: 1, 477: 1, 30: 1}, {14623: 1, 35492: 1, 279: 1, 3100: 1, 45812: 1, 30: 1}]


### 2. Full-sized LLM Document Encoder

#### Load & Merge LLM LoRA Weights

In [None]:
import torch
from transformers import AutoModelForCausalLM
from peft import LoraConfig, PeftModel, LoraModel

device = torch.device("cuda:0")

# Load Base HF Model & Peft Adapters
config = LoraConfig.from_pretrained(model_name_or_path)
base_model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path, 
    torch_dtype=torch.bfloat16, 
    device_map=device,
)
hf_model: LoraModel = PeftModel.from_pretrained(base_model, model_name_or_path, config=config)
hf_model = hf_model.merge_and_unload()  # Merge to single HF Model

#### Inference Corpus Embedding

In [11]:
from transformers.modeling_outputs import CausalLMOutputWithPast

# Tokenize
corpus: list[str] = [
    "Mount Everest is the highest mountain in the world, about 8,848 meters tall.",
    "Thomas Edison invented the electric light bulb.",
    "Mount Fuji is the tallest mountain in Japan.",
]
corpus_tokenized = tokenizer(
    corpus, 
    max_length=512,
    padding="longest",
    truncation=True,
    add_special_tokens=True,
    return_attention_mask=True,
    return_tensors="pt",
)
corpus_tokenized: dict[str, torch.Tensor] = {k: v.to(device) for k, v in corpus_tokenized.items()}

# Encode & Pooling
lm_out: CausalLMOutputWithPast = hf_model(
    **corpus_tokenized,
    return_dict=True,
    use_cache=False,
    output_hidden_states=False
)
lm_out.logits = lm_out.logits.masked_fill(
    ~corpus_tokenized["attention_mask"].bool().unsqueeze(-1), 
    0.
)
aggregated_logits: torch.Tensor = torch.log1p(torch.relu(torch.amax(lm_out.logits, dim=1)))

corpus_embeddings: list[dict[int, float]] = []
for logits in aggregated_logits:
    idx = torch.nonzero(logits, as_tuple=False).squeeze(1)
    vals = logits[idx]
    emb = dict(zip(idx.tolist(), vals.tolist()))
    corpus_embeddings.append(emb)

In [12]:
print([len(emb) for emb in corpus_embeddings])

[2261, 2207, 2651]


### 3. Compute Similarity

In [13]:
def compute_similarity(
    q_rep: dict[int, int],
    p_rep: dict[int, float],
) -> float:
    """
    Compute the similarity between two sparse vectors represented as dictionaries.

    Each dictionary maps a token_id (int) to its learned pseudo-frequency (integer).
    The similarity is defined as the dot product over the intersection of keys:
        sum(q_rep[tok_id] * p_rep[tok_id] for tok_id in q_rep.keys() & p_rep.keys())

    Args:
        q_rep (dict[int, int]): Sparse representation of the query vector.
        p_rep (dict[int, float]): Sparse representation of the passage/document vector.

    Returns:
        float: The similarity score (dot product).
    """
    return sum(
        q_rep[tok_id] * p_rep[tok_id] for tok_id in q_rep.keys() & p_rep.keys()
    )


scores = [
    [compute_similarity(q_rep, p_rep) for p_rep in corpus_embeddings] for q_rep in query_embeddings
]
print(scores)

[[12.94921875, 1.28125, 9.1875], [1.76953125, 11.74609375, 1.81640625]]
