# GRPO exploration

## Goal

Can we solve novel tasks using GRPO?

## Imports

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from tqdm.auto import tqdm
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from transformers import AutoTokenizer, AutoConfig
import matplotlib.pyplot as plt
import matplotlib as mpl
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import logging
from IPython.display import Markdown, display
import torch
import random
import wandb
from transformers import TrainerCallback, TrainerState, TrainerControl

from arc25.training_tasks import *
from arc25.encoders import create_grid_encoder
from arc25.prompting import create_prompt_from_task, pretty_print_prompt
from arc25.plot import plot_task, plot_grids_with_shape, plot_grid
from arc25.code_execution import safe_code_execution, validate_code
from arc25.utils import set_random_seed, get_timestamp
from arc25.logging import configure_logging, log_execution_time

configure_logging()

import sys
sys.path.append(os.path.realpath("../scripts"))
from finetuning import get_data_collator


plt.plot()
plt.close('all')
plt.rcParams["figure.figsize"] = (20, 5)
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 12

## Code

### Tasks definition

In [None]:
def get_task(task_name):
    tasks = []

    input_img = create_img((9, 9), color=0)
    output_img = input_img.copy()
    for x in range(0, input_img.shape[1], 1):
        draw_vertical_line(output_img, x, color=x+1)
    tasks.append(Task(inputs=[input_img], outputs=[output_img], code='', name='9-vertical-lines'))

    input_img = create_img((10, 8), color=0)
    output_img = input_img.copy()
    color = 0
    for x in range(0, input_img.shape[1], 2):
        for y in range(0, input_img.shape[0], 2):
            color = (color + 1) % 10
            if color == 0: color = 1
            draw_rectangle(output_img, (y, x), (y+1, x+1), color=color)
    tasks.append(Task(inputs=[input_img], outputs=[output_img], code='', name='20-squares'))

    input_img = create_img((6, 8), color=0)
    output_img = input_img.copy()
    color = 0
    for x in range(0, input_img.shape[1], 2):
        for y in range(0, input_img.shape[0], 2):
            color = (color + 1) % 10
            if color == 0: color = 1
            draw_rectangle(output_img, (y, x), (y+1, x+1), color=color)
    tasks.append(Task(inputs=[input_img], outputs=[output_img], code='', name='12-squares'))

    input_img = create_img((8, 8), color=0)
    output_img = input_img.copy()
    color = 0
    for x in range(0, input_img.shape[1], 2):
        for y in range(0, input_img.shape[0], 2):
            color = (color + 1) % 10
            if color == 0: color = 1
            draw_rectangle(output_img, (y, x), (y+1, x+1), color=color)
    tasks.append(Task(inputs=[input_img], outputs=[output_img], code='', name='16-squares'))

    for task in tasks:
        if task.name == task_name:
            return task
    raise ValueError(f"Task {task_name} not found. Available tasks: {[task.name for task in tasks]}")

### Load model

In [None]:
@log_execution_time
def load_model(base_model_path, lora_path):
    logging.info(f"Loading model from {base_model_path} and LoRA from {lora_path}")
    torch.cuda.empty_cache()
    model = AutoModelForCausalLM.from_pretrained(
        base_model_path, torch_dtype="auto", device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(lora_path)
    model = PeftModel.from_pretrained(model, lora_path, is_trainable=True)
    return model, tokenizer

### Reward definition

In [None]:
all_completions, all_rewards = [], []


def reward_len(completions, **kwargs):
    all_completions.append(completions)
    return [len(completion) for completion in completions]


def reward_accuracy(completions, **kwargs):
    all_completions.append(completions)
    accuracy = []
    for completion in completions:
        predicted_code = completion.replace('\n```', '')
        try:
            predicted_output = safe_code_execution(predicted_code, task.inputs)
            accuracy.append(float(np.mean(predicted_output[0] == task.outputs[0])))
        except Exception as e:
            print(f'Error executing code: {predicted_code}')
            print(e)
            accuracy.append(0)
    all_rewards.append(accuracy)
    return accuracy

### Calbacks

In [None]:
class RewardExtremaCallback(TrainerCallback):
    def on_log(self, args, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
        if logs is None or not all_rewards:
            return control

        # Grab the last raw batch of rewards
        last_batch = all_rewards[-1]
        rmax, rmin = float(np.max(last_batch)), float(np.min(last_batch))

        # 1) Inject into the HF logs dict so HF will record them
        logs["train/reward_accuracy/max"] = rmax
        logs["train/reward_accuracy/min"] = rmin
        logging.debug(f"Step {state.global_step}: Max reward: {rmax:.2f}, Min reward: {rmin:.2f}")

        # 2) Immediately push to W&B for real-time curves
        wandb.log({
            "train/reward_accuracy/max": rmax,
            "train/reward_accuracy/min": rmin,
        })

        return control

class StopOnRewardCallback(TrainerCallback):
    def __init__(self, threshold=1.0):
        super().__init__()
        self.threshold = threshold

    def on_step_end(self, args, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
        last_batch = all_rewards[-1]
        if np.max(last_batch) >= self.threshold:
            logger.info(f"Stopping training at step {state.global_step} as max reward {np.max(last_batch):.2f} is above threshold {self.threshold:.2f}")
            control.should_training_stop = True
            control.should_save = True
        return control

### Training

In [None]:
@log_execution_time
def grpo(cfg):
    global task
    task = get_task(cfg.task_name)
    plot_task(task); plt.suptitle('Task to solve'); plt.tight_layout(); plt.show()
    model, tokenizer = load_model(cfg.base_model_path, cfg.lora_path)

    # it's a little bit tricky but since we are not using SFT I should set is_train_prompt=False
    prompt = create_prompt_from_task(
        task, prompt_version=cfg.prompt_version, grid_encoder=cfg.grid_encoder, tokenizer=tokenizer, is_train_prompt=False)
    train_dataset = Dataset.from_dict({'prompt': [prompt]})


    per_device_train_batch_size = 4 # initially 8, but had to lower it do to OOM
    run_name = f"{task.name}_lr{cfg.learning_rate:.0e}_n{cfg.num_generations}"
    output_dir = f"/mnt/hdd0/Kaggle/arc25/trainings/20250508_GRPO/{get_timestamp()}_{run_name}"
    training_args = GRPOConfig(
        output_dir=output_dir,
        run_name=run_name,
        num_train_epochs=cfg.num_train_epochs,
        logging_steps=1,
        num_generations=cfg.num_generations, # default is 8
        gradient_accumulation_steps=cfg.num_generations//per_device_train_batch_size,
        per_device_train_batch_size=per_device_train_batch_size, # default is 8
        generation_batch_size=cfg.num_generations, # needs the latest version of trl
        temperature=0.5, # default is 0.9
        top_p=0.95, # default i 1.0
        max_completion_length=768, # default is 256
        max_prompt_length=8192, # default is 512
        learning_rate=cfg.learning_rate, # default is 1e-6
        lr_scheduler_type='constant', # default is 'linear'
        log_completions=False,
    )

    os.environ['WANDB_PROJECT'] = "20250508_GRPO_v2"
    # os.environ['WANDB_NAME'] = os.path.basename(output_dir)
    # os.environ['WANDB_JOB_NAME'] = os.path.basename(output_dir)
    os.environ['WANDB_DIR'] = output_dir

    trainer = GRPOTrainer(
        model=model,
        reward_funcs=reward_accuracy,
        args=training_args,
        train_dataset=train_dataset,
        processing_class=tokenizer,
        callbacks=[StopOnRewardCallback(threshold=1.0), RewardExtremaCallback()],
    )
    trainer.train()

## First experiments

In [None]:
@dataclass
class Config:
    task_name: str = '12-squares' # '9-vertical-lines', '20-squares' or '12-squares'
    learning_rate: float = 4e-6
    num_generations: int = 16
    num_train_epochs: int = 400

    base_model_path: str = '/home/gbarbadillo/models/Qwen2.5-Coder-0.5B-Instruct'
    lora_path: str = '/mnt/hdd0/Kaggle/arc25/trainings/20250430_first_trainings/steps_6400/checkpoint-6400'
    prompt_version: str = 'code-from-examples-v3'
    grid_encoder = create_grid_encoder('GridShapeEncoder(RowNumberEncoder(MinimalGridEncoder()))')

cfg = Config()
grpo(cfg)

In [None]:
print(f'Number of completions: {len(all_completions)}')

In [None]:
plt.plot([np.mean([len(completion) for completion in completions]) for completions in all_completions])
plt.plot([np.max([len(completion) for completion in completions]) for completions in all_completions])
plt.plot([np.min([len(completion) for completion in completions]) for completions in all_completions])
plt.xlabel('Epoch')
plt.ylabel('Mean completion length')
plt.grid();

In [None]:
all_completions[-1]

In [None]:
print(f'Max reward obtained: {np.max(all_rewards)}')
if np.max(all_rewards) >= 1.0:
    print(f'Task was solved at epoch {np.argmax([np.max(accuracy) == 1 for accuracy in all_rewards])}')

In [None]:
plt.plot([np.max(accuracy) for accuracy in all_rewards])
plt.plot([np.mean(accuracy) for accuracy in all_rewards])
plt.plot([np.min(accuracy) for accuracy in all_rewards])
plt.title('Reward accuracy')
plt.grid()

## Learnings

Dummy reward experiments:

- Running a first dummy train with a reward based on output length for 100 epochs took around 21 minutes
- A second run after fixing the initial prompt took 17 minutes
- The effective batch size has to be greater or equal to the number of generations.

True reward experiments:

- 100 epochs takes 13 min
- 200 epochs takes around 30 min, and after epoch 100 almost always generates the correct solution
- Increasing the learning rate to 2e-6 results on solving the task on epoch 20
- Increasing the learning rate to 4e-6 results on solving the task on epoch 11
- With a learning rate of 1e-5 it is solved at epoch 3, but collapses at epoch 20
- when trying the 20 squares task I get OOM
- I have increased the number of generations to 128, but it diverges on the 16 squares task. It has run
  for 46 epochs, for 48 minutes.
- With lr=1e-6, 128 generations, it takes 1h35 to do 100 epochs and diverges. Let's go back to 8 generations and use more epochs.
- I'm not sure why I see timeout errors, but I increased the maximum number of generated tokens from 1024 to 2048 because it might be the root of the problem. However this causes OOM errors because it predicts very long functions.
- If I try with 12 squares it is able to solve it at step 121 with lr=1e-6
- When trying with lr=2e-6 I get another OOM error, sometimes it makes a very long prediction. Same with lr=4e-6
- I decrease the batch size per device and then I get timeout error because it is generating a 2k tokens function WTF. I have finally solved the problem with the timeouts. Now I could set a maximum number of tokens, and give bad reward when it is exceeded.

After solving the problem with timeouts I'm going to set the maximum number of tokens to 768, solving the task with 25 squares needed less than 600 tokens. That will enforce the model to keep the functions short.

[wandb](https://wandb.ai/guillermobarbadillo/20250508_GRPO_v2?nw=nwuserguillermobarbadillo)

9-vertical-lines:

- Solved on 61 steps with lr=1e-6 and num_generations=8
- With lr=2e-6 it is solved at step 21, 3m39s
- With lr=4e-6 it is solved in 48s, 5 steps

12-squares:

- with lr=4e-6 solves the task in 1194s, 73 steps.
- If I increase the number of generations to 16, it takes 150 steps and almost 100 minutes
- with lr=2e-5 and 16 generations it does not converge despite spending 140 minutes

16-squares:

- with lr=4e-6 the training diverges
- with lr=2e-6 I have stopped the training at epoch 123 and 70 minutes. It might seem that the training was going to diverge. At least I need something much faster.

So far it seems that GRPO is much less efficient than HER. Or maybe I haven't found the right configuration.

## TODO

- [x] Implement a reward function
- [x] Set wandb project
- [x] wandb save dir
- [x] Check if I can see the completions done during training
- [x] Stop criteria
- [x] Log more metrics such as max and min reward
- [x] Try with more complex tasks
- [ ] Save the outputs to view the evolution
- [ ] Better reward and output tokens distribution evolution