# Import Classes

In [2]:
import argparse
import json
import re
from pathlib import Path
from typing import List, Tuple
import os
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    BitsAndBytesConfig,
    PreTrainedTokenizer,
)

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "3"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
import json
from pathlib import Path
from typing import Dict, List, Optional
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import logging
import math
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset
from tqdm import tqdm

class PRMConfig:
    """Configuration class for PRM hyperparameters and settings"""
    # MC config
    model_name:             str = "Qwen/Qwen2.5-Math-7B"    # "Qwen/Qwen2.5-Math-7B", "Qwen/Qwen2.5-Math-7B-Instruct" , "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "meta-llama/Llama-3.1-8B"
    max_new_tokens:         int = 512
    num_rollouts:           int = 8      
    samples_per_question:   int = 1
    use_llm:                bool = True  # Use llm for masking
    reward_type:            str = "contri"  # ori, contri, mi, naive, norm
    # PRM Model config 
    hidden_size:        int = 512      # 256-1024 범위에서 적절
    num_layers:         int = 3        # 2-4 범위에서 적절
    dropout:            float = 0.2    # 0.1-0.3 범위에서 적절
    # PRMTrainer config 
    batch_size:         int = 16       # 12 → 16으로 증가 (더 안정적)
    learning_rate:      float = 3e-4   # 5e-4 → 3e-4로 감소 (더 안정적)
    num_workers:        int = 4        # 적절
    weight_decay:       float = 1e-2   # 적절
    lr_scheduler:       str = "cosine" # 적절
    dataset_size:       int = 0
    warmup_steps:       int = 40       # 22 → 50으로 증가 (더 안정적)
    grad_clip:          float = 1.0    # 적절
    epochs:             int = 20       # 25 → 15로 감소 (early stopping 고려)
    # Misc config
    use_wandb:          bool = True
    wandb_project:      str = "mc_prm"
    run_name:           str = "test_400_0715"
    checkpoint_dir:     str = "./checkpoints/0715/contri"
    seed:               int = 42

class ProcessRewardModel(nn.Module):
    """Enhanced Process Reward Model with dropout and layer normalization"""
    def __init__(self, input_size: int, cfg: "PRMConfig"):
        """ 
        Args:
            input_size : CLS-embedding dim of the frozen LLM backbone
            cfg        : PRMConfig instance (hidden_size, num_layers, dropout …)
        """
        super().__init__()
        
        self.input_size = input_size
        # self.output_size = cfg.output_size
        h = cfg.hidden_size
        p_drop = cfg.dropout
        n_layers = cfg.num_layers
        act_fn     = nn.GELU()

         # ── first projection ────────────────────────────────────────────
        self.in_proj = nn.Sequential(
            nn.Linear(input_size, h),
            nn.LayerNorm(h),
            act_fn,
            nn.Dropout(p_drop),
        )

        # ── stacked residual blocks ─────────────────────────────────────
        blocks = []
        for _ in range(n_layers - 1):
            blocks.append(
                nn.Sequential(                   # pre-LN residual MLP
                    nn.LayerNorm(h),
                    nn.Linear(h, h),
                    act_fn,
                    nn.Dropout(p_drop),
                    nn.Linear(h, h),
                    nn.Dropout(p_drop),
                )
            )
        self.blocks = nn.ModuleList(blocks)

        # ── output head ────────────────────────────────────────────────
        self.out_proj = nn.Sequential(
            nn.LayerNorm(h),
            nn.Linear(h, 1),
            nn.Sigmoid(),
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.in_proj(x)
        for blk in self.blocks:
            x = x + blk(x)          # residual connection
        return self.out_proj(x).squeeze(-1)

    def get_complexity(self) -> int:
        return sum(p.numel() for p in self.parameters())
    
class StepwisePRMDataset(Dataset):
    """mcr rewards가 반환한 entries(list[dict])를 (input_ids, scalar_reward) 샘플들로 변환한다.
    한 entry = {question, completion[steps], rewards[float], …} →  (Problem + Step1, r1), (Problem + Step1 \nStep2, r2) …"""
    def __init__(
        self,
        entries: List[dict],
        tokenizer: PreTrainedTokenizer,
        max_length: int = 512,
        reward_type: str = "naive",
        *,
        cache_encodings: bool = True,
        preprocess: bool = True,
    ):
        self.tokenizer   = tokenizer
        self.max_length  = max_length
        self.reward_type = reward_type
        self.cache       = {} if cache_encodings else None
        self.samples: List[Tuple[str, float]] = []

        for e in entries:
            q_txt   = e["question"]
            steps   = e["completion"]
            ans = e["gold_answer"]
            o_rewards = e["ori_rewards"]
            assert len(steps) == len(o_rewards)

            if self.reward_type == "contri":
                rewards = e["contributions"]
                # rewards = [max(0.0, x) for x in contri]
            elif self.reward_type == "mi":
                rewards = e["mi_rewards"]
            elif self.reward_type == "naive":
                rewards = e["naive_rewards"]
            else:
                rewards = o_rewards

            prefix_lines = [f"Problem: {q_txt}"]
            for step_txt, r in zip(steps, rewards):
                prefix_lines.append(step_txt)
                full_txt = "\n".join(prefix_lines)
                if preprocess:
                    full_txt = self._clean(full_txt)
                self.samples.append((full_txt, float(r)))   # (text, reward)

    # --------------------------------------------------------------------- utils
    @staticmethod
    def _clean(txt: str) -> str:
        """whitespace normalize + 소문자화(선택적) 등 간단 전처리"""
        txt = re.sub(r"\s+", " ", txt).strip()
        return txt

    # --------------------------------------------------------------------- dunder
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        text, reward = self.samples[idx]

        if self.cache is not None and text in self.cache:
            ids = self.cache[text]
        else:
            ids = self.tokenizer(
                text,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            ).input_ids.squeeze(0)
            if self.cache is not None:
                self.cache[text] = ids

        return ids, torch.tensor(reward, dtype=torch.float32)

class PRMTrainer:
    """
    (1) entries(list[dict]) → StepwisePRMDataset
    (2) LLM encoder + PRM head fine-tuning
    """
    def __init__(self, cfg: PRMConfig, model, tokenizer):
        self.cfg = cfg
        torch.manual_seed(cfg.seed)

        # ----------------------------- Backbone model LLM (frozen or fine-tuned)
        self.tokenizer = tokenizer
        self.model  = model
        self.model.eval()       # LLM은 feature extractor로 freeze
        for p in self.model.parameters():
            p.requires_grad_(False)

        feat_dim = self.model.config.hidden_size
        self.prm = ProcessRewardModel(feat_dim, cfg=cfg)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.prm.to(self.device)

        self.opt  = optim.AdamW(self.prm.parameters(), lr=cfg.learning_rate, weight_decay = cfg.weight_decay)
        self.crit = nn.MSELoss()
        # self.crit = nn.BCELoss()

        self.scheduler = None
        if cfg.lr_scheduler == "cosine":                   
            # total steps = (#batches per epoch) × epochs
            self.total_steps = math.ceil(cfg.epochs * cfg.dataset_size / cfg.batch_size)
            def lr_lambda(step):
                if step < cfg.warmup_steps:
                    return step / max(1, cfg.warmup_steps)
                progress = (step - cfg.warmup_steps) / max(1, self.total_steps - cfg.warmup_steps)
                return 0.5 * (1.0 + math.cos(math.pi * progress))
            self.scheduler = LambdaLR(self.opt, lr_lambda)
        elif cfg.lr_scheduler == "linear":
            # Linear warmup + decay
            self.total_steps = math.ceil(cfg.epochs * cfg.dataset_size / cfg.batch_size)
            def lr_lambda(step):
                if step < cfg.warmup_steps:
                    return step / max(1, cfg.warmup_steps)
                return max(0.0, (self.total_steps - step) / (self.total_steps - cfg.warmup_steps))
            self.scheduler = LambdaLR(self.opt, lr_lambda)
        elif cfg.lr_scheduler == "step":
            # Step decay
            self.scheduler = optim.lr_scheduler.StepLR(self.opt, step_size=5, gamma=0.5)

        self.ckpt_dir = Path(cfg.checkpoint_dir)
        self.ckpt_dir.mkdir(exist_ok=True, parents=True)

        self.wandb_run = None
        if cfg.use_wandb:                                  # <-- config에 플래그
            self.wandb_run = wandb.init(
                project=cfg.wandb_project,                 # e.g. "omega-prm"
                name=cfg.run_name,                         # e.g. "qwen7b-prm"
                config=vars(cfg),                          # 모든 하이퍼파라미터 로깅
            )

    # ----------------------------------------------------------------- features
    @torch.no_grad()
    def _encode(self, ids: torch.Tensor) -> torch.Tensor:
        """input_ids [B,T] → [B, feat_dim] using 마지막 hidden state의 CLS-like 첫 토큰"""
        out = self.model(input_ids=ids, return_dict=True,output_hidden_states=True)
        features = out.hidden_states[-1][:, 0, :]     # CLS embedding
        return features.float()

    # ----------------------------------------------------------------- loop util
    def _run_epoch(self, loader: DataLoader, train: bool, epoch_idx: int) -> float:
        self.prm.train(train)
        total = 0.0
        num_batches = len(loader)
        
        for step, (ids, reward) in enumerate(loader):
            ids, reward = ids.to(self.device), reward.to(self.device)

            with torch.set_grad_enabled(train):
                feats  = self._encode(ids)
                pred   = self.prm(feats).squeeze(-1)
                loss   = self.crit(pred, reward)
                
                if train:
                    self.opt.zero_grad()
                    loss.backward()
                    # Gradient clipping
                    torch.nn.utils.clip_grad_norm_(self.prm.parameters(), self.cfg.grad_clip)
                    # Gradient accumulation (optional)
                    if hasattr(self.cfg, 'grad_accum_steps') and self.cfg.grad_accum_steps > 1:
                        if (step + 1) % self.cfg.grad_accum_steps == 0:
                            self.opt.step()
                            if self.scheduler: self.scheduler.step()
                    else:
                        self.opt.step()
                        if self.scheduler: self.scheduler.step()

            total += loss.item()

            # -------- minibatch logging --------
            if self.wandb_run and train:
                wandb.log({
                    "batch_loss": loss.item(),
                    "epoch": epoch_idx + step / num_batches,
                    "lr": self.opt.param_groups[0]["lr"],
                    "grad_norm": sum(p.grad.data.norm(2).item() for p in self.prm.parameters() if p.grad is not None),
                    "pred_mean": pred.mean().item(),
                    "pred_std": pred.std().item(),
                    "reward_mean": reward.mean().item(),
                    "reward_std": reward.std().item(),
                })

        return total / len(loader)

    # ----------------------------------------------------------------- public
    def fit(self, train_loader, val_loader) -> Dict[str, List[float]]:
        self.cfg.dataset_size = len(train_loader) 

        history = {"train": [], "val": []}
        best_val, bad_epochs, patience = float("inf"), 0, 8  # patience 증가

        for ep in range(self.cfg.epochs):
            tr_loss = self._run_epoch(train_loader, train=True,  epoch_idx=ep)
            vl_loss = self._run_epoch(val_loader,   train=False, epoch_idx=ep)

            history["train"].append(tr_loss)
            history["val"].append(vl_loss)
            print(f"[Epoch {ep+1}/{self.cfg.epochs}] train={tr_loss:.4f}  val={vl_loss:.4f}")

            # -------- epoch logging --------
            if self.wandb_run:
                wandb.log({
                    "train_loss": tr_loss,
                    "val_loss": vl_loss,
                    "epoch": ep,
                    "lr": self.opt.param_groups[0]["lr"],
                })

            # 체크포인트 저장
            if vl_loss < best_val:
                best_val = vl_loss
                bad_epochs = 0
                self._save_checkpoint("best_prm.pt", epoch=ep, val_loss=vl_loss)
                print(f"[Best] New best validation loss: {vl_loss:.4f}")
            else:
                bad_epochs += 1
                print(f"[Early-Stopping] No improvement for {bad_epochs}/{patience} epochs")
                if bad_epochs >= patience:
                    print(f"[Early-Stopping] Stopping training after {patience} epochs without improvement")
                    break
        
        self._save_checkpoint("last_prm.pt", epoch=self.cfg.epochs - 1, val_loss=vl_loss)
        return history
    
    # ------------------------------------------------------------------
    # Checkpoint helpers
    def _save_checkpoint(self, filename: str, *, epoch: int, val_loss: float) -> None:
        path = self.ckpt_dir / filename
        save_dict = {
            "epoch": epoch,
            "val_loss": val_loss,
            "prm_state": self.prm.state_dict(),
            "scheduler_state": (self.scheduler.state_dict() if self.scheduler else None),
            "optimizer_state": self.opt.state_dict(),
            "config": vars(self.cfg),              # hyper‑params for reproducibility
            "model_name_or_path": getattr(self.model, "name_or_path", None),
            "tokenizer_config": self.tokenizer.__dict__.get("init_kwargs", {}),
        }
        torch.save(save_dict, path)
        print(f"[CKPT] Saved ⇒ {path}")

    # ------------------------------------------------------------------
    # Simple inference helper
    @torch.no_grad()
    def predict_reward(self, text: str) -> float:
        ids = self.tokenizer.encode(text, return_tensors="pt").to(self.device)
        feat = self._encode(ids)
        return float(torch.sigmoid(self.prm(feat)).item())


# Utils for Inference

In [32]:
ANSWER_PATTERN = re.compile(
    r"""^[\s>#*\-]*          # optional markdown/bullet symbols
        Answer               # word 'Answer'
        \s*[:.\-]\s*         # separator
        (.+?)\s*$            # capture everything after
    """,
    re.IGNORECASE | re.MULTILINE | re.VERBOSE,
)
STEP_PATTERN = re.compile(r"Step\s*\d+\s*:\s*(.*?)(?=\nStep|\nAnswer|$)", re.S)
# ANSWER_PATTERN = re.compile(r"Answer\s*:\s*(.+?)\s*$", re.S)

def build_prompt(question: str) -> str:
    """Return the prompt given a raw *question* string."""
    return f"""<|im_start|>system
You are a helpful math tutor. You must solve problems step-by-step using the exact format:
Step 1: [first step]
Step 2: [second step]
...
Answer: [final answer]

Example:
Problem: What is 5 + 3?
Step 1: Add 5 and 3
Step 2: 5 + 3 = 8
Answer: 8

Now solve the given problem using the same format.
<|im_end|>
<|im_start|>user
{question}
<|im_end|>
<|im_start|>assistant
"""

def post_process_response(text: str) -> str:
    # Step 패턴 찾기
    step_pattern = r'Step\s*\d+:\s*[^\n]*'
    steps = re.findall(step_pattern, text, re.IGNORECASE)
    
    # Answer 패턴 찾기
    answer_pattern = r'Answer:\s*([^\n]*)'
    answer_match = re.search(answer_pattern, text, re.IGNORECASE)
    
    if steps and answer_match:
        # 올바른 형식으로 재구성
        result = "\n".join(steps)
        result += f"\nAnswer: {answer_match.group(1).strip()}"
        return result
    else:
        return text

def parse_steps_and_answer(text: str) -> Tuple[List[str], str]:
    """Extract step list and answer string from a generated trajectory."""
    steps = [m.group(1).strip() for m in STEP_PATTERN.finditer(text)]
    ans_match = ANSWER_PATTERN.search(text)
    answer = ans_match.group(1).strip() if ans_match else ""
    return steps, answer

def generate_candidates(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    prompt: str,
    num_candidates: int,
    gen_cfg: GenerationConfig,
    device: torch.device,
) -> List[str]:
    """Generate *num_candidates* reasoning trajectories for the prompt."""
    inputs = tokenizer([prompt] * num_candidates, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_cfg.to_dict())
    texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    texts = [t[len(prompt):] for t in texts]
    
    # processed_texts = []
    # for text in texts:
    #     processed_text = post_process_response(text)
    #     processed_texts.append(processed_text)
    return texts

def compute_step_rewards(
    baseline: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    prm: ProcessRewardModel,
    prm_device: torch.device,
    prompt: str,
    steps: List[str],
) -> List[float]:
    """Return a list of scalar rewards (float) for each *completed* step."""
    rewards: List[float] = []

    # We will iteratively feed *prompt + completed steps* through baseline.
    cumulative_text = prompt
    for i, step_txt in enumerate(steps):
        cumulative_text += f"Step {i + 1}: {step_txt}\n"
        tokens = tokenizer(cumulative_text, return_tensors="pt").to(prm_device)
        with torch.no_grad():
            outputs = baseline(**tokens, output_hidden_states=True)
        # Use hidden states of the last token (or pool as needed)
        last_hidden = outputs.hidden_states[-1][0, -1, :]  # (hidden_dim,)
        last_hidden = last_hidden.float() 
        reward = prm(last_hidden.unsqueeze(0)).item()  # type: ignore
        rewards.append(reward)
    return rewards

# Main

In [None]:
# ------------------- Load baseline LM -------------------
model_name = "Qwen/Qwen2.5-Math-7B-Instruct" 
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)
baseline = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    quantization_config=bnb_config,
    torch_dtype="auto",
    trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

# ------------------- Load PRM ---------------------------
prm_ckpt_path = "/home/leena/ccc_eval/mcts_prm/prm_dataset2/checkpoints/0715/contri/best_prm.pt"
prm_ckpt = torch.load(prm_ckpt_path, map_location="cpu", weights_only=False)
prm_cfg = PRMConfig(**prm_ckpt.get("cfg", {}))
prm = ProcessRewardModel(baseline.config.hidden_size, cfg=prm_cfg)
prm.load_state_dict(prm_ckpt["prm_state"])
prm = prm.float()  # 명시적으로 Float32로 설정
prm = prm.to(device).eval()
print("Finish Loading Baseline and PRM!")

Loading checkpoint shards: 100%|██████████| 4/4 [00:19<00:00,  4.76s/it]


Finish Loading Baseline andPRM!


In [None]:
# ------------------- Dataset ---------------------------
ds = load_dataset("openai/gsm8k", "main", split="test")
max_samples = 2
if max_samples:
    ds = ds.select(range(max_samples))
loader = DataLoader(ds, batch_size=1, shuffle=False)
print("Finish Loading Dataset!")

config = PRMConfig()
gen_cfg = GenerationConfig(
    temperature=0.3,
    top_p=0.8,
    max_new_tokens=config.max_new_tokens,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id= tokenizer.eos_token_id,
    do_sample=True,
)
num_candidates = 6
results = []
for idx, sample in tqdm(enumerate(loader)):
    question = sample["question"][0]
    gold = sample.get("answer", [""])[0]
    prompt = build_prompt(question)

    # 1) Generate candidate CoTs
    cand_texts = generate_candidates(
        baseline,
        tokenizer,
        prompt,
        num_candidates,
        gen_cfg,
        device,
    )

    # 2) Score each candidate via PRM
    cand_scores: List[float] = []
    cand_answers: List[str] = []
    best_chain = ""
    for text in cand_texts:
        steps, answer = parse_steps_and_answer(text)
        print("Step/Answer:",steps, answer)
        step_rewards = compute_step_rewards(baseline, tokenizer, prm, device, prompt, steps)
        print("step rewards:",step_rewards)
        total_r = sum(step_rewards)
        cand_scores.append(total_r)
        cand_answers.append(answer)
        # Keep full chain for printing if it wins
        if total_r == max(cand_scores):
            best_chain = text

    best_idx = int(torch.tensor(cand_scores).argmax().item())
    best_answer = cand_answers[best_idx]
    best_score = cand_scores[best_idx]

    # 3) Save result
    results.append(
        {
            "id": sample.get("id", [idx])[0] if isinstance(sample.get("id", [idx]), list) else idx,
            "question": question,
            "gold": gold,
            "pred": best_answer,
            "chain": best_chain,
            "score": best_score,
        }
    )

    if (idx + 1) % 20 == 0:
        print(f"Processed {idx + 1}/{len(loader)} samples…")

results

# Merge Dataset

In [3]:
import json

with open("/home/leena/ccc_eval/mcts_prm/cmi_samples/math_contri_mistral_0_960.json", "r") as file:
    f1 = json.load(file)

with open("/home/leena/ccc_eval/mcts_prm/cmi_samples/math_contri_mistral_960_3153.json", "r") as file:
    f2 = json.load(file)

print(len(f1))
print(len(f2))    

merged_data = f1 + f2
print(len(merged_data))

with open("/home/leena/ccc_eval/mcts_prm/cmi_samples/math_contri_mistral_0_3153.json", "w") as f:
    json.dump(merged_data, f, indent=2)

with open("/home/leena/ccc_eval/mcts_prm/cmi_samples/math_contri_mistral_0_3153.json", "r") as file:
    saved = json.load(file)

print(len(saved))

960
3152
4112
4112


In [13]:
from datasets import load_dataset

ds = load_dataset("HuggingFaceTB/MATH", "all", split="train")
ds[4112]

{'problem': 'What is the domain of the function $f(x)=\\log_2(\\log_3(\\log_4(\\log_5x)))$?',
 'level': 'Level 4',
 'type': 'Intermediate Algebra',
 'solution': 'In order for the given function to have a real value, $\\log_3(\\log_4(\\log_5x))>0$ (since the logarithm of only any positive number is real). In order for the last inequality to be true, $\\log_4(\\log_5x)>1$ (since the logarithm of only any number greater than 1 is greater than 0). The last inequality is true only if $\\log_5x>4^1=4$, so $x>5^4\\Rightarrow x>625,$ or in interval notation, $x \\in \\boxed{(625, \\infty)}.$'}

In [1]:
import json

def read_jsonl(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line))
    return data

def jsonl_to_json(jsonl_path, json_path):
    data = read_jsonl(jsonl_path)
    with open(json_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    print(f"Converted {jsonl_path} to {json_path}")

jsonl_file = "/home/leena/ccc_eval/mcts_prm/cmi_samples/math_contri_mistral_4112_4681.jsonl"
json_file = "/home/leena/ccc_eval/mcts_prm/cmi_samples/math_contri_mistral_4112_4681.json"
jsonl_to_json(jsonl_file, json_file)

Converted /home/leena/ccc_eval/mcts_prm/cmi_samples/math_contri_mistral_4112_4681.jsonl to /home/leena/ccc_eval/mcts_prm/cmi_samples/math_contri_mistral_4112_4681.json


In [None]:
# 파일 경로 설정
contri_file = "/home/leena/ccc_eval/mcts_prm/cmi_samples/math_contri_mistral_total.json"
mi_file = "/home/leena/ccc_eval/mcts_prm/cmi_samples/math_mi_mistral_total.json"
output_file = "/home/leena/ccc_eval/mcts_prm/cmi_samples/total_math_merge_mistral.json"

with open(contri_file, "r") as file:
    contri_data = json.load(file)

with open(mi_file, "r") as file:
    mi_data = json.load(file)


def make_entry_key(entry: dict) -> tuple:
    return (
        entry["question"].strip(),
        tuple(entry["completion"]),  # 리스트를 튜플로 변환하여 해시 가능하게
        entry["gold_answer"].strip()
    )

merged_dict = {}
print("Processing contribution data...")
for entry in contri_data:
    try:
        key = make_entry_key(entry)
        merged_dict[key] = entry.copy()
    except KeyError as e:
        print(f"Warning: Entry missing required field: {e}")
        continue
print("Merging mutual information data...")
merged_count = 0
new_count = 0
error_count = 0
for entry in mi_data:
    try:
        key = make_entry_key(entry)
        if key in merged_dict:
            merged_dict[key].update({
                "mi_rewards": entry.get("mi_rewards"),
                "mi_filtered": entry.get("mi_filtered")
            })
            merged_count += 1
        else:
            merged_dict[key] = entry.copy()
            new_count += 1
    except KeyError as e:
        print(f"Warning: Entry missing required field: {e}")
        error_count += 1
        continue

merged_data = list(merged_dict.values())

print(f"\nMerge Summary:")
print(f"  Contribution entries: {len(contri_data)}")
print(f"  MI entries: {len(mi_data)}")
print(f"  Merged entries: {len(merged_data)}")
print(f"  Successfully merged: {merged_count}")
print(f"  New entries added: {new_count}")
print(f"  Errors encountered: {error_count}")

with open(output_file, 'w') as f:
    json.dump(data, f, indent=2)
print("✅ Merge completed successfully!")

In [None]:
import json
from pathlib import Path

def load(path: str):
    """json 파일을 읽어 list[dict] 반환."""
    return json.loads(Path(path).read_text(encoding="utf‑8"))

def save(obj, path: str):
    Path(path).write_text(
        json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf‑8"
    )

def make_key(entry: dict) -> tuple:
    """
    중복 여부를 판정할 key 생성.
    - question 은 공백 차이를 없애기 위해 strip
    - completion 은 순서를 보존하기 위해 튜플로 변환
    - gold_answer 도 strip
    """
    return (
        entry["question"].strip(),
        tuple(entry["completion"]),
        entry["gold_answer"].strip(),
    )

def merge_files(file1: str, file2: str, out: str = "merged.json"):
    data1, data2 = load(file1), load(file2)

    # 1️⃣ file1 을 기준으로 dict 초기화 (mi_* 포함)
    merged = {make_key(e): e.copy() for e in data1}

    # 2️⃣ file2 의 ori/ptb/contributions 를 덧붙임
    for e in data2:
        k = make_key(e)
        if k not in merged:          # 혹시 file1 에 없으면 그대로 추가
            merged[k] = e.copy()
        else:                        # 이미 있으면 reward 값만 update
            merged[k].update(
                {
                    "ori_rewards":      e.get("ori_rewards"),
                    "ptb_rewards":      e.get("ptb_rewards"),
                    "contributions":    e.get("contributions"),
                }
            )

    # 3️⃣ list 로 변환 후 저장
    save(list(merged.values()), out)
    print(f"✅  merged {len(merged)} entries → {out}")

if __name__ == "__main__":
    merge_files("file1.json", "file2.json")
