## Medical Question answering with Retrieval Augmented Generation design pattern. 
Use Python 3 (Data Science 3.0) kernel image and `ml.m5.2xlarge` for this notebook.

This includes generating embeddings of all existing documents, indexing them in a vector store. Then for every user query, generate local embeddings and search based on embedding distance. The search responses act as context to the LLM model to generate a output. 

Challenges:
How to manage large document(s) that exceed the token limitHow to find the document(s) relevant to the question being asked

## Key components

LLM (Large Language Model): Falcon-40b-instruct available through Amazon SageMaker This model will be used to understand the document chunks and provide an answer in human friendly manner.

Embeddings Model: GPT-J 6B available through Amazon SageMaker. This model will be used to generate a numerical representation of the textual documents.

Vector Store: FAISS available through LangChainIn this notebook we are using this in-memory vector-store to store both the embeddings and the documents. In an enterprise context this could be replaced with a persistent store such as AWS OpenSearch, RDS Postgres with pgVector, ChromaDB, Pinecone or Weaviate.

Index: VectorIndex The index helps to compare the input embedding and the document embeddings to find relevant document

### Dataset
To explain this architecture pattern we are using the documents from MedQA. These documents include medical textbooks such as:
Pathology, Anatomy, Pharmacology and others. 

Download textbooks that are part of Q&A dataset MedQA released as part of Jin, Di, et al. "What Disease does this Patient Have? A Large-scale Open Domain Question Answering Dataset from Medical Exams." arXiv preprint arXiv:2009.13081 (2020). 

More details are available here https://github.com/jind11/MedQA

* Data source : @article{jin2020disease,
  title={What Disease does this Patient Have? A Large-scale Open Domain Question Answering Dataset from Medical Exams},
  author={Jin, Di and Pan, Eileen and Oufattole, Nassim and Weng, Wei-Hung and Fang, Hanyi and Szolovits, Peter},
  journal={arXiv preprint arXiv:2009.13081},
  year={2020} }
  
  

### Data preparation

> **NOTICE**: "This link leads to a Third-Party Dataset. AWS does not own, nor does it have any control over the Third-Party Dataset. You should perform your own independent assessment, and take measures to ensure that you comply with your own specific quality control practices and standards, and the local rules, laws, regulations, licenses and terms of use that apply to you, your content, and the Third-Party Dataset. AWS does not make any representations or warranties that the Third-Party Dataset is secure, virus-free, accurate, operational, or compatible with your own environment and standards. AWS does not make any representations, warranties or guarantees that any information in the Third-Party Dataset will result in a particular outcome or result."

1. The full dataset can be downloaded can be seen from https://drive.google.com/file/d/1ImYUSLk9JbgHXOemfvyiDiirluZHPeQw/view?usp=sharing. You can read more about the dataset in https://github.com/jind11/MedQA#data.
2. To speed up the uploading for this lab, a smaller version of dataset is already downloaded -  https://d2qrbbbqnxtln.cloudfront.net/Pathology_Robbins.txt

##### Prerequisites

In [None]:
%pip install faiss-cpu==1.7.4 --quiet

In [None]:
%pip install langchain==0.0.222 --quiet

In [None]:
%%capture 

!pip install PyYAML

#### Imports

In [None]:
import requests
import logging 
import boto3
import yaml
import json

##### Setup logging

In [None]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

##### Log versions of dependencies 

In [None]:
logger.info(f'Using requests=={requests.__version__}')
logger.info(f'Using pyyaml=={yaml.__version__}')

#### Setup essentials

In [None]:
TEXT_EMBEDDING_MODEL_ENDPOINT_NAME = 'jumpstart-dft-hf-textembedding-gpt-j-6b-fp16' #INSERT EMBEDDING ENDPOINT NAME IF DIFFERENT
TEXT_GENERATION_MODEL_ENDPOINT_NAME = 'jumpstart-dft-hf-llm-falcon-7b-instruct-bf16' #INSERT TEXT GENERATION ENDPOINT NAME IF DIFFERENT

REGION_NAME = boto3.session.Session().region_name

#### Encode passages (chunks) using JumpStart's GPT-J text embedding model . We are specifically using only 1 of 20 textbooks from the dataset. It takes about 6 minutes to generate embeddings for one textbook (for example, Pathology). You can increase the number of textbooks indexed by adding sufficient time buffer for execution. 

In order to follow the RAG approach this notebook is using the LangChain framework where it has integrations with different services and tools that allow efficient building of patterns such as RAG. 

In [None]:
import numpy as np
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import DirectoryLoader, TextLoader

loader = DirectoryLoader("./", glob="**/Pathology*.txt", loader_cls=TextLoader)

documents = loader.load()
# - in our testing Character split works better with this PDF data set
text_splitter = RecursiveCharacterTextSplitter(
    # Set a really small chunk size, just to show.
    chunk_size = 1000,
    chunk_overlap  = 100,
)
docs = text_splitter.split_documents(documents)

In [None]:
print(docs[0])

In [None]:
avg_doc_length = lambda documents: sum([len(doc.page_content) for doc in documents])//len(documents)
avg_char_count_pre = avg_doc_length(documents)
avg_char_count_post = avg_doc_length(docs)
print(f'Average length among {len(documents)} documents loaded is {avg_char_count_pre} characters.')
print(f'After the split we have {len(docs)} documents more than the original {len(documents)}.')
print(f'Average length among {len(docs)} documents (after split) is {avg_char_count_post} characters.')

In [None]:
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain.embeddings import SagemakerEndpointEmbeddings
from typing import Any, Dict, List, Optional
from langchain.llms.sagemaker_endpoint import ContentHandlerBase


class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
    def embed_documents(self, texts: List[str], chunk_size: int = 5) -> List[List[float]]:
        """Compute doc embeddings using a SageMaker Inference Endpoint.

        Args:
            texts: The list of texts to embed.
            chunk_size: The chunk size defines how many input texts will
                be grouped together as request. If None, will use the
                chunk size specified by the class.

        Returns:
            List of embeddings, one for each text.
        """
        results = []
        _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size

        for i in range(0, len(texts), _chunk_size):
            response = self._embedding_func(texts[i : i + _chunk_size])
            print
            results.extend(response)
        return results


class ContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        embeddings = response_json["embedding"]
        return embeddings


content_handler = ContentHandler()

sagemakerEndpointEmbeddingsJumpStart = SagemakerEndpointEmbeddingsJumpStart(
    endpoint_name=TEXT_EMBEDDING_MODEL_ENDPOINT_NAME,
    region_name=REGION_NAME,
    content_handler=content_handler,
)

In [None]:
print(docs[0].page_content)

---
## Semantic Similarity with Amazon Jumpstart Embedding Models

Semantic search refers to searching for information based on the meaning and concepts of words and phrases, rather than just matching keywords. Embedding models like Amazon Titan Embeddings allow semantic search by representing words and sentences as dense vectors that encode their semantic meaning.

Semantic matching is extremely helpful for RAG because it returns results that are conceptually related to the user's query, even if they don't contain the exact keywords. This leads to more relevant and useful search results which can be injected into our LLM's prompts.

First, let's take a look below to illustrate the sample of an embedding

In [None]:
sample_embedding = np.array(sagemakerEndpointEmbeddingsJumpStart.embed_query(docs[0].page_content))
print("Sample embedding of a document chunk: ", sample_embedding)
print("Size of the embedding: ", sample_embedding.shape)

Now create embeddings for the entire document set. Note for a single medical textbook, it takes about 6 minutes.

In [None]:
from tqdm.contrib.concurrent import process_map
from multiprocessing import cpu_count

def generate_embeddings(x):
    return (x, sagemakerEndpointEmbeddingsJumpStart.embed_query(x))
    
workers = 1 * cpu_count()

texts = [i.page_content for i in docs]

In [None]:
workers

In [None]:
data = process_map(generate_embeddings, texts, max_workers=workers, chunksize=100)

Next, we insert the embeddings to the FAISS vector store

In [None]:
from langchain.vectorstores import FAISS
faiss = FAISS.from_documents(docs[0:2], sagemakerEndpointEmbeddingsJumpStart)
faiss.add_embeddings(data)
faiss.save_local("faiss_index")

Next we create user query to retrieve a response from vector search and LLM combined

In [None]:
query = "What is acute kidney injury?"

In [None]:
query_embedding = faiss.embedding_function(query)
np.array(query_embedding)

In [None]:
relevant_documents = faiss.similarity_search_by_vector(query_embedding)
context = ""
print(f'{len(relevant_documents)} documents are fetched which are relevant to the query.')
print('----')
for i, rel_doc in enumerate(relevant_documents):
    print(f'## Document {i+1}: {rel_doc.page_content}.......')
    print('---')
    context += rel_doc.page_content
context = context.replace("\n", " ")

Now create a prompt template to trigger the model with above context from vector search. We specifically inform the model to answer only using the context provied.

In [None]:
template = """
        You are a helpful, polite, fact-based agent.
        If you don't know the answer, just say that you don't know.
        Please answer the following question using the context provided. 

        CONTEXT: 
        {context}
        =========
        QUESTION: {question} 
        ANSWER: """


In [None]:
prompt = template.format(context=context, question=query)
print(prompt)

Invoke the endpoint to generate a response from the LLM

In [None]:
smr_client = boto3.client("sagemaker-runtime")

In [None]:
response_model = smr_client.invoke_endpoint(
    EndpointName=TEXT_GENERATION_MODEL_ENDPOINT_NAME,
    Body=json.dumps(
        {"inputs": prompt, "parameters": {"max_new_tokens": 500}}
    ),
    ContentType="application/json",
)
response = json.loads(response_model["Body"].read())


In [None]:
print(response[0]["generated_text"])