#  Text-to-SQL demo using Langchain and Anthropic ClaudeV2 Model in AWS Bedrock

In [9]:
# !pip install langchain==0.0.265
# !pip install openai==0.27.8


### Import Libraries and Set environment variables

In [28]:
import sagemaker, boto3, json
from sagemaker import get_execution_role
from sagemaker.session import Session

from sagemaker.jumpstart.estimator import JumpStartEstimator
from importlib.metadata import version
from langchain.agents import AgentExecutor, create_sql_agent 
from langchain.agents.agent_toolkits import SQLDatabaseToolkit 
from langchain.agents.agent_types import AgentType
from langchain.sql_database import SQLDatabase
from langchain.llms import OpenAI
from sqlalchemy import create_engine
import sys
sys.path.append('config')  # Add the directory to sys.path
from config import config  # Import the config module
# print(dir(config))  
# print(config.db_user)

In [5]:
# sagemaker_session = Session()
# aws_role = sagemaker_session.get_caller_identity_arn()
# aws_region = boto3.Session().region_name
# sess = sagemaker.Session()
!pip install pymysql==1.0.3

Collecting pymysql==1.0.3
  Downloading PyMySQL-1.0.3-py3-none-any.whl (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pymysql
Successfully installed pymysql-1.0.3


In [15]:
# model_id, model_version, = (
#     "huggingface-textgeneration1-gpt-j-6b",
# # huggingface-textgeneration1-bloomz-7b1-fp16
#     "1.*",
# )

In [16]:
# from ipywidgets import Dropdown
# from sagemaker.jumpstart.notebook_utils import list_jumpstart_models

# # Retrieves all Text Generation models available by SageMaker Built-In Algorithms.
# filter_value = "task == textgeneration1"
# text_generation_models = list_jumpstart_models(filter=filter_value)

# # display the model-ids in a dropdown to select a model for inference.
# model_dropdown = Dropdown(
#     options=text_generation_models,
#     value=model_id,
#     description="Select a model",
#     style={"description_width": "initial"},
#     layout={"width": "max-content"},
# )


In [21]:
# display(model_dropdown)


In [19]:

# # model_version="*" fetches the latest version of the model
# model_id, model_version = model_dropdown.value, "1.*"

In [18]:
# model_id


## Define connection to your database.

In [31]:
import pymysql

def getconn():
    
    print(config.db_host)
    # RDS connection info
    username = config.db_user
    password = config.db_password
    host = config.db_host
    port = config.db_port
    database = config.db_name

    # create the connection
    # conn = pymysql.connect(host=host, user=username, password=password, db=database, port=port)
    conn = pymysql.connect(host="database-1.cg8brvavjrza.eu-central-1.rds.amazonaws.com",
                           user="admin", password="", db="classicmodels", port=3306)
    # create a cursor object
    cursor = conn.cursor()

    # execute the SHOW TABLES command
    cursor.execute("SHOW TABLES;")

    # fetch all the results
    tables = cursor.fetchall()
    # print the tables
    for table in tables:
        print(table)
  
    return conn

def get_db():
    # create the SQLDatabase object
    db = SQLDatabase.from_uri(
        "mysql+pymysql://",
        engine_args = {'creator': getconn},
    )
    return db

db = get_db()

database-1.cg8brvavjrza.eu-central-1.rds.amazonaws.com
('customers',)
('employees',)
('offices',)
('orderdetails',)
('orders',)
('payments',)
('productlines',)
('products',)


### Step 2: Create helper functions for Bedrock API and Langchain


In [32]:
def test_sql_agent(db, llm, question):
    
    toolkit = SQLDatabaseToolkit(db=db, llm=llm)
    agent_executor = create_sql_agent(
        llm=llm,
        toolkit=toolkit,
        verbose=True,
        agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
        return_intermediate_steps=True,
        use_query_checker=True,
        max_iterations=5
    )
    agent_executor.run(question)

In [33]:
bedrock_runtime = boto3.client(
    service_name = "bedrock-runtime",
    region_name = "us-east-1"
)

In [34]:
modelId = 'anthropic.claude-v2'
accept = 'application/json'
contentType = 'application/json'


In [13]:
# %pip install --upgrade boto3 botocore langchain

In [35]:
from langchain.llms.bedrock import Bedrock

inference_modifier = {
    "max_tokens_to_sample": 4096,
    "temperature": 0,
    # "top_k": 250,
    # "top_p": 1,
    # "stop_sequences": ["\n\nHuman"],
}

llm = Bedrock(
    model_id=modelId,
    client=bedrock_runtime,
    model_kwargs=inference_modifier
)


## Step 3: Invoking model with the prompt

In [36]:
_question = """What is the average difference between order date and ship date grouped by the customer's country and ordered by the average difference?"""

test_sql_agent(
    db = get_db(),
    llm = llm,
    # llm = OpenAI(model_name="text-davinci-003", temperature=0, verbose=True, openai_api_key=config.openai_api_key),
    question = _question
)

database-1.cg8brvavjrza.eu-central-1.rds.amazonaws.com
('customers',)
('employees',)
('offices',)
('orderdetails',)
('orders',)
('payments',)
('productlines',)
('products',)


[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m Here is one way to approach this question:

Thought: I need to first check what tables are available in the database.

Action: sql_db_list_tables
Action Input: 
[0m
Observation: [38;5;200m[1;3mcustomers, employees, offices, orderdetails, orders, payments, productlines, products[0m
Thought:[32;1m[1;3m Here is one way to approach answering the question "What is the average difference between order date and ship date grouped by the customer's country and ordered by the average difference?":

Thought: The orders table likely contains the order date and ship date information I need. I should check the schema for this table.

Action: sql_db_schema  
Action Input: orders
[0m
Observation: [33;1m[1;3mError: table_names {'orders\n'} not found in database

In [27]:
from langchain.prompts.prompt import PromptTemplate

TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Be sure to put a semi-colon after query.
Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}.

Some examples of SQL queries that corrsespond to questions are:

{few_shot_examples}

Question: {input}"""
CUSTOM_PROMPT = PromptTemplate(
    input_variables=["input", "few_shot_examples", "table_info", "dialect"], template=TEMPLATE
)


In [30]:
few_shot_examples = """
Here is an example: 
The question is: How many unique customers placed an order between 1/1/2004 and 1/31/2004? 
The corresponding SQL is: SELECT count(distinct t1.customerNumber) as distinct_customers FROM orders t5 INNER JOIN customers t1 ON t5.customerNumber = t1.customerNumber WHERE t5.orderDate BETWEEN '2004-01-01' AND '2004-01-31';
Here is an example: 
The question is: What is the product name and price of the highest priced item that was ever purchased? 
The corresponding SQL is: SELECT t8.productName, t4.priceEach FROM orders t5 INNER JOIN orderdetails t4 ON t5.orderNumber = t4.orderNumber LEFT JOIN products t8 ON t4.productCode = t8.productCode ORDER BY priceEach DESC LIMIT 1;
Here is an example: 
The question is: What is the name of the customer with the highest order amount on 12/1/2003? 
The corresponding SQL is: SELECT orders.customerName, orders.total_orderAmount FROM (SELECT t5.orderNumber, t1.customerNumber, t1.customerName, SUM(t4.quantityOrdered*t4.priceEach) as total_orderAmount FROM customers t1 INNER JOIN orders t5 ON t1.customerNumber = t5.customerNumber INNER JOIN orderdetails t4 ON t5.orderNumber = t4.orderNumber WHERE t5.orderDate = '2003-12-01' GROUP BY 1,2,3) orders ORDER BY total_orderAmount DESC LIMIT 1;
Here is an example: 
The question is: How many unique orders were placed at the Paris office during the year 2004?
The corresponding SQL is: SELECT count(distinct t5.orderNumber) as orders FROM orders t5 INNER JOIN customers t1 ON t5.customerNumber = t1.customerNumber LEFT JOIN employees t2 ON t1.salesRepEmployeeNumber = t2.employeeNumber LEFT JOIN offices t3 ON t2.officeCode = t3.officeCode WHERE YEAR(t5.orderDate) = 2004 AND t3.city = 'Paris';
"""

question = """ 
The phone number is in the customers table. 
Extract the first three digits of the phone number for the area code.
Revenue is in the payments table. 
What are the top 10 area codes that generated the most revenue filtered by customers in the USA?
"""

table_info = """
payments, customers
"""

In [31]:
db = get_db()
toolkit = SQLDatabaseToolkit(db=db, llm=llm, prompt=CUSTOM_PROMPT)
agent_executor = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    # agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    return_intermediate_steps=True,
    use_query_checker=True,
    max_iterations=5
)
agent_executor.run(
    dict(input=question, 
         few_shot_examples=few_shot_examples,
         table_info=db.get_table_info(), 
         dialect=db.dialect
        )
)

database-1.cg8brvavjrza.eu-central-1.rds.amazonaws.com
(('customers',), ('employees',), ('offices',), ('orderdetails',), ('orders',), ('payments',), ('productlines',), ('products',))
('customers',)
('employees',)
('offices',)
('orderdetails',)
('orders',)
('payments',)
('productlines',)
('products',)


[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m Here is my thought process and actions to answer the question:

Question: The phone number is in the customers table. Extract the first three digits of the phone number for the area code. Revenue is in the payments table. What are the top 10 area codes that generated the most revenue filtered by customers in the USA?

Thought: I need to first get a list of tables in the database.

Action: sql_db_list_tables  
Action Input: 
[0m
Observation: [38;5;200m[1;3mcustomers, employees, offices, orderdetails, orders, payments, productlines, products[0m
Thought:[32;1m[1;3m Here is the full response:

Question: The phone number is in 

'Agent stopped due to iteration limit or time limit.'