In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
import logging

# HuggingFace Token
HF_TOKEN = ""

# 모델 경로, 세팅값
student_model_path = "KAERI-MLP/gemma2-Korean-AtomicGPT-9B"
judge_model_path = "meta-llama/Llama-3.1-8B-Instruct" 
max_seq_len = 1024
lora_rank = 16
lora_alpha = 32

logging.basicConfig(level=logging.INFO)

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

DEVICE = "cuda:0"   # H100 1장

bnb_config_4bit = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

lora_rank = 16
lora_alpha = 32
peft_config = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_alpha,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    bias="none",
    task_type="CAUSAL_LM"
)

# 학생 모델
student_tokenizer = AutoTokenizer.from_pretrained(student_model_path, token=HF_TOKEN)
student_model = AutoModelForCausalLM.from_pretrained(
    student_model_path,
    device_map={"": 0},  # 혹은 None + 아래 .to(DEVICE)
    attn_implementation="eager",
    quantization_config=bnb_config_4bit,
    torch_dtype=torch.bfloat16,
    token=HF_TOKEN
)
student_model = get_peft_model(student_model, peft_config)
student_model.to(DEVICE)

# 평가 모델 (Judge)
judge_tokenizer = AutoTokenizer.from_pretrained(judge_model_path, token=HF_TOKEN)
judge_model = AutoModelForCausalLM.from_pretrained(
    judge_model_path,
    device_map={"": 0},           # Student와 GPU/VRAM 같이 쓰임!
    attn_implementation="eager",
    torch_dtype=torch.bfloat16,
    token=HF_TOKEN
)
judge_model.to(DEVICE)

In [None]:
# ========= 데이터셋 준비 =========
from datasets import load_dataset, Dataset
import re

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

def extract_hash_answer(text: str):
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

def extract_xml_answer(text: str) -> str:
    try:
        return text.split("<answer>")[-1].split("</answer>")[0].strip()
    except Exception:
        return ""

def get_gsm8k_questions(split="train"):
    data = load_dataset('openai/gsm8k', 'main', split=split)
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])   # gold answer가 됨
    })
    return data

dataset = get_gsm8k_questions()

In [None]:
def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

import json
import re
import torch

def agent_reward_func(prompts, completions, answer, judge_json=None, **kwargs):
    results = []
    min_score = 0.0
    max_score = 30.0
    for idx, (p, c, a) in enumerate(zip(prompts, completions, answer)):
        student_output = c[0]['content']
        user_prompt = p[-1]['content'] if p and isinstance(p, list) and p[-1]['role'] == 'user' else ''
        gold_answer = a
        
        # Check if student output is empty or just whitespace
        if not student_output or student_output.strip() == "":
            # print(f"\n=== Evaluation Sample #{idx} ===")
            # print(f"Question: {user_prompt}")
            # print(f"Correct answer: {gold_answer}")
            # print(f"Student's answer: [EMPTY]")
            # print("[INFO] Empty answer detected, automatically assigning score 0")
            
            score = 0.0
            score_norm = 0.0
            eval_result = "0"
            
        else:
            judge_input = f"""Question: {user_prompt}
            Correct answer: {gold_answer}
            Student's answer: {student_output}

            Evaluate the student's answer based on the following criteria:
            - Reasoning process (50%)
            - Correctness of final answer (50%)

            Return ONLY a single integer score ranging from 0 to 30 (including all integers in between: 0, 1, 2, 3, 4, 5..., 29, 30) based on these criteria.
            (0 means completely incorrect, 30 means perfectly correct)
            DO NOT write any explanation, ONLY the score.
            DO NOT include any prefix like "Score:" — output ONLY the number.
            Example: 24
            Example: 17
            """
            # print(f"\n=== Evaluation Sample #{idx} ===")
            # print(f"Question: {user_prompt}")
            # print(f"Correct answer: {gold_answer}")
            # print(f"Student's answer: {student_output}")

            try:
                input_ids = judge_tokenizer(
                    judge_input, return_tensors="pt", truncation=True, max_length=1024
                )
                input_ids = {k: v.to(judge_model.device) for k, v in input_ids.items()}
                with torch.no_grad():
                    outputs = judge_model.generate(
                        **input_ids,
                        max_new_tokens=5,  # Limit to just enough tokens for a number
                        temperature=0.1,
                        pad_token_id=judge_tokenizer.eos_token_id,
                    )
                eval_result = judge_tokenizer.decode(
                    outputs[0, input_ids['input_ids'].shape[1]:], skip_special_tokens=True
                ).strip()
                print("[Raw Judge Output]:", repr(eval_result))

                # Extract only the first number from the output
                score_match = re.search(r'^-?\d+', eval_result.strip())
                if score_match:
                    score = float(score_match.group(0))
                else:
                    # If no number is found at the beginning, search anywhere in the string
                    score_match = re.search(r'^-?\d+', eval_result.strip())
                    score = float(score_match.group(0)) if score_match else 0.0
                # Check score boundaries
                if score > 30.0:
                    print("[INFO] Score exceeds 100, capping at maximum.")
                    score = 30.0
                elif score < 0.0:
                    print("[INFO] Score is negative, setting to 0.")
                    score = 0.0

                # Min-Max normalization
                score_norm = (score - min_score) / (max_score - min_score)
                print(f"Score (raw reward): {score}")
                print(f"Score (normalized reward): {score_norm}")

            except Exception as e:
                print("[ERROR] Judge evaluation error:", e)
                score_norm = 0.0
                eval_result = ""

        # --- Log judge_input & eval_result ---
        if judge_json is not None:
            log_data = {
                "judge_input": judge_input if 'judge_input' in locals() else "Empty student response - automatic 0 score",
                "judge_output": eval_result,
            }
            with open(judge_json, "a", encoding="utf-8") as f:
                f.write(json.dumps(log_data, ensure_ascii=False) + "\n")
        results.append(score_norm)

    return results

In [6]:
def int_reward_func(completions, **kwargs) -> list[float]:
    """출력에 숫자가 포함되어 있으면 보상을 줍니다."""
    responses = [completion[0]['content'] for completion in completions]
    
    # 전체 응답에서 숫자 찾기
    rewards = []
    for r in responses:
        if re.search(r'\d+', r):
            rewards.append(0.25)
        else:
            rewards.append(0.0)
    
    return rewards

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """응답이 정확한 XML 형식을 따르는지 확인하고 보상을 줍니다."""
    # 정확한 XML 형식 패턴
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    
    rewards = []
    for r in responses:
        if re.search(pattern, r, re.DOTALL):
            rewards.append(0.25)
        else:
            rewards.append(0.0)
    
    return rewards

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """응답이 유연한 형식(XML 또는 키워드)을 따르는지 확인하고 보상을 줍니다."""
    # XML 형식 또는 키워드 형식 패턴
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    
    responses = [completion[0]["content"] for completion in completions]
    
    rewards = []
    for r in responses:
        if re.search(pattern, r, re.DOTALL):
            rewards.append(0.125)
        else:
            rewards.append(0.0)
    
    return rewards

def count_xml(text) -> float:
    """XML 태그 또는 키워드의 존재 여부에 따라 보상을 계산합니다."""
    if not text:
        return 0.0
        
    count = 0.0
    # XML 태그 확인
    if "<reasoning>" in text:
        count += 0.0625
    if "</reasoning>" in text:
        count += 0.0625
    if "<answer>" in text:
        count += 0.0625
    if "</answer>" in text:
        count += 0.0625
        
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    """응답에서 XML 태그 또는 키워드의 존재 여부를 측정하여 보상합니다."""
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]


In [None]:
from trl import GRPOConfig, GRPOTrainer

training_args = GRPOConfig(
    use_vllm=False,
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_bnb_8bit",
    logging_steps=1,
    bf16=True,
    fp16=False,
    per_device_train_batch_size=4,     
    gradient_accumulation_steps=1,
    num_generations=4,
    max_prompt_length=256,
    max_completion_length=250,
    max_steps=1000,                    
    save_steps=1000,
    max_grad_norm=0.1,
    report_to="none",
    output_dir="outputs",
    do_train=True,
)

In [None]:
import os

judge_dir = "judge"
os.makedirs(judge_dir, exist_ok=True)  # judge 폴더 없으면 생성
judge_json_path = os.path.join(judge_dir, "judge_evaluations_metric(0~30).jsonl")

In [None]:
from transformers import TrainerCallback
import os

class BestModelSaver(TrainerCallback):
    def __init__(self):
        self.best_mean_reward = float("-inf")
        self.best_step = -1
        self.cumulative_reward = 0.0
        self.steps = []
        self.mean_rewards = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        if 'reward' in logs:
            step = state.global_step
            reward = logs['reward']
            self.cumulative_reward += reward

            self.steps.append(step)

            # (누적 reward 합) / (지금까지 step 수)
            mean_reward = self.cumulative_reward / len(self.steps)

            self.mean_rewards.append(mean_reward)

            if mean_reward > self.best_mean_reward:
                self.best_mean_reward = mean_reward
                self.best_step = step

                save_dir = "./best_adapter(metric1)"
                os.makedirs(save_dir, exist_ok=True)

                # 이전 모델 파일 삭제
                for f in os.listdir(save_dir):
                    os.remove(os.path.join(save_dir, f))

                # 새 최고 모델 저장
                student_model.save_pretrained(save_dir)
                print(f"Best model saved at step {step} with mean reward {mean_reward:.4f}")

    def print_best_result(self):
        if self.best_step != -1:
            print(f"Best mean reward of {self.best_mean_reward:.4f} was achieved at step {self.best_step}.")

# 트레이너 콜백 추가
best_model_saver = BestModelSaver()

trainer = GRPOTrainer(
    model=student_model,
    processing_class=student_tokenizer,
    reward_funcs=[lambda *a, **kw: agent_reward_func(*a, **kw, judge_json=judge_json_path),
                  xmlcount_reward_func, soft_format_reward_func, strict_format_reward_func, int_reward_func],
    args=training_args,
    train_dataset=dataset,
    callbacks=[best_model_saver]  # 콜백 추가
)

trainer.train()

best_model_saver.print_best_result()  # 최고 평균 리워드가 기록된 스텝 출력
