In [1]:
%pip install python-dotenv

You should consider upgrading via the '/Users/kalyan/.pyenv/versions/3.8.12/bin/python3.8 -m pip install --upgrade pip' command.[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
from dotenv import load_dotenv
import os
import sys
from typing import Tuple, Optional, Dict, List
import time
import random

from tqdm.auto import tqdm
import pandas as pd
import numpy as np

sys.path.append("..")

from app.index import OpenSearchIndex
from app.ml import SBERTEncoder

load_dotenv(os.path.join("../..", '.env'))

# %env

  from .autonotebook import tqdm as notebook_tqdm


True

In [3]:
opensearch = OpenSearchIndex(
    url=os.getenv("OPENSEARCH_URL"),
    username=os.getenv("OPENSEARCH_USER"),
    password=os.getenv("OPENSEARCH_PASSWORD"),
    index_name=os.getenv("OPENSEARCH_INDEX"),
    # TODO: convert to env variables?
    opensearch_connector_kwargs={
        "use_ssl": os.getenv("OPENSEARCH_USE_SSL"),
        "verify_certs": os.getenv("OPENSEARCH_VERIFY_CERTS"),
        "ssl_show_warn": os.getenv("OPENSEARCH_SSL_WARNINGS"),
    },
    embedding_dim=768,
)

print(opensearch.is_connected())

opns = opensearch.opns

True


In [4]:
enc = SBERTEncoder(model_name="msmarco-distilbert-dot-v5")
enc.encode("hello world").shape

(768,)

## 1. Load query test set

In [5]:
query_df = pd.read_csv("./data/test_queries.tsv", sep="\t", header=None, names=["idx", "text"])
queries = query_df['text'].tolist()

with open("./data/test_query_embs.npy", "rb") as f:
    embs = np.load(f)
    
embs.shape

(1000, 768)

## 2. Load query

In [473]:
def _innerproduct_threshold_to_lucene_threshold(ip_thresh: float) -> float:
    """
    Opensearch documentation on mapping similarity functions to Lucene thresholds is here: https://github.com/opensearch-project/k-NN/blob/main/src/main/java/org/opensearch/knn/index/SpaceType.java#L33
    It defines 'inner product' as negative inner product i.e. a distance rather than similarity measure, so we reverse the signs of inner product here compared to the docs.
    """
    if ip_thresh > 0:
        return ip_thresh + 1
    else:
        return 1 / (1-ip_thresh)

def _year_range_filter(year_range: Tuple[Optional[int], Optional[int]]):
    """
    Get an Opensearch filter for year range. The filter returned is between the first term of
    `year_range` and the last term, and is inclusive. Either value can be set to None to only
    apply one year constraint.
    """

    start_date = f"01/01/{year_range[0]}" if year_range[0] is not None else None
    end_date = f"31/12/{year_range[1]}" if year_range[1] is not None else None

    policy_year_conditions = {}
    if start_date is not None:
        policy_year_conditions["gte"] = start_date
    if end_date is not None:
        policy_year_conditions["lte"] = end_date

    range_filter = {"range": {}}

    range_filter["range"]["action_date"] = policy_year_conditions

    return range_filter
    

def run_query(
    query: str, 
    embedding: np.ndarray,
    max_passages_per_doc: int, 
    keyword_filters: Optional[Dict[str, List[str]]] = None, 
    year_range: Optional[Tuple[Optional[int], Optional[int]]] = None,
    sort_field: Optional[str] = None,
    sort_order: Optional[str] = None,
    innerproduct_threshold: float = 70,
    max_no_docs: int = 100,
    n_passages_to_sample_per_shard: int = 5000,
    profile: bool = False,
    preference: Optional[str] = None
) -> dict:
    """
    Run an Opensearch query.
    
    Args:
        query (str): query string
        innerproduct_threshold (float): threshold applied to KNN results
        max_passages_per_doc (int): maximum number of passages to return per document
        keyword_filters (Optional[Dict[str, List[str]]]): filters on keyword values to apply.
        In the format `{"field_name": ["values", ...], ...}`. Defaults to None.
        year_range (Optional[Tuple[Optional[int], Optional[int]]]): filter on action year by (minimum, maximum). 
        Either value can be set to `None` for a one-sided filter.
        sort_field (Optional[str]): field to sort. Only the values `action_date`, `action_name` or `None` are valid.
        sort_order (Optional[str]): order to sort in, applied if `sort_field` is not None. Can be either "asc" or "desc".
        max_no_docs (int, optional): maximum number of documents to return. Keep this high so pagination can happen on the entire response. Defaults to 100.
        n_passages_to_sample_per_shard (int, optional): in order to speed up aggregations only the top N passages are considered for aggregation per shard. 
        Setting this value to lower will speed up searches at the cost of lowered recall. This value sets N. Defaults to 5000.
    
    Returns:
        dict: raw Opensearch result.
    """
    
    lucene_threshold = _innerproduct_threshold_to_lucene_threshold(innerproduct_threshold)

    opns_query = {
                "size": 0, # only return aggregations
                "query": {
                    "bool": {
                        "should": [
                            {
                                "bool": {
                                    "should": [
                                        {
                                            "match": {
                                                "text": {
                                                    "query": query,
                                                },
                                            }
                                        },
                                        {
                                            "function_score": {
                                                "query": {
                                                    "knn": {
                                                        "text_embedding": {
                                                            "vector": embedding,
                                                            "k": 1000, # TODO: tune me
                                                        },
                                                    },
                                                },
                                                "min_score": lucene_threshold
                                            }
                                        },
                                    ],
                                    "minimum_should_match": 1,
                                }
                            },
                            {
                                "bool": {
                                    "should": [
                                        {
                                            "match": {
                                                "for_search_action_description": {
                                                    "query": query,
                                                    "boost": 3,
                                                }
                                            }
                                        },
                                        {
                                            "function_score": {
                                                "query": {
                                                    "knn": {
                                                        "action_description_embedding": {
                                                            "vector": embedding,
                                                            "k": 1000, # TODO: tune me
                                                        },
                                                    },
                                                },
                                                "min_score": lucene_threshold # TODO: tune me separately for descriptions?
                                            }
                                        },
                                        # TODO: add knn on action description
                                    ],
                                    "minimum_should_match": 1,
                                    "boost": 2
                                },
                            },
                            {
                                "bool": {
                                    "should": [
                                        {
                                            "match": {
                                                "for_search_action_name": {
                                                    "query": query,
                                                }
                                            }
                                        },
                                        {
                                            "match_phrase": {
                                                "for_search_action_name": {
                                                    "query": query,
                                                    "boost": 2,
                                                }
                                            }
                                        },
                                    ],
                                    "boost": 10,
                                }
                            }
                        ],
                        "minimum_should_match": 1
                    },
                },
                "aggs": {
                    "sample": {
                        "sampler": {"shard_size": n_passages_to_sample_per_shard},
                        "aggs": {
                            "top_docs": {
                                "terms": {
                                    "field": "action_name_and_id",
                                    "order": {"top_hit": "desc"},
                                    "size": max_no_docs,
                                },
                                "aggs": {
                                    "top_passage_hits": {
                                        "top_hits": {
                                            "_source": {"excludes": ["text_embedding", "action_description_embedding"]},
                                            "size": max_passages_per_doc,
                                        }
                                    },
                                    "top_hit": {"max": {"script": {"source": "_score"}}},
                                    "action_date": {
                                        "stats": {
                                            "field": "action_date"
                                        }
                                    }
                                },
                            },
                            # "bucketcount": {
                            #   "stats_bucket": {
                            #     "buckets_path": "top_docs._count"
                            #   }
                            # },
                        }
                    }, 
                    "no_unique_docs": {
                      "cardinality": {
                        "field": "action_name_and_id"
                      }
                    }
                } 
            }
    
    if keyword_filters:
        terms_clauses = []

        for field, values in keyword_filters.items():
            terms_clauses.append({"terms": {field: values}})

        opns_query["query"]["bool"]["filter"] = terms_clauses

    
    if year_range:
        if "filter" not in opns_query["query"]["bool"]:
            opns_query["query"]["bool"]["filter"] = []

        opns_query["query"]["bool"]["filter"].append(
            _year_range_filter(year_range)
        )
        
    # TODO: how does this work in a situation with more than 10,000 i.e. paginated results?
    if sort_field and sort_order:
        if sort_field == "action_date":
            opns_query["aggs"]["sampler"]["aggs"]["top_docs"]["terms"]["order"] = {f"{sort_field}.avg": sort_order}
        elif sort_field == "action_name":
            opns_query["aggs"]["sampler"]["aggs"]["top_docs"]["terms"]["order"] = {"_key": sort_order}
    
    if profile:
        opns_query['profile'] = True
    
    response = opns.search(
        body=opns_query,
        index="navigator",
        request_timeout=30,
        preference=preference, # TODO: document what this means
    )
    
    passage_hit_count = response['hits']['total']['value']
    # note: 'gte' values are returned when there are more than 10,000 results by default
    if response['hits']['total']['relation'] == "eq":
        passage_hit_qualifier = "exactly"
    elif response['hits']['total']['relation'] == "gte":
        passage_hit_qualifier = "at least"
    
    # doc_hit_count = response['aggregations']['sample']['bucketcount']['count']
    doc_hit_count = response['aggregations']['no_unique_docs']['value']
        
    return response, doc_hit_count

run_query(
    queries[idx], 
    embs[idx,:],
    max_passages_per_doc=10,
    # year_range=(2000, None),
    # keyword_filters={
    #     "action_country_code": ["CHE"]
    # },
    # sort_field = "action_name",
    # sort_order = "asc",
    innerproduct_threshold=70, # TODO: tune me
    n_passages_to_sample_per_shard=2000
)


({'took': 9,
  'timed_out': False,
  '_shards': {'total': 1, 'successful': 1, 'skipped': 0, 'failed': 0},
  'hits': {'total': {'value': 10000, 'relation': 'gte'},
   'max_score': None,
   'hits': []},
  'aggregations': {'no_unique_docs': {'value': 909},
   'sample': {'doc_count': 2000,
    'top_docs': {'doc_count_error_upper_bound': -1,
     'sum_other_doc_count': 787,
     'buckets': [{'key': 'taxation laws amendment act, 2009 - sections 12k and 12l inserted in act 58 2249',
       'doc_count': 38,
       'action_date': {'count': 38,
        'min': 1231632000000.0,
        'max': 1231632000000.0,
        'avg': 1231632000000.0,
        'sum': 46802016000000.0,
        'min_as_string': '11/01/2009',
        'max_as_string': '11/01/2009',
        'avg_as_string': '11/01/2009',
        'sum_as_string': '05/02/3453'},
       'top_hit': {'value': 174.41909790039062},
       'top_passage_hits': {'hits': {'total': {'value': 38, 'relation': 'eq'},
         'max_score': 174.4191,
         'hit

## 2. Run query on single thread `n` times

In [478]:
times = []
_iterator = range(len(queries))
_iterator = random.sample(_iterator, len(_iterator)) # shuffle

for idx in tqdm(_iterator[0:50]):
    res, _ = run_query(
        queries[idx], 
        embs[idx,:],
        max_passages_per_doc=5,
        max_no_docs=1000,
        # year_range=(2000, None),
        # keyword_filters={
        #     "action_country_code": ["CHE"]
        # },
        # sort_field = "action_name",
        # sort_order = "asc",
        innerproduct_threshold=70, # TODO: tune me
        n_passages_to_sample_per_shard=5000,
    )
    
    # using 'took' from response because doesn't rely on network connection
    time_taken = res['took'] / 1000
    times.append(time_taken)
    # print(end-start)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:47<00:00,  2.15s/it]


In [479]:
print(np.mean(times))
pd.Series(times).describe()

1.5318200000000002


count    50.000000
mean      1.531820
std       0.422855
min       0.035000
25%       1.370500
50%       1.596000
75%       1.742750
max       2.436000
dtype: float64

In [447]:
print(np.mean(times))
pd.Series(times).describe()

1.4184


count    50.000000
mean      1.418400
std       0.322782
min       0.043000
25%       1.374250
50%       1.469500
75%       1.552750
max       2.043000
dtype: float64

### 2.1 Profile query

See https://www.elastic.co/guide/en/elasticsearch/reference/current/search-profile.html

In [316]:
idx = 301

res = run_query(
    queries[idx], 
    embs[idx,:],
    max_passages_per_doc=5,
    max_no_docs=500,
    # year_range=(2000, None),
    # keyword_filters={
    #     "action_country_code": ["CHE"]
    # },
    # sort_field = "action_name",
    # sort_order = "asc",
    innerproduct_threshold=70, # TODO: tune me
    n_passages_to_sample_per_shard=2000,
    profile=True
)

In [317]:
# res['profile']['shards'][0]['aggregations']

## 3. Assess the accuracy impact of parameters which affect query speed

`max_no_docs` (size of terms query), `n_passages_to_sample_per_shard` and `max_passages_per_doc`

In [356]:
import itertools

In [480]:
idx = 981

gt_max_passages_per_doc = 5
gt_max_no_docs = 2000
gt_n_passages_to_sample_per_shard = 800000

gt_results, no_docs = run_query(
        queries[idx], 
        embs[idx,:],
        max_passages_per_doc=gt_max_passages_per_doc,
        max_no_docs=gt_max_no_docs,
        # year_range=(2000, None),
        # keyword_filters={
        #     "action_country_code": ["CHE"]
        # },
        # sort_field = "action_name",
        # sort_order = "asc",
        innerproduct_threshold=70, # TODO: tune me
        n_passages_to_sample_per_shard=gt_n_passages_to_sample_per_shard,
    )


gt_docs = [bucket['key'] for bucket in gt_results["aggregations"]["sample"]["top_docs"]["buckets"]]
gt_no_docs = no_docs
gt_top_200 = gt_docs[:200]
gt_top_50 = gt_docs[:50]
gt_top_10 = gt_docs[:10]
gt_top_5 = gt_docs[:5]

In [481]:
search_max_passages_per_doc = [5]
search_max_no_docs = [20, 1000]
# search_n_passages_to_sample_per_shard = [2000, 5000, 10000, 50000, 800000]
search_n_passages_to_sample_per_shard = [5000, 50000, 100000, 250000, 800000]

search_combinations = list(itertools.product(search_max_passages_per_doc, search_max_no_docs, search_n_passages_to_sample_per_shard))

search_times = []
search_top_200 = []
search_top_50 = []
search_top_10 = []
search_top_5 = []
hits = []

for max_passages_per_doc, max_no_docs, n_passages_to_sample_per_shard in tqdm(search_combinations):
    results, no_hits = run_query(
        queries[idx], 
        embs[idx,:],
        max_passages_per_doc=max_passages_per_doc,
        max_no_docs=max_no_docs,
        # year_range=(2000, None),
        # keyword_filters={
        #     "action_country_code": ["CHE"]
        # },
        # sort_field = "action_name",
        # sort_order = "asc",
        innerproduct_threshold=70, # TODO: tune me
        n_passages_to_sample_per_shard=n_passages_to_sample_per_shard,
    )
    
    search_times.append(results['took'] / 1000)
    
    all_docs = [bucket['key'] for bucket in results["aggregations"]["sample"]["top_docs"]["buckets"]]
    search_top_200.append(all_docs[:200])
    search_top_50.append(all_docs[:50])
    search_top_10.append(all_docs[:10])
    search_top_5.append(all_docs[:5])
    hits.append(no_hits)
    


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:24<00:00,  2.45s/it]


In [482]:
def precision(gt_list, pred_list):
    """Unordered measure of precision."""
    
    set_intersection = list(set(gt_list).intersection(set(pred_list)))
    
    return len(set_intersection) / len(gt_list)

results_formatted = []

for _idx, params in enumerate(search_combinations):
    
    results_formatted.append(
        {
            "max_passages_per_doc": params[0],
            "max_no_docs": params[1],
            "n_passages_to_sample_per_shard": params[2],
            "time": search_times[_idx],
            "precision_5": precision(gt_top_5, search_top_5[_idx]),
            "precision_10": precision(gt_top_10, search_top_10[_idx]),
            "precision_50": precision(gt_top_50, search_top_50[_idx]),
            "precision_200": precision(gt_top_200, search_top_200[_idx]),
            "no_hits": hits[_idx],
            "actual_hits": gt_no_docs,
        }
    )
    
res_df = pd.DataFrame(results_formatted)
res_df.drop(columns=['precision_50', 'precision_200'])

Unnamed: 0,max_passages_per_doc,max_no_docs,n_passages_to_sample_per_shard,time,precision_5,precision_10,no_hits,actual_hits
0,5,20,5000,0.308,1.0,1.0,909,909
1,5,20,50000,0.431,1.0,1.0,909,909
2,5,20,100000,0.763,1.0,1.0,909,909
3,5,20,250000,0.84,1.0,1.0,909,909
4,5,20,800000,1.107,1.0,1.0,909,909
5,5,1000,5000,1.079,1.0,1.0,909,909
6,5,1000,50000,3.754,1.0,1.0,909,909
7,5,1000,100000,3.928,1.0,1.0,909,909
8,5,1000,250000,3.389,1.0,1.0,909,909
9,5,1000,800000,3.492,1.0,1.0,909,909


In [450]:
def precision(gt_list, pred_list):
    """Unordered measure of precision."""
    
    set_intersection = list(set(gt_list).intersection(set(pred_list)))
    
    return len(set_intersection) / len(gt_list)

results_formatted = []

for _idx, params in enumerate(search_combinations):
    
    results_formatted.append(
        {
            "max_passages_per_doc": params[0],
            "max_no_docs": params[1],
            "n_passages_to_sample_per_shard": params[2],
            "time": search_times[_idx],
            "precision_5": precision(gt_top_5, search_top_5[_idx]),
            "precision_10": precision(gt_top_10, search_top_10[_idx]),
            "precision_50": precision(gt_top_50, search_top_50[_idx]),
            "precision_200": precision(gt_top_200, search_top_200[_idx]),
            "no_hits": hits[_idx],
            "actual_hits": gt_no_docs,
        }
    )
    
res_df = pd.DataFrame(results_formatted)
res_df.drop(columns=['precision_50', 'precision_200'])

Unnamed: 0,max_passages_per_doc,max_no_docs,n_passages_to_sample_per_shard,time,precision_5,precision_10,no_hits,actual_hits
0,5,20,5000,0.372,1.0,1.0,909,909
1,5,20,50000,0.518,1.0,1.0,909,909
2,5,20,100000,0.643,1.0,1.0,909,909
3,5,20,250000,0.927,1.0,1.0,909,909
4,5,20,800000,1.217,1.0,1.0,909,909
5,5,1000,5000,1.051,1.0,1.0,909,909
6,5,1000,50000,3.436,1.0,1.0,909,909
7,5,1000,100000,2.88,1.0,1.0,909,909
8,5,1000,250000,3.174,1.0,1.0,909,909
9,5,1000,800000,3.583,1.0,1.0,909,909


### 3.1 query time simulator

In [451]:
query_time = .660 # seconds

time.sleep(query_time)
print("search result")

search result


## 4. parallel requests

In [460]:
from concurrent.futures import ThreadPoolExecutor
import requests


In [489]:
# sensible defaults based on section 3
MAX_PASSAGES_PER_DOC = 5
MAX_NO_DOCS = 1000
N_PASSAGES_TO_SAMPLE_PER_SHARD = 100000

# simulate a number of users picked at random to see the effect of caching
n_users = 10
users = [str(i) for i in range(n_users)]

def fetch(url):
    page = requests.get(url)
    return page.text
    # Catch HTTP errors/exceptions here
    
def make_request(idx):
    user = random.choice(users)
    
    # small query
    results, no_hits = run_query(
        queries[idx], 
        embs[idx,:],
        max_passages_per_doc=MAX_PASSAGES_PER_DOC,
        max_no_docs=20,
        # year_range=(2000, None),
        # keyword_filters={
        #     "action_country_code": ["CHE"]
        # },
        # sort_field = "action_name",
        # sort_order = "asc",
        innerproduct_threshold=70, # TODO: tune me
        n_passages_to_sample_per_shard=N_PASSAGES_TO_SAMPLE_PER_SHARD,
        preference=None
    )
    
    # big query
    # _, no_hits = run_query(
    #     queries[idx], 
    #     embs[idx,:],
    #     max_passages_per_doc=MAX_PASSAGES_PER_DOC,
    #     max_no_docs=MAX_NO_DOCS,
    #     # year_range=(2000, None),
    #     # keyword_filters={
    #     #     "action_country_code": ["CHE"]
    #     # },
    #     # sort_field = "action_name",
    #     # sort_order = "asc",
    #     innerproduct_threshold=70, # TODO: tune me
    #     n_passages_to_sample_per_shard=N_PASSAGES_TO_SAMPLE_PER_SHARD,
    #     preference=user
    # )
    
    time.sleep(random.randint(2,3))

    return results['took'] / 1000


N_REQUESTS = 1000
N_WORKERS = 100

pool = ThreadPoolExecutor(max_workers=N_WORKERS)

_iterator = range(len(queries))
idxs = random.sample(_iterator, len(_iterator))[:N_REQUESTS]

times = []

start = time.time()

for _time in pool.map(make_request, idxs):
    # Do whatever you want with the results ...
    times.append(_time)
    
end = time.time()

print(f"{N_REQUESTS} requests across {N_WORKERS} workers took {int(end-start)} seconds")

1000 requests across 100 workers took 48 seconds


In [490]:
print(np.mean(times))
pd.Series(times).describe()

1.757523


count    1000.000000
mean        1.757523
std         0.900928
min         0.002000
25%         1.103750
50%         1.766500
75%         2.360250
max         4.444000
dtype: float64