# Exploring the Spider Dataset

This notebook explores the Spider text-to-SQL dataset to understand:
1. Dataset structure (questions, SQL queries, schemas)
2. What data we'll use for retrieval (schemas, examples, similar questions)
3. How to format this information for our RAG system

In [None]:
import json
import os
from pathlib import Path
import pandas as pd
import sqlite3

# Set path to your Spider data
SPIDER_DATA_PATH = Path(r'F:\text2sql\spider_data')

print(f"Spider data location: {SPIDER_DATA_PATH}")
print(f"Exists: {SPIDER_DATA_PATH.exists()}")

## 1. Load Training Data

Load `train_spider.json` to see example questions and SQL queries

In [None]:
# Load training data
with open(SPIDER_DATA_PATH / 'train_spider.json', 'r') as f:
    train_data = json.load(f)

print(f"Number of training examples: {len(train_data)}")
print(f"\nFirst example:")
print(json.dumps(train_data[0], indent=2))

## 2. Examine Example Structure

Look at a few examples to understand the data format

In [None]:
# Look at first 5 examples
print("First 5 training examples:\n")
for i, example in enumerate(train_data[:5]):
    print(f"Example {i+1}:")
    print(f"  Database: {example['db_id']}")
    print(f"  Question: {example['question']}")
    print(f"  SQL: {example['query']}")
    print()

## 3. Load Schema Information

Load `tables.json` to see database schemas (tables, columns, types)

In [None]:
# Load schema information
with open(SPIDER_DATA_PATH / 'tables.json', 'r') as f:
    tables_data = json.load(f)

print(f"Number of databases: {len(tables_data)}")
print(f"\nFirst database schema:")
print(json.dumps(tables_data[0], indent=2))

## 4. Examine a Specific Database Schema

Look at the schema structure in detail

In [None]:
# Pick a database and examine its schema
example_db = tables_data[0]

print(f"Database ID: {example_db['db_id']}")
print(f"\nTable names: {example_db['table_names_original']}")
print(f"\nColumn names (first 10):")
for i, col in enumerate(example_db['column_names_original'][:10]):
    table_idx, col_name = col
    if table_idx >= 0:
        table_name = example_db['table_names_original'][table_idx]
        print(f"  {table_name}.{col_name}")
    else:
        print(f"  {col_name}")

print(f"\nColumn types (first 10): {example_db['column_types'][:10]}")

## 5. Understand Foreign Keys

Foreign keys show relationships between tables

In [None]:
# Look at foreign keys
if 'foreign_keys' in example_db:
    print(f"Foreign keys in {example_db['db_id']}:")
    for fk in example_db['foreign_keys']:
        from_col_idx, to_col_idx = fk
        from_table_idx, from_col = example_db['column_names_original'][from_col_idx]
        to_table_idx, to_col = example_db['column_names_original'][to_col_idx]
        from_table = example_db['table_names_original'][from_table_idx]
        to_table = example_db['table_names_original'][to_table_idx]
        print(f"  {from_table}.{from_col} -> {to_table}.{to_col}")

## 6. Load Development Set

This is what we'll use for evaluation

In [None]:
# Load dev set
with open(SPIDER_DATA_PATH / 'dev.json', 'r') as f:
    dev_data = json.load(f)

print(f"Number of dev examples: {len(dev_data)}")
print(f"\nFirst dev example:")
print(json.dumps(dev_data[0], indent=2))

## 7. Database Statistics

Understand the distribution of questions across databases

In [None]:
# Count questions per database in training set
from collections import Counter

train_db_counts = Counter([ex['db_id'] for ex in train_data])
dev_db_counts = Counter([ex['db_id'] for ex in dev_data])

print(f"Number of unique databases in training: {len(train_db_counts)}")
print(f"Number of unique databases in dev: {len(dev_db_counts)}")

print(f"\nTop 10 databases by training examples:")
for db_id, count in train_db_counts.most_common(10):
    print(f"  {db_id}: {count} examples")

## 8. Test Database Connection

Try connecting to one of the SQLite databases

In [None]:
# Connect to a database and examine it
test_db_id = train_data[0]['db_id']
db_path = SPIDER_DATA_PATH / 'database' / test_db_id / f'{test_db_id}.sqlite'

print(f"Testing connection to: {db_path}")
print(f"Database exists: {db_path.exists()}")

if db_path.exists():
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    # Get list of tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    print(f"\nTables in {test_db_id}:")
    for table in tables:
        print(f"  - {table[0]}")
    
    conn.close()
else:
    print("Database file not found!")

## 9. Execute a Sample Query

Test executing a SQL query from the dataset

In [None]:
# Execute the first training example
example = train_data[0]
db_path = SPIDER_DATA_PATH / 'database' / example['db_id'] / f"{example['db_id']}.sqlite"

print(f"Question: {example['question']}")
print(f"SQL: {example['query']}")

if db_path.exists():
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    try:
        cursor.execute(example['query'])
        results = cursor.fetchall()
        print(f"\nResults: {results[:5]}")
        print(f"Total rows: {len(results)}")
    except Exception as e:
        print(f"Error executing query: {e}")
    
    conn.close()

## 10. Summary: What We'll Use for RAG

Based on this exploration, here's what we'll retrieve:

1. **Schema Information** (from `tables.json`):
   - Table names
   - Column names and types
   - Foreign key relationships

2. **Example Queries** (from `train_spider.json`):
   - Natural language questions
   - Corresponding SQL queries
   - Database context

3. **Similar Questions**:
   - Questions from the same database
   - Questions with similar SQL patterns

Next steps:
- Build embeddings for questions and schemas
- Create FAISS index
- Implement retrieval function

In [None]:
# Print summary statistics
print("=" * 50)
print("SPIDER DATASET SUMMARY")
print("=" * 50)
print(f"Training examples: {len(train_data)}")
print(f"Dev examples: {len(dev_data)}")
print(f"Number of databases: {len(tables_data)}")
print(f"Unique databases in training: {len(train_db_counts)}")
print(f"Unique databases in dev: {len(dev_db_counts)}")
print("=" * 50)