# Let's make Gemma 3 1b think! 🍎

This is another notebook to make Gemma 3 think. This time focusing on the smallest 1b variant. You should be able to download this notebook for Mac silicone.

![logo](https://storage.googleapis.com/gweb-uniblog-publish-prod/images/Gemma3_KeywordBlog_RD3_V01b.width-1200.format-webp.webp)

👩‍🎓 If you want to learn more about making models think and reason, check out [The Reasoning Course](https://huggingface.co/reasoning-course)

### Installation

In [1]:
# install this release tag of transformers
!pip install -qqq git+https://github.com/huggingface/trl.git@main \
                  bitsandbytes

!pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3

!pip install git+https://github.com/huggingface/peft.git

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
langchain-cohere 0.1.9 requires langchain-core<0.3,>=0.2.2, but you have langchain-core 0.3.28 which is incompatible.
ragas 0.1.20 requires langchain-core<0.3, but you have langchain-core 0.3.28 which is incompatible.
embedchain 0.1.121 requires langchain<=0.3,>0.2, but you have langchain 0.3.13 which is incompatible.
embedchain 0.1.121 requires langchain-openai<0.2.0,>=0.1.7, but you have langchain-openai 0.2.13 which is incompatible.
embedchain 0.1.121 requires tiktoken<0.8.0,>=0.7.0, but you have tiktoken 0.8.0 which is incompatible.
crewai 0.51.1 requires langchain<=0.3,>0.2, but you have langchain 0.3.13 which is incompatible.
crewai 0.51.1 requires regex<2024.0.0,>=2023.12.25, but you have regex 2024.11.6 which is incompatible.
aider-chat 0.70.0 requires fsspec==2024.10.0, but you have fsspec 2024.6.1 w

In [2]:
from huggingface_hub import notebook_login
notebook_login()

  from .autonotebook import tqdm as notebook_tqdm


ImportError: The `notebook_login` function can only be used in a notebook (Jupyter or Colab) and you need the `ipywidgets` module: `pip install ipywidgets`.

In [None]:
import torch
from transformers import Gemma3ForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model

torch_dtype = torch.bfloat16

model = Gemma3ForCausalLM.from_pretrained(
    pretrained_model_name_or_path="google/gemma-3-1b-it",
    device_map="auto" if not torch.mps.is_available() else torch.device("mps"),  # switch to mac silicone
    attn_implementation="sdpa",
    torch_dtype=torch_dtype
)

# Load LoRA
peft_config = LoraConfig(
    lora_alpha=4,
    lora_dropout=0.05,
    r=4,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],  # make sure to save the lm_head and embed_tokens as you train the special tokens
)

model = get_peft_model(model, peft_config)
print(model.print_trainable_parameters())

processor = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")

trainable params: 607,241,216 || all params: 1,607,127,168 || trainable%: 37.7843
None


### Process data to create reasoning chains

Borrowing from [Will Brown's gist](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) we'll make reasoning chains from GSM8k.

In [None]:
import re
from datasets import load_dataset, Dataset

# Load and prep dataset
SYSTEM_PROMPT = """Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>"""

XML_COT_FORMAT = """<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

# Reward Functions

Now, let's define reward functions. These are the functions we'll need to setup reward chains.

| Reward Function | Purpose |
|---|---|
| `correctness_reward_func` | Rewards the model when its answer matches the correct answer |
| `int_reward_func` | Rewards the model for providing a numeric answer |
| `strict_format_reward_func` and `soft_format_reward_func` | Reward the model for following the specified format |
| `xmlcount_reward_func` | Rewards proper XML tag usage and penalizes extra content after the closing tags |

In [None]:
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

# Train with GRPOTrainer

Now we'll confgure training with the `GRPOConfig`

In [None]:
from trl import GRPOConfig, GRPOTrainer
from transformers import GenerationConfig

max_prompt_length = 1024
max_seq_length = 2048


training_args = GRPOConfig(
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "constant",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 4,
    gradient_accumulation_steps = 4,
    num_generations = 2,
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    num_train_epochs = 1,
    max_steps = 5,
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "none",
    cache_implementation="hybrid"
)

# Start trainer

In [None]:
from trl.trainer.utils import pad
import torch

trainer = GRPOTrainer(
    model = model,
    processing_class = processor,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)

trainer.train()

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


-------------------- Question:
Ahmed and Emily are having a contest to see who can get the best grade in the class. There have been 9 assignments and Ahmed has a 91 in the class. Emily has a 92. The final assignment is worth the same amount as all the other assignments. Emily got a 90 on the final assignment. What is the minimum grade Ahmed needs to get to beat Emily if all grades are whole numbers? 
Answer:
100 
Response:
<reasoning>
The problem asks for the minimum grade Ahmed needs to achieve to beat Emily. We are given that the final assignment is worth the same amount as all the other assignments. This means that the difference between Ahmed’s and Emily’s final grades is irrelevant to the comparison. We only need to consider the overall grade difference. We are given that Ahmed has a 91, and Emily has a 92. The final assignment is worth the same amount as all other assignments. Therefore, the difference between their final grades is 92 - 91 = 1. This means that Ahmed needs to get 

Step,Training Loss
1,0.0
2,0.0408
3,0.0003
4,0.0
5,0.0


-------------------- Question:
Marie has 98 unread messages on her phone. She decides to clear them by reading 20 messages a day. However, she also gets 6 new messages a day. How many days will it take her to read all her unread messages? 
Answer:
7 
Response:
<reasoning>
Marie needs to clear 98 messages. She reads 20 messages per day and gets 6 new messages per day. This means she gains 20 - 6 = 14 messages per day. To clear 98 messages, it will take 98 / 14 = 6.857... days. Since she can't read a fraction of a day, we need to round up to the nearest whole number.

</reasoning>
<answer>
6 days</answer>یحسن!iliحسن!(Good!)
othi
ieri c'è stata una precedente richiesta.














 ...Since the prompt is the same and there is no change in the instructions, I will simply state: "It is the same."?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?"?

TrainOutput(global_step=5, training_loss=0.008234577474649996, metrics={'train_runtime': 2197.8061, 'train_samples_per_second': 0.036, 'train_steps_per_second': 0.002, 'total_flos': 0.0, 'train_loss': 0.008234577474649996})

In [None]:
trainer.push_to_hub("burtenshaw/gemma3-1b-thinking")



tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/1.51G [00:00<?, ?B/s]

Upload 4 LFS files:   0%|          | 0/4 [00:00<?, ?it/s]

training_args.bin:   0%|          | 0.00/5.94k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/burtenshaw/trainer_output/commit/b12f62b31c8e43e186ab6cebcb3a4b360c9daea0', commit_message='burtenshaw/gemma3-4b-thinking', commit_description='', oid='b12f62b31c8e43e186ab6cebcb3a4b360c9daea0', pr_url=None, repo_url=RepoUrl('https://huggingface.co/burtenshaw/trainer_output', endpoint='https://huggingface.co', repo_type='model', repo_id='burtenshaw/trainer_output'), pr_revision=None, pr_num=None)

In [None]:
from transformers import pipeline

question = "The school principal decided that she wanted every class to have an equal number of boys and girls in each first-grade classroom. There are 4 classrooms. There are 56 boys and 44 girls. How many total students are in each classroom?"
generator = pipeline("text-generation", model=trainer.model, tokenizer=processor.tokenizer)
input = processor.apply_chat_template([{"role": "user", "content": question}])
input + "<reasoning>"
output = generator(input, max_new_tokens=1024)

Device set to use cuda:0
The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['AriaTextForCausalLM', 'BambaForCausalLM', 'BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'Cohere2ForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'DiffLlamaForCausalLM', 'ElectraForCausalLM', 'Emu3ForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FalconMambaForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'Gemma3ForCausalLM', 'Gemma3ForCausalLM', 'GitForCausalLM', 'GlmForCausalLM', 'GotOcr2ForConditionalGeneration', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCaus

In [None]:
output

[{'generated_text': '<bos><start_of_turn>user\nThe school principal decided that she wanted every class to have an equal number of boys and girls in each first-grade classroom. There are 4 classrooms. There are 56 boys and 44 girls. How many total students are in each classroom?<end_of_turn>\n* * *\n**Solution**\n\n1.  **Find the total number of students:** 56 boys + 44 girls = 100 students\n2.  **Divide the total students by the number of classrooms:** 100 students / 4 classrooms = 25 students per classroom\n\n**Answer:** There are 25 students in each classroom.'}]

# Next Steps!

Checkout the [The Reasoing Course](https://huggingface.co/reasoning-course) for more info on GRPO.

In the coming days we'll release a version of this notebook with Unsloth!

<a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>