In [None]:
from llamafactory.chat import ChatModel
from llamafactory.extras.misc import torch_gc 
import json
from src.utils import extract_sql, chdiff 

import logging

logging.disable(logging.CRITICAL)

with open("data/dataset_test.json", "r") as f:
    test_ds = json.load(f)


BASE_MODEL_PATH = "Qwen/Qwen3-0.6B"
ADAPTER_PATH = "out/checkpoint-250" # directory of LoRA checkpoint

args = dict(
    model_name_or_path=BASE_MODEL_PATH,
    adapter_name_or_path=ADAPTER_PATH,
    template="qwen3", 
    finetuning_type="lora"
)
chat_model = ChatModel(args)


def print_red(text, *args, **kwargs):
    print(f"\033[91m{text}\033[0m", *args, **kwargs)
def print_green(text, *args, **kwargs):
    print(f"\033[92m{text}\033[0m", *args, **kwargs)
def print_yellow(text, *args, **kwargs):
    print(f"\033[93m{text}\033[0m", *args, **kwargs)


system_prompt = test_ds[0]["system"]
print_red(f"System:\n{system_prompt}\n", end="", flush=True)

messages = []
for entry in test_ds:
    conversation = entry["conversations"]
    user_prompt = conversation[0]["value"]
    true_sql = extract_sql(conversation[1]["value"]) 

    messages = [{"role": "user", "content": user_prompt}]

    print_green(f"User:\n{user_prompt}\n", end="", flush=True)
    print_yellow("\nAssistant: ", end="", flush=True)

    response = ""
    for new_text in chat_model.stream_chat(messages, system=system_prompt):
        print_yellow(new_text, end="", flush=True)
        response += new_text

    messages.append({"role": "assistant", "content": response})

    predicted_sql = extract_sql(response)

    print("\nDiff:\n", end="", flush=True)
    
    print(chdiff(true_sql, predicted_sql))

    print() 
    torch_gc()

    input_text = input("User: ")
    if input_text == "exit":
        break

  from .autonotebook import tqdm as notebook_tqdm
