# Completer Test

This notebook is made for assessing the performence of the QPL completer alone

## Generate Completions

In [1]:
from transformers.models.auto.tokenization_auto import AutoTokenizer
from peft import AutoPeftModelForCausalLM
from processors import ProcessorRegistry
from datasets import load_dataset
from tqdm import tqdm

from utils.generation import to_model_prompt, generate_batch
import utils.paths as p

# Constants
BATCH_SIZE = 6
MAX_NEW_TOKENS = 128
MODEL_CKPT = "855d8cb9_gemma-3-4b-it-qpl-composer-ds_train_batch_size=1_gradient_accumulation_steps=8_learning_rate=0.0002_num_train_epochs=4_gradient_checkpointing=True_logging_steps=0.00125_save_steps=0.0625_random_seed=1_lora=True_r=16_alpha=32_dropout=0.05/checkpoint-3996"
MODEL_PATH = p.TRAINED_MODELS_DIR / MODEL_CKPT
DATASET_ID = "d4nieldev/qpl-completer-ds"


# Load model & tokenizer
model = AutoPeftModelForCausalLM.from_pretrained(MODEL_PATH, attn_implementation='eager').cuda()
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

# Load and process data
test_dataset = list(load_dataset(DATASET_ID, split="validation"))
processor = ProcessorRegistry.get(DATASET_ID)(with_assistant=False)
chat_templates = list(map(processor.to_chat_template, test_dataset))
prompts = list(map(lambda ct: to_model_prompt(tokenizer, ct), chat_templates))

# Decompose questions
outputs = generate_batch(
    model=model,
    tokenizer=tokenizer,
    model_prompts=prompts,
    batch_size=BATCH_SIZE,
    max_new_tokens=MAX_NEW_TOKENS,
    progress_bar=tqdm(total=len(prompts), desc="Completing QPL"),
    do_sample=False,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Completing QPL: 100%|██████████| 3012/3012 [34:38<00:00,  1.45it/s]


## Evaluate

In [None]:
import re
from typing import List

processor = ProcessorRegistry.get(DATASET_ID)(with_assistant=True)
chat_templates = []
for example in test_dataset:
    chat_templates.append(processor.to_chat_template(example))


def equivalent_bracketed_lines(true_line: str,
                               generated_line: str,
                               *,
                               require_exact_names: bool = False,
                               ignore_case: bool = True) -> bool:
    _ZERO_DECIMAL = re.compile(r'(?<![\d.])(\d+)\.0+\b')
    _TABLE_COL    = re.compile(r'#\d+\.\s*[\w$]+', flags=re.I)
    _EQUALITY     = re.compile(rf'({_TABLE_COL.pattern})\s*=\s*({_TABLE_COL.pattern})',
                               flags=re.I)

    def _norm_nums(txt: str) -> str:
        return _ZERO_DECIMAL.sub(r'\1', txt)

    # ---- build equivalence classes from join predicates
    def _build_equiv_map(*lines: str) -> dict[str, str]:
        parent = {}

        def find(x):
            parent.setdefault(x, x)
            if parent[x] != x:
                parent[x] = find(parent[x])
            return parent[x]

        def union(a, b):
            ra, rb = find(a), find(b)
            if ra != rb:
                parent[rb] = ra

        for line in lines:
            for block in re.findall(r'Predicate\s*\[([^\]]*)\]',
                                    _norm_nums(line), flags=re.I):
                for left, right in _EQUALITY.findall(block):
                    l = left.replace(' ', '')
                    r = right.replace(' ', '')
                    if ignore_case:
                        l, r = l.upper(), r.upper()
                    union(l, r)

        equiv = {}
        for full in parent:
            col = full.split('.', 1)[1]
            root = find(full)
            equiv[full] = col.upper() if ignore_case else col
            if root not in equiv:
                equiv[root] = equiv[full]
        return equiv

    EQUIV = _build_equiv_map(true_line, generated_line)

    # ---- compare skeletons (outside brackets)
    def _skeleton(txt: str) -> str:
        txt = _norm_nums(txt)
        txt = re.sub(r'\[[^\]]*]', '[]', txt)
        txt = re.sub(r'\s+', ' ', txt).strip()
        return txt.upper() if ignore_case else txt

    if _skeleton(true_line) != _skeleton(generated_line):
        return False

    # ---- helper to extract bracket contents
    blocks = lambda s: re.findall(r'\[([^\]]*)]', _norm_nums(s))

    t_blocks, g_blocks = blocks(true_line), blocks(generated_line)
    if len(t_blocks) != len(g_blocks):
        return False

    # ---- canonicalise individual tokens
    def canon(tok: str) -> str:
        tok = re.split(r'\s+AS\s+', tok, flags=re.I)[0]
        tok = re.sub(r'\s+', ' ', tok).strip()
        tok = _norm_nums(tok)

        def repl(m):
            key = m.group(0).replace(' ', '')
            key = key.upper() if ignore_case else key
            return EQUIV.get(key, m.group(0))

        tok = _TABLE_COL.sub(repl, tok)
        return tok.upper() if ignore_case else tok

    # ---- compare each corresponding block
    for tb, gb in zip(t_blocks, g_blocks):
        t_set = {canon(tok) for tok in tb.split(',') if tok.strip()}
        g_set = {canon(tok) for tok in gb.split(',') if tok.strip()}

        if require_exact_names:
            if t_set != g_set:
                return False
        else:
            if not t_set.issubset(g_set):
                return False

    return True


acc = 0
for out, chat_template, example in tqdm(zip(outputs, chat_templates, test_dataset), desc="Evaluating", total=len(outputs)):
    label = chat_template['messages'][-1]['content']
    equivalent = equivalent_bracketed_lines(
        label,
        out,
    )
    if not equivalent:
        print("Question:")
        print(example['question'])
        print('-'*40)
        print("Predicted:")
        print(example['op'], out)
        print('-'*40)
        print("True:")
        print(example['op'], label)
        print("=" * 80)
    else:
        acc += 1

print(f"Accuracy: {acc}/{len(outputs)} = {acc / len(outputs) * 100:.2f}%")

Evaluating:  30%|███       | 915/3012 [00:00<00:00, 4440.26it/s]

Question:
What is the maximum capacity and the average of all stadiums ?
----------------------------------------
Predicted:
Aggregate[ #1 ] Output [ AVG(Capacity) AS Avg_Capacity , MAX(Capacity) AS Max_Capacity ]
```
----------------------------------------
True:
Aggregate[ #1 ] GroupBy [ Average ] Output [ Average , MAX(Capacity) AS Max_Capacity ]
```
Question:
List the stadium name of all concerts.
----------------------------------------
Predicted:
Join[ #1 , #2 ] Predicate [ #1.Stadium_ID = #2.Stadium_ID ] Output [ #2.Name ]
```
----------------------------------------
True:
Join[ #1 , #2 ] Predicate [ #1.Stadium_ID = #2.Stadium_ID ] Output [ #1.Stadium_ID , #2.Name ]
```
Question:
Show the stadium names without any concert.
----------------------------------------
Predicted:
Except[ #1 , #2 ] Predicate [ #2.Stadium_ID = #1.Stadium_ID ] Output [ #1.Name ]
```
----------------------------------------
True:
Except[ #1 , #2 ] Predicate [ #2.Stadium_ID IS NULL OR #1.Stadium_ID = #2.St

Evaluating:  60%|█████▉    | 1803/3012 [00:00<00:00, 4340.21it/s]

Show paragraph details for paragraph with text 'Korea ' .
----------------------------------------
Predicted:
ScanTable [ Paragraphs ] Predicate [ Paragraph_Text like '%Korea%' ] Output [ Paragraph_Text , Other_Details ]
```
----------------------------------------
True:
ScanTable [ Paragraphs ] Predicate [ Paragraph_Text like 'korea' ] Output [ Paragraph_Text , Other_Details ]
```
Question:
What are the details for the paragraph that includes the text 'Korea ' ?
----------------------------------------
Predicted:
ScanTable [ Paragraphs ] Predicate [ Paragraph_Text like '%Korea%' ] Output [ Paragraph_Text , Other_Details ]
```
----------------------------------------
True:
ScanTable [ Paragraphs ] Predicate [ Paragraph_Text like 'korea' ] Output [ Paragraph_Text , Other_Details ]
```
Question:
Show all document ids and the number of paragraphs in each document. Order by document id.
----------------------------------------
Predicted:
Sort[ #2 ] OrderBy [ Document_ID ASC ] Output [ Docu

Evaluating:  89%|████████▉ | 2679/3012 [00:00<00:00, 4335.14it/s]

Question:
What is the number of votes for each area code?
----------------------------------------
Predicted:
Join[ #2 , #3 ] Predicate [ #2.state = #3.state ] Output [ #3.state , #2.Count_Star ]
```
----------------------------------------
True:
Join[ #2 , #3 ] Predicate [ #2.state = #3.state ] Output [ #3.area_code , #2.Count_Star ]
```
Question:
List the contestant number of 'Tabatha Gehling'.
----------------------------------------
Predicted:
ScanTable [ CONTESTANTS ] Predicate [ contestant_name = 'Tabatha Gehling' ] Output [ contestant_number , contestant_name ]
```
----------------------------------------
True:
ScanTable [ CONTESTANTS ] Predicate [ contestant_name = 'Kelly Clauss' ] Output [ contestant_name , contestant_number ]
```
Question:
Find the code of the Aruba country.
----------------------------------------
Predicted:
ScanTable [ country ] Predicate [ LocalName = 'Aruba' ] Output [ Code , LocalName ]
```
----------------------------------------
True:
ScanTable [ count

Evaluating: 100%|██████████| 3012/3012 [00:00<00:00, 4350.09it/s]

Question:
What are all the possible breed type and size type combinations?
----------------------------------------
Predicted:
ScanTable [ Sizes ] Output [ size_description , size_code ]
```
----------------------------------------
True:
ScanTable [ Dogs ] Distinct [ true ] Output [ breed_code , size_code ]
```
Question:
List the first name of all the professionals along with the description of the treatment they have done.
----------------------------------------
Predicted:
Join[ #1 , #4 ] Predicate [ #1.treatment_type_code = #4.treatment_type_code ] Output [ #4.first_name , #1.treatment_type_description ]
```
----------------------------------------
True:
Join[ #1 , #4 ] Predicate [ #1.treatment_type_code = #4.treatment_type_code ] Distinct [ true ] Output [ #1.treatment_type_description , #4.first_name ]
```
Question:
What are each professional's first name and description of the treatment they have performed?
----------------------------------------
Predicted:
Join[ #1 , #4 ] Predi


