In [4]:
import json
import os
import re
from collections import defaultdict
from typing import List, Dict, Any

In [5]:
def load_spider_dataset(dataset_path: str) -> List[Dict[str, Any]]:
    """
    Loads and parses the Spider dataset from the given path.
    
    :param dataset_path: Path to the Spider dataset folder.
    :return: A list of dictionaries containing NL questions, SQL queries, and schema metadata.
    """
    with open(os.path.join(dataset_path, 'test.json'), 'r', encoding='utf-8') as f:
        train_data = json.load(f)
    
    with open(os.path.join(dataset_path, 'test_tables.json'), 'r', encoding='utf-8') as f:
        tables_data = json.load(f)
    
    # Map database schemas
    schema_map = {table["db_id"]: table for table in tables_data}
    
    parsed_data = []
    
    for item in train_data:
        db_id = item["db_id"]
        schema = schema_map.get(db_id, {})
        parsed_data.append({
            "db_id": db_id,
            "question": item["question"],
            "query": item["query"],
            "schema": schema
        })
    
    return parsed_data


def format_for_model(parsed_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Formats the parsed Spider dataset into a structured format as per the requested output.
    
    :param parsed_data: List of parsed Spider dataset entries.
    :return: A list of formatted dictionaries.
    """
    formatted_data = []
    
    for idx, entry in enumerate(parsed_data):
        db_id = entry["db_id"]
        table_names = entry["schema"].get("table_names_original", [])
        columns = entry["schema"].get("column_names_original", [])
        foreign_keys = entry["schema"].get("foreign_keys", [])
        column_idx_map = {
            idx: (table_names[table_idx], col_name)
            for idx, (table_idx, col_name) in enumerate(columns)
            if table_idx >= 0
        }
        
        # Generate simplified DDL
        simplified_ddl = "\n".join(
            [f"# {table}({', '.join(col for i, col in columns if i == tid)})" for tid, table in enumerate(table_names)]
        )
        
        # Generate full DDL
        full_ddl_statements = []
        for tid, table in enumerate(table_names):
            table_columns = [col for i, col in columns if i == tid]
            col_definitions = ", ".join([f"{col} TEXT" for col in table_columns])
            full_ddl_statements.append(f"CREATE TABLE {table}({col_definitions});")
        
        full_ddl = "\n\n".join(full_ddl_statements)
        
        # Foreign keys with correct mapping
        foreign_key_constraints = []
        for fk in foreign_keys:
            if len(fk) == 2 and fk[0] in column_idx_map and fk[1] in column_idx_map:
                child_table, child_column = column_idx_map[fk[0]]
                parent_table, parent_column = column_idx_map[fk[1]]
                fk_constraint = f"{child_table}({child_column}) REFERENCES {parent_table}({parent_column})"
                foreign_key_constraints.append(fk_constraint)
        
        formatted_data.append({
            "id": idx,
            "db": db_id,
            "question": entry["question"],
            "gold_sql": entry["query"],
            "simplified_ddl": simplified_ddl,
            "full_ddl": full_ddl,
            "foreign_key": foreign_key_constraints
        })
    
    return formatted_data

In [6]:

dataset_path = "./spider_data"  # Change this to your actual dataset location
parsed_data = load_spider_dataset(dataset_path)
formatted_data = format_for_model(parsed_data)

# Save output
with open("parsed_data.json", "w", encoding="utf-8") as f:
    json.dump(formatted_data, f, indent=4)

# Example output
for sample in formatted_data[:5]:
    print(json.dumps(sample, indent=4))
    print("-" * 80)

{
    "id": 0,
    "db": "soccer_3",
    "question": "How many clubs are there?",
    "gold_sql": "SELECT count(*) FROM club",
    "simplified_ddl": "# club(Club_ID, Name, Manager, Captain, Manufacturer, Sponsor)\n# player(Player_ID, Name, Country, Earnings, Events_number, Wins_count, Club_ID)",
    "full_ddl": "CREATE TABLE club(Club_ID TEXT, Name TEXT, Manager TEXT, Captain TEXT, Manufacturer TEXT, Sponsor TEXT);\n\nCREATE TABLE player(Player_ID TEXT, Name TEXT, Country TEXT, Earnings TEXT, Events_number TEXT, Wins_count TEXT, Club_ID TEXT);",
    "foreign_key": [
        "player(Club_ID) REFERENCES club(Club_ID)"
    ]
}
--------------------------------------------------------------------------------
{
    "id": 1,
    "db": "soccer_3",
    "question": "Count the number of clubs.",
    "gold_sql": "SELECT count(*) FROM club",
    "simplified_ddl": "# club(Club_ID, Name, Manager, Captain, Manufacturer, Sponsor)\n# player(Player_ID, Name, Country, Earnings, Events_number, Wins_count, 