In [1]:
from langchain_openai import ChatOpenAI
import httpx

MODEL = "Qwen/Qwen2.5-7B-Instruct-AWQ"
# BASE_URL = "https://d9c477b71204.ngrok-free.app/v1"
API_KEY = "dummy-key"
BASE_URL = "https://1dd604bb7266.ngrok-free.app/v1"
http_client = httpx.Client(verify=False)  # ⚠ disables SSL verification (OK for local/dev)
llm = ChatOpenAI(
    model=MODEL,
    temperature=0,
    api_key=API_KEY,
    base_url=BASE_URL,
    http_client=http_client,
)

In [9]:
from backend.prompt_templates import (
    RELEVANT_TABLES_PROMPT_TEMPLATE,
    SQL_QUERY_PROMPT_TEMPLATE,
    QUERY_REWRITE_PROMPT_TEMPLATE
)
import os, json

def read_file(path: str):
    """
    Reads .txt or .json files.
    Returns:
        - str  (for .txt)
        - dict/list (for .json)
    """
    ext = os.path.splitext(path)[1].lower()

    if ext == ".txt":
        with open(path, "r", encoding="utf-8") as f:
            return f.read()

    elif ext == ".json":
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)

    else:
        raise ValueError(f"Unsupported file type: {ext}")
    
def _strip_code_fences(text: str) -> str:
    text = text.strip()
    if text.startswith("```"):
        lines = text.splitlines()
        if len(lines) >= 2 and lines[-1].strip().startswith("```"):
            lines = lines[1:-1]  # remove both fences
        else:
            lines = lines[1:]
        text = "\n".join(lines).strip()
    return text

db_tables = read_file("backend/db_schema.json")
all_tables_text = read_file("backend/db_description.txt")

In [10]:
user_query = "Show top 10 customers by total sales"

In [11]:
response = llm.invoke(
    QUERY_REWRITE_PROMPT_TEMPLATE.format(
        user_query=user_query,
        table_descriptions = all_tables_text
    )
)

modified_query = response.content.strip()
print("Modified Query:")
print(modified_query)

Modified Query:
Display the top 10 customers based on the total sales amount.


In [12]:
response = llm.invoke(
    RELEVANT_TABLES_PROMPT_TEMPLATE.format(
        user_query=user_query,
        table_descriptions=all_tables_text,
    )
)

raw_response = _strip_code_fences(response.content)
tables = json.loads(raw_response)

relevant_tables = [str(t).strip() for t in tables]

print(f"Relevant tables: {relevant_tables}")

Relevant tables: ['Customer', 'SalesOrderHeader']


In [15]:
relevant_tables_text_parts = []
for t in relevant_tables:
    desc = db_tables.get(t)
    if desc:
        relevant_tables_text_parts.append(desc)

relevant_tables_text = "\n\n".join(relevant_tables_text_parts)

print("Relevant Tables Description:")
print(relevant_tables_text)

Relevant Tables Description:
Table: Customer
Columns:
- CustomerID int (PK, NOT NULL)
- PersonID int
- StoreID int
- TerritoryID int
- AccountNumber varchar (NOT NULL)
- rowguid varchar (NOT NULL)
- ModifiedDate timestamp (NOT NULL, default current_timestamp())
Foreign keys:
- PersonID → Person.BusinessEntityID (FK_Customer_Person_PersonID)
- StoreID → Store.BusinessEntityID (FK_Customer_Store_StoreID)
- TerritoryID → SalesTerritory.TerritoryID (FK_Customer_SalesTerritory_TerritoryID)
References tables: Person, SalesTerritory, Store
Referenced by tables: SalesOrderHeader

Table: SalesOrderHeader
Columns:
- SalesOrderID int (PK, NOT NULL)
- RevisionNumber tinyint (NOT NULL, default 0)
- OrderDate timestamp (NOT NULL, default current_timestamp())
- DueDate datetime (NOT NULL)
- ShipDate datetime
- Status tinyint (NOT NULL, default 1)
- OnlineOrderFlag tinyint (NOT NULL, default 1)
- SalesOrderNumber varchar (NOT NULL)
- PurchaseOrderNumber varchar
- AccountNumber varchar
- CustomerID int

In [16]:
response = llm.invoke(
    SQL_QUERY_PROMPT_TEMPLATE.format(
        user_query=user_query,
        tables=relevant_tables_text,
    )
)

sql_text = _strip_code_fences(response.content)

if sql_text.strip().upper().startswith("NOT POSSIBLE WITH GIVEN TABLES"):
        print("NOT POSSIBLE WITH GIVEN TABLES")

print(f"Generated SQL: {sql_text}")

Generated SQL: SELECT 
    c.AccountNumber, 
    SUM(soh.TotalDue) AS TotalSales
FROM 
    Customer c
JOIN 
    SalesOrderHeader soh ON c.CustomerID = soh.CustomerID
GROUP BY 
    c.AccountNumber
ORDER BY 
    TotalSales DESC
LIMIT 10;


In [17]:
from backend.config import DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD
import pymysql
from pymysql.cursors import DictCursor

def get_connection():
    return pymysql.connect(
        host=DB_HOST,
        port=DB_PORT,
        user=DB_USER,
        password=DB_PASSWORD,
        database=DB_NAME,
        cursorclass=DictCursor,
    )

def run_sql(query):
    """
    Execute SQL and return (rows, columns).
    Fetch ALL rows (WARNING: may be large).
    """
    conn = get_connection()
    rows = []
    columns = []

    try:
        with conn.cursor() as cur:
            cur.execute(query)
            rows = cur.fetchall()         # fetch ALL rows
            if rows:
                columns = list(rows[0].keys())
    finally:
        conn.close()

    return rows, columns

In [19]:
import pandas as pd

rows, columns = run_sql(sql_text)

df = pd.DataFrame(rows) if rows else pd.DataFrame(columns=columns)

In [20]:
df

Unnamed: 0,AccountNumber,TotalSales
0,AW00029818,989184.082
1,AW00029715,961675.8596
2,AW00029722,954021.9235
3,AW00030117,919801.8188
4,AW00029614,901346.856
5,AW00029639,887090.4106
6,AW00029701,841866.5522
7,AW00029617,834475.9271
8,AW00029994,824331.7682
9,AW00029646,820383.5466
