# Evaluate AutoGenText2SQL

This notebook evaluates the AutoGenText2Sql class using the Spider test suite evaluation metric. 

The evaluation uses the official Spider evaluation approach, which requires:

1. A gold file with format: `SQL query \t database_id`
2. A predictions file with generated SQL queries
3. The Spider databases and schema information

### Required Data Downloads

Before running this notebook, you need to download and set up two required directories:

1. Spider Test Suite Evaluation Scripts:
   - Download from: https://github.com/taoyds/test-suite-sql-eval
   - Clone this repository into `/text_2_sql/test-suite-sql-eval/` directory:
   ```bash
   cd text_2_sql
   git clone https://github.com/taoyds/test-suite-sql-eval
   ```

2. Spider Dataset:
   - Download from: https://drive.google.com/file/d/1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J/view
   - Extract the downloaded file into `/text_2_sql/spider_data/` directory
   - The directory should contain:
     - `database/` directory with all the SQLite databases
     - `tables.json` with schema information
     - `dev.json` with development set queries

### Dependencies

To install dependencies for this evaluation:

`uv sync --package autogen_text_2_sql`

`uv add --editable text_2_sql_core`

In [None]:
import sys
import os
import time
import json
import logging
import subprocess
import dotenv
from pathlib import Path

# Get the notebook directory path
notebook_dir = Path().absolute()
# Add the src directory to the path
sys.path.append(str(notebook_dir / "src"))

from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload
from autogen_text_2_sql.evaluation_utils import get_final_sql_query

# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

# Set up paths
TEST_SUITE_DIR = Path("../test-suite-sql-eval")
SPIDER_DATA_DIR = Path("../spider_data").absolute()
DATABASE_DIR = SPIDER_DATA_DIR / "database"

# Set SPIDER_DATA_DIR in environment so SQLiteSqlConnector can find tables.json
os.environ["SPIDER_DATA_DIR"] = str(SPIDER_DATA_DIR)

# Load environment variables
dotenv.load_dotenv()

In [2]:
# Initialize the AutoGenText2Sql instance with SQLite-specific rules
sqlite_rules = """
1. Use SQLite syntax
2. Do not use Azure SQL specific functions
3. Use strftime for date/time operations
"""

autogen_text2sql = AutoGenText2Sql(
    engine_specific_rules=sqlite_rules,
    use_case="Evaluating with Spider SQLite databases"
)

In [3]:
# Function to generate SQL for a given question
async def generate_sql(question):
    # Capture log output
    import io
    log_capture = io.StringIO()
    handler = logging.StreamHandler(log_capture)
    logger.addHandler(handler)
    
    logger.info(f"Processing question: {question}")
    logger.info(f"Chat history: None")
    
    # Track all SQL queries found
    all_queries = []
    final_query = None
    
    async for message in autogen_text2sql.process_user_message(UserMessagePayload(user_message=question)):
        if message.payload_type == "answer_with_sources":
            # Extract from results
            if hasattr(message.body, 'results'):
                for q_results in message.body.results.values():
                    for result in q_results:
                        if isinstance(result, dict) and 'sql_query' in result:
                            sql_query = result['sql_query'].strip()
                            if sql_query and sql_query != "SELECT NULL -- No query found":
                                all_queries.append(sql_query)
                                logger.info(f"Found SQL query in results: {sql_query}")
            
            # Extract from sources
            if hasattr(message.body, 'sources'):
                for source in message.body.sources:
                    if hasattr(source, 'sql_query'):
                        sql_query = source.sql_query.strip()
                        if sql_query and sql_query != "SELECT NULL -- No query found":
                            all_queries.append(sql_query)
                            logger.info(f"Found SQL query in sources: {sql_query}")
    
    # Get the log text
    log_text = log_capture.getvalue()
    
    # Clean up logging
    logger.removeHandler(handler)
    log_capture.close()
    
    # Log all queries found
    if all_queries:
        logger.info(f"All queries found: {all_queries}")
        # Select the most appropriate query - prefer DISTINCT queries for questions about unique values
        question_lower = question.lower()
        needs_distinct = any(word in question_lower for word in ['different', 'distinct', 'unique', 'all'])
        
        for query in reversed(all_queries):  # Look at queries in reverse order
            if needs_distinct and 'DISTINCT' in query.upper():
                final_query = query
                break
        if not final_query:  # If no DISTINCT query found when needed, use the last query
            final_query = all_queries[-1]
            # Add DISTINCT if needed but not present
            if needs_distinct and 'DISTINCT' not in final_query.upper() and final_query.upper().startswith('SELECT '):
                final_query = final_query.replace('SELECT ', 'SELECT DISTINCT ', 1)
    
    # Log final query
    logger.info(f"Final SQL query: {final_query or 'SELECT NULL -- No query found'}")
    
    return final_query or "SELECT NULL -- No query found"

In [4]:
# Function to read Spider dev set and generate predictions
async def generate_predictions(num_samples=None):
    # Read Spider dev set
    dev_file = SPIDER_DATA_DIR / "dev.json"
    pred_file = TEST_SUITE_DIR / "predictions.txt"
    gold_file = TEST_SUITE_DIR / "gold.txt"
    
    print(f"Reading dev queries from {dev_file}")
    with open(dev_file) as f:
        dev_data = json.load(f)
    
    # Limit number of samples if specified
    if num_samples is not None:
        dev_data = dev_data[:num_samples]
        print(f"\nGenerating predictions for {num_samples} queries...")
    else:
        print(f"\nGenerating predictions for all {len(dev_data)} queries...")
    
    predictions = []
    gold = []
    
    for idx, item in enumerate(dev_data, 1):
        question = item['question']
        db_id = item['db_id']
        gold_query = item['query']
        
        print(f"\nProcessing query {idx}/{len(dev_data)} for database {db_id}")
        print(f"Question: {question}")
        
        # Update database connection string for current database
        db_path = DATABASE_DIR / db_id / f"{db_id}.sqlite"
        os.environ["Text2Sql__DatabaseConnectionString"] = str(db_path)
        os.environ["Text2Sql__DatabaseName"] = db_id
        
        sql = await generate_sql(question)
        predictions.append(f"{sql}\t{db_id}")
        gold.append(f"{gold_query}\t{db_id}")
        print(f"Generated SQL: {sql}")
    
    print(f"\nSaving predictions to {pred_file}")
    with open(pred_file, 'w') as f:
        f.write('\n'.join(predictions))
        
    print(f"Saving gold queries to {gold_file}")
    with open(gold_file, 'w') as f:
        f.write('\n'.join(gold))
    
    return pred_file, gold_file

In [5]:
# Run evaluation using the test suite evaluation script
def run_evaluation():
    # Use absolute paths to ensure correct file locations
    gold_file = TEST_SUITE_DIR / "gold.txt"
    pred_file = TEST_SUITE_DIR / "predictions.txt"
    table_file = SPIDER_DATA_DIR / "tables.json"  # Use Spider's schema file
    
    print(f"Starting evaluation at {time.strftime('%H:%M:%S')}")
    start_time = time.time()
    
    cmd = [
        "python",
        str(TEST_SUITE_DIR / "evaluation.py"),
        "--gold", str(gold_file),
        "--pred", str(pred_file),
        "--db", str(DATABASE_DIR),
        "--table", str(table_file),
        "--etype", "all",
        "--plug_value",
        "--progress_bar_for_each_datapoint"  # Show progress for each test input
    ]
    
    result = subprocess.run(cmd, capture_output=True, text=True)
    
    end_time = time.time()
    duration = end_time - start_time
    
    print("\nEvaluation Results:")
    print("==================")
    print(result.stdout)
    
    print(f"\nEvaluation completed in {duration:.2f} seconds")
    print(f"End time: {time.strftime('%H:%M:%S')}")

In [None]:
# Generate predictions first - now with optional num_samples parameter
await generate_predictions(num_samples=20)  # Generate predictions for just 20 samples (takes about 4 minutes)

In [None]:
# Run evaluation
run_evaluation()