### Formatting Train Data

In [None]:
def get_table_contents(db_path, table_name, columns):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    contents = {col: [] for col in columns}
    
    quoted_table_name = quote_identifier(table_name)
    
    for col in columns:
        quoted_col = quote_identifier(col)
        try:
            query = f"SELECT {quoted_col} FROM {quoted_table_name} LIMIT 5"
            cursor.execute(query)
            rows = cursor.fetchall()
            contents[col] = [str(row[0]) for row in rows]
        except sqlite3.OperationalError as e:
            print(f"{db_path} Error querying {col} in {table_name}: {str(e)}")
              
            contents[col] = []  # Leave content blank for this column
    
    conn.close()
    return contents

In [None]:
import json
import sqlite3
import os
import csv
from datasets import load_dataset
from tqdm import tqdm

def quote_identifier(identifier):
    return f'"{identifier}"'
def get_table_info(db_id, tables_file):
    with open(tables_file, 'r') as f:
        tables_data = json.load(f)

    for entry in tables_data:
        if entry['db_id'] == db_id:
            return entry['table_names_original'], entry['column_names_original']

    return None, None

def format_spider_data():
    dataset = load_dataset("spider", split="train")
    tables_file = "spider/tables.json"
    
    with open("formatted_train_data.csv", "w", newline='', encoding='utf-8') as csv_file:
        csv_writer = csv.writer(csv_file, quoting=csv.QUOTE_ALL)
        csv_writer.writerow(["question", "query", "db_id", "table_info"])  # Updated header
        
        for item in tqdm(dataset, desc="Formatting SPIDER data", unit="item"):
            question = item['question']
            query = item['query']
            db_id = item['db_id']
            
            table_names, column_names = get_table_info(db_id, tables_file)
            if not table_names:
                continue
            
            # replace relative database path 
            db_path = f"spider/database/{db_id}/{db_id}.sqlite"
            
            table_info = []
            
            try:
                for table_index, table in enumerate(table_names):
                    table_columns = [col[1] for col in column_names if col[0] == table_index]
                    table_contents = get_table_contents(db_path, table, table_columns)
                    
                    table_data = f"{table}: "
                    column_data = []
                    for col in table_columns:
                        column_content = ', '.join(table_contents[col])
                        column_data.append(f"{col} ({column_content})")
                    table_data += " | ".join(column_data)
                    table_info.append(table_data)
                
                table_info_str = " | ".join(table_info)
                
                csv_writer.writerow([question, query, db_id, table_info_str])
            except sqlite3.OperationalError as e:
                print(f"Error opening database: {db_path}")
                print(f"Error message: {str(e)}")
                continue
    
    print("Formatting complete. Results written to formatted_train_data.csv")

format_spider_data()