In [None]:
import subprocess
import time
import gzip
import pandas as pd
import json
import requests

In [None]:
opensearch_name = "benchmark_opensearch"
opensearch_host = "localhost"
opensearch_port = 9211
opensearch_heap = "4g"

product_path = "dataset/shopping_queries_dataset_products.parquet"
query_path = "dataset/shopping_queries_dataset_examples.parquet"

index_name = "products"

In [None]:
def run_opensearch():
    print(f"Starting {opensearch_name}... ", end="")
    docker_cmd = [
        "docker", "run", "-d",
        "--name", opensearch_name,
        "-p", f"{opensearch_port}:9200",
        "-e", "discovery.type=single-node",
        "-e", "bootstrap.memory_lock=true",
        "-e", "plugins.security.disabled=true",
        "-e", f"OPENSEARCH_JAVA_OPTS=-Xmx{opensearch_heap}",
        f"opensearchproject/opensearch:{opensearch_version}"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print("[OK]")
    else:
        print(f"[FAIL]")
        print("STDOUT:")
        print(result.stdout)
        print("STDERR:")
        print(result.stderr)

In [None]:
def stop_opensearch():
    print(f"Stopping {opensearch_name}... ", end="")
    docker_cmd = [
        "docker", "stop", opensearch_name
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print("[OK]")
    else:
        print("[FAIL]")
        print("STDOUT:")
        print(result.stdout)
        print("STDERR:")
        print(result.stderr)

In [None]:
def prune_docker():
    print(f"Cleaning up... ", end="")
    docker_cmd = [
        "docker", "system", "prune", "-f"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print("[OK]")
    else:
        print("[FAIL]")
        print("STDOUT:")
        print(result.stdout)
        print("STDERR:")
        print(result.stderr)

In [None]:
def create_index(number_of_shards=1, number_of_replicas=0):
    print(F"Creating {index_name}... ", end="")
    response = requests.put(f"http://{opensearch_host}:{opensearch_port}/{index_name}",
                            headers={"Content-Type": "application/json"},
                            json={
        "mappings": {
            "properties": {
                "product_brand": {
                    "type": "keyword",
                },
                "product_bullet_point": {
                    "type": "text",
                },
                "product_color": {
                    "type": "keyword",
                },
                "product_description": {
                    "type": "text",
                    # "analyzer": "whitespace_analyzer",
                },
                "product_id": {
                    "type": "keyword",
                },
                "product_locale": {
                    "type": "keyword",
                },
                "product_title": {
                    "type": "text",
                    # "analyzer": "whitespace_analyzer",
                }
            }
        },
        "settings": {
            "index": {
                "number_of_shards": number_of_shards,
                "number_of_replicas": number_of_replicas,
            },
            # "analysis": {
            #     "analyzer": {
            #         "whitespace_analyzer": {
            #             "type": "custom",
            #             "tokenizer": "whitespace",
            #             "filter": ["lowercase", "symbol_filter"]
            #         }
            #     },
            #     "filter": {
            #         "symbol_filter": {
            #             "type": "pattern_replace",
            #             "pattern": "[^\\w\\s]",
            #             "replacement": ""
            #         }
            #     }
            # }
        }
    })
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")


In [None]:
def delete_index():
    print(F"Deleting {index_name}... ", end="")
    response = requests.delete(f"http://{opensearch_host}:{opensearch_port}/{index_name}")
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")


In [None]:
def forcemerge_index():
    print(F"Merging {index_name}... ", end="")
    start_time = time.time()
    response = requests.post(f"http://{opensearch_host}:{opensearch_port}/{index_name}/_forcemerge?max_num_segments=1",
                            timeout=10*60)
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")
        
    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def flush_index():
    print(F"Flushing {index_name}... ", end="")
    start_time = time.time()
    response = requests.post(f"http://{opensearch_host}:{opensearch_port}/{index_name}/_flush",
                            timeout=10*60)
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")
        
    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def refresh_index():
    print(F"Refreshing {index_name}... ", end="")
    start_time = time.time()
    response = requests.post(f"http://{opensearch_host}:{opensearch_port}/{index_name}/_refresh",
                            timeout=10*60)
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")
        
    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def open_index():
    print(F"Opening {index_name}... ", end="")
    start_time = time.time()
    response = requests.post(f"http://{opensearch_host}:{opensearch_port}/{index_name}/_open",
                            timeout=10*60)
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")
        
    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def close_index():
    print(F"Closing {index_name}... ", end="")
    start_time = time.time()
    response = requests.post(f"http://{opensearch_host}:{opensearch_port}/{index_name}/_close",
                            timeout=10*60)
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")
        
    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def print_indices():
    response = requests.get(f"http://{opensearch_host}:{opensearch_port}/_cat/indices")
    print(response.text)


In [None]:
def wait_for_opensearch(retry_count=60):
    print(f"Waiting for {opensearch_name}", end="")
    for i in range(retry_count):
        try:
            response = requests.get(f"http://{opensearch_host}:{opensearch_port}/")
            if response.status_code == 200:
                print("[OK]")        
                return
        except:
            pass
        print(".", end="")
        time.sleep(1)
    print("[FAIL]")


In [None]:
def insert_data(bulk_size=10000):
    df = pd.read_parquet(product_path)

    bulk_data = []
    def send_data(pos):
        print(F"Sending {int(len(bulk_data)/2)} docs ({pos}/{len(df)})... ", end="")
        response = requests.post(f"http://{opensearch_host}:{opensearch_port}/_bulk",
                                 headers={"Content-Type": "application/json"},
                                 data="\n".join(bulk_data) + "\n")
        if response.status_code == 200:
            print("[OK]")
        else:
            print(f"[FAIL]\n{response.text}")

    start_time = time.time()
    for i,row in df.iterrows():
        bulk_data.append(json.dumps({
            "index": {
                "_index": index_name,
                "_id" : row.product_id
            }
        }))
        bulk_data.append(json.dumps({
            "product_brand": row.product_brand,
            "product_bullet_point": row.product_bullet_point,
            "product_color": row.product_color,
            "product_description": row.product_description,
            "product_id": row.product_id,
            "product_locale": row.product_locale,
            "product_title": row.product_title,
        }))
        if len(bulk_data) >= bulk_size * 2:
            send_data(i + 1)
            bulk_data = []

    if len(bulk_data) > 0:
        send_data(i + 1)

    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def search(query, offset=0, size=120, explain=False, track_total_hits=False):
    query_dsl = {
        "query": query,
        "size": size,
        "_source": False,
        "from": offset,
        "explain": explain,
        "sort": [
            {"_score": "desc"},
            {"product_id": "asc"},
        ]
    }
    if track_total_hits:
        query_dsl["track_total_hits"]=track_total_hits
    response = requests.post(f"http://{opensearch_host}:{opensearch_port}/{index_name}/_search?request_cache=false", json=query_dsl)
    # print(response.text)

    if response.status_code == 200:
        obj = json.loads(response.text)
        if obj.get("timed_out"):
            print(f"[TIMEOUT] {query}")
            return -1, -1, [], [], []
        product_ids = [x.get("_id") for x in obj.get("hits").get("hits")]
        scores = [x.get("_score") for x in obj.get("hits").get("hits")]
        explanations = [x.get("_explanation") for x in obj.get("hits").get("hits")] if explain else []
        return obj.get("took"), obj.get("hits").get("total").get("value"), product_ids, scores, explanations
    print(f"[FAIL][{response.status_code}] {query}")
    return -1, -1, [], [], []


In [None]:
def search_with_match_queries(output_path, explain=False, track_total_hits=False, max_size=-1):
    print("Sending match queries...")
    df = pd.read_parquet(query_path)
    start_time = time.time()
    with gzip.open(output_path, "wt", encoding="utf-8") as f:
        queries = df["query"].unique()
        if max_size < 0:
            max_size = len(queries)
        for i,q in enumerate(queries, 1):
            match_queries = []
            for word in q.split():
                if len(word) > 0:
                    match_queries.append({
                        "match": {
                            "product_title": {
                                "query": word
                            }
                        }
                    })
                    match_queries.append({
                        "match": {
                            "product_description": {
                                "query": word
                            }
                        }
                    })
            query = {
                 "bool": {
                     "minimum_should_match": 1,
                     "should": match_queries,
                 }
            }
            took, total_hits, ids, scores, explanations = search(query=query, explain=explain, track_total_hits=track_total_hits)
            # print(f"{took}, {total_hits}, {ids}, {scores}")
            result = {
                "query": q,
                "took": took,
                "total_hits": total_hits,
                "ids": ids,
                "scores": scores,
                "explanations": explanations,
            }
            if len(explanations) > 0:
                result["explanations"] = explanations
            f.write(json.dumps(result, ensure_ascii=False))
            f.write("\n")
            if i % 10000 == 0:
                print(f"Sent {i}/{max_size} queries.")
            if i >= max_size:
                break

    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def get_output_filename(opensearch_version, name, explain=False, track_total_hits=False):
    filename = f"output/opensearch{opensearch_version.replace('.', '_')}_{name}"
    if explain:
        filename += "_explain"
    if track_total_hits:
        filename += "_all"
    filename += ".jsonl.gz"
    return filename


In [None]:
opensearch_versions = [
    "1.3.13",
    "2.0.1",
    "2.1.0",
    "2.2.1",
    "2.3.0",
    "2.4.1",
    "2.5.0",
    "2.6.0",
    "2.7.0",
    "2.8.0",
    "2.9.0",
    "2.10.0",
    "2.11.1",
]
explain = False
track_total_hits = False

In [None]:
for opensearch_version in opensearch_versions:
    prune_docker()
    print(f"<<<OpenSearch {opensearch_version}>>>")
    run_opensearch()
    wait_for_opensearch()
    create_index()
    print_indices()
    insert_data()
    # close_index()
    # time.sleep(10)
    # open_index()
    refresh_index()
    print_indices()
    filename = get_output_filename(opensearch_version, "match", explain=explain, track_total_hits=track_total_hits)
    search_with_match_queries(filename, max_size=1000) # warmup
    search_with_match_queries(filename, explain=explain, track_total_hits=track_total_hits)
    delete_index()
    stop_opensearch()
    time.sleep(10)

In [None]:
def load_output(opensearch_version, name, explain, track_total_hits, max_results):
    output_dict = {}
    filename = get_output_filename(opensearch_version, name, explain=explain, track_total_hits=track_total_hits)
    print(f"Loading the result for {filename} ", end="")
    count = 0
    start_time = time.time()
    with gzip.open(filename, "rt", encoding="utf-8") as f:
        for line in f:
            # print(line)
            obj = json.loads(line)
            output_dict[obj.get("query")] = {
                "took": obj.get("took"),
                "total_hits": obj.get("total_hits"),
                "ids": obj.get("ids"),                
            }
            count += 1
            if count >= max_results:
                break
    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"{int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")
    return output_dict

def compare_ids(opensearch_version1, opensearch_version2, name, explain=False, track_total_hits=False, max_results=130193):
    output1 = load_output(opensearch_version1, name, explain=explain, track_total_hits=track_total_hits, max_results=max_results)
    output2 = load_output(opensearch_version2, name, explain=explain, track_total_hits=track_total_hits, max_results=max_results)
    total_count = 0
    error_count = 0
    for q,data1 in output1.items():
        total_count += 1
        data2 = output2.get(q)
        checks = []
        diff = False
        for id1,id2 in zip(data1.get("ids"), data2.get("ids")):
            if id1 != id2:
                diff = True
                checks.append(0)
            else:
                checks.append(1)
        if diff:
            # print(f"[DIFF] {q}")
            print("[DIFF] " + "".join(["." if x == 1  else "X" for x in checks]) + f" {q}")
            # print(" ".join(data1.get("ids")))
            # print(" ".join(data2.get("ids")))
            error_count += 1

    print(f"{error_count}/{total_count} results are different. ({100*error_count/total_count:3.2f}%)")

In [None]:
for i in range(len(opensearch_versions)-1):
    print(f"<<<{opensearch_versions[i]} vs {opensearch_versions[i+1]}>>>")
    compare_ids(opensearch_versions[i], opensearch_versions[i+1], "match", explain=explain, track_total_hits=track_total_hits)