In [52]:
!pip install snowflake-connector-python



In [53]:
!pip install langchain_core



In [54]:
!pip install langchain_groq



In [None]:
!pip install configparser

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_groq import ChatGroq
import snowflake.connector
import configparser

In [56]:
def connect_to_snowflake(user, password, account, warehouse, database, schema, role):
    conn = snowflake.connector.connect(
        user=user,
        password=password,
        account=account,
        warehouse=warehouse,
        database=database,
        schema=schema,
        role=role,
    )
    return conn

In [57]:
# Function to fetch table and column information from the database schema
def get_db_schema(conn,table_schema, table_name):
    cursor = conn.cursor()
    query = f"SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = '{table_schema}' AND TABLE_NAME = '{table_name}'"
    cursor.execute(query)
    schema_info = [row[0] for row in cursor.fetchall()]
    cursor.close()
    return schema_info

In [58]:
# # Function to create SQL query generation chain
# def create_sql_chain(conn,table_schema, target_table, question, llm_model, api_key):
#     schema_info = get_db_schema(conn,table_schema, target_table)

#     template = f"""
#     You are provided with the schema of the table '{target_table}' in the database:
#     Schema: {schema_info}

#     Write a SQL query to answer the following question using the exact column names from the schema:

#     Question: {question}

#     SQL Query:
#     """
#     prompt = ChatPromptTemplate.from_template(template=template)
#     llm = ChatGroq(model=f"{llm_model}", temperature=0.2, groq_api_key=f"{api_key}")

#     return (
#         RunnablePassthrough(assignments={"schema": schema_info, "question": question})
#         | prompt
#         | llm
#         | StrOutputParser()
#     )

In [59]:
# user_query = What is the cost for customer ID 10000005 on their 5th shopping point?

In [60]:
def execute_query(conn, query: str):
    cursor = conn.cursor()
    cursor.execute(query)
    results = cursor.fetchall()
    cursor.close()
    return results

In [61]:
# Function to create SQL query generation chain
def create_sql_chain(conn,schema, target_table, question,llm_model,api_key):
    schema_info = get_db_schema(conn,schema, target_table)

    template = f"""
        Based on the table schema of table '{target_table}', write a SQL query to answer the question.
        Only provide the SQL query, without any additional text or characters. Use the exact column names from the schema.

        Table schema: {schema_info}
        Question: {question}

        SQL Query:
    """
    prompt = ChatPromptTemplate.from_template(template=template)
    llm = ChatGroq(model=llm_model, temperature=0.2, groq_api_key=api_key)

    return (
        RunnablePassthrough(assignments={"schema": schema_info, "question": question})
        | prompt
        | llm
        | StrOutputParser()
    )

In [62]:
# def create_sql_query(conn, table_schema, target_table, question, llm_model, api_key):
#     schema_info = get_db_schema(conn, table_schema, target_table)

#     template = f"""
#         Based on the table schema of table '{target_table}', write a SQL query to answer the question.
#         Only provide the SQL query, without any additional text or characters. Use the exact column names from the schema.

#         Table schema: {schema_info}
#         Question: {question}

#         SQL Query:
#     """

#     prompt = ChatPromptTemplate.from_template(template=template)
#     llm = ChatGroq(model=llm_model, temperature=0.2, groq_api_key=api_key)

#     # Generate SQL query
#     sql_query_with_description = (
#         RunnablePassthrough(assignments={"schema": schema_info, "question": question})
#         | prompt
#         | llm
#         | StrOutputParser()
#     ).invoke({})

#     # Extract SQL query from the generated output
#     sql_query_lines = sql_query_with_description.split("\n")
#     extracted_query = "\n".join(line.strip() for line in sql_query_lines if line.strip().startswith("SELECT"))

#     return extracted_query.strip()

In [63]:
# Function to create natural language response based on SQL query results

def create_nlp_answer(conn, sql_query, results):

    results_str = "\n".join([str(row) for row in results])

    template = f"""

        Based on the results of the SQL query '{sql_query}', write a natural language response.

        Query Results:

        {results_str}

    """

    prompt = ChatPromptTemplate.from_template(template=template)

    llm = ChatGroq(model="llama3-8b-8192", temperature=0.2, groq_api_key="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")

    return (

        RunnablePassthrough(assignments={"sql_query": sql_query, "results": results_str})

        | prompt

        | llm

        | StrOutputParser()

    )

In [66]:
def main():
    # Read the config file
    config = configparser.ConfigParser()
    config.read('config.ini')

    # Get the Snowflake credentials and other details
    snowflake_config = config['snowflake']
    user = snowflake_config['user']
    password = snowflake_config['password']
    account = snowflake_config['account']
    warehouse = snowflake_config['warehouse']
    database = snowflake_config['database']
    schema = snowflake_config['schema']
    role = snowflake_config['role']
    target_table = snowflake_config['target_table']

    # Get the API details
    api_config = config['api']
    llm_model = 'gemma2-9b-it'
    groq_api_key = api_config['groq_api_key']

    # Connect to Snowflake
    conn = connect_to_snowflake(user, password, account, warehouse, database, schema, role)
    print("Connected to the database successfully!")

    # schema_info = get_db_schema(conn,schema, target_table)
    # print(f"Table Schema: {schema_info}")

    user_query = input("Ask your database a question about " + target_table + ": ")

    sql_chain = create_sql_chain(conn,schema, target_table, user_query,llm_model,groq_api_key)
    sql_query_response = sql_chain.invoke({})

    sql_query = sql_query_response.strip()

    print(f"Generated SQL Query:\n{sql_query}")

    results = execute_query(conn, sql_query)

    if results:

        for row in results:

            print(row)

        # Generate natural language response

        nlp_chain = create_nlp_answer(conn, sql_query, results)

        nlp_response = nlp_chain.invoke({})

        print(f"Natural Language Response:\n{nlp_response}")

    else:

        print("No results found or error occurred.")

    conn.close()

if __name__ == "__main__":

    main()

Connected to the database successfully!
Ask your database a question about INSURANCETABLE: What is the cost for customer ID 10000005 on their 5th shopping point?
Generated SQL Query:
SELECT COST 
FROM INSURANCETABLE 
WHERE CUSTOMER_ID = 10000005 AND SHOPPING_PT = 5;
(731,)
Natural Language Response:
According to the insurance table, the cost associated with customer ID 10000005 and shopping point 5 is $731.
