In [None]:
# Restful API might not be enough, so use pymilvus.
! pip list | grep pymilvus || pip install pymilvus==2.4.0

In [None]:
import gzip
import json
import os
import pprint
import subprocess
import time
from datetime import timedelta
from pathlib import Path

import numpy as np
import pandas as pd
import requests
from pymilvus import MilvusClient, DataType

In [None]:
milvus_name = "benchmark_milvus"
milvus_host = "localhost"
milvus_port = 19540
milvus_version = "2.4.0-rc.1"
etcd_version = "3.5.5"
mineo_version = "RELEASE.2023-03-20T20-16-18Z"

In [None]:
def get_dataset_config(target_name):
    setting = {
        "100k-768-m49-ef100-ip": {
            "content_path": "dataset/passages-c400-jawiki-20230403",
            "embedding_path": "dataset/passages-c400-jawiki-20230403/multilingual-e5-base-passage",
            "num_of_docs": 5555583,
            "index_size": 100000,
            "bulk_size": 1000,
            "index_name": "contents",
            "distance": "IP", # "COSINE"
            "dimension": 768,
            "hnsw_m": 48,
            "hnsw_ef_construction": 200,
            "hnsw_ef": 100,
        },
        "1m-768-m49-ef100-ip": {
            "content_path": "dataset/passages-c400-jawiki-20230403",
            "embedding_path": "dataset/passages-c400-jawiki-20230403/multilingual-e5-base-passage",
            "num_of_docs": 5555583,
            "index_size": 1000000,
            "bulk_size": 1000,
            "index_name": "contents",
            "distance": "IP", # "COSINE"
            "dimension": 768,
            "hnsw_m": 48,
            "hnsw_ef_construction": 200,
            "hnsw_ef": 100,
        },
        "5m-768-m49-ef100-ip": {
            "content_path": "dataset/passages-c400-jawiki-20230403",
            "embedding_path": "dataset/passages-c400-jawiki-20230403/multilingual-e5-base-passage",
            "num_of_docs": 5555583,
            "index_size": 5000000,
            "bulk_size": 1000,
            "index_name": "contents",
            "distance": "IP", # "COSINE"
            "dimension": 768,
            "hnsw_m": 48,
            "hnsw_ef_construction": 200,
            "hnsw_ef": 100,
        },
    }
    return setting.get(target_name)

volume_dir = os.getenv("VOLUME_DIR", "./data")
use_volume = "#" if os.getenv("VOLUME_DIR") is None else ""

dataset_config = get_dataset_config(os.getenv("TARGET_CONFIG", "100k-768-m49-ef100-ip"))
pprint.pprint(dataset_config)

content_path = Path(dataset_config.get("content_path"))
embedding_path = Path(dataset_config.get("embedding_path"))
num_of_docs = int(dataset_config.get("num_of_docs"))
index_size = int(dataset_config.get("index_size"))
bulk_size = int(dataset_config.get("bulk_size"))

index_name = dataset_config.get("index_name")
distance = dataset_config.get("distance")
dimension = int(dataset_config.get("dimension"))
hnsw_m = int(dataset_config.get("hnsw_m"))
hnsw_ef_construction = int(dataset_config.get("hnsw_ef_construction"))
hnsw_ef = int(dataset_config.get("hnsw_ef"))

results = {}

In [None]:
compose_yaml_path = Path("milvus-compose.yaml")
compose_yaml = f"""
services:
  etcd:
    container_name: {milvus_name}-etcd
    image: quay.io/coreos/etcd:v{etcd_version}
    environment:
      - ETCD_AUTO_COMPACTION_MODE=revision
      - ETCD_AUTO_COMPACTION_RETENTION=1000
      - ETCD_QUOTA_BACKEND_BYTES=4294967296
      - ETCD_SNAPSHOT_COUNT=50000
{use_volume}    volumes:
{use_volume}      - {volume_dir}/volumes/etcd:/etcd
    command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
    healthcheck:
      test: ["CMD", "etcdctl", "endpoint", "health"]
      interval: 30s
      timeout: 20s
      retries: 3

  minio:
    container_name: {milvus_name}-minio
    image: minio/minio:{mineo_version}
    environment:
      MINIO_ACCESS_KEY: minioadmin
      MINIO_SECRET_KEY: minioadmin
#    ports:
#      - "9001:9001"
#      - "9000:9000"
{use_volume}    volumes:
{use_volume}      - {volume_dir}/volumes/minio:/minio_data
    command: minio server /minio_data --console-address ":9001"
    healthcheck:
      test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
      interval: 30s
      timeout: 20s
      retries: 3

  standalone:
    container_name: {milvus_name}-standalone
    image: milvusdb/milvus:v{milvus_version}
    command: ["milvus", "run", "standalone"]
    security_opt:
    - seccomp:unconfined
    environment:
      ETCD_ENDPOINTS: etcd:2379
      MINIO_ADDRESS: minio:9000
{use_volume}    volumes:
{use_volume}      - {volume_dir}/volumes/milvus:/var/lib/milvus
    healthcheck:
      test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"]
      interval: 30s
      start_period: 90s
      timeout: 20s
      retries: 3
    ports:
      - "{milvus_port}:19530"
      - "9091:9091"
    depends_on:
      - "etcd"
      - "minio"

networks:
  default:
    name: {milvus_name}
"""

with open(compose_yaml_path, "wt", encoding="utf-8") as f:
    f.write(compose_yaml)

In [None]:
def run_milvus():
    print(f"Starting {milvus_name}... ", end="")
    docker_cmd = [
        # "sudo",
        "docker", "compose",
        "-f", compose_yaml_path,
        "up", "-d"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print("[OK]")
    else:
        print(f"[FAIL]")
        print("STDOUT:")
        print(result.stdout)
        print("STDERR:")
        print(result.stderr)


In [None]:
def stop_milvus():
    print(f"Stopping {milvus_name}... ", end="")
    docker_cmd = [
        # "sudo",
        "docker", "compose",
        "-f", compose_yaml_path,
        "down"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print("[OK]")
    else:
        print(f"[FAIL]")
        print("STDOUT:")
        print(result.stdout)
        print("STDERR:")
        print(result.stderr)


In [None]:
def prune_docker():
    print(f"Cleaning up... ", end="")
    docker_cmd = [
        # "sudo",
        "docker", "system", "prune", "-f"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print("[OK]")
    else:
        print(f"[FAIL]")
        print("STDOUT:")
        print(result.stdout)
        print("STDERR:")
        print(result.stderr)


In [None]:
def print_docker_system_df():
    docker_cmd = [
        # "sudo",
        "docker", "system", "df"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print(result.stdout)
    else:
        print(result.stderr)


In [None]:
def print_docker_container_stats():
    docker_cmd = [
        # "sudo",
        "docker", "container", "stats", "--no-stream"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    containers = {}
    if result.returncode == 0:
        print(result.stdout)
        for line in result.stdout.split("\n"):
            if line.startswith("CONTAINER") or len(line) == 0:
                continue
            values = line.split()
            containers[values[1]] = {
                "container_id": values[0],
                "cpu": values[2],
                "mem": values[6],
                "mem_usage": values[3],
                "mem_limit": values[5],
                "net_in": values[7],
                "net_out": values[9],
                "block_in": values[10],
                "block_out": values[12],
                "pids": values[13],
            }
    else:
        print(result.stderr)
    return containers


In [None]:
def create_index():
    print(F"Creating Collection {index_name}... ", end="")

    client = MilvusClient(
        uri=f"http://{milvus_host}:{milvus_port}"
    )

    schema = MilvusClient.create_schema(
        auto_id=False,
        enable_dynamic_field=False,
    )
    
    schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
    schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=dimension)
    for field_name in ["page_id", "rev_id"]:
        schema.add_field(field_name=field_name, datatype=DataType.INT64)
    for field_name in ["section"]:
        schema.add_field(field_name=field_name, datatype=DataType.VARCHAR, max_length=128)
    # for field_name in ["title", "text"]:
    #     schema.add_field(field_name=field_name, datatype=DataType.VARCHAR, max_length=200)

    index_params = client.prepare_index_params()   
    index_params.add_index(
        field_name="embedding", 
        index_type="HNSW",
        metric_type=distance,
        params={
            "M": hnsw_m,
            "efConstruction": hnsw_ef_construction,
        }
    )
    
    client.create_collection(
        collection_name=index_name,
        schema=schema,
        index_params=index_params
    )

    response = client.get_load_state(
        collection_name=index_name
    )
    if response.get("state") == 3: # Loaded
        print("[OK]")
    else:
        print(f"[FAIL] {response}")

    client.close()


In [None]:
def delete_index():
    print(F"Deleting Collection {index_name}... ", end="")
    response = requests.post(f"http://{milvus_host}:{milvus_port}/v1/vector/collections/drop",
                            headers={
                                "Accept": "application/json",
                                "Content-Type": "application/json",
                            },
                            json={
                                "collectionName": index_name,
                            })
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")


In [None]:
def print_indices():
    response = requests.post(f"http://{milvus_host}:{milvus_port}/v1/vector/query",
                            headers={
                                "Accept": "application/json",
                                "Content-Type": "application/json",
                            },
                            json={
                                "collectionName": index_name,
                               "outputFields": ["count(*)"],
                               "filter": "id > 0",
                               "limit": 0,
                            })
    obj = json.loads(response.text)
    count = obj.get("data")[0].get("count(*)") if obj.get("code") == 200 else 0
    print(f"count: {count}")
    return {
        "num_of_docs": count,
    }


In [None]:
def wait_for_milvus(retry_count=60):
    print(f"Waiting for {milvus_name}", end="")
    for i in range(retry_count):
        try:
            # TODO replace with node check api?
            response = requests.post(f"http://{milvus_host}:{milvus_port}/v1/vector/collections/create",
                                    headers={
                                        "Accept": "application/json",
                                        "Content-Type": "application/json",
                                    },
                                    json={
                                        "collectionName": "healthcheck",
                                        "dimension": 256,
                                    })
            obj = json.loads(response.text)
            if response.status_code == 200 and obj.get("code") == 200:
                print("[OK]")        
                response = requests.post(f"http://{milvus_host}:{milvus_port}/v1/vector/collections/drop",
                                        headers={
                                            "Accept": "application/json",
                                            "Content-Type": "application/json",
                                        },
                                        json={
                                            "collectionName": "healthcheck",
                                        })
                return

        except:
            pass
        print(".", end="")
        time.sleep(1)
    print("[FAIL]")


In [None]:
def get_embedding(embedding_index, embedding_data, id):
    emb_index = int(id / 100000) * 100000
    if embedding_data is None or embedding_index != emb_index:
        with np.load(embedding_path / f"{emb_index}.npz") as data:
            embedding_data = data["embs"]
    embedding = embedding_data[id - emb_index]
    if distance == "IP":
        embedding = embedding / np.linalg.norm(embedding)
    return emb_index, embedding_data, embedding


section_values = []

def get_section_values(df, min_count=1000):
    tmp_df = df[["id","section"]].groupby("section").count().reset_index()
    tmp_df = tmp_df[tmp_df["id"] >= min_count]
    return tmp_df["section"].values.tolist()


def insert_data(bulk_size, max_size):
    start_time = time.time()

    docs = []
    def send_data(pos):
        print(F"Sending {int(len(docs))} docs ({pos}/{max_size})... ", end="")
        now = time.time()
        response = requests.post(f"http://{milvus_host}:{milvus_port}/v1/vector/insert",
                                 headers={
                                    "Accept": "application/json",
                                    "Content-Type": "application/json",
                                 },
                                 json={
                                     "collectionName": index_name,
                                     "data": docs,
                                 })
        t = time.time() - now
        if json.loads(response.text).get("code") == 200:
            print(f"[OK] {t}")
        else:
            print(f"[FAIL] {t} {response.status_code} {response.text}")
        return t

    total_time = 0
    count = 0
    embedding_index = -1
    embedding_data = None
    for content_file in sorted(content_path.glob("*.parquet")):
        df = pd.read_parquet(content_file)
        section_values.extend(get_section_values(df))
        for i,row in df.iterrows():
            if count >= max_size:
                break
            embedding_index, embedding_data, embedding = get_embedding(embedding_index, embedding_data, row.id)
            count += 1
            docs.append({
                "id": count,
                "embedding": embedding.tolist(),
                "page_id": row.pageid,
                "rev_id": row.revid,
                # "title": row.title,
                "section": row.section,
                # "text": row.text,
            })
            if len(docs) >= bulk_size:
                total_time += send_data(count)
                docs = []

    if len(docs) > 0:
        total_time += send_data(count)

    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f} ({timedelta(seconds=total_time)})")
    return {
        "execution_time": execution_time,
        "process_time": total_time,
    }


In [None]:
def search(query):
    now = time.time()
    response = requests.post(f"http://{milvus_host}:{milvus_port}/v2/vectordb/entities/search",
                             headers={
                                "Accept": "application/json",
                                "Content-Type": "application/json",
                             },
                             json=query)
    took = (time.time() - now) * 1000
    # print(response.text)

    obj = json.loads(response.text)
    if obj.get("code") == 200:
        results = obj.get("data")
        product_ids = [x.get("id") for x in results]
        scores = [x.get("distance") for x in results]
        return took, len(results), product_ids, scores
    print(f"[FAIL][{response.status_code}] {response.text}")
    return -1, -1, [], []


In [None]:
def search_with_knn_queries(output_path, pre_filter=None, max_size=10000, page_size=100, offset=0):
    print("Sending knn queries...")
    start_time = time.time()
    pos = offset
    count = 0
    running = True
    error_count = 0
    with gzip.open(output_path, "wt", encoding="utf-8") as f:
        while running:
            with np.load(embedding_path / f"{pos}.npz") as data:
                embedding_data = data["embs"]
            for embedding in embedding_data:
                if count >= max_size:
                    running = False
                    break
                if distance == "IP":
                    embedding = embedding / np.linalg.norm(embedding)
                query = {
                    "collectionName": index_name,
                    "annsField": "embedding",
                    "data": [embedding.tolist()],
                    "limit": page_size,
                    "searchParams": {
                        "metricType": distance,
                    },
                    # "outputFields": [
                    #     "page_id",
                    #     "rev_id",
                    #     "title",
                    #     "section",
                    #     "text",
                    # ],
                }
                if pre_filter is not None:
                    query["filter"] = next(pre_filter)
                took, hits, ids, scores = search(query=query)
                # print(f"{took}, {total_hits}, {ids}, {scores}")
                if took == -1:
                    error_count += 1
                    if error_count >= max_error_count:
                        running = False
                        break
                    continue
                result = {
                    "id": (count + 1),
                    "took": took,
                    "hits": hits,
                    "ids": ids,
                    "scores": scores,
                }
                f.write(json.dumps(result, ensure_ascii=False))
                f.write("\n")
                count += 1
                if count % 10000 == 0:
                    print(f"Sent {count}/{max_size} queries.")

            pos += 100000
            if pos > num_of_docs:
                pos = 0

    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def get_output_filename(milvus_version, name):
    filename = f"output/milvus{milvus_version.replace('.', '_')}_{name}"
    filename += ".jsonl.gz"
    return filename


In [None]:
def print_took_and_total_hits(filename):
    tooks = []
    hits = []
    with gzip.open(filename, "rt", encoding="utf-8") as f:
        for line in f.readlines():
            obj = json.loads(line)
            tooks.append(obj.get("took"))
            hits.append(obj.get("hits"))
    df = pd.DataFrame({"took": tooks, "hits": hits})
    print(df.describe().to_markdown())
    return {
        "num_of_queries": len(df),
        "took": {
            "mean": df.took.mean(),
            "std": df.took.std(),
            "min": df.took.min(),
            "25%": df.took.quantile(0.25),
            "50%": df.took.quantile(0.5),
            "75%": df.took.quantile(0.75),
            "90%": df.took.quantile(0.9),
            "99%": df.took.quantile(0.99),
            "max": df.took.max(),
        },
        "hits": {
            "mean": df.took.mean(),
            "std": df.took.std(),
            "min": df.took.min(),
            "25%": df.took.quantile(0.25),
            "50%": df.took.quantile(0.5),
            "75%": df.took.quantile(0.75),
            "max": df.took.max(),
        },
    }


In [None]:
def save_results():
    with open("results.json", "wt", encoding="utf-8") as f:
        json.dump({
            "version": milvus_version,
            "settings": dataset_config,
            "results": results,
        }, f, ensure_ascii=False, default=lambda x: int(x) if isinstance(x, np.int64) else None)


In [None]:
prune_docker()
print(f"<<<Milvus {milvus_version}>>>")
run_milvus()
wait_for_milvus()

In [None]:
print_docker_container_stats()
print_docker_system_df()

In [None]:
create_index()

In [None]:
print_docker_container_stats()
print_indices()
print_docker_system_df()

In [None]:
results["indexing"] = insert_data(bulk_size=bulk_size, max_size=index_size)

In [None]:
results["indexing"]["container"] = print_docker_container_stats()
print_indices()
print_docker_system_df()

In [None]:
for page_size in [10, 100, 400]:
    print(f"page size: {page_size}")
    filename = get_output_filename(milvus_version, f"knn_{page_size}")
    search_with_knn_queries(filename, page_size=page_size, max_size=1000) # warmup
    search_with_knn_queries(filename, page_size=page_size, offset=index_size)
    results[f"top_{page_size}"] = print_took_and_total_hits(filename)

In [None]:
def pre_filter_generator():
    if len(section_values) > 0:
        while True:
            for s in section_values:
                yield f"section == \"{s}\""


for page_size in [10, 100, 400]:
    print(f"page size: {page_size}")
    filename = get_output_filename(milvus_version, f"knn_{page_size}_filtered")
    search_with_knn_queries(filename, page_size=page_size, max_size=1000, pre_filter=pre_filter_generator()) # warmup
    search_with_knn_queries(filename, page_size=page_size, offset=index_size, pre_filter=pre_filter_generator())
    results[f"top_{page_size}_filtered"] = print_took_and_total_hits(filename)

In [None]:
print_docker_container_stats()
print_indices()
print_docker_system_df()

In [None]:
save_results()

In [None]:
delete_index()
stop_milvus()