# 📊 Text-to-SQL: Generación de Consultas SQL desde Lenguaje Natural

Objetivo: implementar sistemas NL2SQL (Natural Language to SQL) robustos usando LLMs, con validación, seguridad y optimización.

- Duración: 90-120 min
- Dificultad: Media/Alta
- Prerrequisitos: GenAI 01, SQL intermedio, schemas de base de datos

## 1. Caso base: prompt simple

### 🏗️ **NL2SQL Architecture: From Research to Production**

**¿Qué es NL2SQL?**

Natural Language to SQL (NL2SQL) es un sistema que traduce preguntas en lenguaje natural a consultas SQL ejecutables. Es uno de los use cases más valiosos de LLMs para democratizar acceso a datos.

**Evolution of NL2SQL:**

```
2015-2018: Rule-Based Systems
├── Regex patterns + templates
├── Limited vocabulary
└── Accuracy: ~30-40%

2018-2020: Seq2Seq Models
├── LSTM/GRU encoders
├── Better generalization
└── Accuracy: ~50-60%

2020-2023: Pre-trained LLMs
├── BERT → T5 → GPT-3
├── Spider benchmark: 70%+
└── Fine-tuned models (SQLCoder)

2023-Present: GPT-4 + RAG
├── Zero-shot: 80-85%
├── Few-shot: 85-90%
└── Production-ready with validation
```

**Architecture Components:**

```
┌─────────────────────────────────────────────────────────────┐
│                    NL2SQL System                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. Schema Retrieval (RAG)                                  │
│     ┌───────────────────────────────────────┐              │
│     │ Question → Embed → Search Vector DB   │              │
│     │ Return: Relevant tables + columns     │              │
│     └───────────────────────────────────────┘              │
│                        ↓                                     │
│  2. Prompt Construction                                     │
│     ┌───────────────────────────────────────┐              │
│     │ • Schema context (DDL + samples)      │              │
│     │ • Few-shot examples                   │              │
│     │ • Dialect-specific instructions       │              │
│     │ • Query constraints (LIMIT, timeout)  │              │
│     └───────────────────────────────────────┘              │
│                        ↓                                     │
│  3. LLM Generation                                          │
│     ┌───────────────────────────────────────┐              │
│     │ GPT-4o / Claude 3.5 / SQLCoder        │              │
│     │ Temperature: 0 (deterministic)        │              │
│     └───────────────────────────────────────┘              │
│                        ↓                                     │
│  4. Validation & Safety                                     │
│     ┌───────────────────────────────────────┐              │
│     │ • SQL parsing (sqlparse)              │              │
│     │ • Read-only check                     │              │
│     │ • Injection prevention                │              │
│     │ • EXPLAIN plan analysis               │              │
│     └───────────────────────────────────────┘              │
│                        ↓                                     │
│  5. Execution Engine                                        │
│     ┌───────────────────────────────────────┐              │
│     │ • Dry run mode (EXPLAIN only)         │              │
│     │ • Timeout enforcement                 │              │
│     │ • Result caching                      │              │
│     │ • Error handling                      │              │
│     └───────────────────────────────────────┘              │
│                        ↓                                     │
│  6. Post-Processing                                         │
│     ┌───────────────────────────────────────┐              │
│     │ • Result formatting                   │              │
│     │ • Aggregation                         │              │
│     │ • Visualization hints                 │              │
│     └───────────────────────────────────────┘              │
└─────────────────────────────────────────────────────────────┘
```

**Schema Representation Formats:**

**1. DDL (Data Definition Language):**
```sql
CREATE TABLE sales (
    sale_id INT PRIMARY KEY,
    customer_id INT NOT NULL REFERENCES customers(customer_id),
    product_id INT NOT NULL REFERENCES products(product_id),
    sale_date DATE NOT NULL,
    quantity INT CHECK (quantity > 0),
    total_amount DECIMAL(10,2),
    status VARCHAR(20) DEFAULT 'pending',
    INDEX idx_customer (customer_id),
    INDEX idx_date (sale_date)
);
```

**2. JSON Schema:**
```json
{
    "tables": {
        "sales": {
            "columns": {
                "sale_id": {"type": "INT", "primary_key": true},
                "customer_id": {"type": "INT", "foreign_key": "customers.customer_id"},
                "sale_date": {"type": "DATE", "nullable": false},
                "total_amount": {"type": "DECIMAL(10,2)"}
            },
            "indexes": ["idx_customer", "idx_date"],
            "row_count": 1500000,
            "description": "Stores all sales transactions"
        }
    }
}
```

**3. Natural Language Description:**
```python
schema_description = """
Database: E-commerce Sales System

Tables:
1. sales (1.5M rows)
   - sale_id: unique identifier for each sale
   - customer_id: links to customers table
   - product_id: links to products table
   - sale_date: when the sale occurred
   - quantity: number of items sold
   - total_amount: total price in USD
   - status: pending, completed, or cancelled

2. customers (50K rows)
   - customer_id: unique customer identifier
   - name: customer full name
   - email: contact email
   - country: customer location
   - signup_date: when they registered

3. products (5K rows)
   - product_id: unique product identifier
   - name: product name
   - category: electronics, clothing, books, etc.
   - price: unit price in USD
   - stock_quantity: current inventory
"""
```

**Schema Retrieval with RAG:**

```python
from sentence_transformers import SentenceTransformer
import chromadb

# Embed schema components
model = SentenceTransformer('all-MiniLM-L6-v2')

# ChromaDB for schema storage
chroma_client = chromadb.Client()
schema_collection = chroma_client.create_collection("database_schema")

# Index tables with descriptions
tables_data = [
    {
        "id": "sales",
        "text": "sales table stores transactions with customer_id, product_id, date, quantity, amount",
        "metadata": {"type": "table", "row_count": 1500000}
    },
    {
        "id": "customers",
        "text": "customers table has customer info: name, email, country, signup_date",
        "metadata": {"type": "table", "row_count": 50000}
    }
]

for table in tables_data:
    schema_collection.add(
        ids=[table["id"]],
        documents=[table["text"]],
        metadatas=[table["metadata"]]
    )

# Query relevant schema
def get_relevant_schema(question: str, top_k: int = 3):
    """Retrieve relevant tables for question"""
    results = schema_collection.query(
        query_texts=[question],
        n_results=top_k
    )
    
    relevant_tables = results['ids'][0]
    return relevant_tables

# Example
question = "What are the top selling products last month?"
relevant = get_relevant_schema(question)
print(f"Relevant tables: {relevant}")
# Output: ['sales', 'products']
```

**Model Comparison for NL2SQL:**

| Model | Accuracy (Spider) | Context | Cost | Best For |
|-------|------------------|---------|------|----------|
| GPT-4o | 85-90% | 128K | $$ | Complex queries, multi-table joins |
| GPT-4o-mini | 75-80% | 128K | $ | Simple queries, high volume |
| Claude 3.5 Sonnet | 88-92% | 200K | $$$ | Best accuracy, long schemas |
| Gemini 1.5 Flash | 80-85% | 1M | $ | Huge schemas, fast |
| SQLCoder 70B | 82-85% | 8K | Free | Self-hosted, fine-tuned |
| Llama 3 70B + LoRA | 75-80% | 8K | Free | Budget option, trainable |

**SQLCoder (Fine-tuned Open Source):**

```python
# Defog SQLCoder - specialized for SQL generation
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("defog/sqlcoder-70b-alpha")
model = AutoModelForCausalLM.from_pretrained(
    "defog/sqlcoder-70b-alpha",
    device_map="auto",
    load_in_8bit=True  # Quantization para reducir VRAM
)

prompt = f"""### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Database Schema
{schema}

### SQL
SELECT"""

inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=300)
sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
```

**Advantages of SQLCoder:**
- ✅ No API costs (run locally)
- ✅ Data privacy (nothing sent externally)
- ✅ Fine-tunable on your schema
- ❌ Requires GPU (70B model needs 48GB VRAM or quantization)
- ❌ Slower inference than API calls

**Prompt Engineering for NL2SQL:**

```python
from string import Template

NL2SQL_PROMPT = Template("""
You are an expert $dialect SQL developer. Generate a query to answer the question.

### Database Schema
$schema

### Instructions
- Use table aliases for readability
- Add appropriate JOINs based on foreign keys
- Include LIMIT clause (max $max_rows rows)
- Use aggregate functions when appropriate
- Handle NULL values properly
- Return ONLY the SQL query, no explanations

### Examples
$examples

### Question
$question

### SQL Query
""")

# Usage
prompt = NL2SQL_PROMPT.substitute(
    dialect="PostgreSQL",
    schema=schema_ddl,
    max_rows=1000,
    examples=few_shot_examples,
    question="Show revenue by product category for Q4 2024"
)
```

**Handling Complex Queries:**

```python
# Multi-step query decomposition
def complex_nl2sql(question: str, schema: str) -> str:
    """For complex questions, break into sub-queries"""
    
    decomposition_prompt = f"""
    Question: {question}
    
    Is this a complex question requiring multiple steps?
    If yes, break it down into sub-questions.
    If no, answer "simple".
    
    Format:
    Type: simple | complex
    Sub-questions (if complex):
    1. ...
    2. ...
    """
    
    response = ask_llm(decomposition_prompt, temperature=0)
    
    if "simple" in response.lower():
        return nl_to_sql(question, schema)
    
    # Generate SQL for each sub-question
    sub_questions = extract_sub_questions(response)
    sub_sqls = [nl_to_sql(sq, schema) for sq in sub_questions]
    
    # Combine with CTEs
    combined_prompt = f"""
    Combine these sub-queries into a single query using CTEs:
    
    {chr(10).join(f"-- Step {i+1}: {sq}" for i, sq in enumerate(sub_questions))}
    {chr(10).join(sub_sqls)}
    
    Generate final query with CTEs:
    """
    
    return ask_llm(combined_prompt, temperature=0)

# Example
question = "Compare average order value between new and returning customers for each product category"
# → Breaks into: 1) Identify new vs returning, 2) Calculate AOV, 3) Group by category
sql = complex_nl2sql(question, schema)
```

---
**Autor:** Luis J. Raigoso V. (LJRV)

### 🛡️ **Security & Validation: Preventing SQL Injection**

**Security Threats in NL2SQL:**

1. **SQL Injection via Prompt**: Usuario manipula pregunta para inyectar código
2. **Data Exfiltration**: Query extrae datos sensibles no autorizados
3. **Denial of Service**: Queries costosas que consumen recursos
4. **Schema Exposure**: Revelar estructura de base de datos

**Defense Layers:**

```
┌────────────────────────────────────────────┐
│         Security Architecture              │
├────────────────────────────────────────────┤
│                                            │
│  Layer 1: Input Sanitization              │
│  ├─ Remove SQL keywords from question     │
│  ├─ Escape special characters             │
│  └─ Length limits                          │
│                                            │
│  Layer 2: LLM Guardrails                  │
│  ├─ System prompt constraints             │
│  ├─ Read-only emphasis                    │
│  └─ Whitelist operations                   │
│                                            │
│  Layer 3: SQL Parsing & Analysis          │
│  ├─ sqlparse validation                   │
│  ├─ AST inspection                        │
│  └─ Keyword blacklist                      │
│                                            │
│  Layer 4: Query Plan Analysis             │
│  ├─ EXPLAIN cost estimation               │
│  ├─ Timeout prediction                    │
│  └─ Resource limits                        │
│                                            │
│  Layer 5: Execution Sandbox               │
│  ├─ Read-only user                        │
│  ├─ Row limit enforcement                 │
│  ├─ Timeout killer                        │
│  └─ Query logging                          │
└────────────────────────────────────────────┘
```

**Implementation:**

**1. Input Sanitization:**

```python
import re
from typing import Tuple

def sanitize_question(question: str) -> Tuple[str, bool]:
    """Remove potentially dangerous patterns from user input"""
    
    # Check length
    if len(question) > 500:
        return "", False
    
    # Suspicious patterns
    dangerous_patterns = [
        r";\s*(drop|delete|update|insert|alter|create|truncate)",  # Multiple statements
        r"--",  # SQL comments
        r"/\*.*\*/",  # Block comments
        r"union\s+select",  # Union injection
        r"exec\s*\(",  # Stored proc execution
        r"xp_cmdshell",  # System commands
    ]
    
    question_lower = question.lower()
    for pattern in dangerous_patterns:
        if re.search(pattern, question_lower):
            return "", False
    
    # Remove excessive whitespace
    cleaned = re.sub(r'\s+', ' ', question).strip()
    
    return cleaned, True

# Example
test_questions = [
    "Show top 10 sales",  # Safe
    "Show sales; DROP TABLE customers; --",  # SQL injection
    "SELECT * FROM users UNION SELECT password FROM admin",  # Union injection
]

for q in test_questions:
    cleaned, safe = sanitize_question(q)
    print(f"{'✅' if safe else '❌'} {q[:50]}")
```

**2. SQL Validation with sqlparse:**

```python
import sqlparse
from sqlparse.sql import IdentifierList, Identifier, Where
from sqlparse.tokens import Keyword, DML

def validate_sql_safety(sql: str) -> Tuple[bool, str]:
    """Comprehensive SQL safety validation"""
    
    # Parse SQL
    try:
        parsed = sqlparse.parse(sql)
        if not parsed:
            return False, "Invalid SQL syntax"
    except Exception as e:
        return False, f"Parsing error: {e}"
    
    statement = parsed[0]
    
    # 1. Check if read-only (SELECT only)
    if statement.get_type() != 'SELECT':
        return False, f"Only SELECT queries allowed, got: {statement.get_type()}"
    
    # 2. Check for nested dangerous operations
    sql_lower = sql.lower()
    write_operations = ['insert', 'update', 'delete', 'drop', 'create', 'alter', 'truncate', 'replace']
    
    for op in write_operations:
        # Check for operation as separate word
        if re.search(rf'\b{op}\b', sql_lower):
            return False, f"Forbidden operation: {op.upper()}"
    
    # 3. Check for system functions/procedures
    dangerous_functions = [
        'exec', 'execute', 'xp_cmdshell', 'sp_executesql',
        'dbms_output', 'utl_file', 'load_file', 'outfile'
    ]
    
    for func in dangerous_functions:
        if func in sql_lower:
            return False, f"Forbidden function: {func}"
    
    # 4. Check for multiple statements (SQL injection)
    if len(parsed) > 1:
        return False, "Multiple statements not allowed"
    
    # 5. Verify LIMIT clause exists (prevent large result sets)
    has_limit = 'limit' in sql_lower or 'fetch first' in sql_lower or 'top' in sql_lower
    if not has_limit:
        return False, "Query must include LIMIT clause"
    
    # 6. Check for UNION (common injection vector)
    if 'union' in sql_lower:
        # UNION is valid but requires manual review
        return False, "UNION queries require manual approval"
    
    return True, "Query validated"

# Test cases
test_sqls = [
    "SELECT * FROM sales LIMIT 100",  # Safe
    "SELECT * FROM sales; DROP TABLE customers;",  # Multiple statements
    "SELECT * FROM sales WHERE id = (SELECT password FROM users)",  # Subquery injection
    "DELETE FROM sales WHERE id = 1",  # Write operation
    "SELECT * FROM sales",  # No LIMIT
]

for sql in test_sqls:
    safe, message = validate_sql_safety(sql)
    print(f"{'✅' if safe else '❌'} {message}")
    print(f"   {sql[:60]}\n")
```

**3. Query Cost Estimation (EXPLAIN):**

```python
import psycopg2

def estimate_query_cost(sql: str, conn) -> Tuple[bool, dict]:
    """Use EXPLAIN to estimate query cost before execution"""
    
    try:
        # Get query plan
        explain_sql = f"EXPLAIN (FORMAT JSON, ANALYZE false) {sql}"
        
        with conn.cursor() as cur:
            cur.execute(explain_sql)
            plan = cur.fetchone()[0][0]
        
        # Extract cost metrics
        total_cost = plan['Plan']['Total Cost']
        estimated_rows = plan['Plan']['Plan Rows']
        
        # Define thresholds
        MAX_COST = 10000  # Arbitrary units
        MAX_ROWS = 100000
        
        if total_cost > MAX_COST:
            return False, {
                "reason": "Query cost too high",
                "cost": total_cost,
                "threshold": MAX_COST
            }
        
        if estimated_rows > MAX_ROWS:
            return False, {
                "reason": "Too many rows",
                "rows": estimated_rows,
                "threshold": MAX_ROWS
            }
        
        return True, {
            "cost": total_cost,
            "rows": estimated_rows,
            "approved": True
        }
        
    except Exception as e:
        return False, {"error": str(e)}

# Example usage
sql = "SELECT * FROM large_table WHERE date > '2020-01-01' LIMIT 1000"
conn = psycopg2.connect("dbname=mydb")

approved, metrics = estimate_query_cost(sql, conn)
if approved:
    print(f"✅ Query approved: {metrics}")
else:
    print(f"❌ Query rejected: {metrics}")
```

**4. Read-Only Database User:**

```sql
-- Create read-only role for NL2SQL execution
CREATE ROLE nl2sql_readonly;

-- Grant SELECT only on specific schemas
GRANT USAGE ON SCHEMA public TO nl2sql_readonly;
GRANT SELECT ON ALL TABLES IN SCHEMA public TO nl2sql_readonly;

-- Deny write operations explicitly
REVOKE INSERT, UPDATE, DELETE, TRUNCATE ON ALL TABLES IN SCHEMA public FROM nl2sql_readonly;

-- Create user
CREATE USER nl2sql_user WITH PASSWORD 'secure_password';
GRANT nl2sql_readonly TO nl2sql_user;

-- Set session limits
ALTER ROLE nl2sql_user SET statement_timeout = '10s';
ALTER ROLE nl2sql_user SET lock_timeout = '5s';
```

**5. Execution with Timeout & Limits:**

```python
import signal
from contextlib import contextmanager

class QueryTimeout(Exception):
    pass

@contextmanager
def query_timeout(seconds: int):
    """Context manager to enforce query timeout"""
    def timeout_handler(signum, frame):
        raise QueryTimeout(f"Query exceeded {seconds}s timeout")
    
    # Set alarm
    old_handler = signal.signal(signal.SIGALRM, timeout_handler)
    signal.alarm(seconds)
    
    try:
        yield
    finally:
        signal.alarm(0)
        signal.signal(signal.SIGALRM, old_handler)

def execute_safe_query(sql: str, conn, max_rows: int = 1000, timeout_sec: int = 10):
    """Execute query with safety constraints"""
    
    # Validate SQL
    safe, message = validate_sql_safety(sql)
    if not safe:
        raise ValueError(f"Unsafe query: {message}")
    
    # Force LIMIT if not present
    sql_lower = sql.lower()
    if 'limit' not in sql_lower:
        sql += f" LIMIT {max_rows}"
    
    try:
        with query_timeout(timeout_sec):
            df = pd.read_sql_query(sql, conn)
        
        # Additional row limit check
        if len(df) > max_rows:
            df = df.head(max_rows)
            print(f"⚠️ Results truncated to {max_rows} rows")
        
        return df
        
    except QueryTimeout as e:
        raise TimeoutError(f"Query timeout: {e}")
    except Exception as e:
        raise RuntimeError(f"Query execution failed: {e}")

# Example
sql = "SELECT * FROM sales WHERE date >= '2024-01-01'"
df = execute_safe_query(sql, conn, max_rows=500, timeout_sec=5)
```

**6. Audit Logging:**

```python
import logging
import json
from datetime import datetime

# Configure structured logging
logging.basicConfig(
    level=logging.INFO,
    format='%(message)s'
)
logger = logging.getLogger('nl2sql')

def log_nl2sql_query(
    user_id: str,
    question: str,
    sql_generated: str,
    execution_status: str,
    rows_returned: int,
    execution_time_ms: float,
    error: str = None
):
    """Log all NL2SQL interactions for audit"""
    
    log_entry = {
        "timestamp": datetime.utcnow().isoformat(),
        "user_id": user_id,
        "question_hash": hashlib.sha256(question.encode()).hexdigest()[:16],
        "question_length": len(question),
        "sql_hash": hashlib.sha256(sql_generated.encode()).hexdigest()[:16],
        "execution_status": execution_status,  # success | validation_failed | timeout | error
        "rows_returned": rows_returned,
        "execution_time_ms": execution_time_ms,
        "error": error
    }
    
    # Only log full text in dev environment
    if os.getenv("ENV") == "dev":
        log_entry["question"] = question[:200]
        log_entry["sql"] = sql_generated[:500]
    
    logger.info(json.dumps(log_entry))

# Usage
log_nl2sql_query(
    user_id="user_123",
    question="Show top 10 customers by revenue",
    sql_generated="SELECT customer_id, SUM(total) FROM sales GROUP BY customer_id ORDER BY 2 DESC LIMIT 10",
    execution_status="success",
    rows_returned=10,
    execution_time_ms=45.2
)
```

**7. Human-in-the-Loop for High-Risk Queries:**

```python
from enum import Enum

class RiskLevel(Enum):
    LOW = "low"
    MEDIUM = "medium"
    HIGH = "high"

def assess_query_risk(sql: str) -> RiskLevel:
    """Determine if query needs human approval"""
    
    sql_lower = sql.lower()
    
    # High risk indicators
    high_risk = [
        'union',  # Complex joins
        'cross join',  # Cartesian products
        'recursive',  # CTEs
        'window function',  # Complex analytics
    ]
    
    # Medium risk indicators
    medium_risk = [
        'subquery',
        'having',
        'case when',
        'join' in sql_lower and sql_lower.count('join') > 2,  # Multiple joins
    ]
    
    for indicator in high_risk:
        if indicator in sql_lower:
            return RiskLevel.HIGH
    
    for indicator in medium_risk:
        if indicator in sql_lower:
            return RiskLevel.MEDIUM
    
    return RiskLevel.LOW

def execute_with_approval(question: str, sql: str, conn):
    """Execute query with risk-based approval"""
    
    risk = assess_query_risk(sql)
    
    if risk == RiskLevel.HIGH:
        print(f"⚠️ HIGH RISK query requires manual approval:")
        print(f"Question: {question}")
        print(f"SQL: {sql}\n")
        
        approval = input("Approve execution? (yes/no): ")
        if approval.lower() != "yes":
            return None
    
    elif risk == RiskLevel.MEDIUM:
        print(f"⚠️ MEDIUM RISK query - review recommended")
        print(f"SQL: {sql[:100]}...\n")
    
    # Execute approved query
    return execute_safe_query(sql, conn)
```

**Security Checklist:**

- ✅ Input sanitization (length, dangerous patterns)
- ✅ SQL parsing validation (sqlparse)
- ✅ Read-only database user
- ✅ LIMIT clause enforcement
- ✅ Query timeout (10s default)
- ✅ EXPLAIN cost estimation
- ✅ Audit logging (all queries)
- ✅ Human approval for high-risk queries
- ✅ Rate limiting per user
- ✅ Row count limits (1000 default)

---
**Autor:** Luis J. Raigoso V. (LJRV)

### 🎯 **Advanced Techniques: Few-Shot, Chain-of-Thought & Self-Correction**

**1. Few-Shot Learning for NL2SQL**

Few-shot examples dramatically improve accuracy by showing the LLM the desired output format and handling of edge cases.

**Example Selection Strategies:**

```python
from sentence_transformers import SentenceTransformer
import numpy as np

class FewShotSelector:
    """Dynamically select most relevant examples for each question"""
    
    def __init__(self, example_pool: list):
        self.model = SentenceTransformer('all-MiniLM-L6-v2')
        self.examples = example_pool
        
        # Pre-compute embeddings for all examples
        self.embeddings = self.model.encode([ex['question'] for ex in self.examples])
    
    def select_examples(self, question: str, k: int = 3) -> list:
        """Select k most similar examples"""
        
        # Embed query
        query_embedding = self.model.encode([question])
        
        # Compute cosine similarity
        similarities = np.dot(self.embeddings, query_embedding.T).flatten()
        
        # Get top k indices
        top_indices = np.argsort(similarities)[-k:][::-1]
        
        return [self.examples[i] for i in top_indices]

# Example pool with diverse query types
example_pool = [
    {
        "question": "How many orders were placed yesterday?",
        "sql": "SELECT COUNT(*) FROM orders WHERE DATE(created_at) = CURRENT_DATE - INTERVAL '1 day';",
        "explanation": "Count aggregation with date filter"
    },
    {
        "question": "Top 5 customers by total spend",
        "sql": """
SELECT 
    c.customer_id,
    c.name,
    SUM(o.total_amount) as total_spent
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
GROUP BY c.customer_id, c.name
ORDER BY total_spent DESC
LIMIT 5;
""",
        "explanation": "Aggregation with JOIN and ORDER BY"
    },
    {
        "question": "Products without any sales this month",
        "sql": """
SELECT p.product_id, p.name
FROM products p
LEFT JOIN orders o ON p.product_id = o.product_id 
    AND DATE_TRUNC('month', o.created_at) = DATE_TRUNC('month', CURRENT_DATE)
WHERE o.order_id IS NULL
LIMIT 100;
""",
        "explanation": "LEFT JOIN with NULL check for missing records"
    },
    {
        "question": "Running total of sales by date",
        "sql": """
SELECT 
    DATE(created_at) as sale_date,
    SUM(total_amount) as daily_total,
    SUM(SUM(total_amount)) OVER (ORDER BY DATE(created_at)) as running_total
FROM orders
GROUP BY DATE(created_at)
ORDER BY sale_date
LIMIT 365;
""",
        "explanation": "Window function for running totals"
    },
    {
        "question": "Customers who haven't ordered in 90 days",
        "sql": """
SELECT 
    c.customer_id,
    c.name,
    MAX(o.created_at) as last_order_date,
    CURRENT_DATE - MAX(o.created_at)::date as days_since_order
FROM customers c
LEFT JOIN orders o ON c.customer_id = o.customer_id
GROUP BY c.customer_id, c.name
HAVING MAX(o.created_at) < CURRENT_DATE - INTERVAL '90 days'
    OR MAX(o.created_at) IS NULL
ORDER BY last_order_date NULLS FIRST
LIMIT 100;
""",
        "explanation": "HAVING clause with date comparison"
    }
]

# Usage
selector = FewShotSelector(example_pool)
question = "Show customers with no recent purchases"
relevant_examples = selector.select_examples(question, k=2)

print("Selected examples:")
for ex in relevant_examples:
    print(f"\nQ: {ex['question']}")
    print(f"SQL: {ex['sql'][:100]}...")
```

**Few-Shot Prompt Construction:**

```python
def build_fewshot_prompt(question: str, schema: str, examples: list, dialect: str = "PostgreSQL") -> str:
    """Construct optimized few-shot prompt"""
    
    examples_text = "\n\n".join([
        f"Example {i+1}:\n"
        f"Question: {ex['question']}\n"
        f"SQL:\n{ex['sql'].strip()}\n"
        f"Why: {ex['explanation']}"
        for i, ex in enumerate(examples)
    ])
    
    prompt = f"""You are an expert {dialect} developer. Generate SQL to answer questions.

Database Schema:
{schema}

{examples_text}

Now generate SQL for this question:
Question: {question}

Requirements:
- Use table aliases (e.g., FROM orders o)
- Include LIMIT clause (max 1000 rows)
- Handle NULL values appropriately
- Use appropriate indexes mentioned in schema
- Return ONLY the SQL query

SQL:"""
    
    return prompt

# Example
question = "Which product categories have declining sales this quarter vs last quarter?"
examples = selector.select_examples(question, k=2)
prompt = build_fewshot_prompt(question, schema_ddl, examples)

sql = ask_llm(prompt, temperature=0)
```

---

**2. Chain-of-Thought for Complex Queries**

Break complex questions into steps, generating SQL iteratively.

```python
def nl2sql_with_cot(question: str, schema: str) -> dict:
    """Generate SQL with chain-of-thought reasoning"""
    
    # Step 1: Decompose question
    decomposition_prompt = f"""
Analyze this database question step-by-step:

Question: {question}

Schema: {schema}

Break down the question into logical steps:
1. What tables are needed?
2. What joins are required?
3. What filters should be applied?
4. What aggregations are needed?
5. What is the final output format?

Provide step-by-step reasoning:
"""
    
    reasoning = ask_llm(decomposition_prompt, temperature=0)
    
    # Step 2: Generate SQL with reasoning
    sql_prompt = f"""
Based on this reasoning:

{reasoning}

Generate the final SQL query for PostgreSQL:

Requirements:
- Include comments for each major section
- Use CTEs for complex logic
- Add LIMIT clause

SQL:
"""
    
    sql = ask_llm(sql_prompt, temperature=0)
    
    return {
        "sql": sql,
        "reasoning": reasoning,
        "complexity": "high" if "cte" in sql.lower() or "window" in sql.lower() else "medium"
    }

# Example
question = """
Compare the average order value and customer retention rate 
between customers acquired via different marketing channels, 
but only for customers who have made at least 3 purchases
"""

result = nl2sql_with_cot(question, schema_ddl)
print("Reasoning:")
print(result["reasoning"])
print("\nGenerated SQL:")
print(result["sql"])
```

---

**3. Self-Correction with Execution Feedback**

Let the LLM fix its own errors by providing execution feedback.

```python
def nl2sql_with_self_correction(question: str, schema: str, conn, max_attempts: int = 3) -> dict:
    """Generate SQL with self-correction loop"""
    
    attempts = []
    
    for attempt in range(max_attempts):
        # Generate SQL
        if attempt == 0:
            # First attempt: standard generation
            sql = nl_to_sql(question, schema)
        else:
            # Subsequent attempts: include error feedback
            correction_prompt = f"""
Previous attempt failed with error:

Question: {question}
Schema: {schema}

Previous SQL (FAILED):
{attempts[-1]['sql']}

Error:
{attempts[-1]['error']}

Generate a corrected SQL query that fixes this error.
Analyze the error message and adjust the query accordingly.

Corrected SQL:
"""
            sql = ask_llm(correction_prompt, temperature=0.1)
        
        # Validate and execute
        safe, validation_msg = validate_sql_safety(sql)
        
        if not safe:
            attempts.append({
                "attempt": attempt + 1,
                "sql": sql,
                "error": f"Validation failed: {validation_msg}",
                "success": False
            })
            continue
        
        # Try execution
        try:
            df = pd.read_sql_query(sql, conn)
            
            # Success!
            return {
                "sql": sql,
                "result": df,
                "attempts": attempt + 1,
                "success": True,
                "history": attempts
            }
            
        except Exception as e:
            error_msg = str(e)
            attempts.append({
                "attempt": attempt + 1,
                "sql": sql,
                "error": error_msg,
                "success": False
            })
    
    # All attempts failed
    return {
        "sql": None,
        "result": None,
        "attempts": max_attempts,
        "success": False,
        "history": attempts,
        "final_error": "Max retry attempts reached"
    }

# Example
question = "Show monthly revenue trend for 2024"
result = nl2sql_with_self_correction(question, schema_ddl, conn)

if result["success"]:
    print(f"✅ Success after {result['attempts']} attempt(s)")
    print(result["result"].head())
else:
    print(f"❌ Failed after {result['attempts']} attempts")
    for hist in result["history"]:
        print(f"\nAttempt {hist['attempt']}: {hist['error']}")
```

---

**4. SQL Optimization with LLM**

```python
def optimize_generated_sql(sql: str, schema: str, conn) -> dict:
    """Analyze and optimize generated SQL"""
    
    # Get EXPLAIN plan
    try:
        explain_query = f"EXPLAIN (FORMAT JSON, ANALYZE false) {sql}"
        with conn.cursor() as cur:
            cur.execute(explain_query)
            plan = cur.fetchone()[0][0]
        
        cost = plan['Plan']['Total Cost']
        rows = plan['Plan']['Plan Rows']
        
        # If cost is high, ask LLM to optimize
        if cost > 1000:  # Threshold
            optimization_prompt = f"""
This SQL query has high cost ({cost:.0f}) and may be slow:

{sql}

Schema with indexes:
{schema}

Query execution plan shows:
- Total Cost: {cost}
- Estimated Rows: {rows}

Optimize this query:
1. Add appropriate indexes hints
2. Rewrite inefficient subqueries
3. Consider materialized CTEs
4. Optimize JOIN order

Provide optimized SQL:
"""
            
            optimized_sql = ask_llm(optimization_prompt, temperature=0)
            
            # Compare costs
            explain_optimized = f"EXPLAIN (FORMAT JSON, ANALYZE false) {optimized_sql}"
            with conn.cursor() as cur:
                cur.execute(explain_optimized)
                plan_opt = cur.fetchone()[0][0]
            
            cost_opt = plan_opt['Plan']['Total Cost']
            
            return {
                "original_sql": sql,
                "original_cost": cost,
                "optimized_sql": optimized_sql,
                "optimized_cost": cost_opt,
                "improvement": f"{((cost - cost_opt) / cost * 100):.1f}%"
            }
        
        return {
            "sql": sql,
            "cost": cost,
            "optimization": "Not needed (cost acceptable)"
        }
        
    except Exception as e:
        return {"error": str(e)}

# Example
sql = "SELECT * FROM orders o JOIN customers c ON o.customer_id = c.customer_id WHERE o.date > '2024-01-01'"
result = optimize_generated_sql(sql, schema_ddl, conn)

if "improvement" in result:
    print(f"Optimization improved cost by {result['improvement']}")
```

---

**5. Multi-Turn Conversation (Refinement)**

```python
class NL2SQLConversation:
    """Multi-turn conversation for query refinement"""
    
    def __init__(self, schema: str, conn):
        self.schema = schema
        self.conn = conn
        self.history = []
    
    def query(self, question: str) -> dict:
        """Process question with conversation context"""
        
        # Build context from history
        context = self._build_context()
        
        # Generate SQL
        prompt = f"""
{context}

Schema: {self.schema}

User Question: {question}

Generate SQL considering previous conversation context.

SQL:
"""
        
        sql = ask_llm(prompt, temperature=0)
        
        # Execute
        try:
            df = pd.read_sql_query(sql, self.conn)
            
            # Store in history
            self.history.append({
                "question": question,
                "sql": sql,
                "result_preview": df.head(3).to_dict(),
                "row_count": len(df)
            })
            
            return {"sql": sql, "result": df, "success": True}
            
        except Exception as e:
            return {"error": str(e), "success": False}
    
    def _build_context(self) -> str:
        """Build conversation context"""
        if not self.history:
            return ""
        
        context = "Previous conversation:\n"
        for i, turn in enumerate(self.history[-3:], 1):  # Last 3 turns
            context += f"\n{i}. Q: {turn['question']}"
            context += f"\n   SQL: {turn['sql'][:100]}..."
            context += f"\n   Returned {turn['row_count']} rows\n"
        
        return context

# Example conversation
conv = NL2SQLConversation(schema_ddl, conn)

# Turn 1
result1 = conv.query("Show top 10 products by sales")
print(result1["result"])

# Turn 2 (refinement with context)
result2 = conv.query("Now filter to only electronics category")
print(result2["result"])

# Turn 3 (further refinement)
result3 = conv.query("And sort by profit margin instead")
print(result3["result"])
```

---

**6. Ensemble Methods (Multiple LLMs)**

```python
def nl2sql_ensemble(question: str, schema: str, models: list = ["gpt-4o", "claude-3-5-sonnet", "gemini-1.5-pro"]) -> dict:
    """Generate SQL with multiple models and vote"""
    
    results = []
    
    for model in models:
        try:
            sql = nl_to_sql(question, schema, model=model)
            results.append({
                "model": model,
                "sql": sql,
                "sql_hash": hashlib.sha256(sql.encode()).hexdigest()
            })
        except Exception as e:
            print(f"Model {model} failed: {e}")
    
    # Count identical queries (consensus)
    from collections import Counter
    sql_hashes = [r["sql_hash"] for r in results]
    consensus = Counter(sql_hashes).most_common(1)[0]
    
    consensus_count = consensus[1]
    consensus_hash = consensus[0]
    
    # Get SQL for consensus
    consensus_sql = next(r["sql"] for r in results if r["sql_hash"] == consensus_hash)
    
    return {
        "consensus_sql": consensus_sql,
        "agreement": f"{consensus_count}/{len(models)}",
        "all_results": results
    }

# Example
question = "Monthly active users in Q4 2024"
result = nl2sql_ensemble(question, schema_ddl)

print(f"Consensus ({result['agreement']} models agree):")
print(result["consensus_sql"])
```

---
**Autor:** Luis J. Raigoso V. (LJRV)

### 📊 **Evaluation & Production Metrics: Measuring NL2SQL Quality**

**Evaluation Challenges:**

NL2SQL evaluation es complejo porque:
1. **Múltiples SQL válidos**: Misma pregunta → diferentes queries correctas
2. **Equivalencia semántica**: Queries sintácticamente diferentes pero semánticamente idénticas
3. **Partial correctness**: Query puede ser parcialmente correcta

**Evaluation Metrics:**

```
┌─────────────────────────────────────────────────┐
│          NL2SQL Evaluation Metrics              │
├─────────────────────────────────────────────────┤
│                                                 │
│  1. Exact Match (EM)                            │
│     SQL generado == SQL esperado                │
│     ❌ Demasiado estricto                        │
│                                                 │
│  2. Execution Accuracy (EX)                     │
│     Result(SQL_gen) == Result(SQL_gold)         │
│     ✅ Mejor métrica práctica                    │
│                                                 │
│  3. Component Match                             │
│     SELECT clause: 90%                          │
│     FROM clause: 95%                            │
│     WHERE clause: 85%                           │
│     GROUP BY: 80%                               │
│     ORDER BY: 75%                               │
│                                                 │
│  4. Test Suite Execution (TSE)                  │
│     % of test cases passed                      │
│     ✅ Robusto a variaciones sintácticas         │
└─────────────────────────────────────────────────┘
```

**1. Execution Accuracy (Gold Standard):**

```python
import pandas as pd
from typing import Tuple

def execution_accuracy(
    generated_sql: str,
    gold_sql: str,
    conn,
    timeout: int = 10
) -> Tuple[bool, dict]:
    """Compare results of generated vs gold SQL"""
    
    try:
        # Execute both queries
        df_gen = pd.read_sql_query(generated_sql, conn)
        df_gold = pd.read_sql_query(gold_sql, conn)
        
        # Sort both dataframes for comparison
        df_gen_sorted = df_gen.sort_values(by=list(df_gen.columns)).reset_index(drop=True)
        df_gold_sorted = df_gold.sort_values(by=list(df_gold.columns)).reset_index(drop=True)
        
        # Compare
        if df_gen_sorted.equals(df_gold_sorted):
            return True, {"match": "exact", "rows": len(df_gen)}
        
        # Check if row counts match
        if len(df_gen) != len(df_gold):
            return False, {
                "match": "row_count_mismatch",
                "generated_rows": len(df_gen),
                "gold_rows": len(df_gold)
            }
        
        # Check column differences
        if set(df_gen.columns) != set(df_gold.columns):
            return False, {
                "match": "column_mismatch",
                "generated_cols": list(df_gen.columns),
                "gold_cols": list(df_gold.columns)
            }
        
        # Partial match (values differ)
        return False, {
            "match": "value_mismatch",
            "diff_rows": (df_gen_sorted != df_gold_sorted).sum().sum()
        }
        
    except Exception as e:
        return False, {"error": str(e)}

# Example
generated = "SELECT customer_id, SUM(total) as revenue FROM sales GROUP BY customer_id ORDER BY revenue DESC LIMIT 10"
gold = "SELECT customer_id, SUM(total_amount) as revenue FROM sales GROUP BY customer_id ORDER BY 2 DESC LIMIT 10"

match, details = execution_accuracy(generated, gold, conn)
print(f"{'✅' if match else '❌'} {details}")
```

**2. Component-Level Evaluation:**

```python
import sqlparse
from sqlparse.sql import IdentifierList, Identifier, Where

def parse_sql_components(sql: str) -> dict:
    """Extract components from SQL query"""
    
    parsed = sqlparse.parse(sql)[0]
    components = {
        "select": [],
        "from": [],
        "where": [],
        "group_by": [],
        "order_by": [],
        "limit": None
    }
    
    # Extract SELECT columns
    for token in parsed.tokens:
        if token.ttype is sqlparse.tokens.DML and token.value.upper() == 'SELECT':
            # Next token is column list
            continue
    
    # Simplified extraction (full implementation uses AST traversal)
    sql_lower = sql.lower()
    
    # Extract table names
    if 'from' in sql_lower:
        from_clause = sql_lower.split('from')[1].split('where')[0] if 'where' in sql_lower else sql_lower.split('from')[1]
        components["from"] = [t.strip() for t in from_clause.split('join')]
    
    return components

def component_accuracy(generated_sql: str, gold_sql: str) -> dict:
    """Compare SQL components"""
    
    gen_comp = parse_sql_components(generated_sql)
    gold_comp = parse_sql_components(gold_sql)
    
    scores = {}
    
    for component in ["select", "from", "where", "group_by", "order_by"]:
        gen_set = set(gen_comp.get(component, []))
        gold_set = set(gold_comp.get(component, []))
        
        if not gold_set:
            scores[component] = 1.0 if not gen_set else 0.0
        else:
            # Jaccard similarity
            intersection = len(gen_set & gold_set)
            union = len(gen_set | gold_set)
            scores[component] = intersection / union if union > 0 else 0.0
    
    # Overall score (weighted average)
    weights = {"select": 0.3, "from": 0.25, "where": 0.25, "group_by": 0.1, "order_by": 0.1}
    overall = sum(scores[k] * weights[k] for k in weights)
    
    return {
        "component_scores": scores,
        "overall_score": overall
    }

# Example
scores = component_accuracy(generated, gold)
print(f"Overall Score: {scores['overall_score']:.2%}")
for comp, score in scores['component_scores'].items():
    print(f"  {comp}: {score:.2%}")
```

**3. Spider Benchmark Evaluation:**

```python
# Spider dataset: https://yale-lily.github.io/spider
# 10,181 questions + SQL pairs across 200 databases

def evaluate_on_spider_subset(model_fn, test_cases: list) -> dict:
    """Evaluate model on Spider benchmark subset"""
    
    results = {
        "total": len(test_cases),
        "exact_match": 0,
        "execution_match": 0,
        "failed": 0,
        "details": []
    }
    
    for i, case in enumerate(test_cases):
        question = case["question"]
        gold_sql = case["sql"]
        db_id = case["db_id"]
        
        try:
            # Generate SQL
            generated_sql = model_fn(question, case["schema"])
            
            # Check exact match
            if generated_sql.strip().lower() == gold_sql.strip().lower():
                results["exact_match"] += 1
            
            # Check execution match
            # (requires database instances from Spider)
            conn = get_spider_db_connection(db_id)
            match, _ = execution_accuracy(generated_sql, gold_sql, conn)
            
            if match:
                results["execution_match"] += 1
            
            results["details"].append({
                "case_id": i,
                "exact_match": generated_sql.strip().lower() == gold_sql.strip().lower(),
                "execution_match": match
            })
            
        except Exception as e:
            results["failed"] += 1
            results["details"].append({
                "case_id": i,
                "error": str(e)
            })
    
    # Calculate percentages
    results["exact_match_pct"] = results["exact_match"] / results["total"] * 100
    results["execution_match_pct"] = results["execution_match"] / results["total"] * 100
    
    return results

# Example usage
test_cases = load_spider_subset(100)  # 100 test cases
results = evaluate_on_spider_subset(nl_to_sql, test_cases)

print(f"Exact Match: {results['exact_match_pct']:.1f}%")
print(f"Execution Accuracy: {results['execution_match_pct']:.1f}%")
```

**4. Production Monitoring Metrics:**

```python
from dataclasses import dataclass
from datetime import datetime
import numpy as np

@dataclass
class NL2SQLMetrics:
    """Production metrics for NL2SQL system"""
    
    timestamp: datetime
    
    # Volume metrics
    total_queries: int
    queries_per_minute: float
    
    # Success metrics
    validation_pass_rate: float  # % passed SQL validation
    execution_success_rate: float  # % executed without error
    result_quality_score: float  # % with non-empty results
    
    # Performance metrics
    avg_generation_time_ms: float
    p95_generation_time_ms: float
    avg_execution_time_ms: float
    p95_execution_time_ms: float
    
    # Cost metrics
    total_tokens_consumed: int
    total_cost_usd: float
    cost_per_query: float
    
    # User satisfaction
    thumbs_up_rate: float  # % of queries with positive feedback
    retry_rate: float  # % of users who retry same question
    
    # Error breakdown
    validation_errors: int
    execution_errors: int
    timeout_errors: int

class MetricsCollector:
    """Collect and aggregate NL2SQL metrics"""
    
    def __init__(self):
        self.queries = []
    
    def log_query(
        self,
        question: str,
        sql: str,
        generation_time_ms: float,
        execution_time_ms: float,
        validation_passed: bool,
        execution_success: bool,
        result_rows: int,
        tokens_used: int,
        cost_usd: float,
        user_feedback: str = None  # 'thumbs_up', 'thumbs_down', None
    ):
        self.queries.append({
            "timestamp": datetime.utcnow(),
            "question": question,
            "sql": sql,
            "generation_time_ms": generation_time_ms,
            "execution_time_ms": execution_time_ms,
            "validation_passed": validation_passed,
            "execution_success": execution_success,
            "result_rows": result_rows,
            "tokens_used": tokens_used,
            "cost_usd": cost_usd,
            "user_feedback": user_feedback
        })
    
    def get_metrics(self, time_window_minutes: int = 60) -> NL2SQLMetrics:
        """Calculate metrics for recent time window"""
        
        cutoff = datetime.utcnow() - timedelta(minutes=time_window_minutes)
        recent = [q for q in self.queries if q["timestamp"] >= cutoff]
        
        if not recent:
            return None
        
        # Calculate metrics
        total = len(recent)
        
        validation_passed = sum(1 for q in recent if q["validation_passed"])
        execution_success = sum(1 for q in recent if q["execution_success"])
        non_empty_results = sum(1 for q in recent if q["result_rows"] > 0)
        
        generation_times = [q["generation_time_ms"] for q in recent]
        execution_times = [q["execution_time_ms"] for q in recent if q["execution_time_ms"]]
        
        thumbs_up = sum(1 for q in recent if q["user_feedback"] == "thumbs_up")
        total_feedback = sum(1 for q in recent if q["user_feedback"])
        
        return NL2SQLMetrics(
            timestamp=datetime.utcnow(),
            total_queries=total,
            queries_per_minute=total / time_window_minutes,
            validation_pass_rate=validation_passed / total,
            execution_success_rate=execution_success / total,
            result_quality_score=non_empty_results / total,
            avg_generation_time_ms=np.mean(generation_times),
            p95_generation_time_ms=np.percentile(generation_times, 95),
            avg_execution_time_ms=np.mean(execution_times) if execution_times else 0,
            p95_execution_time_ms=np.percentile(execution_times, 95) if execution_times else 0,
            total_tokens_consumed=sum(q["tokens_used"] for q in recent),
            total_cost_usd=sum(q["cost_usd"] for q in recent),
            cost_per_query=sum(q["cost_usd"] for q in recent) / total,
            thumbs_up_rate=thumbs_up / total_feedback if total_feedback > 0 else 0,
            retry_rate=0,  # Calculate based on user session analysis
            validation_errors=total - validation_passed,
            execution_errors=validation_passed - execution_success,
            timeout_errors=0  # Track separately
        )

# Usage
collector = MetricsCollector()

# Log queries as they happen
collector.log_query(
    question="Top 10 products",
    sql="SELECT * FROM products ORDER BY sales DESC LIMIT 10",
    generation_time_ms=450,
    execution_time_ms=120,
    validation_passed=True,
    execution_success=True,
    result_rows=10,
    tokens_used=250,
    cost_usd=0.001,
    user_feedback="thumbs_up"
)

# Get hourly metrics
metrics = collector.get_metrics(time_window_minutes=60)
print(f"Queries/min: {metrics.queries_per_minute:.1f}")
print(f"Success Rate: {metrics.execution_success_rate:.1%}")
print(f"Avg Latency: {metrics.avg_generation_time_ms:.0f}ms")
print(f"Cost/Query: ${metrics.cost_per_query:.4f}")
```

**5. Grafana Dashboard Metrics:**

```yaml
# Prometheus metrics for NL2SQL
nl2sql_queries_total{status="success|validation_failed|execution_failed"}
nl2sql_generation_duration_seconds{model="gpt-4o"}
nl2sql_execution_duration_seconds
nl2sql_tokens_consumed_total{model="gpt-4o",type="input|output"}
nl2sql_cost_usd_total{model="gpt-4o"}
nl2sql_user_satisfaction{feedback="thumbs_up|thumbs_down"}
nl2sql_result_rows{quantile="0.5|0.95|0.99"}
```

**Grafana Panels:**
1. **Request Rate**: `rate(nl2sql_queries_total[5m])`
2. **Success Rate**: `rate(nl2sql_queries_total{status="success"}[5m]) / rate(nl2sql_queries_total[5m])`
3. **Latency p95**: `histogram_quantile(0.95, nl2sql_generation_duration_seconds)`
4. **Cost Burn**: `rate(nl2sql_cost_usd_total[1h]) * 24 * 30` ($/month)
5. **User Satisfaction**: `rate(nl2sql_user_satisfaction{feedback="thumbs_up"}[1h]) / rate(nl2sql_user_satisfaction[1h])`

**6. A/B Testing Framework:**

```python
import random

class ABTestFramework:
    """A/B test different NL2SQL strategies"""
    
    def __init__(self, variants: dict):
        self.variants = variants  # {variant_name: generator_function}
        self.results = {name: [] for name in variants.keys()}
    
    def generate(self, question: str, schema: str, user_id: str) -> dict:
        """Route user to variant and track results"""
        
        # Consistent variant assignment per user
        variant_name = self._assign_variant(user_id)
        generator = self.variants[variant_name]
        
        # Generate SQL
        start = time.time()
        sql = generator(question, schema)
        generation_time = (time.time() - start) * 1000
        
        return {
            "variant": variant_name,
            "sql": sql,
            "generation_time_ms": generation_time
        }
    
    def log_result(
        self,
        variant: str,
        success: bool,
        generation_time_ms: float,
        user_feedback: str = None
    ):
        """Log result for variant"""
        self.results[variant].append({
            "success": success,
            "generation_time_ms": generation_time_ms,
            "user_feedback": user_feedback
        })
    
    def analyze(self) -> dict:
        """Statistical analysis of variants"""
        
        analysis = {}
        
        for variant, results in self.results.items():
            if not results:
                continue
            
            analysis[variant] = {
                "sample_size": len(results),
                "success_rate": sum(1 for r in results if r["success"]) / len(results),
                "avg_latency_ms": np.mean([r["generation_time_ms"] for r in results]),
                "thumbs_up_rate": sum(1 for r in results if r["user_feedback"] == "thumbs_up") / len(results)
            }
        
        return analysis
    
    def _assign_variant(self, user_id: str) -> str:
        """Consistent hash-based assignment"""
        hash_val = hash(user_id) % 100
        
        # 50/50 split (adjust as needed)
        threshold = 50
        variants_list = list(self.variants.keys())
        
        return variants_list[0] if hash_val < threshold else variants_list[1]

# Example
variants = {
    "baseline": lambda q, s: nl_to_sql(q, s, model="gpt-4o-mini"),
    "optimized": lambda q, s: nl_to_sql_fewshot(q, s, model="gpt-4o")
}

ab_test = ABTestFramework(variants)

# Run experiment
for user_id in users:
    result = ab_test.generate(question, schema, user_id)
    
    # Execute and collect feedback
    success = execute_and_validate(result["sql"])
    feedback = get_user_feedback(user_id)
    
    ab_test.log_result(
        variant=result["variant"],
        success=success,
        generation_time_ms=result["generation_time_ms"],
        user_feedback=feedback
    )

# Analyze results
analysis = ab_test.analyze()
print("A/B Test Results:")
for variant, metrics in analysis.items():
    print(f"\n{variant}:")
    print(f"  Success Rate: {metrics['success_rate']:.1%}")
    print(f"  Avg Latency: {metrics['avg_latency_ms']:.0f}ms")
    print(f"  Satisfaction: {metrics['thumbs_up_rate']:.1%}")
```

**Production Readiness Checklist:**

- ✅ Evaluation metrics tracked (Execution Accuracy, Component Match)
- ✅ Continuous monitoring (Prometheus + Grafana)
- ✅ A/B testing framework (compare strategies)
- ✅ User feedback loop (thumbs up/down)
- ✅ Error tracking & categorization
- ✅ Cost monitoring & alerts
- ✅ Latency SLOs (p95 < 2s)
- ✅ Quality SLOs (success rate > 90%)
- ✅ Regression testing suite (100+ test cases)
- ✅ Model version tracking (experiments)

---
**Autor:** Luis J. Raigoso V. (LJRV)

In [None]:
import os
from openai import OpenAI
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))

def nl_to_sql_simple(question: str) -> str:
    prompt = f'''
Convierte esta pregunta en SQL:

Pregunta: {question}

SQL:
'''
    resp = client.chat.completions.create(
        model='gpt-3.5-turbo',
        messages=[{'role':'user','content':prompt}],
        temperature=0
    )
    return resp.choices[0].message.content

print(nl_to_sql_simple('Muestra las 10 ventas más recientes'))

## 2. Mejorado: con contexto de schema

In [None]:
schema_context = '''
Tablas disponibles:

ventas:
  - venta_id (int, PK)
  - fecha (date)
  - cliente_id (int, FK)
  - producto_id (int, FK)
  - cantidad (int)
  - total (decimal)

clientes:
  - cliente_id (int, PK)
  - nombre (varchar)
  - email (varchar)
  - pais (varchar)

productos:
  - producto_id (int, PK)
  - nombre (varchar)
  - categoria (varchar)
  - precio (decimal)
'''

def nl_to_sql(question: str, schema: str = schema_context, dialect='postgresql') -> str:
    prompt = f'''
Eres un experto en SQL. Genera una consulta {dialect} válida para responder la pregunta.

Schema de base de datos:
{schema}

Pregunta: {question}

Genera SOLO la consulta SQL, sin explicaciones adicionales.
'''
    resp = client.chat.completions.create(
        model='gpt-4',
        messages=[{'role':'user','content':prompt}],
        temperature=0
    )
    return resp.choices[0].message.content.strip().replace('```sql','').replace('```','')

query = nl_to_sql('¿Cuál es el top 5 de productos más vendidos por categoría?')
print(query)

## 3. Validación y seguridad

In [None]:
import re
import sqlparse

def is_safe_query(sql: str) -> tuple[bool, str]:
    """Valida que la consulta sea de solo lectura y segura."""
    sql_lower = sql.lower()
    # Bloquear operaciones de escritura
    dangerous = ['insert', 'update', 'delete', 'drop', 'create', 'alter', 'truncate', 'exec', 'execute']
    for kw in dangerous:
        if re.search(rf'\b{kw}\b', sql_lower):
            return False, f'Operación bloqueada: {kw}'
    
    # Solo permitir SELECT
    if not sql_lower.strip().startswith('select'):
        return False, 'Solo se permiten consultas SELECT'
    
    return True, 'OK'

test_queries = [
    'SELECT * FROM ventas LIMIT 10',
    'DELETE FROM ventas WHERE fecha < \'2020-01-01\'',
    'SELECT COUNT(*) FROM clientes'
]

for q in test_queries:
    safe, msg = is_safe_query(q)
    print(f'{"✅" if safe else "❌"} {msg}: {q[:50]}')

## 4. Ejecución segura con SQLite demo

In [None]:
import sqlite3
import pandas as pd

# Setup demo DB
conn = sqlite3.connect(':memory:')
conn.executescript('''
CREATE TABLE ventas (venta_id INT, fecha TEXT, cliente_id INT, producto_id INT, cantidad INT, total REAL);
INSERT INTO ventas VALUES (1,'2025-10-01',10,101,2,200.0),(2,'2025-10-02',11,102,1,50.0),(3,'2025-10-03',10,101,1,100.0);
CREATE TABLE productos (producto_id INT, nombre TEXT, categoria TEXT);
INSERT INTO productos VALUES (101,'Laptop','Electronics'),(102,'Mouse','Electronics');
''')

def execute_nl_query(question: str):
    sql = nl_to_sql(question)
    print(f'SQL generado: {sql}\n')
    safe, msg = is_safe_query(sql)
    if not safe:
        return f'❌ Consulta bloqueada: {msg}'
    try:
        df = pd.read_sql_query(sql, conn)
        return df
    except Exception as e:
        return f'Error ejecutando SQL: {e}'

result = execute_nl_query('Muestra el total de ventas por producto')
print(result)

## 5. Few-shot con ejemplos

In [None]:
few_shot_examples = '''
Ejemplos:

Q: ¿Cuántas ventas hubo ayer?
SQL: SELECT COUNT(*) FROM ventas WHERE fecha = CURRENT_DATE - INTERVAL '1 day';

Q: Top 3 clientes por gasto total
SQL: SELECT c.nombre, SUM(v.total) as gasto FROM ventas v JOIN clientes c ON v.cliente_id=c.cliente_id GROUP BY c.nombre ORDER BY gasto DESC LIMIT 3;

Q: Productos sin ventas en octubre
SQL: SELECT p.nombre FROM productos p LEFT JOIN ventas v ON p.producto_id=v.producto_id AND v.fecha >= '2025-10-01' AND v.fecha < '2025-11-01' WHERE v.venta_id IS NULL;
'''

def nl_to_sql_fewshot(question: str) -> str:
    prompt = f'''
Genera SQL para responder preguntas sobre ventas.

{schema_context}

{few_shot_examples}

Ahora genera SQL para:
Q: {question}
SQL:
'''
    resp = client.chat.completions.create(
        model='gpt-4',
        messages=[{'role':'user','content':prompt}],
        temperature=0
    )
    return resp.choices[0].message.content.strip().replace('```sql','').replace('```','')

query = nl_to_sql_fewshot('Ventas por categoría en los últimos 7 días')
print(query)

## 6. Optimización y caché

In [None]:
from functools import lru_cache
import hashlib

@lru_cache(maxsize=128)
def cached_nl_to_sql(question: str) -> str:
    return nl_to_sql(question)

# Llamadas repetidas usan caché
q = '¿Cuál es el total de ventas del mes actual?'
print('Primera llamada:')
print(cached_nl_to_sql(q))
print('\nSegunda llamada (desde caché):')
print(cached_nl_to_sql(q))

## 7. Buenas prácticas

- **Schema completo**: incluye tipos, PKs, FKs, índices.
- **Ejemplos representativos**: few-shot con casos edge.
- **Validación estricta**: whitelist de operaciones permitidas.
- **Timeout y límites**: evita consultas costosas.
- **Logging**: registra pregunta, SQL generado, resultado.
- **Feedback loop**: almacena correcciones humanas para fine-tuning.

## 8. Ejercicios

1. Agrega soporte multi-idioma (inglés/español).
2. Implementa un sistema de aprobación humana para SQL complejos.
3. Crea un dashboard Streamlit donde usuarios escriban preguntas y vean resultados.
4. Añade explicación del SQL generado en lenguaje natural.