In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
from peft import LoraConfig, get_peft_model
ckpt = "google/gemma-3-4b-it"

model = AutoModelForImageTextToText.from_pretrained(
    ckpt, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
)

# Load LoRA
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=16,
    lora_alpha=32,
    target_modules="all-linear",
)
model = get_peft_model(model, lora_config)
print(model.print_trainable_parameters())

processor = AutoProcessor.from_pretrained(ckpt)

  from .autonotebook import tqdm as notebook_tqdm
`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.50it/s]


trainable params: 38,497,792 || all params: 4,338,577,264 || trainable%: 0.8873
None


Fetching 2 files: 100%|██████████| 2/2 [00:00<00:00, 24672.38it/s]
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [3]:
import re
from datasets import load_dataset, Dataset

SYSTEM_PROMPT = """
You are playing Wordle, a word-guessing game.

Think through the problem and feedback step by step. Make sure to first add your step by step thought process within <think> </think> tags. Then, return your guessed word in the following format: <guess> guessed-word </guess>.
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<guess>")[-1]
    answer = answer.split("</guess>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()


def get_wordle_dataset(split="train") -> Dataset:
    data = load_dataset("predibase/wordle-grpo", split=split)
    
    # Prepare the prompt-answer pairs
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['prompt']}  # Use the full prompt from the dataset
        ],
        'answer': x['secret']  # secret word as the final answer to guess
    })
    return data

dataset = get_wordle_dataset()

In [4]:
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_guesses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Prompt:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted Guess:\n{extracted_guesses[0]}")
    return [2.0 if guess.lower() == ans.lower() else 0.0 for guess, ans in zip(extracted_guesses, answer)]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Checks strict format: <think> ... </think>\n<guess> ... </guess>"""
    pattern = r"^<think>\n.*?\n</think>\n<guess>\n.*?\n</guess>$"
    responses = [completion[0]['content'] for completion in completions]
    matches = [re.match(pattern, r, re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Checks that both <think> and <guess> tags appear somewhere"""
    pattern = r"<think>.*?</think>.*?<guess>.*?</guess>"
    responses = [completion[0]['content'] for completion in completions]
    matches = [re.search(pattern, r, re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<think>\n") == 1:
        count += 0.125
    if text.count("\n</think>\n") == 1:
        count += 0.125
    if text.count("\n<guess>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</guess>\n")[-1]) * 0.001
    if text.count("\n</guess>") == 1:
        count += 0.125
        count -= (len(text.split("\n</guess>")[-1]) - 1) * 0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]['content'] for completion in completions]
    return [count_xml(c) for c in contents]




In [5]:
from trl import GRPOConfig, GRPOTrainer

max_prompt_length = 256
max_seq_length = 1024

training_args = GRPOConfig(
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 2,
    gradient_accumulation_steps = 1,
    num_generations = 2,
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    num_train_epochs = 1,
    max_steps = 100,        #250
    save_steps = 100,       #250
    max_grad_norm = 0.1,
    report_to = "none",
)

INFO 09-08 17:44:19 [__init__.py:241] Automatically detected platform cuda.


In [None]:
trainer = GRPOTrainer(
    model=model,
    processing_class=processor,
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        correctness_reward_func,
    ],
    args=training_args,
    train_dataset=dataset,
)

trainer.train()


The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.
`generation_config` default values have been modified to match model-specific defaults: {'top_p': 0.95}. If this is not desired, please set these values explicitly.


-------------------- Prompt:
<|im_start|>system

You are playing Wordle, a word-guessing game.

### Game Rules:
- You have **6 tries** to guess a secret **5-letter** word.
- Each guess must be a valid **5-letter English word**.
- After each guess, you will receive feedback indicating how close your guess was.

### Feedback Format:
Each letter in your guess will receive one of three symbols:
1. ✓ : The letter is in the word and in the CORRECT position.
2. - : The letter is in the word but in the WRONG position.
3. x : The letter is NOT in the word.

### Example:
Secret Word: BRISK

Guess 1: STORM → Feedback: S(-) T(x) O(x) R(-) M(x)
Guess 2: BRAVE → Feedback: B(✓) R(✓) A(x) V(x) E(x)
Guess 3: BRISK → Feedback: B(✓) R(✓) I(✓) S(✓) K(✓)

### Response Format:
Think through the problem and feedback step by step. Make sure to first add your step by step thought process within <think> </think> tags. Then, return your guessed word in the following format: <guess> guessed-word </guess>.
<|im_en

Step,Training Loss
1,0.0
2,0.0
3,0.0
4,0.0
5,0.0
6,0.0
7,0.0
8,0.0


-------------------- Prompt:
<|im_start|>system

You are playing Wordle, a word-guessing game.

### Game Rules:
- You have **6 tries** to guess a secret **5-letter** word.
- Each guess must be a valid **5-letter English word**.
- After each guess, you will receive feedback indicating how close your guess was.

### Feedback Format:
Each letter in your guess will receive one of three symbols:
1. ✓ : The letter is in the word and in the CORRECT position.
2. - : The letter is in the word but in the WRONG position.
3. x : The letter is NOT in the word.

### Example:
Secret Word: BRISK

Guess 1: STORM → Feedback: S(-) T(x) O(x) R(-) M(x)
Guess 2: BRAVE → Feedback: B(✓) R(✓) A(x) V(x) E(x)
Guess 3: BRISK → Feedback: B(✓) R(✓) I(✓) S(✓) K(✓)

### Response Format:
Think through the problem and feedback step by step. Make sure to first add your step by step thought process within <think> </think> tags. Then, return your guessed word in the following format: <guess> guessed-word </guess>.
<|im_en