In [1]:
import os
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain_classic.utilities import SQLDatabase
from langchain_core.prompts import ChatPromptTemplate
from langchain_classic.chains import create_sql_query_chain
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
Groq_api_key=os.getenv('GROQ_API_KEY')
db_password=os.getenv('DB_PASSWORD')

In [3]:
#Connecting with MySQL DataBase
host='localhost'
port=3306
username='root'
password=db_password
database_schema='MySQL_DATABASE'
mysql_uri=f"mysql+pymysql://{username}:{password}@{host}:{port}/{database_schema}"
database=SQLDatabase.from_uri(mysql_uri,sample_rows_in_table_info=2)

In [4]:
context=database.get_table_info()

In [5]:
context

'\nCREATE TABLE `STUDENT` (\n\troll_no INTEGER NOT NULL, \n\tname VARCHAR(50) NOT NULL, \n\tclass VARCHAR(10) NOT NULL, \n\t`GRADE` VARCHAR(10) NOT NULL, \n\tPRIMARY KEY (roll_no)\n)ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4\n\n/*\n2 rows from STUDENT table:\nroll_no\tname\tclass\tGRADE\n\n*/\n\n\nCREATE TABLE customers (\n\t`Customer Index` INTEGER, \n\t`Customer Names` TEXT\n)ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4\n\n/*\n2 rows from customers table:\nCustomer Index\tCustomer Names\n1\tGeiss Company\n2\tJaxbean Group\n*/\n\n\nCREATE TABLE products (\n\t`Index` INTEGER, \n\t`Product Name` TEXT\n)ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4\n\n/*\n2 rows from products table:\nIndex\tProduct Name\n1\tProduct 1\n2\tProduct 2\n*/\n\n\nCREATE TABLE regions (\n\tid INTEGER, \n\tname TEXT, \n\tcounty TEXT, \n\tstate_code TEXT, \n\tstate TEXT, \n\ttype TEXT, \n\tlatitude DOUBLE, \n\tlongitude DOUBLE, \n\tarea_code INTEGER, \n\

In [6]:
prompt_template="""
Based on the table schema below, write a valid MySQL SQL query that answers the user's question.
Rules:
- Only output the SQL query.
- Do NOT include explanations or comments.
- Output the SQL query in a single line with no line breaks.
- Use only the tables and columns shown in the schema.
Table Schema:
{schema}
Question: {question}
SQL Query
"""

prompt=ChatPromptTemplate.from_template(prompt_template)

In [7]:
def Get_DataBase_Schema(db):
    return db.get_table_info()

In [8]:
llm_model=ChatGroq(groq_api_key=Groq_api_key,model="openai/gpt-oss-120b")

In [9]:
MySQL_chain=(
    RunnablePassthrough.assign(schema=lambda x:Get_DataBase_Schema(database))|prompt|
    llm_model.bind(stop=["\nSQLResult:"])|StrOutputParser()
)

In [10]:
response=MySQL_chain.invoke({'question':'How many products are there.'})
response

'SELECT COUNT(*) FROM products;'

In [11]:
database.run(response)

'[(30,)]'

In [12]:
response=MySQL_chain.invoke({'question':'What is the total revenue'})
response

'SELECT SUM(`Line Total`) AS total_revenue FROM sales_order;'

In [13]:
database.run(response)

'[(1235968899.0000184,)]'

In [14]:
response=MySQL_chain.invoke({'question':'Give the product with most revenue, return the product name and its total revenue'})
response

'SELECT p.`Product Name` AS product_name, SUM(s.`Line Total`) AS total_revenue FROM products p JOIN sales_order s ON p.`Index` = s.`Product Description Index` GROUP BY p.`Product Name` ORDER BY total_revenue DESC LIMIT 1'

In [15]:
database.run(response)

"[('Product 26', 117291821.40000035)]"