# Qwen 0.5b on GRPO

---
이 노트북은 [will brown](https://x.com/willccbb)이 작성한 [GRPO 데모](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb)의 변형 버전으로, gsm8k 수학 데이터셋을 사용하여 llama-1b를 학습시키는 내용입니다.

우리는 Colab에서 더 잘 작동하도록 다음과 같은 변경사항들을 구현했습니다:
* llama-1b를 Qwen-0.5b로 교체
* vllm을 사용한 생성 방식 도입(상당한 속도 향상). Qwen의 작은 크기 덕분에 GRPO에 사용되는 것과 동일한 GPU에서 vllm을 실행할 수 있습니다
* flash-attn 제거 (Qwen 모델링에서 반복적인 버그 발생, 원인 불명확)

---
Ref:
- [Original Notebook](https://colab.research.google.com/drive/1bfhs1FMLW3FGa8ydvkOZyBNxLYOu0Hev?usp=sharing#scrollTo=Q7qTZbUcg5VD)
- [Run your local code as a SageMaker training job](https://docs.aws.amazon.com/sagemaker/latest/dg/train-remote-decorator.html)
    - [Quick Start - Run local code as SageMaker training job](https://github.com/aws/amazon-sagemaker-examples/blob/main/sagemaker-remote-function/quick_start/quick_start.ipynb)
    - [Train a Pre-trained Huggingface Model](https://github.com/aws/amazon-sagemaker-examples/blob/main/sagemaker-remote-function/huggingface_text_classification/huggingface.ipynb)
    - 
---


## 1. 환경 설정
- 시작을 위해서 여기 가이드 보세요: [Setup Guide](../setup/README.md)
- 이후에 아래 셀 실행을 통해서, 필요한 패키지가 설치 되었는지를 확인 합니다.

In [10]:
! pip list | grep -E "sagemaker|vllm|trl|datasets"

datasets                          3.2.0
sagemaker                         2.239.0
sagemaker-core                    1.0.21
trl                               0.14.0
vllm                              0.7.2


## System Prompt , COT Format 정의

In [11]:
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer

# Load and prep dataset

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

## 데이터 셋 준비

In [12]:
def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

## 리워드 정의

In [13]:
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

## 훈련 파라미터 정의
- 모델 정의
- 훈련 파라미터

In [14]:
model_name = "Qwen/Qwen2.5-0.5B-Instruct"

output_dir="outputs/Qwen-0.5B-GRPO"
run_name="Qwen-0.5B-GRPO-gsm8k"

training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_generations=16,
    max_prompt_length=256,
    max_completion_length=200,
    num_train_epochs=1,
    save_steps=100,
    max_grad_norm=0.1,
    log_on_each_node=False,
    use_vllm=True,
    vllm_gpu_memory_utilization=.3,
    vllm_device="cuda:0",
    report_to="none" #I'm disabling Wandb.
)



## 모델 로딩, 토큰나이저 로딩

In [15]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=None
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

## S3 Root 폴더 정의

In [16]:
import sagemaker

sm_session = sagemaker.Session()

s3_root_folder = f"s3://{sm_session.default_bucket()}/deepseek/qrpo"


## 리모트르로 SageMaker Training Job 실행

In [17]:
from sagemaker.remote_function import remote

@remote(instance_type="ml.p4d.24xlarge", dependencies='./requirements.txt', keep_alive_period_in_seconds=3600)
def wrap_hf_trainer():
    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=[
            xmlcount_reward_func,
            soft_format_reward_func,
            strict_format_reward_func,
            int_reward_func,
            correctness_reward_func],
        args=training_args,
        train_dataset=dataset,
        #peft_config=peft_config
    )
    trainer.train()

In [18]:
wrap_hf_trainer()

2025-02-09 02:33:05,846 sagemaker.remote_function INFO     Serializing function code to s3://sagemaker-us-east-1-057716757052/wrap-hf-trainer-2025-02-09-02-33-05-846/function
2025-02-09 02:33:13,324 sagemaker.remote_function INFO     Serializing function arguments to s3://sagemaker-us-east-1-057716757052/wrap-hf-trainer-2025-02-09-02-33-05-846/arguments
2025-02-09 02:33:13,494 sagemaker.remote_function INFO     Copied dependencies file at './requirements.txt' to '/tmp/tmpd2uad8ai/temp_workspace/sagemaker_remote_function_workspace/requirements.txt'
2025-02-09 02:33:13,495 sagemaker.remote_function INFO     Successfully created workdir archive at '/tmp/tmpd2uad8ai/workspace.zip'
2025-02-09 02:33:13,543 sagemaker.remote_function INFO     Successfully uploaded workdir to 's3://sagemaker-us-east-1-057716757052/wrap-hf-trainer-2025-02-09-02-33-05-846/sm_rf_user_ws/workspace.zip'
2025-02-09 02:33:13,544 sagemaker.remote_function INFO     Creating job: wrap-hf-trainer-2025-02-09-02-33-05-846


2025-02-09 02:33:14 Starting - Starting the training job
2025-02-09 02:33:14 Pending - Training job waiting for capacity......
2025-02-09 02:34:16 Pending - Preparing the instances for training...........................
2025-02-09 02:38:39 Downloading - Downloading the training image.................................
2025-02-09 02:44:08 Training - Training image download completed. Training in progress......INFO: CONDA_PKGS_DIRS is set to '/opt/ml/sagemaker/warmpoolcache/sm_remotefunction_user_dependencies_cache/conda/pkgs'
INFO: PIP_CACHE_DIR is set to '/opt/ml/sagemaker/warmpoolcache/sm_remotefunction_user_dependencies_cache/pip'
INFO: /opt/ml/input/config/resourceconfig.json:
{"current_host":"algo-1","current_instance_type":"ml.p4d.24xlarge","current_group_name":"homogeneousCluster","hosts":["algo-1"],"instance_groups":[{"instance_group_name":"homogeneousCluster","instance_type":"ml.p4d.24xlarge","hosts":["algo-1"]}],"network_interface_name":"eth0"}INFO: Bootstraping runtime environ