# Improving BARC induction model with RL

## Goal

Create the code to do RL with the BARC induction model. 

Once it works it will be moved to a script.

## Server

Before running the notebook launch a server. 

```bash
export CUDA_VISIBLE_DEVICES=0; trl vllm-serve --max_model_len 12000 --model /home/gbarbadillo/models/Llama-3.1-ARC-Potpourri-Induction-8B
```

## Imports

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1" # 0 is used by the vllm server

from unsloth import FastLanguageModel
from dataclasses import dataclass
from datasets import Dataset

from trl import GRPOConfig, GRPOTrainer

from arc25.encoders import create_grid_encoder
from arc25.utils import load_arc_dataset_with_solutions
from arc25.data_augmentation import apply_data_augmentation, get_random_data_augmentation_params
from arc25.prompting import create_prompt_from_task
from arc25.collator import get_data_collator
from arc25.logging import configure_logging, logging

configure_logging()
logger = logging.getLogger(__name__)

## First steps

In [None]:
@dataclass
class cfg:
    # base model
    model_path: str = "/home/gbarbadillo/models/Llama-3.1-ARC-Potpourri-Induction-8B"
    load_in_4bit: bool = False
    max_seq_length: int = 12000
    grid_encoder: str = 'ColorNameEncoder()'
    # LoRA
    lora_r: int = 16
    use_rslora: bool = True
    # dataset
    dataset_path: str = "/mnt/hdd0/Kaggle/arc25/data/arc-prize-2024/arc-agi_training_challenges.json"
    output_dir: str = "/mnt/hdd0/Kaggle/arc25/trainings/2025-09-12-debug-grpo/first-steps"
    # training hyperparameters
    max_epochs: int = 1
    num_generations: int = 8
    training_batch_size: int = 1
    learning_rate: float = 1e-5

In [None]:
dataset = load_arc_dataset_with_solutions(cfg.dataset_path)
print(f"Loaded {len(dataset)} tasks from {cfg.dataset_path}")

In [None]:
llm, tokenizer = FastLanguageModel.from_pretrained(
    cfg.model_path, load_in_4bit=cfg.load_in_4bit,fast_inference=False)
grid_encoder = create_grid_encoder(cfg.grid_encoder)

Let's create a small dataset.

In [None]:
task_id = list(dataset.keys())[0]
grpo_dataset = []
for _ in range(10):
    params = get_random_data_augmentation_params()
    task = apply_data_augmentation(dataset[task_id], **params)
    prompt = create_prompt_from_task(
            task, grid_encoder=grid_encoder, tokenizer=tokenizer, shuffle_train_samples=True)
    ground_truth = [sample['output'] for sample in task['train']] + [sample['output'] for sample in task['test']]
    grpo_dataset.append(dict(prompt=prompt, ground_truth=ground_truth))
grpo_dataset = Dataset.from_list(grpo_dataset)

In [None]:
def reward_num_unique_letters(completions, **kwargs):
    """Reward function that rewards completions with more unique letters."""
    logger.info(f"Computing reward for {len(completions)} completions")
    logger.info(f'Completions: {completions}')
    logger.info(f'This are the kwargs: {list(kwargs.keys())}')
    completion_contents = [completion[0]["content"] for completion in completions]
    return [float(len(set(content))) for content in completion_contents]

In [None]:
model = FastLanguageModel.get_peft_model(
    llm,
    r = cfg.lora_r, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = 64,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    use_rslora = cfg.use_rslora,
    # random_state = 3407,
)

In [None]:
training_args = GRPOConfig(
    output_dir=cfg.output_dir,
    num_train_epochs=cfg.max_epochs,
    per_device_train_batch_size=cfg.training_batch_size,
    num_generations=cfg.num_generations,
    gradient_accumulation_steps=cfg.num_generations // cfg.training_batch_size,
    learning_rate=cfg.learning_rate,
    use_vllm=True,
    vllm_mode="server",
)
print(f"Training arguments: {training_args}")
trainer = GRPOTrainer(
    model=model,
    reward_funcs=reward_num_unique_letters,
    data_collator=get_data_collator(tokenizer),
    args=training_args,
    train_dataset=grpo_dataset,
)
trainer.train()

## Debug

## TODO