In [None]:
import gzip
import json
import os
import pprint
import subprocess
import time
from datetime import timedelta
from pathlib import Path

import numpy as np
import pandas as pd
import requests

In [None]:
opensearch_name = "benchmark_opensearch"
opensearch_host = "localhost"
opensearch_port = 9212
opensearch_version = "2.13.0"
opensearch_heap = "2g" # "4g"

In [None]:
def get_dataset_config(target_name):
    setting = {
        "100k-768-m49-ef100-ip": {
            "content_path": "dataset/passages-c400-jawiki-20230403",
            "embedding_path": "dataset/passages-c400-jawiki-20230403/multilingual-e5-base-passage",
            "num_of_docs": 5555583,
            "index_size": 100000,
            "bulk_size": 1000,
            "index_name": "contents",
            "engine": "lucene", # faiss"
            "distance": "cosinesimil", # "innerproduct"
            "dimension": 768,
            "hnsw_m": 48,
            "hnsw_ef_construction": 200,
            "hnsw_ef": 100,
        },
        "1m-768-m49-ef100-ip": {
            "content_path": "dataset/passages-c400-jawiki-20230403",
            "embedding_path": "dataset/passages-c400-jawiki-20230403/multilingual-e5-base-passage",
            "num_of_docs": 5555583,
            "index_size": 1000000,
            "bulk_size": 1000,
            "index_name": "contents",
            "engine": "lucene", # faiss"
            "distance": "cosinesimil", # "innerproduct"
            "dimension": 768,
            "hnsw_m": 48,
            "hnsw_ef_construction": 200,
            "hnsw_ef": 100,
        },
        "5m-768-m49-ef100-ip": {
            "content_path": "dataset/passages-c400-jawiki-20230403",
            "embedding_path": "dataset/passages-c400-jawiki-20230403/multilingual-e5-base-passage",
            "num_of_docs": 5555583,
            "index_size": 5000000,
            "bulk_size": 1000,
            "index_name": "contents",
            "engine": "lucene", # faiss"
            "distance": "cosinesimil", # "innerproduct"
            "dimension": 768,
            "hnsw_m": 48,
            "hnsw_ef_construction": 200,
            "hnsw_ef": 100,
        },
    }
    return setting.get(target_name)

volume_dir = os.getenv("VOLUME_DIR", "./data")

dataset_config = get_dataset_config(os.getenv("TARGET_CONFIG", "100k-768-m49-ef100-ip"))
pprint.pprint(dataset_config)

content_path = Path(dataset_config.get("content_path"))
embedding_path = Path(dataset_config.get("embedding_path"))
num_of_docs = int(dataset_config.get("num_of_docs"))
index_size = int(dataset_config.get("index_size"))
bulk_size = int(dataset_config.get("bulk_size"))

index_name = dataset_config.get("index_name")
distance = dataset_config.get("distance")
dimension = int(dataset_config.get("dimension"))
hnsw_m = int(dataset_config.get("hnsw_m"))
hnsw_ef_construction = int(dataset_config.get("hnsw_ef_construction"))
hnsw_ef = int(dataset_config.get("hnsw_ef"))

engine = dataset_config.get("engine")

results = {}

In [None]:
def run_opensearch():
    print(f"Starting {opensearch_name}... ", end="")
    docker_cmd = [
        # "sudo",
        "docker", "run", "-d",
        "--name", opensearch_name,
        "--ulimit", "memlock=-1:-1",
        "--ulimit", "nofile=65535:65535",
        "-p", f"{opensearch_port}:9200",
        "-e", "discovery.type=single-node",
        "-e", "bootstrap.memory_lock=true",
        "-e", "plugins.security.disabled=true",
        "-e", f"OPENSEARCH_JAVA_OPTS=-Xms{opensearch_heap} -Xmx{opensearch_heap}",
        "-e", "OPENSEARCH_INITIAL_ADMIN_PASSWORD=0LX4wquYDZu6jsve",
        # "-v", f"{volume_dir}:/usr/share/opensearch/data",
        f"opensearchproject/opensearch:{opensearch_version}"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print("[OK]")
    else:
        print(f"[FAIL]")
        print("STDOUT:")
        print(result.stdout)
        print("STDERR:")
        print(result.stderr)


In [None]:
def stop_opensearch():
    print(f"Stopping {opensearch_name}... ", end="")
    docker_cmd = [
        # "sudo",
        "docker", "stop", opensearch_name
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print("[OK]")
    else:
        print(f"[FAIL]")
        print("STDOUT:")
        print(result.stdout)
        print("STDERR:")
        print(result.stderr)


In [None]:
def prune_docker():
    print(f"Cleaning up... ", end="")
    docker_cmd = [
        # "sudo",
        "docker", "system", "prune", "-f"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print("[OK]")
    else:
        print(f"[FAIL]")
        print("STDOUT:")
        print(result.stdout)
        print("STDERR:")
        print(result.stderr)


In [None]:
def print_docker_system_df():
    docker_cmd = [
        # "sudo",
        "docker", "system", "df"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print(result.stdout)
    else:
        print(result.stderr)


In [None]:
def print_docker_container_stats():
    docker_cmd = [
        # "sudo",
        "docker", "container", "stats", "--no-stream"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    containers = {}
    if result.returncode == 0:
        print(result.stdout)
        for line in result.stdout.split("\n"):
            if line.startswith("CONTAINER") or len(line) == 0:
                continue
            values = line.split()
            containers[values[1]] = {
                "container_id": values[0],
                "cpu": values[2],
                "mem": values[6],
                "mem_usage": values[3],
                "mem_limit": values[5],
                "net_in": values[7],
                "net_out": values[9],
                "block_in": values[10],
                "block_out": values[12],
                "pids": values[13],
            }
    else:
        print(result.stderr)
    return containers


In [None]:
def create_index(number_of_shards=1, number_of_replicas=0):
    print(F"Creating {index_name}... ", end="")
    response = requests.put(f"http://{opensearch_host}:{opensearch_port}/{index_name}",
                            headers={"Content-Type": "application/json"},
                            json={
        "mappings": {
            "_source": {
                "excludes": [
                    "embedding"
                ]
            },
            "properties": {
                "page_id": {
                    "type": "integer",
                },
                "rev_id": {
                    "type": "integer",
                },
                "title": {
                    "type": "text",
                },
                "section": {
                    "type": "keyword",
                },
                "text": {
                    "type": "text",
                },
                "embedding": {
                    "type": "knn_vector",
                    "dimension": dimension,
                    "method": {
                        "name": "hnsw",
                        "space_type": distance,
                        "engine": engine,
                        "parameters": {
                            "ef_construction": hnsw_ef_construction,
                            "m": hnsw_m,
                        }
                    }
                }
            }
        },
        "settings": {
            "index": {
                "number_of_shards": number_of_shards,
                "number_of_replicas": number_of_replicas,
                "knn": True,
                "knn.algo_param.ef_search": hnsw_ef
            },
        }
    })
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")


In [None]:
def delete_index():
    print(F"Deleting {index_name}... ", end="")
    response = requests.delete(f"http://{opensearch_host}:{opensearch_port}/{index_name}")
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")


In [None]:
def forcemerge_index():
    print(F"Merging {index_name}... ", end="")
    start_time = time.time()
    response = requests.post(f"http://{opensearch_host}:{opensearch_port}/{index_name}/_forcemerge?max_num_segments=1",
                            timeout=60*60)
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")
        
    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def flush_index():
    print(F"Flushing {index_name}... ", end="")
    start_time = time.time()
    response = requests.post(f"http://{opensearch_host}:{opensearch_port}/{index_name}/_flush",
                            timeout=10*60)
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")
        
    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def refresh_index():
    print(F"Refreshing {index_name}... ", end="")
    start_time = time.time()
    response = requests.post(f"http://{opensearch_host}:{opensearch_port}/{index_name}/_refresh",
                            timeout=10*60)
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")
        
    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def open_index():
    print(F"Opening {index_name}... ", end="")
    start_time = time.time()
    response = requests.post(f"http://{opensearch_host}:{opensearch_port}/{index_name}/_open",
                            timeout=10*60)
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")
        
    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def close_index():
    print(F"Closing {index_name}... ", end="")
    start_time = time.time()
    response = requests.post(f"http://{opensearch_host}:{opensearch_port}/{index_name}/_close",
                            timeout=10*60)
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")
        
    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def print_indices():
    response = requests.get(f"http://{opensearch_host}:{opensearch_port}/_cat/indices")
    print(response.text)
    for line in response.text.split("\n"):
        values = line.split()
        if len(values) < 3:
            continue
        if values[2] == index_name:
            return {
                "num_of_docs": values[6],
                "index_size": values[8],
            }
    return {}


In [None]:
def wait_for_opensearch(retry_count=60):
    print(f"Waiting for {opensearch_name}", end="")
    for i in range(retry_count):
        try:
            response = requests.get(f"http://{opensearch_host}:{opensearch_port}/")
            if response.status_code == 200:
                print("[OK]")        
                return
        except:
            pass
        print(".", end="")
        time.sleep(1)
    print("[FAIL]")


In [None]:
def get_embedding(embedding_index, embedding_data, id):
    emb_index = int(id / 100000) * 100000
    if embedding_data is None or embedding_index != emb_index:
        with np.load(embedding_path / f"{emb_index}.npz") as data:
            embedding_data = data["embs"]
    embedding = embedding_data[id - emb_index]
    if distance == "innerproduct":
        embedding = embedding / np.linalg.norm(embedding)
    return emb_index, embedding_data, embedding


section_values = []

def get_section_values(df, min_count=10000):
    tmp_df = df[["id","section"]].groupby("section").count().reset_index()
    tmp_df = tmp_df[tmp_df["id"] >= min_count]
    return tmp_df["section"].values.tolist()


def insert_data(bulk_size, max_size):
    start_time = time.time()

    bulk_data = []
    def send_data(pos):
        print(F"Sending {int(len(bulk_data)/2)} docs ({pos}/{max_size})... ", end="")
        response = requests.post(f"http://{opensearch_host}:{opensearch_port}/_bulk",
                                 headers={"Content-Type": "application/x-ndjson"},
                                 data="\n".join(bulk_data) + "\n")
        if response.status_code == 200:
            t = json.loads(response.text).get("took") / 1000
            print(f"[OK] {t}")
            return t
        else:
            print(f"[FAIL] 0 {response.status_code} {response.text}")
            return 0

    total_time = 0
    count = 0
    embedding_index = -1
    embedding_data = None
    for content_file in sorted(content_path.glob("*.parquet")):
        df = pd.read_parquet(content_file)
        section_values.extend(get_section_values(df))
        for i,row in df.iterrows():
            if count >= max_size:
                break
            embedding_index, embedding_data, embedding = get_embedding(embedding_index, embedding_data, row.id)
            count += 1
            bulk_data.append(json.dumps({
                "index": {
                    "_index": index_name,
                    "_id" : count
                }
            }))
            bulk_data.append(json.dumps({
                "page_id": row.pageid,
                "rev_id": row.revid,
                #"title": row.title,
                "section": row.section,
                #"text": row.text,
                "embedding": embedding.tolist(),
            }))
            if len(bulk_data) >= bulk_size * 2:
                total_time += send_data(count)
                bulk_data = []

    if len(bulk_data) > 0:
        total_time += send_data(count)

    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f} ({timedelta(seconds=total_time)})")
    return {
        "execution_time": execution_time,
        "process_time": total_time,
    }


In [None]:
def search(query, offset=0, size=120, explain=False, track_total_hits=False):
    query_dsl = {
        "query": query,
        "size": size,
        "_source": False,
        "from": offset,
        "explain": explain,
        "sort": [
            {"_score": "desc"},
        ]
    }
    if track_total_hits:
        query_dsl["track_total_hits"]=track_total_hits
    response = requests.post(f"http://{opensearch_host}:{opensearch_port}/{index_name}/_search?request_cache=false", json=query_dsl)
    # print(response.text)

    if response.status_code == 200:
        obj = json.loads(response.text)
        if obj.get("timed_out"):
            print(f"[TIMEOUT] {query}")
            return -1, -1, -1, [], [], []
        hits = obj.get("hits").get("hits")
        product_ids = [x.get("_id") for x in hits]
        scores = [x.get("_score") for x in hits]
        explanations = [x.get("_explanation") for x in hits] if explain else []
        return obj.get("took"), len(hits), obj.get("hits").get("total").get("value"), product_ids, scores, explanations
    print(f"[FAIL][{response.status_code}] {response.text}")
    return -1, -1, -1, [], [], []


In [None]:
def search_with_knn_queries(output_path, pre_filter=None, explain=False, track_total_hits=False, max_size=10000, page_size=100, offset=0, max_error_count=100):
    print("Sending knn queries...")
    start_time = time.time()
    pos = offset
    count = 0
    running = True
    error_count = 0
    with gzip.open(output_path, "wt", encoding="utf-8") as f:
        while running:
            with np.load(embedding_path / f"{pos}.npz") as data:
                embedding_data = data["embs"]
            for embedding in embedding_data:
                if count >= max_size:
                    running = False
                    break
                if distance == "innerproduct":
                    embedding = embedding / np.linalg.norm(embedding)
                query = {
                    "knn": {
                        "embedding": {
                            "vector": embedding.tolist(),
                            "k": page_size,
                        }
                    }
                }
                if pre_filter is not None:
                    query["knn"]["embedding"]["filter"] = next(pre_filter)
                took, hits, total_hits, ids, scores, explanations = search(query=query, size=page_size, explain=explain, track_total_hits=track_total_hits)
                # print(f"{took}, {total_hits}, {ids}, {scores}")
                if took == -1:
                    error_count += 1
                    if error_count >= max_error_count:
                        running = False
                        break
                    continue
                result = {
                    "id": (count + 1),
                    "took": took,
                    "hits": hits,
                    "total_hits": total_hits,
                    "ids": ids,
                    "scores": scores,
                    "explanations": explanations,
                }
                if len(explanations) > 0:
                    result["explanations"] = explanations
                f.write(json.dumps(result, ensure_ascii=False))
                f.write("\n")
                count += 1
                if count % 10000 == 0:
                    print(f"Sent {count}/{max_size} queries.")

            pos += 100000
            if pos > num_of_docs:
                pos = 0

    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def get_output_filename(opensearch_version, name, explain=False, track_total_hits=False):
    filename = f"output/opensearch{opensearch_version.replace('.', '_')}_{name}"
    if explain:
        filename += "_explain"
    if track_total_hits:
        filename += "_all"
    filename += ".jsonl.gz"
    return filename


In [None]:
def print_took_and_total_hits(filename):
    tooks = []
    hits = []
    total_hits = []
    with gzip.open(filename, "rt", encoding="utf-8") as f:
        for line in f.readlines():
            obj = json.loads(line)
            tooks.append(obj.get("took"))
            hits.append(obj.get("hits"))
            total_hits.append(obj.get("total_hits"))
    df = pd.DataFrame({"took": tooks, "hits": hits, "total_hits": total_hits})
    print(df.describe().to_markdown())
    return {
        "num_of_queries": len(df),
        "took": {
            "mean": df.took.mean(),
            "std": df.took.std(),
            "min": df.took.min(),
            "25%": df.took.quantile(0.25),
            "50%": df.took.quantile(0.5),
            "75%": df.took.quantile(0.75),
            "90%": df.took.quantile(0.9),
            "99%": df.took.quantile(0.99),
            "max": df.took.max(),
        },
        "hits": {
            "mean": df.hits.mean(),
            "std": df.hits.std(),
            "min": df.hits.min(),
            "25%": df.hits.quantile(0.25),
            "50%": df.hits.quantile(0.5),
            "75%": df.hits.quantile(0.75),
            "max": df.hits.max(),
        },
    }

In [None]:
def save_results():
    with open("results.json", "wt", encoding="utf-8") as f:
        json.dump({
            "version": opensearch_version,
            "java_heap": opensearch_heap,
            "settings": dataset_config,
            "results": results,
        }, f, ensure_ascii=False, default=lambda x: int(x) if isinstance(x, np.int64) else None)


In [None]:
prune_docker()
print(f"<<<OpenSearch {opensearch_version}>>>")
run_opensearch()
wait_for_opensearch()

In [None]:
print_docker_container_stats()
print_docker_system_df()

In [None]:
create_index()

In [None]:
print_docker_container_stats()
print_indices()
print_docker_system_df()

In [None]:
results["indexing"] = insert_data(bulk_size=bulk_size, max_size=index_size)
flush_index()

In [None]:
print_docker_system_df()
print_docker_container_stats()
print_indices()

In [None]:
# forcemerge_index()
close_index()
time.sleep(10)
open_index()
refresh_index()

In [None]:
results["indexing"]["container"] = print_docker_system_df()
print_docker_container_stats()
print_indices()

In [None]:
for page_size in [10, 100, 400]:
    print(f"page size: {page_size}")
    filename = get_output_filename(opensearch_version, f"knn_{page_size}", explain=False, track_total_hits=False)
    search_with_knn_queries(filename, page_size=page_size, max_size=1000) # warmup
    search_with_knn_queries(filename, page_size=page_size, explain=False, track_total_hits=False, offset=index_size)
    results[f"top_{page_size}"] = print_took_and_total_hits(filename)

In [None]:
def pre_filter_generator():
    if len(section_values) > 0:
        while True:
            for s in section_values:
                yield {
                    "term": {
                        "section": s
                    }
                }

results["num_of_filtered_words"] = len(section_values)
for page_size in [10, 100, 400]:
    print(f"page size: {page_size}")
    filename = get_output_filename(opensearch_version, f"knn_{page_size}_filtered", explain=False, track_total_hits=False)
    search_with_knn_queries(filename, page_size=page_size, max_size=1000, pre_filter=pre_filter_generator()) # warmup
    search_with_knn_queries(filename, page_size=page_size, explain=False, track_total_hits=False, offset=index_size, pre_filter=pre_filter_generator())
    results[f"top_{page_size}_filtered"] = print_took_and_total_hits(filename)

In [None]:
print_docker_system_df()
print_docker_container_stats()
print_indices()

In [None]:
save_results()

In [None]:
delete_index()
stop_opensearch()