# Text-to-SQL Agent - Interactive Demo

This notebook demonstrates the key features of the Text-to-SQL Agent system:

1. **Schema Management** - Loading database schema from Excel
2. **Semantic Join Inference** - Finding joins without explicit foreign keys
3. **Session Management** - Tracking and persisting agent state
4. **Correction System** - Learning from user feedback
5. **BigQuery Integration** - Executing and validating queries
6. **Error Recovery** - Handling API failures gracefully

## Prerequisites

Before running this notebook:
1. Install dependencies: `pip install -r ../requirements.txt`
2. Configure `.env` file with your credentials
3. Prepare your schema Excel file

## Setup and Imports

In [None]:
# Add parent directory to path
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd().parent))

# Import main components
from src import (
    settings,
    schema_loader,
    bigquery_client,
    azure_client,
    session_manager,
    AgentState,
    JoinInference,
    CorrectionParser,
)

from src.utils import (
    RetryExhaustedError,
    FatalError,
    AmbiguityError,
    setup_logger,
)

# For nice display
import pandas as pd
from IPython.display import display, Markdown, HTML
import json

print("‚úÖ Imports successful!")

## 1. Configuration Check

Let's verify that the system is properly configured.

In [None]:
# Check configuration
print("Configuration Status:")
print("=" * 50)

try:
    azure_endpoint = settings.get("azure_openai.endpoint")
    print(f"‚úÖ Azure OpenAI Endpoint: {azure_endpoint[:50]}...")
except:
    print("‚ùå Azure OpenAI not configured")

try:
    project_id = settings.get("bigquery.project_id")
    dataset = settings.get("bigquery.dataset")
    print(f"‚úÖ BigQuery Project: {project_id}")
    print(f"‚úÖ BigQuery Dataset: {dataset}")
except:
    print("‚ùå BigQuery not configured")

try:
    schema_path = settings.get("schema.excel_path")
    print(f"‚úÖ Schema Path: {schema_path}")
except:
    print("‚ö†Ô∏è  Schema path not set (will need to provide manually)")

print("=" * 50)

## 2. Schema Loading

Load database schema from Excel file containing table and column metadata.

In [None]:
# Option 1: Load from configured path
try:
    schema = schema_loader.load_from_excel()
    print(f"‚úÖ Loaded schema from configured path")
except Exception as e:
    print(f"‚ö†Ô∏è  Could not load from configured path: {e}")
    print("\nOption 2: Provide path manually:")
    print("schema = schema_loader.load_from_excel(excel_path='/path/to/your/schema.xlsx')")
    
    # For demo purposes, create a mock schema
    from src.schema import Schema, Table, Column, ColumnType
    
    schema = Schema(project_id="demo-project", dataset="demo_dataset")
    
    # Create Customers table
    customers = Table(name="Customers", description="Customer master data")
    customers.add_column(Column(name="customer_id", data_type=ColumnType.INTEGER, is_primary=True))
    customers.add_column(Column(name="customer_name", data_type=ColumnType.STRING))
    customers.add_column(Column(name="region", data_type=ColumnType.STRING, description="Geographic region"))
    customers.add_column(Column(name="account_status", data_type=ColumnType.STRING))
    schema.add_table(customers)
    
    # Create Orders table
    orders = Table(name="Orders", description="Customer orders")
    orders.add_column(Column(name="order_id", data_type=ColumnType.INTEGER, is_primary=True))
    orders.add_column(Column(name="customer_id", data_type=ColumnType.INTEGER, description="Reference to customer"))
    orders.add_column(Column(name="order_date", data_type=ColumnType.DATE))
    orders.add_column(Column(name="amount", data_type=ColumnType.FLOAT))
    schema.add_table(orders)
    
    # Create Products table
    products = Table(name="Products", description="Product catalog")
    products.add_column(Column(name="product_id", data_type=ColumnType.INTEGER, is_primary=True))
    products.add_column(Column(name="product_name", data_type=ColumnType.STRING))
    products.add_column(Column(name="category", data_type=ColumnType.STRING))
    products.add_column(Column(name="price", data_type=ColumnType.FLOAT))
    schema.add_table(products)
    
    print("\nüìù Created demo schema for illustration purposes")

In [None]:
# Display schema summary
print(f"\nüìä Schema Summary")
print("=" * 60)
print(f"Project: {schema.project_id}")
print(f"Dataset: {schema.dataset}")
print(f"Total Tables: {len(schema.tables)}")
print("\nTables:")

for table_name, table in schema.tables.items():
    print(f"\n  üìã {table_name}")
    if table.description:
        print(f"     {table.description}")
    print(f"     Columns: {len(table.columns)}")
    
    # Show first few columns
    for col in table.columns[:5]:
        indicators = []
        if col.is_primary:
            indicators.append("üîë PK")
        if col.is_pii:
            indicators.append("üîí PII")
        indicator_str = " ".join(indicators)
        print(f"       ‚Ä¢ {col.name} ({col.data_type.value}) {indicator_str}")
    
    if len(table.columns) > 5:
        print(f"       ... and {len(table.columns) - 5} more columns")

## 3. Semantic Join Inference

The system can automatically infer how to join tables even without explicit foreign keys, using:
- Column name similarity
- Business name matching
- Data type compatibility
- LLM semantic understanding

In [None]:
# Initialize join inference
join_inference = JoinInference(schema, confidence_threshold=0.70)

# Get table names
table_names = list(schema.tables.keys())
print(f"Available tables: {table_names}")

if len(table_names) >= 2:
    table1, table2 = table_names[0], table_names[1]
    
    print(f"\nüîç Inferring joins between: {table1} ‚Üî {table2}")
    print("=" * 60)
    
    try:
        joins = join_inference.infer_joins(table1, table2)
        
        print(f"\n‚úÖ Found {len(joins)} possible join(s):\n")
        
        for i, join in enumerate(joins, 1):
            confidence_emoji = "üü¢" if join.confidence >= 0.9 else "üü°" if join.confidence >= 0.7 else "üü†"
            print(f"{confidence_emoji} Option {i}:")
            print(f"   SQL: {join.to_sql_condition()}")
            print(f"   Confidence: {join.confidence:.1%}")
            print(f"   Reasoning: {join.reasoning}")
            print()
            
    except AmbiguityError as e:
        print(f"\n‚ö†Ô∏è  Ambiguity Detected!")
        print(f"\nMessage: {e}")
        print(f"\nOptions to choose from:")
        for i, opt in enumerate(e.options, 1):
            print(f"  {i}. {opt}")
        print("\nüí° User would be prompted to select the correct option")
        
    except Exception as e:
        print(f"‚ùå Error: {e}")
else:
    print("‚ö†Ô∏è  Need at least 2 tables for join inference demo")

### Manual Join Inference

You can also manually specify which tables to analyze:

In [None]:
# Example: Infer joins between specific tables
if "Customers" in schema.tables and "Orders" in schema.tables:
    print("üîç Analyzing: Customers ‚Üî Orders")
    print("=" * 60)
    
    try:
        joins = join_inference.infer_joins("Customers", "Orders")
        
        if joins:
            best_join = joins[0]
            print(f"\n‚ú® Best join found:")
            print(f"   {best_join.to_sql_condition()}")
            print(f"   Confidence: {best_join.confidence:.1%}")
            
            # Show as SQL
            sql_example = f"""
SELECT c.customer_name, COUNT(o.order_id) as total_orders
FROM Customers c
JOIN Orders o ON {best_join.to_sql_condition()}
GROUP BY c.customer_name
"""
            print(f"\nüìù Example SQL usage:")
            print(sql_example)
    except Exception as e:
        print(f"Note: {e}")
else:
    print("‚ö†Ô∏è  Customers and Orders tables not found in schema")

## 4. Session Management

Sessions track the entire conversation and can be saved/resumed at any time.

In [None]:
# Create a new session
user_query = "Show me the top 5 customers by total order amount in Q4 2025"

print(f"üí¨ User Query: '{user_query}'")
print("=" * 60)

session = session_manager.create_session(user_query)

print(f"\n‚úÖ Session created")
print(f"   Session ID: {session.session_id}")
print(f"   Status: {session.status}")
print(f"   Created: {session.created_at.strftime('%Y-%m-%d %H:%M:%S')}")

In [None]:
# Simulate agent workflow
print("ü§ñ Agent Workflow Simulation")
print("=" * 60)

# Step 1: Add user message
session.add_message("user", user_query)
print("1Ô∏è‚É£  Added user message")

# Step 2: Transition to query understanding
session.state_machine.transition_to(
    AgentState.QUERY_UNDERSTANDING,
    reason="Starting query analysis"
)
print(f"2Ô∏è‚É£  State: {session.state_machine.current_state.value}")

# Step 3: Identify tables
session.identified_tables = ["Customers", "Orders"]
session.add_intermediate_result(
    "identified_tables",
    {"tables": session.identified_tables, "confidence": 0.95}
)
print(f"3Ô∏è‚É£  Identified tables: {session.identified_tables}")

# Step 4: Join inference
session.state_machine.transition_to(
    AgentState.JOIN_INFERENCE,
    reason="Inferring table joins"
)

if len(session.identified_tables) >= 2:
    try:
        joins = join_inference.infer_joins(
            session.identified_tables[0],
            session.identified_tables[1]
        )
        session.inferred_joins = [j.to_dict() for j in joins]
        print(f"4Ô∏è‚É£  Inferred {len(joins)} join(s)")
    except:
        print("4Ô∏è‚É£  Join inference skipped (demo mode)")

# Step 5: Increment iteration
session.increment_iteration()
print(f"5Ô∏è‚É£  Iteration: {session.iteration_count}")

# Step 6: Save session
session_manager.save_session(session)
print(f"6Ô∏è‚É£  Session saved to disk")

print(f"\n‚úÖ Workflow complete!")

In [None]:
# View session details
print("üìä Session Details")
print("=" * 60)
print(f"Session ID: {session.session_id}")
print(f"Status: {session.status}")
print(f"Current State: {session.state_machine.current_state.value}")
print(f"Iterations: {session.iteration_count}")
print(f"Messages: {len(session.messages)}")
print(f"Identified Tables: {session.identified_tables}")
print(f"Inferred Joins: {len(session.inferred_joins)}")
print(f"\nState Transitions:")
for i, trans in enumerate(session.state_machine.get_transition_history(), 1):
    print(f"  {i}. {trans['from_state']} ‚Üí {trans['to_state']}")
    if trans['reason']:
        print(f"     Reason: {trans['reason']}")

### List All Sessions

In [None]:
# List recent sessions
sessions = session_manager.list_sessions(limit=10)

print(f"üìã Recent Sessions ({len(sessions)} found)")
print("=" * 80)

if sessions:
    # Create DataFrame for nice display
    df = pd.DataFrame(sessions)
    df['session_id'] = df['session_id'].str[:8] + '...'  # Truncate for display
    df['query'] = df['query'].str[:50] + '...'  # Truncate long queries
    display(df)
else:
    print("No sessions found")

### Resume a Session

In [None]:
# Resume the session we just created
print(f"üîÑ Resuming session: {session.session_id[:8]}...")
print("=" * 60)

resumed_session = session_manager.load_session(session.session_id)

print(f"‚úÖ Session resumed successfully!")
print(f"   Query: {resumed_session.original_query}")
print(f"   State: {resumed_session.state_machine.current_state.value}")
print(f"   Iteration: {resumed_session.iteration_count}")
print(f"   Messages: {len(resumed_session.messages)}")
print(f"\nüí° The session can now continue from where it left off")

## 5. Correction System

Users can provide corrections to guide the agent when it makes mistakes or encounters ambiguity.

In [None]:
print("üîß User Correction Examples")
print("=" * 60)

# Create a session for corrections demo
correction_session = session_manager.create_session("Demo for corrections")

# Example 1: Join clarification
print("\n1Ô∏è‚É£  Join Clarification")
correction1 = CorrectionParser.parse("join Customers.customer_id with Orders.customer_id")
print(f"   Input: 'join Customers.customer_id with Orders.customer_id'")
print(f"   Type: {correction1.correction_type.value}")
print(f"   Content: {correction1.content}")
print(f"   Constraint: {correction1.to_constraint_string()}")

correction_session.add_correction(correction1)

# Example 2: Column mapping
print("\n2Ô∏è‚É£  Column Mapping")
correction2 = CorrectionParser.parse("region means Customers.geographic_area")
print(f"   Input: 'region means Customers.geographic_area'")
print(f"   Type: {correction2.correction_type.value}")
print(f"   Content: {correction2.content}")
print(f"   Constraint: {correction2.to_constraint_string()}")

correction_session.add_correction(correction2)

# Example 3: Natural language correction
print("\n3Ô∏è‚É£  Natural Language Correction")
correction3 = CorrectionParser.parse(
    "Use the customer_id field from Orders table, not the account_number field"
)
print(f"   Input: 'Use the customer_id field from Orders table...'")
print(f"   Type: {correction3.correction_type.value}")
print(f"   Content: {correction3.content}")
print(f"   Constraint: {correction3.to_constraint_string()}")

correction_session.add_correction(correction3)

In [None]:
# View all corrections in session
print("\nüìù Session Corrections Summary")
print("=" * 60)
print(f"Total corrections: {len(correction_session.corrections)}")
print(f"\nHard constraints (applied to LLM prompts):")
for i, constraint in enumerate(correction_session.hard_constraints, 1):
    print(f"  {i}. {constraint}")

print("\nüí° These constraints will be included in all future LLM prompts")
print("   to ensure the agent follows user's specifications.")

### Structured Correction Format

In [None]:
# Corrections can also be provided in structured format
print("üéØ Structured Correction Format")
print("=" * 60)

structured_correction = {
    "type": "join",
    "tables": ["Orders", "Products"],
    "join_condition": "Orders.product_id = Products.product_id",
    "description": "Correct join for order-product relationship"
}

correction = CorrectionParser.parse_dict(structured_correction)

print(f"Input (JSON):")
print(json.dumps(structured_correction, indent=2))
print(f"\nParsed correction:")
print(f"  Type: {correction.correction_type.value}")
print(f"  Constraint: {correction.to_constraint_string()}")

## 6. BigQuery Integration

Execute and validate SQL queries against BigQuery.

**Note**: These operations require valid BigQuery credentials.

In [None]:
# Example query (modify for your schema)
test_query = f"""
SELECT
    table_name,
    row_count
FROM `{settings.get('bigquery.project_id')}.{settings.get('bigquery.dataset')}.__TABLES__`
LIMIT 5
"""

print("üóÑÔ∏è  BigQuery Operations Demo")
print("=" * 60)
print(f"\nQuery:")
print(test_query)

### Step 1: Validate Query

In [None]:
try:
    print("1Ô∏è‚É£  Validating query...")
    validation = bigquery_client.validate_query(test_query)
    
    if validation["success"]:
        print(f"   ‚úÖ Query is valid")
        print(f"   üìä Bytes to process: {validation.get('bytes_processed', 0):,}")
    else:
        print(f"   ‚ùå Validation failed: {validation.get('error')}")
except Exception as e:
    print(f"   ‚ö†Ô∏è  Validation skipped: {e}")
    print("   (Make sure BigQuery credentials are configured)")

### Step 2: Estimate Cost

In [None]:
try:
    print("2Ô∏è‚É£  Estimating query cost...")
    cost_info = bigquery_client.estimate_query_cost(test_query)
    
    if cost_info["success"]:
        print(f"   üí∞ Estimated cost: ${cost_info['estimated_cost_usd']:.6f}")
        print(f"   üì¶ Data size: {cost_info['readable_size']}")
    else:
        print(f"   ‚ùå Cost estimation failed: {cost_info.get('error')}")
except Exception as e:
    print(f"   ‚ö†Ô∏è  Cost estimation skipped: {e}")

### Step 3: Execute Query

In [None]:
try:
    print("3Ô∏è‚É£  Executing query...")
    result = bigquery_client.execute_query(test_query, max_results=10)
    
    if result["success"]:
        print(f"   ‚úÖ Query successful!")
        print(f"   üìä Rows returned: {result['row_count']}")
        print(f"   üì¶ Bytes processed: {result['bytes_processed']:,}")
        
        # Display results as DataFrame
        if result['rows']:
            print(f"\n   Results:")
            df_results = pd.DataFrame(result['rows'])
            display(df_results)
    else:
        print(f"   ‚ùå Query failed: {result['error']}")
        print(f"   Error type: {result.get('error_type')}")
        
except Exception as e:
    print(f"   ‚ö†Ô∏è  Execution skipped: {e}")
    print(f"\n   üí° To enable BigQuery operations:")
    print(f"      1. Set up Google Cloud credentials")
    print(f"      2. Configure GOOGLE_APPLICATION_CREDENTIALS in .env")
    print(f"      3. Set GCP_PROJECT_ID and BIGQUERY_DATASET")

## 7. Azure OpenAI Integration with Retry

Make LLM calls with automatic retry on failures.

In [None]:
print("ü§ñ Azure OpenAI Demo")
print("=" * 60)

# Create a session for this demo
llm_session = session_manager.create_session("Test LLM capabilities")

test_prompt = "Explain what a database foreign key is in one sentence."

print(f"\nüìù Prompt: '{test_prompt}'")
print(f"\nüîÑ Making API call with automatic retry...\n")

try:
    response = azure_client.chat_completion(
        messages=[
            {
                "role": "system",
                "content": "You are a helpful database assistant."
            },
            {
                "role": "user",
                "content": test_prompt
            }
        ],
        session=llm_session,
        temperature=0.0,
    )
    
    print("‚úÖ Response received:")
    print("=" * 60)
    print(response)
    print("=" * 60)
    
    print(f"\nüìä Session updated:")
    print(f"   Messages: {len(llm_session.messages)}")
    
except RetryExhaustedError as e:
    print(f"‚ùå All retry attempts failed: {e}")
    print(f"\nüíæ Session {llm_session.session_id[:8]}... has been saved")
    print(f"   You can resume it later when the service is available")
    print(f"\n   Command: session_manager.load_session('{llm_session.session_id}')")
    
except FatalError as e:
    print(f"‚ùå Non-recoverable error: {e}")
    print(f"\nüí° Check your Azure OpenAI configuration:")
    print(f"   - AZURE_OPENAI_ENDPOINT")
    print(f"   - AZURE_OPENAI_API_KEY")
    print(f"   - AZURE_OPENAI_DEPLOYMENT")

except Exception as e:
    print(f"‚ö†Ô∏è  Could not make LLM call: {e}")
    print(f"\nüí° This is expected if Azure OpenAI is not configured.")
    print(f"   The retry mechanism would handle temporary failures automatically.")

## 8. End-to-End Example

Putting it all together: Complete workflow from query to SQL generation.

In [None]:
print("üéØ End-to-End Workflow")
print("=" * 80)

# User query
user_query = "What are the top 5 customers by total spending?"
print(f"\nüë§ User Query: '{user_query}'")
print("\n" + "="*80)

# Step 1: Create session
print("\n1Ô∏è‚É£  Creating session...")
workflow_session = session_manager.create_session(user_query)
workflow_session.add_message("user", user_query)
print(f"   ‚úÖ Session ID: {workflow_session.session_id[:8]}...")

# Step 2: Load schema
print("\n2Ô∏è‚É£  Loading schema...")
workflow_session.schema_snapshot = schema.to_dict()
workflow_session.state_machine.transition_to(AgentState.SCHEMA_LOADING)
print(f"   ‚úÖ Loaded {len(schema.tables)} tables")

# Step 3: Identify relevant tables
print("\n3Ô∏è‚É£  Identifying relevant tables...")
workflow_session.state_machine.transition_to(AgentState.QUERY_UNDERSTANDING)
workflow_session.identified_tables = ["Customers", "Orders"]  # Would be done by LLM
print(f"   ‚úÖ Identified: {workflow_session.identified_tables}")

# Step 4: Infer joins
print("\n4Ô∏è‚É£  Inferring table joins...")
workflow_session.state_machine.transition_to(AgentState.JOIN_INFERENCE)
try:
    if len(workflow_session.identified_tables) >= 2:
        joins = join_inference.infer_joins(
            workflow_session.identified_tables[0],
            workflow_session.identified_tables[1]
        )
        workflow_session.inferred_joins = [j.to_dict() for j in joins]
        print(f"   ‚úÖ Found {len(joins)} join(s)")
        if joins:
            print(f"   üìä Best: {joins[0].to_sql_condition()} (confidence: {joins[0].confidence:.1%})")
except Exception as e:
    print(f"   ‚ö†Ô∏è  Using demo joins: {e}")
    workflow_session.inferred_joins = [{
        "left_table": "Customers",
        "right_table": "Orders",
        "left_column": "customer_id",
        "right_column": "customer_id",
        "confidence": 0.95
    }]

# Step 5: Generate SQL (simulated)
print("\n5Ô∏è‚É£  Generating SQL query...")
workflow_session.state_machine.transition_to(AgentState.GENERATING_SQL)

# Simulated SQL generation (would normally use LLM)
if workflow_session.inferred_joins:
    join_info = workflow_session.inferred_joins[0]
    generated_sql = f"""
SELECT
    c.customer_name,
    SUM(o.amount) as total_spending
FROM {schema.dataset}.Customers c
JOIN {schema.dataset}.Orders o
    ON c.{join_info['left_column']} = o.{join_info['right_column']}
GROUP BY c.customer_name
ORDER BY total_spending DESC
LIMIT 5
"""
else:
    generated_sql = "-- SQL generation would happen here"

workflow_session.add_sql_attempt(generated_sql, success=True)
workflow_session.increment_iteration()

print(f"   ‚úÖ SQL generated")

# Step 6: Complete workflow
print("\n6Ô∏è‚É£  Finalizing...")
workflow_session.state_machine.transition_to(AgentState.COMPLETED)
session_manager.save_session(workflow_session)
print(f"   ‚úÖ Session saved with status: {workflow_session.status}")

print("\n" + "="*80)
print("‚ú® Workflow Complete!")
print("="*80)

In [None]:
# Display generated SQL
print("\nüìù Generated SQL Query:")
print("=" * 80)
print(generated_sql)
print("=" * 80)

In [None]:
# Display workflow summary
print("\nüìä Workflow Summary")
print("=" * 80)

summary_data = {
    "Session ID": workflow_session.session_id,
    "Original Query": workflow_session.original_query,
    "Final State": workflow_session.state_machine.current_state.value,
    "Status": workflow_session.status,
    "Iterations": workflow_session.iteration_count,
    "Tables Identified": len(workflow_session.identified_tables),
    "Joins Inferred": len(workflow_session.inferred_joins),
    "SQL Attempts": len(workflow_session.sql_attempts),
    "Corrections Applied": len(workflow_session.corrections),
}

for key, value in summary_data.items():
    print(f"  {key:.<40} {value}")

print("\nüìà State Transition History:")
for i, trans in enumerate(workflow_session.state_machine.get_transition_history(), 1):
    print(f"  {i}. {trans['from_state']:.<25} ‚Üí {trans['to_state']}")

## Summary

This notebook demonstrated the key capabilities of the Text-to-SQL Agent:

‚úÖ **Schema Management** - Load and explore database metadata from Excel  
‚úÖ **Semantic Join Inference** - Automatically find table relationships  
‚úÖ **Session Persistence** - Save and resume agent state  
‚úÖ **User Corrections** - Learn from feedback and improve  
‚úÖ **BigQuery Integration** - Validate and execute queries  
‚úÖ **Error Recovery** - Handle API failures gracefully  
‚úÖ **End-to-End Workflow** - Complete query processing pipeline  

## Next Steps

1. **Configure your environment** with real credentials
2. **Prepare your schema** Excel file
3. **Try real queries** against your BigQuery dataset
4. **Extend the system** with custom reasoning modules
5. **Build a UI** on top of this framework

---

**For more information, see:**
- [README.md](../README.md) - Setup and usage guide
- [ARCHITECTURE.md](../ARCHITECTURE.md) - System design documentation
- [example_usage.py](example_usage.py) - Python examples