# Complete SQL Evaluation with Full Spider Metrics

**Comprehensive evaluation including:**
1. String-based metrics (exact match, keyword F1)
2. Execution accuracy (database result matching)
3. **Full Spider component-wise evaluation** (SELECT, WHERE, GROUP BY, ORDER BY, HAVING, AND/OR, IUEN, keywords)
4. Accuracy, Recall, and F1 scores for each SQL component

## Configuration

In [164]:
import json
import os
import sqlite3
import re
from statistics import mean
from collections import defaultdict
from pathlib import Path

In [166]:
NOTEBOOK_DIR = Path.cwd()  # Current notebook location
PROJECT_ROOT = NOTEBOOK_DIR.parent if NOTEBOOK_DIR.name == 'notebooks' else NOTEBOOK_DIR

PROJECT_ROOT

WindowsPath('c:/Users/aswat/OneDrive - UWA/Desktop/sem4/CITS5553/explainable-nl-query-db-agents')

In [167]:
# === CONFIGURATION ===
FILE_PATH = PROJECT_ROOT/"SQL_Prediction_Result.txt"
DB_DIR = PROJECT_ROOT/"notebooks"/"database"

USE_EXEC = True
SHOW_FIRST_N_MISMATCHES = 10
FILE_PATH

WindowsPath('c:/Users/aswat/OneDrive - UWA/Desktop/sem4/CITS5553/explainable-nl-query-db-agents/SQL_Prediction_Result.txt')

## Imports and Setup

In [None]:


# Try to import nltk
try:
    from nltk import word_tokenize
    print("✓ NLTK available")
except ImportError:
    print("⚠ NLTK not available, using simple tokenization")
    def word_tokenize(s):
        return s.replace('(', ' ( ').replace(')', ' ) ').replace(',', ' , ').split()

# SQL Constants from process_sql.py
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {'sql': "sql", 'table_unit': "table_unit"}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')

SQL_KEYWORDS = {
    'select', 'from', 'where', 'group', 'order', 'by', 'having', 'limit',
    'join', 'inner', 'left', 'right', 'full', 'outer', 'on', 'as',
    'union', 'intersect', 'except', 'and', 'or', 'not', 'in', 'like',
    'exists', 'between', 'asc', 'desc', 'distinct', 'count', 'sum',
    'avg', 'max', 'min'
}

print("✓ Imports and constants loaded")

✓ NLTK available
✓ Imports and constants loaded


## Process SQL - Full Parser Implementation

In [141]:
class Schema:
    """Simple schema which maps table&column to a unique identifier"""
    def __init__(self, schema):
        self._schema = schema
        self._idMap = self._map(self._schema)

    @property
    def schema(self):
        return self._schema

    @property
    def idMap(self):
        return self._idMap

    def _map(self, schema):
        idMap = {'*': "__all__"}
        id = 1
        for key, vals in schema.items():
            for val in vals:
                idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__"
                id += 1
        for key in schema:
            idMap[key.lower()] = "__" + key.lower() + "__"
            id += 1
        return idMap


def get_schema(db):
    """Get database's schema"""
    schema = {}
    conn = sqlite3.connect(db)
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = [str(table[0].lower()) for table in cursor.fetchall()]
    for table in tables:
        cursor.execute("PRAGMA table_info({})".format(table))
        schema[table] = [str(col[1].lower()) for col in cursor.fetchall()]
    conn.close()
    return schema


def tokenize(string):
    """Robust SQL tokenizer that doesn't rely on NLTK"""
    string = str(string)
    string = string.replace("'", '"')  # Normalize quotes
    
    # Extract quoted strings first
    quote_idxs = [idx for idx, char in enumerate(string) if char == '"']
    
    vals = {}
    # Only process if we have matching pairs
    if len(quote_idxs) % 2 == 0:
        for i in range(len(quote_idxs)-1, -1, -2):
            qidx1 = quote_idxs[i-1]
            qidx2 = quote_idxs[i]
            val = string[qidx1: qidx2+1]
            key = "__val_{}_{}__".format(qidx1, qidx2)
            string = string[:qidx1] + key + string[qidx2+1:]
            vals[key] = val
    
    # Simple but effective tokenization
    # Add spaces around operators and parentheses
    string = string.replace('(', ' ( ')
    string = string.replace(')', ' ) ')
    string = string.replace(',', ' , ')
    string = string.replace(';', ' ; ')
    string = string.replace('=', ' = ')
    string = string.replace('>', ' > ')
    string = string.replace('<', ' < ')
    string = string.replace('!', ' ! ')
    string = string.replace('+', ' + ')
    string = string.replace('-', ' - ')
    string = string.replace('*', ' * ')
    string = string.replace('/', ' / ')
    
    # Tokenize by whitespace
    toks = [word.lower() for word in string.split() if word.strip()]
    
    # Replace placeholders with original quoted values
    for i in range(len(toks)):
        if toks[i] in vals:
            toks[i] = vals[toks[i]]
    
    # Fix !=, >=, <= that got split
    fixed_toks = []
    i = 0
    while i < len(toks):
        if i + 1 < len(toks):
            if toks[i] == '!' and toks[i+1] == '=':
                fixed_toks.append('!=')
                i += 2
                continue
            elif toks[i] == '>' and toks[i+1] == '=':
                fixed_toks.append('>=')
                i += 2
                continue
            elif toks[i] == '<' and toks[i+1] == '=':
                fixed_toks.append('<=')
                i += 2
                continue
        fixed_toks.append(toks[i])
        i += 1
    
    return fixed_toks


def scan_alias(toks):
    as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as']
    alias = {}
    for idx in as_idxs:
        if idx + 1 < len(toks):
            alias[toks[idx+1]] = toks[idx-1]
    return alias


def get_tables_with_alias(schema, toks):
    tables = scan_alias(toks)
    for key in schema:
        tables[key] = key
    return tables


def skip_semicolon(toks, start_idx):
    idx = start_idx
    while idx < len(toks) and toks[idx] == ";":
        idx += 1
    return idx

print("✓ Schema and tokenization functions loaded")

✓ Schema and tokenization functions loaded


In [142]:
def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
    tok = toks[start_idx]
    if tok == "*":
        return start_idx + 1, schema.idMap[tok]

    if '.' in tok:
        alias, col = tok.split('.')
        key = tables_with_alias[alias] + "." + col
        return start_idx+1, schema.idMap[key]

    if default_tables is None or len(default_tables) == 0:
        return start_idx+1, tok

    for alias in default_tables:
        table = tables_with_alias[alias]
        if tok in schema.schema[table]:
            key = table + "." + tok
            return start_idx+1, schema.idMap[key]

    return start_idx+1, tok


def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
    idx = start_idx
    len_ = len(toks)
    isBlock = False
    isDistinct = False
    if toks[idx] == '(':
        isBlock = True
        idx += 1

    if toks[idx] in AGG_OPS:
        agg_id = AGG_OPS.index(toks[idx])
        idx += 1
        if idx < len_ and toks[idx] == '(':
            idx += 1
        if idx < len_ and toks[idx] == "distinct":
            idx += 1
            isDistinct = True
        idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
        if idx < len_ and toks[idx] == ')':
            idx += 1
        return idx, (agg_id, col_id, isDistinct)

    if toks[idx] == "distinct":
        idx += 1
        isDistinct = True
    agg_id = AGG_OPS.index("none")
    idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)

    if isBlock:
        if idx < len_ and toks[idx] == ')':
            idx += 1

    return idx, (agg_id, col_id, isDistinct)


def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
    idx = start_idx
    len_ = len(toks)
    isBlock = False
    if toks[idx] == '(':
        isBlock = True
        idx += 1

    col_unit1 = None
    col_unit2 = None
    unit_op = UNIT_OPS.index('none')

    idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
    if idx < len_ and toks[idx] in UNIT_OPS:
        unit_op = UNIT_OPS.index(toks[idx])
        idx += 1
        idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)

    if isBlock:
        if idx < len_ and toks[idx] == ')':
            idx += 1

    return idx, (unit_op, col_unit1, col_unit2)


def parse_table_unit(toks, start_idx, tables_with_alias, schema):
    idx = start_idx
    len_ = len(toks)
    key = tables_with_alias.get(toks[idx], toks[idx])

    if idx + 1 < len_ and toks[idx+1] == "as":
        idx += 3
    else:
        idx += 1

    return idx, schema.idMap.get(key, key), key


def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):
    idx = start_idx
    len_ = len(toks)

    isBlock = False
    if toks[idx] == '(':
        isBlock = True
        idx += 1

    if idx < len_ and toks[idx] == 'select':
        idx, val = parse_sql(toks, idx, tables_with_alias, schema)
    elif idx < len_ and "\"" in toks[idx]:
        val = toks[idx]
        idx += 1
    else:
        try:
            val = float(toks[idx])
            idx += 1
        except:
            end_idx = idx
            while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')' \
                and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS:
                    end_idx += 1

            idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables)
            idx = end_idx

    if isBlock:
        if idx < len_ and toks[idx] == ')':
            idx += 1

    return idx, val

print("✓ Column and value parsing functions loaded")

✓ Column and value parsing functions loaded


In [143]:
def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):
    idx = start_idx
    len_ = len(toks)
    conds = []

    while idx < len_:
        idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
        not_op = False
        if idx < len_ and toks[idx] == 'not':
            not_op = True
            idx += 1

        if idx >= len_ or toks[idx] not in WHERE_OPS:
            break
            
        op_id = WHERE_OPS.index(toks[idx])
        idx += 1
        val1 = val2 = None
        if op_id == WHERE_OPS.index('between'):
            idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
            if idx < len_ and toks[idx] == 'and':
                idx += 1
            idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
        else:
            idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
            val2 = None

        conds.append((not_op, op_id, val_unit, val1, val2))

        if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS):
            break

        if idx < len_ and toks[idx] in COND_OPS:
            conds.append(toks[idx])
            idx += 1

    return idx, conds


def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):
    idx = start_idx
    len_ = len(toks)

    if idx >= len_ or toks[idx] != 'select':
        return idx, (False, [])
        
    idx += 1
    isDistinct = False
    if idx < len_ and toks[idx] == 'distinct':
        idx += 1
        isDistinct = True
    val_units = []

    while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:
        agg_id = AGG_OPS.index("none")
        if toks[idx] in AGG_OPS:
            agg_id = AGG_OPS.index(toks[idx])
            idx += 1
        idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
        val_units.append((agg_id, val_unit))
        if idx < len_ and toks[idx] == ',':
            idx += 1

    return idx, (isDistinct, val_units)


def parse_from(toks, start_idx, tables_with_alias, schema):
    len_ = len(toks)
    
    if 'from' not in toks[start_idx:]:
        return start_idx, [], [], []
    
    idx = toks.index('from', start_idx) + 1
    default_tables = []
    table_units = []
    conds = []

    while idx < len_:
        isBlock = False
        if toks[idx] == '(':
            isBlock = True
            idx += 1

        if idx < len_ and toks[idx] == 'select':
            idx, sql = parse_sql(toks, idx, tables_with_alias, schema)
            table_units.append((TABLE_TYPE['sql'], sql))
        else:
            if idx < len_ and toks[idx] == 'join':
                idx += 1
            idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema)
            table_units.append((TABLE_TYPE['table_unit'],table_unit))
            default_tables.append(table_name)
        if idx < len_ and toks[idx] == "on":
            idx += 1
            idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
            if len(conds) > 0:
                conds.append('and')
            conds.extend(this_conds)

        if isBlock:
            if idx < len_ and toks[idx] == ')':
                idx += 1
        if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
            break

    return idx, table_units, conds, default_tables

print("✓ Condition and FROM clause parsing loaded")

✓ Condition and FROM clause parsing loaded


In [144]:
def parse_where(toks, start_idx, tables_with_alias, schema, default_tables):
    idx = start_idx
    if idx >= len(toks) or toks[idx] != 'where':
        return idx, []
    idx += 1
    idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
    return idx, conds


def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):
    idx = start_idx
    col_units = []
    if idx >= len(toks) or toks[idx] != 'group':
        return idx, col_units
    idx += 1
    if idx < len(toks) and toks[idx] == 'by':
        idx += 1
    while idx < len(toks) and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
        idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
        col_units.append(col_unit)
        if idx < len(toks) and toks[idx] == ',':
            idx += 1
        else:
            break
    return idx, col_units


def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):
    idx = start_idx
    val_units = []
    order_type = 'asc'
    if idx >= len(toks) or toks[idx] != 'order':
        return idx, val_units
    idx += 1
    if idx < len(toks) and toks[idx] == 'by':
        idx += 1
    while idx < len(toks) and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
        idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
        val_units.append(val_unit)
        if idx < len(toks) and toks[idx] in ORDER_OPS:
            order_type = toks[idx]
            idx += 1
        if idx < len(toks) and toks[idx] == ',':
            idx += 1
        else:
            break
    return idx, (order_type, val_units)


def parse_having(toks, start_idx, tables_with_alias, schema, default_tables):
    idx = start_idx
    if idx >= len(toks) or toks[idx] != 'having':
        return idx, []
    idx += 1
    idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
    return idx, conds


def parse_limit(toks, start_idx):
    idx = start_idx
    if idx < len(toks) and toks[idx] == 'limit':
        idx += 1
        if idx < len(toks):
            try:
                return idx + 1, int(toks[idx])
            except:
                return idx, None
    return idx, None

print("✓ WHERE, GROUP BY, ORDER BY, HAVING, LIMIT parsing loaded")

✓ WHERE, GROUP BY, ORDER BY, HAVING, LIMIT parsing loaded


In [145]:
def parse_sql(toks, start_idx, tables_with_alias, schema):
    """Main SQL parsing function"""
    isBlock = False
    idx = start_idx

    sql = {}
    if idx < len(toks) and toks[idx] == '(':
        isBlock = True
        idx += 1

    # Parse FROM to get default tables
    from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema)
    sql['from'] = {'table_units': table_units, 'conds': conds}
    
    # Parse SELECT
    _, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables)
    idx = from_end_idx
    sql['select'] = select_col_units
    
    # Parse WHERE
    idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables)
    sql['where'] = where_conds
    
    # Parse GROUP BY
    idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables)
    sql['groupBy'] = group_col_units
    
    # Parse HAVING
    idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables)
    sql['having'] = having_conds
    
    # Parse ORDER BY
    idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables)
    sql['orderBy'] = order_col_units
    
    # Parse LIMIT
    idx, limit_val = parse_limit(toks, idx)
    sql['limit'] = limit_val

    idx = skip_semicolon(toks, idx)
    if isBlock:
        if idx < len(toks) and toks[idx] == ')':
            idx += 1
    idx = skip_semicolon(toks, idx)

    # Parse INTERSECT/UNION/EXCEPT
    for op in SQL_OPS:
        sql[op] = None
    if idx < len(toks) and toks[idx] in SQL_OPS:
        sql_op = toks[idx]
        idx += 1
        idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema)
        sql[sql_op] = IUE_sql
    return idx, sql


def get_sql(schema, query):
    """Parse SQL query into structured representation"""
    try:
        toks = tokenize(query)
        if not toks:
            return None
        tables_with_alias = get_tables_with_alias(schema.schema, toks)
        _, sql = parse_sql(toks, 0, tables_with_alias, schema)
        return sql
    except Exception as e:
        return None

print("✓ Main SQL parsing function loaded")

✓ Main SQL parsing function loaded


## Spider Evaluation Functions

In [146]:
def get_scores(count, pred_total, label_total):
    if pred_total != label_total:
        return 0, 0, 0
    elif count == pred_total:
        return 1, 1, 1
    return 0, 0, 0


def eval_sel(pred, label):
    pred_sel = pred['select'][1]
    label_sel = label['select'][1][:]
    label_wo_agg = [unit[1] for unit in label_sel]
    pred_total = len(pred_sel)
    label_total = len(label_sel)
    cnt = cnt_wo_agg = 0
    
    for unit in pred_sel:
        if unit in label_sel:
            cnt += 1
            label_sel.remove(unit)
        if unit[1] in label_wo_agg:
            cnt_wo_agg += 1
            label_wo_agg.remove(unit[1])
    
    return label_total, pred_total, cnt, cnt_wo_agg


def eval_where(pred, label):
    pred_conds = [unit for unit in pred['where'][::2]]
    label_conds = [unit for unit in label['where'][::2]]
    label_wo_agg = [unit[2] for unit in label_conds]
    pred_total = len(pred_conds)
    label_total = len(label_conds)
    cnt = cnt_wo_agg = 0
    
    for unit in pred_conds:
        if unit in label_conds:
            cnt += 1
            label_conds.remove(unit)
        if len(unit) > 2 and unit[2] in label_wo_agg:
            cnt_wo_agg += 1
            label_wo_agg.remove(unit[2])
    
    return label_total, pred_total, cnt, cnt_wo_agg


def eval_group(pred, label):
    pred_cols = [str(unit[1]) for unit in pred['groupBy']]
    label_cols = [str(unit[1]) for unit in label['groupBy']]
    pred_total = len(pred_cols)
    label_total = len(label_cols)
    cnt = 0
    
    for col in pred_cols:
        if col in label_cols:
            cnt += 1
            label_cols.remove(col)
    
    return label_total, pred_total, cnt


def eval_having(pred, label):
    pred_total = 1 if len(pred['having']) > 0 else 0
    label_total = 1 if len(label['having']) > 0 else 0
    cnt = 1 if pred['having'] == label['having'] else 0
    return label_total, pred_total, cnt


def eval_order(pred, label):
    pred_total = 1 if len(pred['orderBy']) > 0 else 0
    label_total = 1 if len(label['orderBy']) > 0 else 0
    cnt = 1 if pred['orderBy'] == label['orderBy'] else 0
    return label_total, pred_total, cnt


def eval_and_or(pred, label):
    pred_ao = set(pred['where'][1::2]) if len(pred['where']) > 1 else set()
    label_ao = set(label['where'][1::2]) if len(label['where']) > 1 else set()
    if pred_ao == label_ao:
        return 1, 1, 1
    return len(label_ao), len(pred_ao), 0


def eval_IUEN(pred, label):
    lt = pt = cnt = 0
    for op in SQL_OPS:
        if pred[op] is not None:
            pt += 1
        if label[op] is not None:
            lt += 1
        if pred[op] is not None and label[op] is not None:
            cnt += 1
    return lt, pt, cnt


def get_keywords(sql):
    res = set()
    if len(sql['where']) > 0:
        res.add('where')
    if len(sql['groupBy']) > 0:
        res.add('group')
    if len(sql['having']) > 0:
        res.add('having')
    if len(sql['orderBy']) > 0:
        res.add('order')
    if sql['limit'] is not None:
        res.add('limit')
    for op in SQL_OPS:
        if sql[op] is not None:
            res.add(op)
    return res


def eval_keywords(pred, label):
    pred_keywords = get_keywords(pred)
    label_keywords = get_keywords(label)
    pred_total = len(pred_keywords)
    label_total = len(label_keywords)
    cnt = len(pred_keywords & label_keywords)
    return label_total, pred_total, cnt

print("✓ Spider evaluation functions loaded")

✓ Spider evaluation functions loaded


In [147]:
class Evaluator:
    def __init__(self):
        self.partial_scores = None
    
    def eval_exact_match(self, pred, label):
        partial_scores = self.eval_partial_match(pred, label)
        self.partial_scores = partial_scores
        
        for _, score in partial_scores.items():
            if score['f1'] != 1:
                return 0
        
        if len(label['from']['table_units']) > 0:
            label_tables = sorted([str(t) for t in label['from']['table_units']])
            pred_tables = sorted([str(t) for t in pred['from']['table_units']])
            return 1 if label_tables == pred_tables else 0
        return 1
    
    def eval_partial_match(self, pred, label):
        res = {}
        
        # SELECT
        label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['select'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total}
        acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
        res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total}
        
        # WHERE
        label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['where'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total}
        acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
        res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total}
        
        # GROUP BY
        label_total, pred_total, cnt = eval_group(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total}
        
        label_total, pred_total, cnt = eval_having(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['group'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total}
        
        # ORDER BY
        label_total, pred_total, cnt = eval_order(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['order'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total}
        
        # AND/OR
        label_total, pred_total, cnt = eval_and_or(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total}
        
        # IUEN
        label_total, pred_total, cnt = eval_IUEN(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total}
        
        # Keywords
        label_total, pred_total, cnt = eval_keywords(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total}
        
        return res

print("✓ Evaluator class loaded")

✓ Evaluator class loaded


## Tolerant File Parser

In [148]:
def _strip_outer_quotes(s: str):
    s = s.strip()
    if s.endswith(","):
        s = s[:-1].rstrip()
    if len(s) >= 2 and s[0] == '"' and s[-1] == '"':
        return s[1:-1]
    return s


def parse_file(filepath):
    """Tolerant parser for SQL prediction results"""
    with open(filepath, "r", encoding="utf-8") as f:
        text = f.read()
    
    # Split sections
    m = re.search(r'^\s*DATABASE\s*:', text, flags=re.I|re.M)
    if not m:
        raise ValueError("Could not find 'DATABASE :' section")
    
    rows_text = text[:m.start()].strip()
    db_text = text[m.end():].strip()
    
    # Parse records line-by-line
    records = []
    current = None
    
    for line in rows_text.splitlines():
        t = line.strip()
        if t.startswith("{"):
            current = {"query": "", "sql_truth": "", "sql_pred": ""}
        elif t.startswith("}"):
            if current and all(current[k] for k in ["query", "sql_truth", "sql_pred"]):
                records.append(current)
            current = None
        elif current is not None:
            m = re.match(r'^\s*"(?P<key>query|sql_truth|sql_pred)"\s*:\s*(?P<val>.+)$', line)
            if m:
                current[m.group("key")] = _strip_outer_quotes(m.group("val"))
    
    # Parse database IDs
    start = db_text.find('[')
    end = db_text.rfind(']')
    if start == -1 or end == -1:
        raise ValueError("Could not find database list")
    
    db_parts = [p.strip() for p in db_text[start+1:end].split(',')]
    db_ids = []
    for p in db_parts:
        if p and len(p) >= 2 and p[0] in "'\"" and p[-1] in "'\"":
            db_ids.append(p[1:-1])
    
    # Combine
    n = min(len(records), len(db_ids))
    for i in range(n):
        records[i]['db_id'] = db_ids[i]
    
    return records[:n]

print("✓ Tolerant file parser loaded")

✓ Tolerant file parser loaded


## Helper Functions

In [149]:
def normalize_sql(sql: str) -> str:
    sql = " ".join((sql or "").split()).strip()
    if sql.endswith(";"):
        sql = sql[:-1]
    return sql.lower()


def extract_keywords(sql: str) -> set:
    tokens = re.findall(r"[A-Za-z_]+", (sql or "").lower())
    return set(t for t in tokens if t in SQL_KEYWORDS)


def find_db_path(db_id: str):
    p1 = os.path.join(DB_DIR, db_id, f"{db_id}.sqlite")
    p2 = os.path.join(DB_DIR, f"{db_id}.sqlite")
    return p1 if os.path.exists(p1) else (p2 if os.path.exists(p2) else None)


def exec_sql(db_path: str, sql: str):
    try:
        conn = sqlite3.connect(db_path)
        cur = conn.cursor()
        cur.execute(sql)
        rows = cur.fetchall()
        conn.close()
        return True, rows
    except Exception as e:
        return False, str(e)


def results_equal(a, b):
    try:
        return sorted(tuple(r) for r in a) == sorted(tuple(r) for r in b)
    except:
        return a == b

print("✓ Helper functions loaded")

✓ Helper functions loaded


## Load Data

In [169]:
# Load data using tolerant parser
data = parse_file(FILE_PATH)
print(f"\n✓ Loaded {len(data)} query pairs")
print(f"  Unique databases: {len(set(r['db_id'] for r in data))}")
print(f"\n  Sample query: {data[0]['query'][:60]}...")
print(f"  Sample DB: {data[0]['db_id']}")


✓ Loaded 255 query pairs
  Unique databases: 126

  Sample query: How many heads of the departments are older than 56 ?...
  Sample DB: department_management


## Run Complete Evaluation

In [170]:
print("\n" + "=" * 80)
print("RUNNING COMPLETE EVALUATION")
print("=" * 80)

total = len(data)

# String-based metrics
exact_match = 0
kw_prec, kw_rec, kw_f1 = [], [], []

# Execution metrics
exec_match = 0
exec_attempted = 0
exec_pred_err = 0
exec_gold_err = 0

# Spider structural metrics
evaluator = Evaluator()
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 
                 'group(no Having)', 'group', 'order', 'and/or', 'IUEN', 'keywords']
scores = {'count': 0, 'exact': 0, 'partial': {}}
for type_ in partial_types:
    scores['partial'][type_] = {'acc': 0, 'rec': 0, 'f1': 0, 'acc_count': 0, 'rec_count': 0}

parse_errors = 0
mismatches = []

for i, r in enumerate(data):
    gold_str = normalize_sql(r.get("sql_truth", ""))
    pred_str = normalize_sql(r.get("sql_pred", ""))
    db_id = r.get("db_id", "")
    
    # 1. Exact Match
    if pred_str == gold_str:
        exact_match += 1
    else:
        if len(mismatches) < SHOW_FIRST_N_MISMATCHES:
            mismatches.append((i, db_id, r.get("query"), gold_str, pred_str))
    
    # 2. Keyword Metrics
    kg = extract_keywords(gold_str)
    kp = extract_keywords(pred_str)
    inter = len(kg & kp)
    prec = inter / (len(kp) or 1)
    rec = inter / (len(kg) or 1)
    f1 = (2 * prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
    kw_prec.append(prec)
    kw_rec.append(rec)
    kw_f1.append(f1)
    
    # 3. Execution Match
    if USE_EXEC:
        path = find_db_path(db_id)
        if path:
            exec_attempted += 1
            ok_g, res_g = exec_sql(path, r["sql_truth"])
            ok_p, res_p = exec_sql(path, r["sql_pred"])
            
            if not ok_g:
                exec_gold_err += 1
            if not ok_p:
                exec_pred_err += 1
            
            if ok_g and ok_p and results_equal(res_g, res_p):
                exec_match += 1
    
    # 4. Spider Structural Evaluation
    path = find_db_path(db_id)
    if path:
        try:
            schema = Schema(get_schema(path))
            g_sql = get_sql(schema, r["sql_truth"])
            p_sql = get_sql(schema, r["sql_pred"])

            if g_sql and p_sql:
                scores['count'] += 1
                
                exact_score = evaluator.eval_exact_match(p_sql, g_sql)
                partial_scores = evaluator.partial_scores
                
                scores['exact'] += exact_score
                
                for type_ in partial_types:
                    if partial_scores[type_]['pred_total'] > 0:
                        scores['partial'][type_]['acc'] += partial_scores[type_]['acc']
                        scores['partial'][type_]['acc_count'] += 1
                    if partial_scores[type_]['label_total'] > 0:
                        scores['partial'][type_]['rec'] += partial_scores[type_]['rec']
                        scores['partial'][type_]['rec_count'] += 1
                    scores['partial'][type_]['f1'] += partial_scores[type_]['f1']
            else:
                parse_errors += 1
        except Exception as e:
            parse_errors += 1
    
    if (i + 1) % 50 == 0:
        print(f"Processed {i + 1}/{total}...")

# Calculate averages for Spider metrics
if scores['count'] > 0:
    scores['exact'] /= scores['count']
    for type_ in partial_types:
        if scores['partial'][type_]['acc_count'] > 0:
            scores['partial'][type_]['acc'] /= scores['partial'][type_]['acc_count']
        if scores['partial'][type_]['rec_count'] > 0:
            scores['partial'][type_]['rec'] /= scores['partial'][type_]['rec_count']
        scores['partial'][type_]['f1'] /= scores['count']

print(f"\n✓ Evaluation complete!")


RUNNING COMPLETE EVALUATION
Processed 50/255...
Processed 100/255...
Processed 150/255...
Processed 200/255...
Processed 250/255...

✓ Evaluation complete!


## Results: Part 1 - String Metrics

In [171]:
print("\n" + "=" * 80)
print("PART 1: STRING-BASED METRICS")
print("=" * 80)

em_pct = (exact_match / total * 100) if total > 0 else 0
avg_prec = mean(kw_prec) if kw_prec else 0
avg_rec = mean(kw_rec) if kw_rec else 0
avg_f1 = mean(kw_f1) if kw_f1 else 0

print(f"\n{'Metric':<40} {'Score':<15} {'Percentage'}")
print("-" * 80)
print(f"{'Exact String Match':<40} {exact_match:<15} {em_pct:.2f}%")
print(f"{'Keyword Precision (avg)':<40} {avg_prec:.4f}")
print(f"{'Keyword Recall (avg)':<40} {avg_rec:.4f}")
print(f"{'Keyword F1 (avg)':<40} {avg_f1:.4f}")


PART 1: STRING-BASED METRICS

Metric                                   Score           Percentage
--------------------------------------------------------------------------------
Exact String Match                       51              20.00%
Keyword Precision (avg)                  0.8696
Keyword Recall (avg)                     0.8859
Keyword F1 (avg)                         0.8693


## Results: Part 2 - Execution Accuracy

In [172]:
if USE_EXEC:
    print("\n" + "=" * 80)
    print("PART 2: EXECUTION ACCURACY")
    print("=" * 80)
    
    exec_pct = (exec_match / exec_attempted * 100) if exec_attempted > 0 else 0
    
    print(f"\nQueries Executed: {exec_attempted}")
    print(f"Execution Match:  {exec_match} ({exec_pct:.2f}%)")
    print(f"Prediction Error: {exec_pred_err} ({exec_pred_err/exec_attempted*100:.2f}%)")
    print(f"Gold Error:       {exec_gold_err} ({exec_gold_err/exec_attempted*100:.2f}%)")
    
    valid = exec_attempted - exec_pred_err - exec_gold_err
    if valid > 0:
        eff_acc = (exec_match / valid * 100)
        print(f"\nEffective Accuracy (valid executions only): {exec_match}/{valid} ({eff_acc:.2f}%)")


PART 2: EXECUTION ACCURACY

Queries Executed: 255
Execution Match:  189 (74.12%)
Prediction Error: 1 (0.39%)
Gold Error:       1 (0.39%)

Effective Accuracy (valid executions only): 189/253 (74.70%)


## Results: Part 3 - Spider Component Matching

In [173]:
if scores['count'] > 0:
    print("\n" + "=" * 80)
    print("PART 3: SPIDER-STYLE STRUCTURAL COMPONENT MATCHING")
    print("=" * 80)
    
    print(f"\nSuccessfully Parsed: {scores['count']}/{total}")
    print(f"Parsing Errors: {parse_errors}")
    
    print('\n====================== EXACT MATCHING ACCURACY =====================')
    print(f"{'exact match':<20} {scores['exact']:<20.3f}")
    
    print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
    for type_ in partial_types:
        print(f"{type_:<20} {scores['partial'][type_]['acc']:<20.3f}")
    
    print('\n---------------------- PARTIAL MATCHING RECALL ----------------------')
    for type_ in partial_types:
        print(f"{type_:<20} {scores['partial'][type_]['rec']:<20.3f}")
    
    print('\n---------------------- PARTIAL MATCHING F1 --------------------------')
    for type_ in partial_types:
        print(f"{type_:<20} {scores['partial'][type_]['f1']:<20.3f}")
else:
    print("\n⚠ Spider evaluation skipped (parsing errors)")


PART 3: SPIDER-STYLE STRUCTURAL COMPONENT MATCHING

Successfully Parsed: 151/255
Parsing Errors: 104

exact match          0.000               

---------------------PARTIAL MATCHING ACCURACY----------------------
select               0.709               
select(no AGG)       0.709               
where                0.522               
where(no OP)         0.687               
group(no Having)     0.880               
group                0.167               
order                0.800               
and/or               1.000               
IUEN                 1.000               
keywords             0.889               

---------------------- PARTIAL MATCHING RECALL ----------------------
select               0.709               
select(no AGG)       0.709               
where                0.636               
where(no OP)         0.836               
group(no Having)     0.917               
group                0.167               
order                0.667               


### Error diagonostic errors

In [174]:
import pandas as pd
from collections import defaultdict

print("=" * 80)
print("PARSING ERROR ANALYSIS")
print("=" * 80)

# Track all error types
error_data = {
    'Database Not Found': [],
    'Tokenize Gold Failed (empty tokens)': [],
    'Tokenize Pred Failed (empty tokens)': [],
    'Parse Gold Failed (returned None)': [],
    'Parse Pred Failed (returned None)': [],
    'Schema Error': [],
    'Assertion Error': [],
    'Key Error': [],
    'Other Errors': [],
    'Successfully Parsed': []
}

for i, r in enumerate(data):
    db_id = r.get("db_id", "")
    path = find_db_path(db_id)
    
    if not path:
        error_data['Database Not Found'].append({
            'index': i,
            'db_id': db_id,
            'query': r['query'][:60],
            'error': 'DB file not found'
        })
        continue
    
    try:
        # Tokenization
        gold_toks = tokenize(r["sql_truth"])
        pred_toks = tokenize(r["sql_pred"])
        
        if not gold_toks:
            error_data['Tokenize Gold Failed (empty tokens)'].append({
                'index': i,
                'db_id': db_id,
                'query': r['query'][:60],
                'sql': r["sql_truth"][:80]
            })
            continue
        
        if not pred_toks:
            error_data['Tokenize Pred Failed (empty tokens)'].append({
                'index': i,
                'db_id': db_id,
                'query': r['query'][:60],
                'sql': r["sql_pred"][:80]
            })
            continue
        
        # Parsing
        schema = Schema(get_schema(path))
        g_sql = get_sql(schema, r["sql_truth"])
        p_sql = get_sql(schema, r["sql_pred"])
        
        if g_sql is None:
            error_data['Parse Gold Failed (returned None)'].append({
                'index': i,
                'db_id': db_id,
                'query': r['query'][:60],
                'sql': r["sql_truth"][:80]
            })
            continue
        
        if p_sql is None:
            error_data['Parse Pred Failed (returned None)'].append({
                'index': i,
                'db_id': db_id,
                'query': r['query'][:60],
                'sql': r["sql_pred"][:80]
            })
            continue
        
        # Success
        error_data['Successfully Parsed'].append({
            'index': i,
            'db_id': db_id,
            'query': r['query'][:60]
        })
        
    except AssertionError as e:
        error_data['Assertion Error'].append({
            'index': i,
            'db_id': db_id,
            'query': r['query'][:60],
            'error': str(e)[:60]
        })
    except KeyError as e:
        error_data['Key Error'].append({
            'index': i,
            'db_id': db_id,
            'query': r['query'][:60],
            'error': f"Missing: {e}"
        })
    except Exception as e:
        error_data['Other Errors'].append({
            'index': i,
            'db_id': db_id,
            'query': r['query'][:60],
            'error': f"{type(e).__name__}: {str(e)[:40]}"
        })

# Create summary table
summary_data = []
for category, errors in error_data.items():
    summary_data.append({
        'Error Category': category,
        'Count': len(errors),
        'Percentage': f"{len(errors)/len(data)*100:.2f}%"
    })

summary_df = pd.DataFrame(summary_data)
summary_df = summary_df.sort_values('Count', ascending=False)

print("\n" + "=" * 80)
print("SUMMARY TABLE")
print("=" * 80)
print(summary_df.to_string(index=False))

# Detailed breakdown of each error type
print("\n" + "=" * 80)
print("DETAILED BREAKDOWN (First 3 examples of each error type)")
print("=" * 80)

for category, errors in error_data.items():
    if errors and category != 'Successfully Parsed':
        print(f"\n{'=' * 80}")
        print(f"{category.upper()} ({len(errors)} total)")
        print(f"{'=' * 80}")
        
        # Show first 3 examples
        for example in errors[:3]:
            print(f"\n  Index: {example['index']}")
            print(f"  DB: {example['db_id']}")
            print(f"  Query: {example['query']}...")
            if 'sql' in example:
                print(f"  SQL: {example['sql']}...")
            if 'error' in example:
                print(f"  Error: {example['error']}")

# Export to CSV (optional)
print("\n" + "=" * 80)
print("EXPORTING DETAILED RESULTS")
print("=" * 80)

all_errors = []
for category, errors in error_data.items():
    for error in errors:
        error['Category'] = category
        all_errors.append(error)

if all_errors:
    error_df = pd.DataFrame(all_errors)
    error_df = error_df[['Category', 'index', 'db_id', 'query'] + 
                        [col for col in error_df.columns if col not in ['Category', 'index', 'db_id', 'query']]]
    error_df.to_csv('parsing_errors_detailed.csv', index=False)
    print(f"Exported {len(all_errors)} records to 'parsing_errors_detailed.csv'")
    
    # Show preview
    print("\nPreview of export:")
    print(error_df.head(10).to_string(index=False))

PARSING ERROR ANALYSIS

SUMMARY TABLE
                     Error Category  Count Percentage
                Successfully Parsed    151     59.22%
  Parse Pred Failed (returned None)    104     40.78%
Tokenize Gold Failed (empty tokens)      0      0.00%
                 Database Not Found      0      0.00%
  Parse Gold Failed (returned None)      0      0.00%
Tokenize Pred Failed (empty tokens)      0      0.00%
                       Schema Error      0      0.00%
                    Assertion Error      0      0.00%
                          Key Error      0      0.00%
                       Other Errors      0      0.00%

DETAILED BREAKDOWN (First 3 examples of each error type)

PARSE PRED FAILED (RETURNED NONE) (104 total)

  Index: 1
  DB: farm
  Query: What are the hosts of competitions whose theme is not "Alien...
  SQL: SELECT DISTINCT c.official_name FROM farm_competition fc JOIN city c ON fc.host_...

  Index: 2
  DB: farm
  Query: Please show the themes of competitions with 

In [175]:
# Diagnostic: Compare gold vs pred for failed parses
print("=" * 80)
print("ANALYZING FAILED PREDICTIONS")
print("=" * 80)

failed_examples = []

for i, r in enumerate(data):
    db_id = r.get("db_id", "")
    path = find_db_path(db_id)
    
    if not path:
        continue
    
    try:
        schema = Schema(get_schema(path))
        g_sql = get_sql(schema, r["sql_truth"])
        p_sql = get_sql(schema, r["sql_pred"])
        
        # Gold parses but pred doesn't
        if g_sql is not None and p_sql is None:
            failed_examples.append({
                'index': i,
                'gold': r["sql_truth"],
                'pred': r["sql_pred"],
                'gold_tokens': len(tokenize(r["sql_truth"])),
                'pred_tokens': len(tokenize(r["sql_pred"]))
            })
    except:
        pass

print(f"\nFound {len(failed_examples)} cases where gold parses but pred fails\n")

# Show first 10 examples
for idx, example in enumerate(failed_examples[:10]):
    print(f"\n{'='*80}")
    print(f"Example {idx+1} (Query #{example['index']})")
    print(f"{'='*80}")
    print(f"Gold SQL ({example['gold_tokens']} tokens):")
    print(f"  {example['gold']}")
    print(f"\nPred SQL ({example['pred_tokens']} tokens):")
    print(f"  {example['pred']}")
    
    # Check for common issues
    issues = []
    if len(example['pred']) > len(example['gold']) + 50:
        issues.append("Pred much longer than gold")
    if example['pred'].count('(') != example['pred'].count(')'):
        issues.append("Unbalanced parentheses")
    if '...' in example['pred'] or example['pred'].endswith('...'):
        issues.append("Truncated SQL (contains ...)")
    if '\n' in example['pred']:
        issues.append("Contains newlines")
    
    if issues:
        print(f"\nPotential issues: {', '.join(issues)}")

ANALYZING FAILED PREDICTIONS

Found 104 cases where gold parses but pred fails


Example 1 (Query #1)
Gold SQL (8 tokens):
  SELECT Hosts FROM farm_competition WHERE Theme !=  'Aliens'

Pred SQL (18 tokens):
  SELECT DISTINCT c.official_name FROM farm_competition fc JOIN city c ON fc.host_city_id = c.city_id WHERE fc.theme <> 'Aliens'

Potential issues: Pred much longer than gold

Example 2 (Query #2)
Gold SQL (18 tokens):
  SELECT T2.Theme FROM city AS T1 JOIN farm_competition AS T2 ON T1.City_ID  =  T2.Host_city_ID WHERE T1.Population  >  1000

Pred SQL (16 tokens):
  SELECT fc.theme FROM farm_competition fc JOIN city c ON fc.host_city_id = c.city_id WHERE c.population > 1000

Example 3 (Query #5)
Gold SQL (23 tokens):
  SELECT T1.lat ,  T1.long ,  T1.city FROM station AS T1 JOIN trip AS T2 ON T1.id  =  T2.start_station_id ORDER BY T2.duration LIMIT 1

Pred SQL (22 tokens):
  SELECT s.lat, s.long, s.city FROM station s JOIN trip t ON s.id = t.start_station_id ORDER BY t.duration ASC 

In [176]:
# Update your summary to emphasize this:

print("=" * 80)
print("EVALUATION INTERPRETATION")
print("=" * 80)

print(f"\n1. EXECUTION ACCURACY (Most Important - Does it work?)")
print(f"   → {exec_pct:.2f}% of queries return correct results")
print(f"   This is your PRIMARY metric - queries that work correctly.\n")

print(f"2. STRING MATCHING (Least Important - Exact text match)")
print(f"   → {em_pct:.2f}% exact string matches")
print(f"   Low scores expected due to stylistic differences (aliases, spacing, etc.)\n")

print(f"3. SPIDER STRUCTURAL (Medium - Component matching)")
print(f"   → Evaluated on {scores['count']}/{total} queries ({scores['count']/total*100:.1f}%)")
print(f"   → {scores['exact']*100:.2f}% exact structural match")
print(f"   Parser is strict about syntax style, so some valid queries fail to parse.")
print(f"   This metric works best when gold and pred use similar SQL dialects.\n")

print("RECOMMENDATION:")
print("Use EXECUTION ACCURACY as your primary evaluation metric.")
print("Spider metrics are supplementary and only apply to parseable queries.")

EVALUATION INTERPRETATION

1. EXECUTION ACCURACY (Most Important - Does it work?)
   → 74.12% of queries return correct results
   This is your PRIMARY metric - queries that work correctly.

2. STRING MATCHING (Least Important - Exact text match)
   → 20.00% exact string matches
   Low scores expected due to stylistic differences (aliases, spacing, etc.)

3. SPIDER STRUCTURAL (Medium - Component matching)
   → Evaluated on 151/255 queries (59.2%)
   → 0.00% exact structural match
   Parser is strict about syntax style, so some valid queries fail to parse.
   This metric works best when gold and pred use similar SQL dialects.

RECOMMENDATION:
Use EXECUTION ACCURACY as your primary evaluation metric.
Spider metrics are supplementary and only apply to parseable queries.


In [177]:
if scores['count'] > 0:
    print("\n" + "=" * 80)
    print("SPIDER PARTIAL COMPONENT SCORES")
    print("=" * 80)
    
    for type_ in partial_types:
        f1 = scores['partial'][type_]['f1']
        acc = scores['partial'][type_]['acc']
        rec = scores['partial'][type_]['rec']
        print(f"{type_:<25} F1: {f1:.3f}  |  Acc: {acc:.3f}  |  Rec: {rec:.3f}")


SPIDER PARTIAL COMPONENT SCORES
select                    F1: 0.709  |  Acc: 0.709  |  Rec: 0.709
select(no AGG)            F1: 0.709  |  Acc: 0.709  |  Rec: 0.709
where                     F1: 0.788  |  Acc: 0.522  |  Rec: 0.636
where(no OP)              F1: 0.861  |  Acc: 0.687  |  Rec: 0.836
group(no Having)          F1: 0.980  |  Acc: 0.880  |  Rec: 0.917
group                     F1: 0.007  |  Acc: 0.167  |  Rec: 0.167
order                     F1: 0.212  |  Acc: 0.800  |  Rec: 0.667
and/or                    F1: 0.974  |  Acc: 1.000  |  Rec: 0.974
IUEN                      F1: 0.987  |  Acc: 1.000  |  Rec: 0.333
keywords                  F1: 0.907  |  Acc: 0.889  |  Rec: 0.897


## Score Analysis

### Strong Performance `(F1 > 0.85)`:

- **GROUP BY columns**: 0.980 - Nearly perfect at identifying what to group by
- **IUEN (UNION/INTERSECT/EXCEPT)**: 0.987 - Excellent at set operations
- **AND/OR logic**: 0.974 - Great at logical operators
- **Keywords**: 0.907 - Uses correct SQL clauses
- **WHERE (no operators)**: 0.861 - Good at identifying which columns to filter

### Critical Weaknesses:

#### `HAVING` clauses: `0.007 F1` - Almost complete failure

The model groups correctly (0.980) but can't filter grouped results
This is a significant gap in SQL generation capability


#### `ORDER BY`: `0.212 F1` - Poor performance

Low accuracy (0.800) but terrible recall (0.667)
Model often omits ORDER BY when needed


#### `WHERE` operators: `0.788 F1` - Decent but room for improvement

Identifies the right columns (`0.861`) but struggles with exact operator matching



### Actionable Insights:
The `74%` execution accuracy is solid, but these component scores show where to improve:

- **Priority fix**: Train on more HAVING clause examples
- **Secondary fix**: Improve ORDER BY generation (especially knowing when it's required)
- **Fine-tuning**: Better operator selection in WHERE clauses (=, !=, <>, >, <, etc.)

The model understands SQL structure well (high keyword/logic scores) but needs refinement on advanced filtering and sorting.

## Sample Mismatches

In [178]:
if mismatches:
    print("\n" + "=" * 80)
    print(f"SAMPLE MISMATCHES (First {len(mismatches)})")
    print("=" * 80)
    
    for idx, (i, db_id, query, gold, pred) in enumerate(mismatches, 1):
        print(f"\n[{idx}] Query #{i} | DB: {db_id}")
        print(f"Q: {query[:70]}...")
        print(f"G: {gold[:90]}...")
        print(f"P: {pred[:90]}...")
        
        kg = extract_keywords(gold)
        kp = extract_keywords(pred)
        if kg - kp:
            print(f"Missing keywords: {kg - kp}")
        if kp - kg:
            print(f"Extra keywords: {kp - kg}")
        print("-" * 80)
else:
    print("\n✓ All queries matched exactly!")


SAMPLE MISMATCHES (First 10)

[1] Query #1 | DB: farm
Q: What are the hosts of competitions whose theme is not "Aliens"?...
G: select hosts from farm_competition where theme != 'aliens'...
P: select distinct c.official_name from farm_competition fc join city c on fc.host_city_id = ...
Extra keywords: {'distinct', 'on', 'join'}
--------------------------------------------------------------------------------

[2] Query #2 | DB: farm
Q: Please show the themes of competitions with host cities having populat...
G: select t2.theme from city as t1 join farm_competition as t2 on t1.city_id = t2.host_city_i...
P: select fc.theme from farm_competition fc join city c on fc.host_city_id = c.city_id where ...
Missing keywords: {'as'}
--------------------------------------------------------------------------------

[3] Query #3 | DB: student_assessment
Q: What are the ids of the students who either registered or attended a c...
G: select student_id from student_course_registrations union select stu

## Final Summary

In [179]:
print("\n" + "=" * 80)
print("COMPREHENSIVE EVALUATION SUMMARY")
print("=" * 80)

print(f"\nDataset: {total} queries across {len(set(r['db_id'] for r in data))} databases")

print(f"\n1. STRING MATCHING:")
print(f"   Exact Match:     {em_pct:>6.2f}%")
print(f"   Keyword F1:      {avg_f1:>6.4f}")

if USE_EXEC:
    print(f"\n2. EXECUTION:")
    print(f"   Execution Match: {exec_pct:>6.2f}%")
    if valid > 0:
        print(f"   Effective (valid only): {eff_acc:>6.2f}%")

if scores['count'] > 0:
    print(f"\n3. SPIDER STRUCTURAL:")
    print(f"   Exact Match:     {scores['exact']*100:>6.2f}%")
    print(f"   SELECT F1:       {scores['partial']['select']['f1']*100:>6.2f}%")
    print(f"   WHERE F1:        {scores['partial']['where']['f1']*100:>6.2f}%")
    print(f"   GROUP BY F1:     {scores['partial']['group']['f1']*100:>6.2f}%")
    print(f"   ORDER BY F1:     {scores['partial']['order']['f1']*100:>6.2f}%")

print("\n" + "=" * 80)
print("✓ EVALUATION COMPLETE")
print("=" * 80)


COMPREHENSIVE EVALUATION SUMMARY

Dataset: 255 queries across 126 databases

1. STRING MATCHING:
   Exact Match:      20.00%
   Keyword F1:      0.8693

2. EXECUTION:
   Execution Match:  74.12%
   Effective (valid only):  74.70%

3. SPIDER STRUCTURAL:
   Exact Match:       0.00%
   SELECT F1:        70.86%
   WHERE F1:         78.81%
   GROUP BY F1:       0.66%
   ORDER BY F1:      21.19%

✓ EVALUATION COMPLETE
