In [None]:
import gzip
import json
import os
import pprint
import subprocess
import time
from datetime import timedelta
from pathlib import Path

import numpy as np
import pandas as pd
import requests

In [None]:
weaviate_name = "benchmark_weaviate"
weaviate_host = "localhost"
weaviate_port = 8091
weaviate_version = "1.24.4"

In [None]:
def get_dataset_config(target_name):
    setting = {
        "default": {
            "content_path": "dataset/passages-c400-jawiki-20230403",
            "embedding_path": "dataset/passages-c400-jawiki-20230403/multilingual-e5-base-passage",
            "num_of_docs": 5555583,
            "index_size": 100000,
            "bulk_size": 1000,
            "index_name": "Content",
            "distance": "dot",
            "dimension": 768,
            "hnsw_m": 48,
            "hnsw_ef_construction": 200,
            "hnsw_ef": 100,
        },
        "passages-c400-jawiki-20230403": {
            "content_path": "dataset/passages-c400-jawiki-20230403",
            "embedding_path": "dataset/passages-c400-jawiki-20230403/multilingual-e5-base-passage",
            "num_of_docs": 5555583,
            "index_size": 5000000,
            "bulk_size": 1000,
            "index_name": "Content",
            "distance": "dot",
            "dimension": 768,
            "hnsw_m": 48,
            "hnsw_ef_construction": 200,
            "hnsw_ef": 100,
        },
    }
    return setting.get(target_name)

dataset_config = get_dataset_config(os.getenv("TARGET_CONFIG", "default"))
pprint.pprint(dataset_config)

content_path = Path(dataset_config.get("content_path"))
embedding_path = Path(dataset_config.get("embedding_path"))
num_of_docs = int(dataset_config.get("num_of_docs"))
index_size = int(dataset_config.get("index_size"))
bulk_size = int(dataset_config.get("bulk_size"))

index_name = dataset_config.get("index_name")
distance = dataset_config.get("distance")
dimension = int(dataset_config.get("dimension"))
hnsw_m = int(dataset_config.get("hnsw_m"))
hnsw_ef_construction = int(dataset_config.get("hnsw_ef_construction"))
hnsw_ef = int(dataset_config.get("hnsw_ef"))

In [None]:
def run_weaviate():
    print(f"Starting {weaviate_name}... ", end="")
    docker_cmd = [
        # "sudo",
        "docker", "run", "-d",
        "--name", weaviate_name,
        "-p", f"{weaviate_port}:8080",
        # "-v", "./data:/var/lib/weaviate",
        f"cr.weaviate.io/semitechnologies/weaviate:{weaviate_version}"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print("[OK]")
    else:
        print(f"[FAIL]")
        print("STDOUT:")
        print(result.stdout)
        print("STDERR:")
        print(result.stderr)


In [None]:
def stop_weaviate():
    print(f"Stopping {weaviate_name}... ", end="")
    docker_cmd = [
        # "sudo",
        "docker", "stop", weaviate_name
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print("[OK]")
    else:
        print(f"[FAIL]")
        print("STDOUT:")
        print(result.stdout)
        print("STDERR:")
        print(result.stderr)


In [None]:
def prune_docker():
    print(f"Cleaning up... ", end="")
    docker_cmd = [
        # "sudo",
        "docker", "system", "prune", "-f"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print("[OK]")
    else:
        print(f"[FAIL]")
        print("STDOUT:")
        print(result.stdout)
        print("STDERR:")
        print(result.stderr)


In [None]:
def print_docker_system_df():
    docker_cmd = [
        # "sudo",
        "docker", "system", "df"
    ]
    result = subprocess.run(docker_cmd, capture_output=True, text=True)
    if result.returncode == 0:
        print(result.stdout)
    else:
        print(result.stderr)


In [None]:
def create_index():
    print(F"Creating {index_name}... ", end="")
    response = requests.post(f"http://{weaviate_host}:{weaviate_port}/v1/schema",
                            headers={"Content-Type": "application/json"},
                            json={
                                "class": index_name,
                                "vectorIndexType": "hnsw",
                                "vectorIndexConfig": {
                                    "distance": distance,
                                    "maxConnections": hnsw_m,
                                    "ef": hnsw_ef,
                                    "efConstruction": hnsw_ef_construction,
                                },
                                "properties": [
                                    {
                                        "name": "doc_id",
                                        "dataType": ["int"]
                                    },
                                    {
                                        "name": "pageId",
                                        "dataType": ["int"]
                                    },
                                    {
                                        "name": "revId",
                                        "dataType": ["int"]
                                    },
                                    {
                                        "name": "section",
                                        "dataType": ["string"],
                                        "indexInverted": True
                                    },
                                    {
                                        "name": "text",
                                        "dataType": ["text"],
                                        "indexInverted": True
                                    },
                                    {
                                        "name": "title",
                                        "dataType": ["text"],
                                        "indexInverted": True
                                    }
                                ]
                            })
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")


In [None]:
def delete_index():
    print(F"Deleting {index_name}... ", end="")
    response = requests.delete(f"http://{weaviate_host}:{weaviate_port}/v1/schema/{index_name}")
    if response.status_code == 200:
        print("[OK]")
    else:
        print(f"[FAIL]\n{response.text}")


In [None]:
def print_indices():
    response = requests.post(f"http://{weaviate_host}:{weaviate_port}/v1/graphql",
                            headers={"Content-Type": "application/json"},
                            json={
                                "query": "{ Aggregate { Content { meta { count } } } }"
                            })
    if response.status_code == 200:
        obj = json.loads(response.text)
        count = obj.get("data").get("Aggregate").get(index_name)[0].get("meta").get("count")
        print(f"count: {count}")
    else:
        print("count: FAILED")


In [None]:
def wait_for_weaviate(retry_count=60):
    print(f"Waiting for {weaviate_name}", end="")
    for i in range(retry_count):
        try:
            response = requests.get(f"http://{weaviate_host}:{weaviate_port}/v1/nodes")
            if response.status_code == 200:
                obj = json.loads(response.text)
                if obj.get("nodes")[0].get("status") == "HEALTHY":
                    print(" [OK]")        
                    return
        except:
            pass
        print(".", end="")
        time.sleep(1)
    print(" [FAIL]")


In [None]:
def get_embedding(embedding_index, embedding_data, id):
    emb_index = int(id / 100000) * 100000
    if embedding_data is None or embedding_index != emb_index:
        with np.load(embedding_path / f"{emb_index}.npz") as data:
            embedding_data = data["embs"]
    embedding = embedding_data[id - emb_index]
    if distance == "dot":
        embedding = embedding / np.linalg.norm(embedding)
    return emb_index, embedding_data, embedding


def insert_data(bulk_size, max_size):
    start_time = time.time()

    docs = []
    def send_data(pos):
        print(F"Sending {int(len(docs))} docs ({pos}/{max_size})... ", end="")
        now = time.time()
        response = requests.post(f"http://{weaviate_host}:{weaviate_port}/v1/batch/objects",
                                 headers={"Content-Type": "application/json"},
                                 json={
                                     "objects": docs,
                                 })
        if response.status_code == 200:
            t = time.time() - now
            print(f"[OK] {t}")
            return t
        else:
            print(f"[FAIL] 0 {response.status_code} {response.text}")
            return 0

    total_time = 0
    count = 0
    embedding_index = -1
    embedding_data = None
    for content_file in sorted(content_path.glob("*.parquet")):
        df = pd.read_parquet(content_file)
        for i,row in df.iterrows():
            if count >= max_size:
                break
            embedding_index, embedding_data, embedding = get_embedding(embedding_index, embedding_data, row.id)
            count += 1
            docs.append({
                "class": index_name,
                "properties": {
                    "doc_id": count,
                    #"page_id": row.pageid,
                    #"rev_id": row.revid,
                    #"title": row.title,
                    #"section": row.section,
                    #"text": row.text,
                },
                "vector": embedding.tolist(),                    
            })
            if len(docs) >= bulk_size:
                total_time += send_data(count)
                docs = []

    if len(docs) > 0:
        total_time += send_data(count)

    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f} ({timedelta(seconds=total_time)})")


In [None]:
def search(query):
    # print(query)
    now = time.time()
    response = requests.post(f"http://{weaviate_host}:{weaviate_port}/v1/graphql",
                             json={
                                 "query": query
                             })
    took = time.time() - now
    # print(response.text)

    if response.status_code == 200:
        obj = json.loads(response.text)
        results = obj.get("data").get("Get").get(index_name)
        product_ids = [x.get("doc_id") for x in results]
        scores = [x.get("_additional").get("distance") for x in results]
        return took * 1000, len(results), product_ids, scores
    print(f"[FAIL][{response.status_code}] {response.text}")
    return -1, -1, [], []


In [None]:
def create_query(embedding, page_size):
    return  f"""{{
  Get {{
    {index_name} (
      limit: {page_size}
      nearVector: {{
        vector: {json.dumps(embedding)}
      }}
    ) {{
      doc_id
      _additional {{
        distance
      }}
    }}
  }}
}}"""


def search_with_knn_queries(output_path, max_size=10000, page_size=100, offset=0):
    print("Sending knn queries...")
    start_time = time.time()
    pos = offset
    count = 0
    running = True
    with gzip.open(output_path, "wt", encoding="utf-8") as f:
        while running:
            with np.load(embedding_path / f"{pos}.npz") as data:
                embedding_data = data["embs"]
            for embedding in embedding_data:
                if count >= max_size:
                    running = False
                    break
                if distance == "dot":
                    embedding = embedding / np.linalg.norm(embedding)
                query = create_query(embedding.tolist(), page_size)
                took, hits, ids, scores = search(query)
                # print(f"{took}, {total_hits}, {ids}, {scores}")
                if took == -1:
                    continue
                result = {
                    "id": (count + 1),
                    "took": took,
                    "hits": hits,
                    "ids": ids,
                    "scores": scores,
                }
                f.write(json.dumps(result, ensure_ascii=False))
                f.write("\n")
                count += 1
                if count % 10000 == 0:
                    print(f"Sent {count}/{max_size} queries.")

            pos += 100000
            if pos > num_of_docs:
                pos = 0

    execution_time = time.time() - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"Execution Time: {int(hours):02d}:{int(minutes):02d}:{seconds:02.2f}")


In [None]:
def get_output_filename(weaviate_version, name, explain=False, track_total_hits=False):
    filename = f"output/weaviate{weaviate_version.replace('.', '_')}_{name}"
    filename += ".jsonl.gz"
    return filename


In [None]:
def print_took_and_total_hits(filename):
    tooks = []
    hits = []
    with gzip.open(filename, "rt", encoding="utf-8") as f:
        for line in f.readlines():
            obj = json.loads(line)
            tooks.append(obj.get("took"))
            hits = obj.get("hits")
    df = pd.DataFrame({"took": tooks, "hits": hits})
    print(df.describe().to_markdown())


In [None]:
prune_docker()
print(f"<<<Weaviate {weaviate_version}>>>")
run_weaviate()
wait_for_weaviate()

In [None]:
create_index()
print_docker_system_df()
print_indices()

In [None]:
insert_data(bulk_size=bulk_size, max_size=index_size)
print_docker_system_df()
print_indices()

In [None]:
for page_size in [10, 100, 400]:
    print(f"page size: {page_size}")
    filename = get_output_filename(weaviate_version, f"knn_{page_size}")
    search_with_knn_queries(filename, page_size=page_size, max_size=1000) # warmup
    search_with_knn_queries(filename, page_size=page_size, offset=index_size)
    print_took_and_total_hits(filename)

In [None]:
delete_index()
stop_weaviate()