In [1]:
import os
import pandas as pd
import getpass
from sqlalchemy import create_engine, text, inspect

# LangChain Core
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
# FIX: The ChatPromptTemplate is now in langchain_core.prompts
from langchain_core.prompts import ChatPromptTemplate 


# LangChain Google GenAI (Model & Embeddings)
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings

# LangChain Community (Database & Tools)
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import create_sql_agent, SQLDatabaseToolkit

# LangChain Vector Stores for RAG
from langchain_community.vectorstores import Chroma

In [2]:
if not os.getenv("GOOGLE_API_KEY"):
    os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter your Google API key: ")

In [3]:
# initialize llm
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash")

# embedding model 
embeddings = GoogleGenerativeAIEmbeddings(model="models/gemini-embedding-001")

print("Setup Complete: LLM and Embeddings Initialized.")

Setup Complete: LLM and Embeddings Initialized.


In [4]:
# We'll use an in-memory SQLite database for simplicity.
db_engine = create_engine("sqlite:///:memory:")

In [5]:
schema_sql = """
-- Table: Customer
CREATE TABLE Customer (
    CustomerID INTEGER PRIMARY KEY,
    FirstName TEXT,
    LastName TEXT,
    Email TEXT,
    City TEXT,
    JoinDate DATE
);

-- Table: Order
CREATE TABLE "Order" (
    OrderID INTEGER PRIMARY KEY,
    CustomerID INTEGER,
    OrderDate DATE,
    TotalAmount REAL,
    FOREIGN KEY (CustomerID) REFERENCES Customer(CustomerID)
);

-- Table: Product
CREATE TABLE Product (
    ProductID INTEGER PRIMARY KEY,
    Name TEXT,
    Category TEXT,
    Price REAL
);

-- Table: OrderItem
CREATE TABLE OrderItem (
    OrderItemID INTEGER PRIMARY KEY,
    OrderID INTEGER,
    ProductID INTEGER,
    Quantity INTEGER,
    Subtotal REAL,
    FOREIGN KEY (OrderID) REFERENCES "Order"(OrderID),
    FOREIGN KEY (ProductID) REFERENCES Product(ProductID)
);

-- Table: Warehouse
CREATE TABLE Warehouse (
    WarehouseID INTEGER PRIMARY KEY,
    City TEXT,
    Capacity INTEGER
);
"""

In [6]:
# Execute DDL
with db_engine.connect() as connection:
    for statement in schema_sql.split(';'):
        statement = statement.strip()
        if statement:
            connection.exec_driver_sql(statement)
    connection.commit()

In [7]:
# Insert Sample Data
data_inserts = [
    ("INSERT INTO Customer VALUES (1, 'Alice', 'Smith', 'alice@e.com', 'New York', '2023-01-15')",),
    ("INSERT INTO Customer VALUES (2, 'Bob', 'Johnson', 'bob@e.com', 'London', '2023-03-10')",),
    ("INSERT INTO \"Order\" VALUES (101, 1, '2024-05-01', 55.99)",),
    ("INSERT INTO \"Order\" VALUES (102, 2, '2024-05-05', 12.50)",),
    ("INSERT INTO Product VALUES (1, 'Laptop', 'Electronics', 1200.00)",),
    ("INSERT INTO Product VALUES (2, 'Mouse', 'Electronics', 25.00)",),
    ("INSERT INTO OrderItem VALUES (1, 101, 2, 2, 50.00)",),
    ("INSERT INTO OrderItem VALUES (2, 102, 2, 1, 25.00)",),
    ("INSERT INTO Warehouse VALUES (10, 'New York', 5000)",),
    ("INSERT INTO Warehouse VALUES (20, 'London', 8000)",)
]

with db_engine.connect() as connection:
    for sql_query in data_inserts:
        connection.exec_driver_sql(sql_query[0])
    connection.commit()

# Create LangChain SQLDatabase object
db = SQLDatabase(db_engine)
print(f"Database created with tables: {db.get_usable_table_names()}")

Database created with tables: ['Customer', 'Order', 'OrderItem', 'Product', 'Warehouse']


In [8]:
# 1. Extract Full Schema into Documents
# We chunk the schema by table definition. This is our document for RAG.
table_schemas = db.get_table_info() # Gets all table DDLs

In [9]:
# For RAG, we break the schema string into separate documents (one per table definition)
# For this simplified example, we'll manually split the DDLs:
tables_ddls = []
ddl_snippets = schema_sql.strip().split(");")
for snippet in ddl_snippets:
    if "CREATE TABLE" in snippet:
        # Clean up and add the DDL for the table
        tables_ddls.append(snippet.strip() + ");")

In [10]:
# 2. Create Vector Store with Schema Documents
# Using Chroma as the Vector Store
vectorstore = Chroma.from_texts(
    texts=tables_ddls,
    embedding=embeddings,
    collection_name="sql_schema_collection",
)
retriever = vectorstore.as_retriever()
print(f"Schema indexed into Vector Store with {len(tables_ddls)} documents.")

Schema indexed into Vector Store with 5 documents.


In [11]:
# 1. Custom Prompt Template (to inject RAG context)
# The prompt is key: it tells the LLM to use the *retrieved context* to generate SQL.
system_prompt = """
You are an expert SQL Agent that translates natural language questions into accurate SQL queries.
You must use the provided database schema context below to write syntactically correct queries.

Relevant Database Schema (RETRIEVED CONTEXT):
---------------------
{context}
---------------------

Based on the context and the user's question, generate the appropriate SQLite query.
- DO NOT use tables or columns not present in the RETRIEVED CONTEXT.
- If you cannot answer the question based on the provided schema, state that you need more information.
- Use Joins where necessary.
- Always limit your query to a maximum of 10 rows if a limit is not explicitly requested.

Question: {question}
SQL Query:
"""

rag_prompt = ChatPromptTemplate.from_template(system_prompt)

In [12]:
# 2. Define the Chain for Query Generation (RAG + LLM)
# This chain finds relevant schema, then passes it to the LLM to generate SQL.
def format_docs(docs):
    """Formats the retrieved schema documents into a single string."""
    return "\n\n".join([doc.page_content for doc in docs])

sql_generator_chain = (
    {
        "context": retriever | format_docs,  # RAG step: Retrieve relevant schema
        "question": RunnablePassthrough(),    # Pass the user question through
    }
    | rag_prompt
    | llm
    | StrOutputParser() # Output should be the SQL string
)

In [13]:
# 3. Define the Final Answer Generation Chain
# This chain takes the question, the SQL result, and generates the final natural language answer.
answer_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a helpful assistant. Given the user's question and the SQL query result, provide a concise and clear natural language answer. Do not include the SQL query in the final response."),
        ("user", "Question: {question}\nSQL Result: {sql_result}"),
    ]
)

In [14]:
# 4. Define the Execution Tool
# A tool that the Agent can call to execute the generated SQL.
def execute_sql_query(query: str):
    """A tool to safely execute a SQL query against the database."""
    try:
        with db_engine.connect() as connection:
            result = connection.execute(text(query)).fetchall()
            return f"SQL execution successful. Result: {result}"
    except Exception as e:
        return f"SQL Error: {e}. Please correct the query and try again."

In [15]:
# 5. Create the Full RAG-Agent Workflow (The Agent acts as the orchestrator)
# Note: For simplicity in a single notebook, we'll connect the RAG generation to the final answer chain.
# A true LangChain Agent would use tools dynamically, but this runnable approach is cleaner for a quick MVP.

full_rag_sql_workflow = (
    RunnablePassthrough.assign(sql_query=sql_generator_chain)
    | RunnablePassthrough.assign(
        sql_result=lambda x: execute_sql_query(x["sql_query"])
    )
    | RunnablePassthrough.assign(
        answer=lambda x: final_answer_chain.invoke(
            {"question": x["question"], "sql_result": x["sql_result"]}
        )
    )
)
print("RAG-SQL Workflow Ready.")

RAG-SQL Workflow Ready.


In [16]:
def run_query(question):
    print(f"--- User Question: {question} ---")
    
    # FIX: Convert the string input into a dictionary with the expected key "question"
    input_dict = {"question": question}
    
    # Run the full workflow
    response = full_rag_sql_workflow.invoke(input_dict)

    print(f"\n[Generated SQL]: {response['sql_query'].strip()}")
    print(f"[SQL Result]:    {response['sql_result']}")
    print(f"\n[Final Answer]:  {response['answer']}")
    print("--------------------------------------------------\n")
    return response

In [17]:
# Test Query 1: Simple query requiring one table
run_query("List the first name and city of all customers.")

# Test Query 2: Query requiring join and aggregation
run_query("What is the total amount for the order placed by Alice Smith?")

# Test Query 3: Query involving a table with low relevance to the main question
run_query("Which warehouses have a capacity greater than 6000?")

--- User Question: List the first name and city of all customers. ---


GoogleGenerativeAIError: Error embedding content: 'ProtoType' object has no attribute 'DESCRIPTOR'