# Binary embeddings with [mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)

Our model was trained to have a non-'clunky' embeddings space. This allows for quantizing the embeddings with low performance loss compared to techniques like Matryoshka. With binary embeddings, we can use the Hamming distance, which is well optimized for CPUs.

In general, the approach is divided into 2 steps:

1. Retrieve candidates based on Hamming distance.
2. Rescore the candidates based on the dot product between the binary embedding and the floating embedding of the query.

We find that we can retain ~96-99% of the performance, achieve ~40x faster retrieval, and realize 32x storage savings.

In [None]:
!pip install sentence_transformers datasets beir faiss-cpu

In [None]:
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
from beir.retrieval.evaluation import EvaluateRetrieval

import numpy as np
import faiss
import time
import os

Let's use the worlds best model xD

In [3]:
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")

TrecCovid is a nice benchmark, not too large, not too small, also pretty difficult.

In [4]:
task = "mteb/trec-covid"
dataset = load_dataset(task, "corpus")
docs_ids = dataset["corpus"]["_id"]
features = [d["title"] + " " + d["text"] for d in dataset["corpus"]]

Let's speedup the calculation by using fp16

In [5]:
_ = model.half()

On 4xA100 it should take ~2min

In [6]:
pool = model.start_multi_process_pool()

# normalize_embeddings=True will normalize the embeddings to unit length before indexing so the dot product is equal to the cosine similarity
emb = model.encode_multi_process(features, pool, normalize_embeddings=True)
print("Embeddings computed. Shape:", emb.shape)

model.stop_multi_process_pool(pool)

Embeddings computed. Shape: (171332, 1024)


FP32 Index

In [7]:
index = faiss.IndexFlatIP(emb.shape[1])
index.add(emb)
faiss.write_index(index, "index_fp32.faiss")

Binary Index, convert embeddings using simple thresholding

In [8]:
bemb = np.where(emb < 0, 0, 1).astype(np.uint8)
bemb = np.packbits(bemb).reshape(bemb.shape[0], -1)
print("Binary embeddings computed. Shape:", bemb.shape)
num_dim = emb.shape[1]
bindex = faiss.IndexBinaryFlat(num_dim)
bindex.add(bemb)
faiss.write_index_binary(bindex, "index_binary.faiss")

Binary embeddings computed. Shape: (171332, 128)


Compression size is ~32 as expected

In [9]:
# check file size
fp32_index_size = os.path.getsize("index_fp32.faiss")
binary_index_size = os.path.getsize("index_binary.faiss")
print("File size of index_fp32.faiss:", fp32_index_size)
print("File size of index_binary.faiss:", binary_index_size)
print("Compression ratio:", fp32_index_size / binary_index_size)

File size of index_fp32.faiss: 701775917
File size of index_binary.faiss: 21930529
Compression ratio: 31.999953899880847


Some BEIR stuff for the eval later

In [10]:
qrels_df = load_dataset(task)["test"]
qrels = {}
for row in qrels_df:
    qid = row['query-id']
    cid = row['corpus-id']
    
    if row['score'] > 0:
        if qid not in qrels:
            qrels[qid] = {}
        qrels[qid][cid] = int(row['score'])

In [11]:
queries = load_dataset(task, "queries")
queries = queries.filter(lambda x: x['_id'] in qrels)

query_ids = queries["queries"]["_id"]
queries = ["Represent this sentence for searching relevant passages: " + d["text"] for d in queries["queries"]]

In [12]:
model.float()
query_emb = model.encode(queries, convert_to_numpy=True, normalize_embeddings=True)
query_bemb = np.where(query_emb < 0, 0, 1).astype(np.uint8) # binarize
query_bemb = np.packbits(query_bemb).reshape(query_bemb.shape[0], -1)

In [13]:
def faiss_search(index, queries_emb, k=[10, 100], float_embed = None):
    start_time = time.time()
    faiss_scores, faiss_doc_ids = index.search(queries_emb, max(k))
    print(f"Search took {(time.time()-start_time):.4f} sec")
    
    query2id = {idx: qid for idx, qid in enumerate(query_ids)}
    doc2id = {idx: cid for idx, cid in enumerate(docs_ids)}
    id2doc = {cid: idx for idx, cid in enumerate(docs_ids)}

    faiss_results = {}
    for idx in range(0, len(faiss_scores)):
        qid = query2id[idx]
        doc_scores = {doc2id[doc_id]: score.item() for doc_id, score in zip(faiss_doc_ids[idx], faiss_scores[idx])}
       
        # Rescore
        if float_embed is not None:
            bin_doc_emb = np.asarray([index.reconstruct(id2doc[doc_id]) for doc_id in doc_scores])
            bin_doc_emb_unpacked = np.unpackbits(bin_doc_emb, axis=-1).astype("int")
            
            scores_cont = (float_embed[idx] @ bin_doc_emb_unpacked.T)
            doc_scores = {doc_id: score_cont for doc_id, score_cont in zip(doc_scores, scores_cont)}

        faiss_results[qid] = doc_scores

        
    ndcg, map_score, recall, precision = EvaluateRetrieval.evaluate(qrels, faiss_results, k)
    acc = EvaluateRetrieval.evaluate_custom(qrels, faiss_results, [3, 5, 10], metric="acc")
    print(ndcg)
    print(recall)
    print(acc)

### Baseline: Normal exact search
We mostly care about NDCG@10 here.

In [14]:
faiss_search(index, query_emb)

Search took 0.4663 sec
{'NDCG@10': 0.75558, 'NDCG@100': 0.56317}
{'Recall@10': 0.02136, 'Recall@100': 0.13842}
{'Accuracy@3': 0.98, 'Accuracy@5': 1.0, 'Accuracy@10': 1.0}


### W/O Rescoring

We loose around 53% of the performance. But its pretty fast ~30-40x faster.

In [15]:
faiss_search(bindex, query_bemb)

Search took 0.0094 sec
{'NDCG@10': 0.35133, 'NDCG@100': 0.43985}
{'Recall@10': 0.00866, 'Recall@100': 0.1221}
{'Accuracy@3': 0.66, 'Accuracy@5': 0.76, 'Accuracy@10': 0.84}


### With Rescoring

Still extremely fast, with the difference that we retain 99% of the performance. We verified similar behavior for SciFact and ArguAna. Accuracy was also boosted.

In [16]:
faiss_search(bindex, query_bemb, float_embed=query_emb)

Search took 0.0153 sec
{'NDCG@10': 0.75496, 'NDCG@100': 0.51706}
{'Recall@10': 0.02128, 'Recall@100': 0.1221}
{'Accuracy@3': 1.0, 'Accuracy@5': 1.0, 'Accuracy@10': 1.0}


## Conclusion

Binary embedding enables extremely fast retrieval and low storage usage, at the expense of a slight performance loss, which can be mitigated by using a reranker. This has cool applications for on-device usage, large-scale developments, etc. We should also explore its potential for other tasks, such as clustering and deduplication at scale.