# OpenEnv Sudoku with GRPO using TRL

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/openenv_sudoku_grpo.ipynb)

![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)

## What is GRPO and OpenEnv

## Setup dependencies for training

Install the required libraries, including Hugging Face TRL for fine-tuning and OpenEnv for reinforcement learning environments.

In [None]:
#!pip install -Uq trl[vllm] git+https://github.com/meta-pytorch/OpenEnv.git openenv_core  trackio #liger-kernel
!pip install -Uq trl[vllm] trackio git+https://github.com/meta-pytorch/OpenEnv.git@bf5e968286e0d49cdc03fd904d48faff4b15a437 openenv_core==0.1.1

A valid Hugging Face token is required to save the fine-tuned model. In Google Colab, the token can be securely accessed through Colab secrets. Otherwise, it can be provided directly in the login method. Ensure the token has write permissions to allow uploading the model to the Hugging Face Hub during training.

In [1]:
from google.colab import userdata
from huggingface_hub import login

# Login into Hugging Face Hub
hf_token = userdata.get('HF_TOKEN') # If you are running inside a Google Colab
login(hf_token)

## Initialize the OpenEnv's TextArena environment

In [2]:
from envs.textarena_env import TextArenaEnv

space_url = "https://sergiopaniego-textarena.hf.space"

client = TextArenaEnv(base_url=space_url)

## Create rollout function with helpers

In [3]:
# @title System prompt (click to expand)
SYSTEM_PROMPT = """You are an expert Sudoku player with deep knowledge of logical deduction strategies and number placement techniques.

## GAME RULES

1. The puzzle is a 9x9 grid divided into nine 3x3 subgrids (boxes)
2. Some cells are pre-filled with numbers 1-9
3. You must fill in the empty cells (shown as '.') with numbers 1-9
4. Each row must contain numbers 1-9 without repetition
5. Each column must contain numbers 1-9 without repetition
6. Each 3x3 subgrid must contain numbers 1-9 without repetition
7. You cannot overwrite pre-filled cells
8. Invalid moves result in penalties (-1 reward)

## RESPONSE FORMAT

**CRITICAL: Output ONLY the move, nothing else. No text, no explanation.**

Format: [row col number]

Examples:
- [5 3 7] ‚Üí places 7 in row 5, column 3
- [1 2 4] ‚Üí places 4 in row 1, column 2

## STRATEGIC APPROACH

Do not repeat the same move twice.

### Basic Strategies
- **Naked Singles**: If a cell has only one possible candidate, fill it in immediately.
- **Hidden Singles**: If a number can only go in one cell within a row, column, or box, place it there.
- **Scanning**: Look at each row, column, and box to find where specific numbers can go.

### Intermediate Strategies
- **Naked Pairs/Triples**: When two/three cells in a unit contain only the same candidates, eliminate those from other cells.
- **Hidden Pairs/Triples**: When numbers only appear in specific cells within a unit, those cells can only contain those numbers.
- **Pointing Pairs**: When a candidate in a box is restricted to a single row/column, eliminate it elsewhere.

### Solving Process
1. Start by scanning the entire grid to identify easy fills (cells with few candidates)
2. Look for rows, columns, or boxes with many numbers already placed
3. Fill all naked singles first
4. Then look for hidden singles in each row, column, and box
5. Apply more advanced techniques as needed

### Common Pitfalls to Avoid
- Don't guess randomly - Sudoku is pure logic
- Don't overlook any constraint (row, column, or box)
- Don't try to overwrite pre-filled cells
- Don't place invalid numbers (must be 1-9)
- Don't use invalid coordinates (must be 1-9)
- Don't repeat a move that was already made

## EXAMPLES

### Example 1: Naked Single
If row 3, column 4 can only contain the number 5:
[3 4 5]

### Example 2: Hidden Single
If the number 8 can only go in one cell in row 1:
[1 7 8]

### Example 3: Row Analysis
Row 2 is missing only value 5, and column 8 is the empty cell:
[2 8 5]

### Example 4: Box Analysis
In the center box, only one cell can contain 9:
[5 5 9]

## BOARD READING

The board is displayed as a 9x9 grid:
- Numbers 1-9 are pre-filled or already placed
- Empty cells are shown as '.'
- Rows are labeled R1-R9 (top to bottom)
- Columns are labeled C1-C9 (left to right)

Example board representation:
```
   C1 C2 C3   C4 C5 C6   C7 C8 C9
R1  .  8  9 |  1  .  . |  .  3  7
R2  2  7  1 |  9  4  3 |  6  .  8
R3  .  6  5 |  .  2  7 |  4  9  .
   - - - - - - - - - - - - - - - -
R4  .  .  . |  7  8  . |  9  2  3
R5  .  9  2 |  .  5  6 |  .  .  4
R6  7  3  8 |  .  .  2 |  1  .  .
   - - - - - - - - - - - - - - - -
R7  8  4  . |  .  .  9 |  5  .  .
R8  5  .  . |  6  .  8 |  3  4  9
R9  9  .  6 |  5  3  4 |  8  7  2
```

## COORDINATE REFERENCE

Row indices (top to bottom): 1, 2, 3, 4, 5, 6, 7, 8, 9
Column indices (left to right): 1, 2, 3, 4, 5, 6, 7, 8, 9

Subgrid layout:
```
Subgrid 1 | Subgrid 2 | Subgrid 3
  (R1-R3)    (R1-R3)     (R1-R3)
  (C1-C3)    (C4-C6)     (C7-C9)
----------+-----------+----------
Subgrid 4 | Subgrid 5 | Subgrid 6
  (R4-R6)    (R4-R6)     (R4-R6)
  (C1-C3)    (C4-C6)     (C7-C9)
----------+-----------+----------
Subgrid 7 | Subgrid 8 | Subgrid 9
  (R7-R9)    (R7-R9)     (R7-R9)
  (C1-C3)    (C4-C6)     (C7-C9)
```

## IMPORTANT CONSTRAINTS

- Coordinates are 1-indexed (1-9 for both row and column)
- Numbers must be 1-9
- One move per response
- Must be a valid move (no rule violations)
- Never repeat a previous move

## YOUR GOAL

Output ONLY your move in the format [row col number]. No explanation, no reasoning, just the move.
"""

In [4]:
from trl import GRPOTrainer

max_turns = 100
debug = False
difficulty="easy"
api_delay = 0.0


def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
    all_prompt_ids = []
    all_completion_ids = []
    all_logprobs = []
    all_correct = []
    all_valid = []
    all_empty_cell = []
    all_repetition = []
    all_progress = []

    for _ in prompts:
        episode = rollout_once(
            trainer=trainer,
            env=client,
            tokenizer=trainer.processing_class,
            system_prompt=SYSTEM_PROMPT,
            max_turns=max_turns,
            debug=debug,
            difficulty=difficulty,
            api_delay=api_delay,
        )
        all_prompt_ids.append(episode["prompt_ids"])
        all_completion_ids.append(episode["completion_ids"])
        all_logprobs.append(episode["logprobs"])
        all_correct.append(episode["correct_reward"])
        all_valid.append(episode["valid_move_reward"])
        all_empty_cell.append(episode["empty_cell_reward"])
        all_repetition.append(episode["repetition_reward"])
        all_progress.append(episode["progress_reward"])

    return {
        "prompt_ids": all_prompt_ids,
        "completion_ids": all_completion_ids,
        "logprobs": all_logprobs,
        "correct_reward": all_correct,
        "valid_move_reward": all_valid,
        "empty_cell_reward": all_empty_cell,
        "repetition_reward": all_repetition,
        "progress_reward": all_progress,
    }


### Define `rollout_once`

In [5]:
from trl.experimental.openenv import generate_rollout_completions
from envs.textarena_env import TextArenaAction
from transformers import AutoTokenizer

import time
from collections import defaultdict


def rollout_once(
    trainer: GRPOTrainer,
    env: TextArenaEnv,
    tokenizer: AutoTokenizer,
    system_prompt: str,
    max_turns: int,
    debug: bool = False,
    difficulty: str = "hard",
    api_delay: float = 0.0,
) -> dict[str, list]:
    result = env.reset()
    time.sleep(api_delay)  # Avoid rate limiting
    observation = result.observation

    # Only store the LAST turn for backprop (much more efficient!)
    last_turn_data: dict | None = None

    valid_move_scores: list[float] = []
    empty_cell_scores: list[float] = []
    correct_scores: list[float] = []
    repetition_scores: list[float] = []

    move_counts: defaultdict[str, int] = defaultdict(int)

    # Track successful and failed moves for summary
    successful_moves: list[str] = []
    failed_moves: list[str] = []

    # Extract initial board state
    last_board_state = ""
    initial_filled = 0
    for message in observation.messages:
        if message.content and is_valid_board_state(message.content):
            last_board_state = message.content
            initial_filled = count_filled_cells(last_board_state)
            break

    max_filled = initial_filled  # Track max progress

    for turn in range(max_turns):
        if result.done:
            break

        # Build COMPACT prompt (saves tokens!)
        user_prompt = make_compact_prompt(
            board=last_board_state,
            step=turn + 1,
            successful_moves=successful_moves,
            failed_moves=failed_moves,
            difficulty=difficulty,
        )
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ]
        prompt_text = tokenizer.apply_chat_template(
            messages, add_generation_prompt=True, tokenize=False, enable_thinking=False
        )

        if debug:
            print(f"\n{'=' * 60}")
            print(f"STEP {turn + 1}")
            print(f"{'=' * 60}")
            print(f"USER PROMPT:\n{user_prompt}")
            print(f"{'=' * 60}")

        # Generate
        rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0]

        # Store ONLY this turn's data (replace previous)
        last_turn_data = {
            "prompt_ids": rollout_outputs["prompt_ids"],
            "completion_ids": rollout_outputs["completion_ids"],
            "logprobs": rollout_outputs["logprobs"],
        }

        if debug:
            step_tokens = len(rollout_outputs["prompt_ids"]) + len(rollout_outputs["completion_ids"])
            print(f"TOKENS: this_step={step_tokens} (only last turn used for backprop)")

        completion_text = rollout_outputs.get("text") or tokenizer.decode(
            rollout_outputs["completion_ids"], skip_special_tokens=True
        )

        # Extract move
        move = extract_sudoku_move(completion_text)

        if debug:
            print(f"MODEL OUTPUT: {completion_text}")
            print(f"EXTRACTED MOVE: {move}")

        # Step environment
        result = env.step(TextArenaAction(message=move))
        time.sleep(api_delay)  # Avoid rate limiting
        observation = result.observation
        correct_score = float(result.reward or 0.0)

        # Get feedback
        feedback = extract_feedback(observation)

        # Get environment response
        env_response = ""
        for msg in observation.messages:
            if msg.sender_id == -1:  # Environment message
                env_response = msg.content
                break

        if debug:
            print(
                f"ENV RESPONSE: {env_response[:200]}..."
                if len(env_response) > 200
                else f"ENV RESPONSE: {env_response}"
            )
            print(f"VALID: {feedback['valid_move']}, WARNING: {feedback['got_warning']}, REWARD: {correct_score}")

        # Calculate empty_cell_score
        if last_board_state and move:
            targets_empty = check_move_targets_empty_cell(move, last_board_state)
            empty_cell_score = 1.0 if targets_empty else -1.0
        else:
            empty_cell_score = 0.0

        # Calculate valid_move_score and repetition_score
        is_new_move = move_counts[move] == 0
        repetition_count = move_counts[move]
        move_counts[move] += 1

        # Exponential penalty for repetitions: -2^(n-1) capped at -10
        # 1st repeat: -1, 2nd: -2, 3rd: -4, 4th+: -10 (capped)
        if repetition_count > 0:
            repetition_score = -min(2 ** (repetition_count - 1), 10.0)
        else:
            repetition_score = 0.0

        if debug:
            print(
                f"SCORES: empty_cell={empty_cell_score}, is_new={is_new_move}, repetitions={repetition_count}, rep_penalty={repetition_score}"
            )

        if not debug:
            print(f"Step {turn + 1}: {move}")

        if feedback["valid_move"] and is_new_move:
            valid_move_score = 1.0
            if move:
                successful_moves.append(move)  # Track for summary
        elif feedback["got_warning"]:
            valid_move_score = -0.5
            if move:
                failed_moves.append(move)  # Track for summary
        else:
            valid_move_score = 0.0

        # Update board state and track progress
        if feedback["board_state"] and is_valid_board_state(feedback["board_state"]):
            last_board_state = feedback["board_state"]
            current_filled = count_filled_cells(last_board_state)
            if current_filled > max_filled:
                max_filled = current_filled

        valid_move_scores.append(valid_move_score)
        empty_cell_scores.append(empty_cell_score)
        correct_scores.append(correct_score)
        repetition_scores.append(repetition_score)

    # Aggregate rewards
    correct_reward = correct_scores[-1] if correct_scores else 0.0
    valid_move_reward = sum(valid_move_scores) / len(valid_move_scores) if valid_move_scores else 0.0
    empty_cell_reward = sum(empty_cell_scores) / len(empty_cell_scores) if empty_cell_scores else 0.0
    repetition_reward = sum(repetition_scores) / len(repetition_scores) if repetition_scores else 0.0

    # Progress reward: how many cells we filled beyond initial state (normalized to 0-1)
    # 81 total cells, so (max_filled - initial_filled) / (81 - initial_filled) gives progress
    remaining_to_fill = 81 - initial_filled
    if remaining_to_fill > 0:
        progress_reward = (max_filled - initial_filled) / remaining_to_fill
    else:
        progress_reward = 1.0  # Already complete

    # Use ONLY last turn for backpropagation (much more efficient!)
    if last_turn_data:
        prompt_ids = last_turn_data["prompt_ids"]
        completion_ids = last_turn_data["completion_ids"]
        logprobs = last_turn_data["logprobs"]
    else:
        prompt_ids = []
        completion_ids = []
        logprobs = []

    total_tokens = len(prompt_ids) + len(completion_ids)
    cells_filled = max_filled - initial_filled
    print(
        f"Episode: empty_cell={empty_cell_reward:.2f}, valid={valid_move_reward:.2f}, "
        f"repetition={repetition_reward:.2f}, progress={progress_reward:.2f} ({cells_filled} cells), "
        f"correct={correct_reward:.2f}, tokens={total_tokens}"
    )

    return {
        "prompt_ids": prompt_ids,
        "completion_ids": completion_ids,
        "logprobs": logprobs,
        "correct_reward": correct_reward,
        "valid_move_reward": valid_move_reward,
        "empty_cell_reward": empty_cell_reward,
        "repetition_reward": repetition_reward,
        "progress_reward": progress_reward,
    }

### Helper functions

Supporting utilities used in `rollout_once`:

In [6]:
# @title Helpers (click to expand)
import re

def sanitize_name(name: str) -> str:
    return name.replace("/", "-")


def extract_sudoku_move(text: str) -> str:
    """Extract a Sudoku move [row col number] from text."""
    # Try with spaces
    match = re.search(r"\[(\d)\s+(\d)\s+(\d)\]", text)
    if match:
        row, col, num = match.groups()
        return f"[{row} {col} {num}]"

    # Try without spaces
    match = re.search(r"\[(\d)(\d)(\d)\]", text)
    if match:
        row, col, num = match.groups()
        return f"[{row} {col} {num}]"

    return ""


def is_valid_board_state(board_str: str) -> bool:
    """Check if the string contains an actual Sudoku board."""
    return "R1" in board_str and "R9" in board_str and "|" in board_str


def parse_board(board_str: str) -> list[list[int]]:
    """Parse board string into 9x9 grid (0 = empty)."""
    grid = [[0] * 9 for _ in range(9)]
    if not is_valid_board_state(board_str):
        return grid

    for line in board_str.split("\n"):
        line_stripped = line.strip()
        if line_stripped and line_stripped[0] == "R" and len(line_stripped) > 1 and line_stripped[1].isdigit():
            row = int(line_stripped[1]) - 1  # 0-indexed
            cell_part = line_stripped[2:]
            col = 0
            for char in cell_part:
                if char == ".":
                    grid[row][col] = 0
                    col += 1
                elif char.isdigit():
                    grid[row][col] = int(char)
                    col += 1
    return grid


def count_filled_cells(board_str: str) -> int:
    """Count the number of filled cells in the board."""
    if not is_valid_board_state(board_str):
        return 0
    grid = parse_board(board_str)
    return sum(1 for row in grid for cell in row if cell != 0)


def get_valid_numbers(grid: list[list[int]], row: int, col: int) -> set[int]:
    """Get valid numbers for a cell based on Sudoku rules."""
    if grid[row][col] != 0:
        return set()

    used = set()

    # Check row
    for c in range(9):
        if grid[row][c] != 0:
            used.add(grid[row][c])

    # Check column
    for r in range(9):
        if grid[r][col] != 0:
            used.add(grid[r][col])

    # Check 3x3 box
    box_row, box_col = 3 * (row // 3), 3 * (col // 3)
    for r in range(box_row, box_row + 3):
        for c in range(box_col, box_col + 3):
            if grid[r][c] != 0:
                used.add(grid[r][c])

    return set(range(1, 10)) - used


def extract_empty_cells_with_candidates(
    board_str: str, sort_by_difficulty: bool = True
) -> list[tuple[int, int, set[int]]]:
    """Extract empty cells with their valid candidate numbers.

    Args:
        sort_by_difficulty: If True, sort by number of candidates (easiest first).
                           If False, keep natural order (top-left to bottom-right).
    """
    grid = parse_board(board_str)
    cells_with_candidates = []

    for row in range(9):
        for col in range(9):
            if grid[row][col] == 0:
                candidates = get_valid_numbers(grid, row, col)
                cells_with_candidates.append((row + 1, col + 1, candidates))  # 1-indexed

    if sort_by_difficulty:
        # Sort by number of candidates (easiest first = naked singles)
        cells_with_candidates.sort(key=lambda x: len(x[2]))

    return cells_with_candidates


def extract_empty_cells(board_str: str) -> list[tuple[int, int]]:
    """Extract list of empty cells (row, col) from board string."""
    empty_cells = []
    if not is_valid_board_state(board_str):
        return empty_cells

    for line in board_str.split("\n"):
        line_stripped = line.strip()
        if line_stripped and line_stripped[0] == "R" and len(line_stripped) > 1 and line_stripped[1].isdigit():
            row = int(line_stripped[1])
            cell_part = line_stripped[2:]
            col = 0
            for char in cell_part:
                if char == ".":
                    col += 1
                    empty_cells.append((row, col))
                elif char.isdigit():
                    col += 1
    return empty_cells


def extract_board_only(text: str) -> str:
    """Extract just the Sudoku grid from a message."""
    if not text:
        return ""

    lines = text.split("\n")
    board_lines = []
    in_board = False

    for line in lines:
        stripped = line.strip()
        if stripped.startswith("C1") or (
            stripped and stripped[0] == "R" and len(stripped) > 1 and stripped[1].isdigit()
        ):
            in_board = True
        if in_board and (stripped.startswith("-") or stripped.startswith("R") or stripped.startswith("C1")):
            board_lines.append(line)
        elif (
            in_board
            and stripped
            and not stripped.startswith("-")
            and not (stripped[0] == "R" and len(stripped) > 1 and stripped[1].isdigit())
        ):
            break

    return "\n".join(board_lines) if board_lines else ""


def make_compact_prompt(
    board: str,
    step: int,
    successful_moves: list[str],
    failed_moves: list[str],
    difficulty: str = "hard",
) -> str:
    """Create a compact prompt with only essential info (saves tokens!).

    Args:
        difficulty: Training difficulty level:
            - "easy": Show guaranteed moves (naked singles) + other options
            - "medium": Only show other options (hints where to look, not exact answers)
            - "hard": No hints (model must learn Sudoku rules by itself)
    """

    # Summary line
    cells_filled = len(successful_moves)
    summary = f"Step {step}. Progress: {cells_filled} cells filled."

    # Board (only show the grid, stripped down)
    board_only = extract_board_only(board) if board else "No board available."

    # Moves already tried (for learning what NOT to do)
    tried_moves_hint = ""
    all_tried = successful_moves + failed_moves
    if all_tried:
        tried_moves_hint = f"\n\n‚ö†Ô∏è MOVES ALREADY TRIED (do not repeat): {', '.join(all_tried)}"

    # Hints based on difficulty
    hints = ""
    if difficulty == "easy" and board:
        # Easy: sorted by difficulty, show guaranteed moves + other easy options
        cells_with_candidates = extract_empty_cells_with_candidates(board, sort_by_difficulty=True)
        if cells_with_candidates:
            guaranteed = []
            other_hints = []
            for row, col, candidates in cells_with_candidates[:10]:
                if len(candidates) == 1:
                    num = list(candidates)[0]
                    guaranteed.append(f"[{row} {col} {num}]")
                elif len(candidates) <= 3:
                    nums = ",".join(str(n) for n in sorted(candidates))
                    other_hints.append(f"({row},{col})‚Üí{nums}")

            if guaranteed:
                hints = f"\n\nüéØ GUARANTEED MOVES: {', '.join(guaranteed[:5])}"
            if other_hints:
                hints += f"\nOther options: {' | '.join(other_hints[:5])}"

    elif difficulty == "medium" and board:
        # Medium: NOT sorted, just show empty cells with candidates (no ordering hints)
        cells_with_candidates = extract_empty_cells_with_candidates(board, sort_by_difficulty=False)
        if cells_with_candidates:
            cell_hints = []
            for row, col, candidates in cells_with_candidates[:10]:
                nums = ",".join(str(n) for n in sorted(candidates))
                cell_hints.append(f"({row},{col})‚Üí{nums}")
            if cell_hints:
                hints = f"\n\nEmpty cells: {' | '.join(cell_hints)}"

    return f"{summary}\n\nBoard:\n{board_only}{tried_moves_hint}{hints}\n\nYour move:"


def check_move_targets_empty_cell(move: str, board_str: str) -> bool:
    """Check if the move targets an empty cell on the board."""
    if not move or not board_str:
        return False

    match = re.search(r"\[(\d)\s+(\d)\s+(\d)\]", move)
    if not match:
        return False

    row, col = int(match.group(1)), int(match.group(2))
    empty_cells = extract_empty_cells(board_str)
    return (row, col) in empty_cells


def extract_feedback(observation) -> dict:
    """Extract feedback from environment observation."""
    feedback = {"valid_move": True, "got_warning": False, "board_state": ""}

    if not observation or not observation.messages:
        return feedback

    for message in observation.messages:
        content = message.content.lower() if message.content else ""

        if any(kw in content for kw in ["invalid", "error", "cannot", "already", "violation", "lost"]):
            feedback["valid_move"] = False
            if "please resubmit" in content or "avoid penalties" in content:
                feedback["got_warning"] = True

        if message.content and "|" in message.content and "R1" in message.content:
            feedback["board_state"] = message.content

    return feedback

## Define the reward functions

In [7]:
def reward_empty_cell(completions: list[str], **kwargs) -> list[float]:
    """Reward for targeting empty cells (learn to pick valid positions first)."""
    rewards = kwargs.get("empty_cell_reward")
    if rewards is None:
        return [0.0 for _ in completions]
    return [float(r) for r in rewards]


def reward_valid_moves(completions: list[str], **kwargs) -> list[float]:
    """Reward for making valid moves."""
    rewards = kwargs.get("valid_move_reward")
    if rewards is None:
        return [0.0 for _ in completions]
    return [float(r) for r in rewards]


def reward_correct(completions: list[str], **kwargs) -> list[float]:
    """Reward for solving the puzzle."""
    rewards = kwargs.get("correct_reward")
    if rewards is None:
        return [0.0 for _ in completions]
    return [float(r) for r in rewards]


def reward_repetition(completions: list[str], **kwargs) -> list[float]:
    """Penalty for repeating moves."""
    rewards = kwargs.get("repetition_reward")
    if rewards is None:
        return [0.0 for _ in completions]
    return [float(r) for r in rewards]


def reward_progress(completions: list[str], **kwargs) -> list[float]:
    """Reward for filling more cells in the board."""
    rewards = kwargs.get("progress_reward")
    if rewards is None:
        return [0.0 for _ in completions]
    return [float(r) for r in rewards]

## Load the custom dataset

The dataset is constructed with repeated prompts to control the total number of training episodes.

Each entry in the dataset triggers a single rollout episode during training. The `dataset_prompt` provides the initial instruction to the model at the start of each episode, ensuring consistent guidance for task execution.

In [8]:
from datasets import Dataset

dataset_prompt = "Play Sudoku like an expert."
dataset_size = 1000

dataset = Dataset.from_dict({"prompt": [dataset_prompt] * dataset_size})

## Fine-tune using TRL and the GRPOTrainer

The next step is to define the GRPOConfig, which sets all key training parameters.

This configuration determines how the model interacts with vLLM, handles memory and computation, and records training metrics and logs for monitoring the fine-tuning process.

In [9]:
from trl import GRPOConfig
output_dir = "sudoku-grpo-qwen3"

grpo_config = GRPOConfig(
    use_vllm=True,                                            # Use vLLM engine for fast inference
    vllm_mode="colocate",                                     # vLLM mode: "colocate" runs generation on the same GPU as training
    vllm_gpu_memory_utilization=0.1,                          # Fraction of GPU memory allocated to vLLM
    vllm_max_model_length=2560, # 40960

    output_dir=output_dir,
    num_train_epochs=1,
    learning_rate=5e-6,

    #weight_decay=args.weight_decay,
    gradient_accumulation_steps=8,
    per_device_train_batch_size=1,
    warmup_steps=20,
    num_generations=8,
    max_completion_length=8,

    logging_steps=1,
    save_strategy="steps",
    save_steps=10,

    report_to="trackio",
    trackio_space_id=output_dir,

    gradient_checkpointing=True,                              # Save memory by recomputing activations during backpropagation
    gradient_checkpointing_kwargs={"use_reentrant": False},   # Additional args to prevent warnings during gradient checkpointing
    # chat_template_kwargs={"enable_thinking": False},
)

In [10]:
model_name = "Qwen/Qwen3-1.7B"

In [11]:
trainer = GRPOTrainer(
    model=model_name,
    reward_funcs=[
        reward_empty_cell,  # Learn to pick empty cells
        reward_valid_moves,  # Learn valid numbers
        reward_repetition,  # Penalize repeating moves
        reward_progress,  # Reward filling more cells
        reward_correct,  # Solve the puzzle
    ],
    train_dataset=dataset,
    args=grpo_config,
    rollout_func=rollout_func,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  trainer = GRPOTrainer(
The model is already on multiple devices. Skipping the move to device specified in `args`.


Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]


Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:00<00:00, 17.76it/s]
Capturing CUDA graphs (decode, FULL): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00, 21.47it/s]


In [None]:
trainer_stats = trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.


* Trackio project initialized: huggingface
* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/sudoku-grpo-qwen3-dataset
* Found existing space: https://huggingface.co/spaces/sergiopaniego/sudoku-grpo-qwen3
* View dashboard by going to: https://sergiopaniego-sudoku-grpo-qwen3.hf.space/


* Created new run: sergiopaniego-1767180324
Step 1: [1 3 3]
Step 2: [1 5 4]
Step 3: [3 1 4]
Step 4: [2 1 2]
Step 5: [1 1 1]
Step 6: [2 2 7]
Step 7: [2 4 9]
Step 8: [3 8 8]
Step 9: [4 9 5]
Step 10: [2 3 8]
Step 11: [3 4 7]
Step 12: [4 4 6]
Step 13: [4 6 2]
Step 14: [7 8 6]
Step 15: [5 3 6]
Step 16: [6 6 5]
Step 17: [7 6 9]
Step 18: [8 1 6]
Step 19: [7 1 3]
Step 20: [6 4 3]
Step 21: [5 4 1]
Step 22: [5 6 4]
Step 23: [8 9 8]
Step 24: [6 7 7]
Step 25: [7 7 6]
Step 26: [8 2 2]
Step 27: [8 4 4]
Step 28: [9 4 5]
Step 29: [9 6 1]
Step 30: [9 8 2]
Step 31: [9 9 3]
Episode: empty_cell=0.94, valid=0.95, repetition=0.00, progress=1.00 (30 cells), correct=1.00, tokens=2036
Step 1: [1 4 5]
Step 2: [4 9 5]
Step 3: [3 2 5]
Step 4: [4 9 4]
Step 5: [4 9 4]
Episode: empty_cell=1.00, valid=0.20, repetition=-0.20, progress=0.07 (2 cells), correct=-1.00, tokens=1847
Step 1: [1 1 6]
Step 2: [2 4 1]
Step 3: [3 1 9]
Step 4: [4 5 8]
Step 5: [2 9 2]
Step 6: [4 5 8]
Step 7: [3 6 2]
Step 8: [3 9 6]
Step 9: [4 5 8]

Step,Training Loss
1,0.0035
2,-0.007
3,0.0304
4,0.1001
5,-0.0186
6,-0.0116
7,0.0484
8,0.0362
9,0.0204
10,0.0234


Step 1: [1 2 4]
Step 2: [1 8 7]
Step 3: [2 3 7]
Step 4: [4 5 6]
Step 5: [4 5 3]
Episode: empty_cell=0.20, valid=0.50, repetition=0.00, progress=0.10 (3 cells), correct=-1.00, tokens=1852
Step 1: [1 6 1]
Step 2: [1 2 9]
Step 3: [2 8 5]
Step 4: [2 6 3]
Step 5: [2 8 5]
Step 6: [2 8 5]
Episode: empty_cell=1.00, valid=0.33, repetition=-0.50, progress=0.10 (3 cells), correct=-1.00, tokens=1860
Step 1: [1 4 3]
Step 2: [3 8 3]
Step 3: [3 5 7]
Step 4: [4 7 5]
Step 5: [4 7 6]
Episode: empty_cell=0.20, valid=0.50, repetition=0.00, progress=0.10 (3 cells), correct=-1.00, tokens=1854
Step 1: [1 4 3]
Step 2: [2 8 8]
Step 3: [2 7 4]
Step 4: [2 1 5]
Step 5: [3 1 4]
Step 6: [2 2 6]
Step 7: [4 5 7]
Step 8: [4 1 1]
Step 9: [2 9 3]
Step 10: [3 4 8]
Step 11: [3 9 1]
Step 12: [4 6 9]
Step 13: [5 2 5]
Step 14: [5 4 6]
Step 15: [5 6 3]
Step 16: [5 8 1]
Step 17: [5 7 9]
Step 18: [7 8 5]
Step 19: [7 8 6]
Step 20: [7 8 6]
Episode: empty_cell=0.80, valid=0.88, repetition=-0.05, progress=0.60 (18 cells), correct=-

In this step, the fine-tuned model is saved locally and uploaded to the Hugging Face Hub using the configured account credentials.

In [None]:
trainer.save_model(output_dir)
trainer.push_to_hub()