# Setting the semantic search threshold

As part of the way we're doing search with Opensearch, we want to filter out passages with low semantic similarity to the search term. We do this on the results returned by Opensearch, in `api.py`.

This notebook exists because the API/UI currently doesn't return semantic similarities (inner product of search term embedding vector and each passage embedding vector). You can use it to return some results for a query and experiment with different minimum inner product thresholds.



In [1]:
import sys
sys.path.append("..")

import json
import time

import numpy as np
import requests

from policy_search.pipeline.semantic_search import SBERTEncoder

In [2]:
OPENSEARCH_ENDPOINT = "https://localhost:9200"
OPENSEARCH_USER_PASSWORD = ("admin", "admin")

session = requests.Session()
session.auth = OPENSEARCH_USER_PASSWORD

query_encoder = SBERTEncoder("msmarco-distilbert-dot-v5")

headers = {
  'Authorization': 'Basic YWRtaW46YWRtaW4=',
  'Content-Type': 'application/json'
}


2021-10-29 10:23:36,211 - policy_search.pipeline.semantic_search - DEBUG - Downloading sentence-transformers model


## 1. Run query

In [3]:
# query parameters
query = "transport levies"
max_passages_per_page = 5
max_pages_per_doc = 10

# ------

query_embedding = [float(i) for i in list(query_encoder.encode(query))]

# NOTE: this is the query from the method `policy_search.pipeline.opensearch.OpenSearchIndex.search`
es_query = {
                "_source": {"excludes": ["text.embedding"]},
                "query": {
                    "bool": {
                        "must": [
                            {
                                "nested": {
                                    "path": "text",
                                    "score_mode": "max",
                                    "inner_hits": {
                                        "_source": ["text.text_id", "text.text", "text.embedding"],
                                        "size": max_passages_per_page,
                                    },
                                    "query": {
                                        "bool": { 
                                            "should": [
                                                {
                                                    "match": {
                                                        "text.text": {
                                                            "query": query,
                                                            "boost": 1,
                                                        },
                                                    }
                                                },
                                                { 
                                                    "knn": {
                                                        "text.embedding": {
                                                            "vector": query_embedding,
                                                            "k": max_passages_per_page,
                                                            "boost": 1,
                                                        },
                                                    } 
                                                }     
                                            ]
                                        }
                                    },
                                }
                            },
                        ],
                        "should": [
                            {
                                "match_phrase": {
                                    "policy_name": {
                                        "query": query,
                                        "boost": 1,
                                    }
                                }
                            }
                        ]
                    }
                },
                "aggs": {
                    "top_docs": {
                        "terms": {
                            "field": "policy_id",
                            "order": {"top_hit": "desc"},
                        },
                        "aggs": {
                            "top_passage_hits": {
                                "top_hits": {
                                    "_source": {"excludes": ["text.embedding"]},
                                    "size": max_pages_per_doc,
                                }
                            },
                            "top_hit": {"max": {"script": {"source": "_score"}}},
                        },
                    }
                },
            }

start = time.time()
search_result = session.post(url=f"{OPENSEARCH_ENDPOINT}/policies/_search", data=json.dumps(es_query), headers=headers, verify=False).json()
end = time.time()
print(f"completed in {round(end-start, 1)} seconds")




completed in 2.1 seconds


## 2. Experiment with different inner product thresholds

Filtered results are returned below. Passages that have been excluded from the results based on the chosen threshold are stored in the variable `removed_using_threshold`.

In [4]:
INNERPRODUCT_THRESHOLD = 70

results_by_doc = search_result["aggregations"]["top_docs"]["buckets"]

query_results_by_doc = []

removed_using_threshold = []

# Iterate over each document returned from the query
for result in results_by_doc[0:20]:
    hits_by_page = result["top_passage_hits"]["hits"]["hits"]
    # num_pages_with_hit = result["doc_count"]
    policy_id = result["key"]
    policy_name = hits_by_page[0]['_source']['policy_name']

    document_response = []

    # Iterate over each page hit in each document
    for hit in hits_by_page:
        page_text_hits = []
        # Find the matching text passages and add to results
        for page_inner_hits in hit["inner_hits"]["text"]["hits"]["hits"]:
            score = np.dot(query_embedding, page_inner_hits['_source']['embedding'])
            if score >= INNERPRODUCT_THRESHOLD:
                page_text_hits.append((page_inner_hits["_source"]["text"], score))
            else:
                removed_using_threshold.append((page_inner_hits["_source"]["text"], score))

        # Add the page matches for this document
        if len(page_text_hits) > 0:
            document_response.append(
                {
                    "pageNumber": hit["_source"]["page_number"],
                    "text": page_text_hits,
                }
            )

    # Add the query matches for this document
    if len(document_response) > 0:
        query_results_by_doc.append(
            {
                "policyId": policy_id,
                "policyName": policy_name,
                "resultsByPage": document_response,
            }
        )

print(f"{len(removed_using_threshold)} documents removed with an inner product threshold of {INNERPRODUCT_THRESHOLD}")
        
query_results_by_doc

78 documents removed with an inner product threshold of 70


[{'policyId': 6,
  'policyName': 'Long-Term Climate Strategy, 2021',
  'resultsByPage': [{'pageNumber': 20,
    'text': [('Two new steering levies are provided for in aviation.',
      77.04585712082876),
     ('Airlines which achieve significant emissions reductions can benefit from a lower rate for both levies.',
      75.45201574327331),
     ('A levy of 500 to 3000 francs will be levied for business and private flights.',
      75.06579617323784),
     ('Support with innovative measures to reduce climate impact in aviation can also be provided by the Climate Fund which will receive less than half of the revenues from the flight levies.',
      72.64053568134089)]},
   {'pageNumber': 33,
    'text': [('The transport sector is made up of the transport (1A3) and military (1A5) emissions categories.',
      75.39057475867008),
     ('Domestic navigation Pipeline transport Other (military) Total road transportation Light duty trucks',
      74.60234589729858),
     ('2018 Road transport

In [5]:
removed_using_threshold

[('In public transport, disincentives to switch from diesel-fuelled buses to buses with lower greenhouse gas emissions will be eliminated as the partial mineral oil tax rebate for licensed transport companies gradually expires: initially in local transport from 2026, and also in regional passenger transport from 2030 unless topographical conditions prevent climate-friendly alternatives.',
  69.43866929669971),
 ('Federal Department of the Environment, Transport, Energy and Communications DETEC (2017): Switzerland Future Mobility DETEC Guiding Framework 2040.',
  69.81972320511444),
 ('Examples include the reduction of the emission of air pollutants by converting to renewable energies and lower noise pollution from transport by switching from combustion to electric engines.',
  64.42640013273795),
 ('2040 2045 2050 Retail Transport, communications Public administration Education Health Other services',
  68.03026567406205),
 ('The Climate Fund will receive a third of the revenues from t