# Intelligent document processing with Gen AI, Amazon Textract and FlanT5 on SageMaker Jumpstart
____


In this notebook we will first use Amazon Textract's document extraction capabilities, and then the steps required to perform Q&A with a document first by extracting text from a document using Amazon Textract, generating chunks of text and store them into a Vector DB, and then performing Q&A with a FlanT5 model deployed in SageMaker endpoint via SageMaker Jumpstart and get precise answers from the model.

# Setup notebook <a id="step1"></a>


In [13]:
!pip install -U langchain 
!pip install pdfplumber
!pip install unstructured
!pip install chromadb
!pip install -U sentence-transformers
!pip install pydantic==1.10.11 #use 1.10.11 version due to stability
#textractor libraries
!python -m pip install -q amazon-textract-caller --upgrade
!python -m pip install -q amazon-textract-prettyprinter --upgrade
!python -m pip install -q amazon-textract-response-parser --upgrade

[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
amazon-textract-prettyprinter 0.1.8 requires amazon-textract-response-parser<0.2,>=0.1, but you have amazon-textract-response-parser 1.0.2 which is incompatible.[0m[31m
[0m

# Module 1 - Document Extraction 

In [14]:
import boto3
import botocore
import sagemaker
from sagemaker.session import Session
from sagemaker.session import Session
from IPython.display import Image, display, JSON
from textractcaller.t_call import call_textract, Textract_Features, call_textract_expense
from textractprettyprinter.t_pretty_print import convert_table_to_list
from trp import Document
import os
import pandas as pd

# variables
sagemaker_session = Session()
data_bucket = sagemaker.Session().default_bucket()
region = boto3.session.Session().region_name
aws_role = sagemaker_session.get_caller_identity_arn()

# boto3 clients
s3=boto3.client('s3')
textract = boto3.client('textract', region_name=region)

print(f"Region is {region}, IAM Role: {aws_role}, S3 Bucket: {data_bucket}")

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
Region is us-east-1, IAM Role: arn:aws:iam::768517646540:role/idp-SageMakerExecutionRole-Ir1Yes3jcOAq, S3 Bucket: sagemaker-us-east-1-768517646540


## Upload sample data to S3 bucket


The sample document is in `/samples` directory. For this workshop, we will be using a sample document.

In [15]:
# Upload images to S3 bucket:

!aws s3 cp samples s3://{data_bucket}/idp/genai --recursive --only-show-errors

---
# Extract structured data such as tables and key-value pairs using Amazon Textract


### Extracting Tables


In [16]:
prefix = "idp/genai"
file_key = "health_plan.pdf"
resp = call_textract(input_document=f's3://{data_bucket}/{prefix}/{file_key}', features=[Textract_Features.TABLES])
tdoc = Document(resp)
dfs = list()

In [17]:
for page in tdoc.pages:
    for table in page.tables:
        tab_list = convert_table_to_list(trp_table=table)
        print(tab_list)
        dfs.append(pd.DataFrame(tab_list))
df1 = dfs[0]
df2 = dfs[1]

[['PLAN INFORMATION Plan Name ', 'ANYCOMPANY INC '], ['Name And Address Of Employer ', 'ANYCOMPANY, INC 123 ANY STREET ANY CITY CA, 92127 '], ['Name, Address, And Phone Number Of Plan Administrator ', 'ANYCOMPANY, INC 123 ANY STREET ANY CITY CA, 92127 '], ['866-648-0044 ', ''], ['Named Fiduciary ', 'ANYCOMPANY, INC. '], ['Claims Appeal Fiduciary For Medical Claims ', 'UMR '], ['Employer Identification Number Assigned By The IRS ', '12-34567890 '], ['Plan Number Assigned By The Plan ', '511 '], ['Type Of Benefit Plan Provided ', 'Self-funded Health and Welfare Plan providing group health benefits. ']]
[['Type Of Administration ', 'The administration of the Plan is under the supervision of the Plan Administrator. The Plan is not financed by an insurance company and benefits are not guaranteed by a contract of insurance. '], ['Name And Address Of Agent For Service Of Legal Process ', 'ANYCOMPANY, INC 123 ANY STREET ANY CITY CA, 92127 '], ['Funding Of The Plan ', 'Employer and Employee Con

In [18]:
df1

Unnamed: 0,0,1
0,PLAN INFORMATION Plan Name,ANYCOMPANY INC
1,Name And Address Of Employer,"ANYCOMPANY, INC 123 ANY STREET ANY CITY CA, 92..."
2,"Name, Address, And Phone Number Of Plan Admini...","ANYCOMPANY, INC 123 ANY STREET ANY CITY CA, 92..."
3,866-648-0044,
4,Named Fiduciary,"ANYCOMPANY, INC."
5,Claims Appeal Fiduciary For Medical Claims,UMR
6,Employer Identification Number Assigned By The...,12-34567890
7,Plan Number Assigned By The Plan,511
8,Type Of Benefit Plan Provided,Self-funded Health and Welfare Plan providing ...


In [19]:
df2

Unnamed: 0,0,1
0,Type Of Administration,The administration of the Plan is under the su...
1,Name And Address Of Agent For Service Of Legal...,"ANYCOMPANY, INC 123 ANY STREET ANY CITY CA, 92..."
2,Funding Of The Plan,Employer and Employee Contributions Benefits a...
3,Benefit Plan Year,Benefits begin on January 1 and end on the fol...
4,Plan Year,January 1 through December 31
5,And Other Federal Compliance,It is intended that this Plan comply with all ...


### Extracting Forms (key-value pairs) data


In [20]:
from textractcaller.t_call import call_textract, Textract_Features
from textractprettyprinter.t_pretty_print import Pretty_Print_Table_Format, Textract_Pretty_Print, get_string


# Call Amazon Textract
response = call_textract(input_document=f's3://{data_bucket}/{prefix}/{file_key}', features=[Textract_Features.FORMS])


print(get_string(textract_json=response,
               table_format=Pretty_Print_Table_Format.csv,
               output_type=[Textract_Pretty_Print.FORMS]))

Key,Value
3.,"Improve population health. Invest in public health programs and prevention to promote healthy lifestyles, reduce health risks, and improve health outcomes."
Policy Reforms:,3. Require insurance companies to cover pre-existing premiums due to health status. 4. Increase healthcare subsidies and tax credits. Make income individuals and families. 5. Invest in preventive care and public health. Increase
Goals:,
Revised,01-01-2022
Mission:,"To implement policy reforms and programs that expand access to healthcare, reduce costs, and improve health outcomes."
Vision:,"To provide high quality, affordable healthcare for all citizens."
Key,Value
Plan Name,ANYCOMPANY INC
Claims Appeal Fiduciary For Medical Claims,UMR
Name And Address Of Employer,"ANYCOMPANY, INC 123 ANY STREET ANY CITY CA, 92127"
Type Of Benefit Plan Provided,Self-funded Health and Welfare Plan providing group health benefits.
Plan Number Assigned By The Plan,511
Employer Identification Number Assigned By The IRS,12-

# Module 2 - Enhancing IDP with Foundation Models

## Select a pre-trained model


In [21]:
# "huggingface-text2text-flan-t5-xl",
# "huggingface-text2text-flan-t5-large",

model_id, model_version, = (
    "huggingface-text2text-flan-t5-xl",
    "2.0.0",
)

## Retrieve Artifacts & Deploy a HuggingFace FLAN-T5 Endpoint

---

Using SageMaker, we can perform inference on the pre-trained model, even without fine-tuning it first on a new dataset.


In [22]:
def get_sagemaker_session(local_download_dir) -> sagemaker.Session:
    """Return the SageMaker session."""

    sagemaker_client = boto3.client(
        service_name="sagemaker", region_name=boto3.Session().region_name
    )

    session_settings = sagemaker.session_settings.SessionSettings(
        local_download_dir=local_download_dir
    )

    # the unit test will ensure you do not commit this change
    session = sagemaker.session.Session(
        sagemaker_client=sagemaker_client, settings=session_settings
    )

    return session

Deploy model as inference point

In [23]:
import json
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
from sagemaker.utils import name_from_base
import config

# Hub Model configuration. https://huggingface.co/models
hub = {
	'HF_MODEL_ID':'google/flan-t5-xl',
	'SM_NUM_GPUS': json.dumps(1)
}

endpoint_name = name_from_base(f"{config.SOLUTION_PREFIX}-{model_id}")


# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
	image_uri=get_huggingface_llm_image_uri("huggingface",version="1.1.0"),
	env=hub,
	role=aws_role, 
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
	initial_instance_count=1,
	instance_type="ml.g5.2xlarge",
	container_startup_health_check_timeout=300,
    name=endpoint_name
  )
  


# from sagemaker import image_uris, model_uris, script_uris, hyperparameters
# from sagemaker.model import Model
# from sagemaker.predictor import Predictor
# from sagemaker.utils import name_from_base
# import config


# endpoint_name = name_from_base(f"{config.SOLUTION_PREFIX}-{model_id}")

# inference_instance_type = "ml.g5.2xlarge"

# # Retrieve the inference docker container uri. This is the base HuggingFace container image for the default model above.
# deploy_image_uri = image_uris.retrieve(
#     region=None,
#     framework=None,  # automatically inferred from model_id
#     image_scope="inference",
#     model_id=model_id,
#     model_version=model_version,
#     instance_type=inference_instance_type,
# )

# # Retrieve the inference script uri. This includes all dependencies and scripts for model loading, inference handling etc.
# deploy_source_uri = script_uris.retrieve(
#     model_id=model_id, model_version=model_version, script_scope="inference"
# )


# # Retrieve the model uri.
# model_uri = model_uris.retrieve(
#     model_id=model_id, model_version=model_version, model_scope="inference"
# )

# #Create model
# model = Model(
#     image_uri=deploy_image_uri,
#     model_data=model_uri,
#     role=aws_role,
#     predictor_cls=Predictor,
#     name=endpoint_name,
# )

# # deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,
# # for being able to run inference through the sagemaker API.
# model_predictor = model.deploy(
#     initial_instance_count=1,
#     instance_type=inference_instance_type,
#     predictor_cls=Predictor,
#     endpoint_name=endpoint_name,
#     # volume_size=30,
# )

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
-----------!

---
# Perform Common sense reasoning and QA on a document

In this section, we will perform common sense reasoning and Q&A on a document. This section does the following

- Generates text from documents and stores them into S3 in plaintext format
- Generate embeddings from the text
- Uses an in-memory vector database to store the embeddings
- Perform similarity search on the in-memory vector db to find relevant pieces of text that have relavancy to the asked question (by the user)
- Generate the context for the LLM using the search results
- Give the model the context and the original question asked
- Get the answer back from the LLM
- Profit


In [56]:
data_bucket

'sagemaker-us-east-1-768517646540'

In [57]:
from textractcaller.t_call import call_textract, Textract_Features
from trp.trp2 import TDocument, TDocumentSchema
from trp.t_pipeline import order_blocks_by_geo
import boto3
import sagemaker
import pdfplumber
import mimetypes
import trp
import json
import uuid



data_bucket = sagemaker.Session().default_bucket()
prefix="idp/genai"
file_key='discharge-summary.png'
doc_path = f's3://{data_bucket}/{prefix}/discharge-summary.png'

s3=boto3.client('s3')
doc_text=list()
page_num=1
prefix=str(uuid.uuid4())

print(f"Bucket is {data_bucket}")

if not doc_text:
    # CAREFUL: this only works with Single pages of scanned PDF documents
    # typically we will have OCR done on the page in advance of the lang chain initiation
    j = call_textract(input_document=doc_path) 

    t_doc = TDocumentSchema().load(j)
    ordered_doc = order_blocks_by_geo(t_doc) #sort by reading order
    trp_doc = trp.Document(TDocumentSchema().dump(ordered_doc))

    doc_content = str()
    # Iterate over elements in the document
    for page in trp_doc.pages:
        # Print lines and words
        for line in page.lines:
            doc_content = doc_content + "\n" + line.text
            
        content_res = bytes(doc_content, 'utf-8')
        s3.put_object(Bucket=data_bucket,
                                Key=f"llm/sample/page-{page_num}.txt",
                                Body=content_res)
        print(f"Page text written into llm/sample/page-{page_num}.txt")
        page_num=page_num+1

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
Bucket is sagemaker-us-east-1-768517646540
Page text written into llm/sample/page-1.txt


In [58]:
from langchain.document_loaders import S3DirectoryLoader
from langchain.vectorstores import Chroma
from langchain.text_splitter import NLTKTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document
import sagemaker

data_bucket = sagemaker.Session().default_bucket()
prefix='llm/sample'

embeddings = HuggingFaceEmbeddings()
loader = S3DirectoryLoader(data_bucket, prefix=prefix)
docs = loader.load()
text_splitter = NLTKTextSplitter(chunk_size=550)
texts = text_splitter.split_documents(docs)
vectordb = Chroma.from_documents(texts, embeddings)

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml


In [59]:
docs

[Document(page_content="Not a Memorial Hospital\n\nOf Collier\n\nReg: PN/S/11011, Non-Profit\n\nContact: (999)-(888)-(1234)\n\nPhysician Hospital Discharge Summary\n\nProvider: Mateo Jackson, Phd\n\nProvider's Pt ID: 00988277891\n\nPatient Gender: Male\n\nPatient: John Doe\n\nAttachment Control Number: XA/7B/00338763\n\nVisit (Encounter)\n\nAdmitted: 07-Sep-2020\n\nDischarged: 08-Sep-2020\n\nDischarged to: Home with support services\n\nAssessment\n\n35 yo M c/o stomach problems since 2 montsh ago. Patient\n\nReported Symptoms / History\n\nreports epigastric abdominal pain non-radiating. Pain is\n\nof present illness:\n\ndescribed as gnawing and burning, intermitent lasting 1-2\n\nhours, and gotten progressively worse. Antacids used to\n\nalleviate pain but not anymore; nothing exhacerbates pain.\n\nPain unrelated to daytime or to meals. Patient denies\n\nconstipation or diarrhea. Patient denies blood in stool but\n\nhave noticed them darker. Patient also reports nausea.\n\nDenies recen

## Using HuggingFace FLAN-T5 XL SageMaker endpoint

Now we have our Vector DB loaded with the chunks of the document. Now all is left is to take a question from the user, perform similarity search on the Vector DB and then give the model the context and the prompt and wait for it to answer the question. But before that let's define a custom QA chain with the same SageMaker endpoint but a slightly different prompt template since we want the model to answer question from the text rather than generate questions. We won't do a detailed prompt engineering as before but rather use a simple prompt in this case, but the previous method may also be utilized to craft a more robust QA prompt. We use LangChain's `PromptTemplate` to craft the prompt this time -

Let's first set the payload parameters of (output) text generation. When invoking the endpoint, our JSON payload can include any desired inference parameters that help control the length, sampling strategy, and output token sequence restrictions. 

You may refer to this [documentation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig) by HuggingFace for detailed explanation on generation parameters. 

In [26]:
FLAN_T5_PARAMETERS = {
    "temperature": 0.97,           # the value used to modulate the next token probabilities.
    "max_length": 100,             # restrict the length of the generated text.
    "num_return_sequences": 5,     # number of output sequences returned.
    "top_k": 50,                   # in each step of text generation, sample from only the top_k most likely words.
    "top_p": 0.95,                 # in each step of text generation, sample from the smallest possible set of words with cumulative probability top_p.
    "do_sample": True              # whether or not to use sampling; use greedy decoding otherwise.
}

In [60]:
from langchain import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import json
from typing import Dict

# Define a handler class to transform input from LLM to a format that SageMaker endpoint expects.
class QAContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    
    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generated_text"]
    


qa_content_handler = QAContentHandler()

prompt_template="""Given the following text from a document, answer the question to the best of your abilities. Answer only from the provided document,, if you do not know the answer 
just say you don't know. DO NOT make up an answer.

Document: {context}
Question: {input_query}
"""

prompt=PromptTemplate(input_variables=["input_query", "context"], 
                                               template=prompt_template)
llm = SagemakerEndpoint(
        endpoint_name="huggingface-pytorch-tgi-inference-2023-12-10-19-12-31-309", # replace with your endpoint name if needed
        region_name=region,
        model_kwargs=FLAN_T5_PARAMETERS,
        content_handler=qa_content_handler)

qa_chain = LLMChain(
    llm=llm,
    prompt=prompt
)



In [61]:
question="What is the patient's name?"

## Common sense reasoning / natural language inference

Perform a similarity search on the document with `k=3` which means it will return the top-3 chunks of text that are relevant to the question asked.

In [62]:
similar_docs = vectordb.similarity_search(question, k=3) #see also : max_marginal_relevance_search_by_vector(query, k=3)
context_list = [a.page_content for a in similar_docs]
metadata_list = [a.metadata.get('source') for a in similar_docs]
context = "\n\n".join(context_list)
context

"Not a Memorial Hospital\n\nOf Collier\n\nReg: PN/S/11011, Non-Profit\n\nContact: (999)-(888)-(1234)\n\nPhysician Hospital Discharge Summary\n\nProvider: Mateo Jackson, Phd\n\nProvider's Pt ID: 00988277891\n\nPatient Gender: Male\n\nPatient: John Doe\n\nAttachment Control Number: XA/7B/00338763\n\nVisit (Encounter)\n\nAdmitted: 07-Sep-2020\n\nDischarged: 08-Sep-2020\n\nDischarged to: Home with support services\n\nAssessment\n\n35 yo M c/o stomach problems since 2 montsh ago.\n\nPatient\n\nReported Symptoms / History\n\nreports epigastric abdominal pain non-radiating.\n\nPatient denies\n\nconstipation or diarrhea.\n\nPatient denies blood in stool but\n\nhave noticed them darker.\n\nPatient also reports nausea.\n\nDenies recent illness or fever.\n\nHe also reports fatigue\n\nsince 2 weeks ago and bloating after eating.\n\nPatient ID: NARH-36640\n\nROS: Negative except for above findings\n\nMeds: Motrin once/week.\n\nTums previously.\n\nPMHX: Back pain and muscle spasms.\n\nNo HX of surge

## Question and answering

We can now use the custom QA chain with the SageMaker endpoint to provide an answer to our question, based on the content of the documents as shown below.

In [63]:
qa_chain.run({
    'input_query': question,
    'context': context
    })

'John Doe'

# Gradio

## Automated chatbot

In [31]:
similar_docs = vectordb.similarity_search(question, k=3)
print(similar_docs)

[Document(page_content='ANYCOMPANY INC\n\nPLAN INFORMATION Plan Name\n\nANYCOMPANY, INC\n\nName And Address Of Employer\n\n123 ANY STREET ANY CITY CA, 92127\n\nANYCOMPANY, INC\n\nName, Address, And Phone Number\n\n123 ANY STREET ANY CITY CA, 92127\n\nOf Plan Administrator\n\n866-648-0044\n\nANYCOMPANY, INC.\n\nNamed Fiduciary\n\nClaims Appeal Fiduciary For Medical Claims\n\nUMR\n\n12-34567890\n\nEmployer Identification Number Assigned By\n\nThe IRS\n\nPlan Number Assigned By The Plan\n\n511\n\nSelf-funded Health and Welfare Plan providing\n\nType Of Benefit Plan Provided\n\ngroup health benefits.', metadata={'source': 's3://sagemaker-us-east-1-768517646540/llm/sample/page-1.txt'}), Document(page_content='You are a valued Employee of ANYCOMPANY, INC., and Your employer is pleased to sponsor this Plan\n\nto provide benefits that can help meet Your health care needs.\n\nPlease read this document carefully and\n\ncontact Your Human Resources or Personnel office if You have questions or if 

In [32]:
# let's delete the index, we will create it again
ids_to_delete=vectordb.get(where={"source":  's3://sagemaker-us-east-1-768517646540/llm/sample/page-1.txt'})['ids']
vectordb.delete(ids=ids_to_delete)

In [33]:
from langchain.document_loaders import S3DirectoryLoader
from langchain.vectorstores import Chroma
from langchain.text_splitter import NLTKTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document
import sagemaker

data_bucket = sagemaker.Session().default_bucket()
prefix='llm/sample'

embeddings = HuggingFaceEmbeddings()
loader = S3DirectoryLoader(data_bucket, prefix=prefix)
docs = loader.load()
text_splitter = NLTKTextSplitter(chunk_size=550)
texts = text_splitter.split_documents(docs)
vectordb = Chroma.from_documents(texts, embeddings)

In [80]:
from langchain import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import json
from typing import Dict



FLAN_T5_PARAMETERS = {
    "temperature": 0.97,           # the value used to modulate the next token probabilities.
    "max_length": 100,             # restrict the length of the generated text.
    "num_return_sequences": 5,     # number of output sequences returned.
    "top_k": 50,                   # in each step of text generation, sample from only the top_k most likely words.
    "top_p": 0.95,                 # in each step of text generation, sample from the smallest possible set of words with cumulative probability top_p.
    "do_sample": True              # whether or not to use sampling; use greedy decoding otherwise.
}


# Define a handler class to transform input from LLM to a format that SageMaker endpoint expects.
class QAContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    
    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generated_text"]
    


qa_content_handler = QAContentHandler()



def create_prompt_template():
    _template = """
    
Given the following chat history and a follow up question, rephrase the follow up question to be a standalone question, in its original language. Skip the preamble and just get to the question.

<chat_history>    
{chat_history}
</chat_history>    
<follow_up_question>
{question}
</follow_up_question>
"""
    conversation_prompt = PromptTemplate.from_template(_template)
    return conversation_prompt

template = """

Answer the question as truthfully as possible strictly using only the provided text, and if the answer is not contained within the text, say "I don't know". Skip any preamble text and reasoning and give just the answer. If the user greets you, just greet them back.

<text>
{context}
{chat_history}
</text>

<question>
{input_query}
</question>

<answer>
"""



prompt=PromptTemplate(input_variables=["input_query", "context", "chat_history"], 
                                               template=template)
llm = SagemakerEndpoint(
        endpoint_name="huggingface-pytorch-tgi-inference-2023-12-10-19-12-31-309", # replace with your endpoint name if needed
        region_name=region,
        model_kwargs=FLAN_T5_PARAMETERS,
        content_handler=qa_content_handler)

qa_chain = LLMChain(
    llm=llm,
    prompt=prompt
)

questions = [     
    "Hi AI, I am Bob Doe. How are you?",
    "Who is the patient?",
    "Why was John admitted to the hospital?",
    "Do you remember my name?",
    "What past health issues does John have?",
]


chat_history = []

for question in questions:
    similar_docs = vectordb.similarity_search(question, k=3) #see also : max_marginal_relevance_search_by_vector(query, k=3)
    context_list = [a.page_content for a in similar_docs]
    metadata_list = [a.metadata.get('source') for a in similar_docs]
    context = "\n\n".join(context_list)
    result = qa_chain.run({
        'input_query': question,
        'context': context,
        'chat_history': chat_history
    })
    chat_history.append((question, result[0]))
    print(f"-> **Question**: {question} \n")
    print(f"**Answer**: {result} \n")

-> **Question**: Hi AI, I am Bob Doe. How are you? 

**Answer**: I don't know. 

-> **Question**: Who is the patient? 

**Answer**: John Doe 

-> **Question**: Why was John admitted to the hospital? 

**Answer**: He has been experiencing stomach problems. 

-> **Question**: Do you remember my name? 

**Answer**: I don't know. 

-> **Question**: What past health issues does John have? 

**Answer**: Back pain and muscle spasms. No HX of surgery. NKDA. F 



In [74]:
from langchain import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import json
from typing import Dict
import random
import gradio as gr
from langchain.chains import ConversationalRetrievalChain

FLAN_T5_PARAMETERS = {
    "temperature": 0.95,           # the value used to modulate the next token probabilities.
    "max_length": 150,             # restrict the length of the generated text.
    "num_return_sequences": 5,     # number of output sequences returned.
    "top_k": 50,                   # in each step of text generation, sample from only the top_k most likely words.
    "top_p": 0.95,                 # in each step of text generation, sample from the smallest possible set of words with cumulative probability top_p.
    "do_sample": True              # whether or not to use sampling; use greedy decoding otherwise.
}


# Define a handler class to transform input from LLM to a format that SageMaker endpoint expects.
class QAContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    
    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generated_text"]
    


qa_content_handler = QAContentHandler()



def create_prompt_template():
    _template = """
    
Given the following chat history and a follow up question, rephrase the follow up question to be a standalone question, in its original language. Skip the preamble and just get to the question.

<chat_history>    
{chat_history}
</chat_history>    
<follow_up_question>
{question}
</follow_up_question>
"""
    conversation_prompt = PromptTemplate.from_template(_template)
    return conversation_prompt

template = """

Answer the question as truthfully as possible strictly using only the provided text, and if the answer is not contained within the text, say "I don't know". Skip any preamble text and reasoning and give just the answer. If the user greets you, just greet them back.

<text>
{context}
{chat_history}
</text>

<question>
{input_query}
</question>

<answer>
"""



prompt=PromptTemplate(input_variables=["input_query", "context", "chat_history"], 
                                               template=template)
                     
llm = SagemakerEndpoint(
        endpoint_name="huggingface-pytorch-tgi-inference-2023-12-10-19-12-31-309", # replace with your endpoint name if needed
        region_name=region,
        model_kwargs=FLAN_T5_PARAMETERS,
        content_handler=qa_content_handler)

qa_chain = LLMChain(
    llm=llm,
    prompt=prompt
)

questions = [     
    "Hi AI, I am Bob Doe. How are you?",
    "Who is the employer?",
    "What is the address of the employer?",
    "Do you remember my name?"
]


chat_history = []


def qa_chain_fn(message, history):
    result = qa_chain.run({
        'input_query': message,
        'context': context,
        'chat_history': chat_history
    })
    chat_history.append((question, result["answer"]))
    return result['answer'].strip()
    
gr.ChatInterface(qa_chain_fn).launch()

ImportError: cannot import name 'RootModel' from 'pydantic' (/opt/conda/lib/python3.10/site-packages/pydantic/__init__.cpython-310-x86_64-linux-gnu.so)

## Use LangChain to create LLM class for Text extraction and SageMaker endpoint calls

---
Now that we have deployed our endpoints, it is ready to use and perform Summarization on our document. We will use LangChain to perform inference and we need to first create two LLM Classes using the base LangChain LLM Class. Read more about LangChain LLM Class [here](https://python.langchain.com/en/latest/modules/models/llms.html). Specifically we will create two custom LLM classes

1. An LLM class to extract text from our document using Amazon Textract
2. An LLM class to be able to make calls to the SageMaker endpoint where our FlanT5 model is deployed

The purpose of building these custom LLM classes is to be able to easily use these constructs with LangChain's pre-built or custom chains. Read more about LangChain chains [here](https://python.langchain.com/en/latest/modules/chains.html)

In [None]:
from langchain.llms.base import LLM
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from typing import Optional, List
from textractcaller.t_call import call_textract, Textract_Features
from trp.trp2 import TDocumentSchema
from trp.t_pipeline import order_blocks_by_geo_x_y
import trp
import json

class OcrLLM(LLM):    
    @property
    def _llm_type(self) -> str:
        return "custom"
    
    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        # prompt is the document path
        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")
        j = call_textract(input_document=prompt)
        t_doc = TDocumentSchema().load(j)
        ordered_doc = order_blocks_by_geo_x_y(t_doc)
        trp_doc = trp.Document(TDocumentSchema().dump(ordered_doc))
        document = str()
        for page in trp_doc.pages:
            for line in page.lines:
                document = document + "\n" + line.text
        return document

ocrllm = OcrLLM()
ocr_prompt = PromptTemplate(
    input_variables=["doc_path"],
    template="{doc_path}",
)
ocr_chain = LLMChain(llm=ocrllm, prompt=ocr_prompt)

In [None]:
from langchain import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.chains import LLMChain
from langchain.prompts import load_prompt, PromptTemplate
import json

class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
        input_str = json.dumps({"text_inputs": prompt,  **model_kwargs})
        return input_str.encode('utf-8')
    
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json['generated_texts'][0]

content_handler = ContentHandler()
prompt_template = """Write a short summary for this text using your own words without quoting text directly from the provided text. Make sure to only include full and complete sentences: 
{document}"""
prompt = PromptTemplate.from_template(prompt_template)

llm_chain = LLMChain(
    llm=SagemakerEndpoint(
        endpoint_name=endpoint_name, # replace with your endpoint name if needed
        region_name=region,
        model_kwargs={"temperature":0.97,
                      "max_length": 150,
                      "num_return_sequences": 3,
                      "top_k": 50,
                      "top_p": 0.95,
                      "do_sample": True},
        content_handler=content_handler
    ),
    prompt=prompt
)