In [1]:
%run ../utils/init_env.py

In [2]:
import json
import os
import sys
import config
from collections import defaultdict
from typing import List, Dict, Any

In [12]:
def load_spider_dataset(dataset_path: str, tables_path: str) -> List[Dict[str, Any]]:
    """
    Loads and parses the Spider dataset from the given path.
    
    :param dataset_path: Path to the Spider dataset folder.
    :return: A list of dictionaries containing NL questions, SQL queries, and schema metadata.
    """
    with open(os.path.join(dataset_path), 'r', encoding='utf-8') as f:
        train_data = json.load(f)
    
    with open(os.path.join(tables_path), 'r', encoding='utf-8') as f:
        tables_data = json.load(f)
    
    # Map database schemas
    schema_map = {table["db_id"]: table for table in tables_data}
    
    parsed_data = []
    
    for item in train_data:
        db_id = item["db_id"]
        schema = schema_map.get(db_id, {})
        parsed_data.append({
            "db_id": db_id,
            "question": item["question"],
            "query": item["query"],
            "schema": schema
        })
    
    return parsed_data


def format_for_model(parsed_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Formats the parsed Spider dataset into a structured format as per the requested output.
    
    :param parsed_data: List of parsed Spider dataset entries.
    :return: A list of formatted dictionaries.
    """
    formatted_data = []
    
    for idx, entry in enumerate(parsed_data):
        db_id = entry["db_id"]
        table_names = entry["schema"].get("table_names_original", [])
        columns = entry["schema"].get("column_names_original", [])
        foreign_keys = entry["schema"].get("foreign_keys", [])
        column_idx_map = {
            idx: (table_names[table_idx], col_name)
            for idx, (table_idx, col_name) in enumerate(columns)
            if table_idx >= 0
        }
        
        # Generate simplified DDL
        simplified_ddl = "\n".join(
            [f"{table}({', '.join(col for i, col in columns if i == tid)})" for tid, table in enumerate(table_names)]
        )
        
        # Generate full DDL
        full_ddl_statements = []
        for tid, table in enumerate(table_names):
            table_columns = [col for i, col in columns if i == tid]
            col_definitions = ", ".join([f"{col} TEXT" for col in table_columns])
            full_ddl_statements.append(f"CREATE TABLE {table}({col_definitions});")
        
        full_ddl = "\n\n".join(full_ddl_statements)
        
        # Foreign keys with correct mapping
        foreign_key_constraints = []
        for fk in foreign_keys:
            if len(fk) == 2 and fk[0] in column_idx_map and fk[1] in column_idx_map:
                child_table, child_column = column_idx_map[fk[0]]
                parent_table, parent_column = column_idx_map[fk[1]]
                fk_constraint = f"{child_table}({child_column}) REFERENCES {parent_table}({parent_column})"
                foreign_key_constraints.append(fk_constraint)
        
        formatted_data.append({
            "id": idx,
            "db": db_id,
            "question": entry["question"],
            "gold_sql": entry["query"],
            "simplified_ddl": simplified_ddl,
            "full_ddl": full_ddl,
            "foreign_key": foreign_key_constraints
        })
    
    return formatted_data

In [15]:

dataset_path = config.SPIDER_DATA_DIR
datasets = {
    'train': {
        'data': 'train_spider',
        'tables': 'tables'
    }, 
    'dev': {
        'data': 'dev',
        'tables': 'tables'
    },
    'test': {
        'data': 'test',
        'tables': 'test_tables'
    } 
}

for partition, dataset in datasets.items():
    parsed_data = load_spider_dataset(
        os.path.join(dataset_path, dataset['data']+'.json'),
        os.path.join(dataset_path, dataset['tables']+'.json')
    )
    formatted_data = format_for_model(parsed_data)

    # Save output
    with open(os.path.join(config.DATA_DIR,partition+'.json'), "w", encoding="utf-8") as f:
        json.dump(formatted_data, f, indent=4)
