In [None]:
import os
from langchain_community.llms import OpenAI
from langchain_community.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from snowflake.snowpark import Session
from langchain.prompts.prompt import PromptTemplate
from langchain.chains import create_sql_query_chain
from langchain import FewShotPromptTemplate
from langchain.prompts.example_selector import LengthBasedExampleSelector
import dotenv
dotenv.load_dotenv()

In [None]:
account_identifier = os.environ['SF_ACCOUNT_IDENTIFIER']
user = os.environ['SF_USER']
password = os.environ['SF_PASSWORD']
database= os.environ['SF_DATABASE']
schema = os.environ['SF_SCHEMA']
warehouse = os.environ['SF_WAREHOUSE']
role = os.environ['SF_ROLE']

In [None]:
# create our examples
examples = [
    {
        "question": "what is the total number of employees?",
        "answer": "There are 1470 employees."
    }, {
        "question": "How does the Attrition rate differ with EducationField for employees in each department?",
        "answer": """SELECT 
                        "Department", 
                        "EducationField", 
                        AVG(
                            CASE WHEN 'Attrition' = 'Yes' THEN 1 ELSE 0 END
                        ) AS "Avg Attrition Rate" 
                    FROM 
                        attrition 
                    GROUP BY 
                        "Department", 
                        "EducationField" 
                    ORDER BY 
                        "Avg Attrition Rate" DESC 
                    LIMIT 
                        5;"""
    }, {
        "question": "Find the count of employees in each department who have a DailyRate less than the department's average DailyRate.",
        "answer": """SELECT
                        a."Department",
                        COUNT(*) AS "Count of Employees"
                    FROM
                        attrition a
                    JOIN (
                        SELECT
                            "Department",
                            AVG("DailyRate") AS "AvgDailyRate"
                        FROM
                            attrition
                        GROUP BY
                            "Department"
                    ) b ON a."Department" = b."Department"
                    WHERE
                        a."DailyRate" < b."AvgDailyRate"
                    GROUP BY
                        a."Department"
                    ORDER BY
                        "Count of Employees" DESC;"""
    }
]

# create a example template
example_template = """
Question: {question}
Answer: {answer}
"""

# create a prompt example from above template
example_prompt = PromptTemplate(
    input_variables=["question", "answer"],
    template=example_template
)

# now break our previous prompt into a prefix and suffix
# the prefix is our instructions
prefix = """You are a database expert, answer the following question in a descriptive manner.
Keep column names in double-quotes and table name in upper-case, while generating queries.
Here are some examples: 
"""
# and the suffix our user input and output indicator
suffix = """
Question: {question}
Answer: """

# now create the few shot prompt template
few_shot_prompt_template = FewShotPromptTemplate(
    examples=examples,
    example_prompt=example_prompt,
    prefix=prefix,
    suffix=suffix,
    input_variables=["question"],
)

In [None]:
print(few_shot_prompt_template.format(question="What is the total number of employees ?"))

In [None]:
snowflake_url = f"snowflake://{user}:{password}@{account_identifier}/{database}/{schema}?warehouse={warehouse}&role={role}"

db = SQLDatabase.from_uri(snowflake_url,sample_rows_in_table_info=1, include_tables=['attrition'])

# we can see what information is passed to the LLM regarding the database
# print(db.table_info)

In [None]:
llm = OpenAI(temperature=0)

db_chain = SQLDatabaseChain.from_llm(llm, db, return_sql=True, verbose=True)

question = "Find the count of employees in each department who have a DailyRate less than the department's average DailyRate."

result = db_chain.run(few_shot_prompt_template.format(question=question))

In [None]:
print(result)

In [None]:
q_a = []

with open('questions_test.txt', 'r') as f:
    questions = f.readlines()

    for question in questions:
        # print(question)
        result = 'Not calculated'
        result = db_chain.run(few_shot_prompt_template.format(question=question))
        # print(result)
        # answers.append(result)
        q_a.append((question, result))

In [None]:
for q, a in q_a:
    if a == 'Not calculated':
        print(q)

In [None]:
lines = []

for q, a in q_a:
    lines.append(a)
    lines.append('\n')


# Open the file in write mode
with open('question_test_queries', 'w') as file:
    # Write each line to the file
    file.writelines(lines)

In [None]:
q_a = []

with open('questions_valid.txt', 'r') as f:
    questions = f.readlines()

    for question in questions:
        # print(question)
        result = 'Not calculated'
        try:
            result = db_chain.run(few_shot_prompt_template.format(question=question))
        except Exception as e:
            pass
        # print(result)
        # answers.append(result)
        q_a.append((question, result))

In [None]:
for q, a in q_a:
    if a == 'Not calculated':
        print(q)

In [None]:
lines = []

for q, a in q_a:
    lines.append(a)
    lines.append('\n')


# Open the file in write mode
with open('question_valid_queries', 'w') as file:
    # Write each line to the file
    file.writelines(lines)