# GRPO exploration

## Goal

Can we solve novel tasks using GRPO?

## Imports

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from tqdm.auto import tqdm
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from transformers import AutoTokenizer, AutoConfig
import matplotlib.pyplot as plt
import matplotlib as mpl
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import logging
from IPython.display import Markdown, display
import torch
import random


from arc25.training_tasks import *
from arc25.encoders import create_grid_encoder
from arc25.prompting import create_prompt_from_task, pretty_print_prompt
from arc25.plot import plot_task, plot_grids_with_shape, plot_grid
from arc25.code_execution import safe_code_execution, validate_code
from arc25.utils import set_random_seed, get_timestamp
from arc25.logging import configure_logging, log_execution_time

configure_logging()

import sys
sys.path.append(os.path.realpath("../scripts"))
from finetuning import get_data_collator


plt.plot()
plt.close('all')
plt.rcParams["figure.figsize"] = (20, 5)
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 12

## Code

In [None]:
EpochResults = namedtuple("EpochResults", ["best_prediction", 'pixel_accuracies'])

@log_execution_time
def hindsight_experience_replay(task, cfg):
    """
    Use hindsight experience replay to try to solve new tasks
    """
    plot_task(task); plt.suptitle('Task to solve'); plt.tight_layout(); plt.show()
    model, tokenizer = load_model(cfg.base_model_path, cfg.lora_path)
    metrics = []
    for epoch in range(cfg.max_epochs):
        logging.info(f'Starting epoch {epoch}...')
        new_tasks, pixel_accuracies = inference(
            task, model, tokenizer, cfg.grid_encoder, cfg.prompt_version,
            n_predictions=cfg.n_predictions)
        metrics.append(EpochResults(best_prediction=new_tasks[-1], pixel_accuracies=pixel_accuracies))
        plot_metrics_evolution(metrics)
        if np.max(pixel_accuracies) == 1:
            logger.info(f'Found a perfect prediction at epoch {epoch}!')
            break
        if not cfg.use_accuracy_for_sorting:
            logging.info('Shuffling the tasks, no information about the accuracy is used')
            random.shuffle(new_tasks)
        finetuning(new_tasks, model, tokenizer, cfg.grid_encoder, cfg.prompt_version)
    display(Markdown(f'# Best prediction code\n\n```python\n{metrics[-1].best_prediction.code}\n```'))
    return metrics

@log_execution_time
def load_model(base_model_path, lora_path):
    logging.info(f"Loading model from {base_model_path} and LoRA from {lora_path}")
    torch.cuda.empty_cache()
    model = AutoModelForCausalLM.from_pretrained(
        base_model_path, torch_dtype="auto", device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(lora_path)
    model = PeftModel.from_pretrained(model, lora_path, is_trainable=True)
    return model, tokenizer

@log_execution_time
def inference(task, model, tokenizer, grid_encoder, prompt_version, n_predictions=256):
    prompt = create_prompt_from_task(
        task, prompt_version=prompt_version, grid_encoder=grid_encoder, tokenizer=tokenizer, is_train_prompt=False)
    model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=1024,
        do_sample=True,
        temperature=0.5,
        top_p=0.95,
        num_return_sequences=n_predictions
    )
    generated_ids = generated_ids[:, len(model_inputs.input_ids[0]):]
    predictions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    predicted_codes = [prediction.replace('\n```', '') for prediction in predictions]
    new_tasks = []
    pixel_accuracies = []
    for predicted_code in tqdm(predicted_codes):
        try:
            predicted_output = safe_code_execution(predicted_code, task.inputs)
            validated_code = validate_code(predicted_code, task.inputs)
            new_tasks.append(Task(inputs=task.inputs, outputs=predicted_output, code=validated_code, name=task.name))
            pixel_accuracies.append(float(np.mean(new_tasks[-1].outputs[0] == task.outputs[0])))
        except Exception as e:
                print(f'Error executing code: {predicted_code}')
                print(e)

    new_tasks_with_unique_outputs = [new_tasks[0]]
    filtered_pixel_accuracies = []
    for new_task in new_tasks[1:]:
        if not any([np.all(new_task.outputs[0] == t.outputs[0]) for t in new_tasks_with_unique_outputs]):
            new_tasks_with_unique_outputs.append(new_task)
            filtered_pixel_accuracies.append(float(np.mean(new_task.outputs[0] == task.outputs[0])))
    logging.info(f'Number of unique outputs: {len(new_tasks_with_unique_outputs)}/{len(new_tasks)}')
    logging.info(f'Max pixel accuracy: {max(pixel_accuracies)}')
    new_tasks_with_unique_outputs = sorted(new_tasks_with_unique_outputs, key=lambda x: float(np.mean(x.outputs[0] == task.outputs[0])), reverse=False)
    return new_tasks_with_unique_outputs, pixel_accuracies

@log_execution_time
def finetuning(new_tasks, model, tokenizer, grid_encoder, prompt_version):
    prompts = []
    for task in new_tasks:
        prompts.append(create_prompt_from_task(
    task, prompt_version=prompt_version, grid_encoder=grid_encoder, tokenizer=tokenizer, is_train_prompt=True))
    train_dataset = Dataset.from_dict({'text': prompts})

    training_arguments = SFTConfig(
        output_dir=None, #'/mnt/hdd0/Kaggle/arc25/trainings/20250505_TTT/debug',
        save_strategy='no',
        num_train_epochs=1,
        warmup_ratio=0.1,
        learning_rate=1e-5,
        lr_scheduler_type='constant_with_warmup', #constant_with_warmup, cosine, cosine_with_restarts
        # lr_scheduler_kwargs=lr_scheduler_kwargs,
        gradient_checkpointing=False,
        optim="paged_adamw_8bit",
        max_grad_norm=1.0,

        dataset_text_field="text",
        max_seq_length=4096,

        do_eval=True,
        eval_strategy="no", #TODO: previously it was steps
        # save_steps=cfg.save_steps or cfg.eval_steps,
        logging_steps=10, #50,
        log_level="info",
        report_to='none',

        # parameters added to make the code work with accelerate
        # dispatch_batches=False,
        # https://huggingface.co/transformers/v4.9.1/main_classes/trainer.html#trainingarguments
        ddp_find_unused_parameters=False, # only used with accelerate, got a warning saying that it slows down if True

        ignore_data_skip=True, # otherwise it takes too long to start training when resuming from checkpoint

        per_device_train_batch_size=1,
        gradient_accumulation_steps=1,
    )

    trainer = SFTTrainer(
        model=model,
        processing_class=tokenizer,
        train_dataset=train_dataset,
        data_collator=get_data_collator(tokenizer),
        args=training_arguments,
    )
    trainer.train()


def plot_best_prediction(task, best_prediction, accuracy):
    plot_grids_with_shape(task.outputs + best_prediction.outputs, suptitle=f'Best prediction accuracy: {accuracy:.1%}')
    display(Markdown(f'```python\n{best_prediction.code}\n```'))


def plot_metrics_evolution(metrics):
    plot_score_histograms(metrics)

    for epoch, epoch_results in enumerate(metrics):
        plt.subplot(1, len(metrics), epoch + 1)
        plot_grid(epoch_results.best_prediction.outputs[0])
        plt.title(f'Epoch {epoch} acc: {max(epoch_results.pixel_accuracies):.1%}')
    plt.suptitle('Evolution of best predictions')
    plt.tight_layout()
    plt.show()


def plot_score_histograms(metrics, offset_scale=1):
    """
    Plots stacked (y-offset) histograms
    """
    cmap = mpl.colormaps['viridis']#get_cmap("viridis")
    norm = plt.Normalize(0, len(metrics) - 1)
    bins = np.linspace(0, 1, 100)
    bin_centers = 0.5 * (bins[1:] + bins[:-1])

    plt.figure(figsize=(10, 6))
    for i, epoch_results in enumerate(metrics):
        color = cmap(norm(i))
        counts, _ = np.histogram(epoch_results.pixel_accuracies, bins=bins)
        counts = np.log1p(counts)
        offset = i * np.max(counts) * offset_scale  # Add spacing between histograms
        plt.fill_between(bin_centers, offset, counts + offset, color=color, label=f'Epoch {i}', alpha=0.5)

    plt.xlabel("Pixel accuracy")
    plt.ylabel("Epoch ->")
    plt.title("Evolution of pixel accuracy")
    plt.yticks([])  # Hide y-ticks since they don't represent absolute values
    plt.grid(axis='x')
    plt.tight_layout()
    plt.show()

## First experiments

In [None]:
@dataclass
class Config:
    base_model_path: str = '/home/gbarbadillo/models/Qwen2.5-Coder-0.5B-Instruct'
    lora_path: str = '/mnt/hdd0/Kaggle/arc25/trainings/20250430_first_trainings/steps_6400/checkpoint-6400'
    prompt_version: str = 'code-from-examples-v3'
    grid_encoder = create_grid_encoder('GridShapeEncoder(RowNumberEncoder(MinimalGridEncoder()))')
    max_epochs: int = 10
    n_predictions: int = 256 # 256 seems to be the best for my hardware
    use_accuracy_for_sorting: bool = True

cfg = Config()

In [None]:
input_img = create_img((9, 9), color=0)
output_img = input_img.copy()
for x in range(0, input_img.shape[1], 1):
    draw_vertical_line(output_img, x, color=x+1)

task = Task(inputs=[input_img], outputs=[output_img], code='', name='manual')
plot_task(task); plt.suptitle('Task to solve'); plt.tight_layout(); plt.show()

In [None]:
model, tokenizer = load_model(cfg.base_model_path, cfg.lora_path)

In [None]:
# it's a little bit tricky but since we are not using SFT I should set is_train_prompt=False
prompt = create_prompt_from_task(
    task, prompt_version=cfg.prompt_version, grid_encoder=cfg.grid_encoder, tokenizer=tokenizer, is_train_prompt=False)
train_dataset = Dataset.from_dict({'prompt': [prompt]})

In [None]:
all_completions, accuracy_evolution = [], []


def reward_len(completions, **kwargs):
    all_completions.append(completions)
    return [len(completion) for completion in completions]


def reward_accuracy(completions, **kwargs):
    all_completions.append(completions)
    accuracy = []
    for completion in completions:
        predicted_code = completion.replace('\n```', '')
        try:
            predicted_output = safe_code_execution(predicted_code, task.inputs)
            accuracy.append(float(np.mean(predicted_output[0] == task.outputs[0])))
        except Exception as e:
            print(f'Error executing code: {predicted_code}')
            print(e)
            accuracy.append(0)
    accuracy_evolution.append(accuracy)
    return accuracy

learning_rate = 1e-5
training_args = GRPOConfig(
    output_dir=f"/mnt/hdd0/Kaggle/arc25/trainings/20250508_GRPO/{get_timestamp()}_lr{learning_rate:.0e}",
    num_train_epochs=100,
    logging_steps=1,
    num_generations=8, # default is 8
    per_device_train_batch_size=8, # default is 8
    temperature=0.5, # default is 0.9
    learning_rate=learning_rate, # default is 1e-6
)
trainer = GRPOTrainer(
    model=model,
    reward_funcs=reward_accuracy,
    args=training_args,
    train_dataset=train_dataset,
    processing_class=tokenizer,
)
trainer.train()

In [None]:
print(f'Number of completions: {len(all_completions)}')
np.unique([len(completion) for completion in all_completions])

In [None]:
plt.plot([np.mean([len(completion) for completion in completions]) for completions in all_completions])
plt.xlabel('Epoch')
plt.ylabel('Mean completion length')
plt.grid();

In [None]:
all_completions[-1]

In [None]:
np.max(accuracy_evolution)

In [None]:
plt.plot([np.max(accuracy) for accuracy in accuracy_evolution])
plt.plot([np.mean(accuracy) for accuracy in accuracy_evolution])
plt.plot([np.min(accuracy) for accuracy in accuracy_evolution])
plt.grid()

In [None]:
np.argmax([np.max(accuracy) == 1 for accuracy in accuracy_evolution])

## Learnings

Dummy reward experiments:

- Running a first dummy train with a reward based on output length for 100 epochs took around 21 minutes
- A second run after fixing the initial prompt took 17 minutes
- The effective batch size has to be greater or equal to the number of generations.

True reward experiments:

- 100 epochs takes 13 min
- 200 epochs takes around 30 min, and after epoch 100 almost always generates the correct solution
- Increasing the learning rate to 2e-6 results on solving the task on epoch 20
- Increasing the learning rate to 4e-6 results on solving the task on epoch 11
- With a learning rate of 1e-5 it is solved at epoch 3, but collapses at epoch 20

cst with warmup learning rate

## TODO

- [x] Implement a reward function
- [ ] Set wandb project
- [x] Check if I can see the completions done during training
- [ ] Stop criteria
- [ ] Log more metrics such as max and min reward
- [ ] Try with more complex tasks