`uv sync --package autogen_text_2_sql`
`uv add --editable text_2_sql_core`


# Evaluate AutoGenText2SQL

This notebook evaluates the AutoGenText2Sql class using the Spider test suite evaluation metric. 

The evaluation uses the official Spider evaluation approach, which requires:

1. A gold file with format: `SQL query \t database_id`
2. A predictions file with generated SQL queries
3. The Spider databases and schema information

### Required Data Downloads

Before running this notebook, you need to download and set up two required directories:

1. Spider Test Suite Evaluation Scripts:
   - Download from: https://github.com/taoyds/test-suite-sql-eval
   - Clone this repository into `/text_2_sql/test-suite-sql-eval/` directory:
   ```bash
   cd text_2_sql
   git clone https://github.com/taoyds/test-suite-sql-eval
   ```

2. Spider Dataset:
   - Download from: https://drive.google.com/file/d/1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J/view
   - Extract the downloaded file into `/text_2_sql/spider_data/` directory
   - The directory should contain:
     - `database/` directory with all the SQLite databases
     - `tables.json` with schema information
     - `dev.json` with development set queries

In [1]:
import sys
import os
import time
import json
import logging
import subprocess
import sqlite3
import dotenv
from pathlib import Path

# Get the notebook directory path
notebook_dir = Path().absolute()
# Add the src directory to the path
sys.path.append(str(notebook_dir / "src"))

from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload
from autogen_text_2_sql.state_store import InMemoryStateStore
from autogen_text_2_sql.evaluation_utils import get_final_sql_query

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set up paths
TEST_SUITE_DIR = Path("../test-suite-sql-eval")
SPIDER_DATA_DIR = Path("../spider_data").absolute()
DATABASE_DIR = SPIDER_DATA_DIR / "database"

# Set SPIDER_DATA_DIR in environment so SQLiteSqlConnector can find tables.json
os.environ["SPIDER_DATA_DIR"] = str(SPIDER_DATA_DIR)

# Load environment variables
dotenv.load_dotenv()

# Initialize state store and AutoGenText2Sql instance with SQLite-specific rules
state_store = InMemoryStateStore()
sqlite_rules = """
1. Use SQLite syntax
2. Do not use Azure SQL specific functions
3. Use strftime for date/time operations
"""

autogen_text2sql = AutoGenText2Sql(
    state_store=state_store,
    engine_specific_rules=sqlite_rules,
    use_case="Evaluating with Spider SQLite databases"
)

In [2]:
def combine_aggregation_queries(queries):
    """Combine multiple aggregation queries into a single query."""
    if not queries:
        return None
        
    # Extract the common FROM and WHERE clauses from the first query
    base_query = queries[0]
    from_start = base_query.lower().find('from')
    if from_start == -1:
        return queries[0]  # Can't combine if no FROM clause
        
    table_and_condition = base_query[from_start:]
    
    # Extract all aggregations while preserving case and aliases
    aggs = []
    for query in queries:
        select_part = query[:query.lower().find('from')].strip()
        agg_part = select_part.replace('SELECT', '').strip()
        aggs.append(agg_part)
    
    # Combine into a single query while preserving case
    combined_query = f"SELECT {', '.join(aggs)} {table_and_condition}"
    return combined_query

def execute_query(query, db_path):
    """Execute a SQL query and return properly typed results."""
    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        cursor.execute(query)
        results = cursor.fetchall()
        
        if not results:
            return None
            
        # Get column names from cursor description
        columns = [desc[0] for desc in cursor.description]
        
        # Create a dictionary mapping column names to values
        result_dict = {}
        for i, col in enumerate(columns):
            value = results[0][i]
            
            # Handle numeric conversions
            if value is not None:
                try:
                    if isinstance(value, (int, float)):
                        # Already numeric, just round floats
                        result_dict[col] = round(float(value), 2) if isinstance(value, float) else value
                    else:
                        # Try converting to numeric
                        try:
                            result_dict[col] = int(value)
                        except ValueError:
                            try:
                                result_dict[col] = round(float(value), 2)
                            except ValueError:
                                result_dict[col] = value
                except (ValueError, TypeError) as e:
                    logger.warning(f"Error converting value for {col}: {value}, Error: {str(e)}")
                    result_dict[col] = value
            else:
                result_dict[col] = None
                
        return result_dict
        
    except Exception as e:
        logger.error(f"Error executing query: {e}")
        return None
    finally:
        if 'conn' in locals():
            conn.close()

# Function to generate SQL for a given question
async def generate_sql(question):
    # Capture log output
    import io
    log_capture = io.StringIO()
    handler = logging.StreamHandler(log_capture)
    logger.addHandler(handler)
    
    logger.info(f"Processing question: {question}")
    logger.info(f"Chat history: None")
    
    # Track all SQL queries found
    all_queries = []
    final_query = None
    
    # Check if the question involves aggregation
    agg_keywords = ['average', 'avg', 'minimum', 'min', 'maximum', 'max', 'count', 'sum']
    is_aggregation = any(keyword in question.lower() for keyword in agg_keywords)
    
    # Create a unique thread ID for each question
    thread_id = f"eval_{hash(question)}"
    message_payload = UserMessagePayload(user_message=question)
    
    async for message in autogen_text2sql.process_user_message(
        thread_id=thread_id,
        message_payload=message_payload
    ):
        if message.payload_type == "answer_with_sources":
            # Extract queries from results and sources
            if hasattr(message.body, 'results'):
                for q_results in message.body.results.values():
                    for result in q_results:
                        if isinstance(result, dict) and 'sql_query' in result:
                            sql_query = result['sql_query'].strip()
                            if sql_query and sql_query != "SELECT NULL -- No query found":
                                all_queries.append(sql_query)
                                logger.info(f"Found SQL query in results: {sql_query}")
            
            if hasattr(message.body, 'sources'):
                for source in message.body.sources:
                    if hasattr(source, 'sql_query'):
                        sql_query = source.sql_query.strip()
                        if sql_query and sql_query != "SELECT NULL -- No query found":
                            all_queries.append(sql_query)
                            logger.info(f"Found SQL query in sources: {sql_query}")
    
    # Get the log text
    log_text = log_capture.getvalue()
    
    # Clean up logging
    logger.removeHandler(handler)
    log_capture.close()
    
    # Process queries
    if all_queries:
        logger.info(f"All queries found: {all_queries}")
        
        if is_aggregation and len(all_queries) > 1:
            # For aggregation questions with multiple queries, try to combine them
            agg_queries = [q for q in all_queries if any(agg in q.upper() 
                          for agg in ['COUNT', 'SUM', 'AVG', 'MIN', 'MAX'])]
            if agg_queries:
                final_query = combine_aggregation_queries(agg_queries)
        
        if not final_query:
            # If no aggregation combination or not needed, use standard selection
            question_lower = question.lower()
            needs_distinct = any(word in question_lower 
                                for word in ['different', 'distinct', 'unique', 'all'])
            
            for query in reversed(all_queries):
                if needs_distinct and 'DISTINCT' in query.upper():
                    final_query = query
                    break
            
            if not final_query:
                final_query = all_queries[-1]
                if needs_distinct and 'DISTINCT' not in final_query.upper() \
                   and final_query.upper().startswith('SELECT '):
                    final_query = final_query.replace('SELECT ', 'SELECT DISTINCT ', 1)
    
    # Log final query
    logger.info(f"Final SQL query: {final_query or 'SELECT NULL -- No query found'}")
    
    return final_query or "SELECT NULL -- No query found"

In [3]:
# Function to read Spider dev set and generate predictions
async def generate_predictions(num_samples=None):
    # Read Spider dev set
    dev_file = SPIDER_DATA_DIR / "dev.json"
    pred_file = TEST_SUITE_DIR / "predictions.txt"
    gold_file = TEST_SUITE_DIR / "gold.txt"
    
    print(f"Reading dev queries from {dev_file}")
    with open(dev_file) as f:
        dev_data = json.load(f)
    
    # Limit number of samples if specified
    if num_samples is not None:
        dev_data = dev_data[:num_samples]
        print(f"\nGenerating predictions for {num_samples} queries...")
    else:
        print(f"\nGenerating predictions for all {len(dev_data)} queries...")
    
    predictions = []
    gold = []
    
    for idx, item in enumerate(dev_data, 1):
        question = item['question']
        db_id = item['db_id']
        gold_query = item['query']
        
        print(f"\nProcessing query {idx}/{len(dev_data)} for database {db_id}")
        print(f"Question: {question}")
        
        # Update database connection string for current database
        db_path = DATABASE_DIR / db_id / f"{db_id}.sqlite"
        os.environ["Text2Sql__DatabaseConnectionString"] = str(db_path)
        os.environ["Text2Sql__DatabaseName"] = db_id
        
        sql = await generate_sql(question)
        
        # For aggregation queries, execute and validate the results
        if any(agg in sql.upper() for agg in ['COUNT', 'SUM', 'AVG', 'MIN', 'MAX']):
            results = execute_query(sql, db_path)
            if results:
                logger.info(f"Query results: {results}")
                # Verify numeric results for aggregations
                for key, value in results.items():
                    if not isinstance(value, (int, float)):
                        logger.warning(f"Non-numeric aggregation result: {key}={value}")
        
        predictions.append(f"{sql}\t{db_id}")
        gold.append(f"{gold_query}\t{db_id}")
        print(f"Generated SQL: {sql}")
    
    print(f"\nSaving predictions to {pred_file}")
    with open(pred_file, 'w') as f:
        f.write('\n'.join(predictions))
        
    print(f"Saving gold queries to {gold_file}")
    with open(gold_file, 'w') as f:
        f.write('\n'.join(gold))
    
    return pred_file, gold_file

In [4]:
# Run evaluation using the test suite evaluation script
def run_evaluation():
    # Use absolute paths to ensure correct file locations
    gold_file = TEST_SUITE_DIR / "gold.txt"
    pred_file = TEST_SUITE_DIR / "predictions.txt"
    table_file = SPIDER_DATA_DIR / "tables.json"  # Use Spider's schema file
    
    print(f"Starting evaluation at {time.strftime('%H:%M:%S')}")
    start_time = time.time()
    
    cmd = [
        "python",
        str(TEST_SUITE_DIR / "evaluation.py"),
        "--gold", str(gold_file),
        "--pred", str(pred_file),
        "--db", str(DATABASE_DIR),
        "--table", str(table_file),
        "--etype", "all",
        "--plug_value",
        "--progress_bar_for_each_datapoint"  # Show progress for each test input
    ]
    
    result = subprocess.run(cmd, capture_output=True, text=True)
    
    end_time = time.time()
    duration = end_time - start_time
    
    print("\nEvaluation Results:")
    print("==================")
    print(result.stdout)
    
    print(f"\nEvaluation completed in {duration:.2f} seconds")
    print(f"End time: {time.strftime('%H:%M:%S')}")

In [None]:
# Generate predictions first - now with optional num_samples parameter
await generate_predictions(num_samples=5)  # Generate predictions for just 5 samples

In [None]:
# Run evaluation
run_evaluation()