<a href="https://colab.research.google.com/github/fabiodr/colabs/blob/main/%F0%9F%A7%A0_Gemma3_THINKING.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Let's make Gemma 3 think! üß†

In this notebook we'll use TRL and `GRPOTrainer` to make Gemma3 think before it answers.

üë©‚Äçüéì If you want to learn more about making models think and reason, check out [The Reasoning Course](https://huggingface.co/reasoning-course)

### Installation

In [None]:
# install this release tag of transformers
!pip install -qqq git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3 \
                  git+https://github.com/huggingface/trl.git@main \
                  bitsandbytes

In [None]:
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
from peft import LoraConfig, get_peft_model
ckpt = "google/gemma-3-4b-it"

model = AutoModelForImageTextToText.from_pretrained(
    ckpt, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
)

# Load LoRA
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=16,
    lora_alpha=32,
    target_modules="all-linear",
)
model = get_peft_model(model, lora_config)
print(model.print_trainable_parameters())

processor = AutoProcessor.from_pretrained(ckpt)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 42,734,080 || all params: 4,342,813,552 || trainable%: 0.9840
None


### Process data to create reasoning chains

Borrowing from [Will Brown's gist](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) we'll make reasoning chains from GSM8k.

In [None]:
import re
from datasets import load_dataset, Dataset

# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

# Reward Functions

Now, let's define reward functions. These are the functions we'll need to setup reward chains.

| Reward Function | Purpose |
|---|---|
| `correctness_reward_func` | Rewards the model when its answer matches the correct answer |
| `int_reward_func` | Rewards the model for providing a numeric answer |
| `strict_format_reward_func` and `soft_format_reward_func` | Reward the model for following the specified format |
| `xmlcount_reward_func` | Rewards proper XML tag usage and penalizes extra content after the closing tags |

In [None]:
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

# Train with GRPOTrainer

Now we'll confgure training with the `GRPOConfig`

In [None]:
from trl import GRPOConfig, GRPOTrainer

max_prompt_length = 256
max_seq_length = 1024


training_args = GRPOConfig(
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 2,
    gradient_accumulation_steps = 1,
    num_generations = 2,
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    num_train_epochs = 1,
    max_steps = 250,
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "none",
)

# Training Run

In [None]:
processor.pad_token_id = 0
processor.bos_token_id = 1
processor.eos_token_id = 2

trainer = GRPOTrainer(
    model = model,
    processing_class = processor,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)

trainer.train()

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


-------------------- Question:
Ahmed and Emily are having a contest to see who can get the best grade in the class. There have been 9 assignments and Ahmed has a 91 in the class. Emily has a 92. The final assignment is worth the same amount as all the other assignments. Emily got a 90 on the final assignment. What is the minimum grade Ahmed needs to get to beat Emily if all grades are whole numbers? 
Answer:
100 
Response:
<reasoning>
Let $A$ be Ahmed's grade and $E$ be Emily's grade.
Ahmed has a total grade of 91 from the first 9 assignments.
Emily has a total grade of 92 from the first 9 assignments and a final grade of 90.
So, Emily's total grade is $92 + 90 = 182$.
Ahmed needs to beat Emily, so $A > E$.
We want to find the minimum grade $A$ Ahmed needs to get on the final assignment such that $A > 182$. Since $A$ must be a whole number, we need $A \geq 183$.

However, each assignment is worth the same amount, so we can calculate the average grade for Ahmed and Emily.
Ahmed's averag

Step,Training Loss
1,0.0
2,-0.0
3,-0.0
4,0.0
5,0.0
6,0.0
7,0.0
8,0.0
9,0.0
10,0.0


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
12 = M + 5
Subtract 5 from both sides:
M = 12 - 5
M = 7
So, Maurice's current age is 7.

We can check our answer:
Ron's current age is 43.
Maurice's current age is 7.
In five years, Ron will be 43 + 5 = 48.
In five years, Maurice will be 7 + 5 = 12.
Is 48 = 4 * 12?
48 = 48. Yes, the equation holds.

</reasoning>
<answer>
7
</answer> 
Extracted:
7
-------------------- Question:
Cathy and Chris got summer jobs at the cake shop and were supposed to work 20 hours per week each for 2 months. During that time, Chris got sick for one week and Cathy took her shifts. If there are four weeks in a month, how many hours did Cathy work during the 2 months? 
Answer:
180 
Response:
<reasoning>
First, let's calculate the total number of weeks in two months. Since there are four weeks in a month, two months have 2 * 4 = 8 weeks.
Each person was supposed to work 20 hours per week, so without any issues, they would each work 20 * 8 = 160 ho

Step,Training Loss
1,0.0
2,-0.0
3,-0.0
4,0.0
5,0.0
6,0.0
7,0.0
8,0.0
9,0.0
10,0.0


-------------------- Question:
A family of 12 monkeys collected 10 piles of bananas. 6 piles had 9 hands, with each hand having 14 bananas, while the remaining piles had 12 hands, with each hand having 9 bananas. How many bananas would each monkey get if they divide the bananas equally amongst themselves? 
Answer:
99 
Response:
<reasoning>
Let's first calculate the number of bananas in the 6 piles with 9 hands each. Each hand has 14 bananas, so each pile contains 9 hands * 14 bananas/hand = 126 bananas. Since there are 6 piles, the total number of bananas in these piles is 6 * 126 = 756 bananas.

Next, let's calculate the number of bananas in the remaining piles. There are 10 piles total, and 6 piles have 9 hands, so there are 10 - 6 = 4 piles with 12 hands. Each hand has 9 bananas, so each pile contains 4 piles * 12 hands * 9 bananas/hand = 4 * 108 = 432 bananas.

Now, let's find the total number of bananas collected: 756 + 432 = 1188 bananas.

Finally, we need to divide the total num



TrainOutput(global_step=250, training_loss=3.5115013348217874e-05, metrics={'train_runtime': 24101.7629, 'train_samples_per_second': 0.021, 'train_steps_per_second': 0.01, 'total_flos': 0.0, 'train_loss': 3.5115013348217874e-05})

In [None]:
trainer.push_to_hub("burtenshaw/gemma3-4b-thinking")



tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/1.51G [00:00<?, ?B/s]

Upload 4 LFS files:   0%|          | 0/4 [00:00<?, ?it/s]

training_args.bin:   0%|          | 0.00/5.94k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/burtenshaw/trainer_output/commit/b12f62b31c8e43e186ab6cebcb3a4b360c9daea0', commit_message='burtenshaw/gemma3-4b-thinking', commit_description='', oid='b12f62b31c8e43e186ab6cebcb3a4b360c9daea0', pr_url=None, repo_url=RepoUrl('https://huggingface.co/burtenshaw/trainer_output', endpoint='https://huggingface.co', repo_type='model', repo_id='burtenshaw/trainer_output'), pr_revision=None, pr_num=None)

In [None]:
from transformers import pipeline

question = "The school principal decided that she wanted every class to have an equal number of boys and girls in each first-grade classroom. There are 4 classrooms. There are 56 boys and 44 girls. How many total students are in each classroom?"
generator = pipeline("text-generation", model=trainer.model, tokenizer=processor.tokenizer)
input = processor.apply_chat_template([{"role": "user", "content": question}])
input + "<reasoning>"
output = generator(input, max_new_tokens=1024)

Device set to use cuda:0
The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['AriaTextForCausalLM', 'BambaForCausalLM', 'BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'Cohere2ForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'DiffLlamaForCausalLM', 'ElectraForCausalLM', 'Emu3ForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FalconMambaForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'Gemma3ForCausalLM', 'Gemma3ForCausalLM', 'GitForCausalLM', 'GlmForCausalLM', 'GotOcr2ForConditionalGeneration', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCaus

In [None]:
output

[{'generated_text': '<bos><start_of_turn>user\nThe school principal decided that she wanted every class to have an equal number of boys and girls in each first-grade classroom. There are 4 classrooms. There are 56 boys and 44 girls. How many total students are in each classroom?<end_of_turn>\n* * *\n**Solution**\n\n1.  **Find the total number of students:** 56 boys + 44 girls = 100 students\n2.  **Divide the total students by the number of classrooms:** 100 students / 4 classrooms = 25 students per classroom\n\n**Answer:** There are 25 students in each classroom.'}]

# Next Steps!

Checkout the [The Reasoing Course](https://huggingface.co/reasoning-course) for more info on GRPO.

In the coming days we'll release a version of this notebook with Unsloth!

<a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>