# GRPO exploration

## Goal

Can we solve novel tasks using GRPO?

## Imports

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from tqdm.auto import tqdm
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from transformers import AutoTokenizer, AutoConfig
import matplotlib.pyplot as plt
import matplotlib as mpl
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import logging
from IPython.display import Markdown, display
import torch
import random


from arc25.training_tasks import *
from arc25.encoders import create_grid_encoder
from arc25.prompting import create_prompt_from_task, pretty_print_prompt
from arc25.plot import plot_task, plot_grids_with_shape, plot_grid
from arc25.code_execution import safe_code_execution, validate_code
from arc25.utils import set_random_seed, get_timestamp
from arc25.logging import configure_logging, log_execution_time

configure_logging()

import sys
sys.path.append(os.path.realpath("../scripts"))
from finetuning import get_data_collator


plt.plot()
plt.close('all')
plt.rcParams["figure.figsize"] = (20, 5)
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 12

In [None]:
code = """def task(img):
    import time
    time.sleep(2)
"""
try:
    safe_code_execution(code, [create_img((3, 3))])
except Exception as e:
    print(f"Error: {e}")
    print("Code execution failed.")
# this works, I don't understand why it does not in the training

## Code

In [None]:
@log_execution_time
def load_model(base_model_path, lora_path):
    logging.info(f"Loading model from {base_model_path} and LoRA from {lora_path}")
    torch.cuda.empty_cache()
    model = AutoModelForCausalLM.from_pretrained(
        base_model_path, torch_dtype="auto", device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(lora_path)
    model = PeftModel.from_pretrained(model, lora_path, is_trainable=True)
    return model, tokenizer

## First experiments

In [None]:
@dataclass
class Config:
    base_model_path: str = '/home/gbarbadillo/models/Qwen2.5-Coder-0.5B-Instruct'
    lora_path: str = '/mnt/hdd0/Kaggle/arc25/trainings/20250430_first_trainings/steps_6400/checkpoint-6400'
    prompt_version: str = 'code-from-examples-v3'
    grid_encoder = create_grid_encoder('GridShapeEncoder(RowNumberEncoder(MinimalGridEncoder()))')
    max_epochs: int = 10
    n_predictions: int = 256 # 256 seems to be the best for my hardware
    use_accuracy_for_sorting: bool = True

cfg = Config()

In [None]:
input_img = create_img((9, 9), color=0)
output_img = input_img.copy()
for x in range(0, input_img.shape[1], 1):
    draw_vertical_line(output_img, x, color=x+1)

task = Task(inputs=[input_img], outputs=[output_img], code='', name='9-vertical-lines')
plot_task(task); plt.suptitle('Task to solve'); plt.tight_layout(); plt.show()

In [None]:
input_img = create_img((10, 8), color=0)
output_img = input_img.copy()
color = 0
for x in range(0, input_img.shape[1], 2):
    for y in range(0, input_img.shape[0], 2):
        color = (color + 1) % 10
        if color == 0: color = 1
        draw_rectangle(output_img, (y, x), (y+1, x+1), color=color)
task = Task(inputs=[input_img], outputs=[output_img], code='', name='20-squares')
plot_task(task); plt.suptitle('Task to solve'); plt.tight_layout(); plt.show()

In [None]:
model, tokenizer = load_model(cfg.base_model_path, cfg.lora_path)

In [None]:
# it's a little bit tricky but since we are not using SFT I should set is_train_prompt=False
prompt = create_prompt_from_task(
    task, prompt_version=cfg.prompt_version, grid_encoder=cfg.grid_encoder, tokenizer=tokenizer, is_train_prompt=False)
train_dataset = Dataset.from_dict({'prompt': [prompt]})

In [None]:
all_completions, all_rewards = [], []


def reward_len(completions, **kwargs):
    all_completions.append(completions)
    return [len(completion) for completion in completions]


def reward_accuracy(completions, **kwargs):
    all_completions.append(completions)
    accuracy = []
    for completion in completions:
        predicted_code = completion.replace('\n```', '')
        try:
            predicted_output = safe_code_execution(predicted_code, task.inputs)
            accuracy.append(float(np.mean(predicted_output[0] == task.outputs[0])))
        except Exception as e:
            print(f'Error executing code: {predicted_code}')
            print(e)
            accuracy.append(0)
    all_rewards.append(accuracy)
    return accuracy

learning_rate = 1e-6
num_generations = 8
num_train_epochs = 400
per_device_train_batch_size = 4
output_dir = f"/mnt/hdd0/Kaggle/arc25/trainings/20250508_GRPO/{get_timestamp()}_{task.name}_lr{learning_rate:.0e}_n{num_generations}"
training_args = GRPOConfig(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    logging_steps=1,
    num_generations=num_generations, # default is 8
    gradient_accumulation_steps=num_generations//per_device_train_batch_size,
    per_device_train_batch_size=per_device_train_batch_size, # default is 8
    generation_batch_size=num_generations, # needs the latest version of trl
    temperature=0.5, # default is 0.9
    top_p=0.95, # default i 1.0
    max_completion_length=2048, # default is 256
    max_prompt_length=8192, # default is 512
    learning_rate=learning_rate, # default is 1e-6
    lr_scheduler_type='constant' # default is 'linear'
    log_completions=True,
)

os.environ['WANDB_PROJECT'] = "20250508_GRPO"
os.environ['WANDB_NAME'] = os.path.basename(output_dir)
os.environ['WANDB_JOB_NAME'] = os.path.basename(output_dir)
os.environ['WANDB_DIR'] = output_dir

import wandb
from transformers import TrainerCallback, TrainerState, TrainerControl

class RewardExtremaCallback(TrainerCallback):
    def on_log(self, args, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
        if logs is None or not all_rewards:
            return control

        # Grab the last raw batch of rewards
        last_batch = all_rewards[-1]
        rmax, rmin = float(np.max(last_batch)), float(np.min(last_batch))

        # 1) Inject into the HF logs dict so HF will record them
        logs["train/reward_accuracy/max"] = rmax
        logs["train/reward_accuracy/min"] = rmin
        logging.info(f"Step {state.global_step}: Max reward: {rmax:.2f}, Min reward: {rmin:.2f}")

        # 2) Immediately push to W&B for real-time curves
        wandb.log({
            "train/reward_accuracy/max": rmax,
            "train/reward_accuracy/min": rmin,
        })

        return control

class StopOnRewardCallback(TrainerCallback):
    def __init__(self, threshold=1.0):
        super().__init__()
        self.threshold = threshold

    def on_step_end(self, args, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
        last_batch = all_rewards[-1]
        if np.max(last_batch) >= self.threshold:
            logger.info(f"Stopping training at step {state.global_step} as max reward {np.max(last_batch):.2f} is above threshold {self.threshold:.2f}")
            control.should_training_stop = True
            control.should_save = True
        return control

trainer = GRPOTrainer(
    model=model,
    reward_funcs=reward_accuracy,
    args=training_args,
    train_dataset=train_dataset,
    processing_class=tokenizer,
    callbacks=[StopOnRewardCallback(threshold=1.0), RewardExtremaCallback()],
)
trainer.train()

In [None]:
help(GRPOConfig)

In [None]:
print(f'Number of completions: {len(all_completions)}')
np.unique([len(completion) for completion in all_completions])

In [None]:
plt.plot([np.mean([len(completion) for completion in completions]) for completions in all_completions])
plt.xlabel('Epoch')
plt.ylabel('Mean completion length')
plt.grid();

In [None]:
all_completions[-1]

In [None]:
np.max(all_rewards)

In [None]:
plt.plot([np.max(accuracy) for accuracy in all_rewards])
plt.plot([np.mean(accuracy) for accuracy in all_rewards])
plt.plot([np.min(accuracy) for accuracy in all_rewards])
plt.grid()

In [None]:
np.argmax([np.max(accuracy) == 1 for accuracy in all_rewards])

## Learnings

Dummy reward experiments:

- Running a first dummy train with a reward based on output length for 100 epochs took around 21 minutes
- A second run after fixing the initial prompt took 17 minutes
- The effective batch size has to be greater or equal to the number of generations.

True reward experiments:

- 100 epochs takes 13 min
- 200 epochs takes around 30 min, and after epoch 100 almost always generates the correct solution
- Increasing the learning rate to 2e-6 results on solving the task on epoch 20
- Increasing the learning rate to 4e-6 results on solving the task on epoch 11
- With a learning rate of 1e-5 it is solved at epoch 3, but collapses at epoch 20
- when trying the 16 squares task I get OOM
- I have increased the number of generations to 128, but it diverges on the 16 squares task. It has run
  for 46 epochs, for 48 minutes.
- With lr=1e-6, 128 generations, it takes 1h35 to do 100 epochs and diverges. Let's go back to 8 generations and use more epochs.
- I'm not sure why I see timeout errors, but I increased the maximum number of generated tokens from 1024 to 2048 because it might be the root of the problem.

cst with warmup learning rate

So far it seems that GRPO is much less efficient than HER. Or maybe I haven't found the right configuration.

Sometimes I get TimeoutException

```
---------------------------------------------------------------------------
TimeoutException                          Traceback (most recent call last)
Cell In[8], line 96
     86         return control
     88 trainer = GRPOTrainer(
     89     model=model,
     90     reward_funcs=reward_accuracy,
   (...)
     94     callbacks=[StopOnRewardCallback(threshold=1.0), RewardExtremaCallback()],
     95 )
---> 96 trainer.train()

File ~/miniconda3/envs/arc25/lib/python3.10/site-packages/transformers/trainer.py:2245, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2243         hf_hub_utils.enable_progress_bars()
   2244 else:
-> 2245     return inner_training_loop(
   2246         args=args,
   2247         resume_from_checkpoint=resume_from_checkpoint,
   2248         trial=trial,
   2249         ignore_keys_for_eval=ignore_keys_for_eval,
   2250     )

File ~/miniconda3/envs/arc25/lib/python3.10/site-packages/transformers/trainer.py:2560, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2553 context = (
   2554     functools.partial(self.accelerator.no_sync, model=model)
   2555     if i != len(batch_samples) - 1
   2556     and self.accelerator.distributed_type != DistributedType.DEEPSPEED
   2557     else contextlib.nullcontext
   2558 )
   2559 with context():
-> 2560     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
   2562 if (
   2563     args.logging_nan_inf_filter
   2564     and not is_torch_xla_available()
   2565     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   2566 ):
   2567     # if loss is nan or inf simply add the average of previous logged losses
   2568     tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File ~/miniconda3/envs/arc25/lib/python3.10/site-packages/transformers/trainer.py:3736, in Trainer.training_step(self, model, inputs, num_items_in_batch)
   3733     return loss_mb.reduce_mean().detach().to(self.args.device)
   3735 with self.compute_loss_context_manager():
-> 3736     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
   3738 del inputs
   3739 if (
   3740     self.args.torch_empty_cache_steps is not None
   3741     and self.state.global_step % self.args.torch_empty_cache_steps == 0
   3742 ):

File ~/miniconda3/envs/arc25/lib/python3.10/site-packages/trl/extras/profiling.py:96, in profiling_decorator.<locals>.wrapper(self, *args, **kwargs)
     93 @functools.wraps(func)
     94 def wrapper(self, *args, **kwargs):
     95     with profiling_context(self, func.__name__):
---> 96         return func(self, *args, **kwargs)

File ~/miniconda3/envs/arc25/lib/python3.10/site-packages/trl/trainer/grpo_trainer.py:1312, in GRPOTrainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
   1310     return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs)
   1311 else:
-> 1312     return self._compute_loss(model, inputs)

File ~/miniconda3/envs/arc25/lib/python3.10/site-packages/trl/trainer/grpo_trainer.py:1370, in GRPOTrainer._compute_loss(self, model, inputs)
   1368 if self.beta != 0.0:
   1369     mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
-> 1370     self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).nanmean().item())
   1372 # Compute the clipped probability ratios
   1373 is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)

File /mnt/hdd0/MEGA/AI/22_Kaggle/arc25/arc25/code_execution.py:85, in timeout_handler(signum, frame)
     84 def timeout_handler(signum, frame):
---> 85     raise TimeoutException("Code execution exceeded time limit!")

TimeoutException: Code execution exceeded time limit!
```

## TODO

- [x] Implement a reward function
- [x] Set wandb project
- [x] wandb save dir
- [x] Check if I can see the completions done during training
- [x] Stop criteria
- [x] Log more metrics such as max and min reward
- [ ] Try with more complex tasks