In [None]:
# train_grpo_improved.py
#
# See https://github.com/willccbb/verifiers for ongoing developments

import os
import re
import torch
import matplotlib.pyplot as plt
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainerCallback
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer

# Define cache directory
CACHE_DIR = "cai6307-henrykobs/cache"

# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format, you have to adhere to the format, only output the final answer without **ANY** additional information in the "answer" box.

<think>
...
</think>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<think>
{think}
</think>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str):
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str):
    if "####" not in text:
        return None
    return text.split("####")[1].strip().replace(",", "").replace("$", "")

def get_gsm8k_questions(split="train"):
    data = load_dataset('openai/gsm8k', 'main', cache_dir=CACHE_DIR)[split]  # specify cache directory
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    })
    return data

dataset = get_gsm8k_questions()

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs):
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs):
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs):
    pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs):
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<think>\n") == 1:
        count += 0.125
    if text.count("\n</think>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1]) * 0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

# Choose model and set up output parameters
# model_name = "Qwen/Qwen2.5-7B-Instruct-1M"
# model_name = "Qwen/Qwen2.5-14B-Instruct-1M"
model_name = "Qwen/Qwen2-7B"

if "Llama" in model_name:
    output_dir = "cai6307-henrykobs/model/Llama-1B-GRPO"
    run_name = "Llama-1B-GRPO-gsm8k"
else:
    output_dir = "cai6307-henrykobs/model/Qwen-7B-GRPO-2.0"
    run_name = "Qwen-7B-GRPO-gsm8k-2.0"

# Training configuration optimized for an A100 80GB
training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=1e-5,  # Slightly increased learning rate (tune as needed)
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=4,      # Increased batch size for ample GPU memory
    gradient_accumulation_steps=4,
    num_generations=2,                   # Single generation per prompt
    max_prompt_length=256,
    max_completion_length=512,
    num_train_epochs=1,
    save_steps=20,                      # Save checkpoint every 100 steps
    max_grad_norm=0.1,
    report_to="none",                    # Change to "tensorboard" if you want native TensorBoard logging
    log_on_each_node=False,
    scale_rewards=False
)

peft_config = LoraConfig(
    r=16,
    lora_alpha=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
)

# Initialize the model loaded entirely on GPU (remove offloading)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    cache_dir=CACHE_DIR,  # specify cache directory
    device_map="auto"
)

# Optionally, if you are using PyTorch 2.0+, you can compile the model for a speed boost:
# model = torch.compile(model)

# Gradient checkpointing is not required with 80GB, so it's been disabled for potentially faster execution
# model.gradient_checkpointing_enable()

tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR)
tokenizer.pad_token = tokenizer.eos_token

# --- Custom Callback for Plotting Metrics Separately ---
class PlottingCallback(TrainerCallback):
    def __init__(self, model, tokenizer, sample_prompts, eval_steps=10, eval_max_length=50):
        """
        :param model: The current model.
        :param tokenizer: The tokenizer for processing prompts.
        :param sample_prompts: A list of sample prompts to evaluate response lengths.
                               Each prompt should be a list of messages (dict with "content").
        :param eval_steps: Frequency (in global steps) to evaluate and log response length.
        :param eval_max_length: Maximum additional tokens to generate for evaluation.
        """
        self.model = model
        self.tokenizer = tokenizer
        self.sample_prompts = sample_prompts
        self.eval_steps = eval_steps
        self.eval_max_length = eval_max_length

        self.global_steps = []
        self.losses = []
        self.rewards = []
        self.avg_resp_lengths = []
        self.resp_plot_steps = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is None:
            return

        # Collect metrics if present.
        if "loss" in logs:
            self.global_steps.append(state.global_step)
            self.losses.append(logs["loss"])
        if "reward" in logs:
            self.rewards.append(logs["reward"])

        # Every eval_steps, compute average response length.
        if state.global_step % self.eval_steps == 0:
            avg_length = self.evaluate_response_length()
            self.avg_resp_lengths.append(avg_length)
            self.resp_plot_steps.append(state.global_step)
            self.plot_all_metrics()

    def evaluate_response_length(self):
        lengths = []
        self.model.eval()
        with torch.no_grad():
            for prompt in self.sample_prompts:
                # Use the last message's content as the prompt text.
                prompt_text = prompt[-1]["content"]
                encoded = self.tokenizer(prompt_text, return_tensors="pt").to(self.model.device)
                input_length = encoded.input_ids.shape[1]
                outputs = self.model.generate(
                    **encoded,
                    max_length=input_length + self.eval_max_length,
                    do_sample=True
                )
                # Compute number of generated tokens (excluding the prompt)
                gen_length = outputs.shape[1] - input_length
                lengths.append(gen_length)
        self.model.train()
        avg_length = sum(lengths) / len(lengths) if lengths else 0
        print(f"Evaluated average response length: {avg_length:.2f} tokens")
        return avg_length

    def plot_all_metrics(self):
        # Plot Loss
        plt.figure()
        plt.plot(self.global_steps, self.losses, label="Loss")
        plt.xlabel("Global Steps")
        plt.ylabel("Loss")
        plt.title("Loss over Time")
        plt.legend()
        plt.savefig("run3/loss_metrics.png")
        plt.close()
        print(f"Saved loss plot at step {self.global_steps[-1]}")

        # Plot Reward (if available)
        if self.rewards:
            plt.figure()
            # Make sure we plot rewards vs. steps corresponding to rewards.
            plt.plot(self.global_steps[:len(self.rewards)], self.rewards, label="Reward")
            plt.xlabel("Global Steps")
            plt.ylabel("Reward")
            plt.title("Reward over Time")
            plt.legend()
            plt.savefig("run3/reward_metrics.png")
            plt.close()
            print(f"Saved reward plot at step {self.global_steps[-1]}")

        # Plot Average Response Length
        plt.figure()
        plt.plot(self.resp_plot_steps, self.avg_resp_lengths, label="Avg Response Length (tokens)")
        plt.xlabel("Global Steps")
        plt.ylabel("Tokens")
        plt.title("Average Response Length over Time")
        plt.legend()
        plt.savefig("run3/response_length.png")
        plt.close()
        print(f"Saved response length plot at step {self.global_steps[-1]}")

# Select a few sample prompts from the dataset for evaluation.
# Here we take the first 5 samples and use their "prompt" field.
sample_prompts = [item["prompt"] for item in dataset.select(range(5))]

# Initialize the GRPOTrainer with the custom callback.
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        correctness_reward_func,
        soft_format_reward_func,
        int_reward_func
    ],
    args=training_args,
    train_dataset=dataset,
    peft_config=peft_config,
)

trainer.add_callback(PlottingCallback(model, tokenizer, sample_prompts, eval_steps=10, eval_max_length=50))

# --- Checkpoint Recovery ---
# If the training was interrupted, resume from the latest checkpoint if available.
resume_checkpoint = None
if os.path.exists(output_dir):
    checkpoints = [d for d in os.listdir(output_dir) if d.startswith("checkpoint")]
    if checkpoints:
        # Sort checkpoints by the numeric step value appended at the end of the folder name.
        checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))
        resume_checkpoint = os.path.join(output_dir, checkpoints[-1])
        print(f"Resuming training from checkpoint: {resume_checkpoint}")

# --- Start Training ---
if __name__ == "__main__":
    try:
        trainer.train(resume_from_checkpoint=resume_checkpoint)
    except Exception as e:
        # If something goes wrong, save the current state so we can resume later.
        print("An error occurred during training:", e)
        trainer.save_state()
        raise


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Loading checkpoint shards: 100%|██████████| 4/4 [00:06<00:00,  1.51s/it]
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Resuming training from checkpoint: cai6307-henrykobs/model/Qwen-7B-GRPO-2.0/checkpoint-240


Step,Training Loss
241,0.1794
242,0.0975
243,0.0793
244,0.1478
245,-0.0233
246,0.0287
247,0.1525
248,0.0962
249,-0.0051
250,0.0647


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=89) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=81) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=111) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos

Evaluated average response length: 499.00 tokens
Saved loss plot at step 250
Saved reward plot at step 250
Saved response length plot at step 250


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=89) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=81) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=111) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos

Evaluated average response length: 316.20 tokens
Saved loss plot at step 260
Saved reward plot at step 260
Saved response length plot at step 260


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=89) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=81) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=111) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos

Evaluated average response length: 321.80 tokens
Saved loss plot at step 270
Saved reward plot at step 270
Saved response length plot at step 270


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=89) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=81) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=111) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos

Evaluated average response length: 126.20 tokens
Saved loss plot at step 280
Saved reward plot at step 280
Saved response length plot at step 280


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=89) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=81) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=111) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos

Evaluated average response length: 258.60 tokens
Saved loss plot at step 290
Saved reward plot at step 290
Saved response length plot at step 290


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=89) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=81) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Both `max_new_tokens` (=2048) and `max_length`(=111) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos

Evaluated average response length: 277.80 tokens
Saved loss plot at step 300
Saved reward plot at step 300
Saved response length plot at step 300


In [None]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM
CACHE_DIR = "cai6307-henrykobs/cache"
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", cache_DIR=CACHE_DIR)
print("chegou")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B-Instruct",cache_dir=CACHE_DIR, device_map="auto")


In [5]:
# -*- coding: utf-8 -*-
import os
import torch
import logging
import time
import xml.etree.ElementTree as ET
from types import SimpleNamespace
from datasets import load_dataset, Dataset
from transformers import pipeline, AutoTokenizer
from tqdm.auto import tqdm

# --- Flash Attention Check ---
try:
    import flash_attn
    _flash_attn_available = True
    logging.info("Flash Attention 2 available. Will use if applicable.")
except ImportError:
    _flash_attn_available = False
    logging.warning("Flash Attention 2 not found. Install with `pip install flash-attn --no-build-isolation` for potential speedup.")

# --- Constants ---
SYSTEM_PROMPT = """
Respond in the following format, you have to adhere to the format, only output the final answer without **ANY** additional information in the "answer" box.

<think>
...
</think>
<answer>
...
</answer>
"""
DEFAULT_CACHE_DIR = "cai6307-henrykobs/cache"
DEFAULT_BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct-1M"
DEFAULT_OUTPUT_DIR = "cai6307-henrykobs/model/Qwen-7B-GRPO-2nd"
DEFAULT_EVAL_FRACTION = 0.2
DEFAULT_BATCH_SIZE = 32 # Reduced default for demonstration if needed, adjust per GPU
SEED = 42
TORCH_DTYPE = torch.bfloat16
ATTN_IMPLEMENTATION = "flash_attention_2" if _flash_attn_available else "eager"
DEFAULT_DEBUG_SAMPLE_OUTPUT = True # Enable debug logging by default
MAX_DEBUG_SAMPLES = 5 # Max *initial* raw samples to log
GENERATION_MAX_NEW_TOKENS = 1024
MAX_MISMATCH_LOG = 5

# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Extraction Functions ---
def extract_xml_answer(text: str):
    """Extracts content within the first <answer>...</answer> tag, with fallbacks."""
    try:
        start_tag = "<answer>"
        end_tag = "</answer>"
        start_idx = text.find(start_tag)
        end_idx = text.find(end_tag, start_idx + len(start_tag)) # Search after start tag

        if start_idx != -1 and end_idx != -1:
             return text[start_idx + len(start_tag):end_idx].strip().lower()

        # Fallback 1: Robust XML parsing
        try:
            xml_content = f"<root>{text}</root>" # Dummy root
            root = ET.fromstring(xml_content)
            answer_element = root.find('.//answer')
            if answer_element is not None and answer_element.text is not None:
                return answer_element.text.strip().lower()
        except ET.ParseError:
             pass # Malformed XML

        # Fallback 2: Simple split
        if start_tag in text and end_tag in text:
             try:
                 potential_answer = text.split(start_tag, 1)[1].split(end_tag, 1)[0]
                 return potential_answer.strip().lower()
             except IndexError:
                 pass # Split failed

        return None

    except Exception as e:
         logger.error(f"Unexpected error during XML extraction: {e} for text: {text[:150]}...", exc_info=True)
         return None

def extract_hash_answer(text: str):
    """Extracts the answer from GSM8K format '... #### <answer>'."""
    if "####" not in text:
        return None
    try:
        answer = text.split("####")[1].strip().replace(",", "").replace("$", "").lower()
        return answer
    except IndexError:
         logger.warning(f"Could not split gold answer on '####': {text}")
         return None

# --- Dataset Preparation ---
def get_gsm8k_dataset(cache_dir: str, split="test"):
    """Loads GSM8K, extracts gold answers, and filters invalid samples."""
    logger.info(f"Loading GSM8K dataset ({split} split)...")
    try:
        dataset = load_dataset("openai/gsm8k", "main", cache_dir=cache_dir, split=split, trust_remote_code=True)
        dataset = dataset.map(
            lambda x: {"gold_answer": extract_hash_answer(x["answer"]), "question": x["question"]},
            remove_columns=["answer"],
            desc="Extracting gold answers"
        )
        initial_count = len(dataset)
        dataset = dataset.filter(lambda x: x["gold_answer"] is not None)
        filtered_count = len(dataset)
        if initial_count != filtered_count:
             logger.warning(f"Filtered {initial_count - filtered_count} samples with invalid gold answers.")
        logger.info(f"Dataset loaded with {filtered_count} samples.")
        return dataset
    except Exception as e:
        logger.error(f"Failed to load or process dataset: {e}", exc_info=True)
        raise

def prepare_prompts(dataset: Dataset, tokenizer):
    """Applies the chat template to each question."""
    logger.info(f"Applying chat template using {tokenizer.name_or_path}...")
    prompts = []
    skipped_count = 0
    for sample in tqdm(dataset, desc="Applying template"):
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": sample["question"]},
        ]
        try:
            # Important: Ensure the tokenizer is compatible with apply_chat_template
            prompt_text = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            prompts.append(prompt_text)
        except Exception as e:
            logger.warning(f"Failed template application for question (truncated): {sample['question'][:50]}... Error: {e}. Skipping.")
            prompts.append(None) # Placeholder for filtering
            skipped_count += 1

    # Add prompts and filter failures
    dataset = dataset.add_column("prompt_text", prompts)
    original_len = len(prompts) # Count before filtering Nones
    dataset = dataset.filter(lambda x: x["prompt_text"] is not None)
    if skipped_count > 0 or len(dataset) != original_len: # Check both ways
        actual_filtered = original_len - len(dataset)
        logger.warning(f"Filtered out {actual_filtered} samples due to chat template errors.")

    if not dataset:
         logger.error("Dataset is empty after applying chat template. Cannot proceed.")
         # Consider raising an error or handling this state appropriately upstream
         return dataset # Return empty dataset

    logger.info(f"Chat template applied. {len(dataset)} prompts ready.")
    return dataset


# --- Evaluation Function ---
@torch.no_grad()
def evaluate_model(model_id: str, dataset: Dataset, batch_size: int, tokenizer, debug_sample_output: bool):
    """Evaluates a model using a pipeline, calculates accuracy, logs samples per batch if debug enabled."""
    logger.info(f"Starting evaluation for: {model_id}")
    logger.info(f"Using dtype={TORCH_DTYPE}, attn='{ATTN_IMPLEMENTATION}', max_new_tokens={GENERATION_MAX_NEW_TOKENS}")
    start_load_time = time.time()
    pipe = None

    # Check if dataset is empty before proceeding
    if not dataset or len(dataset) == 0:
        logger.error(f"Dataset provided to evaluate_model for {model_id} is empty. Skipping evaluation.")
        return 0.0

    try:
        # Load pipeline
        pipe = pipeline(
            "text-generation",
            model=model_id,
            tokenizer=tokenizer, # Use pre-loaded tokenizer
            device_map="auto",
            torch_dtype=TORCH_DTYPE,
            trust_remote_code=True,
            model_kwargs={"attn_implementation": ATTN_IMPLEMENTATION} if ATTN_IMPLEMENTATION == "flash_attention_2" else {} # Only pass if using flash
        )

        # Ensure padding token setup for batching
        if pipe.tokenizer.pad_token is None or pipe.tokenizer.pad_token_id is None:
            logger.warning(f"Tokenizer for {model_id} lacks pad token/ID. Setting pad_token = eos_token.")
            pipe.tokenizer.pad_token = pipe.tokenizer.eos_token
            pipe.tokenizer.pad_token_id = pipe.tokenizer.eos_token_id # Make sure ID is set
        if pipe.model.config.pad_token_id is None:
             logger.warning(f"Model config for {model_id} lacks pad_token_id. Setting from tokenizer: {pipe.tokenizer.pad_token_id}")
             pipe.model.config.pad_token_id = pipe.tokenizer.pad_token_id

        # Use left-padding for decoder-only models during generation
        pipe.tokenizer.padding_side = "left"
        logger.info(f"Tokenizer padding side: '{pipe.tokenizer.padding_side}', Pad token: '{pipe.tokenizer.pad_token}', ID: {pipe.tokenizer.pad_token_id}")

    except Exception as e:
        logger.error(f"FATAL: Failed to load pipeline for {model_id}: {e}", exc_info=True)
        if pipe is not None: del pipe
        torch.cuda.empty_cache()
        return 0.0 # Indicate failure

    load_time = time.time() - start_load_time
    logger.info(f"Pipeline loaded in {load_time:.2f}s.")

    logger.info(f"Generating predictions for {len(dataset)} samples (batch size: {batch_size})...")
    predictions = []
    generation_start_time = time.time()
    debug_samples_logged = 0 # Counter for the *initial* debug samples

    # Process dataset in batches using the pipeline
    # The pipeline handles batching internally based on `batch_size`
    # The loop iterates through results sample by sample, `i` is the dataset index
    for i, output in enumerate(tqdm(pipe(dataset["prompt_text"],
                                           max_new_tokens=GENERATION_MAX_NEW_TOKENS,
                                           do_sample=False, # Greedy decoding for consistency
                                           batch_size=batch_size,
                                           return_full_text=False, # Only generated part
                                           pad_token_id=pipe.tokenizer.pad_token_id), # Crucial for batching
                                      total=len(dataset), desc=f"Generating {os.path.basename(model_id)}")):
        pred = None # Default prediction if extraction fails
        generated_text = "Error: Could not get generated text." # Default message

        try:
            # Standard output structure: [{'generated_text': '...'}]
            generated_text = output[0]['generated_text']
            pred = extract_xml_answer(generated_text) # Extract final answer

            # --- Combined Debug Logging ---
            if debug_sample_output:
                # 1. Log first N raw outputs (original behavior)
                if debug_samples_logged < MAX_DEBUG_SAMPLES:
                    generated_length = len(tokenizer.encode(generated_text))
                    logger.info(f"\n[DEBUG RAW OUTPUT {debug_samples_logged+1}/{MAX_DEBUG_SAMPLES} - {model_id}]"
                                f"\nRaw Output (Tokens: ~{generated_length}/{GENERATION_MAX_NEW_TOKENS}):\n---\n{generated_text}\n---"
                                f"\nExtracted: {pred}")
                    if generated_length >= GENERATION_MAX_NEW_TOKENS - 5: # Small buffer
                         logger.warning(f"[DEBUG] Raw output length close to max_new_tokens. May be truncated.")
                    debug_samples_logged += 1

                # 2. Log Question/Full Output/Gold for first sample of each batch
                # Check if 'i' is the start of a batch (index 0, batch_size, 2*batch_size, etc.)
                if i % batch_size == 0:
                    try:
                        question = dataset[i]["question"]
                        gold_answer = dataset[i]["gold_answer"]
                        batch_num = i // batch_size + 1
                        total_batches = (len(dataset) + batch_size - 1) // batch_size
                        logger.info(f"\n--- Sample from Batch {batch_num}/{total_batches} (Dataset Index {i}) ---"
                                    f"\nModel: {model_id}"
                                    f"\nQuestion: {question}"
                                    f"\nFull Model Output:\n{generated_text}"
                                    f"\nGold Answer: {gold_answer}"
                                    f"\nExtracted Answer: {pred}\n"
                                    f"------------------------------------")
                    except IndexError:
                         logger.warning(f"Could not retrieve data for batch sample log at index {i}.")
                    except Exception as log_e:
                         logger.warning(f"Error during batch sample logging at index {i}: {log_e}")
            # --- End Combined Debug Logging ---

        except (IndexError, KeyError, TypeError) as e:
            # Handle pipeline output parsing issues
            logger.warning(f"Pipeline output parsing error at index {i}: {e}. Raw output object: {output}. Prediction set to None.")
            # Log the raw text if possible, even on error
            raw_text_on_error = "N/A"
            try: raw_text_on_error = output[0]['generated_text']
            except Exception: pass
            if debug_sample_output: logger.warning(f"[DEBUG RAW ON PARSE ERROR]\n---\n{raw_text_on_error}\n---")

        except Exception as e:
            # Catch any other unexpected errors during processing a single sample
            logger.error(f"Unexpected error processing sample {i} for {model_id}: {e}", exc_info=True)
            # generated_text is already set to an error message
        finally:
            # Always append a prediction (None if extraction failed or error occurred)
            predictions.append(pred)

    generation_time = time.time() - generation_start_time
    samples_per_sec = len(dataset) / generation_time if generation_time > 0 else 0
    logger.info(f"Generation done in {generation_time:.2f}s ({samples_per_sec:.2f} samples/sec).")

    # Calculate Accuracy
    correct = 0
    total = 0
    mismatched_samples = []
    logger.info("Calculating accuracy...")
    for i, (pred, gold) in enumerate(zip(predictions, dataset["gold_answer"])):
        if gold is None: continue # Should be filtered, but check
        total += 1
        if pred is not None and pred == gold:
            correct += 1
        elif len(mismatched_samples) < MAX_MISMATCH_LOG:
             # Log mismatch details (question truncated)
             mismatched_samples.append({
                 "index": i, "question": dataset[i]["question"][:100]+"...",
                 "prediction": pred, "gold": gold
             })

    accuracy = (correct / total) * 100 if total > 0 else 0.0
    # Ensure total isn't zero before logging division
    correct_str = f"{correct}/{total}" if total > 0 else "0/0"
    logger.info(f"Model {model_id} -- Accuracy: {accuracy:.2f}% ({correct_str})")


    if mismatched_samples:
        logger.warning(f"--- First {len(mismatched_samples)} Mismatched/Failed Samples ({model_id}) ---")
        for sample in mismatched_samples:
            logger.warning(f"Idx: {sample['index']}, Q: '{sample['question']}', Pred: '{sample['prediction']}', Gold: '{sample['gold']}'")
        logger.warning("--- End Mismatched Samples ---")

    # Cleanup
    logger.info(f"Cleaning up resources for {model_id}...")
    del pipe
    # Explicitly clear cache if needed, might help prevent OOM in loops
    torch.cuda.empty_cache()
    logger.info(f"Finished cleanup for {model_id}.")

    return accuracy

# --- Tokenizer Cache Helper ---
tokenizer_cache = {}
def get_tokenizer(model_id_or_path: str, cache_dir: str):
    """Loads and caches tokenizer, ensuring pad token exists."""
    # Normalize path for consistent caching key
    normalized_path = os.path.abspath(model_id_or_path) if os.path.exists(model_id_or_path) else model_id_or_path

    if normalized_path in tokenizer_cache:
        logger.info(f"Using cached tokenizer for {model_id_or_path}")
        return tokenizer_cache[normalized_path]
    else:
        logger.info(f"Loading tokenizer for {model_id_or_path}...")
        try:
            tokenizer = AutoTokenizer.from_pretrained(
                model_id_or_path,
                trust_remote_code=True,
                cache_dir=cache_dir,
                padding_side='left' # Set padding side during load
            )
            # Critical: Ensure pad token and ID are set for batching
            if tokenizer.pad_token is None or tokenizer.pad_token_id is None:
                if tokenizer.eos_token is not None and tokenizer.eos_token_id is not None:
                    logger.warning(f"Tokenizer for {model_id_or_path} lacks pad token/ID. Setting to EOS token: '{tokenizer.eos_token}' (ID: {tokenizer.eos_token_id}).")
                    tokenizer.pad_token = tokenizer.eos_token
                    tokenizer.pad_token_id = tokenizer.eos_token_id
                else:
                    # This is a problem - no EOS token either! Add a default?
                    logger.error(f"CRITICAL: Tokenizer for {model_id_or_path} lacks BOTH pad and EOS tokens. Cannot set default padding. Batching will likely fail.")
                    # You might need to manually add a pad token here if this occurs:
                    # tokenizer.add_special_tokens({'pad_token': '[PAD]'})
                    # model.resize_token_embeddings(len(tokenizer)) # If model is loaded separately
                    # For pipeline, this is harder. Best to fix the tokenizer source.
                    return None # Indicate failure

            logger.info(f"Tokenizer loaded. Pad token: '{tokenizer.pad_token}', ID: {tokenizer.pad_token_id}, Padding side: {tokenizer.padding_side}")
            tokenizer_cache[normalized_path] = tokenizer # Cache it using normalized path
            return tokenizer
        except Exception as e:
            logger.error(f"Failed to load tokenizer {model_id_or_path}: {e}", exc_info=True)
            return None

# --- Main Execution Logic ---
def main(args: SimpleNamespace):
    """Orchestrates dataset loading, model evaluation, and result reporting."""
    overall_start_time = time.time()
    results = {}

    # Load and prepare the dataset once
    try:
        dataset_full = get_gsm8k_dataset(cache_dir=args.cache_dir, split="test")
    except Exception:
        logger.error("Failed to load the dataset. Exiting.")
        return

    # Subset the dataset if requested
    total_samples = len(dataset_full)
    eval_size = total_samples
    if 0.0 < args.eval_fraction < 1.0:
        eval_size = max(1, int(total_samples * args.eval_fraction)) # Ensure at least 1
        logger.info(f"Selecting {eval_size} samples ({args.eval_fraction*100:.1f}%) from {total_samples} using seed={SEED}...")
        dataset = dataset_full.shuffle(seed=SEED).select(range(eval_size))
    elif args.eval_fraction == 1.0:
         logger.info(f"Evaluating on the full test set ({total_samples} samples).")
         dataset = dataset_full
    else:
        # This validation is also done at startup, but good to have defense in depth
        logger.error(f"Invalid eval_fraction: {args.eval_fraction}. Must be > 0.0 and <= 1.0. Exiting.")
        return

    if len(dataset) == 0:
        logger.error("Dataset is empty after potential filtering/subsetting. Exiting.")
        return

    # --- Evaluate Base Model ---
    logger.info(f"\n--- Evaluating Base Model: {args.base_model} ---")
    base_tokenizer = get_tokenizer(args.base_model, args.cache_dir)
    if base_tokenizer:
        # Prepare prompts using the *original (potentially subsetted) dataset* and base tokenizer
        base_dataset_prepared = prepare_prompts(dataset, base_tokenizer)
        if len(base_dataset_prepared) > 0:
             base_acc = evaluate_model(
                 args.base_model, base_dataset_prepared, args.batch_size,
                 base_tokenizer, args.debug_sample_output # Pass debug flag
             )
             results["base_model"] = base_acc
        else:
             logger.error("Base dataset preparation resulted in 0 samples. Skipping base model eval.")
             results["base_model"] = "Eval Skipped (0 samples)"
    else:
         logger.error("Skipping base model evaluation due to tokenizer load failure.")
         results["base_model"] = "Eval Failed (Tokenizer Load)"
    # ---------------------------

    # --- Evaluate Checkpoints ---
    logger.info(f"\n--- Evaluating Checkpoints in: {args.output_dir} ---")
    if not os.path.isdir(args.output_dir):
        logger.warning(f"Checkpoint directory not found: {args.output_dir}. Skipping checkpoint evaluation.")
    else:
        checkpoints = []
        try:
            ckpt_dirs = [d for d in os.listdir(args.output_dir) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(args.output_dir, d))]
            # Extract step number, handling potential errors, and sort
            valid_checkpoints = []
            for d in ckpt_dirs:
                try:
                    step = int(d.split("-")[-1])
                    valid_checkpoints.append((step, d))
                except (ValueError, IndexError):
                    logger.warning(f"Could not parse step number from directory name: {d}")
            checkpoints = sorted(valid_checkpoints, key=lambda x: x[0]) # Sort by step number
            logger.info(f"Found {len(checkpoints)} valid checkpoint directories.")
        except OSError as e:
            logger.error(f"Cannot access checkpoint directory {args.output_dir}: {e}", exc_info=True)
            checkpoints = [] # Ensure checkpoints is empty list on error

        if not checkpoints:
             logger.warning(f"No valid 'checkpoint-<number>' directories found in {args.output_dir}.")

        # Simplified Checkpoint Loop
        for step, ckpt_name in checkpoints:
            ckpt_path = os.path.join(args.output_dir, ckpt_name)
            logger.info(f"\n--- Evaluating Checkpoint: {ckpt_name} (Step: {step}) ---")

            # Determine tokenizer path: checkpoint's own or fallback to base
            # Check for tokenizer files within the specific checkpoint directory
            tokenizer_config_path = os.path.join(ckpt_path, "tokenizer_config.json")
            if os.path.exists(tokenizer_config_path):
                tokenizer_path = ckpt_path # Use checkpoint's tokenizer
                logger.info(f"Found tokenizer config in {ckpt_path}. Using checkpoint-specific tokenizer.")
            else:
                tokenizer_path = args.base_model # Fallback to base model tokenizer
                logger.info(f"No tokenizer config in {ckpt_path}. Using base model tokenizer ({args.base_model}).")

            # Load tokenizer (uses cache if path hasn't changed effectively)
            ckpt_tokenizer = get_tokenizer(tokenizer_path, args.cache_dir)
            if not ckpt_tokenizer:
                logger.error(f"Failed to load tokenizer from {tokenizer_path}. Skipping checkpoint {ckpt_name}.")
                results[ckpt_name] = "Eval Failed (Tokenizer Load)"
                continue # Skip this checkpoint

            # Prepare prompts using the *original (potentially subsetted) dataset* and the determined tokenizer
            # This ensures prompts are always correctly formatted for the specific model/tokenizer being evaluated
            logger.info(f"Preparing prompts for {ckpt_name} using tokenizer: {tokenizer_path}")
            ckpt_dataset_prepared = prepare_prompts(dataset, ckpt_tokenizer)

            if len(ckpt_dataset_prepared) == 0:
                logger.error(f"Dataset preparation for {ckpt_name} resulted in 0 samples (using tokenizer {tokenizer_path}). Skipping eval.")
                results[ckpt_name] = "Eval Skipped (0 samples)"
                continue # Skip this checkpoint

            # Evaluate the checkpoint using its specific path and the correctly prepared dataset/tokenizer
            acc = evaluate_model(
                ckpt_path,              # Model path is the checkpoint directory
                ckpt_dataset_prepared,  # Dataset prepared with the right tokenizer
                args.batch_size,
                ckpt_tokenizer,         # The tokenizer object itself
                args.debug_sample_output # Pass debug flag
            )
            results[ckpt_name] = acc
    # --------------------------

    # --- Print Final Summary ---
    print("\n" + "="*60)
    print(" " * 15 + "Overall Evaluation Results Summary")
    print("="*60)
    total_eval_time = time.time() - overall_start_time

    # Print Base Model Result First
    if "base_model" in results:
         base_res = results.pop("base_model") # Remove from dict to handle checkpoints separately
         model_name_disp = f"BASE: {args.base_model}"
         res_str = f"{base_res:.2f}% accuracy" if isinstance(base_res, float) else base_res
         print(f"{model_name_disp:<45}: {res_str}")

    # Print Checkpoint Results, sorted by step number (derived from the key)
    sorted_ckpt_results = sorted(
        [(k, v) for k, v in results.items() if k.startswith("checkpoint-")],
        key=lambda item: int(item[0].split('-')[-1]) # Sort by step number in the key string
    )
    for model_id, acc in sorted_ckpt_results:
        res_str = f"{acc:.2f}% accuracy" if isinstance(acc, float) else acc
        print(f"{model_id:<45}: {res_str}") # model_id is like "checkpoint-1000"

    print("-" * 60)
    logger.info(f"Total evaluation runtime: {total_eval_time:.2f} seconds ({total_eval_time/60:.2f} minutes).")
    print("="*60)

# --- Script Entry Point ---
if __name__ == "__main__":
    # Configuration Setup using SimpleNamespace and Environment Variables
    args = SimpleNamespace()
    args.base_model = os.environ.get("BASE_MODEL", DEFAULT_BASE_MODEL)
    args.output_dir = os.environ.get("OUTPUT_DIR", DEFAULT_OUTPUT_DIR)
    args.cache_dir = os.environ.get("CACHE_DIR", DEFAULT_CACHE_DIR)
    try:
        args.eval_fraction = float(os.environ.get("EVAL_FRACTION", DEFAULT_EVAL_FRACTION))
    except ValueError:
        logger.warning(f"Invalid EVAL_FRACTION env var. Using default: {DEFAULT_EVAL_FRACTION}")
        args.eval_fraction = DEFAULT_EVAL_FRACTION
    try:
        args.batch_size = int(os.environ.get("BATCH_SIZE", DEFAULT_BATCH_SIZE))
    except ValueError:
        logger.warning(f"Invalid BATCH_SIZE env var. Using default: {DEFAULT_BATCH_SIZE}")
        args.batch_size = DEFAULT_BATCH_SIZE
    # Read debug flag from environment, converting common "true" values
    debug_env = os.environ.get("DEBUG_SAMPLE_OUTPUT", str(DEFAULT_DEBUG_SAMPLE_OUTPUT)).lower()
    args.debug_sample_output = debug_env in ['true', '1', 'yes', 'on']

    # Validation
    if not 0.0 < args.eval_fraction <= 1.0:
         logger.error(f"Configuration Error: eval_fraction ({args.eval_fraction}) must be > 0.0 and <= 1.0. Adjust script or EVAL_FRACTION env var.")
         exit(1) # Exit on invalid config
    if args.batch_size <= 0:
        logger.error(f"Configuration Error: batch_size ({args.batch_size}) must be positive. Adjust script or BATCH_SIZE env var.")
        exit(1) # Exit on invalid config

    # Log Final Configuration
    logger.info("--- Final Configuration ---")
    logger.info(f"Base Model         : {args.base_model}")
    logger.info(f"Checkpoints Dir    : {args.output_dir}")
    logger.info(f"Cache Dir          : {args.cache_dir}")
    logger.info(f"Eval Fraction      : {args.eval_fraction}")
    logger.info(f"Batch Size         : {args.batch_size}")
    logger.info(f"Debug Sample Output: {args.debug_sample_output} (Logs 1st sample/batch if True)")
    logger.info(f"Torch Dtype        : {TORCH_DTYPE}")
    logger.info(f"Attention Impl.    : {ATTN_IMPLEMENTATION}")
    logger.info(f"Max New Tokens     : {GENERATION_MAX_NEW_TOKENS}")
    logger.info("-------------------------")

    main(args)

Applying template: 100%|██████████| 263/263 [00:00<00:00, 6898.22it/s]
Loading checkpoint shards:  75%|███████▌  | 3/4 [00:15<00:05,  5.01s/it]


KeyboardInterrupt: 

In [None]:
!pip install flash-attn --no-build-isolation

In [None]:
# Use a pipeline as a high-level helper
from transformers import pipeline

SYSTEM_PROMPT = """
Respond in the following format, you have to adhere to the format, only output the final answer without **ANY** additional information in the "answer" box.

<think>
...
</think>
<answer>
...
</answer>
"""

messages = [
    {'role': 'system', 'content': SYSTEM_PROMPT},
    {"role": "user", "content": "what is 2+2?"},
]
pipe = pipeline("text-generation", model="Qwen/Qwen2.5-7B-Instruct-1M", device_map='auto')
pipe(messages)

In [None]:
!pip install flash-attn

In [None]:
!module load cuda/12.4.1

In [None]:
!which nvcc

In [None]:
# folio_evaluation_revised.py

import os
import json
import numpy as np
import torch
import logging  # Use logging for better feedback
from tqdm import tqdm
from datasets import load_dataset
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer  # No PeftModel needed here
import re
import glob
import time  # Import time for potential timing if needed

# --- Flash Attention Check ---
try:
    import flash_attn
    _flash_attn_available = True
    logging.info("Flash Attention 2 available. Will use if applicable.")
except ImportError:
    _flash_attn_available = False
    logging.warning("Flash Attention 2 not found. Install with `pip install flash-attn --no-build-isolation` for potential speedup.")

# --- Configuration Constants ---
BASE_MODEL_PATH = "Qwen/Qwen2.5-7B-Instruct-1M"
TRAINING_OUTPUT_DIR = "cai6307-henrykobs/model/Qwen-7B-GRPO"
CACHE_DIR = "cai6307-henrykobs/cache"
EVALUATION_ROOT_DIR = "./folio_checkpoint_evaluation_revised_changed_prompt"  # Updated output dir name
BATCH_SIZE = 8  # Adjust based on memory
MAX_LENGTH = 1536  # Input context window
MAX_NEW_TOKENS = 1024  # Max output tokens
TORCH_DTYPE = torch.bfloat16  # Consistent dtype
ATTN_IMPLEMENTATION = "flash_attention_2" if _flash_attn_available else "eager"

# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Hugging Face Login ---
try:
    os.makedirs(CACHE_DIR, exist_ok=True)
    login(token="")
    logger.info("Hugging Face login successful.")
except Exception as e:
    logger.warning(f"Hugging Face login failed: {e}. Ensure HF_TOKEN is set. Public models might still work.")

# --- System Prompt (FOLIO Task) ---
SYSTEM_PROMPT = """Respond in the following format, you have to adhere to the format, only output the final answer without **ANY** additional information in the "answer" box.

<think>
...
</think>
<answer>
...
</answer>
"""

# --- Helper Functions (Dataset Loading, Answer Extraction, Prompt Prep) ---

def load_folio_dataset():
    """Load the FOLIO dataset validation split."""
    try:
        logger.info("Loading FOLIO dataset (validation split)...")
        dataset = load_dataset("yale-nlp/FOLIO", split="validation", cache_dir=CACHE_DIR)
        logger.info(f"Loaded {len(dataset)} validation examples from FOLIO dataset.")
        return dataset
    except Exception as e:
        logger.error(f"Error loading FOLIO dataset: {e}", exc_info=True)
        return None

def extract_answer_with_status(response):
    """Parse model responses for FOLIO (True/False/Uncertain)."""
    response = response.strip()
    answer_pattern = re.compile(r'<answer>(.*?)</answer>', re.DOTALL | re.IGNORECASE)
    match = answer_pattern.search(response)

    if match:
        answer = match.group(1).strip().lower()
        if answer == "true":
            return "True", "TAG_SUCCESS"
        if answer == "false":
            return "False", "TAG_SUCCESS"
        if answer == "uncertain":
            return "Uncertain", "TAG_SUCCESS"
        words = answer.split()
        if "true" in words:
            return "True", "TAG_FALLBACK"
        if "false" in words:
            return "False", "TAG_FALLBACK"
        if "uncertain" in words:
            return "Uncertain", "TAG_FALLBACK"

    response_lower = response.lower()
    if re.search(r'[.>:\s]\s*(true)\s*$', response_lower):
        return "True", "REGEX_FALLBACK"
    if re.search(r'[.>:\s]\s*(false)\s*$', response_lower):
        return "False", "REGEX_FALLBACK"
    if re.search(r'[.>:\s]\s*(uncertain)\s*$', response_lower):
        return "Uncertain", "REGEX_FALLBACK"

    words = response_lower.split()
    if "true" in words:
        return "True", "KEYWORD_FALLBACK"
    if "false" in words:
        return "False", "KEYWORD_FALLBACK"
    if "uncertain" in words:
        return "Uncertain", "KEYWORD_FALLBACK"

    return "Uncertain", "DEFAULT_UNCERTAIN"

def prepare_prompt(example, tokenizer):
    """Format FOLIO example into chat prompt."""
    premises = example["premises"]
    conclusion = example["conclusion"]
    user_message = f"""Given the following premises, determine if the conclusion is True, False, or Uncertain in the answer box.

Premises:
{premises}

Conclusion:
{conclusion}"""
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_message}
    ]
    try:
        if not tokenizer.chat_template and tokenizer.default_chat_template:
            tokenizer.chat_template = tokenizer.default_chat_template
            # logger.debug("Applied default chat template to tokenizer.")  # Use debug if too verbose
        elif not tokenizer.chat_template:
            logger.warning("Tokenizer does not have a chat template. Prompt formatting might be incorrect.")
            prompt_string = SYSTEM_PROMPT + "\n" + user_message + "\nAssistant:"
            return prompt_string

        prompt_string = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        return prompt_string
    except Exception as e:
        logger.error(f"Error applying chat template: {e}", exc_info=True)
        try:
            prompt_string = SYSTEM_PROMPT + "\n" + user_message + "\nAssistant:"
            logger.warning("Falling back to basic prompt concatenation due to template error.")
            return prompt_string
        except Exception as e_fallback:
            logger.error(f"Fallback prompt formatting also failed: {e_fallback}")
            return "Error: Prompt formatting failed"

# --- Tokenizer Loading Helper (Adopted from Script 2) ---
tokenizer_cache = {}

def get_tokenizer(model_id_or_path: str, base_model_path_fallback: str, cache_dir: str):
    """Loads and caches tokenizer, checking checkpoint dir first, ensures pad token."""
    tokenizer_path_to_load = base_model_path_fallback  # Default to base
    load_source = "base model fallback"

    # Check if model_id_or_path is a directory (i.e., a checkpoint)
    if os.path.isdir(model_id_or_path):
        # Check if tokenizer files exist within the checkpoint directory
        tokenizer_config_path = os.path.join(model_id_or_path, "tokenizer_config.json")
        if os.path.exists(tokenizer_config_path):
            tokenizer_path_to_load = model_id_or_path  # Use checkpoint's tokenizer path
            load_source = "checkpoint directory"
            logger.info(f"Found tokenizer config in {model_id_or_path}. Will load tokenizer from checkpoint.")
        else:
            logger.info(f"No tokenizer config in {model_id_or_path}. Will load tokenizer from base model: {base_model_path_fallback}.")

    # Use the absolute path (if local) or ID as the cache key
    cache_key = os.path.abspath(tokenizer_path_to_load) if os.path.exists(tokenizer_path_to_load) else tokenizer_path_to_load

    if cache_key in tokenizer_cache:
        logger.info(f"Using cached tokenizer (originally from {load_source}): {tokenizer_path_to_load}")
        return tokenizer_cache[cache_key]
    else:
        logger.info(f"Loading tokenizer from {load_source}: {tokenizer_path_to_load}...")
        try:
            tokenizer = AutoTokenizer.from_pretrained(
                tokenizer_path_to_load,
                trust_remote_code=True,
                cache_dir=cache_dir,
                padding_side='left'  # Set padding side during load
            )
            # Critical: Ensure pad token and ID are set for batching
            if tokenizer.pad_token_id is None:
                if tokenizer.eos_token_id is not None:
                    logger.warning(f"Tokenizer lacks pad_token_id. Setting to eos_token_id ({tokenizer.eos_token_id}).")
                    tokenizer.pad_token_id = tokenizer.eos_token_id
                    tokenizer.pad_token = tokenizer.eos_token  # Also set the token string
                else:
                    # Add a pad token if EOS is also missing (less common but possible)
                    logger.error("CRITICAL: Tokenizer lacks BOTH pad and EOS tokens. Adding '[PAD]'. Model resizing needed.")
                    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
                    # The model resize will happen after model loading

            logger.info(f"Tokenizer loaded. Pad token: '{tokenizer.pad_token}', ID: {tokenizer.pad_token_id}, Padding side: {tokenizer.padding_side}")
            tokenizer_cache[cache_key] = tokenizer  # Cache it
            return tokenizer
        except Exception as e:
            logger.error(f"Failed to load tokenizer {tokenizer_path_to_load}: {e}", exc_info=True)
            return None  # Indicate failure

# --- Generation Function (REVISED loading logic) ---

def generate_predictions_for_checkpoint(
    model_path,  # Path to checkpoint dir OR base model ID
    dataset,
    batch_size,
    max_length,
    max_new_tokens,
    checkpoint_output_dir,
    cache_dir
):
    """
    Loads model (base or checkpoint+adapter) automatically using from_pretrained
    and generates predictions.
    """
    model = None
    tokenizer = None
    all_completions = []

    logger.info(f"\n--- Evaluating: {model_path} ---")
    logger.info(f"Using dtype={TORCH_DTYPE}, attn='{ATTN_IMPLEMENTATION}'")

    try:
        # --- Load Tokenizer (using helper function) ---
        tokenizer = get_tokenizer(model_path, BASE_MODEL_PATH, cache_dir)
        if not tokenizer:
            raise ValueError(f"Failed to load tokenizer for {model_path}. Cannot proceed.")

        # --- Load Model (Simplified - AutoModel handles adapters) ---
        logger.info(f"Loading model using AutoModelForCausalLM from: {model_path}")
        # This will load the base model + adapter if model_path is a checkpoint dir with adapter_config.json
        # OR it will load just the base model if model_path is the base model ID.
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=TORCH_DTYPE,
            device_map="auto",  # Let HF handle device placement
            # low_cpu_mem_usage=True,  # Generally good for large models
            cache_dir=cache_dir,
            trust_remote_code=True,  # Often needed for custom model code (like Qwen)
            attn_implementation=ATTN_IMPLEMENTATION  # Use flash attention if available
        )

        # --- Resize embeddings AFTER model load ---
        # Important if tokenizer added tokens (e.g., pad token)
        # Compare model vocab size with tokenizer vocab size
        if model.config.vocab_size != len(tokenizer):
            logger.warning(f"Model vocab size ({model.config.vocab_size}) != Tokenizer vocab size ({len(tokenizer)}). Resizing model embeddings.")
            model.resize_token_embeddings(len(tokenizer))
            # Check if vocab size matches after resize
            if model.config.vocab_size != len(tokenizer):
                logger.error("Resizing token embeddings failed! Vocab sizes still mismatch.")
        else:
            logger.info("Model and tokenizer vocab sizes match.")

        # --- Set Pad Token ID in Model Config ---
        # Crucial for generation if tokenizer's pad_token_id was initially None
        if model.config.pad_token_id is None:
            # Check if tokenizer has a valid pad_token_id first
            if tokenizer.pad_token_id is not None:
                logger.warning(f"Model config lacks pad_token_id. Setting from tokenizer: {tokenizer.pad_token_id}")
                model.config.pad_token_id = tokenizer.pad_token_id
            else:
                # This case should be handled by get_tokenizer, but double-check
                logger.error("CRITICAL: Model config and tokenizer both lack pad_token_id even after tokenizer loading. Batch generation may fail.")

        model.eval()
        # *** THIS IS THE CORRECTED LINE ***
        logger.info(f"Model loaded onto device(s): {model.device if hasattr(model, 'device') else 'Distributed (device_map=auto)'}")

        # --- Prepare Prompts ---
        logger.info("Preparing prompts for FOLIO task...")
        prompts = []
        prompt_errors = 0
        for i, example in enumerate(dataset):
            prompt_str = prepare_prompt(example, tokenizer)
            if prompt_str == "Error: Prompt formatting failed":
                logger.warning(f"Skipping example {i} due to prompt formatting error.")
                prompts.append(None)  # Placeholder for filtering
                prompt_errors += 1
            else:
                prompts.append({"prompt": prompt_str})

        # Filter out errored prompts if any
        valid_prompts = [p["prompt"] for p in prompts if p is not None]
        original_indices = [i for i, p in enumerate(prompts) if p is not None]  # Keep track if needed, not strictly used below currently
        logger.info(f"Prepared {len(valid_prompts)} valid prompts ({prompt_errors} errors).")

        if not valid_prompts:
            raise ValueError("No valid prompts could be generated. Cannot proceed.")

        # --- Batch Generation ---
        logger.info(f"Generating completions in batches of {batch_size}...")
        generated_outputs = ["Error: Generation not run"] * len(valid_prompts)  # Initialize placeholders based on valid prompts count

        progress_bar = tqdm(range(0, len(valid_prompts), batch_size), desc=f"Generating {os.path.basename(model_path)}")
        for i in progress_bar:
            batch_prompts = valid_prompts[i: i + batch_size]

            try:
                inputs = tokenizer(
                    batch_prompts,
                    return_tensors="pt",
                    padding=True,  # Pad batch to the longest sequence
                    truncation=True,
                    max_length=max_length  # Max input length
                ).to(model.device)  # Ensure batch is on the same device as model

                # Double-check pad_token_id just before generation
                gen_pad_token_id = model.config.pad_token_id if model.config.pad_token_id is not None else tokenizer.pad_token_id
                if gen_pad_token_id is None:
                    logger.error("Cannot determine pad_token_id for generation. Batching will likely fail.")
                    # Handle error appropriately, maybe skip batch

                with torch.no_grad():
                    outputs = model.generate(
                        **inputs,
                        max_new_tokens=max_new_tokens,  # Max length for OUTPUT
                        do_sample=False,  # Greedy for evaluation
                        pad_token_id=gen_pad_token_id  # CRUCIAL for batch generation
                    )

                # Decode only generated tokens
                input_length = inputs.input_ids.shape[1]
                decoded_batch_outputs = tokenizer.batch_decode(
                    outputs[:, input_length:],
                    skip_special_tokens=True
                )

                # Store results for this batch into the correct slice of generated_outputs
                for j, completion in enumerate(decoded_batch_outputs):
                    output_index = i + j
                    if output_index < len(generated_outputs):  # Boundary check
                        generated_outputs[output_index] = completion.strip()
                    else:
                        logger.error(f"Index out of bounds ({output_index}) when storing batch generation results. Batch size: {len(decoded_batch_outputs)}, Total valid: {len(generated_outputs)}")

            except Exception as e:
                logger.error(f"\nError during generation for batch starting at index {i}: {e}", exc_info=True)
                # Mark outputs in this batch as failed
                for j in range(len(batch_prompts)):
                    output_index = i + j
                    if output_index < len(generated_outputs):
                        generated_outputs[output_index] = "Error: Generation failed in batch"

        # --- Reconstruct full completions list including errors ---
        # This ensures the final list aligns with the original dataset length
        final_completions_idx = 0
        for i in range(len(prompts)):  # Iterate through original prompt list length (including None placeholders)
            if prompts[i] is None:  # Was a formatting error
                all_completions.append("Error: Prompt formatting failed")
            else:
                # Check if index is valid for generated_outputs
                if final_completions_idx < len(generated_outputs):
                    all_completions.append(generated_outputs[final_completions_idx])
                    final_completions_idx += 1  # Increment only when using a valid output slot
                else:
                    # This indicates a logic error - more prompts were valid than outputs generated/stored
                    logger.error(f"Mismatch during result reconstruction at original index {i}. Expected valid output but none found (valid index {final_completions_idx} >= generated {len(generated_outputs)}).")
                    all_completions.append("Error: Mismatch during result reconstruction")

        # Final length check
        if len(all_completions) != len(dataset):
            logger.error(f"CRITICAL: Final completions length ({len(all_completions)}) does not match dataset length ({len(dataset)}).")

        logger.info(f"Finished generation for {model_path}.")

    except Exception as e:
        logger.error(f"FATAL ERROR during setup or generation for {model_path}: {e}", exc_info=True)
        # Fill remaining completions if needed
        num_expected = len(dataset)
        num_generated = len(all_completions)
        if num_generated < num_expected:
            logger.error(f"Appending {num_expected - num_generated} error messages due to failure.")
            all_completions.extend(["Error: Checkpoint evaluation failed"] * (num_expected - num_generated))
        all_completions = all_completions[:num_expected]  # Ensure correct length

    finally:
        # --- Resource Cleanup ---
        logger.info(f"Cleaning up resources for {model_path}...")
        del model
        del tokenizer
        # Clearing the global tokenizer cache might cause reloads if base tokenizer is used again.
        # Consider if tokenizer_cache needs clearing or if memory usage is acceptable.
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            logger.info("CUDA cache cleared.")

    return all_completions

# --- Analysis and Reporting Functions (Mostly Unchanged) ---

def analyze_extraction_methods(statuses):
    """Analyzes distribution of how answers were extracted."""
    status_counts = {}
    for status in statuses:
        status_counts[status] = status_counts.get(status, 0) + 1
    logger.info("Answer Extraction Method Distribution:")
    total = len(statuses)
    if total == 0:
        logger.info("  No responses to analyze.")
    else:
        for status, count in sorted(status_counts.items()):
            logger.info(f"  {status}: {count} ({count/total*100:.2f}%)")

def analyze_response_distribution(predictions):
    """Analyzes distribution of predicted labels (True/False/Uncertain)."""
    prediction_counts = {}
    for pred in predictions:
        prediction_counts[pred] = prediction_counts.get(pred, 0) + 1
    logger.info("Prediction distribution:")
    total = len(predictions)
    if total == 0:
        logger.info("  No predictions to analyze.")
    else:
        for pred, count in sorted(prediction_counts.items()):
            logger.info(f"  {pred}: {count} ({count/total*100:.2f}%)")

def plot_confusion_matrix(cm, classes, output_path):
    """Generate visualization of confusion matrix."""
    try:
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes, annot_kws={"size": 12})
        plt.xlabel('Predicted Label', fontsize=12)
        plt.ylabel('True Label', fontsize=12)
        plt.title('Confusion Matrix', fontsize=14)
        plt.xticks(fontsize=10)
        plt.yticks(fontsize=10)
        plt.tight_layout()
        plt.savefig(output_path, dpi=150)
        plt.close()
        logger.info(f"Confusion matrix saved to {output_path}")
    except Exception as e:
        logger.error(f"Error plotting confusion matrix: {e}", exc_info=True)

# --- Main Evaluation Orchestration (Simplified Call) ---

def run_evaluation_for_checkpoint(
    model_path,  # Path to checkpoint dir or base model ID
    dataset,
    checkpoint_output_dir,
    cache_dir
):
    """Runs the full evaluation pipeline for a single model/checkpoint."""
    os.makedirs(checkpoint_output_dir, exist_ok=True)
    logger.info(f"Output directory for this run: {checkpoint_output_dir}")

    # 1. Generate Predictions (Simplified call)
    responses = generate_predictions_for_checkpoint(
        model_path=model_path,
        dataset=dataset,
        batch_size=BATCH_SIZE,
        max_length=MAX_LENGTH,
        max_new_tokens=MAX_NEW_TOKENS,
        checkpoint_output_dir=checkpoint_output_dir,
        cache_dir=cache_dir,
    )

    # Save raw responses
    raw_responses_path = os.path.join(checkpoint_output_dir, "raw_responses.json")
    try:
        with open(raw_responses_path, "w") as f:
            json.dump(responses, f, indent=2)
        logger.info(f"Raw responses saved to {raw_responses_path}")
    except Exception as e:
        logger.error(f"Error saving raw responses: {e}", exc_info=True)

    # 2. Parse Predictions and Analyze Extraction
    predictions = []
    extraction_statuses = []
    generation_errors = 0
    logger.info("Parsing predictions...")
    for response in responses:
        if isinstance(response, str) and response.startswith("Error:"):
            pred, status = "Uncertain", "GENERATION_ERROR"
            generation_errors += 1
        else:
            pred, status = extract_answer_with_status(response)
        predictions.append(pred)
        extraction_statuses.append(status)
    if generation_errors > 0:
        logger.warning(f"Encountered {generation_errors} generation/prompt errors during processing.")
    analyze_extraction_methods(extraction_statuses)
    analyze_response_distribution(predictions)

    # 3. Get Ground Truth Labels
    labels = [example["label"] for example in dataset]

    # Length Check
    if len(predictions) != len(labels):
        logger.error(f"FATAL: Mismatch between predictions ({len(predictions)}) and labels ({len(labels)}). Evaluation metrics skipped.")
        results = {
            "error": "Prediction/Label length mismatch",
            "model_path": model_path,
            "num_predictions": len(predictions),
            "num_labels": len(labels)
        }
        results_path = os.path.join(checkpoint_output_dir, "evaluation_results_ERROR.json")
        try:
            with open(results_path, "w") as f:
                json.dump(results, f, indent=2)
            logger.error(f"Error details saved to {results_path}")
        except Exception as e:
            logger.error(f"Failed to save error details: {e}")
        return  # Exit evaluation for this checkpoint

    # 4. Calculate Metrics
    logger.info("Calculating metrics...")
    accuracy = accuracy_score(labels, predictions)
    report = classification_report(labels, predictions, labels=["True", "False", "Uncertain"], output_dict=True, zero_division=0)
    cm = confusion_matrix(labels, predictions, labels=["True", "False", "Uncertain"])

    # Log metrics
    logger.info(f"\nAccuracy: {accuracy:.4f}")
    logger.info("Classification Report (dict):")
    try:
        logger.info(json.dumps(report, indent=2))
    except Exception:
        logger.info(str(report))

    # 5. Save Results
    results = {
        "model_path": model_path,
        "is_checkpoint": model_path != BASE_MODEL_PATH,
        "accuracy": accuracy,
        "classification_report": report,
        "confusion_matrix": cm.tolist(),
        "extraction_method_summary": {status: extraction_statuses.count(status) for status in sorted(set(extraction_statuses))},
        "generation_prompt_errors": generation_errors
    }
    results_path = os.path.join(checkpoint_output_dir, "evaluation_results.json")
    try:
        with open(results_path, "w") as f:
            json.dump(results, f, indent=2)
        logger.info(f"Evaluation metrics saved to {results_path}")
    except Exception as e:
        logger.error(f"Error saving evaluation metrics: {e}", exc_info=True)

    # 6. Plot Confusion Matrix
    cm_path = os.path.join(checkpoint_output_dir, "confusion_matrix.png")
    plot_confusion_matrix(cm, classes=["True", "False", "Uncertain"], output_path=cm_path)

    # 7. Save Detailed Predictions (Optional but useful)
    logger.info("Saving detailed predictions...")
    prediction_details = []
    for i, (example, pred, label, response, status) in enumerate(zip(dataset, predictions, labels, responses, extraction_statuses)):
        prediction_details.append({
            "example_id": i,  # This is the index within the evaluated subset
            "premises": example["premises"],
            "conclusion": example["conclusion"],
            "true_label": label,
            "predicted_label": pred,
            "extraction_status": status,
            "is_correct": pred == label if status != "GENERATION_ERROR" and label in ["True", "False", "Uncertain"] else None,
            "model_response": response
        })
    details_path = os.path.join(checkpoint_output_dir, "prediction_details.json")
    try:
        with open(details_path, "w") as f:
            json.dump(prediction_details, f, indent=2)
        logger.info(f"Detailed predictions saved to {details_path}")
    except Exception as e:
        logger.error(f"Error saving detailed predictions: {e}", exc_info=True)

    logger.info(f"\n--- Finished evaluation for: {model_path} ---")

# --- Main Execution Logic ---
if __name__ == "__main__":
    logger.info("Starting FOLIO Evaluation Script (Revised Loading)")
    logger.info(f"Base Model ID: {BASE_MODEL_PATH}")
    logger.info(f"Training Checkpoint Dir: {TRAINING_OUTPUT_DIR}")
    logger.info(f"Cache Dir: {CACHE_DIR}")
    logger.info(f"Evaluation Output Root: {EVALUATION_ROOT_DIR}")
    logger.info(f"Batch Size: {BATCH_SIZE}")
    logger.info(f"Torch Dtype: {TORCH_DTYPE}")
    logger.info(f"Attention Impl.: {ATTN_IMPLEMENTATION}")

    os.makedirs(EVALUATION_ROOT_DIR, exist_ok=True)
    os.makedirs(CACHE_DIR, exist_ok=True)

    # --- Load Dataset Once ---
    folio_val_dataset = load_folio_dataset()
    if folio_val_dataset is None:
        logger.error("Exiting due to dataset loading failure.")
        exit(1)

    # --- Identify Models/Checkpoints ---
    models_to_evaluate = []  # Stores paths to load (base ID or checkpoint dir)

    # 1. Add the base model
    if BASE_MODEL_PATH:
        logger.info(f"Adding base model for evaluation: {BASE_MODEL_PATH}")
        models_to_evaluate.append(BASE_MODEL_PATH)
    else:
        logger.warning("No BASE_MODEL_PATH specified, skipping base model evaluation.")

    # 2. Find checkpoint directories
    if os.path.isdir(TRAINING_OUTPUT_DIR):
        checkpoint_pattern = os.path.join(TRAINING_OUTPUT_DIR, "checkpoint-*")
        checkpoint_dirs = glob.glob(checkpoint_pattern)
        valid_checkpoints = []
        for d in checkpoint_dirs:
            if os.path.isdir(d):
                # Basic check for adapter config - essential for auto-loading
                if os.path.exists(os.path.join(d, "adapter_config.json")):
                    try:
                        step = int(d.split('-')[-1])
                        valid_checkpoints.append((step, d))
                    except (ValueError, IndexError):
                        logger.warning(f"Skipping directory - cannot parse step number: {os.path.basename(d)}")
                else:
                    logger.warning(f"Skipping directory - missing 'adapter_config.json': {os.path.basename(d)}")

        # Sort by step number and add paths to list
        valid_checkpoints.sort(key=lambda x: x[0])
        if valid_checkpoints:
            logger.info(f"Found {len(valid_checkpoints)} valid checkpoints in {TRAINING_OUTPUT_DIR}:")
            for step, path in valid_checkpoints:
                logger.info(f"  - Adding checkpoint: {os.path.basename(path)} (Step: {step})")
                models_to_evaluate.append(path)
        else:
            logger.warning(f"No valid checkpoint directories with 'adapter_config.json' found in {TRAINING_OUTPUT_DIR}.")
    else:
        logger.warning(f"Training output directory '{TRAINING_OUTPUT_DIR}' not found. Skipping checkpoint evaluation.")

    # --- Run Evaluation Loop ---
    if not models_to_evaluate:
        logger.error("No models or checkpoints found to evaluate. Exiting.")
        exit(1)

    logger.info(f"\nStarting evaluation for {len(models_to_evaluate)} model(s)/checkpoint(s)...")

    for model_path in models_to_evaluate:
        # Determine output directory name
        if model_path == BASE_MODEL_PATH:
            model_name = BASE_MODEL_PATH.replace("/", "__") + "_BASE"
        else:
            model_name = os.path.basename(model_path)  # e.g., "checkpoint-1000"
        checkpoint_output_dir = os.path.join(EVALUATION_ROOT_DIR, model_name)

        # Run evaluation
        run_evaluation_for_checkpoint(
            model_path=model_path,
            dataset=folio_val_dataset,
            checkpoint_output_dir=checkpoint_output_dir,
            cache_dir=CACHE_DIR,
        )
        logger.info("-" * 70)  # Separator

    logger.info("\nAll evaluations complete!")
    logger.info(f"Results saved in subdirectories under: {EVALUATION_ROOT_DIR}")


Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.21s/it]
Generating Qwen2.5-7B-Instruct-1M: 100%|██████████| 26/26 [04:49<00:00, 11.14s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.15s/it]
Generating checkpoint-100: 100%|██████████| 26/26 [09:35<00:00, 22.12s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:07<00:00,  1.75s/it]
Generating checkpoint-200:  69%|██████▉   | 18/26 [07:44<03:57, 29.73s/it]

In [12]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import os
import warnings
import logging

# --- Configuration ---
# <<<--- REPLACE THESE --->>>
CHECKPOINT_PATH = "cai6307-henrykobs/model/Qwen-7B-GRPO/checkpoint-400" # Path to your specific checkpoint directory OR a Hub model ID
BASE_MODEL_FALLBACK = "Qwen/Qwen2.5-7B-Instruct-1M" # Base model used for training (important fallback for tokenizer)
CACHE_DIR = "cai6307-henrykobs/cache"          # Optional: Directory to cache downloaded models/tokenizers
# <<<-------------------->>>

# --- Advanced Configuration ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Use bf16 for Ampere GPUs (>=8.0), fp16 for older GPUs, or fp32 for CPU
if torch.cuda.is_available():
    if torch.cuda.get_device_capability()[0] >= 8:
        TORCH_DTYPE = torch.bfloat16
        print("Using torch.bfloat16 (Ampere+ GPU detected)")
    else:
        TORCH_DTYPE = torch.float16
        print("Using torch.float16 (Older GPU detected)")
else:
    TORCH_DTYPE = torch.float32
    print("Using torch.float32 (CPU detected)")

# Optional: Quantization (reduces memory usage, might affect performance slightly)
# Set USE_QUANTIZATION = True to enable 4-bit quantization
USE_QUANTIZATION = False # Set to True to try quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=TORCH_DTYPE, # Compute dtype should match TORCH_DTYPE
    bnb_4bit_use_double_quant=True,
) if USE_QUANTIZATION else None

# Optional: Flash Attention 2 (requires compatible GPU and installation)
try:
    import flash_attn
    _flash_attn_available = True
    ATTN_IMPLEMENTATION = "flash_attention_2"
    print("Flash Attention 2 available. Will use.")
except ImportError:
    _flash_attn_available = False
    ATTN_IMPLEMENTATION = "eager" # Default attention mechanism
    print("Flash Attention 2 not found. Using 'eager' attention.")

# Generation parameters
MAX_NEW_TOKENS = 512 # Max tokens the model should generate in one turn
TEMPERATURE = 0.6    # Controls randomness (higher = more random)
TOP_P = 0.9          # Nucleus sampling probability
DO_SAMPLE = True     # Whether to use sampling; False uses greedy decoding

# --- Logging ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Hugging Face Login (Optional) ---
# from huggingface_hub import login
# try:
#     login(token="hf_YOUR_TOKEN_HERE") # Replace or use CLI login
#     logger.info("Hugging Face login successful.")
# except Exception as e:
#     logger.warning(f"Hugging Face login failed or token not provided: {e}. Public models might still work.")

# --- Helper Function: Load Tokenizer ---
# (tokenizer loading function remains the same as before)
def get_chat_tokenizer(model_id_or_path, base_model_fallback, cache_dir):
    """Loads tokenizer, prioritizing the checkpoint dir, ensures pad token."""
    tokenizer_path_to_load = base_model_fallback
    load_source = "base model fallback"

    # Check if model_id_or_path is a directory with tokenizer files
    if os.path.isdir(model_id_or_path):
        if os.path.exists(os.path.join(model_id_or_path, "tokenizer_config.json")):
            tokenizer_path_to_load = model_id_or_path
            load_source = "checkpoint directory"
            logger.info(f"Found tokenizer config in {model_id_or_path}. Loading tokenizer from checkpoint.")
        else:
            logger.info(f"No tokenizer config in {model_id_or_path}. Attempting load from base model: {base_model_fallback}.")
    else:
        # It's likely a Hub ID, attempt to load directly
        tokenizer_path_to_load = model_id_or_path
        load_source = "provided ID/path"
        logger.info(f"Attempting to load tokenizer directly from: {model_id_or_path}")


    logger.info(f"Loading tokenizer from {load_source}: {tokenizer_path_to_load}...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_path_to_load,
            trust_remote_code=True,
            cache_dir=cache_dir,
            padding_side='left' # Important for generation
        )
        # Ensure pad token is set
        if tokenizer.pad_token_id is None:
            if tokenizer.eos_token_id is not None:
                logger.warning(f"Tokenizer lacks pad_token_id. Setting to eos_token_id ({tokenizer.eos_token_id}).")
                tokenizer.pad_token_id = tokenizer.eos_token_id
                tokenizer.pad_token = tokenizer.eos_token
            else:
                # Add a pad token if EOS is also missing
                logger.error("CRITICAL: Tokenizer lacks BOTH pad and EOS tokens. Adding '[PAD]'. This might require model resizing if not done during training.")
                tokenizer.add_special_tokens({'pad_token': '[PAD]'})
                # Note: Model resizing would need to happen after model load if this occurs

        logger.info(f"Tokenizer loaded successfully. Pad token ID: {tokenizer.pad_token_id}")

        # Check for chat template
        if not tokenizer.chat_template and not tokenizer.default_chat_template:
            logger.warning("Tokenizer does not have a chat_template or default_chat_template defined. Formatting might be incorrect.")
        elif not tokenizer.chat_template:
             tokenizer.chat_template = tokenizer.default_chat_template
             logger.info("Using default_chat_template for formatting.")

        return tokenizer
    except Exception as e:
        logger.error(f"Failed to load tokenizer from {tokenizer_path_to_load}: {e}", exc_info=True)
        if load_source != "base model fallback" and base_model_fallback:
            logger.warning(f"Falling back to loading tokenizer from base model: {base_model_fallback}")
            try:
                tokenizer = AutoTokenizer.from_pretrained(
                    base_model_fallback,
                    trust_remote_code=True,
                    cache_dir=cache_dir,
                    padding_side='left'
                )
                # Repeat pad token check for fallback
                if tokenizer.pad_token_id is None:
                    if tokenizer.eos_token_id is not None:
                        tokenizer.pad_token_id = tokenizer.eos_token_id
                        tokenizer.pad_token = tokenizer.eos_token
                    else:
                         tokenizer.add_special_tokens({'pad_token': '[PAD]'})
                logger.info(f"Fallback tokenizer loaded. Pad token ID: {tokenizer.pad_token_id}")
                if not tokenizer.chat_template and tokenizer.default_chat_template:
                     tokenizer.chat_template = tokenizer.default_chat_template
                return tokenizer
            except Exception as e_fallback:
                 logger.error(f"Fallback tokenizer loading also failed: {e_fallback}", exc_info=True)
                 return None
        else:
             return None


# --- Main Chat Logic ---
if __name__ == "__main__":
    logger.info(f"Starting interactive chat session.")
    logger.info(f"Loading checkpoint/model from: {CHECKPOINT_PATH}")
    logger.info(f"Using device: {DEVICE}")
    if USE_QUANTIZATION:
        logger.info("Quantization enabled (4-bit).")

    # 1. Load Tokenizer
    tokenizer = get_chat_tokenizer(CHECKPOINT_PATH, BASE_MODEL_FALLBACK, CACHE_DIR)
    if not tokenizer:
        logger.error("Could not load tokenizer. Exiting.")
        exit(1)

    # Define stop tokens based on the tokenizer
    stop_token_ids = []
    if tokenizer.eos_token_id is not None:
        stop_token_ids.append(tokenizer.eos_token_id)

    # Add other common chat stop tokens if they exist and are different
    im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
    if im_end_token_id != tokenizer.unk_token_id and im_end_token_id not in stop_token_ids:
         stop_token_ids.append(im_end_token_id)
         logger.info(f"Adding '<|im_end|>' (ID: {im_end_token_id}) to stop tokens.")

    logger.info(f"Stop token IDs for generation: {stop_token_ids}")


    # 2. Load Model
    logger.info("Loading model...")
    try:
        model = AutoModelForCausalLM.from_pretrained(
            CHECKPOINT_PATH,
            torch_dtype=TORCH_DTYPE,
            device_map=DEVICE if not USE_QUANTIZATION else None, # device_map not recommended with quantization directly here
            quantization_config=quantization_config, # Apply quantization if enabled
            attn_implementation=ATTN_IMPLEMENTATION,
            trust_remote_code=True,
            cache_dir=CACHE_DIR,
        )
        if DEVICE == "cpu" and USE_QUANTIZATION:
             logger.warning("Quantization is primarily for GPU acceleration and memory saving. Performance on CPU might be suboptimal.")
        elif USE_QUANTIZATION and DEVICE =="cuda":
             logger.info("Model loaded with 4-bit quantization on GPU.")
        elif DEVICE == "cuda":
            # If not using quantization and on GPU, ensure model is on the correct device
            if not hasattr(model, 'hf_device_map'): # Check if device_map placed it
                 model.to(DEVICE)
                 logger.info(f"Model loaded to {DEVICE} with dtype {TORCH_DTYPE}.")
        else:
             logger.info(f"Model loaded to {DEVICE} with dtype {TORCH_DTYPE}.")


        # Ensure model pad token ID is set
        if model.config.pad_token_id is None:
            if tokenizer.pad_token_id is not None:
                logger.warning(f"Model config lacks pad_token_id. Setting from tokenizer: {tokenizer.pad_token_id}")
                model.config.pad_token_id = tokenizer.pad_token_id
            else:
                 logger.error("Model config and tokenizer both lack pad_token_id. Generation might fail.")


        # Check for vocab resize necessity
        if model.config.vocab_size != len(tokenizer):
            logger.warning(f"Model vocab size ({model.config.vocab_size}) != Tokenizer vocab size ({len(tokenizer)}). Resizing model embeddings.")
            model.resize_token_embeddings(len(tokenizer))
            if model.config.pad_token_id is None and tokenizer.pad_token_id is not None:
                 model.config.pad_token_id = tokenizer.pad_token_id


        model.eval() # Set model to evaluation mode

    except Exception as e:
        logger.error(f"Failed to load the model: {e}", exc_info=True)
        exit(1)

    # 3. Initialize Conversation History
    messages = []
    # <<< CHANGE START >>>
    # Set the specific system prompt using a multi-line string
    system_prompt = """Respond in the following format, you have to adhere to the format, only output the final answer without **ANY** additional information in the "answer" box.

<think>
...
</think>
<answer>
...
</answer>"""
    if system_prompt: # Only add if system_prompt is not empty
        messages.append({"role": "system", "content": system_prompt})
        # Log only the first line for brevity if the prompt is long
        logger.info(f"System prompt set (first line): '{system_prompt.splitlines()[0]}...'")
    # <<< CHANGE END >>>


    logger.info("\nModel loaded. Type your message or 'quit' to exit.")
    print("-" * 30)

    # 4. Interaction Loop
    while True:
        try:
            # a. Get user input
            user_input = input("You: ")

            # b. Check for quit command
            if user_input.strip().lower() == 'quit':
                print("Exiting chat.")
                break

            # c. Add user input to history
            messages.append({"role": "user", "content": user_input})

            # d. Apply chat template
            try:
                prompt_string = tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True # Crucial! Signals model's turn
                )
            except Exception as e:
                logger.error(f"Error applying chat template: {e}. Using basic concatenation.")
                # Basic fallback (might not work well with chat models)
                prompt_string = ""
                # Manually add system prompt if fallback used and it exists
                if messages and messages[0]['role'] == 'system':
                    prompt_string += f"system: {messages[0]['content']}\n"
                    # Add remaining messages
                    prompt_string += "\n".join([f"{m['role']}: {m['content']}" for m in messages[1:]])
                else:
                    prompt_string += "\n".join([f"{m['role']}: {m['content']}" for m in messages])
                prompt_string += "\nassistant:" # Add the prompt for the assistant's turn


            # e. Tokenize the prompt
            inputs = tokenizer(prompt_string, return_tensors="pt", padding=True).to(DEVICE)

            # f. Generate response
            logger.info("Generating response...")
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=MAX_NEW_TOKENS,
                    eos_token_id=stop_token_ids,
                    pad_token_id=tokenizer.pad_token_id, # Use tokenizer's pad_token_id
                    do_sample=DO_SAMPLE,
                    temperature=TEMPERATURE if DO_SAMPLE else None, # Only use temp if sampling
                    top_p=TOP_P if DO_SAMPLE else None,           # Only use top_p if sampling
                )

            # g. Decode the generated part of the output
            output_tokens = outputs[0, inputs.input_ids.shape[1]:]
            model_response = tokenizer.decode(output_tokens, skip_special_tokens=True)

            # h. Print the response
            print("-" * 30)
            print(f"Assistant: {model_response}")
            print("-" * 30)


            # i. Add model response to history
            messages.append({"role": "assistant", "content": model_response})

            # Optional: Limit history length
            max_history = 10 # Keep last 10 turns (5 user, 5 assistant) + system prompt if used
            if len(messages) > max_history + (1 if messages and messages[0]['role']=='system' else 0):
                logger.debug("Trimming conversation history.")
                offset = 1 if messages and messages[0]['role']=='system' else 0
                # Keep system prompt + last max_history items
                messages = messages[0:offset] + messages[-(max_history):]


        except KeyboardInterrupt:
            print("\nExiting chat.")
            break
        except Exception as e:
            logger.error(f"An error occurred: {e}", exc_info=True)
            # break # Uncomment to exit on error

Using torch.bfloat16 (Ampere+ GPU detected)
Flash Attention 2 not found. Using 'eager' attention.


Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.25s/it]


------------------------------


You:  Given the following premises, determine if the conclusion is True, False, or Uncertain in the answer box.  Premises:When the Monkeypox virus occurs in a being, it may get Monkeypox. Monkeypox virus can occur in certain animals. Humans are mammals. Mammals are animals. Symptoms of Monkeypox include fever, headache, muscle pains, and tiredness. People feel tired when they get the flu.   Conclusion:No one gets the flu. 


------------------------------
Assistant: <think>
Premises provided:
1. If the Monkeypox virus occurs in a being, that may get Monkeypox.
2. Monkeypox virus can occur in certain animals.
3. Humans are mammals.
4. Mammals are animals.
5. Symptoms of Monkeypox include fever, headache, muscle pains, and tiredness.
6. People feel tired when they get the flu.
Conclusion: No one gets the flu.

To analyze this, we need to check if there's any logical connection between getting Monkeypox and not getting the flu based on the given premises. 

Premise 5 states that people feel tired when they get the flu, while premise  tells us that symptoms of Monkeypox include tiredness. However, tiredness alone does not necessarily imply that someone has the flu or Monkeypox; it could be due to other reasons. The premises do not establish a direct relationship or exclusivity between feeling tired and having the flu or Monkeypox.

Therefore, the premises do not provide enough information to conclude that no o

In [None]:
Given the following premises, determine if the conclusion is True, False, or Uncertain. Output only one of the 3 options.  Premises:When the Monkeypox virus occurs in a being, it may get Monkeypox. Monkeypox virus can occur in certain animals. Humans are mammals. Mammals are animals. Symptoms of Monkeypox include fever, headache, muscle pains, and tiredness. People feel tired when they get the flu.   Conclusion:No one gets the flu. 