# Natural Language to SQL Notebook
For this project the AI21 Grande Instruct model seems like the most appropiate. 
We followed its guide to create the endpoint:
https://github.com/AI21Labs/SageMaker/blob/main/J2_GrandeInstruct_example_model_use.ipynb


## Prerequisites
- Please make sure you have boto3 installed and your credentials in ~/.aws/credentials using aws config command. Also set your default region to a region where the models are available eg. us-east-1.
- If you are not running in sagemaker studio, you need to create a sage-maker role 
- Your account needs access to "ml.g5.12xlarge" instances.


In [None]:
model_package_map = {
    "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/j2-grande-instruct-v1-1-43-b1704f916990312a8e21b249a0bd479c",
    "us-east-2": "arn:aws:sagemaker:us-east-2:057799348421:model-package/j2-grande-instruct-v1-1-43-b1704f916990312a8e21b249a0bd479c",
    "us-west-1": "arn:aws:sagemaker:us-west-1:382657785993:model-package/j2-grande-instruct-v1-1-43-b1704f916990312a8e21b249a0bd479c",
    "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/j2-grande-instruct-v1-1-43-b1704f916990312a8e21b249a0bd479c",
    "ca-central-1": "arn:aws:sagemaker:ca-central-1:470592106596:model-package/j2-grande-instruct-v1-1-43-b1704f916990312a8e21b249a0bd479c",
    "eu-central-1": "arn:aws:sagemaker:eu-central-1:446921602837:model-package/j2-grande-instruct-v1-1-43-b1704f916990312a8e21b249a0bd479c",
    "eu-west-1": "arn:aws:sagemaker:eu-west-1:985815980388:model-package/j2-grande-instruct-v1-1-43-b1704f916990312a8e21b249a0bd479c",
    "eu-west-2": "arn:aws:sagemaker:eu-west-2:856760150666:model-package/j2-grande-instruct-v1-1-43-b1704f916990312a8e21b249a0bd479c",
    "eu-west-3": "arn:aws:sagemaker:eu-west-3:843114510376:model-package/j2-grande-instruct-v1-1-43-b1704f916990312a8e21b249a0bd479c",
    "eu-north-1": "arn:aws:sagemaker:eu-north-1:136758871317:model-package/j2-grande-instruct-v1-1-43-b1704f916990312a8e21b249a0bd479c",
    "ap-southeast-1": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/j2-grande-instruct-v1-1-43-b1704f916990312a8e21b249a0bd479c",
    "ap-southeast-2": "arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/j2-grande-instruct-v1-1-43-b1704f916990312a8e21b249a0bd479c",
    "ap-northeast-2": "arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/j2-grande-instruct-v1-1-43-b1704f916990312a8e21b249a0bd479c",
    "ap-northeast-1": "arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/j2-grande-instruct-v1-1-43-b1704f916990312a8e21b249a0bd479c",
    "ap-south-1": "arn:aws:sagemaker:ap-south-1:077584701553:model-package/j2-grande-instruct-v1-1-43-b1704f916990312a8e21b249a0bd479c",
    "sa-east-1": "arn:aws:sagemaker:sa-east-1:270155090741:model-package/j2-grande-instruct-v1-1-43-b1704f916990312a8e21b249a0bd479c"
}

In [None]:
#%pip install -qU "sagemaker"
from sagemaker import ModelPackage
from sagemaker import get_execution_role
import sagemaker as sage
import boto3

In [None]:
boto3.__version__

In [None]:
# ! pip install -U "ai21[SM]"
# ! pip install langchain_experimental langchain
import ai21

In [None]:
region = boto3.Session().region_name
if region not in model_package_map.keys():
    raise ("UNSUPPORTED REGION")

model_package_arn = model_package_map[region]
region

In [None]:
# create a role and give it full sagemaker access. (https://stackoverflow.com/questions/47710558/the-current-aws-identity-is-not-a-role-for-sagemaker)
# Only need to do this if you are running this notebook outside of Sagemaker studio
SAGEMAKER_ROLE = 'sagemaker-role' # TODO replace the role name

In [None]:
try: 
    role = get_execution_role()
except ValueError: # workaround if you are running this notebook locally
    iam = boto3.client('iam')
    role = iam.get_role(RoleName=SAGEMAKER_ROLE)['Role']['Arn']
sagemaker_session = sage.Session()

runtime_sm_client = boto3.client("runtime.sagemaker")
print(f"Using role: {role}")


In [None]:
endpoint_name = "j2-grande-instruct-g5-12"

content_type = "application/json"

real_time_inference_instance_type = (
    "ml.g5.12xlarge" # Optimal cost-latency tradeoff
)

## Deploy the model

In [None]:
# create a deployable model from the model package.
model = ModelPackage(
    role=role, model_package_arn=model_package_arn, sagemaker_session=sagemaker_session
)

# Deploy the model
predictor = model.deploy(1, real_time_inference_instance_type, endpoint_name=endpoint_name, 
                         model_data_download_timeout=3600,
                         container_startup_health_check_timeout=600,
                        )

## Use this cell to deploy endpoint if it is turned off

In [None]:
sagemaker_client = boto3.client('sagemaker')

# The name of the endpoint configuration associated with this endpoint.
endpoint_config_name='j2-grande-instruct-g5-12'


create_endpoint_response = sagemaker_client.create_endpoint(
                                            EndpointName=endpoint_name, 
                                            EndpointConfigName=endpoint_config_name) 
create_endpoint_response


## Experiments

### Play around with the model

In [None]:
instruction = """
Create an executable SQL statement from instruction:

Instruction:
What were the average monthly $ sales for product 03821 in EMEA last year?

SQL Query:
"""

response = ai21.Completion.execute(sm_endpoint=endpoint_name,
                                   prompt=instruction,
                                   maxTokens=100,
                                   temperature=0,
                                   numResults=1)

print(response['completions'][0]['data']['text'])

In [None]:
instruction = """Write an engaging product description for clothing eCommerce site.
Product: Humor Men's Graphic T-Shirt.
Description:

"""

response = ai21.Completion.execute(sm_endpoint=endpoint_name,
                                   prompt=instruction,
                                   maxTokens=100,
                                   temperature=0,
                                   numResults=1)

print(response['completions'][0]['data']['text'])

# integrate langchain into workflow

In [None]:
from langchain.prompts import PromptTemplate

prompt = PromptTemplate.from_template("What is a good name for a company that makes {product}?")
prompt.format(product="colorful socks")

In [None]:
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from typing import Dict
import json 

class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"prompt": prompt, **model_kwargs})
        # print(input_str.encode('utf-8'))
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        print(response_json["completions"][0]["data"]["text"])

        return response_json["completions"][0]["data"]["text"]
    

content_handler = ContentHandler()
parameters = {"maxTokens": 80, "temperature": 0, "numResults": 1}

llm_ai21 = SagemakerEndpoint(
    endpoint_name=endpoint_name,
    region_name=region,
    model_kwargs=parameters,
    content_handler=content_handler,
)

In [None]:
from langchain import SQLDatabase
from langchain_experimental.sql.base import SQLDatabaseSequentialChain

# Reference: https://python.langchain.com/en/latest/modules/chains/examples/sqlite.html#sqldatabasesequentialchain
from sqlalchemy.exc import ProgrammingError

In [None]:
RDS_PORT="5432"
RDS_USERNAME="mihirma"
RDS_PASSWORD=""
RDS_DB_NAME = "postgres" 
RDS_ENDPOINT = "localhost"
RDS_URI = f"postgresql+psycopg2://{RDS_USERNAME}:{RDS_PASSWORD}@{RDS_ENDPOINT}:{RDS_PORT}/{RDS_DB_NAME}"

db = SQLDatabase.from_uri(RDS_URI,
                           include_tables=["employees", "projects", "timelog"],
                           sample_rows_in_table_info=4)

### SQL Database sequential chain 
Performance: not so good. much worse than expected. barely gets it right half the time.

In [None]:
EXAMPLE_PROMPTS = [
    "What is Velma's employee id?",
    "What is the email address of the Chief Technology Officer?",
    "How many hours did Peter work in 2022?",
    "How many Software Engineers does the company have?",
    "Who are the Software Engineers of the company?",
    "Who are the Employees of the company?",
    "List all Software Engineers who have Peter as their manager",
    "Who are the Software Engineers working on the 'Restaurant Management App' project?"
]
results = []
for i in range(10):
    db_chain = SQLDatabaseSequentialChain.from_llm(
        llm_ai21, 
        db, 
        verbose=True, 
        use_query_checker=False, 
        return_intermediate_steps=True,
    )
    result = None
    try:
        result = db_chain(EXAMPLE_PROMPTS[4])
        results.append(result["result"])
    except ProgrammingError as exc:
        print(f"\n\n{exc}")
        results.append(None)
    

In [None]:
#result["result"]
# temperature: 1 -> success: 3/10
# temperature: 0 -> success: 0/10
# temperature: 0.5 -> success 4/10
# temperature: 0.75 -> success 5/10

results

### Result explanation

In [None]:
import pandas as pd

def explain_result(result):
    instruction = f"""
    I am building a text2sql project. Please formulate an answer to my question in natural language in a human readable format.

    Query: 
    List all the software engineers. 
    Response: 
    [('Peter', 'Kabel', 'Software Engineer'), ('Max', 'Mustermann', 'Software Engineer'), ('Fidel', 'Wind', 'Software Engineer')]
    Explanation:
    The Software Engineers are Peter Kabel, Max Mustermann and Fidel Wind.

    Query:
    How many software engineers does the company have?
    Response: 
    3
    Explanation:
    """
    response = ai21.Completion.execute(sm_endpoint=endpoint_name,
                                    prompt=instruction,
                                    maxTokens=80,
                                    temperature=0,
                                    numResults=1)

    return response['completions'][0]['data']['text']
explain_result(None)

### SQL Database Chain (using this right now)
Performs the best out of all the options. currently using zero shot prompting and works well for the basic cases. 

In [None]:
from langchain_experimental.sql import SQLDatabaseChain


db_chain = SQLDatabaseChain.from_llm(llm_ai21, db, verbose=True, return_intermediate_steps=True)
result = db_chain("List all the software engineers.")
#pd.DataFrame(result)
print(result["intermediate_steps"][1])
print(result)

### intergrate few shot prompting

In [None]:
from langchain.prompts import PromptTemplate


TEMPLATE = """Given an input question, create a syntactically correct {dialect} query to run.
Use the following format:

Question: "Question here"
SQLQuery:
"SQL Query to run"

Only use the following tables:

{table_info}.

Some examples of SQL queries that correspond to questions are:

{few_shot_examples}

Question: {input}"""




CUSTOM_PROMPT = PromptTemplate(
    input_variables=["input", "few_shot_examples", "table_info", "dialect"], template=TEMPLATE
)

FEW_SHOT_EXAMPLES = """

Question: Who worked the most hours in 2022?
SQL Query:
SELECT e.first_name, e.last_name, SUM(t.entered_hours) AS total_hours_worked
FROM employees e
JOIN timelog t ON e.employee_id = t.employee_id
WHERE EXTRACT(YEAR FROM t.working_day) = 2022
GROUP BY e.employee_id, e.first_name, e.last_name
ORDER BY total_hours_worked DESC
LIMIT 1;

##

Question: How many Software Engineers does the company have?
SQL Query:
SELECT COUNT(*) from employees
WHERE designation='Software Engineer';

##

Question: How many hours did Velma work in July 2022?
SQL Query:
SELECT SUM(t.entered_hours) AS total_hours_worked
FROM employees e
JOIN timelog t ON e.employee_id = t.employee
WHERE e.first_name = 'Velma'
  AND EXTRACT(YEAR FROM t.working_day) = 2022
  AND EXTRACT(MONTH FROM t.working_day) = 7;

##

Question: Who is working on the Music generator project?
SQL Query:
SELECT * FROM employees
WHERE project_id=(
SELECT project_id FROM projects
WHERE project_name = 'Music generator'
);

##

Question: Who works under Max?
SQL Query:
SELECT * FROM employees
WHERE manager_id=(
SELECT employee_id FROM employees
WHERE first_name = 'Max');

##

Question: Who worked the least hours in April 2022?
SQL Query:
SELECT e.first_name, e.last_name, SUM(t.entered_hours) AS total_hours_worked
FROM employees e
JOIN timelog t ON e.employee_id = t.employee
WHERE EXTRACT(YEAR FROM t.working_day) = 2022
  AND EXTRACT(MONTH FROM t.working_day) = 4
GROUP BY e.employee_id, e.first_name, e.last_name
ORDER BY total_hours_worked
LIMIT 1;

##

"""



In [None]:
from langchain.chains import create_sql_query_chain


input="List all software engineers"
prompt = CUSTOM_PROMPT.format(
    input=input,
    table_info=db.table_info,
    dialect="PostgreSQL",
    few_shot_examples=FEW_SHOT_EXAMPLES
)

chain = create_sql_query_chain(llm_ai21, db)
response = chain.invoke({"question": prompt})
print(response)
print("response:")
response = response.split("##")[0]
print(response)

print("query results:")
db.run(response)

### SQL Agents
Dont perform well with our model. Maybe with a better model.

In [None]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
from langchain.agents import AgentExecutor
from langchain.agents.agent_types import AgentType

toolkit = SQLDatabaseToolkit(db=db, llm=llm_ai21)
agent_executor = create_sql_agent(
    llm=llm_ai21,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
)

agent_executor.run("Describe the employees table")

### Miscellanious

In [None]:
response = ai21.Completion.execute(sm_endpoint=endpoint_name,
                                   prompt=prompt.format(product="colorful socks"),
                                   maxTokens=100,
                                   temperature=0,
                                   numResults=1)
print(response['completions'][0]['data']['text'])

In [None]:
promttext = """ 
Create SQL statement from instruction.

Database: 
Employees: employees(employee_id, first_name, last_name, designation, project_id, email, manager_id)
Projects: projects(project_id, project_name, customer)
Timelog: timelog(entry_id, employee, working_day, entered_hours)

Request: Find what is Peter's email adress.
SQL statement:
SELECT email FROM employees WHERE first_name='Peter';

##

Create SQL statement from instruction.

Database: 
Employees: employees(employee_id, first_name, last_name, designation, project_id, email, manager_id)
Projects: projects(project_id, project_name, customer)
Timelog: timelog(entry_id, employee, working_day, entered_hours)

Request: How many Software Engineers does the company have?
SQL statement:
SELECT COUNT(*) from employees
WHERE designation='Software Engineer';
##

Create SQL statement from instruction.

Database: 
Employees: employees(employee_id, first_name, last_name, designation, project_id, email, manager_id)
Projects: projects(project_id, project_name, customer)
Timelog: timelog(entry_id, employee, working_day, entered_hours)

Request: How many hours did Velma work in 2022?
SQL statement:
SELECT SUM(entered_hours) from timelog 
WHERE employee=(
SELECT employee_id FROM employees
WHERE first_name = 'Velma'
);

##

Create SQL statement from instruction.

Database: 
Employees: employees(employee_id, first_name, last_name, designation, project_id, email, manager_id)
Projects: projects(project_id, project_name, customer)
Timelog: timelog(entry_id, employee, working_day, entered_hours)

Request: {query}
SQL statement:
"""

query = "What is Velma's managers employee id?" 


prompt = PromptTemplate.from_template(promttext)
prompt.format(query=query)

In [None]:
response = ai21.Completion.execute(sm_endpoint=endpoint_name,
                                   prompt=prompt.format(query=query),
                                   maxTokens=80,
                                   temperature=0,
                                   numResults=1)
print(response['completions'][0]['data']['text'])

## Teardown

In [None]:
# Delete endpoint
sagemaker_client.delete_endpoint(EndpointName=endpoint_name)


# Connecting to a RDS instance