# Change baseline

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import os

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "3"

model_name = "Qwen/QwQ-32B"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,            # enable 4-bit (QLoRA-style) weights
    bnb_4bit_quant_type="nf4",    # NF4 gives the best accuracy for most LLMs
    bnb_4bit_use_double_quant=True, # optional: second quantisation pass to save ~0.4 bits/param
    bnb_4bit_compute_dtype=torch.bfloat16  # faster matmuls on recent GPUs; fall back to float16 if needed
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",      # let Accelerate split layers across all visible GPUs
    quantization_config=bnb_config,
    torch_dtype="auto",     # keeps non-linear layers in their original dtype
    trust_remote_code=True  # Qwen models need their custom code
)

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 14/14 [01:04<00:00,  4.60s/it]


In [None]:
prompt = "How many r's are in the word \"strawberry\""
messages = [
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=32768
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)

# Reward Shaping

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import re
from typing import List, Optional
from tqdm import tqdm
import random

# Project‑level helpers 
from utils import _sanitize, _numeric_equiv, _strip_markup, _to_float, system_prompt
from ..prm_dataset.config import PRMConfig

class MCReward:
    STEP_PATTERN = re.compile(
    r"""^[\s>#*\-]*          # optional markdown/bullet symbols
        Step\s*              # word 'Step' (case-insensitive)
        (\d+)                # capture step number
        \s*[:.\-]            # separator (: . or -)
    """,
    re.IGNORECASE | re.VERBOSE,
    )
    ANSWER_PATTERN = re.compile(
        r"""^[\s>#*\-]*          # optional markdown/bullet symbols
            Answer               # word 'Answer'
            \s*[:.\-]\s*         # separator
            (.+?)\s*$            # capture everything after
        """,
        re.IGNORECASE | re.MULTILINE | re.VERBOSE,
    )
    ## Masked rewards ##
    OP_TOKENS = ["add", "plus", "sum", "subtract", "minus",
             "multiply", "times", "product", "divide", "quotient"]
    _MASK_PATTERN = re.compile(
        r"""
        (?:
        # {ops_pattern}|                # operator patterns
            \b\d+(?:\.\d+)?\b         # integers / decimals
          | \b\d+/\d+\b                 # simple fractions
        #   | \b[a-zA-Z]\b                 # single‑letter variables
        )
        """,
        re.VERBOSE,
    )

    def __init__(self, config: "PRMConfig", model, tokenizer):
        self.config = config
        self.model = model
        self.tokenizer = tokenizer
        self.device = next(model.parameters()).device

    # Function to generate one or more step-by-step solutions for a given question.
    def gsm8k_solutions(self, question: str, gold_solution: str):
        # 1. Split lines *before* the final answer marker (#### …)
        lines: List[str] = []
        gold_answer: str = ""
        _ANSWER_RE = re.compile(r"####\s*(.+?)\s*$")

        for raw_ln in gold_solution.splitlines():
            ln = raw_ln.strip()
            if not ln:
                continue  # skip empty
            ans_match = _ANSWER_RE.match(ln)
            if ans_match:
                gold_answer = ans_match.group(1).strip()
                break  # everything after #### is ignored
            lines.append(ln)

        if not gold_answer:
            raise ValueError("Could not find final answer marker '#### <answer>' in gold_solution.")

        # 2. Prefix each explanatory line with "Step i:"
        solution_steps = [f"Step {i + 1}: {txt}" for i, txt in enumerate(lines)]
        return {
            "question": question,
            "solution": solution_steps,
            "gold_answer": gold_answer,
        }

    # Function to parse a solution text into steps and final answer.
    def _extract_answer(self, text: str) -> Optional[str]:
        """Try multiple heuristics / regexes to pull out an answer string."""
        # Primary regex (robust to Answer:, Answer ‑, etc.)
        match = self.ANSWER_PATTERN.search(text)
        if match:
            return _sanitize(match.group(1))
        
        # Fallback 1: last non‑empty line if it looks simple / numeric
        lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
        if lines:
            candidate = lines[-1]
            if re.search(r"\d", candidate):  # contains digit
                return _sanitize(candidate)

        # Fallback 2: look for last line that starts with 'Answer'
        for line in reversed(text.splitlines()):
            if line.strip().lower().startswith("answer"):
                return _sanitize(line.split("Answer", 1)[-1])
        
        return None

    def parse_solution(self, solution_text: str):
        """Split each step to start with 'Step X:' and the answer to start with 'Answer:'."""
        steps = []
        # Split by lines to identify steps and answer
        for line in solution_text.splitlines():
            line = line.strip()
            if not line:
                continue
            if self.STEP_PATTERN.match(line):
                cleaned = re.sub(r'^[\s>#*\-]+', '', line)
                steps.append(cleaned)
            answer = self._extract_answer(solution_text)
        return steps, answer
    
    # Function to estimate intermediate rewards for each step via rollouts.
    def compute_step_rewards(self, question, sys_prompt, steps, gold_answer):
        """
        For each prefix ending at a given step in 'steps', generate rollouts and compute the reward 
        (fraction of rollouts ending in the correct answer). Returns a list of reward values corresponding to each step.
        """
        rewards = []
        total_steps = len(steps)

        # Pre‑encode static prefix (sys_prompt + question) once for efficiency
        base_prompt = f"{sys_prompt}\n\nProblem: {question}\n"
        base_ids = self.tokenizer.encode(base_prompt, return_tensors="pt").to(self.device)

        for i in range(total_steps):
            prefix_tokens = self.tokenizer.encode("\n".join(steps[: i + 1]) + "\n", return_tensors="pt").to(self.device) # steps up to current step i (0-indexed)
            # Decide how to prompt the next part:
            if i < total_steps - 1:
                next_label = f"Step {i + 2}:"
            else:
                next_label = "Answer:"
            cont_ids = self.tokenizer.encode(next_label, return_tensors="pt").to(self.device)
            # Build full prefix ids (avoid Python concat inefficiency by cat)
            prefix_ids = torch.cat([base_ids, prefix_tokens, cont_ids], dim=-1)
            rollout_outputs = self.model.generate(
                prefix_ids,
                max_new_tokens=self.config.max_new_tokens,
                do_sample=True,
                num_return_sequences=self.config.num_rollouts,
                temperature=0.8,
                top_p=0.8,
                pad_token_id=self.tokenizer.eos_token_id
            )
            new_token_start = prefix_ids.shape[-1] 
            # Check each rollout's final answer against the gold answer
            correct_count = 0
            for idx, seq in enumerate(rollout_outputs):
                completion = self.tokenizer.decode(seq[new_token_start:], skip_special_tokens=True)
                pred_answer = self._extract_answer(completion)
                # print(f"[{i+1}-th Step, {idx}-th Original Rollout]", completion, "Pred Answer", pred_answer)
                if pred_answer is not None and _numeric_equiv(pred_answer, gold_answer):
                    correct_count += 1
            reward = correct_count / float(self.config.num_rollouts)
            rewards.append(reward)
        return rewards
    
    # Masked solution paths
    def model_masking(self, text: str, *, max_new_tokens: int = 64) -> str:
        prompt = "In the sentence below, mask any word or expression that seems crucial for solving the math step. This may include key numbers, variables, or action words (like operations), but you should decide what matters. Replace each important item with '[MASKED]'. Keep everything else unchanged. Return ONE line.\n\nSentence: \"{sent}\"\nRewritten:".format(sent=text)
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        out_ids   = self.model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.2, top_p=0.2,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        return self.tokenizer.decode(out_ids[0][input_ids.shape[-1]:],
                                     skip_special_tokens=True).strip()

    def perturbed_step_rewards(self, question: str, sys_prompt: str, steps: List[str], gold_answer: str, use_llm: bool = True) -> List[float]:
        """Compute MC correctness rates *after masking* the current step.
        Each step `i` is replaced with a *perturbed* version where important tokens (numbers, fractions, single‑letter variables) are substituted by the literal string ``[MASKED]``. All preceding steps remain intact.
        """
        ptb_rewards: List[float] = []
        total_steps = len(steps)
        base_prompt = f"{sys_prompt}\n\nProblem: {question}\n"
        base_ids = self.tokenizer.encode(base_prompt, return_tensors="pt").to(self.device)

        for i in range(total_steps):
            # 1. Perturb *only* step i
            orig_step = steps[i] 
            step_match = re.match(r"^[\s>#*\-]*Step\s*\d+\s*[:.\-]\s*", orig_step, flags=re.I)
            prefix = step_match.group(0) if step_match else ""
            # ② 나머지 부분(body)만 마스킹
            body   = steps[i][len(prefix):]                       # 접두사 뒷부분
            if use_llm:
                masked_body = self.model_masking(body)
            else:
                masked_body = self._MASK_PATTERN.sub("[MASKED]", body)
            # ③ 접두사 + 마스킹된 body
            masked_step = prefix + masked_body    
            ptb_prefix_steps = steps[:i] + [masked_step]
            # print("perturbed step:", ptb_prefix_steps)

            prefix_tokens = self.tokenizer.encode("\n".join(ptb_prefix_steps) + "\n", return_tensors="pt").to(self.device)
            next_label = f"Step {i + 2}:" if i < total_steps - 1 else "Answer:"
            cont_ids = self.tokenizer.encode(next_label, return_tensors="pt").to(self.device)
            prefix_ids = torch.cat([base_ids, prefix_tokens, cont_ids], dim=-1)

            rollout_outputs = self.model.generate(
                prefix_ids,
                max_new_tokens=self.config.max_new_tokens,
                do_sample=True,
                num_return_sequences=self.config.num_rollouts,
                temperature=0.8,
                top_p=0.8,
                pad_token_id=self.tokenizer.eos_token_id,
            )
            new_token_start = prefix_ids.shape[-1]
            correct_count = 0
            for idx, seq in enumerate(rollout_outputs):
                completion = self.tokenizer.decode(seq[new_token_start:], skip_special_tokens=True)
                pred_answer = self._extract_answer(completion)
                # print(f"Masked [{i+1}-th Step, {idx}-th Rollout]", completion, "Pred Answer", pred_answer)
                if pred_answer is not None and _numeric_equiv(pred_answer, gold_answer):
                    correct_count += 1
            ptb_rewards.append(correct_count / float(self.config.num_rollouts))
        return ptb_rewards

    def _generate_rollouts(self, prompt: str) -> list[str]:
        """
        vLLM 에서 동일 프롬프트를 n번(=num_rollouts) 샘플링해서 텍스트만 반환
        """
        outs = self.llm.generate([prompt], self.sparams)   # 배치 길이 1
        return [o.outputs[ri].text for o in outs for ri in range(len(o.outputs))]

    # Build datasets based on input datas
    def build_datasets_gsm8k(self, *, split: str = "train", start: int = 0, take: int | None):
        _ANSWER_RE = re.compile(r"####\s*(.+?)\s*$")

        rollout_pr = system_prompt("rollout")
        ds = load_dataset("openai/gsm8k", "main", split=split)
        if take is not None:
            ds = ds.shuffle(seed=self.config.seed).select(range(start, start+take))
        else:
            ds = ds.shuffle(seed=self.config.seed).select(range(start, len(ds)))

        csr, psr   = self.compute_step_rewards, self.perturbed_step_rewards
        sanitize   = _sanitize
        use_llm    = self.config.use_llm
        dataset    = []
        
        for sample in tqdm(ds, desc="Building GSM-8K reward-dataset"):
            # ── (1) extract step solutions ──────────────────────────────────────────
            q_txt   = sample["question"]
            g_sol   = sample["answer"]
            lines, gold_ans = [], None
            for ln in g_sol.splitlines():
                ln = ln.strip()
                if not ln:
                    continue
                m = _ANSWER_RE.match(ln)
                if m:
                    gold_ans = sanitize(m.group(1))
                    break
                lines.append(ln)
            if gold_ans is None:
                raise ValueError("gold answer not found for sample")
            steps = [f"Step {i+1}: {t}" for i, t in enumerate(lines)]

            # ── (2) compute rewards ───────────────────────────────────────────────────
            ori = csr(q_txt, rollout_pr, steps, gold_ans)
            ptb = psr(q_txt, rollout_pr, steps, gold_ans, use_llm)
            if len(ptb) != len(ori):
                ptb = ptb[: len(ori)]
            contrib = [round(o - p, 4) for o, p in zip(ori, ptb)]

            #  ── (3) Append entry ───────────────────────────────────────────
            entry = {
                    "question":      q_txt,
                    "completion":    steps,
                    "ori_rewards":   ori,
                    "ptb_rewards":   ptb,
                    "contributions": contrib,
                    "answer":        gold_ans,
                    "gold_answer":   gold_ans,
                }
            dataset.append(entry)
            # print(entry)
        return dataset

    def build_datasets_math(self, *, split: str = "train", start: int = 0, take: int | None):
        boxed_re   = re.compile(r'\\boxed\{(.+?)\}', re.S)
        sent_split = re.compile(r'\.(?!\d)(?=\s|$)')   # 소수점·수식 내부 마침표 무시

        rollout_prompt = system_prompt("rollout")
        ds = load_dataset("HuggingFaceTB/MATH", "all", split=split)

        # shuffle & take
        if take is not None:
            ds = ds.select(range(start, start+take))
        else:
            ds = ds.select(range(start, len(ds)))

        # (alias) time optimize
        csr, psr   = self.compute_step_rewards, self.perturbed_step_rewards
        sanitize   = _sanitize
        use_llm    = self.config.use_llm
        dataset    = []

        for sample in tqdm(ds, desc="Building MATH reward-dataset"):
            # ── (1) extract step solutions ──────────────────────────────────────────
            full_sol   = sample["solution"]
            m          = boxed_re.search(full_sol)
            gold_ans   = sanitize(m.group(1)) if m else None
            sol_wo_box = boxed_re.sub("", full_sol)
            raw_steps  = [s.strip() for s in sent_split.split(sol_wo_box) if s.strip()]
            steps      = [f"Step {i+1}: {s}" for i, s in enumerate(raw_steps)]

            # ── (2) compute rewards ───────────────────────────────────────────────────
            ori = csr(sample["problem"], rollout_prompt, steps, gold_ans)
            ptb = psr(sample["problem"], rollout_prompt, steps, gold_ans, use_llm)
            if len(ptb) != len(ori):
                ptb = ptb[: len(ori)]
            contrib = [round(o - p, 4) for o, p in zip(ori, ptb)]

            # ── (3) Append entry ───────────────────────────────────────────
            entry = {
                "question":      sample["problem"],
                "completion":    steps,
                "ori_rewards":   ori,
                "ptb_rewards":   ptb,
                "contributions": contrib,
                "answer":        gold_ans,
                "gold_answer":   gold_ans,
                "level":         sample["level"],
                "type":          sample["type"],
            }
            dataset.append(entry)
            print(entry)
        return dataset

    def _sequence_nll(self, prompt: str, target: str) -> float:
        with torch.no_grad():
            full = prompt + target
            inputs = tokenizer(full, return_tensors="pt").to(device)
            prompt_len = len(tokenizer(prompt)["input_ids"])
            
            # Mask out the prompt tokens so loss is only on the target
            labels = inputs["input_ids"].clone()
            labels[:, :prompt_len] = -100          # ignore index
            logits = model(**inputs).logits
            loss   = F.cross_entropy(
                        logits.view(-1, logits.size(-1)),
                        labels.view(-1),
                        reduction="none"
                    )
            # keep only target positions
            target_loss = loss[labels.view(-1) != -100]
            return (target_loss.sum() / torch.log(torch.tensor(2.0))).item()  # bits
        
    def _info_gain(self, context, step, answer):
        no_step = self._sequence_nll(context, answer)
        with_step = self._sequence_nll(context + step, answer)
        return no_step - with_step

    def _step_entropy(self, context, step):
        """Cross-entropy of a step sequence (bits)."""
        return self._sequence_nll(context, step)

In [101]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.nn import functional as F
import math

# MODEL_NAME = "Qwen/Qwen1.5-7B-Chat"      # or your qwen2.5-math-7b path
device = "cuda" if torch.cuda.is_available() else "cpu"

def sequence_nll_ver1(prompt: str, target: str) -> float:
    """Cross-entropy (in bits) of 'target' tokens given the 'prompt'."""
    with torch.no_grad():
        prefix, value = target.split("Answer:")
        ans_prefix, ans_value = "Answer: ", value.lstrip()
        prompt = prompt + ans_prefix
        full = prompt + ans_value

        # full = prompt + target
        inputs = tokenizer(full, return_tensors="pt").to(device)
        prompt_len = len(tokenizer(prompt)["input_ids"])
        
        # Mask out the prompt tokens so loss is only on the target
        labels = inputs["input_ids"].clone()
        labels[:, :prompt_len] = -100          # ignore index
        
        # 디버그: 정답 부분만 확인
        answer_part = labels[0, prompt_len:]
        valid_tokens = answer_part[answer_part != -100]
        print(f"Answer tokens: {tokenizer.decode(valid_tokens)}")
        print(f"Answer token count: {len(valid_tokens)}")
        
        logits = model(**inputs).logits
        loss   = F.cross_entropy(
                    logits.view(-1, logits.size(-1)),
                    labels.view(-1),
                    reduction="none"
                 )
        # keep only target positions
        target_loss = loss[labels.view(-1) != -100]
        print(f"Per-token losses: {target_loss.tolist()}")

        for i, tok_id in enumerate(valid_tokens):
            tok = tokenizer.decode([tok_id])
            prob = torch.softmax(logits[0, -len(valid_tokens)+i], dim=-1)[tok_id].item()
            print(f"{tok!r}  p={prob:.3e},  -ln p={-math.log(prob):.3f}")

        return (target_loss.sum() / torch.log(torch.tensor(2.0))).item()  # bits

def sequence_nll(prompt: str, target: str):
    """NLL(bits) of `target` given `prompt`. The prompt **does not** contain the target portion."""
    with torch.no_grad():
        prefix, value = target.split("Answer:")
        ans_prefix, ans_value = "Answer: ", value.lstrip()
        prompt = prompt + ans_prefix
        full = prompt + ans_value
        # full = prompt + target
        full_ids    = tokenizer(full, return_tensors="pt").to(device)["input_ids"]
        prompt_len  = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])

        labels = full_ids.clone()
        labels[:, :prompt_len] = -100          # ignore prompt tokens

        logits = model(full_ids).logits
        loss   = F.cross_entropy(
                    logits.view(-1, logits.size(-1)),
                    labels.view(-1),
                    reduction="sum"             # total bits, not mean
                 ) / math.log(2)                # nats → bits
        return loss.item()

def info_gain(context, step, answer):
    """I(S;A|c) = H(A|c) - H(A|c,S)."""
    no_step = sequence_nll(context, answer)
    with_step = sequence_nll(context + step, answer)
    return no_step - with_step

def step_entropy(context, step):
    """Cross-entropy of a step sequence (bits)."""
    return sequence_nll(context, step)

def show_topk(prompt, k=5):
    ids = tokenizer(prompt, return_tensors="pt").to(device)["input_ids"]
    with torch.no_grad():
        logits = model(ids).logits[0, -1]
    probs, idx = torch.topk(torch.softmax(logits, dim=-1), k)
    print([ (tokenizer.decode([i]), float(p)) for p,i in zip(probs, idx) ])


In [103]:
# context = """Solve the given problem with step by step reasoning and write final answer in the format of "Answer: <answer>". Problem: What is the sum of the digits of the number 84?\n"""
# step1   = "Step 1: The tens digit of 84 is 8.\n"
# step2   = "Step 2: The ones digit of 84 is 4.\n"
# step3   = "Step 3: Add the digits: 8 + 4 = 12.\n"
# answer  = "\nAnswer:\n12"

context = """Solve the given problem with step by step reasoning and write final answer in the format of "Answer: <answer>". Problem: What is (5 + 3) × 2 - 4?\n"""
step1   = "Step 1: Compute inside the parentheses first: 5 + 3 = 8.\n"
step2   = "Step 2: Multiply the result by 2: 8 × 2 = 16.\n"
step3   = "Step 3: Subtract 4: 16 - 4 = 12.\n"
# step3   = "Step 3: What should I do next?\n"
answer  = "Answer: 12"

# context = """Solve the given problem with step by step reasoning and write final answer in the format of "Answer: <answer>". Problem: What is 1/2 + 1/4 + 1/4?\n"""
# step1   = "Step 1: Add 1/4 and 1/4 first: 1/4 + 1/4 = 1/2.\n"
# step2   = "Step 2: Now add 1/2 + 1/2 = 1.\n"
# step3   = "Step 3: Final result is 1.\n"
# answer  = "Answer: 1"

# print("Answer tokenized:", tokenizer.tokenize(answer))
# print("Answer IDs:", tokenizer.encode(answer, add_special_tokens=False))
# show_topk(context)                  # before any step
# show_topk(context + step1)          # after Step 1
# show_topk(context + step1 + step2)  # etc.
# show_topk(context + step1 + step2 + step3) 

print("H(Answer|context) ≈", step_entropy(context, answer), "bits")
print("H(Answer|context+Step1) ≈", step_entropy(context + step1, answer), "bits")
print("H(Answer|context+Step1+Step2) ≈", step_entropy(context + step1 + step2, answer), "bits")
print("H(Answer|context+Step1+Step2+Step3) ≈", step_entropy(context + step1 + step2 + step3, answer), "bits")
print("Information gain of Step1:", info_gain(context, step1, answer), "bits")
print("Information gain of Step2:", info_gain(context + step1, step2, answer), "bits")
print("Information gain of Step3:", info_gain(context + step1 + step2, step3, answer), "bits")

# prompt_ids = tokenizer(context + step1 + step2, return_tensors="pt").to(device)["input_ids"]
# out = model.generate(
#         prompt_ids,
#         max_new_tokens=30,
#         return_dict_in_generate=True,
#         output_scores=True,
#         temperature=0.2,     # deterministic
#         do_sample=True
#      )

# gen_ids   = out.sequences[0, prompt_ids.size(-1):]  # 생성된 부분
# scores    = out.scores                              # 길이 = #gen_tokens

# # 토큰별 −log2 p 계산
# nll_bits = 0.0
# for t, (logits, tok_id) in enumerate(zip(scores, gen_ids)):
#     probs = logits.squeeze(0).softmax(dim=-1)
#     tok_id = tok_id.item()
#     p     = probs[tok_id].item()
#     nll_bits += -math.log2(p)
#     print(f"{t:02d}  {tokenizer.decode([tok_id])}  p={p:.3e}  −log2 p={-math.log2(p):.3f}")

# print(f"Total NLL(bits) for generated segment = {nll_bits:.3f}")

H(Answer|context) ≈ 25.375 bits
H(Answer|context+Step1) ≈ 28.5 bits
H(Answer|context+Step1+Step2) ≈ 30.875 bits
H(Answer|context+Step1+Step2+Step3) ≈ 25.25 bits
Information gain of Step1: -3.125 bits
Information gain of Step2: -2.375 bits
Information gain of Step3: 5.625 bits


In [None]:
def nll_bits(prompt: str, target: str) -> float:
    """Average NLL of 'target' given 'prompt', in **bits** per target token."""
    with torch.no_grad():
        full   = prompt + target
        inputs = tokenizer(full, return_tensors="pt").to(device)
        Lp     = len(tokenizer(prompt)["input_ids"])

        labels = inputs["input_ids"].clone()
        labels[:, :Lp] = -100          # ignore prompt tokens
        logits = model(**inputs).logits

        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
            reduction="sum"            # total nll (nats)
        )
        ntoks = (labels != -100).sum()
        nll_nats = loss.item() / ntoks
        return nll_nats / math.log(2)  # nats → bits

# --------------------------------------------------------------------
context = "Solve the given problem with step by step reasoning and write final answer in the format of \"Answer: <answer>\". Problem: What is (5 + 3) × 2 - 4?\n"
step1   = "Step 1: Compute inside the parentheses first: 5 + 3 = 8.\n"
step2   = "Step 2: Multiply the result by 2: 8 × 2 = 10.\n"
step3   = "Step 3: Subtract 4: 10 - 4 = 6.\n"
answer  = "Answer: 12"

# 1) 엔트로피(=평균 NLL) 계산
nll0 = nll_bits(context, answer)
nll1 = nll_bits(context + step1,                 answer)
nll2 = nll_bits(context + step1 + step2,         answer)
nll3 = nll_bits(context + step1 + step2 + step3, answer)

# 2) mutual information
mi1 = nll0 - nll1
mi2 = nll1 - nll2
mi3 = nll2 - nll3

print(f"H(A|c)        = {nll0:.4f} bits")
print(f"H(A|c,S1)     = {nll1:.4f} bits   →  I(S1;A|c) = {mi1:.4f} bits")
print(f"H(A|c,S1,S2)  = {nll2:.4f} bits   →  I(S2;A|c,S1) = {mi2:.4f} bits (increment)")
print(f"H(A|c,S1,S2,S3)= {nll3:.4f} bits   →  I(S3;A|c,S1,S2)= {mi3:.4f} bits")

H(A|c)        = 16.8074 bits
H(A|c,S1)     = 20.7748 bits   →  I(S1;A|c) = -3.9674 bits
H(A|c,S1,S2)  = 16.3746 bits   →  I(S2;A|c,S1) = 4.4002 bits (increment)
H(A|c,S1,S2,S3)= 18.2501 bits   →  I(S3;A|c,S1,S2)= -1.8755 bits


In [None]:
import math, torch, random
from transformers import AutoTokenizer, AutoModelForCausalLM

LOG2E = 1 / math.log(2)

# ─────────────────────── 1. 기본 NLL 함수 ─────────────────────
def nll_bits(prompt: str, target: str, avg=True) -> float:
    """
    NLL(prompt→target) in bits.
    If avg=True, return *average* bits / target-token,
    else return *total* bits of the sequence.
    """
    full   = prompt + target
    inputs = tokenizer(full, return_tensors="pt").to(device)
    Lp     = len(tokenizer(prompt)["input_ids"])

    labels = inputs["input_ids"].clone()
    labels[:, :Lp] = -100
    with torch.no_grad():
        logits = model(**inputs).logits
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
            reduction="sum"          # total nats
        )
    ntoks = (labels != -100).sum()
    bits  = loss.item() * LOG2E      # total bits
    return bits / ntoks if avg else bits

# ─────────────────────── 2-A. MC Sampling 방식 ─────────────────────
def entropy_bits_mc(prompt: str, k: int = 5, max_new: int = 4, temperature: float = 0.7, top_p: float = 0.9) -> float:
    """
    Monte-Carlo estimate of H(A|prompt) [bits per token].
    Generates k continuations, then 평균[-log₂ p(sample | prompt)].
    """
    input_ids = tokenizer(prompt, return_tensors="pt").to(device).input_ids
    BITS = []
    for _ in range(k):
        out_ids = model.generate(
            input_ids,
            max_new_tokens=max_new,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )[0][input_ids.size(1):]           # strip prompt

        sample = tokenizer.decode(out_ids, skip_special_tokens=True)
        BITS.append(entropy_bits_exact(prompt, sample))
    return sum(BITS) / k                  # bits/token

# ─────────────────────── 2-B. Exact 토큰-엔트로피 ─────────────────────
def entropy_bits_exact(prompt: str, target: str) -> float:
    """True H(A|prompt) in bits/token, by ∑_t H(p_t). Memory-intensive: stores full probs tensor."""
    full   = prompt + target
    inputs = tokenizer(full, return_tensors="pt").to(device)
    Lp     = len(tokenizer(prompt)["input_ids"])

    with torch.no_grad():
        logits = model(**inputs).logits.float()      # [1,L,V]

    probs = logits.softmax(-1)                      # [...,V]
    token_H = -(probs * probs.log()).sum(-1) * LOG2E  # bits/token

    mask = torch.zeros_like(inputs["input_ids"], dtype=torch.bool)
    mask[:, Lp:] = True                             # answer tokens
    return token_H[mask].sum().item() / mask.sum().item()

# ─────────────────────── 3. 프롬프트 정의 ─────────────────────
context = """Solve the given problem with step by step reasoning in the format of "Step k: <k-th rationale>" and write final answer in the format of "Answer: <answer>". Problem: What is the sum of the digits of the number 84?\n"""
step1   = "Step 1: The tens digit of 84 is 8.\n"
step2   = "Step 2: The ones digit of 84 is 4.\n"
step3   = "Step 3: Add the digits: 8 + 4 = 12.\n"
answer  = "Answer: 12"

# ─────────────────────── 4. 엔트로피 & MI 계산 ─────────────────────
def mi_report(entropy_fn, label: str):
    H0 = entropy_fn(context,                  answer)            # H(A|c)
    H1 = entropy_fn(context+step1,            answer)            # H(A|c,S1)
    H2 = entropy_fn(context+step1+step2,      answer)            # H(A|c,S1,S2)
    H3 = entropy_fn(context+step1+step2+step3,answer)            # ...
    print(f"\n=== {label} ===")
    print(f"H(A|c)          = {H0:.4f} bits")
    print(f"H(A|c,S1)       = {H1:.4f} bits   →  I(S1;A|c)          = {H0-H1:+.4f}")
    print(f"H(A|c,S1,S2)    = {H2:.4f} bits   →  ΔI(S2|prev)        = {H1-H2:+.4f}")
    print(f"H(A|c,S1,S2,S3) = {H3:.4f} bits   →  ΔI(S3|prev)        = {H2-H3:+.4f}")

# 4-A. MC 샘플링 (k 줄이려면 30~50 도 OK)
mi_report(lambda p,t: entropy_bits_mc(p), "Monte-Carlo (k=5)")

# 4-B. Exact  (target이 2-토큰이므로 메모리 부담 ↓)
mi_report(entropy_bits_exact, "Exact token-entropy")



=== Monte-Carlo (k=5) ===
H(A|c)          = 0.6414 bits
H(A|c,S1)       = 0.0133 bits   →  I(S1;A|c)          = +0.6282
H(A|c,S1,S2)    = 0.4116 bits   →  ΔI(S2|prev)        = -0.3983
H(A|c,S1,S2,S3) = 0.1660 bits   →  ΔI(S3|prev)        = +0.2455

=== Exact token-entropy ===
H(A|c)          = 1.2748 bits
H(A|c,S1)       = 0.7583 bits   →  I(S1;A|c)          = +0.5165
H(A|c,S1,S2)    = 0.7090 bits   →  ΔI(S2|prev)        = +0.0492
H(A|c,S1,S2,S3) = 0.5144 bits   →  ΔI(S3|prev)        = +0.1946


In [122]:
context = """Solve the given problem with step by step reasoning and write final answer in the format of "Answer: <answer>". Probelm: What is the greatest common factor of $20 !$ and $200,\\!000$?  (Reminder: If $n$ is a positive integer, then $n!$ stands for the product $1\\cdot 2\\cdot 3\\cdot \\cdots \\cdot (n-1)\\cdot n$.)"""
steps = [
      "I want to find the largest positive integer that divides both $20 !$ and $200,\\!000$ evenly.",
      "One way to do this is to factor both numbers into prime factors and look for the common ones.",
      "I know that $200,\\!000 = 2^5\\cdot 10^4 = 2^9\\cdot 5^4$.",
      "To find the prime factorization of $20 !$, I can use the fact that it is the product of all the positive integers from $1$ to $20$.",
      "For each prime number $p$ between $1$ and $20$, I can count how many multiples of $p$ are in that range.",
      "For example, there are $10$ multiples of $2$ between $1$ and $20$, namely $2, 4, 6, \\dots, 20$.",
      "But there are also $5$ multiples of $4$, which is $2^2$, and $2$ multiples of $8$, which is $2^3$, and $1$ multiple of $16$, which is $2^4$.",
      "So, the total power of $2$ in $20 !$ is $10 + 5 + 2 + 1 = 18$.",
      "Similarly, there are $4$ multiples of $5$, namely $5, 10, 15, 20$, so the power of $5$ in $20 !$ is $4$.",
      "There are $6$ multiples of $3$, namely $3, 6, 9, \\dots, 18$, but there are also $2$ multiples of $9$, which is $3^2$, so the power of $3$ in $20 !$ is $6 + 2 = 8$.",
      "There are $2$ multiples of $7$, namely $7$ and $14$, so the power of $7$ in $20 !$ is $2$.",
      "There are $1$ multiple of each of the other prime numbers $11, 13, 17$, and $19$, so the powers of those primes in $20 !$ are $1$ each.",
      "Therefore, the prime factorization of $20 !$ is $2^{18}\\cdot 3^8\\cdot 5^4\\cdot 7^2\\cdot 11\\cdot 13\\cdot 17\\cdot 19$.",
      "To find the greatest common factor of $20 !$ and $200,\\!000$, I need to take the lowest power of each common prime factor.",
      "The only common prime factors are $2$ and $5$, and the lowest powers are $9$ and $4$, respectively.",
      "So, the greatest common factor is $2^9\\cdot 5^4 = 512\\cdot 625 = 320,\\!000$.\n\n# Answer\n\n320,000"
    ]

num_steps = []
for i, step in enumerate(steps):
    numbering = f"Step {i+1}: " + step
    num_steps.append(numbering)

for idx in range(len(num_steps)):
  

Step 1: I want to find the largest positive integer that divides both $20 !$ and $200,\!000$ evenly.
Step 2: One way to do this is to factor both numbers into prime factors and look for the common ones.
Step 3: I know that $200,\!000 = 2^5\cdot 10^4 = 2^9\cdot 5^4$.
Step 4: To find the prime factorization of $20 !$, I can use the fact that it is the product of all the positive integers from $1$ to $20$.
Step 5: For each prime number $p$ between $1$ and $20$, I can count how many multiples of $p$ are in that range.
Step 6: For example, there are $10$ multiples of $2$ between $1$ and $20$, namely $2, 4, 6, \dots, 20$.
Step 7: But there are also $5$ multiples of $4$, which is $2^2$, and $2$ multiples of $8$, which is $2^3$, and $1$ multiple of $16$, which is $2^4$.
Step 8: So, the total power of $2$ in $20 !$ is $10 + 5 + 2 + 1 = 18$.
Step 9: Similarly, there are $4$ multiples of $5$, namely $5, 10, 15, 20$, so the power of $5$ in $20 !$ is $4$.
Step 10: There are $6$ multiples of $3$, na

In [77]:
_LOG2E = torch.log2(torch.tensor(2.718281828459045))

def next_token_probs(prompt: str, temperature: float = 1.0) -> torch.Tensor:
    """softmax(logits/τ) over the full vocab, shape (|V|,)"""
    ids = tokenizer(prompt, return_tensors="pt").to(device)["input_ids"]
    with torch.no_grad():
        logits = model(ids).logits[0, -1] / temperature
    return torch.softmax(logits.float(), dim=-1)          # full-vocab probs

def info_gain_kl(context: str, step: str, temperature: float = 1.0) -> float:
    """ I≈KL( P(·|c,step) || P(·|c) ) measured on the distribution of the *first* next token (Answer: 직후). 결과 단위: bits """
    p = next_token_probs(context, temperature)
    q = next_token_probs(context + step, temperature)

    kl_nat = torch.sum(q * (q.log() - p.log()))           # nats
    kl_bits = (kl_nat / _LOG2E).item()                    # convert nats→bits
    return kl_bits                                        # ≥ 0 by definition

# ---------------------------------------------------------------------
# 2) Monte-Carlo 샘플링 기반 정보 이득  (엔트로피 근사)
# ---------------------------------------------------------------------
def sequence_nll(prompt: str, target: str) -> float:
    """ Cross-entropy (bits) of `target` given `prompt`."""
    full = prompt + target
    inputs = tokenizer(full, return_tensors="pt").to(device)
    p_len  = len(tokenizer(prompt)["input_ids"])

    labels = inputs["input_ids"].clone()
    labels[:, :p_len] = -100                      # ignore prompt tokens

    with torch.no_grad():
        logits = model(**inputs).logits
        loss   = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), reduction="none")
    target_loss = loss[labels.view(-1) != -100]
    return (target_loss.sum() / torch.log(torch.tensor(2.0))).item()  # bits

def sample_answers(prompt: str,
                   k: int = 8,
                   max_new_tokens: int = 56,
                   temperature: float = 0.8) -> list[str]:
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=temperature,
                top_p=0.9,
                num_return_sequences=k,
                pad_token_id=tokenizer.eos_token_id
              )
    # 잘라서 답변 부분만 디코딩
    answer_tokens = outs[:, inputs["input_ids"].shape[1]:]
    return tokenizer.batch_decode(answer_tokens, skip_special_tokens=True)

def entropy_mc(prompt: str, k: int = 16, max_new_tokens: int = 32, temperature: float = 0.8) -> float:
    """Monte-Carlo estimate of H(A|prompt) in bits."""
    samples = sample_answers(prompt, k, max_new_tokens, temperature)
    nlls = [sequence_nll(prompt, ans) for ans in samples]
    return sum(nlls) / len(nlls)

def info_gain_mc(context: str, step: str, k: int = 16, max_new_tokens: int = 32, temperature: float = 0.8) -> float:
    """ I ≈ Ĥ(A|c) – Ĥ(A|c,step) via MC sampling. 단위 bits (양수가 정보 이득, 음수면 정보 손실) """
    h_no  = entropy_mc(context, k, max_new_tokens, temperature)
    h_yes = entropy_mc(context + step, k, max_new_tokens, temperature)
    return h_no - h_yes

def answer_token(prompt: str) -> int:
    """
    returns token id of the *first* token the model produces right after 'Answer:'.
    여기선 숫자 한 글자(예: '12' → 'Ġ12' 토큰)라고 가정.
    """
    # 'Answer:' 까지 디코딩 후 generate 1토큰
    ids = tokenizer(prompt, return_tensors="pt").to(device)["input_ids"]
    with torch.no_grad():
        out = model.generate(ids, max_new_tokens=1, do_sample=False)   # greedy
    return int(out[0, -1])

def info_gain_answer_kl(context: str,step: str, correct_answer: str, temperature: float = 1.0) -> tuple[float, float]:
    """
    (1) KL over full vocab  (≥0)
    (2) Δ log2-prob of *gold answer token*  (= contribution sign)
    """
    # 준비: Answer:  프롬프트
    base = context + "Answer:"
    base_with = context + step + "Answer:"

    # 전체 분포
    p = next_token_probs(base, temperature)      # (|V|)
    q = next_token_probs(base_with, temperature)

    # 1) KL  (nats→bits)
    kl_bits = torch.sum(q * (q.log() - p.log())) / _LOG2E

    # 2) 정답 토큰 확률 변화
    # gold_id = answer_token(base)                 # 모델이 '정답'이라고 보는 토큰
    # delta_log2 = (q[gold_id].log2() - p[gold_id].log2()).item()

    gold_id = tokenizer(correct_answer, add_special_tokens=False)["input_ids"][0]
    delta_log2 = (q[gold_id].log2() - p[gold_id].log2()).item()

    return kl_bits.item(), delta_log2


def kl_and_delta_nll(context: str, step: str, answer: str, temperature: float = 1.0) -> tuple[float, float]:
    """Returns (KL_bits_full_answer , Δ-NLL_bits_full_answer)
    • KL = Σ_t  KL( p_with(·|prefix_t) || p_no(·|prefix_t) ) ≥ 0
    • Δ-NLL =  NLL_no − NLL_with   (정답 확률 ↑ → 양수)"""
    # pre-encode once to speed up
    prompt_no   = context + "Answer:"
    prompt_with = context + step + "Answer:"
    ans_ids     = tokenizer(answer, return_tensors="pt").to(device)["input_ids"][0]     # (T,)

    # holders
    total_kl_nat   = 0.0
    total_dlog2    = 0.0            # Δ-log2P over the whole answer

    # build running prefixes for with / no prompts
    ids_no   = tokenizer(prompt_no,   return_tensors="pt").to(device)["input_ids"]
    ids_with = tokenizer(prompt_with, return_tensors="pt").to(device)["input_ids"]

    for next_id in ans_ids:          # iterate over answer tokens
        with torch.no_grad():
            # logits for next token distribution
            logit_no   = model(ids_no).logits[0, -1] / temperature
            logit_with = model(ids_with).logits[0, -1] / temperature

        p_with = torch.softmax(logit_with.float(), dim=-1)            # (|V|)
        p_no   = torch.softmax(logit_no.float(),   dim=-1)

        # ----- KL( p_with || p_no )  (using p_with as 'true' dist) -----
        kl_nat = torch.sum(p_with * (torch.log(p_with) - torch.log(p_no)))
        total_kl_nat += kl_nat.item()

        # ----- Δ log2 P(next_id)  (direction) -----
        log2_with = ( torch.log(p_with[next_id]).item() ) / _LOG2E
        log2_no   = ( torch.log(p_no  [next_id]).item() ) / _LOG2E
        total_dlog2 += (log2_with - log2_no)          # >0 ⇒ 정답 확률↑

        # teacher-force next_id into both prefixes
        ids_no   = torch.cat([ids_no,   next_id.view(1,1)], dim=1)
        ids_with = torch.cat([ids_with, next_id.view(1,1)], dim=1)

    kl_bits   = total_kl_nat / _LOG2E           # nats→bits
    delta_nll = -total_dlog2                  # NLL_no − NLL_with  (bits)

    return kl_bits, delta_nll                 # (≥0 , ±)
    

In [73]:
# context = "Problem: What is (5 + 3) × 2 - 4?\n\n"
# step1   = "Step 1: Compute inside the parentheses first: 5 + 3 = 8.\n"
# step2   = "Step 2: Multiply the result by 2: 8 × 2 = 16.\n"
# step3   = "Step 3: Subtract 4: 16 - 4 = 12.\n"
# answer  = "Answer: 12."

context = "Problem: What is 1/2 + 1/4 + 1/4?\n\n"
step1   = "Step 1: Add 1/4 and 1/4 first: 1/4 + 1/4 = 1/2.\n"
step2   = "Step 2: Now add 1/2 + 1/2 = 1.\n"
step3   = "Step 3: All sum of the proabability is 1.\n"
answer  = " 1."

# context = "Problem: What is the sum of the digits of the number 84?\n\n"
# step1   = "Step 1: The tens digit of 84 is 8.\n"
# step2   = "Step 2: The ones digit of 84 is 4.\n"
# step3   = "Step 3: The subtraction of 8-4 is 4.\n"
# step4   = "Step 4: Add the digits: 8 + 4 = 12.\n"
# answer  = " 12."


for i,(s,ctx) in enumerate([(step1, context),
                            (step2, context+step1),
                            (step3, context+step1+step2),
                            # (step4, context+step1+step2+step3)
                            ], 1):
    kl, dnll = kl_and_delta_nll(ctx, s, answer)
    sign = "↑" if dnll>0 else "↓"
    print(f"Step{i}: KL={kl:6.3f} bits,  Δ-NLL={dnll:+6.3f} bits ({sign}⇒{'help' if dnll>0 else 'hurt'})")

kl1, d1 = info_gain_answer_kl(context, step1)
kl2, d2 = info_gain_answer_kl(context+step1, step2)
kl3, d3 = info_gain_answer_kl(context+step1+step2, step3)
# kl4, d4 = info_gain_answer_kl(context+step1+step2+step3, step4)

print(f"Step1 KL={kl1:.3f}, Δlog₂P(ans)={d1:+.3f}")
print(f"Step2 KL={kl2:.3f}, Δlog₂P(ans)={d2:+.3f}")
print(f"Step3 KL={kl3:.3f}, Δlog₂P(ans)={d3:+.3f}")
# print(f"Step4 KL={kl4:.3f}, Δlog₂P(ans)={d4:+.3f}")

Step1: KL= 1.770 bits,  Δ-NLL=+3.046 bits (↑⇒help)
Step2: KL= 4.294 bits,  Δ-NLL=-3.089 bits (↓⇒hurt)
Step3: KL= 0.458 bits,  Δ-NLL=+1.046 bits (↑⇒help)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Step1 KL=0.483, Δlog₂P(ans)=+0.288
Step2 KL=0.221, Δlog₂P(ans)=+0.023
Step3 KL=0.373, Δlog₂P(ans)=-0.752


In [82]:
examples = [
    # EXAMPLE 1  ──────────────────────────────────────────────────────────
    # {
    #     "name": "Sum-of-digits (84)",
    #     "context": "Problem: What is the sum of the digits of the number 84?\n\n",
    #     "steps": [
    #         "Step 1: The tens digit of 84 is 8.\n",                       # ✅
    #         "Step 2: The ones digit of 84 is 4.\n",                      # ✅
    #         "Step 3: The subtraction of 8-4 is 4.\n",                    # ❌ (irrelevant)
    #         "Step 4: Add the digits: 8 + 4 = 12.\n"                      # ✅
    #     ],
    #     "answer": "Answer: 12."
    # },
    # EXAMPLE 2  ──────────────────────────────────────────────────────────
    {
        "name": "Rectangle area (5 cm × 3 cm)",
        "context": "Problem: A rectangle has length 5 cm and width 3 cm. What is its area?\n\n",
        "steps": [
            "Step 1: The area of a rectangle is length × width.\n",      # ✅
            "Step 2: Add the dimensions: 5 + 3 = 8.\n",                  # ❌ (wrong op)
            "Step 3: Multiply: 5 × 3 = 15 square centimetres.\n"         # ✅
        ],
        "answer": "Answer: 15 cm^2."
    },
    # EXAMPLE 3  ──────────────────────────────────────────────────────────
    {
        "name": "Solve for x (3x + 9 = 18)",
        "context": "Problem: Solve for x: 3x + 9 = 18.\n\n",
        "steps": [
            "Step 1: Subtract 9 from both sides: 3x = 9.\n",             # ✅
            "Step 2: Divide both sides by 3: x = 3.\n",                  # ✅
            "Step 3: Check: 3(3) + 9 = 18, so x = 3 is correct.\n",      # ✅
            "Step 4: Therefore, x = 6.\n"                                # ❌ (contradict)
        ],
        "answer": "Answer: 3."
    },
    # EXAMPLE 4  ──────────────────────────────────────────────────────────
    {
        "name": "Prime factorization (60)",
        "context": "Problem: What is the prime factorization of 60?\n\n",
        "steps": [
            "Step 1: 60 = 6 × 10.\n",                                    # ✅
            "Step 2: 6 = 2 × 3.\n",                                      # ✅
            "Step 3: 10 = 2 × 5.\n",                                     # ✅
            "Step 4: So 60 = 2 × 2 × 3 × 5.\n",                          # ✅
            "Step 5: Combine two 2's into 4, so 60 = 4 × 3 × 5.\n"       # ❌ (not prime factors)
        ],
        "answer": "Answer: 2^2 × 3 × 5."
    },
    # EXAMPLE 5  ──────────────────────────────────────────────────────────
    {
        "name": "Simplify fraction (24/36)",
        "context": "Problem: Simplify the fraction 24/36.\n\n",
        "steps": [
            "Step 1: The GCD of 24 and 36 is 12.\n",                     # ✅
            "Step 2: Divide numerator by 6: 24 ÷ 6 = 4.\n",              # ❌ (wrong divisor)
            "Step 3: Divide numerator and denominator by 12: 24/12 = 2, 36/12 = 3.\n",  # ✅
            "Step 4: The simplified fraction is 2/3.\n"                 # ✅
        ],
        "answer": "Answer: 2/3."
    }
]

for ex in examples:
    print(f"\n=== {ex['name']} ===")
    ctx = ex["context"]
    for idx, step in enumerate(ex["steps"], 1):
        kl, dnll = kl_and_delta_nll(ctx, step, ex["answer"], 0.8)
        label = "help ✅" if dnll > 0 else "hurt ❌"
        print(f"Step {idx:>2}: KL={kl:6.2f} bits,  Δ-NLL={dnll:+6.2f} bits   {label}")
        ctx += step  # 다음 스텝 컨텍스트에 누적

    print(f"→ {ex['answer']}")
    print("-" * 60)


=== Rectangle area (5 cm × 3 cm) ===
Step  1: KL=  2.52 bits,  Δ-NLL= -1.22 bits   hurt ❌
Step  2: KL=  6.95 bits,  Δ-NLL= +5.48 bits   help ✅
Step  3: KL=  3.25 bits,  Δ-NLL= -7.73 bits   hurt ❌
→ Answer: 15 cm^2.
------------------------------------------------------------

=== Solve for x (3x + 9 = 18) ===
Step  1: KL=  3.51 bits,  Δ-NLL= -1.52 bits   hurt ❌
Step  2: KL=  1.56 bits,  Δ-NLL= +3.23 bits   help ✅
Step  3: KL=  0.47 bits,  Δ-NLL= -3.28 bits   hurt ❌
Step  4: KL=  4.04 bits,  Δ-NLL= -0.38 bits   hurt ❌
→ Answer: 3.
------------------------------------------------------------

=== Prime factorization (60) ===
Step  1: KL=  1.95 bits,  Δ-NLL= -4.33 bits   hurt ❌
Step  2: KL=  0.82 bits,  Δ-NLL= +3.37 bits   help ✅
Step  3: KL=  0.33 bits,  Δ-NLL= +1.60 bits   help ✅
Step  4: KL=  0.86 bits,  Δ-NLL= -3.30 bits   hurt ❌
Step  5: KL=  3.05 bits,  Δ-NLL= +3.88 bits   help ✅
→ Answer: 2^2 × 3 × 5.
------------------------------------------------------------

=== Simplify fract

# Using vllm

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import os

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "3"

model_name = "Qwen/Qwen2.5-Math-7B-Instruct"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,            # enable 4-bit (QLoRA-style) weights
    bnb_4bit_quant_type="nf4",    # NF4 gives the best accuracy for most LLMs
    bnb_4bit_use_double_quant=True, # optional: second quantisation pass to save ~0.4 bits/param
    bnb_4bit_compute_dtype=torch.bfloat16  # faster matmuls on recent GPUs; fall back to float16 if needed
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",      # let Accelerate split layers across all visible GPUs
    quantization_config=bnb_config,
    torch_dtype="auto",     # keeps non-linear layers in their original dtype
    trust_remote_code=True  # Qwen models need their custom code
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [00:20<00:00,  5.20s/it]


In [2]:
class PRMConfig:
    # MC config
    model_name:             str = "Qwen/Qwen2.5-Math-7B-Instruct"    # "Qwen/Qwen2.5-Math-7B", "Qwen/Qwen2.5-Math-7B-Instruct" , "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "meta-llama/Llama-3.1-8B"
    max_new_tokens:         int = 512
    num_rollouts:           int = 8      
    use_llm:                bool = True  # Use llm for masking
    reward_type:            str = "contri"  # ori, contri, mi, naive, norm
    # PRM Model config 
    hidden_size:        int = 512      # 256-1024 범위에서 적절
    num_layers:         int = 3        # 2-4 범위에서 적절
    dropout:            float = 0.2    # 0.1-0.3 범위에서 적절
    # PRMTrainer config 
    batch_size:         int = 16       # 12 → 16으로 증가 (더 안정적)
    learning_rate:      float = 3e-4   # 5e-4 → 3e-4로 감소 (더 안정적)
    num_workers:        int = 4        # 적절
    weight_decay:       float = 1e-2   # 적절
    lr_scheduler:       str = "cosine" # 적절
    dataset_size:       int = 0
    warmup_steps:       int = 40       # 22 → 50으로 증가 (더 안정적)
    grad_clip:          float = 1.0    # 적절
    epochs:             int = 20       # 25 → 15로 감소 (early stopping 고려)
    # Misc config
    use_wandb:          bool = True
    wandb_project:      str = "mc_prm"
    run_name:           str = "test_400_0715"
    checkpoint_dir:     str = "./checkpoints/0715/contri"
    seed:               int = 42
    # Inference config
    num_candidates:     int = 4

In [None]:
from mc_shaped_reward import MCRewardShaped

cfg = PRMConfig()
mcrs = MCRewardShaped(config=cfg , model=model, tokenizer=tokenizer)
gsm8k_raw = mcrs.gsm8k_reward_dataset(split="train", start=5001, take=1)
gsm8k_raw

Detected Qwen model: qwen2forcausallm, excluding presence_penalty


Building GSM-8K reward-dataset:   0%|          | 0/1 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


[1-th Step, 0-th Original Rollout]  She charges $20.00 per drawing, so she made 40×20 = <<40×20=800 >>$800.
Answer: \boxed{800}
```

The final answer is \(\boxed{800}\). Pred Answer 800 Gold Answer 800
[1-th Step, 1-th Original Rollout]  She charges $20.00 per drawing, so she made 40×20 = <<40×20=800>>800 dollars
Answer: \boxed{800} To determine how much money Gretchen made over the weekend, we need to follow these steps:

1. Calculate the total number of drawings she sold.
2. Multiply the total number of drawings by her charge per drawing.

First, let's find the total number of drawings she sold. She sold 24 drawings on Saturday and 16 drawings on Sunday. So, we add these two numbers together:
\[ 24 + 16 = 40 \]

Next, we need to find out how much money she made from these 40 drawings. Since she charges $20.00 per drawing, we multiply the total number of drawings by her charge per drawing:
\[ 40 \times 20 = 800 \]

Therefore, the total amount of money Gretchen made over the weekend is

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


[2-th Step, 0-th Original Rollout] 800

To determine how much money Gretchen made over the weekend, we need to follow these steps:

1. Calculate the total number of caricatures she drew over the two days.
2. Multiply the total number of caricatures by the price per drawing.

Step 1: She drew 24 caricatures on Saturday and 16 caricatures on Sunday. So, the total number of caricatures she drew is:
\[ 24 + 16 = 40 \]

Step 2: She charges $20.00 per drawing. Therefore, the total amount of money she made is:
\[ 20 \times 40 = 800 \]

Thus, the total amount of money Gretchen made over the weekend is:
\[
\boxed{800}
\] Pred Answer None Gold Answer 800
[2-th Step, 1-th Original Rollout] 800

To determine how much money Gretchen made, we need to follow these steps:

1. Calculate the total number of caricatures she drew over the weekend.
2. Multiply the total number of caricatures by the price per drawing.

Step 1: Calculate the total number of caricatures she drew.
Gretchen drew 24 caricatures 

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


perturbed step: ['Step 1: She drew 24 on Saturday and 16 on Sunday for a total of 24+16 = <<24+16=40>>40 drawings', 'Step 2: "She charges [MASKED] per drawing and she drew [MASKED] caricatures so she made [MASKED]*[MASKED] = [MASKED]"\n\nTo solve the problem, we need to identify the key components of the sentence that are necessary to determine the total amount of money she made. These']


Building GSM-8K reward-dataset: 100%|██████████| 1/1 [00:57<00:00, 57.12s/it]


[{'question': 'Gretchen draws caricatures in the park on the weekends.  She charges $20.00 per drawing.  If she sold 24 on Saturday and 16 on Sunday, how much money did she make?',
  'completion': ['Step 1: She drew 24 on Saturday and 16 on Sunday for a total of 24+16 = <<24+16=40>>40 drawings',
   'Step 2: She charges $20.00 per drawing and she drew 40 caricatures so she made $20*40 = $<<20*40=800>>800'],
  'ori_rewards': [0.875, 0.0],
  'ptb_rewards': [1.0, 1.0],
  'contributions': [-0.125, -1.0],
  'mi_rewards': [-4.25440533955892, 4.919126192728678],
  'naive_rewards': [0.875, 4.9191],
  'gold_answer': '800'}]

: 