# Completer Test

This notebook is made for assessing the performence of the QPL completer alone

## Generate Completions

In [None]:
from transformers.models.auto.tokenization_auto import AutoTokenizer
from peft import AutoPeftModelForCausalLM
from src.prompters import PrompterRegistry
from datasets import load_dataset
from tqdm import tqdm

from src.utils.generation import to_model_prompt, generate_batch
import src.utils.paths as p

# Constants
BATCH_SIZE = 6
MAX_NEW_TOKENS = 128
MODEL_CKPT = "855d8cb9_gemma-3-4b-it-qpl-composer-ds_train_batch_size=1_gradient_accumulation_steps=8_learning_rate=0.0002_num_train_epochs=4_gradient_checkpointing=True_logging_steps=0.00125_save_steps=0.0625_random_seed=1_lora=True_r=16_alpha=32_dropout=0.05/checkpoint-3996"
MODEL_PATH = p.TRAINED_MODELS_DIR / MODEL_CKPT
DATASET_ID = "d4nieldev/qpl-completer-ds"


# Load model & tokenizer
model = AutoPeftModelForCausalLM.from_pretrained(MODEL_PATH, attn_implementation='eager').cuda()
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

# Load and process data
test_dataset = list(load_dataset(DATASET_ID, split="validation"))
prompter = PrompterRegistry.get(DATASET_ID)(with_assistant=False)
chat_templates = list(map(prompter.to_chat_template, test_dataset))
prompts = list(map(lambda ct: to_model_prompt(tokenizer, ct), chat_templates))

# Decompose questions
outputs = generate_batch(
    model=model,
    tokenizer=tokenizer,
    model_prompts=prompts,
    batch_size=BATCH_SIZE,
    max_new_tokens=MAX_NEW_TOKENS,
    progress_bar=tqdm(total=len(prompts), desc="Completing QPL"),
    do_sample=False,
)

## Evaluate

In [None]:
import pyodbc
from src.inference.qpl.qpl_to_cte import flat_qpl_to_cte
from src.inference.qpl.validate_qpl import execute_sql, same_rs

connection_string = (
    'Driver={ODBC Driver 18 for SQL Server};'
    'Server=tcp:spider-sql.database.windows.net,1433;'
    'Database=test;'
    'Uid=iloveqpl;'
    'Pwd=P4$$w0rd!;'
    'Encrypt=yes;'
    'TrustServerCertificate=no;'
    'Connection Timeout=30;'
)

conn = pyodbc.connect(connection_string, autocommit=True)
cursor = conn.cursor()

execution_accuracy = 0
for out, example in tqdm(zip(outputs, test_dataset), desc="Evaluating", total=len(outputs)):
    gold = example['prefix_qpl'] + "\n" + example['qpl_line']
    pred = example['prefix_qpl'] + "\n" + example['op'] + ' ' + out

    flat_gold = [line[:line.index('--')] if '--' in line else line for line in gold.split('\n')]
    flat_pred = [line[:line.index('--')] if '--' in line else line for line in pred.split('\n')]

    gold_cte = flat_qpl_to_cte(flat_gold, example['db_id'])
    pred_cte = flat_qpl_to_cte(flat_pred, example['db_id'])

    grs = execute_sql(cursor, gold_cte)
    prs = execute_sql(cursor, pred_cte)

    same = same_rs(grs, prs, flat_pred)
    if same:
        execution_accuracy += 1

print(f"Execution accuracy: {execution_accuracy}/{len(outputs)} ({execution_accuracy / len(outputs) * 100:.2f}%)")

In [None]:
import re
from typing import List

prompter = PrompterRegistry.get(DATASET_ID)(with_assistant=True)
chat_templates = []
for example in test_dataset:
    chat_templates.append(prompter.to_chat_template(example))


def equivalent_bracketed_lines(
        true_line: str,
        generated_line: str,
        *,
        require_exact_names: bool = False,
        ignore_case: bool = True
    ) -> bool:
    _ZERO_DECIMAL = re.compile(r'(?<![\d.])(\d+)\.0+\b')
    _TABLE_COL    = re.compile(r'#\d+\.\s*[\w$]+', flags=re.I)
    _EQUALITY     = re.compile(rf'({_TABLE_COL.pattern})\s*=\s*({_TABLE_COL.pattern})', flags=re.I)

    def _norm_nums(txt: str) -> str:
        return _ZERO_DECIMAL.sub(r'\1', txt)

    # ---- build equivalence classes from join predicates
    def _build_equiv_map(*lines: str) -> dict[str, str]:
        parent = {}

        def find(x):
            parent.setdefault(x, x)
            if parent[x] != x:
                parent[x] = find(parent[x])
            return parent[x]

        def union(a, b):
            ra, rb = find(a), find(b)
            if ra != rb:
                parent[rb] = ra

        for line in lines:
            for block in re.findall(r'Predicate\s*\[([^\]]*)\]',
                                    _norm_nums(line), flags=re.I):
                for left, right in _EQUALITY.findall(block):
                    l = left.replace(' ', '')
                    r = right.replace(' ', '')
                    if ignore_case:
                        l, r = l.upper(), r.upper()
                    union(l, r)

        equiv = {}
        for full in parent:
            col = full.split('.', 1)[1]
            root = find(full)
            equiv[full] = col.upper() if ignore_case else col
            if root not in equiv:
                equiv[root] = equiv[full]
        return equiv

    EQUIV = _build_equiv_map(true_line, generated_line)

    # ---- compare skeletons (outside brackets)
    def _skeleton(txt: str) -> str:
        txt = _norm_nums(txt)
        txt = re.sub(r'\[[^\]]*]', '[]', txt)
        txt = re.sub(r'\s+', ' ', txt).strip()
        return txt.upper() if ignore_case else txt

    if _skeleton(true_line) != _skeleton(generated_line):
        return False

    # ---- helper to extract bracket contents
    blocks = lambda s: re.findall(r'\[([^\]]*)]', _norm_nums(s))

    t_blocks, g_blocks = blocks(true_line), blocks(generated_line)
    if len(t_blocks) != len(g_blocks):
        return False

    # ---- canonicalise individual tokens
    def canon(tok: str) -> str:
        tok = re.split(r'\s+AS\s+', tok, flags=re.I)[0]
        tok = re.sub(r'\s+', ' ', tok).strip()
        tok = _norm_nums(tok)

        def repl(m):
            key = m.group(0).replace(' ', '')
            key = key.upper() if ignore_case else key
            return EQUIV.get(key, m.group(0))

        tok = _TABLE_COL.sub(repl, tok)
        return tok.upper() if ignore_case else tok

    # ---- compare each corresponding block
    for tb, gb in zip(t_blocks, g_blocks):
        t_set = {canon(tok) for tok in tb.split(',') if tok.strip()}
        g_set = {canon(tok) for tok in gb.split(',') if tok.strip()}

        if require_exact_names:
            if t_set != g_set:
                return False
        else:
            if not t_set.issubset(g_set):
                return False

    return True


acc = 0
for out, chat_template, example in tqdm(zip(outputs, chat_templates, test_dataset), desc="Evaluating", total=len(outputs)):
    label = chat_template['messages'][-1]['content']
    equivalent = equivalent_bracketed_lines(
        label,
        out,
    )
    if not equivalent:
        print("Question:")
        print(example['question'])
        print('-'*40)
        print("Predicted:")
        print(example['op'], out)
        print('-'*40)
        print("True:")
        print(example['op'], label)
        print("=" * 80)
    else:
        acc += 1

print(f"Accuracy: {acc}/{len(outputs)} = {acc / len(outputs) * 100:.2f}%")