# SQL Evaluation Code - Complete Function Reference

This notebook provides comprehensive documentation for the SQL evaluation system used in text-to-SQL benchmarks like Spider.

## Table of Contents
1. [Constants & Configuration](#constants)
2. [Utility Helper Functions](#utility)
3. [Component Evaluation Functions](#component-eval)
4. [Hardness & Complexity Functions](#hardness)
5. [SQL Rebuilding Functions](#rebuilding)
6. [Foreign Key Mapping Functions](#foreign-keys)
7. [Main Evaluation Functions](#main-eval)
8. [Scoring & Display Functions](#scoring)

<a id='constants'></a>
## 1. Constants & Configuration

### Global Flags

In [1]:
# Flag to disable value evaluation
DISABLE_VALUE = True

# Flag to disable distinct in select evaluation
DISABLE_DISTINCT = True

**Purpose**: Control whether literal values and DISTINCT keywords are compared during evaluation.

- `DISABLE_VALUE = True`: Ignores literal values (e.g., 18, 'NYC') in comparisons, focusing only on query structure
- `DISABLE_DISTINCT = True`: Ignores DISTINCT keywords in SELECT clauses

### SQL Component Keywords

In [2]:
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 
                   'intersect', 'union', 'except')

JOIN_KEYWORDS = ('join', 'on', 'as')

WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 
             'in', 'like', 'is', 'exists')

UNIT_OPS = ('none', '-', '+', "*", '/')

AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')

COND_OPS = ('and', 'or')

SQL_OPS = ('intersect', 'union', 'except')

ORDER_OPS = ('desc', 'asc')

**Purpose**: Define all valid SQL operators and keywords used in query parsing and evaluation.

These tuples define the vocabulary of SQL operations. The index in each tuple is used as an ID:
- `WHERE_OPS[2]` = '=' (equals operator)
- `AGG_OPS[1]` = 'max' (maximum aggregation)
- `UNIT_OPS[0]` = 'none' (no arithmetic operation)

### Table Types and Hardness Components

In [3]:
TABLE_TYPE = {
    'sql': "sql",
    'table_unit': "table_unit",
}

HARDNESS = {
    "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
    "component2": ('except', 'union', 'intersect')
}

**Purpose**: 
- `TABLE_TYPE`: Distinguish between regular table references and subqueries in FROM clause
- `HARDNESS`: Define which SQL features count toward query complexity scoring

<a id='utility'></a>
## 2. Utility Helper Functions

### Condition Analysis Functions

In [4]:
def condition_has_or(conds):
    """
    Check if a condition list contains any OR operators.
    
    Args:
        conds (list): Condition list in format [cond_unit1, 'and'/'or', cond_unit2, ...]
    
    Returns:
        bool: True if any OR operator exists
    
    Logic:
        Checks every second element (odd indices) which contain the and/or operators
    
    Example:
        conds = [cond1, 'or', cond2, 'and', cond3]
        Returns: True (because 'or' is at index 1)
    """
    return 'or' in conds[1::2]

In [5]:
def condition_has_like(conds):
    """
    Check if any condition uses the LIKE operator.
    
    Args:
        conds (list): Condition list
    
    Returns:
        bool: True if LIKE operator is used
    
    Logic:
        Checks the operator ID (index 1) of each condition unit (even indices)
        LIKE has index 8 in WHERE_OPS tuple
    """
    return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]

In [6]:
def condition_has_sql(conds):
    """
    Check if conditions contain nested SQL queries (subqueries).
    
    Args:
        conds (list): Condition list
    
    Returns:
        bool: True if any nested SQL exists
    
    Logic:
        Examines values in condition units (indices 3 and 4)
        If value is a dict, it represents a subquery
    
    Example:
        WHERE student_id IN (SELECT id FROM honors)  # val1 is dict = True
        WHERE age > 18  # val1 is int = False
    """
    for cond_unit in conds[::2]:
        val1, val2 = cond_unit[3], cond_unit[4]
        if val1 is not None and type(val1) is dict:
            return True
        if val2 is not None and type(val2) is dict:
            return True
    return False

### Unit Analysis Functions

In [7]:
def val_has_op(val_unit):
    """
    Check if a value unit has an arithmetic operation (+, -, *, /).
    
    Args:
        val_unit (tuple): (unit_op, col_unit1, col_unit2)
    
    Returns:
        bool: True if unit_op is not 'none'
    
    Example:
        val_unit = (2, col1, col2)  # UNIT_OPS[2] = '+'
        Returns: True (has addition operation)
    """
    return val_unit[0] != UNIT_OPS.index('none')

In [8]:
def has_agg(unit):
    """
    Check if a column unit has an aggregation function.
    
    Args:
        unit (tuple): (agg_id, col_id, isDistinct)
    
    Returns:
        bool: True if aggregation exists (not 'none')
    
    Example:
        unit = (3, 'age', False)  # AGG_OPS[3] = 'count'
        Returns: True (COUNT aggregation)
    """
    return unit[0] != AGG_OPS.index('none')

### Scoring Functions

In [9]:
def accuracy(count, total):
    """
    Binary accuracy - returns 1 only if all components match.
    
    Args:
        count (int): Number of correct predictions
        total (int): Total number of items
    
    Returns:
        int: 1 if perfect match, 0 otherwise
    """
    if count == total:
        return 1
    return 0

In [10]:
def recall(count, total):
    """
    Binary recall - returns 1 only if all gold components were found.
    
    Args:
        count (int): Number of found items
        total (int): Total number of gold items
    
    Returns:
        int: 1 if perfect recall, 0 otherwise
    """
    if count == total:
        return 1
    return 0

In [11]:
def F1(acc, rec):
    """
    Calculate F1 score (harmonic mean of accuracy and recall).
    
    Args:
        acc (float): Accuracy value
        rec (float): Recall value
    
    Returns:
        float: F1 score (0.0 to 1.0)
    
    Formula:
        F1 = 2 * (precision * recall) / (precision + recall)
    """
    if (acc + rec) == 0:
        return 0
    return (2. * acc * rec) / (acc + rec)

In [12]:
def get_scores(count, pred_total, label_total):
    """
    Calculate accuracy, recall, and F1 all at once (binary version).
    
    Args:
        count (int): Number of matching items
        pred_total (int): Total predicted items
        label_total (int): Total gold items
    
    Returns:
        tuple: (accuracy, recall, f1) - all either 0 or 1
    
    Logic:
        - If totals don't match → (0, 0, 0)
        - If all items match → (1, 1, 1)
        - Otherwise → (0, 0, 0)
    """
    if pred_total != label_total:
        return 0, 0, 0
    elif count == pred_total:
        return 1, 1, 1
    return 0, 0, 0

<a id='component-eval'></a>
## 3. Component Evaluation Functions

### SELECT Clause Evaluation

In [13]:
def eval_sel(pred, label):
    """
    Evaluate SELECT clause matching.
    
    Args:
        pred (dict): Predicted SQL structure
        label (dict): Gold SQL structure
    
    Returns:
        tuple: (label_total, pred_total, cnt, cnt_wo_agg)
            - label_total: Number of columns in gold SELECT
            - pred_total: Number of columns in predicted SELECT
            - cnt: Number of exact matches (with aggregation)
            - cnt_wo_agg: Number of matches ignoring aggregation
    
    Example:
        Gold: SELECT COUNT(name), age FROM students
        Pred: SELECT COUNT(name), city FROM students
        
        label_total = 2
        pred_total = 2
        cnt = 1 (COUNT(name) matches)
        cnt_wo_agg = 1 (name column matches, even if agg differs)
    """
    pred_sel = pred['select'][1]
    label_sel = label['select'][1]
    label_wo_agg = [unit[1] for unit in label_sel]
    pred_total = len(pred_sel)
    label_total = len(label_sel)
    cnt = 0
    cnt_wo_agg = 0

    for unit in pred_sel:
        if unit in label_sel:
            cnt += 1
            label_sel.remove(unit)
        if unit[1] in label_wo_agg:
            cnt_wo_agg += 1
            label_wo_agg.remove(unit[1])

    return label_total, pred_total, cnt, cnt_wo_agg

### WHERE Clause Evaluation

In [14]:
def eval_where(pred, label):
    """
    Evaluate WHERE clause conditions.
    
    Args:
        pred (dict): Predicted SQL
        label (dict): Gold SQL
    
    Returns:
        tuple: (label_total, pred_total, cnt, cnt_wo_agg)
            - cnt: Exact condition matches (with operators)
            - cnt_wo_agg: Matches ignoring operators
    
    Example:
        Gold: WHERE age > 18 AND city = 'NYC'
        Pred: WHERE age >= 18 AND city = 'NYC'
        
        cnt = 1 (only city condition matches exactly)
        cnt_wo_agg = 2 (both conditions use correct columns)
    """
    pred_conds = [unit for unit in pred['where'][::2]]
    label_conds = [unit for unit in label['where'][::2]]
    label_wo_agg = [unit[2] for unit in label_conds]
    pred_total = len(pred_conds)
    label_total = len(label_conds)
    cnt = 0
    cnt_wo_agg = 0

    for unit in pred_conds:
        if unit in label_conds:
            cnt += 1
            label_conds.remove(unit)
        if unit[2] in label_wo_agg:
            cnt_wo_agg += 1
            label_wo_agg.remove(unit[2])

    return label_total, pred_total, cnt, cnt_wo_agg

### GROUP BY Evaluation

In [15]:
def eval_group(pred, label):
    """
    Evaluate GROUP BY clause (without HAVING).
    
    Args:
        pred (dict): Predicted SQL
        label (dict): Gold SQL
    
    Returns:
        tuple: (label_total, pred_total, cnt)
    
    Logic:
        - Extracts column IDs from groupBy lists
        - Removes table prefixes (e.g., 'students.age' → 'age')
        - Counts matching columns
    """
    pred_cols = [unit[1] for unit in pred['groupBy']]
    label_cols = [unit[1] for unit in label['groupBy']]
    pred_total = len(pred_cols)
    label_total = len(label_cols)
    cnt = 0
    pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
    label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
    for col in pred_cols:
        if col in label_cols:
            cnt += 1
            label_cols.remove(col)
    return label_total, pred_total, cnt

### HAVING Clause Evaluation

In [16]:
def eval_having(pred, label):
    """
    Evaluate GROUP BY with HAVING clause.
    
    Args:
        pred (dict): Predicted SQL
        label (dict): Gold SQL
    
    Returns:
        tuple: (label_total, pred_total, cnt)
    
    Logic:
        - Checks if both queries have GROUP BY
        - If yes, checks if grouping columns AND having conditions match exactly
        - Returns binary (0 or 1)
    """
    pred_total = label_total = cnt = 0
    if len(pred['groupBy']) > 0:
        pred_total = 1
    if len(label['groupBy']) > 0:
        label_total = 1

    pred_cols = [unit[1] for unit in pred['groupBy']]
    label_cols = [unit[1] for unit in label['groupBy']]
    if pred_total == label_total == 1 \
            and pred_cols == label_cols \
            and pred['having'] == label['having']:
        cnt = 1

    return label_total, pred_total, cnt

### ORDER BY Evaluation

In [17]:
def eval_order(pred, label):
    """
    Evaluate ORDER BY clause (also checks LIMIT consistency).
    
    Args:
        pred (dict): Predicted SQL
        label (dict): Gold SQL
    
    Returns:
        tuple: (label_total, pred_total, cnt)
    
    Logic:
        - Checks exact match of orderBy structure
        - Also validates LIMIT is present/absent in both queries
        - Returns binary (0 or 1)
    """
    pred_total = label_total = cnt = 0
    if len(pred['orderBy']) > 0:
        pred_total = 1
    if len(label['orderBy']) > 0:
        label_total = 1
    if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
            ((pred['limit'] is None and label['limit'] is None) or 
             (pred['limit'] is not None and label['limit'] is not None)):
        cnt = 1
    return label_total, pred_total, cnt

### AND/OR Evaluation

In [18]:
def eval_and_or(pred, label):
    """
    Evaluate AND/OR logical operators in WHERE clause.
    
    Args:
        pred (dict): Predicted SQL
        label (dict): Gold SQL
    
    Returns:
        tuple: (label_total, pred_total, cnt)
    
    Logic:
        - Extracts all 'and'/'or' operators from conditions
        - Converts to sets and compares
        - If sets match exactly → (1, 1, 1)
        - Otherwise → (count_label, count_pred, 0)
    """
    pred_ao = pred['where'][1::2]
    label_ao = label['where'][1::2]
    pred_ao = set(pred_ao)
    label_ao = set(label_ao)

    if pred_ao == label_ao:
        return 1, 1, 1
    return len(pred_ao), len(label_ao), 0

### Nested Query Evaluation

In [19]:
def eval_nested(pred, label):
    """
    Evaluate a single nested SQL query recursively.
    
    Args:
        pred (dict or None): Predicted nested SQL
        label (dict or None): Gold nested SQL
    
    Returns:
        tuple: (label_total, pred_total, cnt)
    
    Logic:
        - If both have nested query, recursively evaluate exact match
        - Returns 1 if nested queries match exactly
    """
    label_total = 0
    pred_total = 0
    cnt = 0
    if pred is not None:
        pred_total += 1
    if label is not None:
        label_total += 1
    if pred is not None and label is not None:
        cnt += Evaluator().eval_exact_match(pred, label)
    return label_total, pred_total, cnt

In [20]:
def eval_IUEN(pred, label):
    """
    Evaluate INTERSECT, UNION, EXCEPT, and nested queries.
    
    Args:
        pred (dict): Predicted SQL
        label (dict): Gold SQL
    
    Returns:
        tuple: (label_total, pred_total, cnt)
    
    Logic:
        Evaluates all three set operations and sums the results
    """
    lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
    lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
    lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
    label_total = lt1 + lt2 + lt3
    pred_total = pt1 + pt2 + pt3
    cnt = cnt1 + cnt2 + cnt3
    return label_total, pred_total, cnt

### Keywords Evaluation

In [21]:
def eval_keywords(pred, label):
    """
    Evaluate which SQL keywords are used correctly.
    
    Args:
        pred (dict): Predicted SQL
        label (dict): Gold SQL
    
    Returns:
        tuple: (label_total, pred_total, cnt)
    
    Logic:
        Compares sets of keywords extracted from both queries
    """
    pred_keywords = get_keywords(pred)
    label_keywords = get_keywords(label)
    pred_total = len(pred_keywords)
    label_total = len(label_keywords)
    cnt = 0

    for k in pred_keywords:
        if k in label_keywords:
            cnt += 1
    return label_total, pred_total, cnt

<a id='hardness'></a>
## 4. Hardness & Complexity Functions

In [22]:
def get_nestedSQL(sql):
    """
    Extract all nested/sub-queries from a SQL structure.
    
    Args:
        sql (dict): SQL structure
    
    Returns:
        list: List of nested SQL dictionaries
    
    Logic:
        - Checks condition values in FROM, WHERE, HAVING for dict type (subqueries)
        - Adds any INTERSECT/UNION/EXCEPT queries
        - Returns list of all nested queries found
    """
    nested = []
    for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
        if type(cond_unit[3]) is dict:
            nested.append(cond_unit[3])
        if type(cond_unit[4]) is dict:
            nested.append(cond_unit[4])
    if sql['intersect'] is not None:
        nested.append(sql['intersect'])
    if sql['except'] is not None:
        nested.append(sql['except'])
    if sql['union'] is not None:
        nested.append(sql['union'])
    return nested

In [23]:
def get_keywords(sql):
    """
    Extract all SQL keywords used in a query.
    
    Args:
        sql (dict): SQL structure
    
    Returns:
        set: Set of keyword strings
    
    Keywords detected:
        - where, group, having, order, limit
        - except, union, intersect
        - or, not, in, like
        - asc/desc (from ORDER BY)
    """
    res = set()
    if len(sql['where']) > 0:
        res.add('where')
    if len(sql['groupBy']) > 0:
        res.add('group')
    if len(sql['having']) > 0:
        res.add('having')
    if len(sql['orderBy']) > 0:
        res.add(sql['orderBy'][0])
        res.add('order')
    if sql['limit'] is not None:
        res.add('limit')
    if sql['except'] is not None:
        res.add('except')
    if sql['union'] is not None:
        res.add('union')
    if sql['intersect'] is not None:
        res.add('intersect')

    # or keyword
    ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
    if len([token for token in ao if token == 'or']) > 0:
        res.add('or')

    cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
    # not keyword
    if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
        res.add('not')

    # in keyword
    if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
        res.add('in')

    # like keyword
    if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
        res.add('like')

    return res

In [24]:
def count_agg(units):
    """
    Count aggregation functions in a list of units.
    
    Args:
        units (list): List of column units
    
    Returns:
        int: Count of aggregations
    """
    return len([unit for unit in units if has_agg(unit)])

In [25]:
def count_component1(sql):
    """
    Count basic SQL components (WHERE, GROUP BY, ORDER BY, LIMIT, JOIN, OR, LIKE).
    
    Args:
        sql (dict): SQL structure
    
    Returns:
        int: Component count
    
    What it counts:
        - WHERE: +1
        - GROUP BY: +1
        - ORDER BY: +1
        - LIMIT: +1
        - JOINs: +(num_tables - 1)
        - OR operators: +1 each
        - LIKE operators: +1 each
    """
    count = 0
    if len(sql['where']) > 0:
        count += 1
    if len(sql['groupBy']) > 0:
        count += 1
    if len(sql['orderBy']) > 0:
        count += 1
    if sql['limit'] is not None:
        count += 1
    if len(sql['from']['table_units']) > 0:  # JOIN
        count += len(sql['from']['table_units']) - 1

    ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
    count += len([token for token in ao if token == 'or'])
    cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
    count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])

    return count

In [26]:
def count_component2(sql):
    """
    Count nested queries (INTERSECT, UNION, EXCEPT, subqueries).
    
    Args:
        sql (dict): SQL structure
    
    Returns:
        int: Count of nested queries
    """
    nested = get_nestedSQL(sql)
    return len(nested)

In [27]:
def count_others(sql):
    """
    Count advanced features (multiple aggregations, columns, conditions).
    
    Args:
        sql (dict): SQL structure
    
    Returns:
        int: Count of advanced features
    
    What it counts:
        - Multiple aggregations (>1): +1
        - Multiple SELECT columns (>1): +1
        - Multiple WHERE conditions (>1): +1
        - Multiple GROUP BY columns (>1): +1
    """
    count = 0
    # number of aggregation
    agg_count = count_agg(sql['select'][1])
    agg_count += count_agg(sql['where'][::2])
    agg_count += count_agg(sql['groupBy'])
    if len(sql['orderBy']) > 0:
        agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
                            [unit[2] for unit in sql['orderBy'][1] if unit[2]])
    agg_count += count_agg(sql['having'])
    if agg_count > 1:
        count += 1

    # number of select columns
    if len(sql['select'][1]) > 1:
        count += 1

    # number of where conditions
    if len(sql['where']) > 1:
        count += 1

    # number of group by clauses
    if len(sql['groupBy']) > 1:
        count += 1

    return count

### Evaluator Class

In [28]:
class Evaluator:
    """
    Main evaluation class that orchestrates SQL query comparison.
    """
    
    def __init__(self):
        self.partial_scores = None

    def eval_hardness(self, sql):
        """
        Classify SQL query into difficulty levels: easy, medium, hard, or extra.
        
        Args:
            sql (dict): SQL structure
        
        Returns:
            str: 'easy', 'medium', 'hard', or 'extra'
        
        Classification Logic:
            EASY: comp1 ≤ 1, others = 0, comp2 = 0
            MEDIUM: comp1 ≤ 2, others ≤ 2, comp2 = 0
            HARD: comp1 ≤ 3, others ≤ 2, comp2 ≤ 1
            EXTRA: Everything else
        
        Examples:
            EASY: SELECT * FROM students WHERE age > 18
            MEDIUM: SELECT name, age FROM students WHERE city = 'NYC' ORDER BY age
            HARD: SELECT dept, COUNT(*) FROM employees WHERE age > 25 GROUP BY dept ORDER BY COUNT(*)
            EXTRA: Complex queries with multiple nested queries or set operations
        """
        count_comp1_ = count_component1(sql)
        count_comp2_ = count_component2(sql)
        count_others_ = count_others(sql)

        if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
            return "easy"
        elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
                (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
            return "medium"
        elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
                (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
                (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
            return "hard"
        else:
            return "extra"

    def eval_exact_match(self, pred, label):
        """
        Check if predicted and gold SQL queries match exactly.
        
        Args:
            pred (dict): Predicted SQL structure
            label (dict): Gold SQL structure
        
        Returns:
            int: 1 if exact match, 0 otherwise
        
        Logic:
            - Evaluates partial match for all components
            - Returns 1 only if ALL component F1 scores = 1
            - Also checks that FROM tables match exactly
        """
        partial_scores = self.eval_partial_match(pred, label)
        self.partial_scores = partial_scores

        for _, score in partial_scores.items():
            if score['f1'] != 1:
                return 0
        if len(label['from']['table_units']) > 0:
            label_tables = sorted(label['from']['table_units'])
            pred_tables = sorted(pred['from']['table_units'])
            return label_tables == pred_tables
        return 1

    def eval_partial_match(self, pred, label):
        """
        Evaluate partial matching for all SQL components.
        
        Args:
            pred (dict): Predicted SQL
            label (dict): Gold SQL
        
        Returns:
            dict: Scores for each component type
        
        Component Types:
            - select, select(no AGG)
            - where, where(no OP)
            - group(no Having), group
            - order
            - and/or
            - IUEN (intersect/union/except/nested)
            - keywords
        
        Each component returns:
            {'acc': accuracy, 'rec': recall, 'f1': f1_score, 
             'label_total': gold_count, 'pred_total': pred_count}
        """
        res = {}

        label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
        acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
        res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
        acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
        res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt = eval_group(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt = eval_having(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt = eval_order(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt = eval_and_or(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt = eval_IUEN(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt = eval_keywords(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        return res

<a id='rebuilding'></a>
## 5. SQL Rebuilding Functions

These functions normalize SQL structures for fair comparison by:
1. Removing literal values (when DISABLE_VALUE = True)
2. Normalizing column references using foreign key mappings

### Value Rebuilding (Remove Literals)

In [29]:
def rebuild_cond_unit_val(cond_unit):
    """
    Remove literal values from condition units (keep only subqueries).
    
    Args:
        cond_unit (tuple): (not_op, op_id, val_unit, val1, val2)
    
    Returns:
        tuple: Condition unit with values set to None (unless they're subqueries)
    
    Logic:
        If DISABLE_VALUE is True, replace literal values with None 
        while preserving subquery dictionaries
    """
    if cond_unit is None or not DISABLE_VALUE:
        return cond_unit

    not_op, op_id, val_unit, val1, val2 = cond_unit
    if type(val1) is not dict:
        val1 = None
    else:
        val1 = rebuild_sql_val(val1)
    if type(val2) is not dict:
        val2 = None
    else:
        val2 = rebuild_sql_val(val2)
    return not_op, op_id, val_unit, val1, val2

In [30]:
def rebuild_condition_val(condition):
    """
    Rebuild all condition units in a condition list.
    
    Args:
        condition (list): List alternating between condition units and 'and'/'or'
    
    Returns:
        list: Rebuilt condition list
    """
    if condition is None or not DISABLE_VALUE:
        return condition

    res = []
    for idx, it in enumerate(condition):
        if idx % 2 == 0:
            res.append(rebuild_cond_unit_val(it))
        else:
            res.append(it)
    return res

In [31]:
def rebuild_sql_val(sql):
    """
    Recursively remove all literal values from entire SQL structure.
    
    Args:
        sql (dict): SQL structure
    
    Returns:
        dict: SQL structure with values removed
    
    Logic:
        Applies value rebuilding to all clauses and recursively to nested queries
    """
    if sql is None or not DISABLE_VALUE:
        return sql

    sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
    sql['having'] = rebuild_condition_val(sql['having'])
    sql['where'] = rebuild_condition_val(sql['where'])
    sql['intersect'] = rebuild_sql_val(sql['intersect'])
    sql['except'] = rebuild_sql_val(sql['except'])
    sql['union'] = rebuild_sql_val(sql['union'])

    return sql

### Column Rebuilding (Foreign Key Normalization)

In [32]:
def build_valid_col_units(table_units, schema):
    """
    Get list of valid column references based on tables in FROM clause.
    
    Args:
        table_units (list): List of table units from FROM clause
        schema (Schema): Database schema object
    
    Returns:
        list: List of valid column name strings (e.g., 'students.name')
    """
    col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
    prefixs = [col_id[:-2] for col_id in col_ids]
    valid_col_units= []
    for value in schema.idMap.values():
        if '.' in value and value[:value.index('.')] in prefixs:
            valid_col_units.append(value)
    return valid_col_units

In [33]:
def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
    """
    Normalize column references using foreign key mappings.
    
    Args:
        valid_col_units (list): Valid columns for this query
        col_unit (tuple): (agg_id, col_id, distinct)
        kmap (dict): Foreign key mapping dictionary
    
    Returns:
        tuple: Normalized column unit
    
    Logic:
        If column is a foreign key, replace with canonical column reference
    """
    if col_unit is None:
        return col_unit

    agg_id, col_id, distinct = col_unit
    if col_id in kmap and col_id in valid_col_units:
        col_id = kmap[col_id]
    if DISABLE_DISTINCT:
        distinct = None
    return agg_id, col_id, distinct

In [34]:
def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
    """
    Normalize column references in value units.
    Value units may have operations like col1 + col2.
    """
    if val_unit is None:
        return val_unit

    unit_op, col_unit1, col_unit2 = val_unit
    col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
    col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
    return unit_op, col_unit1, col_unit2

In [35]:
def rebuild_sql_col(valid_col_units, sql, kmap):
    """
    Recursively normalize all column references in entire SQL structure.
    
    Args:
        valid_col_units (list): Valid columns for this query
        sql (dict): SQL structure
        kmap (dict): Foreign key mapping
    
    Returns:
        dict: Normalized SQL structure
    """
    if sql is None:
        return sql

    sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
    sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
    sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
    sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
    sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
    sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
    sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
    sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
    sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)

    return sql

<a id='foreign-keys'></a>
## 6. Foreign Key Mapping Functions

In [36]:
def build_foreign_key_map(entry):
    """
    Build foreign key mapping for a database schema.
    
    Args:
        entry (dict): Schema entry with tables, columns, and foreign keys
    
    Returns:
        dict: Foreign key mapping {column_name: canonical_column_name}
    
    Logic:
        Groups foreign key columns together and maps them all to the 
        canonical (lowest index) column in the group.
    
    Example:
        If students.dept_id references departments.id:
        - Both columns get mapped to whichever has lower index
        - This ensures queries using either column are treated equivalently
    """
    cols_orig = entry["column_names_original"]
    tables_orig = entry["table_names_original"]

    # rebuild cols corresponding to idmap in Schema
    cols = []
    for col_orig in cols_orig:
        if col_orig[0] >= 0:
            t = tables_orig[col_orig[0]]
            c = col_orig[1]
            cols.append("__" + t.lower() + "." + c.lower() + "__")
        else:
            cols.append("__all__")

    def keyset_in_list(k1, k2, k_list):
        for k_set in k_list:
            if k1 in k_set or k2 in k_set:
                return k_set
        new_k_set = set()
        k_list.append(new_k_set)
        return new_k_set

    foreign_key_list = []
    foreign_keys = entry["foreign_keys"]
    for fkey in foreign_keys:
        key1, key2 = fkey
        key_set = keyset_in_list(key1, key2, foreign_key_list)
        key_set.add(key1)
        key_set.add(key2)

    foreign_key_map = {}
    for key_set in foreign_key_list:
        sorted_list = sorted(list(key_set))
        midx = sorted_list[0]
        for idx in sorted_list:
            foreign_key_map[cols[idx]] = cols[midx]

    return foreign_key_map

In [37]:
def build_foreign_key_map_from_json(table):
    """
    Build foreign key mappings for all databases from schema JSON file.
    
    Args:
        table (str): Path to tables.json schema file
    
    Returns:
        dict: {db_id: foreign_key_map}
    
    Usage:
        kmaps = build_foreign_key_map_from_json('tables.json')
        db_kmap = kmaps['student_db']
    """
    with open(table) as f:
        data = json.load(f)
    tables = {}
    for entry in data:
        tables[entry['db_id']] = build_foreign_key_map(entry)
    return tables

<a id='main-eval'></a>
## 7. Main Evaluation Functions

In [38]:
def eval_exec_match(db, p_str, g_str, pred, gold):
    """
    Execute both queries and compare results.
    
    Args:
        db (str): Path to SQLite database file
        p_str (str): Predicted SQL string
        g_str (str): Gold SQL string
        pred (dict): Predicted SQL structure
        gold (dict): Gold SQL structure
    
    Returns:
        bool: True if results match, False otherwise
    
    Logic:
        - Executes both queries against the database
        - Compares result sets (order-independent)
        - Maps results by column to handle different orderings
    """
    conn = sqlite3.connect(db)
    cursor = conn.cursor()
    try:
        cursor.execute(p_str)
        p_res = cursor.fetchall()
    except:
        return False

    cursor.execute(g_str)
    q_res = cursor.fetchall()

    def res_map(res, val_units):
        rmap = {}
        for idx, val_unit in enumerate(val_units):
            key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))
            rmap[key] = [r[idx] for r in res]
        return rmap

    p_val_units = [unit[1] for unit in pred['select'][1]]
    q_val_units = [unit[1] for unit in gold['select'][1]]
    return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)

In [39]:
def evaluate(gold, predict, db_dir, etype, kmaps):
    """
    Main evaluation function that processes all query pairs.
    
    Args:
        gold (str): Path to gold SQL file (tab-separated: sql\tdb_id)
        predict (str): Path to predicted SQL file
        db_dir (str): Directory containing database files
        etype (str): Evaluation type - 'all', 'exec', or 'match'
        kmaps (dict): Foreign key mappings for all databases
    
    Process:
        1. Load gold and predicted queries
        2. For each query pair:
           - Parse both into SQL structures
           - Determine hardness level
           - Normalize using rebuilding functions
           - Evaluate exact match and/or execution accuracy
           - Accumulate partial scores
        3. Aggregate scores across difficulty levels
        4. Print comprehensive results
    
    Output:
        Prints detailed evaluation tables to console
    """
    with open(gold) as f:
        glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]

    with open(predict) as f:
        plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
    
    evaluator = Evaluator()

    levels = ['easy', 'medium', 'hard', 'extra', 'all']
    partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
                     'group', 'order', 'and/or', 'IUEN', 'keywords']
    entries = []
    scores = {}

    for level in levels:
        scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
        scores[level]['exec'] = 0
        for type_ in partial_types:
            scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}

    eval_err_num = 0
    for p, g in zip(plist, glist):
        p_str = p[0]
        g_str, db = g
        db_name = db
        db = os.path.join(db_dir, db, db + ".sqlite")
        schema = Schema(get_schema(db))
        g_sql = get_sql(schema, g_str)
        hardness = evaluator.eval_hardness(g_sql)
        scores[hardness]['count'] += 1
        scores['all']['count'] += 1

        try:
            p_sql = get_sql(schema, p_str)
        except:
            # Empty SQL for invalid predictions
            p_sql = {
            "except": None,
            "from": {"conds": [], "table_units": []},
            "groupBy": [],
            "having": [],
            "intersect": None,
            "limit": None,
            "orderBy": [],
            "select": [False, []],
            "union": None,
            "where": []
            }
            eval_err_num += 1
            print("eval_err_num:{}".format(eval_err_num))

        # Rebuild SQL for value and column evaluation
        kmap = kmaps[db_name]
        g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
        g_sql = rebuild_sql_val(g_sql)
        g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
        p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
        p_sql = rebuild_sql_val(p_sql)
        p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)

        if etype in ["all", "exec"]:
            exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
            if exec_score:
                scores[hardness]['exec'] += 1.0
                scores['all']['exec'] += 1.0

        if etype in ["all", "match"]:
            exact_score = evaluator.eval_exact_match(p_sql, g_sql)
            partial_scores = evaluator.partial_scores
            if exact_score == 0:
                print("{} pred: {}".format(hardness,p_str))
                print("{} gold: {}".format(hardness,g_str))
                print("")
            scores[hardness]['exact'] += exact_score
            scores['all']['exact'] += exact_score
            for type_ in partial_types:
                if partial_scores[type_]['pred_total'] > 0:
                    scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
                    scores[hardness]['partial'][type_]['acc_count'] += 1
                if partial_scores[type_]['label_total'] > 0:
                    scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
                    scores[hardness]['partial'][type_]['rec_count'] += 1
                scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
                if partial_scores[type_]['pred_total'] > 0:
                    scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
                    scores['all']['partial'][type_]['acc_count'] += 1
                if partial_scores[type_]['label_total'] > 0:
                    scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
                    scores['all']['partial'][type_]['rec_count'] += 1
                scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']

            entries.append({
                'predictSQL': p_str,
                'goldSQL': g_str,
                'hardness': hardness,
                'exact': exact_score,
                'partial': partial_scores
            })

    # Calculate averages
    for level in levels:
        if scores[level]['count'] == 0:
            continue
        if etype in ["all", "exec"]:
            scores[level]['exec'] /= scores[level]['count']

        if etype in ["all", "match"]:
            scores[level]['exact'] /= scores[level]['count']
            for type_ in partial_types:
                if scores[level]['partial'][type_]['acc_count'] == 0:
                    scores[level]['partial'][type_]['acc'] = 0
                else:
                    scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
                                                             scores[level]['partial'][type_]['acc_count'] * 1.0
                if scores[level]['partial'][type_]['rec_count'] == 0:
                    scores[level]['partial'][type_]['rec'] = 0
                else:
                    scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
                                                             scores[level]['partial'][type_]['rec_count'] * 1.0
                if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
                    scores[level]['partial'][type_]['f1'] = 1
                else:
                    scores[level]['partial'][type_]['f1'] = \
                        2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
                        scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])

    print_scores(scores, etype)

<a id='scoring'></a>
## 8. Scoring & Display Functions

In [None]:
def print_scores(scores, etype):
    """
    Print formatted evaluation results.
    
    Args:
        scores (dict): Evaluation scores by difficulty level
        etype (str): Evaluation type - 'all', 'exec', or 'match'
    
    Output Format:
        - Query counts by difficulty
        - Execution accuracy (if etype includes 'exec')
        - Exact match accuracy (if etype includes 'match')
        - Partial matching accuracy/recall/F1 for all components
    
    All metrics shown for: easy, medium, hard, extra, and all queries
    """
    levels = ['easy', 'medium', 'hard', 'extra', 'all']
    partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
                     'group', 'order', 'and/or', 'IUEN', 'keywords']

    print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
    counts = [scores[level]['count'] for level in levels]
    print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))

    if etype in ["all", "exec"]:
        print('=====================   EXECUTION ACCURACY     =====================')
        this_scores = [scores[level]['exec'] for level in levels]
        print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))

    if etype in ["all", "match"]:
        print('\n====================== EXACT MATCHING ACCURACY =====================')
        exact_scores = [scores[level]['exact'] for level in levels]
        print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
        print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
        for type_ in partial_types:
            this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
            print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))

        print('---------------------- PARTIAL MATCHING RECALL ----------------------')
        for type_ in partial_types:
            this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
            print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))

        print('---------------------- PARTIAL MATCHING F1 --------------------------')
        for type_ in partial_types:
            this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
            print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))

## Summary

This evaluation system provides:

1. **Comprehensive Comparison**: Evaluates predicted vs gold SQL on structure and execution
2. **Granular Metrics**: Breaks down performance by SQL components and difficulty levels
3. **Normalization**: Handles value abstraction and foreign key equivalence
4. **Difficulty Classification**: Automatically categorizes queries by complexity
5. **Detailed Output**: Provides accuracy, recall, and F1 for 10 different aspects

The system enables precise diagnosis of text-to-SQL model strengths and weaknesses across different query types and complexity levels.