In [None]:
import os
import re
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# ---------- Config ----------
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"  # or another compatible model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

BATCH_SIZE = 8
MAX_NEW_TOKENS = 5
TEMPERATURE = 0.2  # small nonzero helps reduce "always No" bias
OUT_FILE = "bluff_labeled.jsonl"
CHECKPOINT_FILE = "bluff_checkpoint.json"

SAVE_EVERY_BATCHES = 20

# Load your data here
with open("llm_raise.jsonl", "r", encoding="utf-8") as f:
    train_data = [json.loads(line) for line in f]

# ---------- Model Setup ----------
print(f"Loading model {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to(DEVICE)
model.eval()
print("Model loaded.\n")

# ---------- Helper Functions ----------
def save_progress(index):
    with open(CHECKPOINT_FILE, "w") as f:
        json.dump({"last_index": index}, f)

def load_progress():
    if os.path.exists(CHECKPOINT_FILE):
        with open(CHECKPOINT_FILE, "r") as f:
            return json.load(f).get("last_index", 0)
    return 0

def append_jsonl(outfile, records):
    with open(outfile, "a", encoding="utf-8") as f:
        for rec in records:
            f.write(json.dumps(rec, ensure_ascii=False) + "\n")

def clean_instruction(instruction: str) -> str:
    """
    Remove any trailing 'Your optimal action is' sections or redundant text.
    """
    cleaned = re.split(r"\n\nNow it is your turn.*", instruction, maxsplit=1)[0]
    pot_match = re.search(r"pot size is ([\d.]+) chips", instruction)
    if pot_match:
        pot_size = pot_match.group(1)
        cleaned += f"\nThe current pot size is {pot_size} chips."
    return cleaned.strip()

def build_bluff_prompt(instruction, llm_move):
    """
    Build a short, balanced prompt that directly requests Yes/No output.
    Includes a few-shot context to encourage realistic bluff judgment.
    """
    first_word = llm_move.split()[0] if llm_move.strip() else "Raise"
    clean_instr = clean_instruction(instruction)

    few_shot_examples = """Example 1:
Pot is 50 chips. Player has a strong hand (Ace, King on Ace-high board).
Player raises.
Is this a bluff? No

Example 2:
Pot is 50 chips. Player has 7-2 offsuit on Ace-high board and raises after opponent bets.
Is this a bluff? Yes
"""

    return f"""{few_shot_examples}

Definition: A bluff is when a player bets or raises with a weak hand in order to make the opponent fold a better hand.

{clean_instr}

The player decided to: {first_word}.
Is this a bluff? Reply with only 'Yes' or 'No'."""

def normalize_yes_no(text: str) -> str:
    """
    Normalize LLM output to just 'Yes' or 'No'.
    """
    text = text.strip().lower()
    if text.startswith("yes"):
        return "Yes"
    if text.startswith("no"):
        return "No"
    return text  # fallback if unclear

# ---------- Main Loop ----------
start_idx = load_progress()
N = len(train_data)
print(f"Bluff labeling: {N} records, starting from index {start_idx}")

i = start_idx
batches_done = 0

try:
    while i < N:
        end = min(i + BATCH_SIZE, N)
        batch = train_data[i:end]

        prompts = [build_bluff_prompt(r["instruction"], r["llm_move"]) for r in batch]

        # Debug preview
        if batches_done == 0:
            print("\n--- DEBUG EXAMPLE ---")
            print("Prompt:\n", prompts[0][:800])
            print("--------------------\n")

        try:
            inputs = tokenizer(
                prompts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            )
            inputs = {k: v.to(DEVICE) for k, v in inputs.items()}

            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=MAX_NEW_TOKENS,
                    pad_token_id=model.config.pad_token_id,
                    do_sample=True if TEMPERATURE > 0 else False,
                    temperature=TEMPERATURE if TEMPERATURE > 0 else None,
                )

            decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            decoded = decoded[:len(prompts)]

            records = []
            for orig, full_text, prompt in zip(batch, decoded, prompts):
                # Remove prompt portion if model echoed it
                clean_text = full_text[len(prompt):].strip() if full_text.startswith(prompt) else full_text.strip()

                clean_answer = normalize_yes_no(clean_text)

                records.append({
                    **orig,
                    "is_bluff": clean_answer
                })

            append_jsonl(OUT_FILE, records)

            # Debug sample
            if batches_done < 2:
                print("--- DEBUG OUTPUT ---")
                print(f"Raw Model Output:\n{decoded[0]}\n")
                print(f"Cleaned Model Response: {clean_text}")
                print(f"Final is_bluff: {clean_answer}")
                print("--------------------\n")

            i = end
            batches_done += 1

            if batches_done % SAVE_EVERY_BATCHES == 0:
                save_progress(i)
                print(f"Checkpoint saved at index {i}")

        except RuntimeError as e:
            err_str = str(e).lower()
            print(f"Runtime error at batch starting {i}: {e}")
            if "out of memory" in err_str or "cuda" in err_str:
                torch.cuda.empty_cache()
                if BATCH_SIZE <= 1:
                    raise
                old_bs = BATCH_SIZE
                BATCH_SIZE = max(1, BATCH_SIZE // 2)
                print(f"Reducing batch size from {old_bs} to {BATCH_SIZE} and retrying at index {i}")
                continue
            else:
                raise

except KeyboardInterrupt:
    print("Interrupted by user — saving progress.")
    save_progress(i)
    raise

except Exception as e:
    print(f"Unexpected error: {e} — saving progress at index {i}")
    save_progress(i)
    raise

else:
    print("Completed all prompts. Final checkpointing...")
    save_progress(i)