In [None]:
#| hide

%load_ext autoreload
%autoreload 2

# Preprocess Bird Benchmark

> Here, we load, process, and transform the bird benchmark.

In [None]:
#| default_exp preprocess_bird

In [None]:
#| export 

import json
from claimdb.configuration import *
from claimdb.dbops import *

## Convert all CSV schemas to UTF-8

In [None]:
import chardet
from pathlib import Path
import pandas as pd

In [None]:
#| notest

def convert_csv_to_utf8(csv_path: Path):
    """Convert a CSV file to UTF-8 encoding."""
    # Detect current encoding
    with open(csv_path, 'rb') as f:
        raw_data = f.read()
        result = chardet.detect(raw_data)
        detected_encoding = result['encoding']

    if not detected_encoding:
        raise
    
    # Read with detected encoding, write as UTF-8
    try:
        content = raw_data.decode(detected_encoding)
        # Remove BOM if present
        if content.startswith('\ufeff'):
            content = content[1:]
        
        with open(csv_path, 'w', encoding='utf-8') as f:
            f.write(content)
        
        return True, detected_encoding
    except Exception as e:
        return False, str(e)

# Convert all CSV files in BIRD databases
converted = []
failed = []

for db_dir in config.bird_databases_dir.iterdir():
    desc_dir = db_dir / 'database_description'
    if desc_dir.exists():
        for csv_file in desc_dir.glob('*.csv'):
            success, info = convert_csv_to_utf8(csv_file)
            if success:
                converted.append((csv_file.name, info))
            else:
                failed.append((csv_file.name, info))

print(f"Converted {len(converted)} files to UTF-8")
print(f"Failed: {len(failed)} files")

if failed:
    print("\nFailed files:")
    for name, error in failed:
        print(f"  {name}: {error}")

In [None]:
#| notest

problematic = []
for db_dir in config.bird_databases_dir.iterdir():
    desc_dir = db_dir / 'database_description'
    if desc_dir.exists():
        for csv_file in desc_dir.glob('*.csv'):
            with open(csv_file, 'rb') as f:
                result = chardet.detect(f.read())
            # UTF-8-SIG is just UTF-8 with BOM - pandas handles it fine
            if result['encoding'] not in ['ascii', 'utf-8', 'UTF-8-SIG']:
                problematic.append((csv_file, result))
                print(f"{csv_file.name}: {result['encoding']} ({result['confidence']:.2f})")

print(f"\nFound {len(problematic)} problematic files")

## Prepare Data

Here, we load, filter, and merge the bird benchmark dev and test sets.

### Read & Merge Dev/Test sets

In [None]:
config.bird_dir

Load **dev** and **train** data

In [None]:
# Load dev and train data
with open(config.bird_dir / 'dev.json', 'r') as f:
    dev_data = json.load(f)

with open(config.bird_dir / 'train.json', 'r') as f:
    train_data = json.load(f)

dev_data[0]

Now we will **Annotate and save combined data** in `jsonl` format.:

1. Add split label to each entry
2. Merge the two sets into one
3. Add a unique ID to each entry (of the combined set)

In [None]:
#| notest

# Add split label to each entry
for item in dev_data:
    item['split'] = 'dev'

for item in train_data:
    item['split'] = 'train'

# Merge
data = dev_data + train_data

for i, item in enumerate(data):
    item['bird_id'] = i

# Save merged data
with open(config.bird_dir / 'train_dev.jsonl', 'w') as f:
    for item in data:
        f.write(json.dumps(item) + '\n')

In [None]:
# Read JSONL
with open(config.bird_dir / 'train_dev.jsonl', 'r') as f:
    data = [json.loads(line) for line in f]

data[0]

## Schema Parsing

In [None]:
#| export

import pandas as pd
from pathlib import Path

In [None]:
#| export

def load_bird_table_description(csv_path: Path) -> str:
    """Load a single table description from BIRD CSV format as compact text."""
    df = pd.read_csv(csv_path)
    
    table_name = csv_path.stem  # e.g., "Air Carriers"
    
    lines = [f"Table: {table_name}"]
    lines.append("Columns:")
    
    for _, row in df.iterrows():
        col_name = row['column_name'] if pd.notna(row['column_name']) else row['original_column_name']

        # Build compact column line
        col_line = f"  - {col_name}"
        
        lines.append(col_line)
    
    return '\n'.join(lines)

In [None]:
#| export

def load_bird_database_schema(db_path: Path) -> str:
    """Load all table descriptions for a BIRD database as compact text."""
    desc_dir = db_path / 'database_description'
    
    if not desc_dir.exists():
        return None
    
    tables = []
    for csv_file in desc_dir.glob('*.csv'):
        table_info = load_bird_table_description(csv_file)
        tables.append(table_info)
    
    return '\n\n'.join(tables)

In [None]:
#| notest

# Load schema for a single database
db_path = config.bird_databases_dir / 'authors'
schema = load_bird_database_schema(db_path)

print(schema)

## Filter out Low-Information Data

Now, we will filter the data so that they are **high-information** and also their answer is below 10 rows (parsable by LLMs).

In [None]:
# Read JSONL
with open(config.bird_dir / 'train_dev.jsonl', 'r') as f:
    data = [json.loads(line) for line in f]

data[0]

In [None]:
#| notest
import tqdm

In [None]:
#| notest

filtered = []

for item in tqdm.tqdm(data):
    query = item['SQL']
    
    if not is_query_high_information(query):
        continue

    dbdir = config.bird_databases_dir / item['db_id']
    dbpath = dbdir / f"{item['db_id']}.sqlite"

    result = sqlite_execute_with_timeout(dbpath, query)
    
    if len(result) > 10:
        continue

    item['result'] = result

    item['db-schema'] = load_bird_database_schema(dbdir)

    filtered.append(item)

In [None]:
#| notest
len(filtered)

In [None]:
#| notest

with open(config.bird_dir / 'train_dev_filtered.jsonl', 'w') as f:
    for item in filtered:
        f.write(json.dumps(item) + '\n')

In [None]:
with open(config.bird_dir / 'train_dev_filtered.jsonl', 'r') as f:
    filtered = [json.loads(line) for line in f]

In [None]:
filtered[2000]

## LLM-based Formats

In [None]:
#| export

def format_for_llm(d):
    """Format database schema as JSON for LLMs that prefer structured input."""
    return json.dumps(d, ensure_ascii=False, indent=2)

In [None]:
#| export

def convert_db_name(db_id): # The database ID (e.g., 'world_1')
    """ Converts a database ID to a more human-readable format. """
    return db_id.replace('_', ' ').title()

In [None]:
#| export

def prepare_bird_example(example: dict, with_schema: bool = False):
    d = {
        'question': example['question'],
        'answer': example['result'],
        'domain': convert_db_name(example['db_id']),
        'external-knowledge': example['evidence']
    }

    if with_schema:
        d['db-schema'] = example['db-schema']

    return d

In [None]:
print(
    prepare_bird_example(filtered[0])
)

In [None]:
print(
    format_for_llm(prepare_bird_example(filtered[0]))
)

In [None]:
print(
    prepare_bird_example(filtered[0], with_schema=True)
)

In [None]:
formatted = format_for_llm(prepare_bird_example(filtered[0], with_schema=True))

In [None]:
decoded = json.loads(formatted)

In [None]:
print(decoded)

## Filter Out Timeouts + Errors

The SQL answer can return `"error"` or `"timeout"` so we filter this here.

In [None]:
#| notest
from claimdb.configuration import *
import json

In [None]:
#| notest
with open(config.bird_dir / 'train_dev_filtered.jsonl', 'r') as f:
    bird_data = [json.loads(line) for line in f]

In [None]:
#| notest
problematic_bird_ids = []

for ex in bird_data:
    if ex['result'] == 'error' or ex['result'] == 'timeout':
        problematic_bird_ids.append(ex['bird_id'])

In [None]:
#| notest
len(problematic_bird_ids)

In [None]:
#| notest
with open(config.bird_dir / 'train_dev_filtered.jsonl', 'r') as f:
    bird_data = [
        json.loads(line) for line in f
    ]

corrected_bird_data = []
for ex in bird_data:
    if ex['bird_id'] not in problematic_bird_ids:
        corrected_bird_data.append(ex)

In [None]:
#| notest
assert len(corrected_bird_data) == len(bird_data) - len(problematic_bird_ids)

In [None]:
#| notest
with open(config.bird_dir / 'train_dev_filtered.jsonl', 'w') as f:
    for item in corrected_bird_data:
        f.write(json.dumps(item) + '\n')

## Filter Out Semantic Mistake Dev Data

Following the [paper](https://dl.acm.org/doi/10.1145/3711896.3737427) we filter out 106 question IDs since they are semantically incorrect.


In [None]:
#| notest
import json
from claimdb.configuration import *
from claimdb.dbops import *

In [None]:
#| notest
bad_question_ids = [1027, 1029, 519, 523, 530, 23, 70, 72, 584, 1107, 600, 602, 603, 94, 1119, 1120, 1121, 631, 632, 635, 125, 639, 640, 129, 642, 646, 649, 144, 145, 656, 1170, 667, 679, 682, 1197, 686, 687, 1199, 1204, 693, 182, 186, 194, 1219, 709, 710, 1225, 1233, 1243, 221, 1247, 1248, 1256, 1265, 1269, 247, 1273, 1274, 252, 254, 1279, 1284, 271, 1300, 281, 1308, 296, 1322, 812, 309, 341, 342, 343, 855, 349, 360, 1388, 386, 387, 388, 389, 398, 406, 1450, 1454, 431, 1458, 441, 442, 443, 446, 447, 966, 458, 970, 1482, 973, 978, 1491, 986, 993, 484, 1000, 1004, 1530, 1531]

Convert them to our `bird_id` unique identifiers.

In [None]:
#| notest
bad_bird_ids = [70, 94, 129, 144, 182, 186, 254, 309, 342, 349, 386, 446, 458, 523, 603, 631, 632, 639, 640, 642, 656, 667, 679, 682, 687, 693, 709, 710, 966, 970, 973, 986, 1000, 1004, 1027, 1029, 1107, 1119, 1120, 1121, 1170, 1199, 1219, 1243, 1247, 1248, 1256, 1265, 1273, 1274, 1279, 1284, 1300, 1308, 1322, 1388, 1454, 1458, 1482, 1491, 1530, 1531]

In [None]:
#| notest
bad_bird_ids = []

with open(config.bird_dir / 'train_dev_filtered.jsonl', 'r') as f:
    for line in f:
        item = json.loads(line)
        if item['split'] == 'dev' and item['question_id'] in bad_question_ids:
            bad_bird_ids.append(item['bird_id'])

In [None]:
#| notest
len(bad_bird_ids)

In [None]:
#| notest
good_bird_data = []

with open(config.bird_dir / 'train_dev_filtered.jsonl', 'r') as f:
    for line in f:
        item = json.loads(line)
        if item['bird_id'] not in bad_bird_ids:
            good_bird_data.append(item)

In [None]:
#| notest
with open(config.bird_dir / 'train_dev_filtered.jsonl', 'w') as f:
    for item in good_bird_data:
        f.write(json.dumps(item) + '\n')

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()