# Binary embeddings with [mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)

Our model was trained to have a non-'clunky' embeddings space. This allows for quantizing the embeddings with low performance loss compared to techniques like Matryoshka. With binary embeddings, we can use the Hamming distance, which is well optimized for CPUs.

In general, the approach is divided into 2 steps:

1. Retrieve candidates based on Hamming distance.
2. Rescore the candidates based on the dot product between the binary embedding and the floating embedding of the query.

We find that we can retain ~96-99% of the performance, achieve ~40x faster retrieval, and realize 32x storage savings.

In [None]:
from datasets import load_dataset
import numpy as np
import faiss
import time
import os

Let's use the worlds best model xD

In [None]:
# No model needed — we load cached embeddings
from pathlib import Path
CACHE_DIR = Path("cache/embeddings")
model_short = "mxbai-embed-large-v1"

TrecCovid is a nice benchmark, not too large, not too small, also pretty difficult.

In [None]:
# Use data_loader to ensure IDs match cached embedding order
from data_loader import MTEBDataLoader

# Switch between ("scifact", "SciFact"), ("nfcorpus", "NFCorpus"), ("fiqa", "fiqa")
dataset_name = "fiqa"
dataset_cache_name = "fiqa"  # must match cache filename casing

truncate_dim = 128



data_loader = MTEBDataLoader(Path("cache/datasets"))
corpus, queries_dict, qrels = data_loader.load_dataset(dataset_name)
docs_ids, doc_texts, all_query_ids, query_texts = data_loader.get_texts_for_embedding(corpus, queries_dict)

# Load cached embeddings (same order as docs_ids / all_query_ids)
emb = np.load(CACHE_DIR / f"{model_short}_{dataset_cache_name}_corpus.npy")
assert emb.shape[0] == len(docs_ids), f"Corpus mismatch: {emb.shape[0]} embeddings vs {len(docs_ids)} doc IDs"
print(f"Loaded {dataset_name}: {len(docs_ids)} docs ({emb.shape}), {len(all_query_ids)} queries")

In [None]:
index = faiss.IndexFlatIP(emb.shape[1])
index.add(emb)
faiss.write_index(index, "index_fp32.faiss")

In [None]:
emb_truncated = emb[:,:truncate_dim]

bemb = np.packbits(emb_truncated > 0).reshape(emb_truncated.shape[0], -1)
print("Binary embeddings computed. Shape:", bemb.shape)
num_dim = emb_truncated.shape[1]
bindex = faiss.IndexBinaryFlat(num_dim)
bindex.add(bemb)
faiss.write_index_binary(bindex, "index_binary.faiss")

In [None]:
# check file size
fp32_index_size = os.path.getsize("index_fp32.faiss")
binary_index_size = os.path.getsize("index_binary.faiss")
print("File size of index_fp32.faiss:", fp32_index_size)
print("File size of index_binary.faiss:", binary_index_size)
print("Compression ratio:", fp32_index_size / binary_index_size)

In [None]:
for qid, rels in qrels.items():
    pos = {did: score for did, score in rels.items() if score > 0}
    print(pos)
    break

Some BEIR stuff for the eval later

In [None]:
# qrels already loaded by data_loader above — filter to positive scores
# (data_loader includes 0-relevance entries for some datasets)
qrels_pos = {}
for qid, rels in qrels.items():
    pos = {did: score for did, score in rels.items() if score > 0}
    if pos:
        qrels_pos[qid] = pos
qrels = qrels_pos
print(f"Qrels: {len(qrels)} queries with positive relevance judgments")

In [None]:
# Filter to queries with qrels, keeping aligned indices for embedding slicing
keep_mask = [qid in qrels for qid in all_query_ids]
query_ids = [qid for qid, keep in zip(all_query_ids, keep_mask) if keep]
print(f"Queries: {len(all_query_ids)} total → {len(query_ids)} with positive qrels")

In [None]:
# Load cached query embeddings and filter to queries with qrels
all_query_emb = np.load(CACHE_DIR / f"{model_short}_{dataset_cache_name}_queries.npy")
assert all_query_emb.shape[0] == len(all_query_ids), f"Query mismatch: {all_query_emb.shape[0]} vs {len(all_query_ids)}"
query_emb = all_query_emb[keep_mask]

query_truncated = query_emb[:,:truncate_dim]

query_bemb = np.packbits(query_truncated > 0).reshape(query_truncated.shape[0], -1)
print(f"Query embeddings: {all_query_emb.shape} → {query_emb.shape} (with qrels), binary: {query_bemb.shape}")



In [None]:
from metrics import RetrievalMetrics

def faiss_search(index, queries_emb, k=[10, 100],float_embed=None, float_q_embed=None, oversample=1):
    start_time = time.time()
    faiss_scores, faiss_doc_ids = index.search(queries_emb, max(k) * oversample)
    print(f"Search took {(time.time()-start_time):.4f} sec")

    query2id = {idx: qid for idx, qid in enumerate(query_ids)}
    doc2id = {idx: cid for idx, cid in enumerate(docs_ids)}
    id2doc = {cid: idx for idx, cid in enumerate(docs_ids)}

    # Build per-query result dicts (for display) and final indices array (for metrics)
    if float_q_embed is not None:
        # Rescore: reorder candidates by float query × unpacked binary ±1
        n_q = len(queries_emb)
        final_k = max(k)
        final_indices = np.zeros((n_q, final_k), dtype=np.int64)
        for idx in range(n_q):
            cand_idx = faiss_doc_ids[idx]

            if  float_embed is not None:
                doc_emb = float_embed[cand_idx]
            else:
                bin_doc_emb = np.asarray([index.reconstruct(int(did)) for did in cand_idx])
                bin_doc_emb_unpacked = np.unpackbits(bin_doc_emb, axis=-1).astype(np.float32)
                doc_emb = bin_doc_emb_unpacked

    
            scores_cont = float_q_embed[idx] @ doc_emb.T
            reranked = np.argsort(-scores_cont)[:final_k]
            final_indices[idx] = cand_idx[reranked]
    else:
        final_indices = faiss_doc_ids[:, :max(k)]

    # Compute metrics using our own module
    doc_id_to_idx = {cid: idx for idx, cid in enumerate(docs_ids)}
    metrics = RetrievalMetrics.compute_all_metrics(final_indices, qrels, doc_id_to_idx, query_ids, k)
    for name, val in sorted(metrics.items()):
        print(f"  {name}: {val:.4f}")
    return metrics


In [None]:
faiss_search(index, query_emb)

### W/O Rescoring

We loose around 53% of the performance. But its pretty fast ~30-40x faster.

In [None]:
faiss_search(bindex, query_bemb, oversample=1)

In [None]:
faiss_search(bindex, query_bemb,float_embed=emb, float_q_embed=query_emb, oversample=10)

In [None]:
from binary_quant_rerank import *
cfg = ExperimentConfig()
run(cfg)

In [None]:
!python3 --version

## Conclusion

Binary embedding enables extremely fast retrieval and low storage usage, at the expense of a slight performance loss, which can be mitigated by using a reranker. This has cool applications for on-device usage, large-scale developments, etc. We should also explore its potential for other tasks, such as clustering and deduplication at scale.