# Importing supporting libraries

In [1]:
%%time
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langchain.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from sqlalchemy import create_engine, text
import os
from dotenv import load_dotenv

import pandas as pd
import pprint

CPU times: total: 3.61 s
Wall time: 3.68 s


# Connecting Local SQL Server

In [5]:
engine = create_engine(
        "mssql+pyodbc://KARTHIK-ASUS\\SQLEXPRESS/AdventureWorks2022?"
        "driver=ODBC+Driver+17+for+SQL+Server"
        "&autocommit=true"
        "&trusted_connection=yes"
        "&echo=False"
        "&fast_executemany=True"
    )

# Check if the connection is successful
if engine.connect():
    print("Connection successful!")
else:
    print("Connection failed.")

Connection successful!


## Get database schema

In [16]:
schema_query = """ SELECT 
    s.name AS schema_name,
    t.name AS table_name,
    c.name AS column_name,
    ty.name AS data_type
FROM sys.tables t
INNER JOIN sys.columns c ON t.object_id = c.object_id
INNER JOIN sys.types ty ON c.user_type_id = ty.user_type_id
INNER JOIN sys.schemas s ON t.schema_id = s.schema_id
WHERE t.name IN ('SalesOrderHeader', 'Customer')
ORDER BY s.name, t.name, c.column_id; """

with engine.connect() as conn:
    result = conn.execute(text(schema_query))  # Execute query
    schema_info = result.fetchall()  # Fetch all rows

print(schema_info)

[('Sales', 'Customer', 'CustomerID', 'int'), ('Sales', 'Customer', 'PersonID', 'int'), ('Sales', 'Customer', 'StoreID', 'int'), ('Sales', 'Customer', 'TerritoryID', 'int'), ('Sales', 'Customer', 'AccountNumber', 'varchar'), ('Sales', 'Customer', 'rowguid', 'uniqueidentifier'), ('Sales', 'Customer', 'ModifiedDate', 'datetime'), ('Sales', 'SalesOrderHeader', 'SalesOrderID', 'int'), ('Sales', 'SalesOrderHeader', 'RevisionNumber', 'tinyint'), ('Sales', 'SalesOrderHeader', 'OrderDate', 'datetime'), ('Sales', 'SalesOrderHeader', 'DueDate', 'datetime'), ('Sales', 'SalesOrderHeader', 'ShipDate', 'datetime'), ('Sales', 'SalesOrderHeader', 'Status', 'tinyint'), ('Sales', 'SalesOrderHeader', 'OnlineOrderFlag', 'Flag'), ('Sales', 'SalesOrderHeader', 'SalesOrderNumber', 'nvarchar'), ('Sales', 'SalesOrderHeader', 'PurchaseOrderNumber', 'OrderNumber'), ('Sales', 'SalesOrderHeader', 'AccountNumber', 'AccountNumber'), ('Sales', 'SalesOrderHeader', 'CustomerID', 'int'), ('Sales', 'SalesOrderHeader', 'Sa

In [None]:
def get_schema(table_list: list):
       
    # Query to get table and column information
    table_name = ', '.join(f"'{name}'" for name in table_list)

    schema_query = f"""
        SELECT 
            s.name AS schema_name,
            t.name AS table_name,
            c.name AS column_name,
            ty.name AS data_type
        FROM sys.tables t
        INNER JOIN sys.columns c ON t.object_id = c.object_id
        INNER JOIN sys.types ty ON c.user_type_id = ty.user_type_id
        INNER JOIN sys.schemas s ON t.schema_id = s.schema_id
        where t.name in ({table_name})
        ORDER BY s.name, t.name, c.column_id;
    """
    
    with engine.connect() as conn:
        result = conn.execute(text(schema_query))  # Execute query
        schema_info = result.fetchall()  # Fetch all rows
    
    # Format schema information
    schema_text = "Database Schema:\n"
    current_table = ""
    current_schema = ""
    
    for schema_name, table, column, data_type in schema_info:
        if table != current_table:
            schema_text += f"\nSchema: {schema_name}"
            schema_text += f"\nTable: {table}\n"
            current_table = table
            schema_name = current_schema
        schema_text += f"- {column} ({data_type})\n"
    
    conn.close()
    
    return schema_text

schema_text = get_schema(table_list = ['SalesOrderHeader','Customer'])
print(schema_text)

Database Schema:

Schema: Sales
Table: Customer
- CustomerID (int)
- PersonID (int)
- StoreID (int)
- TerritoryID (int)
- AccountNumber (varchar)
- rowguid (uniqueidentifier)
- ModifiedDate (datetime)

Schema: Sales
Table: SalesOrderHeader
- SalesOrderID (int)
- RevisionNumber (tinyint)
- OrderDate (datetime)
- DueDate (datetime)
- ShipDate (datetime)
- Status (tinyint)
- OnlineOrderFlag (Flag)
- SalesOrderNumber (nvarchar)
- PurchaseOrderNumber (OrderNumber)
- AccountNumber (AccountNumber)
- CustomerID (int)
- SalesPersonID (int)
- TerritoryID (int)
- BillToAddressID (int)
- ShipToAddressID (int)
- ShipMethodID (int)
- CreditCardID (int)
- CreditCardApprovalCode (varchar)
- CurrencyRateID (int)
- SubTotal (money)
- TaxAmt (money)
- Freight (money)
- TotalDue (money)
- Comment (nvarchar)
- rowguid (uniqueidentifier)
- ModifiedDate (datetime)



# Connecting with GROQ API

In [3]:
llm = ChatGroq(
    groq_api_key = os.getenv('groq_api_key'),
    model_name="mixtral-8x7b-32768" , # You can also use "llama2-70b-4096"
)

# Create prompt template

In [32]:
user_query = """How many customers have ordered between '2013-06-30' to '2014-06-30'? 
Give the results as 'count_of_customers' """

prompt = PromptTemplate.from_template(f"""
You are a SQL expert. Based on the following database schema and natural language query,
generate a SQL query that answers the question.

{schema_text}

Natural Language Query: {user_query}

Generate only the SQL query without any explanation or additional text.
The query should be compatible with Microsoft SQL Server. 
While generating the sql query give the table name with the schema name 
and give proper alias name in the final result.
""")

print(prompt.template)



You are a SQL expert. Based on the following database schema and natural language query,
generate a SQL query that answers the question.

Database Schema:

Schema: Sales
Table: Customer
- CustomerID (int)
- PersonID (int)
- StoreID (int)
- TerritoryID (int)
- AccountNumber (varchar)
- rowguid (uniqueidentifier)
- ModifiedDate (datetime)

Schema: Sales
Table: SalesOrderHeader
- SalesOrderID (int)
- RevisionNumber (tinyint)
- OrderDate (datetime)
- DueDate (datetime)
- ShipDate (datetime)
- Status (tinyint)
- OnlineOrderFlag (Flag)
- SalesOrderNumber (nvarchar)
- PurchaseOrderNumber (OrderNumber)
- AccountNumber (AccountNumber)
- CustomerID (int)
- SalesPersonID (int)
- TerritoryID (int)
- BillToAddressID (int)
- ShipToAddressID (int)
- ShipMethodID (int)
- CreditCardID (int)
- CreditCardApprovalCode (varchar)
- CurrencyRateID (int)
- SubTotal (money)
- TaxAmt (money)
- Freight (money)
- TotalDue (money)
- Comment (nvarchar)
- rowguid (uniqueidentifier)
- ModifiedDate (datetime)


Natur

In [27]:
def process_natural_language_query(nl_query):
    try:
        # Get database schema
        schema = get_schema(table_list = ['Employee'])
        
        # Generate SQL query using the new chain syntax
        sql_query = chain.invoke({
            "schema": schema,
            "query": nl_query
        })
        
        print("Generated SQL Query:")
        print(sql_query)
        
        # Execute the query
        results = pd.read_sql(sql_query,con=engine)
        
        return {
            "sql_query": sql_query,
            "results": results
        }
        
    except Exception as e:
        import traceback
        print(f"Full error traceback:\n{traceback.format_exc()}")
        return f"Error processing query: {str(e)}"

In [34]:
# Process the query
result = process_natural_language_query(user_query)
print("\nResults:")
print(result)

Generated SQL Query:
SELECT COUNT(DISTINCT c.CustomerID)
FROM Sales.Customer c
JOIN Sales.SalesOrderHeader soh ON c.CustomerID = soh.CustomerID
WHERE soh.ShipDate BETWEEN '2013-06-30' AND '2014-06-30';

Results:
{'sql_query': "SELECT COUNT(DISTINCT c.CustomerID)\nFROM Sales.Customer c\nJOIN Sales.SalesOrderHeader soh ON c.CustomerID = soh.CustomerID\nWHERE soh.ShipDate BETWEEN '2013-06-30' AND '2014-06-30';", 'results':         
0  18051}
