# Opensearch exact query

Requirements:
- [x] searches for exact phrases in titles, summaries and full text
- [x] fields are given priority in the following order: title > summary > text
- [x] non-ascii characters are normalised. E.g. 'El Niño' == 'El Nino'
- [x] search is case-insensitive
- [x] out-of-word punctuation is ignored. E.g. 'electricity!' == 'electricity'.
- [x] number of matching passages and documents is returned. 
- [x] a user can sort by date and title, ascending or descending.

In [1]:
%load_ext autoreload
%autoreload 2

import time
import numpy as np
from typing import Optional, List, Dict, Tuple

from app.index import OpenSearchIndex
from app.ml import SBERTEncoder

  from .autonotebook import tqdm as notebook_tqdm


## 1. Setup

### 1a. Connect to Opensearch
As we're outside of docker-compose we'll connect to Opensearch via localhost.

In [2]:
opensearch = OpenSearchIndex(
    url="http://localhost:9200",
    username="admin",
    password="admin",
    index_name="navigator",
    # TODO: convert to env variables?
    opensearch_connector_kwargs={
        "use_ssl": False,
        "verify_certs": False,
        "ssl_show_warn": False,
    },
    embedding_dim=768,
)

print(opensearch.is_connected())

opns = opensearch.opns

True


In [3]:
# TODO: this needs to be the same model as used for indexing. At a later stage when we start updating 
# models we may want a way of ensuring both models are the same.
enc = SBERTEncoder(model_name="msmarco-distilbert-dot-v5")
enc.encode("hello world").shape

(768,)

In [4]:
emba = enc.encode("bicycle race")
embb = enc.encode("car race")
embc = enc.encode("tortoise race")

np.dot(emba, embb), np.dot(emba, embc)

(80.994026, 76.67018)

### 2. Run search

The `run_query` function does all of the heavy lifting here.

In [10]:
def _year_range_filter(year_range: Tuple[Optional[int], Optional[int]]):
    """
    Get an Opensearch filter for year range. The filter returned is between the first term of
    `year_range` and the last term, and is inclusive. Either value can be set to None to only
    apply one year constraint.
    """

    start_date = f"01/01/{year_range[0]}" if year_range[0] is not None else None
    end_date = f"31/12/{year_range[1]}" if year_range[1] is not None else None

    policy_year_conditions = {}
    if start_date is not None:
        policy_year_conditions["gte"] = start_date
    if end_date is not None:
        policy_year_conditions["lte"] = end_date

    range_filter = {"range": {}}

    range_filter["range"]["action_date"] = policy_year_conditions

    return range_filter
    

def run_query(
    query: str, 
    max_passages_per_doc: int, 
    keyword_filters: Optional[Dict[str, List[str]]] = None, 
    year_range: Optional[Tuple[Optional[int], Optional[int]]] = None,
    sort_field: Optional[str] = None,
    sort_order: Optional[str] = None,
    max_no_docs: int = 100,
    n_passages_to_sample_per_shard: int = 5000,
) -> dict:
    """
    Run an Opensearch query.
    
    Args:
        query (str): query string
        innerproduct_threshold (float): threshold applied to KNN results
        max_passages_per_doc (int): maximum number of passages to return per document
        keyword_filters (Optional[Dict[str, List[str]]]): filters on keyword values to apply.
        In the format `{"field_name": ["values", ...], ...}`. Defaults to None.
        year_range (Optional[Tuple[Optional[int], Optional[int]]]): filter on action year by (minimum, maximum). 
        Either value can be set to `None` for a one-sided filter.
        sort_field (Optional[str]): field to sort. Only the values `action_date`, `action_name` or `None` are valid.
        sort_order (Optional[str]): order to sort in, applied if `sort_field` is not None. Can be either "asc" or "desc".
        max_no_docs (int, optional): maximum number of documents to return. Keep this high so pagination can happen on the entire response. Defaults to 100.
        n_passages_to_sample_per_shard (int, optional): in order to speed up aggregations only the top N passages are considered for aggregation per shard. 
        Setting this value to lower will speed up searches at the cost of lowered recall. This value sets N. Defaults to 5000.
    
    Returns:
        dict: raw Opensearch result.
    """

    opns_query = {
                "size": 0, # only return aggregations
                "query": {
                    "bool": {
                        "should": [
                            # Text passage matching
                            {
                                "match_phrase": {
                                    "text": {
                                        "query": query,
                                        "boost": 1,
                                    },
                                }
                            },
                            # Action (to be document) title matching
                            {
                                "match_phrase": {
                                    "for_search_name": {
                                        "query": query,
                                        "boost": 3,
                                    },
                                }
                            },
                            # Action (to be document) description matching
                            {
                                "match_phrase": {
                                    "for_search_description": {
                                        "query": query,
                                        "boost": 2,
                                    },
                                }
                            },
                        ],
                        "minimum_should_match": 1,
                    },
                },
                "aggs": {
                    "sample": {
                        "sampler": {"shard_size": n_passages_to_sample_per_shard},
                    "aggs": {
                        "top_docs": {
                            "terms": {
                                "field": "action_name_and_id",
                                "order": {"top_hit": "desc"},
                                "size": max_no_docs,
                            },
                            "aggs": {
                                "top_passage_hits": {
                                    "top_hits": {
                                        "_source": {"excludes": ["text_embedding"]},
                                        "size": max_passages_per_doc,
                                    }
                                },
                                "top_hit": {"max": {"script": {"source": "_score"}}},
                                "action_date": {
                                    "stats": {
                                        "field": "action_date"
                                    }
                                }
                            },
                        },
                        "bucketcount": {
                          "stats_bucket": {
                            "buckets_path": "top_docs._count"
                          }
                        }
                    },
                    }
                }
            }
    
    if keyword_filters:
        terms_clauses = []

        for field, values in keyword_filters.items():
            terms_clauses.append({"terms": {field: values}})

        opns_query["query"]["bool"]["filter"] = terms_clauses

    
    if year_range:
        if "filter" not in opns_query["query"]["bool"]:
            opns_query["query"]["bool"]["filter"] = []

        opns_query["query"]["bool"]["filter"].append(
            _year_range_filter(year_range)
        )
        
    # TODO: how does this work in a situation with more than 10,000 i.e. paginated results?
    if sort_field and sort_order:
        if sort_field == "action_date":
            opns_query["aggs"]["top_docs"]["terms"]["order"] = {f"{sort_field}.avg": sort_order}
        elif sort_field == "action_name":
            opns_query["aggs"]["top_docs"]["terms"]["order"] = {"_key": sort_order}
    
    start = time.time()
    response = opns.search(
        body=opns_query,
        index="navigator",
        request_timeout=30,
        preference="prototype_user", # TODO: document what this means
        explain=True,
    )
    end = time.time()
    
    passage_hit_count = response['hits']['total']['value']
    # note: 'gte' values are returned when there are more than 10,000 results by default
    if response['hits']['total']['relation'] == "eq":
        passage_hit_qualifier = "exactly"
    elif response['hits']['total']['relation'] == "gte":
        passage_hit_qualifier = "at least"
    
    doc_hit_count = response['aggregations']['sample']['bucketcount']['count']
    
    print(f"query execution time: {round(end-start, 2)}s")
    print(f"returned {passage_hit_qualifier} {passage_hit_count} passage(s) in {doc_hit_count} document(s)")
    
    return response

# TODO: we should experimentally adjust this threshold 
response = run_query(
    "scheme", 
    max_passages_per_doc=10,
    # year_range=(2000, None),
    # keyword_filters={
    #     "action_country_code": ["CHE"]
    # },
    # sort_field = "action_name",
    # sort_order = "asc",
)

response

query execution time: 0.02s
returned exactly 164 passage(s) in 27 document(s)


{'took': 13,
 'timed_out': False,
 '_shards': {'total': 1, 'successful': 1, 'skipped': 0, 'failed': 0},
 'hits': {'total': {'value': 164, 'relation': 'eq'},
  'max_score': None,
  'hits': []},
 'aggregations': {'sample': {'doc_count': 164,
   'top_docs': {'doc_count_error_upper_bound': 0,
    'sum_other_doc_count': 0,
    'buckets': [{'key': 'forestry act and national strategy for the development of the forest sector 2013-2020 293',
      'doc_count': 3,
      'action_date': {'count': 3,
       'min': 1312329600000.0,
       'max': 1312329600000.0,
       'avg': 1312329600000.0,
       'sum': 3936988800000.0,
       'min_as_string': '03/08/2011',
       'max_as_string': '03/08/2011',
       'avg_as_string': '03/08/2011',
       'sum_as_string': '04/10/2094'},
      'top_hit': {'value': 10.047385215759277},
      'top_passage_hits': {'hits': {'total': {'value': 3, 'relation': 'eq'},
        'max_score': 10.047385,
        'hits': [{'_index': 'navigator',
          '_type': '_doc',
     