<a href="https://colab.research.google.com/github/ickma2311/mycolab/blob/main/LLM_GRPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GRPO Explanation


[CN](https://photos.google.com/photo/AF1QipMyeNzqcwt2s9b-tpfxHpky1dTBSf7yfm1GOtGn)

PPO’s Clipped Objective

The Proximal Policy Optimization (PPO) “clipped” loss is defined as:

$$
{L} = \mathbb{E}\Bigl[\min\bigl(\rho(\theta)\,A,\;\mathrm{clip}(\rho(\theta),\,1-\epsilon,\,1+\epsilon)\,A\bigr)\Bigr].
$$

### Components
*	Probability ratio

$\rho(\theta) = \frac{\pi_\theta(a\mid s)}{\pi_{\theta_{\rm old}}(a\mid s)}$

measures how much the new policy \pi_\theta differs from the old policy $\pi_{\theta_{\rm old}} $ in state s taking action a.


* Advantage A

$A = r + \gamma\,V(s’) - V(s)$

(or another estimator) reflects how much better taking action a in state s is compared to the baseline.

* Clipping

  * Unclipped term: $ \rho(\theta)\,A $    
  * Clipped term: $\mathrm{clip}(\rho(\theta),\,1-\epsilon,\,1+\epsilon)\,A$.    
  * We take the minimum of these two so that when $\rho$ moves outside $[1-\epsilon,1+\epsilon]$, the update is limited.    

* Expectation
$\mathbb{E}[\cdot]$ indicates averaging over all sampled state–action pairs (e.g., a batch).

### Intuition
1.	No change.   
When $\rho(\theta)\approx1$, the loss reduces to the standard policy-gradient term $\rho A$.
2.	Preventing large updates
If $\rho$ deviates from 1 by more than \epsilon, the clipped term caps it at $1\pm\epsilon$.    
	•	For $A>0$, taking the clipped term prevents overly large increases.    
	•	For $A<0$, taking the unclipped term prevents overly large decreases.     
3.	Stable, efficient learning
This design keeps the benefits of importance sampling (low variance) while enforcing a trust-region–like constraint, all without requiring second-order methods.


# Eval original Model

In [None]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft "trl==0.15.2" triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt


In [None]:
%%capture

from unsloth import FastLanguageModel, is_bfloat16_supported
import torch
max_seq_length = 4096 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "Qwen/Qwen2.5-3B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.3, # Reduce if out of memory
)


## Evaluate


In [None]:
# %%capture
!pip install fsspec==2023.9.2

from datasets import load_dataset


ds = load_dataset("TIGER-Lab/MMLU-Pro")

In [None]:
test_ds=ds['test'].select(range(1000))

In [None]:
from string import ascii_uppercase
ascii_uppercase

In [None]:
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>

"""
from tqdm import tqdm
from string import ascii_uppercase

text=[]
for item in tqdm(test_ds):
    t=f"The following are single choice questions  about {item['category']}. Answer the choice index, e.g. A/B/C ...."
    t+=SYSTEM_PROMPT
    t+=f'\nQuestion: {item["question"]}'
    choices=item['options']
    for index,choice in zip(ascii_uppercase[:len(choices)],choices):
        t+=f'\n{index}) {choice}'
    t+='\n Answer:'

    text.append(t)

text_with_template=[tokenizer.apply_chat_template([
        # {"role" : "system", "content" : SYSTEM_PROMPT},
        {"role":"user",
         'content':t}
    ],tokenize=False,add_generation_prompt=True)

      for t in text]

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.1,
    top_p = 0.95,
    max_tokens = 1024,
)

outputs=model.fast_generate(
    text_with_template,
    sampling_params = sampling_params,
    lora_request = None,
    use_tqdm=True

)




In [None]:

def parse_answer(output_):
  try:
    return output_.split('<answer>')[-1].split('</answer>')[0].strip()
  except:
    return ''


def correctness(output_):
  correct=0
  for o,a in zip(output_,test_ds['answer']):
    answer=parse_answer(o.outputs[0].text)
    if answer and answer[0]==a:
      correct+=1
  return correct/len(output_)



In [None]:
correctness(outputs)

# GRPO Train

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

In [None]:

def correctness_reward(completions,answer,**kwargs):
  return [5 if a==parse_answer(c) else 0 for (c,a) in zip(completions,answer)]

def len_reward(completions,**kwargs):
  return [len(c)//1000 if len(c)<3000 else 0 for c in completions]

def format_reward_cal(c,a):
  reward=0
  if c.find('<reasoning>')!=-1:
    reward+=0.25
  if c.find('</reasoning>')!=-1:
    reward+=0.25
  if c.find('<answer>')!=-1:
    reward+=0.25
  if c.find('</answer>')!=-1:
    reward+=0.25
  if len(c.split('<answer>')[-1].split('</answer>')[0].strip())==len(a):
    reward+=1
  return reward

def format_reward(completions,answer,**kwargs):
  return [format_reward_cal(c,a) for c,a in zip(completions,answer)]

def exrtact_answer(a):
  try:
    return a.split('####')[1].strip()
  except:
    return ''

train_dataset_=load_dataset('openai/gsm8k','main')
train_dataset=[]
for item in train_dataset_['train']:
  prompt=tokenizer.apply_chat_template([
        {"role" : "system", "content" : SYSTEM_PROMPT},
        {"role":"user",
         'content':item['question']}],
                                       tokenize=False,
                                      add_generation_prompt=True)
  answer=exrtact_answer(item['answer'])
  train_dataset.append({'prompt':prompt,'answer':answer})




In [None]:
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-5,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 64,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 8, # Decrease if out of memory
    max_prompt_length = 256,
    max_completion_length = 1024,
    num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 200,
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
       correctness_reward,
       format_reward
      #  len_reward
    ],
    args = training_args,
    train_dataset = train_dataset,
)
trainer.train()

In [None]:
model.save_lora("grpo_saved_lora")

In [None]:
outputs=model.fast_generate(
    text_with_template,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
    use_tqdm=True
)


In [None]:
correctness(outputs)