# Advanced Search with SageMaker Endpoints and LangChain

Relevant docs used to create this notebook:
* https://medium.com/mlearning-ai/supercharging-large-language-models-with-langchain-1cac3c103b52
* https://github.com/marshmellow77/langchain-blog/blob/main/mrkl_app.py
* https://langchain.readthedocs.io/en/latest/modules/chains/examples/sqlite.html
* https://langchain.readthedocs.io/en/latest/modules/llms/examples/custom_llm.html
* https://langchain.readthedocs.io/en/latest/modules/llms/integrations/sagemaker.html

In [None]:
# This notebook requires the `Data Science 3.0` kernel due to Python 3.8+ dependency

# Setup the Environment

Install the AI21 labs pypi package and langchain for augmented generation

In [5]:
!pip install -Uq \
    torch==1.13.1 \
    transformers==4.27.2 \
    langchain==0.0.143

[0m

In [6]:
import os
from typing import Dict
import json
from typing import Optional, List, Mapping, Any, Dict
import boto3

from langchain.docstore.document import Document
from langchain.document_loaders import S3DirectoryLoader, WebBaseLoader
from langchain.llms.base import LLM
from langchain.chains.question_answering import load_qa_chain
from langchain import PromptTemplate
from langchain.agents import initialize_agent, Tool
from langchain.tools import BaseTool
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5EncoderModel

# Create a Custom LLM Handler for the HuggingFace Flan T5 Model

We can now use the `LLM` base from LangChain to create a wrapper around the model

In [43]:
class FlanT5LLM(LLM):
    
#     def __init__(self):
#         self.tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-xl')
#         self.model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-xl')        
        
    @property
    def _llm_type(self) -> str:
        return 'google/flan-t5-xl'
    
    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        ###############
        # this is not the right place for these, i know...
        tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-xl')
        model = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-xl')
        ###############        
        inputs = tokenizer(prompt, return_tensors='pt')
        output = tokenizer.decode(
            model.generate(
                input_ids=inputs["input_ids"], 
                max_new_tokens=200,
            )[0], 
            skip_special_tokens=True
        )

#        print(f'{prompt}{output}')
        return output
#        return response['completions'][0]['data']['text']

In [44]:
llm = FlanT5LLM()
llm("How many US states are there?\n")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

'50'

# Answer a Document Question via a `chain`

Now an artibrary quesiton can be used to answer a question about the document.

### Make an Example Document

This is an example of adding a document for context into the model for question answering.

In [45]:
example_doc_1 = """
Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.
Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.
Therefore, Peter stayed with her at the hospital for 3 days without leaving.
"""

docs = [
    Document(
        page_content=example_doc_1,
    )
]

In [46]:
question = """Who went to the hospital?
"""

prompt_template = """Use the following pieces of context to answer the question at the end.

{context}

Question: {question}
Answer: """
PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["context", "question"]
)

chain = load_qa_chain(
    llm=FlanT5LLM(),
    prompt=PROMPT
)

chain({"input_documents": [docs[0]], "question": question}, return_only_outputs=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

{'output_text': 'Elizabeth'}

# Use natural language to execute a SQL query

LangChain also supports connecting to databases to help write SQL queries to answer questions. Here we will use an SQLite database to answer questions about the total number of employees in a table.

In [47]:
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}

Question: {input}"""
PROMPT = PromptTemplate(
    input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
)

### TODO:  query Chinook.db from regular sql to show contents of table.  then show fancy language query.

### Setup SQL Database with LangChain for natural language queries with the LLM

In [48]:
from langchain import SQLDatabase, SQLDatabaseChain

db = SQLDatabase.from_uri(
    "sqlite:///Chinook.db",
    include_tables=['Employee'],
)
db_chain = SQLDatabaseChain(
    llm=FlanT5LLM(),
    database=db,
    verbose=True,
    prompt=PROMPT,
)

In [49]:
result = db_chain.run("How many employees in employee table??")



[1m> Entering new SQLDatabaseChain chain...[0m
How many employees in employee table??
SQLQuery:

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (595 > 512). Running this sequence through the model will result in indexing errors


[32;1m[1;3mSELECT count(*) FROM Employee[0m
SQLResult: [33;1m[1;3m[(8,)][0m
Answer:

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (613 > 512). Running this sequence through the model will result in indexing errors


[32;1m[1;3m"[0m
[1m> Finished chain.[0m


### TODO:  Figure out why `result` does not contain the results of the SQLQuery

In [50]:
print(result)

"
