# Completer Test

This notebook is made for assessing the performence of the QPL completer alone

## Generate Completions

In [1]:
from transformers.models.auto.tokenization_auto import AutoTokenizer
from peft import AutoPeftModelForCausalLM
from datasets import load_dataset
from tqdm import tqdm

from src.prompters import PrompterRegistry
from src.utils.generation import to_model_prompt, generate_batch
import src.utils.paths as p

# Constants
BATCH_SIZE = 6
MAX_NEW_TOKENS = 128
MODEL_CKPT = "output/models/855d8cb9_gemma-3-4b-it-qpl-composer-ds_train_batch_size=1_gradient_accumulation_steps=8_learning_rate=0.0002_num_train_epochs=4_gradient_checkpointing=True_logging_steps=0.00125_save_steps=0.0625_random_seed=1_lora=True_r=16_alpha=32_dropout=0.05/"
MODEL_PATH = MODEL_CKPT + "/checkpoint-5316"
DATASET_ID = "d4nieldev/qpl-completer-ds"

# Load model & tokenizer
model = AutoPeftModelForCausalLM.from_pretrained(MODEL_PATH, attn_implementation='eager').to('cuda')
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.04it/s]


In [1]:
# Connect to SQL Server
import pyodbc
from src.inference.qpl.qpl_to_cte import flat_qpl_to_cte
from src.inference.qpl.validate_qpl import execute_sql, same_rs

connection_string = (
    'Driver={ODBC Driver 18 for SQL Server};'
    'Server=tcp:spider-sql.database.windows.net,1433;'
    'Database=test;'
    'Uid=iloveqpl;'
    'Pwd=P4$$w0rd!;'
    'Encrypt=yes;'
    'TrustServerCertificate=no;'
    'Connection Timeout=30;'
)

conn = pyodbc.connect(connection_string, autocommit=True)
cursor = conn.cursor()

## Evaluate

### Line By Line

In [None]:
# Load and process data
test_dataset = list(load_dataset(DATASET_ID, split="validation"))
prompter = PrompterRegistry.get(DATASET_ID)(with_assistant=False)
chat_templates = list(map(prompter.to_chat_template, test_dataset))
prompts = list(map(lambda ct: to_model_prompt(tokenizer, ct), chat_templates))

In [None]:
# Complete QPL
outputs = generate_batch(
    model=model,
    tokenizer=tokenizer,
    model_prompts=prompts,
    batch_size=BATCH_SIZE,
    max_new_tokens=MAX_NEW_TOKENS,
    progress_bar=tqdm(total=len(prompts), desc="Completing QPL"),
    do_sample=False,
)

In [None]:
# optional - save outputs
import json

with open('line_by_line_outputs.json', 'w') as f:
    json.dump(outputs, f, indent=2)

In [None]:
# optional - load outputs
with open('line_by_line_outputs.json', 'r') as f:
    outputs = json.load(f)

In [None]:
# Calculate execution accuracy and error rate
execution_accuracy = 0
gold_errs = 0
pred_errs = 0
for out, example in tqdm(zip(outputs, test_dataset), desc="Evaluating", total=len(outputs)):
    gold = example['prefix_qpl'] + "\n" + example['qpl_line']
    pred = example['prefix_qpl'] + "\n" + example['qpl_line'][:example['qpl_line'].index(' = ')+3] + example['op'] + ' ' + out

    flat_gold = [line[:line.index('--')] if '--' in line else line for line in gold.split('\n')]
    flat_pred = [line[:line.index('--')] if '--' in line else line for line in pred.split('\n')]

    flat_gold = [l for l in flat_gold if l.strip()]
    flat_pred = [l for l in flat_pred if l.strip() and '`' not in l]

    gold_cte = flat_qpl_to_cte(flat_gold, example['db_id'])

    try:
        pred_cte = flat_qpl_to_cte(flat_pred, example['db_id'])
    except Exception as e:
        print(f"Error converting prediction to CTE\n{pred}")
        print(f"Error: {e}")
        print('-'*20)
        pred_errs += 1

    try:
        grs = execute_sql(cursor, gold_cte)
    except Exception as e:
        print(f"Error executing gold QPL\n{gold}")
        print(f"Error: {e}")
        print('-'*20)
        gold_errs += 1
    else:
        try:
            prs = execute_sql(cursor, pred_cte)
            same = same_rs(grs, prs, flat_pred)
        except Exception as e:
            print(f"Error executing prediction QPL\n{pred}")
            print(f"Error: {e}")
            print('-'*20)
            pred_errs += 1

    if same:
        execution_accuracy += 1

print(f"Execution accuracy: {execution_accuracy}/{len(outputs)} ({execution_accuracy / len(outputs) * 100:.2f}%)")
print(f"Gold error rate: {gold_errs}/{len(outputs)} ({gold_errs / len(outputs) * 100:.2f}%)")
print(f"Prediction error rate: {pred_errs}/{len(outputs)} ({pred_errs / len(outputs) * 100:.2f}%)")

### Full Tree With Perfect Decomposition

In [2]:
# Full tree
from src.databuilders.completer.build import get_decomposer_roots
from src.utils.qpl.tree import PartialQDTree, QPLQDTree
from datasets import load_dataset
from src.inference.qpl.text_to_qpl import complete

# Load and process data
def partial_qd_to_qd(tree: PartialQDTree) -> QPLQDTree:
    """Convert a PartialQDTree to a QPLQDTree."""
    qd_tree = QPLQDTree(
        question=tree.question,
        db_id=tree.db_id,
        op=tree.op,
    )
    if tree.children:
        qd_tree.children = tuple(partial_qd_to_qd(child) for child in tree.children)
        for child in qd_tree.children:
            child.parent = qd_tree
    return qd_tree


decomposer_data = load_dataset("bgunlp/question_decomposer_ds", split="validation")
nl2qpl_data = load_dataset('d4nieldev/nl2qpl-ds', split='validation')
root_questions = set(row['question'] for row in nl2qpl_data)
decomposer_data = [row for row in decomposer_data if row['question'] not in [row['sub_question_1'], row['sub_question_2']]]
root_qd_trees = get_decomposer_roots(decomposer_data, root_questions)
root_qd_trees = [partial_qd_to_qd(tree) for tree in root_qd_trees]

def post_order_index_tree(tree: QPLQDTree, counter: int = 1) -> int:
    for child in tree.children:
        counter = post_order_index_tree(child, counter)
    tree.line_num = counter
    return counter + 1

for tree in root_qd_trees:
    post_order_index_tree(tree)

# Complete QPL for each tree
complete(
    trees=root_qd_trees,
    prompter=PrompterRegistry.get(DATASET_ID)(with_assistant=False),
    model=model,
    tokenizer=tokenizer,
    batch_size=BATCH_SIZE,
    max_new_tokens=MAX_NEW_TOKENS,
)

Constructing QD trees: 100%|██████████| 3027/3027 [00:00<00:00, 158075.60it/s]
Completing QPL: 100%|██████████| 3274/3274 [35:12<00:00,  1.28node/s]

In [3]:
# optionally save outputs
import json
with open('full_tree_outputs_5316.json', 'w') as f:
    json.dump([tree.to_dict() for tree in root_qd_trees], f, indent=2)

In [4]:
# optional - load outputs
import json
from src.utils.qpl.tree import QPLQDTree
from datasets import load_dataset

nl2qpl_data = load_dataset('d4nieldev/nl2qpl-ds', split='validation')
with open('full_tree_outputs_3996.json', 'r') as f:
    root_qd_trees = [QPLQDTree.from_dict(tree) for tree in json.load(f)]

In [None]:
# Calculate execution accuracy and error rate
from tqdm import tqdm

gold_errs = 0
pred_errs = 0
execution_accuracy = 0
total_pred = 0

for example in tqdm(nl2qpl_data, desc="Processing examples"):
    trees = [t for t in root_qd_trees if t.question == example['question']]
    total_pred += len(trees)
    gold = example['query']
    for tree in trees:
        pred = tree.qpl

        flat_pred = [line[:line.index('--')] if '--' in line else line for line in pred.split('\n')]
        flat_pred = [l for l in flat_pred if l.strip() and '`' not in l]

        db_id = tree.db_id if tree.db_id != 'car_11' else 'car_1'

        try:
            pred_cte = flat_qpl_to_cte(flat_pred, db_id)
        except Exception as e:
            print(f"Error converting prediction to CTE\n{pred}")
            print(f"Error: {e}")
            print('-'*20)
            pred_errs += 1
            continue

        try:
            grs = execute_sql(cursor, gold)
        except Exception as e:
            print(f"Error executing gold QPL\n{gold}")
            print(f"Error: {e}")
            print(f"Database ID: {db_id}")
            print('-'*20)
            gold_errs += 1
            continue
        else:
            try:
                prs = execute_sql(cursor, pred_cte)
                same = same_rs(grs, prs, flat_pred)
            except Exception as e:
                print(f"Error executing prediction QPL\n{pred}")
                print(f"Error: {e}")
                print('-'*20)
                pred_errs += 1
                continue
        
        if same:
            execution_accuracy += 1

print(f"Execution accuracy: {execution_accuracy}/{total_pred} ({execution_accuracy / total_pred * 100:.2f}%)")
print(f"Gold errors: {gold_errs}/{total_pred} ({gold_errs / total_pred * 100:.2f}%)")
print(f"Prediction errors: {pred_errs}/{total_pred} ({pred_errs / total_pred * 100:.2f}%)")

Processing examples:   9%|▉         | 94/1034 [00:02<00:23, 40.37it/s]

Error executing prediction QPL
#1 = Scan Table [ countries ] Output [ Continent , Continent ] ; -- List the continent id of all countries.
#2 = Aggregate [ #1 ] GroupBy [ Continent ] Output [ countstar AS Count_Star , Continent ] ; -- How many countries does each continent have? List the continent id.
#3 = Scan Table [ continents ] Output [ ContId , Continent ] ; -- List the id and name of all continents.
#4 = Join [ #2 , #3 ] Predicate [ #2.Continent = #3.ContId ] Output [ #3.Continent , #3.ContId , #2.Count_Star ] ; -- How many countries does each continent have? List the continent id, continent name and the number of countries.
Error: ('42000', "[42000] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]The column 'Continent' was specified multiple times for 'Scan_1'. (8156) (SQLExecDirectW)")
--------------------
Error executing prediction QPL
#1 = Scan Table [ countries ] Output [ Continent , Continent ] ; -- List the continent id of all countries.
#2 = Aggregate [ #1 ] GroupBy

Processing examples:  14%|█▍        | 144/1034 [00:03<00:16, 55.51it/s]

Error executing prediction QPL
#1 = Scan Table [ cars_data ] Predicate [ Brand = 'Volvo' ] Output [ Id , Brand ] ; -- What is the id of the cars for all volvos?
#2 = Scan Table [ cars_data ] Output [ Edispl , Id ] ; -- What is the id and edispl of all cars?
#3 = Join [ #1 , #2 ] Predicate [ #1.Id = #2.Id ] Output [ #2.Edispl ] ; -- What is the edispl of the cars for all volvos?
#4 = Aggregate [ #3 ] Output [ AVG(Edispl) AS Avg_Edispl ] ; -- What is the average edispl for all volvos?
Error: ('42S22', "[42S22] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid column name 'Brand'. (207) (SQLExecDirectW)")
--------------------
Error executing prediction QPL
#1 = Scan Table [ cars_data ] Output [ Model ] ; -- Show the model of all cars.
#2 = Aggregate [ #1 ] GroupBy [ Model ] Output [ countstar AS Count_Star , Model ] ; -- Show the different models and the number of cars correspond to each.
#3 = TopSort [ #2 ] Rows [ 1 ] OrderBy [ Count_Star DESC ] Output [ Count_Star , Model ] 

Processing examples:  15%|█▌        | 156/1034 [00:03<00:17, 49.11it/s]

Error executing prediction QPL
#1 = Scan Table [ car_makers ] Output [ Id , FullName ] ; -- Find the ids and full names of all makers.
#2 = Scan Table [ model_list ] Output [ Maker , Model ] ; -- Find the names and maker ids of all models.
#3 = Scan Table [ model_list ] Output [ ModelId , Model ] ; -- Find the ids and models of all cars.
#4 = Scan Table [ cars_data ] Output [ Weight , Id ] ; -- Find the ids and weights of all cars.
#5 = Join [ #3 , #4 ] Predicate [ #3.ModelId = #4.Id ] Output [ #3.Model , #4.Weight ] ; -- Find the models and weights of all cars.
#6 = Join [ #2 , #5 ] Predicate [ #2.Maker = #5.ModelId ] Output [ #5.Weight , #2.Maker , #5.Model ] ; -- Find the models and weights and maker ids of all cars.
#7 = Join [ #1 , #6 ] Predicate [ #1.Id = #6.Maker ] Output [ #6.Model , #6.Weight , #1.FullName ] ; -- Find the models and weights and maker full names of all cars.
#8 = Filter [ #7 ] Predicate [ Weight > 3500.0 OR FullName = 'General Motors' ] Distinct [ true ] Output

Processing examples:  23%|██▎       | 238/1034 [00:05<00:25, 31.74it/s]

Error executing prediction QPL
#1 = Scan Table [ flights ] Predicate [ DestAirport = 'AHD' ] Output [ DestAirport , Airline ] ; -- What are airline ids that have flights arriving at airport 'AHD'?
#2 = Scan Table [ airlines ] Output [ Airline , uid ] ; -- What are the ids and names of all airlines?
#3 = Join [ #1 , #2 ] Predicate [ #1.Airline = #2.uid ] Output [ #2.Airline , #2.name ] ; -- What are airlines that have flights arriving at airport 'AHD'?
Error: ('42S22', "[42S22] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid column name 'name'. (207) (SQLExecDirectW)")
--------------------


Processing examples:  26%|██▌       | 264/1034 [00:06<00:24, 31.97it/s]

Error executing prediction QPL
#1 = Scan Table [ airports ] Output [ AirportCode , AirportName ] ; -- Find the codes and names of all airports.
#2 = Scan Table [ flights ] Output [ SourceAirport ] ; -- Which airports codes have departing flights?
#3 = Scan Table [ flights ] Output [ DestAirport ] ; -- Which airports codes have arriving flights?
#4 = Union [ #2 , #3 ] Output [ #2.SourceAirport ] ; -- Which airports codes have departing or arriving flights?
#5 = Except [ #1 , #4 ] Predicate [ #4.SourceAirport = #1.AirportCode ] Output [ #1.AirportName ] ; -- Which airports do not have departing or arriving flights?
Error: ('42S22', "[42S22] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid column name 'SourceAirport'. (207) (SQLExecDirectW)")
--------------------


Processing examples:  42%|████▏     | 435/1034 [00:11<00:18, 32.49it/s]

Error executing prediction QPL
#1 = Scan Table [ player ] Output [ 1 AS One ] ; -- List 1 for each player.
#2 = Aggregate [ #1 ] Output [ countstar AS Count_Star ] ; -- Find the total number of players.
Error: ('42S02', "[42S02] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid object name 'wta_1.player'. (208) (SQLExecDirectW)")
--------------------
Error executing prediction QPL
#1 = Scan Table [ player ] Output [ 1 AS One ] ; -- List 1 for each player.
#2 = Aggregate [ #1 ] Output [ countstar AS Count_Star ] ; -- How many players are there?
Error: ('42S02', "[42S02] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid object name 'wta_1.player'. (208) (SQLExecDirectW)")
--------------------


Processing examples:  43%|████▎     | 448/1034 [00:11<00:16, 35.03it/s]

Error executing prediction QPL
#1 = Scan Table [ country ] Distinct [ true ] Output [ country_code ] ; -- Find the distinct country codes of all players.
#2 = Aggregate [ #1 ] Output [ countstar AS Count_Star ] ; -- find the number of distinct country codes of all players.
Error: ('42S02', "[42S02] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid object name 'wta_1.country'. (208) (SQLExecDirectW)")
--------------------
Error executing prediction QPL
#1 = Scan Table [ country ] Distinct [ true ] Output [ country_code ] ; -- Find the distinct country codes the players come from.
#2 = Aggregate [ #1 ] Output [ countstar AS Count_Star ] ; -- How many distinct countries do players come from?
Error: ('42S02', "[42S02] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid object name 'wta_1.country'. (208) (SQLExecDirectW)")
--------------------


Processing examples:  44%|████▎     | 452/1034 [00:11<00:16, 35.41it/s]

Error executing prediction QPL
#1 = Scan Table [ match ] Predicate [ round = 'WTA Championships' ] Output [ round , winner_id ] ; -- What are the id of the players who won in tourney WTA Championships?
#2 = Scan Table [ players ] Output [ country_code , player_id , first_name ] ; -- What are the id, first name and country code of all players?
#3 = Join [ #1 , #2 ] Predicate [ #1.winner_id = #2.player_id ] Output [ #2.country_code , #2.first_name ] ; -- What are the country code and first name of the players who won in tourney WTA Championships?
#4 = Scan Table [ match ] Predicate [ tourney_name = 'Australian Open' ] Output [ tourney_name , winner_id ] ; -- What are the id of the players who won in tourney Australian Open?
#5 = Scan Table [ players ] Output [ country_code , player_id , first_name ] ; -- What are the id, first name and country code of all players?
#6 = Join [ #4 , #5 ] Predicate [ #4.winner_id = #5.player_id ] Output [ #5.country_code , #5.first_name ] ; -- What are the 

Processing examples:  45%|████▌     | 466/1034 [00:12<00:36, 15.71it/s]

Error executing prediction QPL
#1 = Scan Table [ match ] Output [ winner_name , winner_rank_points ] ; -- What is the name of the winner in all matches, and how many rank points does this player have.
#2 = Aggregate [ #1 ] GroupBy [ winner_rank_points , winner_name ] Output [ countstar AS Count_Star , winner_name , winner_rank_points ] ; -- What are the different names of winners and the corresponding number of wins in matches and and how many rank points does this player have?
#3 = TopSort [ #2 ] Rows [ 1 ] OrderBy [ Count_Star DESC ] Output [ Count_Star , winner_name , winner_rank_points ] ; -- What is the name of the winner who has won the most matches, and how many rank points does this player have?
Error: ('42S02', "[42S02] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid object name 'wta_1.match'. (208) (SQLExecDirectW)")
--------------------
Error executing prediction QPL
#1 = Scan Table [ match ] Predicate [ tourney_name = 'Australian Open' ] Output [ tourney_name 

Processing examples:  45%|████▌     | 470/1034 [00:13<01:02,  9.08it/s]

Error executing prediction QPL
#1 = Scan Table [ player ] Output [ last_name , player_id , first_name ] ; -- What are the ids and names of all players?
#2 = Scan Table [ rankings ] Output [ player_id , ranking ] ; -- What are the ids of all players, and their rankings?
#3 = Join [ #1 , #2 ] Predicate [ #1.player_id = #2.player_id ] Output [ #2.ranking , #1.first_name ] ; -- What are the first names of all players, and their rankings?
#4 = Aggregate [ #3 ] GroupBy [ first_name ] Output [ AVG(ranking) AS Avg_ranking , first_name ] ; -- What are the first names of all players, and their average rankings?
Error: ('42S02', "[42S02] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid object name 'wta_1.player'. (208) (SQLExecDirectW)")
--------------------


Processing examples:  46%|████▌     | 477/1034 [00:16<01:49,  5.10it/s]

Error executing prediction QPL
#1 = Scan Table [ country ] Output [ Country_code ] ; -- Find the codes of countries of all players.
#2 = Aggregate [ #1 ] GroupBy [ Country_code ] Output [ countstar AS Count_Star , Country_code ] ; -- Find the number of players for each country code.
#3 = Filter [ #2 ] Predicate [ Count_Star > 50 ] Output [ Country_code ] ; -- Find the codes of countries that have more than 50 players.
Error: ('42S02', "[42S02] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid object name 'wta_1.country'. (208) (SQLExecDirectW)")
--------------------


Processing examples:  48%|████▊     | 501/1034 [00:17<00:24, 21.46it/s]

Error executing prediction QPL
#1 = Scan Table [ death ] Output [ injured ] ; -- What is the number of injuries caused each time?
#2 = Aggregate [ #1 ] GroupBy [ time ] Output [ countstar AS Count_Star , AVG(injured) AS Avg_injured ] ; -- What is the average number of injuries caused each time?
Error: ('42S22', "[42S22] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid column name 'time'. (207) (SQLExecDirectW)")
--------------------
Error executing prediction QPL
#1 = Scan Table [ ship ] Output [ lost_in_battle , id ] ; -- What are the ids of all ships and the id of the battle in which they were lost?
#2 = Scan Table [ battle ] Output [ name , id ] ; -- What are the ids and names of all the battles?
#3 = Join [ #1 , #2 ] Predicate [ #1.lost_in_battle = #2.id ] Output [ #2.name , #2.id ] ; -- What are the ids and names of the battles where a ship was lost?
#4 = Scan Table [ death ] Output [ killed , caused_by_ship_id ] ; -- What is the number of people killed in all death e

Processing examples:  52%|█████▏    | 533/1034 [00:19<00:21, 23.23it/s]

Error executing prediction QPL
#1 = Scan Table [ Sections ] Output [ course_id ] ; -- What are the ids of courses of all sections?
#2 = Scan Table [ Courses ] Output [ course_name , course_id ] ; -- What are the ids and names of all courses?
#3 = Join [ #1 , #2 ] Predicate [ #1.course_id = #2.course_id ] Output [ #2.course_name , #1.course_id ] ; -- What are the ids of courses and names of courses of all sections?
#4 = Aggregate [ #3 ] GroupBy [ course_id ] Output [ countstar AS Count_Star , course_name ] ; -- What is the number of sections for each course?
#5 = Filter [ #4 ] Predicate [ Count_Star <= 2 ] Output [ course_name , course_id ] ; -- What are the names and id of courses having at most 2 sections?
Error: ('42S22', "[42S22] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid column name 'course_id'. (207) (SQLExecDirectW)")
--------------------
Error executing prediction QPL
#1 = Scan Table [ Sections ] Output [ course_id ] ; -- What are the ids of courses of all sec

Processing examples:  53%|█████▎    | 553/1034 [00:19<00:14, 32.96it/s]

Error executing prediction QPL
#1 = Scan Table [ Student_Enrolment ] Output [ course_id ] ; -- What's the id of the course of all enrollments?
#2 = Scan Table [ Courses ] Output [ course_name , course_id ] ; -- What's the ids and names of all courses?
#3 = Join [ #1 , #2 ] Predicate [ #1.course_id = #2.course_id ] Output [ #2.course_name ] ; -- What's the name of the course of all enrollments?
#4 = Aggregate [ #3 ] GroupBy [ course_name ] Output [ countstar AS Count_Star , course_name ] ; -- What's the name of the course and number of enrollments for each course?
#5 = TopSort [ #4 ] Rows [ 1 ] OrderBy [ Count_Star DESC ] Output [ course_name , Count_Star ] ; -- What's the name of the course with most number of enrollments?
Error: ('42S22', "[42S22] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid column name 'course_id'. (207) (SQLExecDirectW)")
--------------------
Error executing prediction QPL
#1 = Scan Table [ Student_Enrolment ] Output [ course_id ] ; -- What is the i

Processing examples:  54%|█████▍    | 561/1034 [00:19<00:15, 31.26it/s]

Error executing prediction QPL
#1 = Scan Table [ Students ] Output [ last_name , date_first_registered , middle_name , graduation_date , first_name ] ; -- What is the first, middle, last name and graduation date of all students?
#2 = TopSort [ #1 ] Rows [ 1 ] OrderBy [ graduation_date ASC ] Output [ last_name , date_first_registered , middle_name , graduation_date , first_name ] ; -- What is the first, middle, and last name of the earliest school graduate?
Error: ('42S22', "[42S22] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid column name 'graduation_date'. (207) (SQLExecDirectW)")
--------------------


Processing examples:  65%|██████▍   | 669/1034 [00:22<00:09, 37.31it/s]

Error executing prediction QPL
#1 = Scan Table [ poker_player ] Output [ Poker_Player_ID , Final_Table_Made ] ; -- List the ids and final tables made of all poker players.
#2 = Scan Table [ people ] Output [ People_ID , Name ] ; -- List the ids and names of all people.
#3 = Join [ #1 , #2 ] Predicate [ #1.People_ID = #2.People_ID ] Output [ #2.Name , #1.Final_Table_Made ] ; -- List the names and final tables made of all poker players.
#4 = Sort [ #3 ] OrderBy [ Final_Table_Made ASC ] Output [ Final_Table_Made , Name ] ; -- List the names of poker players ordered by the final tables made in ascending order.
Error: ('42S22', "[42S22] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid column name 'People_ID'. (207) (SQLExecDirectW)")
--------------------


Processing examples:  68%|██████▊   | 701/1034 [00:23<00:10, 32.75it/s]

Error executing prediction QPL
#1 = Scan Table [ VOTES ] Output [ contestant_number ] ; -- What are the contestant numbers of all votes?
#2 = Aggregate [ #1 ] GroupBy [ contestant_number ] Output [ countstar AS Count_Star , contestant_number ] ; -- What is the number of votes for each contestant number?
#3 = Scan Table [ CONTESTANTS ] Output [ contestant_number , contestant_name ] ; -- What are the numbers and names of all contestants?
#4 = Join [ #2 , #3 ] Predicate [ #2.contestant_number = #3.contestant_number ] Output [ #3.contestant_name , #2.Count_Star ] ; -- What is the number of votes for each contestant?
#5 = Filter [ #4 ] Predicate [ Count_Star >= 2 ] Output [ contestant_number , contestant_name ] ; -- What are the contestant numbers and names of the contestants who had at least two votes?
Error: ('42S22', "[42S22] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid column name 'contestant_number'. (207) (SQLExecDirectW)")
--------------------


Processing examples:  72%|███████▏  | 745/1034 [00:25<00:08, 33.56it/s]

Error executing prediction QPL
#1 = Scan Table [ country ] Predicate [ LocalName = 'Afghanistan' ] Output [ Code , LocalName ] ; -- Find the code of Afghanistan country.
#2 = Scan Table [ countrylanguage ] Output [ IsOfficial , CountryCode ] ; -- Find the country code and the indication whether the language is an official language in all records of a country and language that is spoken in this country.
#3 = Filter [ #2 ] Predicate [ IsOfficial = 'Y' ] Output [ CountryCode ] ; -- Find the country code in all records of a country and language that is an official language in this country.
#4 = Join [ #1 , #3 ] Predicate [ #1.Code = #3.CountryCode ] Output [ #1.LocalName , #1.Code , #1.IsOfficial , #1.Percentage ] ; -- List 1 for each official language Afghanistan has.
#5 = Aggregate [ #4 ] Output [ countstar AS Count_Star ] ; -- How many official languages does Afghanistan have?
Error: ('42S22', "[42S22] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid column name 'IsOfficial

Processing examples:  76%|███████▌  | 785/1034 [00:26<00:08, 30.32it/s]

Error executing prediction QPL
#1 = Scan Table [ countrylanguage ] Output [ CountryCode , Language ] ; -- What are the country codes in all records of a country and the language spoken in this country?
#2 = Aggregate [ #1 ] GroupBy [ CountryCode ] Output [ CountryCode , MAX(Percentage) AS Max_Percentage , Language ] ; -- What are the different country codes in all records of a country and the language spoken in this country?
#3 = Scan Table [ countrylanguage ] Predicate [ Language = 'English' ] Output [ CountryCode , Language ] ; -- What are the country codes for countries that speak English?
#4 = Except [ #2 , #3 ] Predicate [ #3.CountryCode = #2.CountryCode ] Output [ #2.CountryCode ] ; -- What are the country codes for countries that do not speak English?
Error: ('42S22', "[42S22] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid column name 'Percentage'. (207) (SQLExecDirectW)")
--------------------
Error converting prediction to CTE
#1 = Scan Table [ countrylanguage ] 

Processing examples:  98%|█████████▊| 1012/1034 [00:33<00:00, 34.38it/s]

Error executing prediction QPL
#1 = Scan Table [ singer ] Predicate [ Citizenship <> 'France' ] Output [ Nationality , Name ] ; -- What are the names of the singers who are not French citizens?
Error: ('42S22', "[42S22] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid column name 'Nationality'. (207) (SQLExecDirectW)")
--------------------


Processing examples: 100%|██████████| 1034/1034 [00:34<00:00, 30.26it/s]

Execution accuracy: 891/1031 (86.42%)
Gold errors: 0/1031 (0.00%)
Prediction errors: 34/1031 (3.30%)





### Rule Based (Deprecated)

In [None]:
import re
from typing import List

prompter = PrompterRegistry.get(DATASET_ID)(with_assistant=True)
chat_templates = []
for example in test_dataset:
    chat_templates.append(prompter.to_chat_template(example))


def equivalent_bracketed_lines(
        true_line: str,
        generated_line: str,
        *,
        require_exact_names: bool = False,
        ignore_case: bool = True
    ) -> bool:
    _ZERO_DECIMAL = re.compile(r'(?<![\d.])(\d+)\.0+\b')
    _TABLE_COL    = re.compile(r'#\d+\.\s*[\w$]+', flags=re.I)
    _EQUALITY     = re.compile(rf'({_TABLE_COL.pattern})\s*=\s*({_TABLE_COL.pattern})', flags=re.I)

    def _norm_nums(txt: str) -> str:
        return _ZERO_DECIMAL.sub(r'\1', txt)

    # ---- build equivalence classes from join predicates
    def _build_equiv_map(*lines: str) -> dict[str, str]:
        parent = {}

        def find(x):
            parent.setdefault(x, x)
            if parent[x] != x:
                parent[x] = find(parent[x])
            return parent[x]

        def union(a, b):
            ra, rb = find(a), find(b)
            if ra != rb:
                parent[rb] = ra

        for line in lines:
            for block in re.findall(r'Predicate\s*\[([^\]]*)\]',
                                    _norm_nums(line), flags=re.I):
                for left, right in _EQUALITY.findall(block):
                    l = left.replace(' ', '')
                    r = right.replace(' ', '')
                    if ignore_case:
                        l, r = l.upper(), r.upper()
                    union(l, r)

        equiv = {}
        for full in parent:
            col = full.split('.', 1)[1]
            root = find(full)
            equiv[full] = col.upper() if ignore_case else col
            if root not in equiv:
                equiv[root] = equiv[full]
        return equiv

    EQUIV = _build_equiv_map(true_line, generated_line)

    # ---- compare skeletons (outside brackets)
    def _skeleton(txt: str) -> str:
        txt = _norm_nums(txt)
        txt = re.sub(r'\[[^\]]*]', '[]', txt)
        txt = re.sub(r'\s+', ' ', txt).strip()
        return txt.upper() if ignore_case else txt

    if _skeleton(true_line) != _skeleton(generated_line):
        return False

    # ---- helper to extract bracket contents
    blocks = lambda s: re.findall(r'\[([^\]]*)]', _norm_nums(s))

    t_blocks, g_blocks = blocks(true_line), blocks(generated_line)
    if len(t_blocks) != len(g_blocks):
        return False

    # ---- canonicalise individual tokens
    def canon(tok: str) -> str:
        tok = re.split(r'\s+AS\s+', tok, flags=re.I)[0]
        tok = re.sub(r'\s+', ' ', tok).strip()
        tok = _norm_nums(tok)

        def repl(m):
            key = m.group(0).replace(' ', '')
            key = key.upper() if ignore_case else key
            return EQUIV.get(key, m.group(0))

        tok = _TABLE_COL.sub(repl, tok)
        return tok.upper() if ignore_case else tok

    # ---- compare each corresponding block
    for tb, gb in zip(t_blocks, g_blocks):
        t_set = {canon(tok) for tok in tb.split(',') if tok.strip()}
        g_set = {canon(tok) for tok in gb.split(',') if tok.strip()}

        if require_exact_names:
            if t_set != g_set:
                return False
        else:
            if not t_set.issubset(g_set):
                return False

    return True


acc = 0
for out, chat_template, example in tqdm(zip(outputs, chat_templates, test_dataset), desc="Evaluating", total=len(outputs)):
    label = chat_template['messages'][-1]['content']
    equivalent = equivalent_bracketed_lines(
        label,
        out,
    )
    if not equivalent:
        print("Question:")
        print(example['question'])
        print('-'*40)
        print("Predicted:")
        print(example['op'], out)
        print('-'*40)
        print("True:")
        print(example['op'], label)
        print("=" * 80)
    else:
        acc += 1

print(f"Accuracy: {acc}/{len(outputs)} = {acc / len(outputs) * 100:.2f}%")