In [None]:
### this will follow the example at https://python.langchain.com/docs/expression_language/cookbook/sql_db

In [None]:
%pip install langchain

### setup DB on the studio kernel instance

In [None]:
## setup DB
%pwd

In [None]:
%wget https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql

In [2]:
from langchain.prompts import ChatPromptTemplate

sql_template = """Human: Based on the table schema below, write a SQL query and just the SQL, nothing else, that would answer the user's question.:
{schema}


Question: {question}
SQL Query:
"""
sql_prompt = ChatPromptTemplate.from_template(sql_template)

### Athena Connection

In [3]:
from langchain.utilities import SQLDatabase

In [None]:
%pip install "sqlalchemy<2"

In [None]:
%pip install langchain_experimental

In [None]:
%pip install sqlalchemy-access

In [None]:
%pip install PyAthena

In [4]:
import boto3
from botocore.config import Config
from langchain import PromptTemplate,SagemakerEndpoint,SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.chains import create_sql_query_chain
from sqlalchemy import create_engine

"""
Here we will build the required parameter to connect athena and query database.
1. Data is stored in S3 and metadata in Glue metastore.
2. Create a profille which will have access to the required service.
3. if the database exists and s3 buckets exists use them else create.

"""
region = 'us-east-1'
athena_url = f"athena.{region}.amazonaws.com" 
athena_port = '443' #Update, if port is different
athena_db = 'demo-emp-deb-2' #from user defined params
glue_databucket_name='athena-query-bucket-bharsrid'
s3stagingathena = f's3://{glue_databucket_name}/athenaresults/' 
athena_wkgrp = 'primary' 
athena_connection_string = f"awsathena+rest://@{athena_url}:{athena_port}/{athena_db}?s3_staging_dir={s3stagingathena}/&work_group={athena_wkgrp}"

"""
Under the hood, LangChain uses SQLAlchemy to connect to SQL databases. 
The SQLDatabaseChain can therefore be used with any SQL dialect 
supported by SQLAlchemy, such as MS SQL, MySQL, MariaDB, PostgreSQL, 
Oracle SQL, and SQLite. 
"""
print(athena_connection_string)
athena_engine = create_engine(athena_connection_string, echo=True)
athena_db_connection = SQLDatabase(athena_engine)


awsathena+rest://@athena.us-east-1.amazonaws.com:443/demo-emp-deb-2?s3_staging_dir=s3://athena-query-bucket-bharsrid/athenaresults//&work_group=primary


In [5]:
def get_schema(_):
    return athena_db_connection.get_table_info()

In [6]:
def run_query(query):
    return athena_db_connection.run(query)

In [7]:
inference_modifier = {
    "temperature": 1,
    "top_p": .999,
    "top_k": 250,
    "max_tokens_to_sample": 300,
    "stop_sequences": ["\n\nSQL Query:"]
}

In [8]:
from langchain.chat_models import BedrockChat
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough

chat = BedrockChat(model_id="anthropic.claude-v2", model_kwargs=inference_modifier)

# model = ChatOpenAI()

sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | sql_prompt
    | chat.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

In [None]:
sql_response.invoke({"question": "How many employees are there?"})



In [9]:
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)

In [10]:
full_chain = (
    RunnablePassthrough.assign(query=sql_response)
    | RunnablePassthrough.assign(
        schema=get_schema,
        response=lambda x: athena_db_connection.run(x["query"]),
    )
    | prompt_response
    | chat
)

In [11]:
full_chain.invoke({"question": "How many employees are there?"})

2023-11-02 19:47:00,913 INFO sqlalchemy.engine.Engine SELECT details.employee_id, details."first name", details."last name" 
FROM details LIMIT %(param_1)s
2023-11-02 19:47:00,914 INFO sqlalchemy.engine.Engine [generated in 0.00146s] {'param_1': 3}
2023-11-02 19:47:02,245 INFO sqlalchemy.engine.Engine SELECT location.employee_id, location.location 
FROM location LIMIT %(param_1)s
2023-11-02 19:47:02,246 INFO sqlalchemy.engine.Engine [generated in 0.00107s] {'param_1': 3}



Human:' and '

Assistant:'. Received 

Human: 

Human: Based on the table schema below, write a SQL query and just the SQL, nothing else, that would answer the user's question.:

CREATE EXTERNAL TABLE details (
	employee_id INT,
	`first name` STRING,
	`last name` STRING
)
ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'
WITH SERDEPROPERTIES (
	'field.delim' = ','
)
STORED AS INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat' OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'
LOCATION 's3://employee-db-genai-demo/employee/Details/'
TBLPROPERTIES (
	'CrawlerSchemaDeserializerVersion' = '1.0',
	'CrawlerSchemaSerializerVersion' = '1.0',
	'UPDATED_BY_CRAWLER' = 'employee-db-demo-crawler',
	'areColumnsQuoted' = 'false',
	'averageRecordSize' = '18',
	'classification' = 'csv',
	'columnsOrdered' = 'true',
	'compressionType' = 'none',
	'delimiter' = ',',
	'inputformat' = 'org.apache.hadoop.mapred.TextInputFormat',
	'location' = 's3://employee-db-g

2023-11-02 19:47:06,352 INFO sqlalchemy.engine.Engine SELECT details.employee_id, details."first name", details."last name" 
FROM details LIMIT %(param_1)s
2023-11-02 19:47:06,353 INFO sqlalchemy.engine.Engine [cached since 5.441s ago] {'param_1': 3}
2023-11-02 19:47:06,518 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2023-11-02 19:47:06,522 INFO sqlalchemy.engine.Engine  SELECT COUNT(*)
FROM details
2023-11-02 19:47:06,523 INFO sqlalchemy.engine.Engine [generated in 0.00194s] {}
2023-11-02 19:47:07,631 INFO sqlalchemy.engine.Engine SELECT location.employee_id, location.location 
FROM location LIMIT %(param_1)s
2023-11-02 19:47:07,632 INFO sqlalchemy.engine.Engine [cached since 5.387s ago] {'param_1': 3}
2023-11-02 19:47:07,845 INFO sqlalchemy.engine.Engine COMMIT


AIMessage(content=' Based on the provided table schemas, sample data, SQL query and response, there are 7 employees in the details table. The SQL query performs a COUNT(*) to return the total number of rows in the details table, which the response shows is 7. So there are 7 total employees.')