# Decomposer Test

This notebook is made for assessing the performence of the decomposer alone

## Generate QD

In [None]:
import torch
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers import AutoModelForCausalLM
from datasets import load_dataset
from tqdm import tqdm

from src.prompters import QPLDecomposerCotPrompter
from src.utils.generation import to_model_prompt, generate_batch
import src.utils.paths as p

# Constants
BATCH_SIZE = 8
MAX_NEW_TOKENS = 8192
# MODEL_DIR = "qwen-3-qpl-cot"
# MODEL_CKPT = MODEL_DIR + "/checkpoint-2864"
# MODEL_PATH = p.TRAINED_MODELS_DIR / MODEL_CKPT
# DATASET_ID = "bgunlp/question_decomposer_ds"
MODEL_PATH = "Qwen/Qwen3-4B"
DATASET_ID = "d4nieldev/qpl-decomposer-cot-ds"
MAX_RETRIES = 3

# Load and process data
prompter = QPLDecomposerCotPrompter(with_assistant=True, schema_representation="m_schema")
test_dataset = list(prompter.load_dataset()['validation'])
chat_templates = list(map(prompter.to_chat_template, test_dataset))

In [None]:
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, attn_implementation="flash_attention_2", torch_dtype=torch.float16).cuda()
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

In [None]:
chat_template = chat_templates[13]
print(chat_template['messages'][1]['content'])

In [None]:
chat_template['messages'][1]['content'] = """【DB_ID】concert_singer
【Schema】
# Table: singer (6 rows)
## Columns (name:type | pk? | description? | metadata? | fk? | examples?)
[
(Name: TEXT | This column has 0 null values and 6 distinct values. The values are between 9 and 12 characters long, with a mean of 10.50 and a standard deviation of 1.38. | Values Examples: [Timbaland, John Nizinik, Justin Brown, Joe Sharp, Rose White]),
(Song_release_year: TEXT | This column has 0 null values and 6 distinct values. All of them are numeric. The values are always 4 characters long. Common prefixes: 201 (3 occurrences), 200 (2 occurrences). | Values Examples: [2013, 2014, 2003, 2008, 1992]),
(Age: NUMBER | This column has 0 null values and 6 distinct values. All of them are numeric. The values are between 25.0 and 52.0, with a mean of 37.00 and a standard deviation of 10.10. | Values Examples: [29, 41, 32, 43, 52]),
]

【Foreign Keys】
concert.Stadium_ID=stadium.Stadium_ID
singer_in_concert.Singer_ID=singer.Singer_ID
singer_in_concert.concert_ID=concert.concert_ID


[Question]: Show the name and the release year of the song by the youngest singer.

First, determine the operator, then formulate the sub-questions (unless the operator is "Scan", in which case no sub-questions are needed and this step must be skipped), and finally justify the decomposition.
Provide your reasoning enclosed in <think> and </think> tags, and afterwards provide the final answer in the following format: the first line of the final answer should be the toplevel operator, the following lines should be the predicted sub-questions."""

In [None]:
ids = tokenizer.apply_chat_template(
    [msg for msg in chat_template["messages"] if msg['role'] in ['system', 'user']],  # only the system and user
    tokenize=False,
    add_generation_prompt=True,
    continue_final_message=False,
)
ids = tokenizer(
    ids,
    padding=True,
    padding_side="left",  # https://huggingface.co/docs/transformers/llm_tutorial?padding=right+pad#padding-side
    return_tensors="pt",
    add_special_tokens=False,
).to('cuda')

with torch.no_grad():
    outputs = model.generate(**ids, max_new_tokens=4096, labels=ids.input_ids, temperature=0.6, top_p=0.95, top_k=20)
# preplexity = torch.exp(outputs.loss)
# print(f"Perplexity: {preplexity.item()}")

In [None]:
outputs

In [None]:
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

In [None]:
# Load model & tokenizer
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, attn_implementation="flash_attention_2", torch_dtype=torch.float16).cuda()
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

prompts = list(map(lambda ct: to_model_prompt(tokenizer, ct), chat_templates))

# Decompose questions
predictions = generate_batch(
    model=model,
    tokenizer=tokenizer,
    model_prompts=prompts,
    batch_size=BATCH_SIZE,
    max_new_tokens=MAX_NEW_TOKENS,
    progress_bar=tqdm(total=len(prompts), desc="Decomposing"),
    max_retries=MAX_RETRIES,
    **get_decomposer_generation_params(MODE)
)

## Evaluate

In [None]:
import re
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sentence_transformers import SentenceTransformer
from sklearn.metrics import classification_report

op_correct = 0
sum_similarity = 0
sentences_count = 0

op_to_id = {
    'aggregate': 0,
    'except': 1,
    'filter': 2,
    'intersect': 3,
    'join': 4,
    'scan': 5,
    'sort': 6,
    'topsort': 7,
    'union': 8,
    'other': 9
}
id_to_op = {v: k for k, v in op_to_id.items()}

emb_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

y_pred = []
y_true = []

prompter = QPLDecomposerCotPrompter(with_assistant=True)
chat_templates = [prompter.to_chat_template(example) for example in test_dataset]

In [None]:
# from predictions
output_pattern = re.compile(r"(?P<reasoning><think>.*?</think>)?\s*(?P<answer>.*)", re.DOTALL)

output_json = []
for pred, chat_template in tqdm(zip(predictions, chat_templates), desc="Evaluating", total=len(predictions)):
    if pred is None:
        print("No prediction for chat template:", chat_template)
        continue

    # save prediction vs. gold
    output_json.append({
        "input": [ct for ct in chat_template['messages'] if ct['role'] in ['system', 'user']],
        "pred": pred,
        "gold": chat_template['messages'][-1]['content']
    })

    # process data
    gold = chat_template['messages'][-1]['content']
    if not (gold_match := output_pattern.match(gold)):
        raise ValueError(f"Invalid gold output format: {gold}")
    if not (pred_match := output_pattern.match(pred)):
        print(f"Invalid prediction format:\n\n{pred}\n\n---------------------------")
        continue

    pred_lines = pred_match.group("answer").split("\n")
    gold_lines = gold_match.group("answer").split("\n")

    # operator classification
    pred_op_id = op_to_id.get(pred_lines[0].lower(), op_to_id["other"])
    gold_op_id = op_to_id.get(gold_lines[0].lower(), op_to_id["other"])
    y_pred.append(pred_op_id)
    y_true.append(gold_op_id)

    # sentence similarity
    if pred_op_id == gold_op_id:
        op_correct += 1

        model_sentences = pred_lines[1:]
        label_sentences = gold_lines[1:]

        sentences_count += len(label_sentences)

        if len(model_sentences) != len(label_sentences):
            print("======================")
            print(pred)
            print("----")
            print(gold)
            print("======================")
        else:
            all_sentences = model_sentences + label_sentences
            embeddings = emb_model.encode(all_sentences, show_progress_bar=False)
            similarity_matrix = embeddings @ embeddings.T
            if len(model_sentences) == 0:
                similarity = 0
            elif len(model_sentences) == 1:
                similarity = similarity_matrix[0][1]
            else:
                similarity = max(
                    similarity_matrix[0,2] + similarity_matrix[1,3],
                    similarity_matrix[0,3] + similarity_matrix[1,2]
                ) / 2
            sum_similarity += similarity

print(f"Operator Accuracy: {op_correct / len(test_dataset)}")
print(f"Sentence Similarity (when operator is correct): {sum_similarity / op_correct}")

# Generate a classification report
report = classification_report(y_true, y_pred, output_dict=True)

# Print nicely
df = pd.DataFrame(report).transpose()
df.index = df.index.map(lambda x: id_to_op[int(x)] if x.isdigit() else x)
df['support'] = df['support'].astype(int)
print(df)

cm = confusion_matrix(y_true, y_pred, labels=list(op_to_id.values()))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=list(op_to_id.keys()))
disp.plot(cmap=plt.cm.Blues)

plt.xticks(rotation=45)
plt.show()

In [None]:
import json
with open('output/qpl/decomposer_3_predictions.json', 'w') as f:
    json.dump(output_json, f, indent=2)