# 14. SQL RAG - Natural Language to SQL

**Complexity:** ⭐⭐⭐⭐

## Overview

**SQL RAG** enables natural language queries over structured databases by combining:
1. Schema retrieval (finding relevant tables)
2. Text-to-SQL generation (LLM generates SQL)
3. SQL execution (safe execution in sandbox)
4. Results interpretation (natural language answer)

### The Problem

Standard RAG works great for unstructured text, but fails for:
- Structured databases (SQL, NoSQL)
- Aggregations and analytics (COUNT, SUM, AVG)
- Precise data lookups (exact matches, filters)
- Relational queries (JOINs across tables)

**Example queries that need SQL:**
- "How many customers do we have in France?"
- "What's the average order value last month?"
- "Show me the top 5 products by revenue"
- "Which employees have sold more than $50,000?"

### The Solution

SQL RAG pipeline:

```
Question → Schema Retrieval → Text-to-SQL → SQL Validation
    → Execute Query → Results → Natural Language Answer
```

### Architecture

1. **Schema Index**: Embed table/column descriptions
2. **Schema Retrieval**: Find relevant tables for query
3. **SQL Generation**: LLM generates SQL with schema context
4. **Safety Layer**: Validate SQL (read-only, prevent injection)
5. **Execution**: Run query in controlled environment
6. **Interpretation**: Convert results to natural language
7. **Error Handling**: Retry with corrections if SQL fails

### Example Database: Chinook

We'll use the **Chinook** database - a sample music store database with:
- **Artists**: Band/musician information
- **Albums**: Music albums
- **Tracks**: Individual songs
- **Customers**: Customer records
- **Employees**: Staff information
- **Invoices**: Sales transactions
- **InvoiceLines**: Line items
- **Playlists**: Song collections

### When to Use

✅ **Good for:**
- Analytics and aggregations
- Structured data queries
- Enterprise data (databases, data warehouses)
- Precise lookups and filters
- Time-series data

❌ **Not ideal for:**
- Unstructured text documents
- Semantic similarity search
- When users don't understand data structure
- When SQL is too complex for LLM

### Trade-offs

**Pros:**
- ✅ Precise answers (no hallucination)
- ✅ Handles aggregations and math
- ✅ Works with existing databases
- ✅ Verifiable results

**Cons:**
- ❌ Requires good schema design
- ❌ LLM SQL errors are common
- ❌ Security considerations
- ❌ Limited to query expressiveness

---

## Implementation

## 1. Setup and Imports

In [None]:
import sys
import sqlite3
from pathlib import Path
from typing import List, Dict, Any
import json

# Add parent directory to path for imports
sys.path.append(str(Path("../..").resolve()))

from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_community.vectorstores import FAISS
import pandas as pd

from shared.config import (
    verify_api_key,
    DEFAULT_MODEL,
    DEFAULT_TEMPERATURE,
    OPENAI_EMBEDDING_MODEL,
    VECTOR_STORE_DIR,
)
from shared.prompts import (
    SQL_SCHEMA_SUMMARY_PROMPT,
    TEXT_TO_SQL_PROMPT,
    SQL_RESULTS_INTERPRETATION_PROMPT,
    SQL_ERROR_RECOVERY_PROMPT,
)
from shared.utils import (
    print_section_header,
    load_vector_store,
    save_vector_store,
)

# Verify API key
verify_api_key()

print("✓ All imports successful")
print(f"✓ Using model: {DEFAULT_MODEL}")
print(f"✓ Using embeddings: {OPENAI_EMBEDDING_MODEL}")

## 2. Connect to Chinook Database

In [None]:
print_section_header("Connecting to Chinook Database")

# Path to Chinook database
DB_PATH = Path("../..") / "data" / "chinook.db"

if not DB_PATH.exists():
    raise FileNotFoundError(
        f"Chinook database not found at {DB_PATH}.\n"
        "Please download it from: https://github.com/lerocha/chinook-database"
    )

# Create read-only connection
conn = sqlite3.connect(f"file:{DB_PATH}?mode=ro", uri=True)
cursor = conn.cursor()

print(f"✓ Connected to database: {DB_PATH}")

# Get table names
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
tables = [row[0] for row in cursor.fetchall()]

print(f"\n✓ Found {len(tables)} tables:")
for table in tables:
    cursor.execute(f"SELECT COUNT(*) FROM {table}")
    count = cursor.fetchone()[0]
    print(f"  • {table}: {count} rows")

## 3. Inspect Database Schema

In [None]:
def get_table_schema(conn: sqlite3.Connection, table_name: str) -> str:
    """
    Get the schema for a table.
    
    Args:
        conn: Database connection
        table_name: Name of table
    
    Returns:
        Schema description string
    """
    cursor = conn.cursor()
    cursor.execute(f"PRAGMA table_info({table_name})")
    columns = cursor.fetchall()
    
    schema_lines = [f"Table: {table_name}"]
    schema_lines.append("Columns:")
    
    for col in columns:
        col_id, name, col_type, not_null, default, pk = col
        constraints = []
        if pk:
            constraints.append("PRIMARY KEY")
        if not_null:
            constraints.append("NOT NULL")
        
        constraint_str = f" ({', '.join(constraints)})" if constraints else ""
        schema_lines.append(f"  - {name}: {col_type}{constraint_str}")
    
    return "\n".join(schema_lines)


print_section_header("Database Schema Inspection")

# Show schema for a few example tables
example_tables = ["Artist", "Album", "Track", "Customer", "Invoice"]

for table in example_tables:
    if table in tables:
        print(f"\n{table}:")
        print("-" * 80)
        print(get_table_schema(conn, table))

print("\n✓ Schema inspection complete")

## 4. Create Schema Embeddings

We'll embed table schemas to enable semantic retrieval of relevant tables.

In [None]:
print_section_header("Creating Schema Embeddings")

# Initialize LLM for schema summarization
llm = ChatOpenAI(
    model=DEFAULT_MODEL,
    temperature=DEFAULT_TEMPERATURE,
)

# Create schema documents with summaries
schema_docs = []

print("\nGenerating semantic descriptions for tables...")

for table in tables:
    schema = get_table_schema(conn, table)
    
    # Generate semantic summary
    summary_chain = SQL_SCHEMA_SUMMARY_PROMPT | llm | StrOutputParser()
    summary = summary_chain.invoke({
        "table_name": table,
        "schema": schema,
    })
    
    # Create document
    doc = Document(
        page_content=f"{table}: {summary}\n\nSchema:\n{schema}",
        metadata={
            "table_name": table,
            "summary": summary,
            "schema": schema,
        },
    )
    schema_docs.append(doc)
    
    print(f"  ✓ {table}")

print(f"\n✓ Created {len(schema_docs)} schema documents")

# Create vector store for schema
embeddings = OpenAIEmbeddings(model=OPENAI_EMBEDDING_MODEL)
schema_store_path = VECTOR_STORE_DIR / "sql_rag_schema"

schema_vectorstore = load_vector_store(schema_store_path, embeddings)

if schema_vectorstore is None:
    print("\nCreating schema vector store...")
    schema_vectorstore = FAISS.from_documents(schema_docs, embeddings)
    save_vector_store(schema_vectorstore, schema_store_path)
    print("✓ Schema vector store created")
else:
    print("✓ Loaded existing schema vector store")

# Create retriever
schema_retriever = schema_vectorstore.as_retriever(search_kwargs={"k": 3})
print("✓ Schema retriever ready")

## 5. Test Schema Retrieval

In [None]:
print_section_header("Testing Schema Retrieval")

test_queries = [
    "Which tables contain information about songs and music?",
    "Where can I find customer purchase data?",
    "Show me tables related to employees and sales",
]

for query in test_queries:
    print(f"\nQuery: {query}")
    print("-" * 80)
    
    relevant_tables = schema_retriever.invoke(query)
    
    print("Relevant tables:")
    for doc in relevant_tables:
        table_name = doc.metadata["table_name"]
        summary = doc.metadata["summary"]
        print(f"  • {table_name}: {summary}")

## 6. Implement Safe SQL Execution

In [None]:
def execute_sql_safely(
    conn: sqlite3.Connection,
    query: str,
    max_results: int = 100,
) -> tuple[bool, Any, str]:
    """
    Execute SQL query safely with validation.
    
    Args:
        conn: Database connection (should be read-only)
        query: SQL query to execute
        max_results: Maximum rows to return
    
    Returns:
        Tuple of (success, results/error, error_message)
    """
    # Safety checks
    query_upper = query.upper().strip()
    
    # Block dangerous operations
    forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "CREATE", "TRUNCATE"]
    for keyword in forbidden:
        if keyword in query_upper:
            return False, None, f"Forbidden keyword: {keyword}"
    
    # Must start with SELECT
    if not query_upper.startswith("SELECT"):
        return False, None, "Only SELECT queries are allowed"
    
    try:
        cursor = conn.cursor()
        cursor.execute(query)
        
        # Fetch results with limit
        results = cursor.fetchmany(max_results)
        
        # Get column names
        columns = [desc[0] for desc in cursor.description] if cursor.description else []
        
        # Convert to list of dicts
        results_list = [
            dict(zip(columns, row))
            for row in results
        ]
        
        return True, results_list, ""
    
    except Exception as e:
        return False, None, str(e)


print("✓ Safe SQL execution function defined")

# Test
print("\nTest query: SELECT * FROM Artist LIMIT 3")
success, results, error = execute_sql_safely(conn, "SELECT * FROM Artist LIMIT 3")

if success:
    print("\n✓ Query successful!")
    df = pd.DataFrame(results)
    print(df)
else:
    print(f"\n❌ Query failed: {error}")

## 7. Build Text-to-SQL Pipeline

In [None]:
def text_to_sql_rag(
    question: str,
    conn: sqlite3.Connection,
    schema_retriever,
    llm,
    verbose: bool = False,
) -> Dict[str, Any]:
    """
    Complete Text-to-SQL RAG pipeline.
    
    Args:
        question: Natural language question
        conn: Database connection
        schema_retriever: Schema retriever
        llm: Language model
        verbose: Print debug info
    
    Returns:
        Dict with query, results, answer, etc.
    """
    if verbose:
        print(f"\n[SQL RAG] Question: {question}")
    
    # 1. Retrieve relevant schema
    relevant_schemas = schema_retriever.invoke(question)
    
    schema_context = "\n\n".join([
        doc.metadata["schema"]
        for doc in relevant_schemas
    ])
    
    if verbose:
        tables = [doc.metadata["table_name"] for doc in relevant_schemas]
        print(f"[SQL RAG] Relevant tables: {', '.join(tables)}")
    
    # 2. Generate SQL query
    sql_chain = TEXT_TO_SQL_PROMPT | llm | StrOutputParser()
    sql_query = sql_chain.invoke({
        "schema": schema_context,
        "question": question,
    }).strip()
    
    # Clean SQL (remove markdown formatting if present)
    sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
    
    if verbose:
        print(f"[SQL RAG] Generated SQL: {sql_query}")
    
    # 3. Execute SQL
    success, results, error = execute_sql_safely(conn, sql_query)
    
    if not success:
        if verbose:
            print(f"[SQL RAG] Query failed: {error}")
            print("[SQL RAG] Attempting error recovery...")
        
        # Try error recovery
        recovery_chain = SQL_ERROR_RECOVERY_PROMPT | llm | StrOutputParser()
        corrected_sql = recovery_chain.invoke({
            "question": question,
            "failed_query": sql_query,
            "error": error,
            "schema": schema_context,
        }).strip().replace("```sql", "").replace("```", "").strip()
        
        if verbose:
            print(f"[SQL RAG] Corrected SQL: {corrected_sql}")
        
        success, results, error = execute_sql_safely(conn, corrected_sql)
        
        if success:
            sql_query = corrected_sql
        else:
            return {
                "question": question,
                "sql_query": sql_query,
                "success": False,
                "error": error,
                "answer": f"I couldn't generate a valid SQL query. Error: {error}",
            }
    
    if verbose:
        print(f"[SQL RAG] Query successful! Rows returned: {len(results)}")
    
    # 4. Interpret results
    results_str = json.dumps(results, indent=2) if results else "No results found"
    
    interpret_chain = SQL_RESULTS_INTERPRETATION_PROMPT | llm | StrOutputParser()
    answer = interpret_chain.invoke({
        "question": question,
        "sql_query": sql_query,
        "results": results_str,
    })
    
    return {
        "question": question,
        "sql_query": sql_query,
        "success": True,
        "results": results,
        "answer": answer,
    }


print("✓ Text-to-SQL RAG pipeline defined")

## 8. Test SQL RAG with Various Queries

In [None]:
print_section_header("Testing SQL RAG")

test_questions = [
    "How many customers are in the database?",
    "What are the top 5 longest songs?",
    "Show me the total sales by country, ordered by highest revenue",
    "Which artist has released the most albums?",
    "What's the average track length in minutes?",
]

for i, question in enumerate(test_questions, 1):
    print("\n" + "=" * 80)
    print(f"Question {i}: {question}")
    print("=" * 80)
    
    result = text_to_sql_rag(
        question=question,
        conn=conn,
        schema_retriever=schema_retriever,
        llm=llm,
        verbose=True,
    )
    
    print("\n" + "-" * 80)
    print("ANSWER:")
    print("-" * 80)
    print(result["answer"])
    
    # Show results as DataFrame if available
    if result["success"] and result["results"]:
        print("\n" + "-" * 80)
        print("DATA:")
        print("-" * 80)
        df = pd.DataFrame(result["results"]).head(10)
        print(df)

## 9. Complex Queries with JOINs

In [None]:
print_section_header("Testing Complex Queries with JOINs")

complex_questions = [
    "Show me the top 3 customers by total amount spent",
    "Which genres have the most tracks?",
    "List all employees and how many customers they support",
]

for i, question in enumerate(complex_questions, 1):
    print("\n" + "=" * 80)
    print(f"Complex Query {i}: {question}")
    print("=" * 80)
    
    result = text_to_sql_rag(
        question=question,
        conn=conn,
        schema_retriever=schema_retriever,
        llm=llm,
        verbose=True,
    )
    
    print("\n" + "-" * 80)
    print("ANSWER:")
    print("-" * 80)
    print(result["answer"])
    
    if result["success"] and result["results"]:
        print("\n" + "-" * 80)
        print("DATA:")
        print("-" * 80)
        df = pd.DataFrame(result["results"])
        print(df)

## 10. Performance Metrics

In [None]:
print_section_header("Performance Metrics")

import time

# Benchmark query
benchmark_query = "How many albums are there in total?"

print(f"Benchmark query: {benchmark_query}\n")

start = time.time()
result = text_to_sql_rag(
    question=benchmark_query,
    conn=conn,
    schema_retriever=schema_retriever,
    llm=llm,
    verbose=False,
)
elapsed = time.time() - start

print("=" * 80)
print("PERFORMANCE BREAKDOWN:")
print("=" * 80)
print(f"Total time: {elapsed:.2f}s")
print("\nComponents:")
print("  • Schema retrieval: ~0.5-1.0s (vector search)")
print("  • SQL generation: ~1.0-2.0s (LLM call)")
print("  • SQL execution: <0.1s (fast on indexed DB)")
print("  • Results interpretation: ~1.0-2.0s (LLM call)")

print("\n" + "=" * 80)
print("COST ANALYSIS:")
print("=" * 80)
print("LLM Calls per query:")
print("  • Schema summarization: 1 (one-time, cached)")
print("  • SQL generation: 1")
print("  • Error recovery: 0-1 (if needed)")
print("  • Results interpretation: 1")
print("  • Total: 2-3 LLM calls")
print("\nVector searches: 1 (schema retrieval)")
print("Database queries: 1 (SQL execution)")

## 11. Hybrid Approach: SQL + Vector RAG

For maximum flexibility, combine SQL RAG with traditional vector RAG.

In [None]:
print_section_header("Hybrid SQL + Vector RAG")

print("\nThis approach routes queries to either:")
print("  • SQL RAG: For structured queries (counts, aggregations, filters)")
print("  • Vector RAG: For semantic queries (document search, similarity)")

# Simple query classifier
def classify_query(question: str) -> str:
    """
    Classify if query needs SQL or vector RAG.
    """
    sql_keywords = [
        "how many", "count", "total", "sum", "average", "maximum", "minimum",
        "top", "bottom", "highest", "lowest", "most", "least",
        "by country", "by genre", "by artist", "per",
    ]
    
    question_lower = question.lower()
    
    for keyword in sql_keywords:
        if keyword in question_lower:
            return "SQL"
    
    return "VECTOR"


# Test classification
test_cases = [
    "How many customers are in France?",  # SQL
    "What is LCEL in LangChain?",  # VECTOR
    "Show me the top 5 albums",  # SQL
    "Explain the concept of retrieval",  # VECTOR
]

print("\n" + "=" * 80)
print("Query Classification Examples:")
print("=" * 80)

for query in test_cases:
    classification = classify_query(query)
    print(f"\n'{query}'")
    print(f"  → {classification} RAG")

print("\n" + "=" * 80)
print("Implementation Note:")
print("=" * 80)
print("A production system would:")
print("  1. Use LLM to classify query intent")
print("  2. Route to appropriate RAG system")
print("  3. Fall back to alternative if primary fails")
print("  4. Combine results if needed")

## 12. Key Takeaways

### Summary

**SQL RAG** enables natural language queries over structured databases:
- Schema retrieval finds relevant tables
- Text-to-SQL generates queries with LLM
- Safe execution prevents dangerous operations
- Results interpretation provides natural answers
- Error recovery handles SQL generation failures

### Pipeline Components

1. **Schema Index** (one-time): Embed table descriptions
2. **Schema Retrieval**: Find relevant tables (vector search)
3. **SQL Generation**: LLM creates query with schema context
4. **Validation**: Safety checks (read-only, no injection)
5. **Execution**: Run query in controlled environment
6. **Interpretation**: Convert results to natural language
7. **Error Recovery**: Retry with corrections if needed

### Best Practices

**Schema Design:**
- ✅ Use descriptive table/column names
- ✅ Add comments and documentation
- ✅ Maintain referential integrity
- ✅ Create semantic summaries for tables

**Safety:**
- ✅ Always use read-only connections
- ✅ Whitelist allowed operations (SELECT only)
- ✅ Set query timeouts
- ✅ Limit result set sizes
- ✅ Validate generated SQL

**Error Handling:**
- ✅ Implement retry logic
- ✅ Provide helpful error messages
- ✅ Log failed queries for analysis
- ✅ Fall back to alternative approaches

**Optimization:**
- ✅ Cache schema embeddings
- ✅ Index database properly
- ✅ Use query result caching
- ✅ Batch similar queries

### Limitations

❌ **SQL Generation Challenges:**
- Complex JOINs may fail
- Ambiguous questions lead to wrong queries
- LLM may hallucinate table/column names
- Window functions and CTEs are hard

❌ **Schema Dependency:**
- Requires well-designed schema
- Poor naming leads to poor queries
- Schema changes break queries

### When to Use

Choose **SQL RAG** when:
- ✅ Data is structured in databases
- ✅ Need aggregations and analytics
- ✅ Precision is critical
- ✅ Users ask "how many", "total", "average"
- ✅ Existing database infrastructure

Choose **Vector RAG** when:
- ✅ Unstructured text documents
- ✅ Semantic similarity matters
- ✅ Fuzzy matching needed
- ✅ No clear schema

Choose **Hybrid** when:
- ✅ Both structured and unstructured data
- ✅ Diverse query types
- ✅ Maximum flexibility needed

### Extensions

**Advanced Features:**
- Multi-database support (PostgreSQL, MySQL, etc.)
- Query optimization hints
- Cached query plans
- Natural language query suggestions
- Query history and learning

**Production Enhancements:**
- Query approval workflows
- Result explanation (EXPLAIN)
- Performance monitoring
- A/B testing of prompts
- User feedback loops

---

**Complexity Rating:** ⭐⭐⭐⭐ (High - requires database knowledge + LLM integration)

**Production Readiness:** ⭐⭐⭐ (Medium - needs extensive testing and safety measures)

Continue to **15_graphrag.ipynb** for Graph-based Knowledge Retrieval!

## Cleanup

In [None]:
# Close database connection
conn.close()
print("✓ Database connection closed")