# Embedding Quantization

In this notebook we showcase how you try out different quantization methods.

We showcase here the usage of our inference engine [Ofen](https://github/mixedbread-ai/ofen) for generating the embeddings.

In [None]:
%pip install "ofen[torch]==0.0.1"

## 1. Import Libraries

In [1]:
from functools import partial

import numpy as np

from ofen.models import TextEncoder
from ofen.enums import EncodingFormat
from baguetter.indices import USearchDenseIndex
from baguetter.evaluation import evaluate_retrievers, HFDataset

# model_helpers also provides a stable implementation 
# of create_embed_fn using sentence-transformers
from baguetter.utils.model_helpers import create_embed_fn_ofen

## 2. Load Model

In [2]:
model = TextEncoder.from_pretrained("mixedbread-ai/mxbai-embed-large-v1")
## Convert model to half precision (FP16) for efficiency
model.half()

# Define the embedding function expected by the USearchDenseIndex.
# Alternatively, you can compute the embeddings yourself and add them to the index.
# This function caches the float32 embeddings, to reuse them for different indices
embed_fn = create_embed_fn_ofen(model, query_prompt="Represent this sentence for searching relevant passages: ", batch_size=256)

## 3. Create different embedding functions

In [3]:
ubinary_embed_fn = partial(embed_fn, encoding_format=EncodingFormat.UBINARY)
# Not supported atm
# binary_embed_fn = partial(embed_fn, encoding_format=EncodingFormat.BINARY)
# Not supported atm
# uint8_embed_fn = partial(embed_fn, encoding_format=EncodingFormat.UINT8)
int8_embed_fn = partial(embed_fn, encoding_format=EncodingFormat.INT8)
float32_embed_fn = partial(embed_fn, encoding_format=EncodingFormat.FLOAT)

## 4. Evaluate

In [4]:
# Evaluate the retrievers
datasets = [HFDataset("mteb/scidocs")]

# Evaluate dense retriever
result = evaluate_retrievers(
    datasets=datasets,
    retriever_factories={
        "ubinary": lambda: USearchDenseIndex(
            embedding_dim=model.embedding_dim,
            embed_fn=ubinary_embed_fn,
            metric="hamming",
        ),
        "int8": lambda: USearchDenseIndex(
            embedding_dim=model.embedding_dim,
            embed_fn=int8_embed_fn,
            dtype=np.int8
        ),
        "float32": lambda: USearchDenseIndex(
            embedding_dim=model.embedding_dim,
            embed_fn=float32_embed_fn,
        )
    }
)
result.save("eval_results")

Evaluating  3 retrievers...
---------------------------------------------------------------
Datasets:  ['mteb/scidocs']
Top K:  100
Metrics:  ['ndcg@1', 'ndcg@5', 'ndcg@10', 'precision@1', 'precision@5', 'precision@10', 'mrr@1', 'mrr@5', 'mrr@10']
Ignore identical IDs:  True

Evaluating Dataset: mteb/scidocs
---------------------------------------------------------------


Starting Adding 25657 documents to ubinary...


Add: 100%|██████████| 25657/25657 [00:00<00:00, 50087.37vector/s]


Adding 25657 documents to ubinary took 1.10 seconds
Starting Searching 1000 queries with ubinary...


Search: 100%|██████████| 1000/1000 [00:00<00:00, 28258.17vector/s]


Searching 1000 queries with ubinary took 0.22 seconds
Starting Adding 25657 documents to int8...


Add: 100%|██████████| 25657/25657 [00:00<00:00, 41737.65vector/s]


Adding 25657 documents to int8 took 1.21 seconds
Starting Searching 1000 queries with int8...


Search: 100%|██████████| 1000/1000 [00:00<00:00, 11246.50vector/s]


Searching 1000 queries with int8 took 0.26 seconds
Starting Adding 25657 documents to float32...


Add: 100%|██████████| 25657/25657 [00:00<00:00, 52296.59vector/s]


Adding 25657 documents to float32 took 1.08 seconds
Starting Searching 1000 queries with float32...


Search: 100%|██████████| 1000/1000 [00:00<00:00, 4579.06vector/s]


Searching 1000 queries with float32 took 0.40 seconds

Report (rounded):
---------------------------------------------------------------
#    Model      NDCG@1  NDCG@5    NDCG@10      P@1  P@5     P@10      MRR@1  MRR@5    MRR@10
---  -------  --------  --------  ---------  -----  ------  ------  -------  -------  --------
a    ubinary     0.251  0.175     0.211      0.251  0.153   0.110     0.251  0.343    0.359
b    int8        0.236  0.187ᵃ    0.228ᵃ     0.236  0.170ᵃ  0.122ᵃ    0.236  0.351    0.365
c    float32     0.253  0.193ᵃ    0.231ᵃ     0.253  0.174ᵃ  0.121ᵃ    0.253  0.364ᵃ   0.377ᵃ
