In [None]:
from pathlib import Path
import random
import subprocess
import os
import sys

import torch
import pyarrow.parquet as pq
from transformers import AutoModelForCausalLM, AutoTokenizer
from IPython.display import Markdown, display

def find_repo_root(start: Path) -> Path:
    cur = start.resolve()
    for _ in range(6):
        if (cur / "pyproject.toml").exists() and (cur / "verl").is_dir():
            return cur
        if cur.parent == cur:
            break
        cur = cur.parent
    return start.resolve()

REPO_ROOT = find_repo_root(Path.cwd())

BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
CKPT_DIR = REPO_ROOT / "checkpoints/verl_grpo_critique/qwen2.5_7b_instruct_critique_llama3b_4epoch/global_step_180/actor"
MERGED_DIR = CKPT_DIR / "merged_hf"
TRAIN_PATH = REPO_ROOT / "data/train_critique_3.2_4.parquet"

NUM_EXAMPLES = 4
SEED = 42
MAX_NEW_TOKENS = 2048
DO_SAMPLE = True
TEMPERATURE = 0.6
TOP_P = 0.9
LOCAL_ONLY = False

torch.set_grad_enabled(False)


torch.autograd.grad_mode.set_grad_enabled(mode=False)

In [6]:
def has_merged_weights(path: Path) -> bool:
    if (path / "model.safetensors").exists():
        return True
    if (path / "pytorch_model.bin").exists():
        return True
    if (path / "model.safetensors.index.json").exists():
        return True
    if any(path.glob("model-*-of-*.safetensors")):
        return True
    if any(path.glob("pytorch_model-*-of-*.bin")):
        return True
    return False


def run_merge(cmd, cwd: Path):
    env = dict(os.environ)
    env["PYTHONPATH"] = str(cwd) + os.pathsep + env.get("PYTHONPATH", "")
    result = subprocess.run(
        cmd,
        capture_output=True,
        text=True,
        cwd=str(cwd),
        env=env,
    )
    if result.returncode != 0:
        print("Merge failed.")
        if result.stdout:
            print("stdout:\n" + result.stdout)
        if result.stderr:
            print("stderr:\n" + result.stderr)
        raise RuntimeError(f"Merge failed with exit code {result.returncode}")
    if result.stdout:
        print(result.stdout)

if not has_merged_weights(MERGED_DIR):
    MERGED_DIR.mkdir(parents=True, exist_ok=True)
    cmd = [
        sys.executable, "-m", "verl.model_merger", "merge",
        "--backend", "fsdp",
        "--local_dir", str(CKPT_DIR),
        "--target_dir", str(MERGED_DIR),
    ]
    print("Running:", " ".join(cmd))
    run_merge(cmd, REPO_ROOT)
else:
    print(f"Found merged HF model at {MERGED_DIR}")


Running: /data1/home/yunhochoi/miniconda3/envs/verl/bin/python -m verl.model_merger merge --backend fsdp --local_dir /data1/home/yunhochoi/verl/checkpoints/verl_grpo_critique/qwen2.5_7b_instruct_critique_llama3b_4epoch/global_step_180/actor --target_dir /data1/home/yunhochoi/verl/checkpoints/verl_grpo_critique/qwen2.5_7b_instruct_critique_llama3b_4epoch/global_step_180/actor/merged_hf
config: ModelMergerConfig(operation='merge', backend='fsdp', target_dir='/data1/home/yunhochoi/verl/checkpoints/verl_grpo_critique/qwen2.5_7b_instruct_critique_llama3b_4epoch/global_step_180/actor/merged_hf', hf_upload_path=None, private=False, test_hf_dir=None, tie_word_embedding=False, trust_remote_code=False, is_value_model=False, local_dir='/data1/home/yunhochoi/verl/checkpoints/verl_grpo_critique/qwen2.5_7b_instruct_critique_llama3b_4epoch/global_step_180/actor', hf_model_config_path='/data1/home/yunhochoi/verl/checkpoints/verl_grpo_critique/qwen2.5_7b_instruct_critique_llama3b_4epoch/global_step_180

In [6]:
def pick_dtype() -> torch.dtype:
    if torch.cuda.is_available():
        major, _ = torch.cuda.get_device_capability()
        return torch.bfloat16 if major >= 8 else torch.float16
    return torch.float32

dtype = pick_dtype()
device_map = "auto" if torch.cuda.is_available() else None

tokenizer_base = AutoTokenizer.from_pretrained(
    BASE_MODEL, trust_remote_code=True, local_files_only=LOCAL_ONLY
)

def load_tokenizer(path: str, fallback=None):
    try:
        return AutoTokenizer.from_pretrained(
            path, trust_remote_code=True, local_files_only=True, use_fast=True
        )
    except Exception as exc:
        print(f"Tokenizer load failed for {path}: {exc}")
        if fallback is not None:
            print("Falling back to base tokenizer.")
            return fallback
        raise

tokenizer_ft = load_tokenizer(str(MERGED_DIR), fallback=tokenizer_base)


if tokenizer_base.pad_token is None:
    tokenizer_base.pad_token = tokenizer_base.eos_token
if tokenizer_ft.pad_token is None:
    tokenizer_ft.pad_token = tokenizer_ft.eos_token

model_base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=dtype,
    device_map=device_map,
    trust_remote_code=True,
    local_files_only=LOCAL_ONLY,
)
model_ft = AutoModelForCausalLM.from_pretrained(
    str(MERGED_DIR),
    torch_dtype=dtype,
    device_map=device_map,
    trust_remote_code=True,
)

if device_map is None:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model_base.to(device)
    model_ft.to(device)

model_base.eval()
model_ft.eval()


The tokenizer you are loading from '/data1/home/yunhochoi/verl/checkpoints/verl_grpo_critique/qwen2.5_7b_instruct_critique_llama3b_4epoch/global_step_180/actor/merged_hf' with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e. This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue.
Loading checkpoint shards: 100%|██████████| 4/4 [00:07<00:00,  1.84s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:07<00:00,  1.99s/it]


Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(152064, 3584, padding_idx=151643)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
          (k_proj): Linear(in_features=3584, out_features=512, bias=True)
          (v_proj): Linear(in_features=3584, out_features=512, bias=True)
          (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((3584,)

In [7]:
CUSTOM_PROMPTS = []  # Optional: override with your own list of message dicts.

if CUSTOM_PROMPTS:
    examples = CUSTOM_PROMPTS
else:
    table = pq.read_table(TRAIN_PATH, columns=["prompt"])
    prompts = table.column("prompt").to_pylist()
    random.seed(SEED)
    indices = random.sample(range(len(prompts)), k=NUM_EXAMPLES)
    examples = [prompts[i] for i in indices]

print(f"Loaded {len(examples)} critique prompts")


Loaded 4 critique prompts


In [8]:
def first_device(model) -> torch.device:
    if hasattr(model, "hf_device_map") and model.hf_device_map:
        for dev in model.hf_device_map.values():
            if dev not in ("cpu", "disk", "meta"):
                return torch.device(dev)
        first = next(iter(model.hf_device_map.values()))
        return torch.device(first)
    return model.device

def build_input(tokenizer, messages, model):
    input_ids = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, return_tensors="pt"
    )
    return input_ids.to(first_device(model))

@torch.no_grad()
def generate(model, tokenizer, messages):
    input_ids = build_input(tokenizer, messages, model)
    gen_kwargs = {
        "max_new_tokens": MAX_NEW_TOKENS,
        "do_sample": DO_SAMPLE,
        "pad_token_id": tokenizer.eos_token_id,
        "eos_token_id": tokenizer.eos_token_id,
    }
    if DO_SAMPLE:
        gen_kwargs.update({"temperature": TEMPERATURE, "top_p": TOP_P})

    output_ids = model.generate(input_ids, **gen_kwargs)
    new_tokens = output_ids[0, input_ids.shape[-1]:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()

def format_block(text: str, max_chars: int = 1200) -> str:
    if len(text) <= max_chars:
        return text
    return text[:max_chars] + "\n...\n[truncated]"


In [9]:
for idx, messages in enumerate(examples, start=1):
    prompt_text = messages[0]["content"]
    base_out = generate(model_base, tokenizer_base, messages)
    ft_out = generate(model_ft, tokenizer_ft, messages)

    display(Markdown(f"## Example {idx}"))
    display(Markdown("**Prompt**\n\n" + format_block(prompt_text).replace("\n", "  \n")))
    display(Markdown("**Base model**\n\n" + format_block(base_out).replace("\n", "  \n")))
    display(Markdown("**Fine-tuned**\n\n" + format_block(ft_out).replace("\n", "  \n")))


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


## Example 1

**Prompt**

User Question: Let $ABCD$ be a parallelogram. We have that $M$ is the midpoint of $AB$ and $N$ is the midpoint of $BC.$ The segments $DM$ and $DN$ intersect $AC$ at $P$ and $Q$, respectively. If $AC = 7917,$ what is $QA$? [asy]  
pair A, B, C, D, pM, pN, O, P, Q;  
A = (25, 0) * dir(-20);  
B = (15, 0) * dir(60);  
C = (25, 0) * dir(160);  
D = (15, 0) * dir(-120);  
pM = 0.5 * A + 0.5 * B;  
pN = 0.5 * B + 0.5 * C;  
O = 0.25 * A + 0.25 * B + 0.25 * C + 0.25 * D;  
P = 0.33 * C + 0.67 * A;  
Q = 0.67 * C + 0.33 * A;  
draw(A--B--C--D--cycle);  
draw(A--C);  
draw(pM--D);  
draw(pN--D);  
label("$A$", A, S);  
label("$B$", B, NE);  
label("$C$", C, NW);  
label("$D$", D, S);  
label("$M$", pM, NE);  
label("$N$", pN, NW);  
label("$P$", P, N);  
label("$Q$", Q, NE);  
[/asy]  
Correct Answer: 5278  
  
Model Solution Trace:  
## Step 1:  To solve this problem, we first need to understand the given information and what is being asked. We have a parallelogram $ABCD$ with midpoints $M$ on $AB$ and $N$ on $BC$. The segments $DM$ and $DN$ intersect $AC$ at $P$ and $Q$, respectively. We are given that $AC = 7917$ and we need to find the length of $QA$.  
## Step 2:  Since $M$ and $N$ are midpoints of $AB$ and $BC$, respectively, we know that   
...  
[truncated]

**Base model**

The solution provided contains several logical errors and misinterpretations. Let's carefully analyze the steps and identify the issues:  
  
## Step 1: Understanding the Problem  
The problem correctly states that \( M \) is the midpoint of \( AB \) and \( N \) is the midpoint of \( BC \). It also mentions that \( DM \) and \( DN \) intersect \( AC \) at points \( P \) and \( Q \), respectively, and provides the length of \( AC \) as 7917. The goal is to find the length of \( QA \).  
  
## Step 2: Analyzing the Midpoints and Diagonals  
While it is true that \( M \) and \( N \) are midpoints, the statement that \( AC = AD \) is incorrect. In a parallelogram, the diagonals bisect each other but do not necessarily equal the sides unless the parallelogram is a rhombus. However, this detail is not relevant to solving the problem directly.  
  
## Step 3: Identifying the Midpoints on \( AC \)  
The solution correctly identifies that \( P \) and \( Q \) are midpoints of \( AC \) due to the properties of medians in triangles formed by the diagonals of the parallelogram. This means \( AP = PC = BQ = QC = \frac{1}{2}AC = \frac{1}{2}(7917) = 3958.5 \).  
  
## Step 4: Calculating \( QA \)  
Here lies the main iss  
...  
[truncated]

**Fine-tuned**

The solution provided contains several logical errors and misinterpretations. Let's carefully analyze the problem and provide a correct solution.  
  
### Correct Solution:  
  
Given:  
- \(ABCD\) is a parallelogram.  
- \(M\) is the midpoint of \(AB\).  
- \(N\) is the midpoint of \(BC\).  
- Segments \(DM\) and \(DN\) intersect \(AC\) at \(P\) and \(Q\), respectively.  
- \(AC = 7917\).  
  
We need to find the length of \(QA\).  
  
#### Step-by-Step Solution:  
  
1. **Identify Key Properties**:  
   - In a parallelogram, the diagonals bisect each other. So, \(AC\) and \(BD\) bisect each other at point \(O\).  
   - Since \(M\) and \(N\) are midpoints, \(AM = MB\) and \(BN = NC\).  
  
2. **Use Similar Triangles**:  
   - Consider triangles \(AMD\) and \(CMD\). Since \(M\) is the midpoint of \(AB\), \(AM = MB\).  
   - Similarly, consider triangles \(BND\) and \(CND\). Since \(N\) is the midpoint of \(BC\), \(BN = NC\).  
  
3. **Intersection Points**:  
   - \(P\) and \(Q\) are points where \(DM\) and \(DN\) intersect \(AC\), respectively.  
   - By properties of similar triangles and the fact that \(M\) and \(N\) are midpoints, we can use the concept of mass points or Ceva's theorem to determine the ratios.  
  
4. **Apply Ceva  
...  
[truncated]

: 