# Improving BARC induction model with RL

## Goal

Create the code to do RL with the BARC induction model. 

Once it works it will be moved to a script.

## Server

Before running the notebook launch a server. 

```bash
export CUDA_VISIBLE_DEVICES=0; trl vllm-serve --max_model_len 12000 --model /home/gbarbadillo/models/Llama-3.1-ARC-Potpourri-Induction-8B
```

## Imports

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1" # 0 is used by the vllm server

from unsloth import FastLanguageModel
from dataclasses import dataclass
from datasets import Dataset

from trl import GRPOConfig, GRPOTrainer

from arc25.encoders import create_grid_encoder
from arc25.utils import load_arc_dataset_with_solutions
from arc25.data_augmentation import apply_data_augmentation, get_random_data_augmentation_params
from arc25.prompting import create_prompt_from_task
# from arc25.collator import get_data_collator
from arc25.logging import configure_logging, logging
from arc25.parallel_code_execution import run_code_from_predictions

configure_logging()
logger = logging.getLogger(__name__)

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
INFO 09-12 16:42:29 [__init__.py:241] Automatically detected platform cuda.
Unsloth: Your Flash Attention 2 installation seems to be broken?
A possible explanation is you have a new CUDA version which isn't
yet compatible with FA2? Please file a ticket to Unsloth or FA2.
We shall now use Xformers instead, which does not have any performance hits!
We found this negligible impact by benchmarking on 1x A100.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!


## First steps

In [2]:
@dataclass
class cfg:
    # base model
    model_path: str = "/home/gbarbadillo/models/Llama-3.1-ARC-Potpourri-Induction-8B"
    load_in_4bit: bool = True
    max_seq_length: int = 12000
    grid_encoder: str = 'ColorNameEncoder()'
    # LoRA
    lora_r: int = 16
    use_rslora: bool = True
    # dataset
    dataset_path: str = "/mnt/hdd0/Kaggle/arc25/data/arc-prize-2024/arc-agi_training_challenges.json"
    output_dir: str = "/mnt/hdd0/Kaggle/arc25/trainings/2025-09-12-debug-grpo-b/reward-v1"
    # training hyperparameters
    max_epochs: int = 1
    num_generations: int = 16
    training_batch_size: int = 1
    learning_rate: float = 1e-5

In [3]:
dataset = load_arc_dataset_with_solutions(cfg.dataset_path)
print(f"Loaded {len(dataset)} tasks from {cfg.dataset_path}")

Loaded 400 tasks from /mnt/hdd0/Kaggle/arc25/data/arc-prize-2024/arc-agi_training_challenges.json


In [4]:
llm, tokenizer = FastLanguageModel.from_pretrained(
    cfg.model_path, load_in_4bit=cfg.load_in_4bit,fast_inference=False, max_seq_length=cfg.max_seq_length)
grid_encoder = create_grid_encoder(cfg.grid_encoder)

==((====))==  Unsloth 2025.9.4: Fast Llama patching. Transformers: 4.56.1. vLLM: 0.10.1.1.
   \\   /|    NVIDIA GeForce RTX 3090. Num GPUs = 1. Max memory: 23.568 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.1+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.1
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

2025-09-12 16:42:49,454 - arc25.encoders - INFO - create_grid_encoder - Created `ColorNameEncoder()` as grid encoder


Let's create a small dataset.

In [5]:
grpo_dataset = []
for task_id in list(dataset.keys())[:10]: # debug with 10 tasks
    for _ in range(1):
        # params = get_random_data_augmentation_params()
        # task = apply_data_augmentation(dataset[task_id], **params)
        task = dataset[task_id] # debug without data augmentation
        prompt = create_prompt_from_task(
                task, grid_encoder=grid_encoder, tokenizer=tokenizer, shuffle_train_samples=True)
        grpo_dataset.append(dict(prompt=prompt, tasks=task))
grpo_dataset = Dataset.from_list(grpo_dataset)

In [6]:
def arc_reward(completions, tasks, completion_ids, **kwargs):
    """
    Reward function that rewards completions based on how many test cases they pass.

    As input seems to be receiving: completions, prompts, ground_truth and completion_ids
    """
    results = run_code_from_predictions(tasks, list(range(len(completions))), completions, [None]*len(completions), group_results_by_task=False)
    logger.info(f'Completions length: {[len(c) for c in completion_ids]}')
    # logger.info(f"Reward results: {results}")
    # logger.info(f"Task ids: {[result['task_id'] for result in results]}") # this verifies that results are in the same order as completions
    if 'code' in results[0]:
        logger.info(f"Example code:\n{results[0]['code']}")

    rewards = []
    for result in results:
        if 'code' not in result:
            rewards.append(-1.0)
        elif 'train_correct_grids' not in result:
            rewards.append(0.0)
        else:
            rewards.append(float(result['train_correct_grids']) + float(result.get('test_correct_grids', 0)))

    logger.info(f'Rewards: {rewards}')
    return rewards

In [7]:
model = FastLanguageModel.get_peft_model(
    llm,
    r = cfg.lora_r, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = 64,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    use_rslora = cfg.use_rslora,
    # random_state = 3407,
)

Unsloth 2025.9.4 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [8]:
# https://huggingface.co/docs/trl/main/en/grpo_trainer#trl.GRPOConfig
training_args = GRPOConfig(
    output_dir=cfg.output_dir,
    num_train_epochs=cfg.max_epochs,
    per_device_train_batch_size=cfg.training_batch_size,
    num_generations=cfg.num_generations,
    gradient_accumulation_steps=1, # how many prompts are used per task
    learning_rate=cfg.learning_rate,
    # generation
    use_vllm=True,
    vllm_mode="server",
    max_completion_length=1024,
    max_prompt_length=None,
    temperature=1.0,
    top_p=0.95,
    # wandb
    report_to='wandb',
    run_name=os.path.basename(cfg.output_dir),
    # project=os.path.basename(os.path.dirname(cfg.output_dir)),
)
os.environ["WANDB_PROJECT"] = os.path.basename(os.path.dirname(cfg.output_dir))
# set also the output dir for wandb
os.environ["WANDB_DIR"] = cfg.output_dir

print(f"Training arguments: {training_args}")
# Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`. ???
# Why ??? I want to use gradient accumulation to simulate larger batch sizes.

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 16
Training arguments: UnslothGRPOConfig(
_n_gpu=1,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
average_tokens_across_devices=False,
batch_eval_metrics=False,
beta=0.001,
bf16=False,
bf16_full_eval=False,
cache_implementation=None,
data_seed=3407,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
dataloader_prefetch_factor=None,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
delta=None,
disable_dropout=

In [9]:
trainer = GRPOTrainer(
    model=model,
    reward_funcs=arc_reward, #reward_num_unique_letters,
    # data_collator=get_data_collator(tokenizer),
    args=training_args,
    train_dataset=grpo_dataset,
    completion_only_loss=True,
)
trainer.train()

2025-09-12 16:42:54,944 - trl.extras.vllm_client - INFO - check_server - Server is up!


INFO 09-12 16:42:55 [__init__.py:1418] Found nccl from library libnccl.so.2
INFO 09-12 16:42:55 [pynccl.py:70] vLLM is using nccl==2.26.2


The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 128009}.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 10 | Num Epochs = 1 | Total steps = 10
O^O/ \_/ \    Batch size per device = 16 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (16 x 1 x 1) = 16
 "-____-"     Trainable parameters = 41,943,040 of 8,072,204,288 (0.52% trained)
[34m[1mwandb[0m: Currently logged in as: [33mguillermobarbadillo[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Executing predictions:   0%|          | 0/1 [00:00<?, ?batch/s]

Executing predictions for batch 0:   0%|          | 0/16 [00:00<?, ?pred/s]

2025-09-12 16:43:07,133 - __main__ - INFO - arc_reward - Completions length: [240, 264, 244, 232, 269, 245, 241, 231, 271, 237, 240, 269, 240, 237, 239, 251]
2025-09-12 16:43:07,135 - __main__ - INFO - arc_reward - Example code:
from common import *

import numpy as np
from typing import *

# concepts:
# sorting, height mapping, color assignment

# description:
# In the input you will see a row of exactly 4 gray bars of different heights, each starting at the bottom of the canvas, and each separated by 1 pixel (so they are two pixels apart).
# Color the tallest one blue, the second tallest one red, the third tallest one green, and the shortest one yellow.

def transform(input_grid):
    # extract the bars, each of which is a connected component
    bars = find_connected_components(input_grid, background=Color.BLACK)

    # sort the bars by height
    bars = list(sorted(bars, key=lambda bar: np.sum(bar!= Color.BLACK), reverse=True))

    # color the bars
    output_grid = input_grid.cop

Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,sampling / sampling_logp_difference / mean,sampling / sampling_logp_difference / max,sampling / importance_sampling_ratio / min,sampling / importance_sampling_ratio / mean,sampling / importance_sampling_ratio / max,kl,rewards / arc_reward / mean,rewards / arc_reward / std
1,0.0,0.0,0.0,246.875,231.0,271.0,0.0,246.875,231.0,271.0,0,0,0,0,0,0.0,0.0,0.0
2,0.0,0.0,0.0,211.6875,175.0,247.0,0.0,211.6875,175.0,247.0,No Log,No Log,No Log,No Log,No Log,0.0,0.0,0.0
3,0.0,0.0,0.0,293.0625,222.0,457.0,0.0,293.0625,222.0,457.0,No Log,No Log,No Log,No Log,No Log,0.0,0.0,0.0
4,0.0,0.0,0.0,402.0,333.0,575.0,0.0,402.0,333.0,575.0,No Log,No Log,No Log,No Log,No Log,0.0,0.0,0.0
5,0.0,0.0,0.0,368.3125,178.0,474.0,0.0,368.3125,178.0,474.0,No Log,No Log,No Log,No Log,No Log,0.0,0.0,0.0
6,0.0,0.0,0.0,323.875,251.0,385.0,0.0,323.875,251.0,385.0,No Log,No Log,No Log,No Log,No Log,0.0,0.0,0.0
7,0.0,0.0,0.0,502.125,329.0,795.0,0.0,502.125,329.0,795.0,No Log,No Log,No Log,No Log,No Log,0.0,0.0,0.0
8,0.0,0.0,0.0,389.4375,244.0,593.0,0.0,389.4375,244.0,593.0,No Log,No Log,No Log,No Log,No Log,0.0,0.0,0.0
9,0.0,0.0,0.0,437.875,230.0,619.0,0.0,437.875,230.0,619.0,No Log,No Log,No Log,No Log,No Log,0.0,0.0,0.0
10,0.0,0.0,0.0,283.3125,186.0,433.0,0.0,283.3125,186.0,433.0,No Log,No Log,No Log,No Log,No Log,0.0,0.0,0.0


Executing predictions:   0%|          | 0/1 [00:00<?, ?batch/s]

Executing predictions for batch 0:   0%|          | 0/16 [00:00<?, ?pred/s]

2025-09-12 16:43:32,918 - __main__ - INFO - arc_reward - Completions length: [195, 218, 247, 202, 212, 206, 240, 240, 223, 226, 218, 205, 175, 186, 194, 200]
2025-09-12 16:43:32,920 - __main__ - INFO - arc_reward - Example code:
from common import *

import numpy as np
from typing import *

# concepts:
# topology, object boundary, region filling

# description:
# The input grid consists of a black background with some green pixels forming a closed shape. 
# To produce the output, you need to find the boundary of the closed shape and color the enclosed area yellow. 
# The output should retain the original boundary pixels but fill the interior with yellow.

def transform(input_grid):
    # Create an output grid initialized to black
    output_grid = np.copy(input_grid)

    # Find the boundary of the object
    boundary_mask = object_boundary(input_grid, background=Color.BLACK)

    # Create an interior mask for the region that is not on the boundary
    interior_mask = object_interior(i

Unsloth: Will smartly offload gradients to save VRAM!


Executing predictions:   0%|          | 0/1 [00:00<?, ?batch/s]

Executing predictions for batch 0:   0%|          | 0/16 [00:00<?, ?pred/s]

2025-09-12 16:44:24,654 - __main__ - INFO - arc_reward - Completions length: [335, 224, 285, 293, 369, 327, 236, 294, 241, 457, 275, 270, 222, 290, 287, 284]
2025-09-12 16:44:24,656 - __main__ - INFO - arc_reward - Example code:
from common import *
import numpy as np
from typing import *

# concepts:
# bitmasks with separator, boolean logical operations

# description:
# Compute the AND operation of where the two grids are blue, turning the output red in those locations.
# In the input, you should see two 3x3 blue patterns on top and bottom separated by a horizontal gray line in the middle of the grid.
# To make the output, you have to overlap the two patterns. If both overlapping cells are blue, then the corresponding cell is colored red; 
# otherwise, if the overlapping cells are not both blue, then the corresponding cell is colored black.

def transform(input_grid):
    # Get the height and width of the input grid
    width, height = input_grid.shape

    # Find the gray horizontal

Executing predictions:   0%|          | 0/1 [00:00<?, ?batch/s]

Executing predictions for batch 0:   0%|          | 0/16 [00:00<?, ?pred/s]

2025-09-12 16:44:53,336 - __main__ - INFO - arc_reward - Completions length: [386, 333, 371, 367, 336, 384, 340, 343, 406, 407, 514, 349, 534, 575, 422, 365]
2025-09-12 16:44:53,338 - __main__ - INFO - arc_reward - Example code:
from common import *

import numpy as np
from typing import *

# concepts:
# sliding objects, collision detection

# description:
# In the input you will see a 2x2 purple square and a set of red objects. 
# Slide each red object in any of the four directions until it just touches the purple square.

def transform(input_grid):
    # Get the purple object
    purple_object = np.zeros_like(input_grid)
    purple_object[input_grid == Color.PURPLE] = Color.PURPLE

    # Get the red objects
    red_objects = detect_objects(grid=input_grid, colors=[Color.RED], monochromatic=False, connectivity=4)
    
    # Start the output grid with just the purple object
    output_grid = np.copy(purple_object)

    # Slide each red object until it touches the purple square
    for 

Executing predictions:   0%|          | 0/1 [00:00<?, ?batch/s]

Executing predictions for batch 0:   0%|          | 0/16 [00:00<?, ?pred/s]

2025-09-12 16:45:32,307 - __main__ - INFO - arc_reward - Completions length: [423, 470, 178, 424, 310, 474, 464, 361, 274, 187, 368, 474, 351, 275, 388, 472]
2025-09-12 16:45:32,309 - __main__ - INFO - arc_reward - Example code:
from common import *

import numpy as np
from typing import *

# concepts:
# color gradient, scaling

# description:
# In the input, you will see a small colored shape in the center of the grid. The shape has a color gradient from the center to the edges. 
# To create the output, scale the shape up by a factor of 2 while maintaining the color gradient, ensuring that the gradient smoothly transitions from the center outwards.

def transform(input_grid):
    # Detect the colored shape in the center of the grid
    objects = detect_objects(grid=input_grid, monochromatic=False, background=Color.BLACK, connectivity=8)

    # Initialize the output grid with the same size as the input grid
    output_grid = np.zeros_like(input_grid)

    for obj in objects:
        # 

Executing predictions:   0%|          | 0/1 [00:00<?, ?batch/s]

Executing predictions for batch 0:   0%|          | 0/16 [00:00<?, ?pred/s]

2025-09-12 16:46:02,437 - __main__ - INFO - arc_reward - Completions length: [326, 335, 324, 358, 328, 374, 307, 294, 318, 251, 300, 309, 301, 331, 341, 385]
2025-09-12 16:46:02,439 - __main__ - INFO - arc_reward - Example code:
from common import *

import numpy as np
from typing import *

# concepts:
# reflection, color change, symmetry detection

# description:
# In the input, you will see a grid containing a pattern of blue pixels arranged in a way that exhibits translational symmetry.
# To create the output, reflect this pattern horizontally and change the color from blue to red. 
# The output grid should be large enough to accommodate the original pattern and its reflection.

def transform(input_grid):
    # Detect translational symmetries in the input grid
    symmetries = detect_translational_symmetry(input_grid, ignore_colors=[Color.BLACK])
    assert len(symmetries) > 0, "No translational symmetry found"

    # Create a new output grid with double the height to accommodate th

Executing predictions:   0%|          | 0/1 [00:00<?, ?batch/s]

Executing predictions for batch 0:   0%|          | 0/16 [00:00<?, ?pred/s]

2025-09-12 16:46:39,318 - __main__ - INFO - arc_reward - Completions length: [368, 476, 481, 482, 662, 599, 495, 329, 795, 431, 682, 521, 448, 385, 409, 471]
2025-09-12 16:46:39,321 - __main__ - INFO - arc_reward - Example code:
from common import *
import numpy as np
from typing import *

# concepts:
# grid manipulation, regions, overlapping colors, color transformation

# description:
# In the input, you will see a grid divided by horizontal and vertical lines, forming rectangular regions. 
# Each region can contain colored pixels, and there are also special colored pixels (red and green).
# To make the output:
# 1. For each red pixel, fill the entire region it's in with the same color.
# 2. For each green pixel, fill the entire region it's in with the same color, but only if that region is not already filled with red.

def transform(input_grid: np.ndarray) -> np.ndarray:
    output_grid = np.copy(input_grid)

    # Find the divider color (assuming it's the most frequent non-backgrou

Executing predictions:   0%|          | 0/1 [00:00<?, ?batch/s]

Executing predictions for batch 0:   0%|          | 0/16 [00:00<?, ?pred/s]

2025-09-12 16:48:19,489 - __main__ - INFO - arc_reward - Completions length: [418, 380, 294, 469, 254, 593, 443, 354, 244, 388, 260, 502, 396, 472, 441, 323]
2025-09-12 16:48:19,491 - __main__ - INFO - arc_reward - Example code:
from common import *

import numpy as np
from typing import *

# concepts:
# repetition, diagonal lines

# description:
# In the input you will see a 7x7 grid with three diagonal lines that stretch from one corner of the canvas to the other.
# Each line is a different color, and the colors are not black. The output should be the result of repeating every diagonal line
# on multiples of 2 offset from the original, which gives an interlacing pattern filling the output canvas.

def transform(input_grid: np.ndarray) -> np.ndarray:
    output_grid = np.zeros((7, 7), dtype=int)

    # Loop over the input grid to find diagonal lines
    for i in range(input_grid.shape[0]):
        for j in range(input_grid.shape[1]):
            c = input_grid[i][j]
            if c!=

Executing predictions:   0%|          | 0/1 [00:00<?, ?batch/s]

Executing predictions for batch 0:   0%|          | 0/16 [00:00<?, ?pred/s]

2025-09-12 16:48:57,942 - __main__ - INFO - arc_reward - Completions length: [559, 421, 470, 230, 380, 619, 525, 571, 402, 428, 585, 442, 253, 362, 274, 485]
2025-09-12 16:48:57,944 - __main__ - INFO - arc_reward - Example code:
from common import *

import numpy as np
from typing import *

# concepts:
# translation, color swapping, pattern arrangement

# description:
# In the input, you will see a grid with a central square pattern and several colored rectangles surrounding it.
# To create the output, swap the colors of the central square with each surrounding rectangle, then move the rectangles outward in all four cardinal directions until they touch the edges of the grid.

def transform(input_grid: np.ndarray) -> np.ndarray:
    # Step 1: Find the central square and surrounding rectangles
    objects = find_connected_components(input_grid, monochromatic=False)
    
    # Identify the central square (assumed to be the largest connected component)
    central_square = max(objects, key

Executing predictions:   0%|          | 0/1 [00:00<?, ?batch/s]

Executing predictions for batch 0:   0%|          | 0/16 [00:00<?, ?pred/s]

2025-09-12 16:50:17,671 - __main__ - INFO - arc_reward - Completions length: [329, 228, 365, 239, 326, 254, 277, 433, 217, 300, 228, 394, 302, 214, 186, 241]
2025-09-12 16:50:17,673 - __main__ - INFO - arc_reward - Example code:
from common import *

import numpy as np
from typing import *

# concepts:
# scaling, pattern replication, color mapping

# description:
# In the input you will see a 3x3 pattern of colored pixels.
# To create the output, scale the pattern to a 9x9 grid by replicating it three times in each direction,
# and then map the colors of the original pattern to the new scaled positions using a color mapping that 
# shifts the colors by one position in a cyclic manner.

def transform(input_grid):
    # Get the color mapping by extracting the unique colors in the 3x3 pattern
    unique_colors = np.unique(input_grid)
    color_mapping = {color: (unique_colors[(i + 1) % len(unique_colors)]) for i, color in enumerate(unique_colors)}

    # Create the output grid initialized

TrainOutput(global_step=10, training_loss=0.0, metrics={'train_runtime': 461.4252, 'train_samples_per_second': 0.022, 'train_steps_per_second': 0.022, 'total_flos': 0.0, 'train_loss': 0.0})

In [10]:
# use this to reset the vllm server
# ! curl  -X POST --location http://0.0.0.0:8000/close_communicator/

## Debug

There seem to be some compatibility problems:

```
This happens when creating the training conf:
Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`. ???

This happens when training
AttributeError: 'UnslothGRPOConfig' object has no attribute 'delta'

Library versions:
unsloth                   2025.9.1                 pypi_0    pypi
unsloth-zoo               2025.9.2                 pypi_0    pypi
trl                       0.18.0.dev0              pypi_0    pypi

# pip index versions <package-name>
unsloth (2025.9.4)
Available versions: 2025.9.4, 2025.9.3, 2025.9.2, 2025.9.1,
trl (0.23.0)
Available versions: 0.23.0, 0.22.2, 0.22.1, 0.22.0, 0.21.0, 0.20.0, 0.19.1, 0.19.0, 0.18.2, 0.18.1, 0.18.0

I have installed the latest versions of both libraries on the environment `arc25-unsloth`
pip install unsloth==2025.9.4
pip install trl==0.23.0
pip install trl[vllm]

Then it gives this error when launching the server.
NameError: name 'ParallelismConfig' is not defined. Did you mean: 'parallelism_config'?
Solved with: pip install --upgrade accelerate

I also have to remove the collator.
```

This is working, but only did one training step.

```
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 10 | Num Epochs = 1 | Total steps = 1
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 8 x 1) = 64
 "-____-"     Trainable parameters = 41,943,040 of 8,072,204,288 (0.52% trained)
```

If I reduce the gradient accumulation steps to 1, increase the number of epochs to 3 then it does 30 steps.

Now the problem is that it seems that only be predicting 256 output tokens.

Notice 

So far I hasn't solved any of the training tasks, but seems to be always predicting the code correctly.

## TODO

- [ ] Implement the reward function
- [ ] Check if memory is enough
- [ ] Can I optimize the bouncing in compute between the two gpus?