In [None]:
import gzip
import json
import os
import pprint
import subprocess
import time
from datetime import timedelta
from pathlib import Path

import numpy as np
import pandas as pd
import requests

In [None]:
es_name = "benchmark_es"
es_host = "localhost"
es_port = 9211
es_version = "8.13.0"
es_heap = "2g"

In [None]:
def get_dataset_config(target_name):
    setting = {
        "default": {
            "content_path": "dataset/passages-c400-jawiki-20230403",
            "embedding_path": "dataset/passages-c400-jawiki-20230403/multilingual-e5-base-passage",
            "num_of_docs": 5555583,
            "index_size": 100000,
            "bulk_size": 1000,
            "index_name": "contents",
            "distance": "dot_product", # "cosine"
            "dimension": 768,
            "hnsw_m": 48,
            "hnsw_ef_construction": 200,
            "hnsw_ef": 100,
        },
        "passages-c400-jawiki-20230403": {
            "content_path": "dataset/passages-c400-jawiki-20230403",
            "embedding_path": "dataset/passages-c400-jawiki-20230403/multilingual-e5-base-passage",
            "num_of_docs": 5555583,
            "index_size": 5000000,
            "bulk_size": 1000,
            "index_name": "contents",
            "distance": "dot_product", # "cosine"
            "dimension": 768,
            "hnsw_m": 48,
            "hnsw_ef_construction": 200,
            "hnsw_ef": 100,
        },
    }
    return setting.get(target_name)

dataset_config = get_dataset_config(os.getenv("TARGET_CONFIG", "default"))
pprint.pprint(dataset_config)

content_path = Path(dataset_config.get("content_path"))
embedding_path = Path(dataset_config.get("embedding_path"))
num_of_docs = int(dataset_config.get("num_of_docs"))
index_size = int(dataset_config.get("index_size"))
bulk_size = int(dataset_config.get("bulk_size"))

index_name = dataset_config.get("index_name")
distance = dataset_config.get("distance")
dimension = int(dataset_config.get("dimension"))
hnsw_m = int(dataset_config.get("hnsw_m"))
hnsw_ef_construction = int(dataset_config.get("hnsw_ef_construction"))
hnsw_ef = int(dataset_config.get("hnsw_ef"))

In [None]:
def run_elasticsearch():
    print(f"Starting {es_name}... ", end="")
    docker_cmd = [
        # "sudo",
        "docker", "run", "-d",
        "--name", es_name,
        "-p", f"{es_port}:9200",
        "-e", "discovery.type=single-node",
        "-e", "bootstrap.memory_lock=true",
        "-e", "xpack.security.enabled=false",
        "-e", f"ES_JAVA_OPTS=-Xms{es_heap}",
        # "-v", "./data:/usr/share/elasticsearch/data",
        f"elasticsearch:{es_version}"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print("[OK]")
    else:
        print(f"[FAIL]")
        print("STDOUT:")
        print(result.stdout)
        print("STDERR:")
        print(result.stderr)


In [None]:
def stop_elasticsearch():
    print(f"Stopping {es_name}... ", end="")
    docker_cmd = [
        # "sudo",
        "docker", "stop", es_name
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print("[OK]")
    else:
        print(f"[FAIL]")
        print("STDOUT:")
        print(result.stdout)
        print("STDERR:")
        print(result.stderr)


In [None]:
def prune_docker():
    print(f"Cleaning up... ", end="")
    docker_cmd = [
        # "sudo",
        "docker", "system", "prune", "-f"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print("[OK]")
    else:
        print(f"[FAIL]")
        print("STDOUT:")
        print(result.stdout)
        print("STDERR:")
        print(result.stderr)


In [None]:
def print_docker_system_df():
    docker_cmd = [
        # "sudo",
        "docker", "system", "df"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print(result.stdout)
    else:
        print(result.stderr)


In [None]:
def print_docker_container_stats():
    docker_cmd = [
        # "sudo",
        "docker", "container", "stats", "--no-stream"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print(result.stdout)
    else:
        print(result.stderr)


In [None]:
def create_index(number_of_shards=1, number_of_replicas=0):
    print(F"Creating {index_name}... ", end="")
    response = requests.put(f"http://{es_host}:{es_port}/{index_name}",
                            headers={"Content-Type": "application/json"},
                            json={
        "mappings": {
            "_source": {
                "excludes": [
                    "embedding"
                ]
            },
            "properties": {
                "page_id": {
                    "type": "integer",
                },
                "rev_id": {
                    "type": "integer",
                },
                "title": {
                    "type": "text",
                },
                "section": {
                    "type": "keyword",
                },
                "text": {
                    "type": "text",
                },
                "embedding": {
                    "type": "dense_vector",
                    "dims": dimension,
                    "index": True,
                    "similarity": distance,
                    "index_options": {
                        "type": "int8_hnsw",
                        "m" : hnsw_m,
                        "ef_construction" : hnsw_ef_construction
                    }
                }
            }
        },
        "settings": {
            "index": {
                "number_of_shards": number_of_shards,
                "number_of_replicas": number_of_replicas,
            },
        }
    })
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")


In [None]:
def delete_index():
    print(F"Deleting {index_name}... ", end="")
    response = requests.delete(f"http://{es_host}:{es_port}/{index_name}")
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")


In [None]:
def forcemerge_index():
    print(F"Merging {index_name}... ", end="")
    start_time = time.time()
    response = requests.post(f"http://{es_host}:{es_port}/{index_name}/_forcemerge?max_num_segments=1",
                            timeout=60*60)
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")
        
    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def flush_index():
    print(F"Flushing {index_name}... ", end="")
    start_time = time.time()
    response = requests.post(f"http://{es_host}:{es_port}/{index_name}/_flush",
                            timeout=10*60)
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")
        
    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def refresh_index():
    print(F"Refreshing {index_name}... ", end="")
    start_time = time.time()
    response = requests.post(f"http://{es_host}:{es_port}/{index_name}/_refresh",
                            timeout=10*60)
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")
        
    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def open_index():
    print(F"Opening {index_name}... ", end="")
    start_time = time.time()
    response = requests.post(f"http://{es_host}:{es_port}/{index_name}/_open",
                            timeout=10*60)
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")
        
    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def close_index():
    print(F"Closing {index_name}... ", end="")
    start_time = time.time()
    response = requests.post(f"http://{es_host}:{es_port}/{index_name}/_close",
                            timeout=10*60)
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")
        
    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def print_indices():
    response = requests.get(f"http://{es_host}:{es_port}/_cat/indices")
    print(response.text)


In [None]:
def wait_for_elasticsearch(retry_count=60):
    print(f"Waiting for {es_name}", end="")
    for i in range(retry_count):
        try:
            response = requests.get(f"http://{es_host}:{es_port}/")
            if response.status_code == 200:
                print("[OK]")        
                return
        except:
            pass
        print(".", end="")
        time.sleep(1)
    print("[FAIL]")


In [None]:
def get_embedding(embedding_index, embedding_data, id):
    emb_index = int(id / 100000) * 100000
    if embedding_data is None or embedding_index != emb_index:
        with np.load(embedding_path / f"{emb_index}.npz") as data:
            embedding_data = data["embs"]
    embedding = embedding_data[id - emb_index]
    if distance == "dot_product":
        embedding = embedding.astype(np.float32)
        embedding = embedding / np.linalg.norm(embedding)
    return emb_index, embedding_data, embedding


def insert_data(bulk_size, max_size):
    start_time = time.time()

    bulk_data = []
    def send_data(pos):
        print(F"Sending {int(len(bulk_data)/2)} docs ({pos}/{max_size})... ", end="")
        response = requests.post(f"http://{es_host}:{es_port}/_bulk",
                                 headers={"Content-Type": "application/x-ndjson"},
                                 data="\n".join(bulk_data) + "\n")
        if response.status_code == 200:
            t = json.loads(response.text).get("took") / 1000
            print(f"[OK] {t}")
            return t
        else:
            print(f"[FAIL] 0 {response.status_code} {response.text}")
            return 0

    total_time = 0
    count = 0
    embedding_index = -1
    embedding_data = None
    for content_file in sorted(content_path.glob("*.parquet")):
        df = pd.read_parquet(content_file)
        for i,row in df.iterrows():
            if count >= max_size:
                break
            embedding_index, embedding_data, embedding = get_embedding(embedding_index, embedding_data, row.id)
            count += 1
            bulk_data.append(json.dumps({
                "index": {
                    "_index": index_name,
                    "_id" : count
                }
            }))
            bulk_data.append(json.dumps({
                #"page_id": row.pageid,
                #"rev_id": row.revid,
                #"title": row.title,
                #"section": row.section,
                #"text": row.text,
                "embedding": embedding.tolist(),
            }))
            if len(bulk_data) >= bulk_size * 2:
                total_time += send_data(count)
                bulk_data = []

    if len(bulk_data) > 0:
        total_time += send_data(count)

    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f} ({timedelta(seconds=total_time)})")


In [None]:
def search(query, offset=0, size=120, explain=False, track_total_hits=False):
    query_dsl = {
        "query": query,
        "size": size,
        "_source": False,
        "from": offset,
        "explain": explain,
        "sort": [
            {"_score": "desc"},
        ]
    }
    if track_total_hits:
        query_dsl["track_total_hits"]=track_total_hits
    response = requests.post(f"http://{es_host}:{es_port}/{index_name}/_search?request_cache=false", json=query_dsl)
    # print(response.text)

    if response.status_code == 200:
        obj = json.loads(response.text)
        if obj.get("timed_out"):
            print(f"[TIMEOUT] {query}")
            return -1, -1, -1, [], [], []
        hits = obj.get("hits").get("hits")
        product_ids = [x.get("_id") for x in hits]
        scores = [x.get("_score") for x in hits]
        explanations = [x.get("_explanation") for x in hits] if explain else []
        return obj.get("took"), len(hits), obj.get("hits").get("total").get("value"), product_ids, scores, explanations
    print(f"[FAIL][{response.status_code}] {response.text}")
    return -1, -1, -1, [], [], []


In [None]:
def search_with_knn_queries(output_path, explain=False, track_total_hits=False, max_size=10000, page_size=100, offset=0):
    print("Sending knn queries...")
    start_time = time.time()
    pos = offset
    count = 0
    running = True
    with gzip.open(output_path, "wt", encoding="utf-8") as f:
        while running:
            with np.load(embedding_path / f"{pos}.npz") as data:
                embedding_data = data["embs"]
            for embedding in embedding_data:
                if count >= max_size:
                    running = False
                    break
                if distance == "dot_product":
                    embedding = embedding.astype(np.float32)
                    embedding = embedding / np.linalg.norm(embedding)
                num_candidates = hnsw_ef if hnsw_ef > page_size else page_size
                query = {
                    "knn": {
                        "field": "embedding",
                        "query_vector": embedding.tolist(),
                        "num_candidates": num_candidates
                    }
                }
                took, hits, total_hits, ids, scores, explanations = search(query=query, size=page_size, explain=explain, track_total_hits=track_total_hits)
                # print(f"{took}, {total_hits}, {ids}, {scores}")
                if took == -1:
                    print(f"norm: {np.linalg.norm(embedding)}")
                    continue
                result = {
                    "id": (count + 1),
                    "took": took,
                    "hits": hits,
                    "total_hits": total_hits,
                    "ids": ids,
                    "scores": scores,
                    "explanations": explanations,
                }
                if len(explanations) > 0:
                    result["explanations"] = explanations
                f.write(json.dumps(result, ensure_ascii=False))
                f.write("\n")
                count += 1
                if count % 10000 == 0:
                    print(f"Sent {count}/{max_size} queries.")

            pos += 100000
            if pos > num_of_docs:
                pos = 0

    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def get_output_filename(es_version, name, explain=False, track_total_hits=False):
    filename = f"output/es{es_version.replace('.', '_')}_{name}"
    if explain:
        filename += "_explain"
    if track_total_hits:
        filename += "_all"
    filename += ".jsonl.gz"
    return filename


In [None]:
def print_took_and_total_hits(filename):
    tooks = []
    hits = []
    total_hits = []
    with gzip.open(filename, "rt", encoding="utf-8") as f:
        for line in f.readlines():
            obj = json.loads(line)
            tooks.append(obj.get("took"))
            hits = obj.get("hits")
            total_hits = obj.get("total_hits")
    df = pd.DataFrame({"took": tooks, "hits": hits, "total_hits": total_hits})
    print(df.describe().to_markdown())

In [None]:
prune_docker()
print(f"<<<Elasticsearch {es_version}>>>")
run_elasticsearch()
wait_for_elasticsearch()

In [None]:
print_docker_system_df()
print_docker_container_stats()

In [None]:
create_index()

In [None]:
print_docker_system_df()
print_docker_container_stats()
print_indices()

In [None]:
insert_data(bulk_size=bulk_size, max_size=index_size)
flush_index()
print_docker_system_df()

In [None]:
# forcemerge_index()
close_index()
time.sleep(10)
open_index()
refresh_index()

In [None]:
print_docker_system_df()
print_docker_container_stats()
print_indices()

In [None]:
for page_size in [10, 100, 400]:
    print(f"page size: {page_size}")
    filename = get_output_filename(es_version, f"knn_{page_size}", explain=False, track_total_hits=False)
    search_with_knn_queries(filename, page_size=page_size, max_size=1000) # warmup
    search_with_knn_queries(filename, page_size=page_size, explain=False, track_total_hits=False, offset=index_size)
    print_took_and_total_hits(filename)

In [None]:
print_docker_system_df()
print_docker_container_stats()
print_indices()

In [None]:
delete_index()
stop_elasticsearch()