In [1]:
%pip install python-dotenv

You should consider upgrading via the '/Users/kalyan/.pyenv/versions/3.8.12/bin/python3.8 -m pip install --upgrade pip' command.[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
from dotenv import load_dotenv
import os
import sys
from typing import Tuple, Optional, Dict, List
import time
import random

from tqdm.auto import tqdm
import pandas as pd
import numpy as np

sys.path.append("..")

from app.index import OpenSearchIndex
from app.ml import SBERTEncoder

load_dotenv(os.path.join("../..", '.env'))

# %env

  from .autonotebook import tqdm as notebook_tqdm


True

In [3]:
opensearch = OpenSearchIndex(
    url=os.getenv("OPENSEARCH_URL"),
    username=os.getenv("OPENSEARCH_USER"),
    password=os.getenv("OPENSEARCH_PASSWORD"),
    index_name=os.getenv("OPENSEARCH_INDEX"),
    # TODO: convert to env variables?
    opensearch_connector_kwargs={
        "use_ssl": os.getenv("OPENSEARCH_USE_SSL"),
        "verify_certs": os.getenv("OPENSEARCH_VERIFY_CERTS"),
        "ssl_show_warn": os.getenv("OPENSEARCH_SSL_WARNINGS"),
    },
    embedding_dim=768,
)

print(opensearch.is_connected())

opns = opensearch.opns

True


In [4]:
enc = SBERTEncoder(model_name="msmarco-distilbert-dot-v5")
enc.encode("hello world").shape

(768,)

## 1. Load query test set

In [5]:
query_df = pd.read_csv("./data/test_queries.tsv", sep="\t", header=None, names=["idx", "text"])
queries = query_df['text'].tolist()

with open("./data/test_query_embs.npy", "rb") as f:
    embs = np.load(f)
    
embs.shape

(1000, 768)

## 2. Load query

In [21]:
def _innerproduct_threshold_to_lucene_threshold(ip_thresh: float) -> float:
    """
    Opensearch documentation on mapping similarity functions to Lucene thresholds is here: https://github.com/opensearch-project/k-NN/blob/main/src/main/java/org/opensearch/knn/index/SpaceType.java#L33
    It defines 'inner product' as negative inner product i.e. a distance rather than similarity measure, so we reverse the signs of inner product here compared to the docs.
    """
    if ip_thresh > 0:
        return ip_thresh + 1
    else:
        return 1 / (1-ip_thresh)

def _year_range_filter(year_range: Tuple[Optional[int], Optional[int]]):
    """
    Get an Opensearch filter for year range. The filter returned is between the first term of
    `year_range` and the last term, and is inclusive. Either value can be set to None to only
    apply one year constraint.
    """

    start_date = f"01/01/{year_range[0]}" if year_range[0] is not None else None
    end_date = f"31/12/{year_range[1]}" if year_range[1] is not None else None

    policy_year_conditions = {}
    if start_date is not None:
        policy_year_conditions["gte"] = start_date
    if end_date is not None:
        policy_year_conditions["lte"] = end_date

    range_filter = {"range": {}}

    range_filter["range"]["action_date"] = policy_year_conditions

    return range_filter
    

def run_query(
    query: str, 
    embedding: np.ndarray,
    max_passages_per_doc: int, 
    keyword_filters: Optional[Dict[str, List[str]]] = None, 
    year_range: Optional[Tuple[Optional[int], Optional[int]]] = None,
    sort_field: Optional[str] = None,
    sort_order: Optional[str] = None,
    innerproduct_threshold: float = 70,
    max_no_docs: int = 100,
    n_passages_to_sample_per_shard: int = 5000,
) -> dict:
    """
    Run an Opensearch query.
    
    Args:
        query (str): query string
        innerproduct_threshold (float): threshold applied to KNN results
        max_passages_per_doc (int): maximum number of passages to return per document
        keyword_filters (Optional[Dict[str, List[str]]]): filters on keyword values to apply.
        In the format `{"field_name": ["values", ...], ...}`. Defaults to None.
        year_range (Optional[Tuple[Optional[int], Optional[int]]]): filter on action year by (minimum, maximum). 
        Either value can be set to `None` for a one-sided filter.
        sort_field (Optional[str]): field to sort. Only the values `action_date`, `action_name` or `None` are valid.
        sort_order (Optional[str]): order to sort in, applied if `sort_field` is not None. Can be either "asc" or "desc".
        max_no_docs (int, optional): maximum number of documents to return. Keep this high so pagination can happen on the entire response. Defaults to 100.
        n_passages_to_sample_per_shard (int, optional): in order to speed up aggregations only the top N passages are considered for aggregation per shard. 
        Setting this value to lower will speed up searches at the cost of lowered recall. This value sets N. Defaults to 5000.
    
    Returns:
        dict: raw Opensearch result.
    """
    
    lucene_threshold = _innerproduct_threshold_to_lucene_threshold(innerproduct_threshold)

    opns_query = {
                "size": 0, # only return aggregations
                "query": {
                    "bool": {
                        "should": [
                            {
                                "bool": {
                                    "should": [
                                        {
                                            "match": {
                                                "text": {
                                                    "query": query,
                                                },
                                            }
                                        },
                                        {
                                            "function_score": {
                                                "query": {
                                                    "knn": {
                                                        "text_embedding": {
                                                            "vector": embedding,
                                                            "k": 1000, # TODO: tune me
                                                        },
                                                    },
                                                },
                                                "min_score": lucene_threshold
                                            }
                                        },
                                    ],
                                    "minimum_should_match": 1,
                                }
                            },
                            {
                                "bool": {
                                    "should": [
                                        {
                                            "match": {
                                                "for_search_action_description": {
                                                    "query": query,
                                                    "boost": 3,
                                                }
                                            }
                                        },
                                        {
                                            "function_score": {
                                                "query": {
                                                    "knn": {
                                                        "action_description_embedding": {
                                                            "vector": embedding,
                                                            "k": 1000, # TODO: tune me
                                                        },
                                                    },
                                                },
                                                "min_score": lucene_threshold # TODO: tune me separately for descriptions?
                                            }
                                        },
                                        # TODO: add knn on action description
                                    ],
                                    "minimum_should_match": 1,
                                    "boost": 2
                                },
                            },
                            {
                                "bool": {
                                    "should": [
                                        {
                                            "match": {
                                                "for_search_action_name": {
                                                    "query": query,
                                                }
                                            }
                                        },
                                        {
                                            "match_phrase": {
                                                "for_search_action_name": {
                                                    "query": query,
                                                    "boost": 2,
                                                }
                                            }
                                        },
                                    ],
                                    "boost": 10,
                                }
                            }
                        ],
                        "minimum_should_match": 1
                    },
                },
                "aggs": {
                    "sample": {
                        "sampler": {"shard_size": n_passages_to_sample_per_shard},
                        "aggs": {
                            "top_docs": {
                                "terms": {
                                    "field": "action_name_and_id",
                                    "order": {"top_hit": "desc"},
                                    "size": max_no_docs,
                                },
                                "aggs": {
                                    "top_passage_hits": {
                                        "top_hits": {
                                            "_source": {"excludes": ["text_embedding", "action_description_embedding"]},
                                            "size": max_passages_per_doc,
                                        }
                                    },
                                    "top_hit": {"max": {"script": {"source": "_score"}}},
                                    "action_date": {
                                        "stats": {
                                            "field": "action_date"
                                        }
                                    }
                                },
                            },
                            "bucketcount": {
                              "stats_bucket": {
                                "buckets_path": "top_docs._count"
                              }
                            }
                        },
                    } 
                } 
            }
    
    if keyword_filters:
        terms_clauses = []

        for field, values in keyword_filters.items():
            terms_clauses.append({"terms": {field: values}})

        opns_query["query"]["bool"]["filter"] = terms_clauses

    
    if year_range:
        if "filter" not in opns_query["query"]["bool"]:
            opns_query["query"]["bool"]["filter"] = []

        opns_query["query"]["bool"]["filter"].append(
            _year_range_filter(year_range)
        )
        
    # TODO: how does this work in a situation with more than 10,000 i.e. paginated results?
    if sort_field and sort_order:
        if sort_field == "action_date":
            opns_query["aggs"]["sampler"]["aggs"]["top_docs"]["terms"]["order"] = {f"{sort_field}.avg": sort_order}
        elif sort_field == "action_name":
            opns_query["aggs"]["sampler"]["aggs"]["top_docs"]["terms"]["order"] = {"_key": sort_order}
    
    response = opns.search(
        body=opns_query,
        index="navigator",
        request_timeout=30,
        # preference="prototype_user", # TODO: document what this means
    )
    
    passage_hit_count = response['hits']['total']['value']
    # note: 'gte' values are returned when there are more than 10,000 results by default
    if response['hits']['total']['relation'] == "eq":
        passage_hit_qualifier = "exactly"
    elif response['hits']['total']['relation'] == "gte":
        passage_hit_qualifier = "at least"
    
    doc_hit_count = response['aggregations']['sample']['bucketcount']['count']
        
    return response

run_query(
    queries[idx], 
    embs[idx,:],
    max_passages_per_doc=10,
    # year_range=(2000, None),
    # keyword_filters={
    #     "action_country_code": ["CHE"]
    # },
    # sort_field = "action_name",
    # sort_order = "asc",
    innerproduct_threshold=70, # TODO: tune me
    n_passages_to_sample_per_shard=2000,
)

{'took': 44,
 'timed_out': False,
 '_shards': {'total': 1, 'successful': 1, 'skipped': 0, 'failed': 0},
 'hits': {'total': {'value': 10000, 'relation': 'gte'},
  'max_score': None,
  'hits': []},
 'aggregations': {'sample': {'doc_count': 2000,
   'top_docs': {'doc_count_error_upper_bound': -1,
    'sum_other_doc_count': 704,
    'buckets': [{'key': 'philippine national redd-plus strategy 1942',
      'doc_count': 137,
      'action_date': {'count': 137,
       'min': 1287532800000.0,
       'max': 1287532800000.0,
       'avg': 1287532800000.0,
       'sum': 176391993600000.0,
       'min_as_string': '20/10/2010',
       'max_as_string': '20/10/2010',
       'avg_as_string': '20/10/2010',
       'sum_as_string': '23/08/7559'},
      'top_hit': {'value': 121.84829711914062},
      'top_passage_hits': {'hits': {'total': {'value': 137, 'relation': 'eq'},
        'max_score': 121.8483,
        'hits': [{'_index': 'navigator',
          '_type': '_doc',
          '_id': 'xdEBKIAB-j1vkLQeuBg

In [33]:
times = []
_iterator = range(len(queries))
# _iterator = random.sample(_iterator, len(_iterator)) # shuffle

for idx in tqdm(_iterator[0:50]):
    start = time.time()
    _ = run_query(
        queries[idx], 
        embs[idx,:],
        max_passages_per_doc=10,
        # year_range=(2000, None),
        # keyword_filters={
        #     "action_country_code": ["CHE"]
        # },
        # sort_field = "action_name",
        # sort_order = "asc",
        innerproduct_threshold=70, # TODO: tune me
        n_passages_to_sample_per_shard=2000,
    )
    end = time.time()
    
    times.append(end-start)
    # print(end-start)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:38<00:00,  1.31it/s]


In [34]:
print(np.mean(times))
pd.Series(times).describe()

0.7604890203475952


count    50.000000
mean      0.760489
std       0.133360
min       0.530206
25%       0.663100
50%       0.729915
75%       0.857168
max       1.033069
dtype: float64