In [None]:
import gzip
import json
import os
import pprint
import re
import subprocess
import time
from datetime import timedelta, datetime
from pathlib import Path
from dataclasses import dataclass, asdict
import multiprocessing

import numpy as np
import pandas as pd
import requests

In [None]:
@dataclass
class DataSetConfig:
    content_path: Path
    embedding_path: Path
    num_of_docs: int
    index_size: int
    bulk_size: int
    index_name: str
    distance: str
    dimension: int
    exact: bool
    hnsw_m: int
    hnsw_ef_construction: int
    hnsw_ef: int
    update_docs_per_sec: int
    quantization: str

    qdrant_name: str = "benchmark_qdrant"
    qdrant_host: str = "localhost"
    qdrant_port: int = 6344
    qdrant_version: str = "1.13.4"


def get_dataset_config(target_name):
    setting = {
        "100k-768-m32-efc200-ef100-ip": {
            "content_path": Path("dataset/passages-c400-jawiki-20230403"),
            "embedding_path": Path("dataset/passages-c400-jawiki-20230403/multilingual-e5-base-passage"),
            "num_of_docs": 5555583,
            "index_size": 100000,
            "bulk_size": 1000,
            "index_name": "contents",
            "distance": "Dot", # "Cosine"
            "dimension": 768,
            "exact": False,
            "hnsw_m": 32,
            "hnsw_ef_construction": 200,
            "hnsw_ef": 100,
            "update_docs_per_sec": 0,
            "quantization": "none",
        },
        "1m-768-m48-efc200-ef100-ip": {
            "content_path": Path("dataset/passages-c400-jawiki-20230403"),
            "embedding_path": Path("dataset/passages-c400-jawiki-20230403/multilingual-e5-base-passage"),
            "num_of_docs": 5555583,
            "index_size": 1000000,
            "bulk_size": 1000,
            "index_name": "contents",
            "distance": "Dot", # "Cosine"
            "dimension": 768,
            "exact": False,
            "hnsw_m": 48,
            "hnsw_ef_construction": 200,
            "hnsw_ef": 100,
            "update_docs_per_sec": 0,
            "quantization": "none",
        },
        "5m-768-m48-efc200-ef100-ip": {
            "content_path": Path("dataset/passages-c400-jawiki-20230403"),
            "embedding_path": Path("dataset/passages-c400-jawiki-20230403/multilingual-e5-base-passage"),
            "num_of_docs": 5555583,
            "index_size": 5000000,
            "bulk_size": 1000,
            "index_name": "contents",
            "distance": "Dot", # "Cosine"
            "dimension": 768,
            "exact": False,
            "hnsw_m": 48,
            "hnsw_ef_construction": 200,
            "hnsw_ef": 100,
            "update_docs_per_sec": 0,
            "quantization": "none",
        },
    }
    target_setting = setting.get(target_name)
    target_setting["quantization"] = os.getenv("SETTING_QUANTIZATION", target_setting["quantization"])
    return DataSetConfig(**target_setting)


In [None]:
def run_qdrant(config):
    volume_dir = os.getenv("VOLUME_DIR", "./data")
    print(f"Starting {config.qdrant_name}... ", end="")
    docker_cmd = [
        # "sudo",
        "docker", "run", "-d",
        "--name", config.qdrant_name,
        "-p", f"{config.qdrant_port}:6333",
        # "-v", f"{volume_dir}:/qdrant/storage",
        f"qdrant/qdrant:v{config.qdrant_version}"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print("[OK]")
    else:
        print(f"[FAIL]")
        print("STDOUT:")
        print(result.stdout)
        print("STDERR:")
        print(result.stderr)


In [None]:
def stop_qdrant(config):
    print(f"Stopping {config.qdrant_name}... ", end="")
    docker_cmd = [
        # "sudo",
        "docker", "stop", config.qdrant_name
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print("[OK]")
    else:
        print(f"[FAIL]")
        print("STDOUT:")
        print(result.stdout)
        print("STDERR:")
        print(result.stderr)


In [None]:
def prune_docker(config):
    print(f"Cleaning up... ", end="")
    docker_cmd = [
        # "sudo",
        "docker", "system", "prune", "-f"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print("[OK]")
    else:
        print(f"[FAIL]")
        print("STDOUT:")
        print(result.stdout)
        print("STDERR:")
        print(result.stderr)


In [None]:
def print_docker_system_df(config):
    docker_cmd = [
        # "sudo",
        "docker", "system", "df"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print(result.stdout)
    else:
        print(result.stderr)


In [None]:
def print_docker_container_stats(config):
    docker_cmd = [
        # "sudo",
        "docker", "container", "stats", "--no-stream"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    containers = {}
    if result.returncode == 0:
        print(result.stdout)
        for line in result.stdout.split("\n"):
            if line.startswith("CONTAINER") or len(line) == 0:
                continue
            values = line.split()
            containers[values[1]] = {
                "container_id": values[0],
                "cpu": values[2],
                "mem": values[6],
                "mem_usage": values[3],
                "mem_limit": values[5],
                "net_in": values[7],
                "net_out": values[9],
                "block_in": values[10],
                "block_out": values[12],
                "pids": values[13],
            }
    else:
        print(result.stderr)
    return containers


In [None]:
def create_index(config):
    print(F"Creating Collection {config.index_name} with {config.quantization}... ", end="config.quantization")
    schema = {
        "vectors": {
            "size": config.dimension,
            "distance": config.distance,
            "hnsw_config": {
                "m": config.hnsw_m,
                "ef_construct": config.hnsw_ef_construction,
            }
        },
        "hnsw_config": {
            "m": config.hnsw_m,
            "ef_construct": config.hnsw_ef_construction,
        }
    }
    if config.quantization == "int8":
        schema["quantization_config"] = {
            "scalar": {
                "type": "int8",
                "quantile": 0.99,
                "always_ram": True
            }
        }
    response = requests.put(f"http://{config.qdrant_host}:{config.qdrant_port}/collections/{config.index_name}",
                            headers={"Content-Type": "application/json"},
                            json=schema)
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")

    for field_name in ["page_id", "rev_id"]:
        print(F"Creating Payload integer:{config.index_name}... ", end="")
        response = requests.put(f"http://{config.qdrant_host}:{config.qdrant_port}/collections/{config.index_name}/index",
                                headers={"Content-Type": "application/json"},
                                json={
                                    "field_name": field_name,
                                    "field_schema": "integer"
                                })
        if response.status_code == 200:
            print("[OK]")
        else:
            print(f"[FAIL]\n{response.text}")

    for field_name in ["section"]:
        print(F"Creating Payload keyword:{config.index_name}... ", end="")
        response = requests.put(f"http://{config.qdrant_host}:{config.qdrant_port}/collections/{config.index_name}/index",
                                headers={"Content-Type": "application/json"},
                                json={
                                    "field_name": field_name,
                                    "field_schema": "keyword"
                                })
        if response.status_code == 200:
            print("[OK]")
        else:
            print(f"[FAIL]\n{response.text}")

    for field_name in ["title", "text"]:
        print(F"Creating Payload text:{config.index_name}... ", end="")
        response = requests.put(f"http://{config.qdrant_host}:{config.qdrant_port}/collections/{config.index_name}/index",
                                headers={"Content-Type": "application/json"},
                                json={
                                    "field_name": field_name,
                                    "field_schema": {
                                        "type": "text",
                                        "tokenizer": "word",
                                        "min_token_len": 2,
                                        "max_token_len": 2,
                                        "lowercase": True
                                    }
                                })
        if response.status_code == 200:
            print("[OK]")
        else:
            print(f"[FAIL]\n{response.text}")


In [None]:
def delete_index(config):
    print(F"Deleting Collection {config.index_name}... ", end="")
    response = requests.delete(f"http://{config.qdrant_host}:{config.qdrant_port}/collections/{config.index_name}")
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")


In [None]:
def print_indices(config):
    response = requests.get(f"http://{config.qdrant_host}:{config.qdrant_port}/collections/{config.index_name}")
    obj = json.loads(response.text)
    pprint.pprint(obj)
    return {
        "num_of_docs": obj.get("result").get("points_count"),
    }


In [None]:
def wait_for_qdrant(config, retry_count=60):
    print(f"Waiting for {config.qdrant_name}", end="")
    for i in range(retry_count):
        try:
            response = requests.get(f"http://{config.qdrant_host}:{config.qdrant_port}/cluster")
            if response.status_code == 200:
                print("[OK]")        
                return
        except:
            pass
        print(".", end="")
        time.sleep(1)
    print("[FAIL]")


In [None]:
def get_embedding(config, embedding_index, embedding_data, id):
    emb_index = int(id / 100000) * 100000
    if embedding_data is None or embedding_index != emb_index:
        with np.load(config.embedding_path / f"{emb_index}.npz") as data:
            embedding_data = data["embs"]
    embedding = embedding_data[id - emb_index]
    if config.distance == "Dot":
        embedding = embedding / np.linalg.norm(embedding)
    return emb_index, embedding_data, embedding


def get_section_values(config, df, min_count=10000):
    tmp_df = df[["id","section"]].groupby("section").count().reset_index()
    tmp_df = tmp_df[tmp_df["id"] >= min_count]
    return tmp_df["section"].values.tolist()


def insert_data(config, max_size, bulk_size, controller=None, query_data=None):
    start_time = time.time()

    ids = []
    vectors = []
    payloads = []
    def send_data(pos):
        print(F"Sending {int(len(ids))} docs ({pos}/{max_size})... ", end="")
        response = requests.put(f"http://{config.qdrant_host}:{config.qdrant_port}/collections/{config.index_name}/points",
                                 headers={"Content-Type": "application/json"},
                                 params={
                                     "wait": "true",
                                 },
                                 data=json.dumps({
                                     "batch": {
                                         "ids": ids,
                                         "vectors": vectors,
                                         "payloads": payloads,
                                     }
                                 }))
        t = json.loads(response.text).get("time")
        if response.status_code == 200:
            print(f"[OK] {t}")
        else:
            print(f"[FAIL] {t} {response.status_code} {response.text}")
        return t

    total_time = 0
    count = 0
    embedding_index = -1
    embedding_data = None
    for content_file in sorted(config.content_path.glob("*.parquet")):
        if count >= max_size:
            break
        df = pd.read_parquet(content_file)
        if query_data is not None:
            query_data["section_values"].extend(get_section_values(config, df))
        for i,row in df.iterrows():
            if count >= max_size:
                break
            embedding_index, embedding_data, embedding = get_embedding(config, embedding_index, embedding_data, row.id)
            count += 1
            ids.append(count)
            vectors.append(embedding.tolist())
            payloads.append({
                "page_id": row.pageid,
                "rev_id": row.revid,
                # "title": row.title,
                "section": row.section,
                # "text": row.text,
            })
            if len(ids) >= bulk_size:
                total_time += send_data(count)
                ids = []
                vectors = []
                payloads = []
                if controller is not None and not controller.run():
                    return

    if len(ids) > 0:
        total_time += send_data(count)

    green_count = 0
    while green_count < 30:
        response = requests.get(f"http://{config.qdrant_host}:{config.qdrant_port}/collections/{config.index_name}")
        obj = json.loads(response.text)
        if obj.get("result").get("status") == "green":
            green_count += 1
        else:
            green_count = 0 # reset
        print(".", end="")
        time.sleep(1)
    print(".")

    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f} ({timedelta(seconds=total_time)})")
    return {
        "execution_time": execution_time,
        "process_time": total_time,
    }


In [None]:
class IndexingController:

    def __init__(self, stop_event):
        self._previous_time = time.time()
        self._stop_event = stop_event

    def run(self):
        if self._stop_event.is_set():
            return False

        wait_time = 1 - ( time.time() - self._previous_time)
        if wait_time > 0:
            # print(f"Waiting for {wait_time} sec")
            time.sleep(wait_time)
        self._previous_time =  time.time()
        return True


def update_data(target_config, stop_event):
    dataset_config = get_dataset_config(target_config)
    if dataset_config.update_docs_per_sec > 0:
        print(f"Starting update for {target_config}")
        insert_data(dataset_config, 
                    max_size=dataset_config.index_size,
                    bulk_size=dataset_config.update_docs_per_sec,
                    controller=IndexingController(stop_event))
        print(f"Stopping update for {target_config}")
    else:
        print(f"No background updates")


def start_update(target_config):
    stop_event = multiprocessing.Event()

    p = multiprocessing.Process(target=update_data, args=(target_config, stop_event))
    p.start()

    def stop_update():
        stop_event.set()
        p.join()

    return stop_update


In [None]:
def search(config, query):
    response = requests.post(f"http://{config.qdrant_host}:{config.qdrant_port}/collections/{config.index_name}/points/search",
                             headers={"Content-Type": "application/json"},
                             json=query)
    # print(response.text)

    if response.status_code == 200:
        obj = json.loads(response.text)
        if obj.get("status") != "ok":
            print(f"[FAIL] {response.text}")
            return -1, -1, [], [], []
        product_ids = [x.get("id") for x in obj.get("result")]
        scores = [x.get("score") for x in obj.get("result")]
        return obj.get("time") * 1000, len(obj.get("result")), product_ids, scores
    print(f"[FAIL][{response.status_code}] {response.text}")
    return -1, -1, [], []


In [None]:
def search_with_knn_queries(config, output_path, pre_filter=None, max_size=10000, page_size=100, offset=0, max_error_count=100, exact=False):
    print("Sending knn queries...")
    start_time = time.time()
    pos = offset
    doc_id = 0
    count = 0
    running = True
    error_count = 0
    with gzip.open(output_path, "wt", encoding="utf-8") as f:
        while running:
            with np.load(config.embedding_path / f"{pos}.npz") as data:
                embedding_data = data["embs"]
            for embedding in embedding_data:
                doc_id += 1
                if count >= max_size:
                    running = False
                    break
                if config.distance == "Dot":
                    embedding = embedding / np.linalg.norm(embedding)
                query = {
                    "vector": embedding.tolist(),
                    "limit": page_size,
                    # "with_payload": True,
                    "params": {
                        "hnsw_ef": config.hnsw_ef,
                        "exact": exact,
                    },
                }
                if pre_filter is not None:
                    query["filter"] = next(pre_filter)
                took, hits, ids, scores = search(config, query=query)
                # print(f"{took}, {ids}, {scores}")
                if took == -1:
                    error_count += 1
                    if error_count >= max_error_count:
                        running = False
                        break
                    continue
                result = {
                    "id": doc_id,
                    "took": took,
                    "hits": hits,
                    "ids": ids,
                    "scores": scores,
                }
                f.write(json.dumps(result, ensure_ascii=False))
                f.write("\n")
                count += 1
                if count % 10000 == 0:
                    print(f"Sent {count}/{max_size} queries.")

            pos += 100000
            if pos > config.num_of_docs:
                pos = 0

    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def get_output_filename(version, name, exact):
    filename = f"output/qdrant{version.replace('.', '_')}_{name}"
    if exact:
        filename += "_exact"
    filename += ".jsonl.gz"
    return filename


In [None]:
def print_took_and_total_hits(k, filename, truth_filename):
    query_ids = []
    tooks = []
    hits = []
    df = pd.read_json(filename, lines=True)
    result = {
        "num_of_queries": len(df),
        "took": {
            "mean": df.took.mean(),
            "std": df.took.std(),
            "min": df.took.min(),
            "25%": df.took.quantile(0.25),
            "50%": df.took.quantile(0.5),
            "75%": df.took.quantile(0.75),
            "90%": df.took.quantile(0.9),
            "99%": df.took.quantile(0.99),
            "max": df.took.max(),
        },
        "hits": {
            "mean": df.hits.mean(),
            "std": df.hits.std(),
            "min": df.hits.min(),
            "25%": df.hits.quantile(0.25),
            "50%": df.hits.quantile(0.5),
            "75%": df.hits.quantile(0.75),
            "max": df.hits.max(),
        },
    }
    if os.path.exists(truth_filename):
        df = pd.merge(df,
                      pd.read_json(truth_filename, lines=True)[["id", "ids"]].rename(columns={"ids":"truth_ids"}),
                      on="id", how="inner")
        def get_precision(row):
            size = len(row["truth_ids"])
            if size > k:
                size = k
            return len(set(row["ids"]).intersection(set(row["truth_ids"]))) / size
        df["precision"] = df.apply(get_precision, axis=1)
        result["precision"] = {
            "mean": df.precision.mean(),
            "std": df.precision.std(),
            "min": df.precision.min(),
            "25%": df.precision.quantile(0.25),
            "50%": df.precision.quantile(0.5),
            "75%": df.precision.quantile(0.75),
            "90%": df.precision.quantile(0.9),
            "99%": df.precision.quantile(0.99),
            "max": df.precision.max(),
        }
    print(df.describe().to_markdown())
    return result


In [None]:
def save_results(target_config, config, results):
    with open("results.json", "wt", encoding="utf-8") as f:
        json.dump({
            "variant": os.getenv("PRODUCT_VARIANT", ""),
            "target": target_config,
            "version": config.qdrant_version,
            "settings": asdict(config),
            "results": results,
            "timestamp": datetime.now().isoformat(),
        }, f, ensure_ascii=False, default=lambda x: int(x) if isinstance(x, np.int64) else None)


In [None]:
query_data = {"section_values": []}
results = {}
target_config = os.getenv("TARGET_CONFIG", "100k-768-m32-efc200-ef100-ip")
dataset_config = get_dataset_config(target_config)
pprint.pprint(dataset_config)

In [None]:
prune_docker(dataset_config)
print(f"<<<Qdrant {dataset_config.qdrant_version}>>>")
run_qdrant(dataset_config)
wait_for_qdrant(dataset_config)

In [None]:
print_docker_container_stats(dataset_config)
print_docker_system_df(dataset_config)

In [None]:
create_index(dataset_config)

In [None]:
print_docker_container_stats(dataset_config)
print_indices(dataset_config)
print_docker_system_df(dataset_config)

In [None]:
results["indexing"] = insert_data(dataset_config,
                                  max_size=dataset_config.index_size,
                                  bulk_size=dataset_config.bulk_size,
                                  query_data=query_data)

In [None]:
results["indexing"]["container"] = print_docker_container_stats(dataset_config)
print_indices(dataset_config)
print_docker_system_df(dataset_config)

In [None]:
for page_size in [10, 100]:
    print(f"page size: {page_size}")
    filename = get_output_filename(dataset_config.qdrant_version, f"knn_{page_size}", exact=dataset_config.exact)
    stop_update = start_update(target_config)
    search_with_knn_queries(dataset_config, filename, page_size=page_size, max_size=1000) # warmup
    search_with_knn_queries(dataset_config, filename, page_size=page_size, offset=dataset_config.index_size, exact=dataset_config.exact)
    stop_update()
    results[f"top_{page_size}"] = print_took_and_total_hits(page_size, filename,  f"dataset/ground_truth/{re.sub(r'-m.*', '', target_config)}/knn_{page_size}.jsonl.gz")

In [None]:
def pre_filter_generator():
    section_values = query_data["section_values"]
    if len(section_values) > 0:
        while True:
            for s in section_values:
                yield {
                    "must": [
                        {
                            "key": "section",
                            "match": {
                                "value": s
                            }
                        }
                    ]
                }

results["num_of_filtered_words"] = len(query_data["section_values"])
for page_size in [10, 100]:
    print(f"page size: {page_size}")
    filename = get_output_filename(dataset_config.qdrant_version, f"knn_{page_size}_filtered", exact=dataset_config.exact)
    stop_update = start_update(target_config)
    search_with_knn_queries(dataset_config, filename, page_size=page_size, max_size=1000, pre_filter=pre_filter_generator()) # warmup
    search_with_knn_queries(dataset_config, filename, page_size=page_size, offset=dataset_config.index_size, pre_filter=pre_filter_generator(), exact=dataset_config.exact)
    stop_update()
    results[f"top_{page_size}_filtered"] = print_took_and_total_hits(page_size, filename, f"dataset/ground_truth/{re.sub(r'-m.*', '', target_config)}/knn_{page_size}_filtered.jsonl.gz")

In [None]:
results["indexing"]["container"] = print_docker_container_stats(dataset_config)
print_indices(dataset_config)
print_docker_system_df(dataset_config)

In [None]:
save_results(target_config, dataset_config, results)

In [None]:
delete_index(dataset_config)
stop_qdrant(dataset_config)