# Dataset Making

In [None]:
!pip -q install datasets==2.20.0 orjson==3.10.7

from datasets import load_dataset
from datetime import datetime, timezone
import os, gzip, orjson

START = int(datetime(2022, 1, 1, 0, 0, 0, tzinfo=timezone.utc).timestamp() * 1000)
END   = int(datetime(2022,12,31,23,59,59,tzinfo=timezone.utc).timestamp() * 1000)

TARGET_DOCS = 2_000_000
SHARD_SIZE  = 100_000
OUT_DIR     = "/content/amazon_2022"
os.makedirs(OUT_DIR, exist_ok=True)

CATEGORIES = [
  "Clothing_Shoes_and_Jewelry","Home_and_Kitchen","Electronics","Beauty_and_Personal_Care",
  "Tools_and_Home_Improvement","Sports_and_Outdoors","Grocery_and_Gourmet_Food",
  "Toys_and_Games","Cell_Phones_and_Accessories","Health_and_Household","Kindle_Store",
  "Movies_and_TV","Books","Automotive","Video_Games","Industrial_and_Scientific",
  "All_Beauty","Amazon_Fashion","Appliances","Arts_Crafts_and_Sewing","CDs_and_Vinyl",
  "Handmade_Products","Health_and_Personal_Care","Gift_Cards","Magazine_Subscriptions",
  "Software","Subscription_Boxes","Unknown"
]

PER_CAT = TARGET_DOCS // len(CATEGORIES)
EXTRA   = TARGET_DOCS % len(CATEGORIES)

def mk_doc(cat, ex):
    ts = int(ex["timestamp"])
    dt = datetime.utcfromtimestamp(ts/1000)
    hv = int(ex.get("helpful_vote", 0))
    return {
        "id": f"{ex.get('asin','')}::{ex.get('user_id','')}::{ts}",
        "text": f"{(ex.get('title') or '').strip()}\n{(ex.get('text') or '').strip()}",
        "metadata": {
            "category": cat,
            "asin": ex.get("asin"),
            "parent_asin": ex.get("parent_asin"),
            "rating": float(ex.get("rating", 0.0)),
            "helpful_vote": hv,
            "verified_purchase": bool(ex.get("verified_purchase", False)),
            "timestamp_ms": ts,
            "date": dt.strftime("%Y-%m-%d"),
            "year": dt.year,
            "month": dt.month,
            "rating_bucket": int(ex.get("rating", 0.0)),  # 1..5
            "helpful_bucket": 0 if hv < 5 else 5 if hv < 20 else 20,
        }
    }

def write_shard(shard_idx, buffer):
    path = os.path.join(OUT_DIR, f"reviews_shard_{shard_idx:04d}.jsonl.gz")
    with gzip.open(path, "wb") as f:
        for d in buffer:
            f.write(orjson.dumps(d)); f.write(b"\n")
    return path

total = 0
shard_idx = 0
buf = []

for i, cat in enumerate(CATEGORIES):
    if total >= TARGET_DOCS: break
    cap = PER_CAT + (1 if i < EXTRA else 0)   # distribute remainder
    got = 0

    ds = load_dataset("McAuley-Lab/Amazon-Reviews-2023", f"raw_review_{cat}",
                      split="full", streaming=True, trust_remote_code=True)
    for ex in ds:
        ts = int(ex["timestamp"])
        if not (START <= ts <= END):
            continue
        d = mk_doc(cat, ex)
        if not d["text"].strip():
            continue

        buf.append(d); got += 1; total += 1

        if len(buf) >= SHARD_SIZE:
            path = write_shard(shard_idx, buf); buf.clear(); shard_idx += 1
            print(f"Wrote {path} (total={total})")

        if got >= cap or total >= TARGET_DOCS:
            break

# flush tail
if buf:
    path = write_shard(shard_idx, buf); print(f"Wrote {path} (total={total})")
print(f"Done. Wrote ~{total} docs into {OUT_DIR}")


# Ingestion

Installing

In [None]:
# fresh ORT GPU + FastEmbed GPU (works with CUDA 12.x on Colab)
!pip -q uninstall -y onnxruntime onnxruntime-gpu || true
!pip -q install onnxruntime-gpu==1.18.1
!pip -q install fastembed-gpu==0.7.1 qdrant-client>=1.14.2 orjson==3.*

In [None]:
# 1) Install: FastEmbed comes with the client extra
!pip -q install "qdrant-client[fastembed-gpu]>=1.14.2" orjson==3.*

import os, gzip, glob, orjson
from qdrant_client import QdrantClient, models
from qdrant_client.http import models as http_models
from fastembed import TextEmbedding, LateInteractionTextEmbedding, SparseTextEmbedding
from typing import Iterable, List

import onnxruntime as ort
from fastembed import TextEmbedding, SparseTextEmbedding
from qdrant_client import QdrantClient, models
from qdrant_client.http import models as http
# your env from before
COLL = "amazon_reviews_2022"
OUT_DIR = "/content/amazon_2022"  # or use fastembed, see note below

QDRANT_URL = os.getenv("QDRANT_URL", "").strip()
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "").strip()
assert QDRANT_URL and QDRANT_API_KEY, "Set QDRANT_URL and QDRANT_API_KEY env vars."


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m337.3/337.3 kB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m100.9/100.9 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.6/61.6 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.1/103.1 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m283.2/283.2 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m324.8/324.8 kB[0m [31m32.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25h

Embeddings

In [None]:
# ---- Data ----
OUT_DIR = os.getenv("OUT_DIR", "/content/amazon_2022").strip()  # expects reviews_shard_*.jsonl.gz
COLL = "hybrid-search"

# 2) GPU providers for FastEmbed (ONNX Runtime)
providers = ["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"]
print("ONNX providers:", ort.get_available_providers())

dense  = TextEmbedding("sentence-transformers/all-MiniLM-L6-v2", providers=providers, threads=None)
sparse = SparseTextEmbedding("Qdrant/bm25")   # BM25 is CPU; very light
DENSE_DIM = len(next(dense.embed(["hi"]))) # number is needed, higher dim wouldve been better
print("Dense dim:", DENSE_DIM)

# 3) Qdrant **Cloud** client (use gRPC for speed)
client = QdrantClient(
    url=QDRANT_URL,
    api_key=QDRANT_API_KEY,
    prefer_grpc=True,

# Create/replace collection in Cloud (defer ANN graph during ingest)
client.recreate_collection(
    collection_name=COLL,
    vectors_config={"dense": models.VectorParams(size=DENSE_DIM, distance=models.Distance.COSINE)},
    sparse_vectors_config={"bm25": models.SparseVectorParams()},
    hnsw_config=models.HnswConfigDiff(m=0),   # build later
    shard_number=4,                           # parallel segments in Cloud
    optimizers_config=models.OptimizersConfigDiff(
        default_segment_number=1,
        memmap_threshold=1_000_000_000,
        indexing_threshold=50_000,
    ),
)

# 4) Helpers


    # Making the hashable Ids
def make_u64_id(s: str) -> int:
    return int.from_bytes(hashlib.blake2b((s or "").encode(), digest_size=8).digest(), "big", signed=False)

def flatten_payload(ex: dict) -> dict:
    p = {"document": (ex.get("text") or ""), "doc_key": str(ex.get("id"))}
    meta = ex.get("metadata") or {}
    p.update(meta)
    return p

def read_shard(fp: str) -> Iterable[dict]:
    with gzip.open(fp, "rb") as f:
        for line in f:
            ex = orjson.loads(line)
            if (ex.get("text") or "").strip():
                yield ex

def chunk(lst: List, n: int):
    for i in range(0, len(lst), n): yield lst[i:i+n]

# 5) Discover shards
files = sorted(glob.glob(f"{OUT_DIR}/reviews_shard_*.jsonl.gz"))
assert files, f"No shards found in {OUT_DIR}"
print("Shards:", len(files))

# 6) Probe (quick sanity to verify URL/key/collection)
probe = []
for ex in read_shard(files[0]):
    probe.append(ex)
    if len(probe) == 64: break

pd_texts = [ex["text"].strip() for ex in probe]
pd_payloads = [flatten_payload(ex) for ex in probe]
pd_ids = [make_u64_id(str(ex.get("id"))) for ex in probe]

pd_dense  = list(dense.embed(pd_texts, batch_size=min(4096, len(pd_texts))))
pd_sparse = [s.as_object() for s in sparse.embed(pd_texts)]

points = [
    http.PointStruct(
        id=pd_ids[i],
        vector={"dense": pd_dense[i], "bm25": pd_sparse[i]},
        payload=pd_payloads[i],
    )
    for i in range(len(pd_dense))
]
client.upsert(COLL, points=points, wait=True)
print("Post-probe count (cloud):", client.count(COLL).count)

# 7) Cap sequence length for DENSE only (BM25 keeps full text)
MAX_TOKENS = 256  # try 256–320; lower => bigger safe batch

def truncate_whitespace(text: str, max_tokens: int = MAX_TOKENS) -> str:
    toks = text.split()
    if len(toks) > max_tokens:
        toks = toks[:max_tokens]
    return " ".join(toks)

# 8) Full ingest (big batches, deferred indexes).
EMBED_BATCH  = 5120
UPSERT_BATCH = int(os.getenv("UPSERT_BATCH", "5000"))
total = client.count(COLL).count
t0 = time.time()

# Optional: overlap writes with compute (keeps GPU busy)
from concurrent.futures import ThreadPoolExecutor
up_pool = ThreadPoolExecutor(max_workers=4)
_pending = []
def async_upsert(pts):
    _pending.append(up_pool.submit(client.upsert, COLL, pts, wait=False))
def drain():
    for f in _pending: f.result()
    _pending.clear()

for si, fp in enumerate(files, 1):
    texts_full, texts_trunc, payloads, ids = [], [], [], []

    for ex in read_shard(fp):
        full = ex["text"].strip()
        if not full:
            continue
        trunc = truncate_whitespace(full)

        texts_full.append(full)
        texts_trunc.append(trunc)
        payloads.append(flatten_payload(ex))
        ids.append(make_u64_id(str(ex.get("id"))))

        if len(texts_full) >= EMBED_BATCH:
            # Dense on truncated (GPU), Sparse on full (CPU-light)
            dv = list(dense.embed(texts_trunc, batch_size=EMBED_BATCH))
            sv = [s.as_object() for s in sparse.embed(texts_full)]

            for sub in chunk(list(range(len(dv))), UPSERT_BATCH):
                pts = [
                    http.PointStruct(
                        id=ids[i],
                        vector={"dense": dv[i], "bm25": sv[i]},
                        payload=payloads[i],
                    )
                    for i in sub
                ]
                async_upsert(pts)
                total += len(pts)

            texts_full.clear(); texts_trunc.clear(); payloads.clear(); ids.clear()

    # tail flush
    if texts_full:
        dv = list(dense.embed(texts_trunc, batch_size=EMBED_BATCH))
        sv = [s.as_object() for s in sparse.embed(texts_full)]
        for sub in chunk(list(range(len(dv))), UPSERT_BATCH):
            pts = [
                http.PointStruct(
                    id=ids[i],
                    vector={"dense": dv[i], "bm25": sv[i]},
                    payload=payloads[i],
                )
                for i in sub
            ]
            async_upsert(pts)
            total += len(pts)

    # ensure all upserts for this shard are flushed before moving on
    drain()
    print(f"[{os.path.basename(fp)}] cloud upserts so far: {total:,}")

print(f"Cloud total sent: {total:,} | elapsed: {time.time()-t0:.1f}s")
print("Cloud count:", client.count(COLL).count)

# 9) Build HNSW ingest - graph
client.update_collection(COLL, hnsw_config=models.HnswConfigDiff(m=16))
print("HNSW built on Cloud. Ready.")

# 10) Build facet indexes at the end (cheaper overall)
for name, schema in [
    ("category", models.PayloadSchemaType.KEYWORD),
    ("asin", models.PayloadSchemaType.KEYWORD),
    ("year", models.PayloadSchemaType.INTEGER),
    ("month", models.PayloadSchemaType.INTEGER),
    ("rating", models.PayloadSchemaType.FLOAT),
    ("rating_bucket", models.PayloadSchemaType.INTEGER),
    ("verified_purchase", models.PayloadSchemaType.BOOL),
    ("helpful_vote", models.PayloadSchemaType.INTEGER),
    ("helpful_bucket", models.PayloadSchemaType.INTEGER),
    ("timestamp_ms", models.PayloadSchemaType.INTEGER),
]:
    try:
        client.create_payload_index(COLL, field_name=name, field_schema=schema)
    except Exception:
        pass

print("Facet indexes built.")

# 11) Hybrid search (RRF / DBSF)
from qdrant_client import models as qm

def hybrid_search(query: str, topk=10, prefetch_k=64, fusion="RRF", filters: dict=None):
    q_dense  = next(dense.embed([query]))
    q_sparse = next(sparse.embed([query])).as_object()

    prefetch = [
        qm.Prefetch(query=q_sparse, using="bm25",  limit=prefetch_k),
        qm.Prefetch(query=q_dense,  using="dense", limit=prefetch_k),
    ]
    filt = None
    if filters:
        must = []
        for k, v in filters.items():
            if isinstance(v, dict):
                rng = {}
                for key in ("gt","gte","lt","lte"):
                    if key in v: rng[key.upper()] = v[key]
                if rng:
                    must.append(qm.FieldCondition(key=k, range=qm.Range(**rng)))
                    continue
            must.append(qm.FieldCondition(key=k, match=qm.MatchValue(value=v)))
        filt = qm.Filter(must=must)

    fusion_enum = qm.Fusion.RRF if fusion.upper() == "RRF" else qm.Fusion.DBSF
    return client.query_points(
        collection_name=COLL,
        prefetch=prefetch,
        query=qm.FusionQuery(fusion=fusion_enum),
        limit=topk,
        query_filter=filt,
        with_payload=True
    )



# Searching

In [None]:
!pip -q install "qdrant-client[fastembed-gpu]>=1.14.2" orjson==3.*

# 1) Imports / Config
import os, glob, gzip, orjson, hashlib, time
from typing import Iterable, List
import numpy as np

import onnxruntime as ort
from fastembed import TextEmbedding, SparseTextEmbedding, LateInteractionTextEmbedding
from qdrant_client import QdrantClient, models
from qdrant_client.http import models as http

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/337.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m337.3/337.3 kB[0m [31m27.4 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/100.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m100.9/100.9 kB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/61.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.6/61.6 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.1/103.1 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m283.2/283.2 MB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
QDRANT_URL = os.getenv("QDRANT_URL", "").strip()
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "").strip()
assert QDRANT_URL and QDRANT_API_KEY, "Set QDRANT_URL and QDRANT_API_KEY env vars."

In [None]:
from qdrant_client import models as qm

def _to_filter(filters: dict | None) -> qm.Filter | None:
    if not filters:
        return None
    must = []
    for k, v in filters.items():
        if isinstance(v, dict):
            # keep only supported lowercase keys
            rng_kwargs = {key: v[key] for key in ("gt", "gte", "lt", "lte") if key in v}
            if rng_kwargs:
                must.append(qm.FieldCondition(key=k, range=qm.Range(**rng_kwargs)))
                continue
        # equality / exact match
        must.append(qm.FieldCondition(key=k, match=qm.MatchValue(value=v)))
    return qm.Filter(must=must) if must else None

COLL = "hybrid-search"

# 2) GPU providers for FastEmbed (ONNX Runtime)
providers = ["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"]
print("ONNX providers:", ort.get_available_providers())
dense  = TextEmbedding("sentence-transformers/all-MiniLM-L6-v2", providers=providers, threads=None)
sparse = SparseTextEmbedding("Qdrant/bm25")

client = QdrantClient(
    url=QDRANT_URL,
    api_key=QDRANT_API_KEY,
    prefer_grpc=False,
)


ONNX providers: ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']


Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

tokenizer.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

model.onnx:   0%|          | 0.00/90.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/650 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/695 [00:00<?, ?B/s]

Fetching 18 files:   0%|          | 0/18 [00:00<?, ?it/s]

config.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

arabic.txt: 0.00B [00:00, ?B/s]

french.txt:   0%|          | 0.00/813 [00:00<?, ?B/s]

finnish.txt: 0.00B [00:00, ?B/s]

german.txt: 0.00B [00:00, ?B/s]

english.txt:   0%|          | 0.00/936 [00:00<?, ?B/s]

dutch.txt:   0%|          | 0.00/453 [00:00<?, ?B/s]

danish.txt:   0%|          | 0.00/424 [00:00<?, ?B/s]

italian.txt: 0.00B [00:00, ?B/s]

hungarian.txt: 0.00B [00:00, ?B/s]

greek.txt: 0.00B [00:00, ?B/s]

portuguese.txt: 0.00B [00:00, ?B/s]

romanian.txt: 0.00B [00:00, ?B/s]

russian.txt: 0.00B [00:00, ?B/s]

spanish.txt: 0.00B [00:00, ?B/s]

norwegian.txt:   0%|          | 0.00/851 [00:00<?, ?B/s]

swedish.txt:   0%|          | 0.00/559 [00:00<?, ?B/s]

turkish.txt:   0%|          | 0.00/260 [00:00<?, ?B/s]

In [None]:
def hybrid_search(query: str, topk=10, prefetch_k=64, fusion="RRF", filters: dict | None=None):
    q_dense  = next(dense.embed([query]))
    q_sparse = next(sparse.embed([query])).as_object()

    prefetch = [
        qm.Prefetch(query=q_sparse, using="bm25",  limit=prefetch_k),
        qm.Prefetch(query=q_dense,  using="dense", limit=prefetch_k),
    ]
    fusion_enum = qm.Fusion.RRF if fusion.upper() == "RRF" else qm.Fusion.DBSF

    return client.query_points(
        collection_name=COLL,
        prefetch=prefetch,
        query=qm.FusionQuery(fusion=fusion_enum),
        limit=topk,
        query_filter=_to_filter(filters),  # both filters applied
        with_payload=True
    )


Late Re-ranking with ColBERT

In [None]:
colbert = LateInteractionTextEmbedding("colbert-ir/colbertv2.0", providers=providers)

def _normalize_rows(x: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    # L2 normalize rows for cosine sims
    n = np.linalg.norm(x, axis=1, keepdims=True) + eps
    return x / n

def _colbert_score(q_tok: np.ndarray, d_tok: np.ndarray) -> float:
    """
    q_tok: [Q, d] query token embeddings
    d_tok: [D, d] doc token embeddings
    ColBERT scoring: sum_i max_j cosine(q_i, d_j)
    """
    qn = _normalize_rows(q_tok)
    dn = _normalize_rows(d_tok)
    sims = qn @ dn.T          # [Q, D]
    return float(sims.max(axis=1).sum())

def rerank_colbert(query_text: str, qdrant_result, topk: int = 10, rerank_k: int = 100):
    """
    Reorder fused candidates with ColBERT. Keeps the collection unchanged.
    - query_text: user query
    - qdrant_result: result from hybrid_search(...)
    - topk: how many final results to return
    - rerank_k: how many fused candidates to score with ColBERT
    """
    points = qdrant_result.points[:rerank_k]
    if not points:
        return []

    # Prepare texts
    docs = [p.payload.get("document", "") or "" for p in points]

    # Embed query (tokens) and documents (token grids)
    q_tok = next(colbert.query_embed(query_text))    # [Q, d]
    d_tok_list = list(colbert.embed(docs))           # list of [D_i, d]

    # Score & sort
    scores = []
    for d_tok in d_tok_list:
        if d_tok.size == 0:
            scores.append(-1e9)
        else:
            scores.append(_colbert_score(q_tok, d_tok))

    order = np.argsort(-np.asarray(scores))
    ranked = [points[i] for i in order[:topk]]
    return ranked


Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/405 [00:00<?, ?B/s]

model.onnx:   0%|          | 0.00/436M [00:00<?, ?B/s]

In [None]:
q = "skirt is not fitting"


res = hybrid_search(
    q,
    topk=1000, prefetch_k=128, fusion="DBSF",
    filters={}
)

# Stage 2: rerank with ColBERT and return scores
def rerank_colbert(query_text: str, qdrant_result, topk: int = 10, rerank_k: int = 100, return_scores: bool = False):
    points = qdrant_result.points[:rerank_k]
    if not points:
        return [] if not return_scores else []
    docs = [p.payload.get("document", "") or "" for p in points]

    q_tok = next(colbert.query_embed(query_text))      # [Q,d]
    d_tok_list = list(colbert.embed(docs))             # list of [D_i,d]

    def _norm(x):
        n = (x**2).sum(axis=1, keepdims=True)**0.5 + 1e-12
        return x / n
    def _score(qt, dt):
        qt, dt = _norm(qt), _norm(dt)
        return float((qt @ dt.T).max(axis=1).sum()) if dt.size else -1e9

    scores = [_score(q_tok, d_tok) for d_tok in d_tok_list]
    order = np.argsort(-np.asarray(scores))
    if return_scores:
        return [(points[i], float(scores[i])) for i in order[:topk]]
    return [points[i] for i in order[:topk]]

final = rerank_colbert(q, res, topk=10, rerank_k=100, return_scores=True)


# Old Values
for p in res.points[:10]:
    print(f"{p.score:.3f}",p.payload.get('category'), p.payload.get("asin"), (p.payload.get("document","")[:120]).replace("\n"," "))



1.928 Amazon_Fashion B08BNKYGG1 The skirt is terrible I read the instructions and measured to get the correct fit. The leggings fit as expected, but the
1.377 Amazon_Fashion B09H5C24FB Nice skirt wish it would fit as expected. The skirt is cute. The material is a bit cheap but for the price its not bad. 
1.376 Amazon_Fashion B09GP5VPCL Size is wrong Hi!! I am XL  size women and i have odered XL size and the skirt fits like the penal skirt.So,if you need 
1.228 Amazon_Fashion B09MCN2TJM Fits good Skirt rides up a bit but fits as expected
1.207 Amazon_Fashion B09J8SHGTK Hard to fit cute skirt Cute skirt but difficult to fit.  Large was too large but description advised ordering up a size.
1.157 Amazon_Fashion B09Q82N7P7 Good fit, thin fabric I am quite disappointed in this skirt. The sizing was good, and it fit well. I also liked the feel
1.128 Amazon_Fashion B09K6HNC6L Perfect Fit This skirt fit my Granddaughter perfectly! She loves it!!!
1.123 Clothing_Shoes_and_Jewelry B09MYY6G7Z Pret

In [None]:
# ColBERT reranking
for p, s in final:
    print(f"{s:.3f}",p.payload.get('category'), p.payload.get("asin"), (p.payload.get("document","")[:120]).replace("\n"," "))

23.806 Amazon_Fashion B08KDHYTX8 Fit is off I usually wear a small/medium, and this skirt did not fit right.
23.617 Amazon_Fashion B08YD1H17Z Designed for actual children, A girls skirt not for Women I wish I would have listened to the reviews about the fit. You
22.601 Amazon_Fashion B08BNF3XVX Sent back Skirt too long, did not fit correctly.  Sent back
22.591 Clothing_Shoes_and_Jewelry B08SBFDF3P 2x Is more like a XL This skirt is a beautifully made I was very disappointed that I ordered my normal size 2X and it di
22.384 Amazon_Fashion B09RG54SD1 Nice fitting skirt However the pleats don't lay smoothly.  Sent back.
21.995 Clothing_Shoes_and_Jewelry B09X3CYBL1 Beautiful Skirt, Flattering Fit, YES It Is Rayon, But Polyester Lining I don't know about you, but when I get a piece of
21.955 Clothing_Shoes_and_Jewelry B07QJ8G7BV Lovely color and fit but material too thin for a fitted skirt The tan and olive brown skirts are lovely colors, and the 
21.928 Clothing_Shoes_and_Jewelry B09NBNGXR

Filter checking

In [None]:
q = "skirt is not fitting"


res = hybrid_search(
    q,
    topk=1000, prefetch_k=128, fusion="DBSF",
    filters={'category':'Amazon_Fashion'}
)

for p in res.points[:10]:
    print(f"{p.score:.3f}",p.payload.get('category'), p.payload.get("asin"), (p.payload.get("document","")[:120]).replace("\n"," "))

1.976 Amazon_Fashion B08BNKYGG1 The skirt is terrible I read the instructions and measured to get the correct fit. The leggings fit as expected, but the
1.564 Amazon_Fashion B09H5C24FB Nice skirt wish it would fit as expected. The skirt is cute. The material is a bit cheap but for the price its not bad. 
1.563 Amazon_Fashion B09GP5VPCL Size is wrong Hi!! I am XL  size women and i have odered XL size and the skirt fits like the penal skirt.So,if you need 
1.453 Amazon_Fashion B09MCN2TJM Fits good Skirt rides up a bit but fits as expected
1.437 Amazon_Fashion B09J8SHGTK Hard to fit cute skirt Cute skirt but difficult to fit.  Large was too large but description advised ordering up a size.
1.399 Amazon_Fashion B09Q82N7P7 Good fit, thin fabric I am quite disappointed in this skirt. The sizing was good, and it fit well. I also liked the feel
1.378 Amazon_Fashion B09K6HNC6L Perfect Fit This skirt fit my Granddaughter perfectly! She loves it!!!
1.349 Amazon_Fashion B08YD1H17Z Designed for act

In [None]:
# --- deps
import numpy as np, pandas as pd, matplotlib.pyplot as plt
from textwrap import shorten

# --- helpers
def _pid(p):
    # robust id extraction across client versions
    return getattr(p, "id", None) or getattr(p, "point_id", None) or p.payload.get("asin") or ""

def _snippet(p, n=120):
    return shorten((p.payload.get("document","") or "").replace("\n"," "), width=n, placeholder="…")

def build_frames(res, final, baseline_k=100, topk=10):
    # baseline (hybrid) frame
    base_rows = []
    for r, p in enumerate(res.points[:baseline_k], start=1):
        base_rows.append({
            "id": _pid(p), "asin": p.payload.get("asin"), "category": p.payload.get("category"),
            "title": _snippet(p, 140),
            "base_rank": r, "base_score": float(p.score)
        })
    base = pd.DataFrame(base_rows)

    # rerank (ColBERT) frame
    re_rows = []
    for r, (p, s) in enumerate(final, start=1):
        re_rows.append({"id": _pid(p), "rerank_rank": r, "rerank_score": float(s)})
    rer = pd.DataFrame(re_rows)

    # join & features
    df = base.merge(rer, on="id", how="outer", validate="one_to_one")
    df["short_id"] = df["asin"].fillna(df["id"]).astype(str).str[-8:]
    df["delta_rank"] = df["rerank_rank"] - df["base_rank"]  # negative = improved
    for col in ["base_score", "rerank_score"]:
        if col in df:
            m, M = df[col].min(), df[col].max()
            df[col+"_norm"] = (df[col]-m)/(M-m) if np.isfinite(m) and M>m else np.nan

    # slices for displays
    base_top = df.nsmallest(topk, "base_rank")
    rer_top  = df.nsmallest(topk, "rerank_rank")
    return df, base_top, rer_top

def kpis(base_top, rer_top):
    b_ids = set(base_top["id"].dropna())
    r_ids = set(rer_top["id"].dropna())
    overlap = len(b_ids & r_ids)
    # avg lift for items present in both views
    merged = base_top[["id","base_rank"]].merge(rer_top[["id","rerank_rank"]], on="id", how="inner")
    avg_lift = (merged["base_rank"] - merged["rerank_rank"]).mean()  # + means better after rerank
    # simple inversion count among intersection (Kendall lite)
    order_before = {i:r for r,i in enumerate(base_top["id"], start=1) if i in r_ids}
    order_after  = {i:r for r,i in enumerate(rer_top["id"],  start=1) if i in b_ids}
    items = list(order_before.keys() & order_after.keys())
    inversions = 0
    for i in range(len(items)):
        for j in range(i+1, len(items)):
            a, b = items[i], items[j]
            inversions += int((order_before[a] - order_before[b]) * (order_after[a] - order_after[b]) < 0)
    return overlap, avg_lift, inversions

def styled_table(rer_top):
    cols = ["rerank_rank","base_rank","delta_rank","rerank_score","base_score","category","short_id","title"]
    show = rer_top.sort_values("rerank_rank")[cols].copy()
    # pretty
    def arrow(v):
        if pd.isna(v): return ""
        return f"▲{int(-v)}" if v < 0 else (f"▼{int(v)}" if v > 0 else "—")
    show["Δ"] = show["delta_rank"].map(arrow)
    show = show.rename(columns={
        "rerank_rank":"Rank (ColBERT)",
        "base_rank":"Rank (Hybrid)",
        "rerank_score":"Score (ColBERT)",
        "base_score":"Score (Hybrid)"
    })[["Rank (ColBERT)","Rank (Hybrid)","Δ","Score (ColBERT)","Score (Hybrid)","category","short_id","title"]]

    sty = (show.style
        .bar(subset=["Score (ColBERT)","Score (Hybrid)"], align="zero", vmax=None)
        .apply(lambda s: ["background-color:#eaffea" if v.startswith("▲") else
                          ("background-color:#ffecec" if v.startswith("▼") else "")
                          for v in s], subset=["Δ"])
        .hide(axis="index"))
    return sty

def slopegraph(base_top, rer_top, save_path="rank_slope.png"):
    # union of rerank top10 (primary) with any base items missing
    key = "id"
    R = rer_top.sort_values("rerank_rank")[[key,"short_id","base_rank","rerank_rank"]].copy()
    # plot
    plt.figure(figsize=(6, 5), dpi=160)
    for _, r in R.iterrows():
        x = [0, 1]; y = [r["base_rank"], r["rerank_rank"]]
        plt.plot(x, y, marker="o", linewidth=2)
        plt.text(-0.02, y[0], r["short_id"], ha="right", va="center", fontsize=9)
        plt.text( 1.02, y[1], r["short_id"], ha="left",  va="center", fontsize=9)
    plt.gca().invert_yaxis()
    plt.xticks([0,1], ["Hybrid (DBSF)", "ColBERT rerank"])
    plt.yticks(range(1, 11))
    plt.title("Rank movement (Top-10)")
    plt.grid(axis="y", linestyle=":", alpha=0.4)
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    print(f"Saved slopegraph → {save_path}")

# --- run it
df, base_top, rer_top = build_frames(res, final, baseline_k=100, topk=10)
ovl, avg_lift, inv = kpis(base_top, rer_top)
print(f"KPI → Top-10 overlap: {ovl}/10 | Avg. lift: {avg_lift:+.2f} ranks | Pairwise inversions: {inv}")

# Table display (in notebook)
styled_table(rer_top)



KPI → Top-10 overlap: 2/10 | Avg. lift: +5.00 ranks | Pairwise inversions: 0


Rank (ColBERT),Rank (Hybrid),Δ,Score (ColBERT),Score (Hybrid),category,short_id,title
1.0,14,▲13,23.805765,1.018802,Amazon_Fashion,8KDHYTX8,"Fit is off I usually wear a small/medium, and this skirt did not fit right."
2.0,9,▲7,23.61669,1.089162,Amazon_Fashion,8YD1H17Z,"Designed for actual children, A girls skirt not for Women I wish I would have listened to the reviews about the fit. You would need to…"
3.0,25,▲22,22.601387,0.88086,Amazon_Fashion,8BNF3XVX,"Sent back Skirt too long, did not fit correctly. Sent back"
4.0,84,▲80,22.590673,0.547897,Clothing_Shoes_and_Jewelry,8SBFDF3P,2x Is more like a XL This skirt is a beautifully made I was very disappointed that I ordered my normal size 2X and it did not fit at all it…
5.0,44,▲39,22.383745,0.729657,Amazon_Fashion,9RG54SD1,Nice fitting skirt However the pleats don't lay smoothly. Sent back.
6.0,33,▲27,21.995396,0.83292,Clothing_Shoes_and_Jewelry,9X3CYBL1,"Beautiful Skirt, Flattering Fit, YES It Is Rayon, But Polyester Lining I don't know about you, but when I get a piece of clothing that…"
7.0,10,▲3,21.955246,1.067348,Clothing_Shoes_and_Jewelry,7QJ8G7BV,"Lovely color and fit but material too thin for a fitted skirt The tan and olive brown skirts are lovely colors, and the cut is amazing. The…"
8.0,56,▲48,21.927738,0.659416,Clothing_Shoes_and_Jewelry,9NBNGXRK,"Loves it My wife loves this skirt! It's flattering, and a great length. It's a mixture of a flowy skirt that's fitted, without having a…"
9.0,99,▲90,21.927223,0.51037,Amazon_Fashion,84YXXDS8,"Get your fitting right seller! Cheap quality, paper thin material for top, sequins don't shine at all and the skirt this horribly wrong in…"
10.0,24,▲14,21.745113,0.903128,Clothing_Shoes_and_Jewelry,9PBRX2JF,"Super Cute, Runs Big Following the sizing chart, I ordered a 2x for my wife. The skirt fit perfectly, maybe a little big, but looked great.…"


# Time to quantinize it

In [None]:
import os, time, statistics
from qdrant_client import QdrantClient, models
from qdrant_client.http import models as http

client = QdrantClient(url='', api_key=QDRANT_API_KEY, prefer_grpc=False)

admin = client

SRC  = "hybrid-search"
SCAL = "hybrid-scalar"
PROD = "hybrid-product"
BIN2 = "hybrid-binary2"

In [None]:
# discover dense dimension from source collection
src_info = client.get_collection(SRC)
DENSE_DIM = src_info.config.params.vectors['dense'].size
print("Dense dim:", DENSE_DIM)

In [None]:
def wait_green(name):
    while True:
        info = admin.get_collection(name)
        if getattr(info, "status", None) == http.CollectionStatus.GREEN:
            break
        time.sleep(2)

def throttle_bg(name):
    # reduce background CPU pressure while we do heavy work
    admin.update_collection(
        name,
        optimizers_config=models.OptimizersConfigDiff(
            max_optimization_threads=1,
            indexing_threshold=2_147_483_647,
        ),
        hnsw_config=models.HnswConfigDiff(m=0),
        timeout=300,
    )

def clone_only(dst):
    if admin.collection_exists(dst):
        admin.delete_collection(dst, timeout=600)
    admin.create_collection(
        collection_name=dst,
        init_from=models.InitFrom(collection=SRC),
        vectors_config={"dense": models.VectorParams(size=DENSE_DIM, distance=models.Distance.COSINE)},
        sparse_vectors_config={"bm25": models.SparseVectorParams()},
        shard_number=2,
        timeout=60,
    )
    wait_green(dst)
    throttle_bg(dst)

def apply_scalar(dst, always_ram=False):
    admin.update_collection(
        dst,
        quantization_config=models.ScalarQuantization(
            scalar=models.ScalarQuantizationConfig(
                type=models.ScalarType.INT8,
                quantile=0.99,
                always_ram=always_ram,
            )
        ),
        timeout=900,
    )
    wait_green(dst)

def apply_product(dst, always_ram=False):
    admin.update_collection(
        dst,
        quantization_config=models.ProductQuantization(
            product=models.ProductQuantizationConfig(
                compression=models.CompressionRatio.X16,
                always_ram=always_ram,
            )
        ),
        timeout=60,
    )
    wait_green(dst)

def apply_binary2(dst, always_ram=False):
    admin.update_collection(
        dst,
        quantization_config=models.BinaryQuantization(
            binary=models.BinaryQuantizationConfig(
                encoding=models.BinaryQuantizationEncoding.TWO_BITS,
                always_ram=always_ram,
            )
        ),
        timeout=60,
    )
    wait_green(dst)

def build_hnsw(dst, m=16):
    admin.update_collection(dst, hnsw_config=models.HnswConfigDiff(m=m), timeout=600)
    wait_green(dst)

# === ran this one by one cause it was crashing the cluster===
#clone_only(SCAL)
#apply_scalar(SCAL, always_ram=False)
#build_hnsw(SCAL, m=16)

#clone_only(PROD)
#apply_product(PROD, always_ram=False)
#build_hnsw(PROD, m=16)

##clone_only(BIN2)
#apply_binary2(BIN2, always_ram=False)
#build_hnsw(BIN2, m=16)


In [None]:
# ===== Quantization Bake-off (fixed): Dense-only recall + Hybrid demo =====
import time
from statistics import median, quantiles
from qdrant_client import QdrantClient
from qdrant_client import models as qm

# ---------- knobs ----------
BASE    = "hybrid-search"       # float32 baseline (no quant at search time)
SCALAR  = "hybrid-scalar"
PRODUCT = "hybrid-product"
BINARY2 = "hybrid-binary2"
TESTS   = {"scalar": SCALAR, "product": PRODUCT, "binary2": BINARY2}

K           = 10          # top-k
NQ          = 100         # number of queries to sample
PREFETCH_K  = 64          # hybrid prefetch during benchmark
OVERSAMPLE  = 1.5         # prod-ish oversampling for rescored search
RESCORE     = True

# ---------- clients ----------
admin  = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY, prefer_grpc=False, timeout=600)  # admin ops
search = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY, prefer_grpc=True,  timeout=60)   # fast queries

# ---------- embedders (reuse if already defined) ----------
try:
    dense  # noqa: F821
    sparse # noqa: F821
except NameError:
    from fastembed import TextEmbedding, SparseTextEmbedding
    # Good defaults for demos; swap if your project uses something else
    dense  = TextEmbedding(model_name="BAAI/bge-small-en-v1.5")
    sparse = SparseTextEmbedding(model_name="Qdrant/bm25")

def _p95(values):
    return quantiles(values, n=20)[-1] if len(values) >= 20 else max(values)

def _q(coll_name, **kwargs):
    for i in range(3):
        try:
            return search.query_points(collection_name=coll_name, timeout=45, **kwargs)
        except Exception:
            if i == 2: raise
            time.sleep(0.2 * (i + 1))

# ---------- 0) verify quantization really exists ----------
for label, coll in TESTS.items():
    cfg = admin.get_collection(coll).config
    print(f"[check] {coll}: collection-level quantization_config = {getattr(cfg, 'quantization_config', None)}")
    # vector-level quantization compresses storage; not required for search-time quant
    vcfg = admin.get_collection(coll).config.params.vectors["dense"].quantization_config
    print(f"[check] {coll}: vectors['dense'].quantization_config = {vcfg}")

# ---------- 1) sample queries (texts) ----------
pts, _ = admin.scroll(BASE, limit=max(NQ, 150), with_vectors=False, with_payload=True)
text_queries = []
for p in pts:
    txt = (p.payload or {}).get("document", "")
    if txt:
        text_queries.append(txt)
    if len(text_queries) >= NQ:
        break
assert text_queries, "No payload texts found to build queries"

# Build *fresh* dense query vectors from text (DO NOT reuse stored vectors)
dense_queries = [next(dense.embed([t])) for t in text_queries]

# ---------- 2) exact KNN@K (dense-only) on baseline (ground truth) ----------
def exact_knn_dense(qv, k=K):
    r = _q(BASE, using="dense", query=qv, limit=k,
           search_params=qm.SearchParams(exact=True),
           with_vectors=False, with_payload=False)
    return [p.id for p in r.points]

gt_dense = [set(exact_knn_dense(qv, K)) for qv in dense_queries]

# ---------- 3) ANN (dense-only): quant-only + prod-ish rescored ----------
def ann_dense_ids(coll, qv, k=K, oversampling=1.0, rescore=False):
    sp = qm.SearchParams(
        quantization=qm.QuantizationSearchParams(ignore=False, oversampling=oversampling, rescore=rescore)
    )
    r = _q(coll, using="dense", query=qv, limit=k, search_params=sp,
           with_vectors=False, with_payload=False)
    return [p.id for p in r.points]

def bench_dense(coll, k=K):
    def run(oversampling, rescore):
        # warmup
        for qv in dense_queries[:10]:
            _ = ann_dense_ids(coll, qv, k, oversampling, rescore)
        # timed
        lat_ms, hits = [], 0
        t0 = time.perf_counter()
        for i, qv in enumerate(dense_queries):
            tq = time.perf_counter_ns()
            got = ann_dense_ids(coll, qv, k, oversampling, rescore)
            lat_ms.append((time.perf_counter_ns() - tq) / 1e6)
            hits += len(set(got) & gt_dense[i])
        total_ms = (time.perf_counter() - t0) * 1000.0
        return {
            "recall@10": round(hits / (len(dense_queries) * k), 4),
            "p50_ms": round(median(lat_ms), 1),
            "p95_ms": round(_p95(lat_ms), 1),
            "total_ms": round(total_ms, 1),
            "qps": round(len(dense_queries) / (total_ms / 1000.0), 2),
        }

    # A) quant-only (pure codes)
    no = run(oversampling=1.0, rescore=False)
    # B) production-ish (recover quality with rescoring)
    rs = run(oversampling=OVERSAMPLE, rescore=RESCORE)

    return {
        "recall@10_no_rescore": no["recall@10"],
        "p50_ms_no": no["p50_ms"], "p95_ms_no": no["p95_ms"],
        "total_ms_no": no["total_ms"], "qps_no": no["qps"],
        "recall@10_rescore": rs["recall@10"],
        "p50_ms_rs": rs["p50_ms"], "p95_ms_rs": rs["p95_ms"],
        "total_ms_rs": rs["total_ms"], "qps_rs": rs["qps"],
    }

# ---------- 4) HYBRID (BM25+dense) agreement vs baseline-hybrid ----------
def baseline_hybrid_ids(text, k=K, prefetch_k=PREFETCH_K, fusion="RRF"):
    qd = next(dense.embed([text])); qs = next(sparse.embed([text])).as_object()
    prefetch = [
        qm.Prefetch(query=qs, using="bm25",  limit=prefetch_k),
        qm.Prefetch(query=qd, using="dense", limit=prefetch_k),
    ]
    r = _q(BASE, prefetch=prefetch,
           query=qm.FusionQuery(fusion=(qm.Fusion.RRF if fusion.upper()=="RRF" else qm.Fusion.DBSF)),
           limit=k, with_vectors=False, with_payload=False)
    return [p.id for p in r.points]

baseline_hybrid_topk = [set(baseline_hybrid_ids(t, K)) for t in text_queries]

def hybrid_ids(coll, text, k=K, prefetch_k=PREFETCH_K, fusion="RRF",
               oversampling=OVERSAMPLE, rescore=RESCORE):
    qd = next(dense.embed([text])); qs = next(sparse.embed([text])).as_object()
    prefetch = [
        qm.Prefetch(query=qs, using="bm25",  limit=prefetch_k),
        qm.Prefetch(query=qd, using="dense", limit=prefetch_k),
    ]
    sp = qm.SearchParams(
        quantization=qm.QuantizationSearchParams(ignore=False, oversampling=oversampling, rescore=rescore)
    )
    r = _q(coll, prefetch=prefetch,
           query=qm.FusionQuery(fusion=(qm.Fusion.RRF if fusion.upper()=="RRF" else qm.Fusion.DBSF)),
           limit=k, search_params=sp, with_vectors=False, with_payload=False)
    return [p.id for p in r.points]

def bench_hybrid(coll, k=K):
    # warmup
    for t in text_queries[:10]:
        _ = hybrid_ids(coll, t, k)
    # timed
    lat_ms, hits = [], 0
    t0 = time.perf_counter()
    for i, t in enumerate(text_queries):
        tq = time.perf_counter_ns()
        got = hybrid_ids(coll, t, k)
        lat_ms.append((time.perf_counter_ns() - tq) / 1e6)
        hits += len(set(got) & baseline_hybrid_topk[i])
    total_ms = (time.perf_counter() - t0) * 1000.0
    return {
        "agree@10_vs_baselineHybrid": round(hits / (len(text_queries) * k), 4),
        "hyb_p50_ms": round(median(lat_ms), 1),
        "hyb_p95_ms": round(_p95(lat_ms), 1),
        "hyb_total_ms": round(total_ms, 1),
        "hyb_qps": round(len(text_queries) / (total_ms / 1000.0), 2),
    }

# ---------- 5) run all & print (safe + pretty) ----------
rows = []
for name, coll in TESTS.items():
    dense_stats  = bench_dense(coll)
    hybrid_stats = bench_hybrid(coll)
    rows.append({"collection": name, **dense_stats, **hybrid_stats})

# sanity: these must be within [0,1]
for row in rows:
    for k_ in ["recall@10_no_rescore", "recall@10_rescore", "agree@10_vs_baselineHybrid"]:
        v = row[k_]; assert 0.0 <= v <= 1.0, f"{k_} out of range: {v}"

# pretty print
try:
    import pandas as pd
    df = pd.DataFrame(rows, columns=[
        "collection",
        "recall@10_no_rescore","p50_ms_no","p95_ms_no","total_ms_no","qps_no",
        "recall@10_rescore","p50_ms_rs","p95_ms_rs","total_ms_rs","qps_rs",
        "agree@10_vs_baselineHybrid","hyb_p50_ms","hyb_p95_ms","hyb_total_ms","hyb_qps"
    ])

    # show percents nicely
    for col in ["recall@10_no_rescore","recall@10_rescore","agree@10_vs_baselineHybrid"]:
        df[col] = (df[col] * 100).round(2)

    for col in ["p50_ms_no","p95_ms_no","total_ms_no","p50_ms_rs","p95_ms_rs","total_ms_rs","hyb_p50_ms","hyb_p95_ms","hyb_total_ms"]:
        df[col] = df[col].round(1)
    for col in ["qps_no","qps_rs","hyb_qps"]:
        df[col] = df[col].round(2)

    print(df.to_string(index=False))
except Exception:
    # fallback tab-print
    header = [
        "collection",
        "recall@10_no_rescore","p50_ms_no","p95_ms_no","total_ms_no","qps_no",
        "recall@10_rescore","p50_ms_rs","p95_ms_rs","total_ms_rs","qps_rs",
        "agree@10_vs_baselineHybrid","hyb_p50_ms","hyb_p95_ms","hyb_total_ms","hyb_qps"
    ]
    print("\t".join(header))
    for r in rows:
        print("\t".join(str(r[h]) for h in header))

# ---------- 6) OPTIONAL: snapshot sizes (after timing) ----------
try:
    sizes = {}
    for coll in TESTS.values():
        admin.create_snapshot(coll, timeout=600)
        snaps = sorted(admin.list_snapshots(coll), key=lambda s: s.creation_time, reverse=True)
        sizes[coll] = round(snaps[0].size / (1024*1024), 1) if snaps else float("nan")
    print("\nSnapshot size (MB):", {k: sizes[v] for k, v in TESTS.items()})
except Exception as e:
    print("Snapshot size check skipped:", e)


[check] hybrid-scalar: collection-level quantization_config = scalar=ScalarQuantizationConfig(type=<ScalarType.INT8: 'int8'>, quantile=0.99, always_ram=False)
[check] hybrid-scalar: vectors['dense'].quantization_config = None
[check] hybrid-product: collection-level quantization_config = product=ProductQuantizationConfig(compression=<CompressionRatio.X16: 'x16'>, always_ram=False)
[check] hybrid-product: vectors['dense'].quantization_config = None
[check] hybrid-binary2: collection-level quantization_config = binary=BinaryQuantizationConfig(always_ram=False, encoding=<BinaryQuantizationEncoding.TWO_BITS: 'two_bits'>, query_encoding=None)
[check] hybrid-binary2: vectors['dense'].quantization_config = None
collection  recall@10_no_rescore  p50_ms_no  p95_ms_no  total_ms_no  qps_no  recall@10_rescore  p50_ms_rs  p95_ms_rs  total_ms_rs  qps_rs  agree@10_vs_baselineHybrid  hyb_p50_ms  hyb_p95_ms  hyb_total_ms  hyb_qps
    scalar                 100.0      231.2      279.0      23557.4    4.