## Dense Passage Retrieval: Retrieve Top K matching passages 

In [2]:
%%capture

!pip install cohere-sagemaker

#### Imports 

In [3]:
from cohere_sagemaker import Client
from requests.auth import HTTPBasicAuth
import requests
import logging 
import boto3
import yaml
import json



##### Setup logging

In [4]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

##### Log versions of dependencies

In [5]:
logger.info(f'Using requests=={requests.__version__}')
logger.info(f'Using pyyaml=={yaml.__version__}')

Using requests==2.31.0
Using pyyaml==6.0


#### Setup essentials

In [94]:
TEXT_EMBEDDING_MODEL_ENDPOINT_NAME = 'huggingface-textembedding-gpt-j-6b-fp16-1691075868'
TEXT_GENERATION_MODEL_ENDPOINT_NAME_COHERE = 'cohere-medium-1680827379'
TEXT_GENERATION_MODEL_ENDPOINT_NAME = 'huggingface-text2text-flan-t5-xl-1691100607'

CHUNKS_DIR_PATH = './data/chunks'
sagemaker_client = boto3.client('runtime.sagemaker')
cohere_client = Client(endpoint_name=TEXT_GENERATION_MODEL_ENDPOINT_NAME)

In [12]:
with open('config.yml', 'r') as file:
    config = yaml.safe_load(file)

es_username = config['credentials']['username']
es_password = config['credentials']['password']

domain_endpoint = config['domain']['endpoint']
domain_index = config['domain']['index']

In [13]:
URL = f'{domain_endpoint}/{domain_index}/_search'
logger.info(f'URL for OpenSearch index = {URL}')

URL for OpenSearch index = https://search-sematic-search-4vgtrb5lpgqsss26pxewnosnjy.eu-west-1.es.amazonaws.com/legal-passages/_search


Refer to https://docs.aws.amazon.com/opensearch-service/latest/developerguide/knn.html for more info.

#### Encode question using SageMaker JumpStart's text embedding model endpoint

In [29]:
prompt = 'What is the definition of crime of battery?'

In [14]:
payload = {'text_inputs': [prompt]}
payload = json.dumps(payload).encode('utf-8')
response = sagemaker_client.invoke_endpoint(EndpointName=TEXT_EMBEDDING_MODEL_ENDPOINT_NAME, 
                                            ContentType='application/json', 
                                            Body=payload)
body = json.loads(response['Body'].read())
embedding = body['embedding'][0]

#### Find top k (k=3) matching passages aligned in context to the encoded question

In [52]:
K = 3

In [53]:
query = {
    'size': K,
    'query': {
        'knn': {
          'embedding': {
            'vector': embedding,
            'k': K
          }
        }
      }
    }

In [54]:
response = requests.post(URL, auth=HTTPBasicAuth(es_username, es_password), json=query)
response_json = response.json()
hits = response_json['hits']['hits']

#### Generate answers using SageMaker JumpStart's text generation model by leveraging the previously matched passages 

In [13]:
for hit in hits:
    score = hit['_score']
    passage = hit['_source']['passage']
    doc_id = hit['_source']['doc_id']
    passage_id = hit['_source']['passage_id']
    qa_prompt = f'Context={passage}\nQuestion={prompt}\nAnswer='
    
    response = cohere_client.generate(prompt=qa_prompt, 
                                      max_tokens=512, 
                                      temperature=0.25, 
                                      return_likelihoods='GENERATION')
    
    answer = response.generations[0].text.strip().replace('\n', '')
    logger.info(f'Answer:\n{answer}')
    logger.info(f'Reference:\nDocument = {doc_id} | Passage = {passage_id} | Score = {score}')
    
if not hits:
    logger.warn('No matching documents found!')

Answer:
The crime of battery is committed when a person intentionally touches someone else in a way that causes injury.
Reference:
Document = 005 | Passage = 45 | Score = 0.64828235
Answer:
The crime of battery is committed when a person intentionally touches another person in a harmful or offensive way.
Reference:
Document = 005 | Passage = 44 | Score = 0.6335997
Answer:
The crime of battery is defined as "an unlawful and intentional application of force upon another person."
Reference:
Document = 005 | Passage = 7 | Score = 0.6297361
Answer:
The crime of battery is committed when a person intentionally touches another person against their will.
Reference:
Document = 005 | Passage = 74 | Score = 0.6294636
Answer:
The crime of battery is defined as "willfully and maliciously causing harm to another person."
Reference:
Document = 007 | Passage = 12 | Score = 0.6268758


In [88]:
MAX_LENGTH = 1024
NUM_RETURN_SEQUENCES = 1
TOP_K = 100
TOP_P = 0.9
DO_SAMPLE = True 
CONTENT_TYPE = 'application/json'


In [95]:
logger.info(f'Questopn:\n{prompt}')
for hit in hits:
    score = hit['_score']
    passage = hit['_source']['passage']
    doc_id = hit['_source']['doc_id']
    passage_id = hit['_source']['passage_id']
    #qa_prompt = f'Context={passage}\nQuestion={prompt}\nAnswer='
    qa_prompt  = f'Answer based on context:\n{passage}\n{prompt}'
    
    payload = {'text_inputs': qa_prompt, 
           'max_length': MAX_LENGTH, 
           'num_return_sequences': NUM_RETURN_SEQUENCES,
           'top_k': TOP_K,
           'top_p': TOP_P,
            'temperature': 0.25}

    payload = json.dumps(payload).encode('utf-8')

    response = sagemaker_client.invoke_endpoint(EndpointName=TEXT_GENERATION_MODEL_ENDPOINT_NAME, 
                                  ContentType=CONTENT_TYPE, 
                                  Body=payload)
    
    model_predictions = json.loads(response['Body'].read())
    generated_text = model_predictions['generated_texts'][0]
    #answer = response.generations[0].text.strip().replace('\n', '')
    logger.info(f'Answer:{generated_text}')
    logger.info(f'Reference:\nDocument = {doc_id} | Passage = {passage_id} | Score = {score}')
    
if not hits:
    logger.warn('No matching documents found!')

Questopn:
What is the definition of crime of battery?
Answer:rape
Reference:
Document = 005 | Passage = 45 | Score = 0.64828235
Answer:causing physical harm to another person
Reference:
Document = 005 | Passage = 44 | Score = 0.6335997
Answer:touching of another person
Reference:
Document = 005 | Passage = 7 | Score = 0.6297361
