1. Connect to MySQL

Uses SQLDatabase to connect to the database.

Fetches table + column information so the model knows the schema.

2. Build a Prompt

Creates a template that tells GPT-4o:

Given this schema and a user question, return only a single-line SQL query.

3. Generate SQL

Takes the user’s question + schema.
Sends them to GPT-4o.
GPT-4o returns an SQL query string (like SELECT SUM(...) ...).

4. Clean & Run

Cleans up the SQL (removes ```sql fences).
Executes it with db.run(sql_query).
Prints the results from MySQL.

In [14]:
from langchain.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import create_sql_query_chain
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate
from dotenv import load_dotenv

In [15]:
load_dotenv()

True

In [2]:
# Connect MySQL database
host = 'localhost'
port = '3306'
username = 'root'
password = 'root'
database_schema = 'poc_test_to_sql'

mysql_uri = f"mysql+pymysql://{username}:{password}@{host}:{port}/{database_schema}"


In [3]:
# Initialize SQLDatabase
db = SQLDatabase.from_uri(mysql_uri, sample_rows_in_table_info=1)


In [4]:
db

<langchain_community.utilities.sql_database.SQLDatabase at 0x2302ff278d0>

In [5]:
# Get table information
db.get_table_info()

'\nCREATE TABLE `2017_budgets` (\n\t`Product_Name` TEXT, \n\t`2017_Budgets` DOUBLE\n)ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4\n\n/*\n1 rows from 2017_budgets table:\nProduct_Name\t2017_Budgets\nProduct 1\t3016489.2089999998\n*/\n\n\nCREATE TABLE customers (\n\t`Customer_Index` INTEGER, \n\t`Customer_Names` TEXT\n)ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4\n\n/*\n1 rows from customers table:\nCustomer_Index\tCustomer_Names\n1\tGeiss Company\n*/\n\n\nCREATE TABLE products (\n\t`Index` INTEGER, \n\t`Product_Name` TEXT\n)ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4\n\n/*\n1 rows from products table:\nIndex\tProduct_Name\n1\tProduct 1\n*/\n\n\nCREATE TABLE regions (\n\tid INTEGER, \n\tname TEXT, \n\tcounty TEXT, \n\tstate_code TEXT, \n\tstate TEXT, \n\ttype TEXT, \n\tlatitude DOUBLE, \n\tlongitude DOUBLE, \n\tarea_code INTEGER, \n\tpopulation INTEGER, \n\thouseholds INTEGER, \n\tmedian_income INTEGER, \n\tland_area INTEGER, \

In [6]:
# Create the LLM Prompt Template
template = """
Based on the table schema below, write a SQL query that would answer the user's question:
Remember: Only provide me the SQL query, don't include anything else. 
Provide me SQL query in a single line, don't add line breaks.

Table Schema: {schema}
Question: {question}
SQL Query:
"""

prompt = ChatPromptTemplate.from_template(template)

In [7]:
prompt

ChatPromptTemplate(input_variables=['question', 'schema'], input_types={}, partial_variables={}, messages=[HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question', 'schema'], input_types={}, partial_variables={}, template="\nBased on the table schema below, write a SQL query that would answer the user's question:\nRemember: Only provide me the SQL query, don't include anything else. \nProvide me SQL query in a single line, don't add line breaks.\n\nTable Schema: {schema}\nQuestion: {question}\nSQL Query:\n"), additional_kwargs={})])

In [13]:
# Get the schema of the database
def get_schema(db):
    schema = db.get_table_info()
    return schema


In [16]:
# Initialize LLM (GPT-4.0)
llm = ChatOpenAI(
    model="gpt-4o",     
    temperature=0     
)

In [25]:
# Create the SQL query chain using the LLM and the prompt template
sql_chain = (
    RunnablePassthrough.assign(schema=lambda _: get_schema(db))
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

In [26]:
# Test the SQL query chain with a sample question
resp = sql_chain.invoke({
    "question": "What is the total 'Line Total' for Zava Group"
})

print(resp)

```sql
SELECT SUM(Line_Total) AS Total_Line_Total FROM sales_order JOIN customers ON sales_order.Customer_Name_Index = customers.Customer_Index WHERE customers.Customer_Names = 'Zava Group';
```


In [27]:
# Clean the SQL query (remove ```sql fences if present)
sql_query = resp.strip()
if sql_query.startswith("```"):
    sql_query = sql_query.strip("`").replace("sql", "", 1).strip()
if sql_query.endswith("```"):
    sql_query = sql_query[:-3].strip()

print("SQL to run:\n", sql_query)

# Run the query on MySQL
result = db.run(sql_query)
print("Query Result:\n", result)


SQL to run:
 SELECT SUM(Line_Total) AS Total_Line_Total FROM sales_order JOIN customers ON sales_order.Customer_Name_Index = customers.Customer_Index WHERE customers.Customer_Names = 'Zava Group';
Query Result:
 [(6979443.600000001,)]
