<a href="https://colab.research.google.com/github/chiffonng/mnemonic-gen/blob/sft-re/notebooks/gemma3-grpo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Finetuning Gemma-3 with GRPO and LoRA: Enhancing AI Reasoning Capabilities

# Installation

In [None]:
import sys
import os


# Environment detection functions
def is_colab():
    return "COLAB_" in "".join(os.environ.keys())


def is_kaggle():
    return "KAGGLE_URL_BASE" in os.environ


print(is_colab())  # TRUE
print(is_kaggle())  # TRUE

In [None]:
%%capture

if not is_colab() and not is_kaggle():
    !pip install unsloth vllm "transformers>=4.50.0"
elif is_kaggle():
    !pip install unsloth[kaggle-new] vllm "transformers>=4.50.0"
else:
    !pip install --no-deps unsloth vllm "transformers>=4.50.0"

In [None]:
# @title Colab Extra Install { display-mode: "form" }
%%capture
import os

if not is_colab():
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests

    modules = list(sys.modules.keys())
    for x in modules:
        sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft "trl==0.15.2" triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get(
        "https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt"
    ).content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

## Utility functions

In [None]:
import os
from huggingface_hub import login

# Authentication handling based on environment
if is_kaggle():
    # For Kaggle, use Kaggle Secrets
    from kaggle_secrets import UserSecretsClient

    user_secrets = UserSecretsClient()
    HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
    WB_API_KEY = user_secrets.get_secret("WANDB_API_KEY")
    if HF_TOKEN is None:
        raise KeyError("HF_TOKEN not found in Kaggle secrets.")
elif is_colab():
    from google.colab import userdata

    HF_TOKEN = userdata.get("HF_TOKEN")
    WB_API_KEY = userdata.get("WANDB_API_KEY")
    if HF_TOKEN is None:
        raise KeyError("HF_TOKEN not found in Google Colab userdata.")
else:
    from dotenv import load_dotenv

    load_dotenv()
    try:
        HF_TOKEN = os.getenv("HF_TOKEN")
        WB_API_KEY = os.getenv("WANDB_API_KEY")
    except KeyError:
        raise KeyError("HF_TOKEN or WANDB_API_KEY not found in environment variables.")

# Login to Hugging Face
if is_kaggle():
    login(token=HF_TOKEN)
else:
    login(token=HF_TOKEN, add_to_git_credential=True)

# Initialize wandb if using
import wandb

if WB_API_KEY:
    wandb.login(key=WB_API_KEY)
    use_wandb = True
    run = wandb.init(
        project="ft-gemma-3-4b-it-en-mnemonics-reason",
        job_type="training",
        anonymous="allow",
    )
else:
    use_wandb = False

print(use_wandb)

### Load models and wrap with LoRA adapters

In [None]:
from unsloth import FastModel, is_bfloat16_supported
import torch

# Set maximum sequence length and LoRA rank (controls the adaptation complexity).
max_seq_length = 2048  # Increase if you need longer reasoning traces.
lora_rank = 16  # Larger rank can improve performance but may slow down training.

model, tokenizer = FastModel.from_pretrained(
    model_name="unsloth/gemma-3-4b-it",
    max_seq_length=max_seq_length,  # Choose any for long context!
    load_in_4bit=True,  # 4 bit quantization to reduce memory
    load_in_8bit=False,  # [NEW!] A bit more accurate, uses 2x memory
    full_finetuning=False,
)

model = FastModel.get_peft_model(
    model,
    finetune_vision_layers=False,  # Turn off for just text!
    finetune_language_layers=True,  # Should leave on!
    finetune_attention_modules=True,  # Attention good for GRPO
    finetune_mlp_modules=True,  # Should leave on always!
    r=lora_rank,  # Larger = higher accuracy, but might overfit
    lora_alpha=2 * lora_rank,
    lora_dropout=0,
    bias="none",
    random_state=42,
    use_rslora=True,  # Rank stabilized LoRA
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 03-19 15:51:40 [__init__.py:256] Automatically detected platform cuda.
==((====))==  Unsloth 2025.3.17: Fast Gemma3 patching. Transformers: 4.50.0.dev0. vLLM: 0.8.0.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Using float16 precision for gemma3 won't work! Using float32.
Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.


model.safetensors:   0%|          | 0.00/2.00G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/670 [00:00<?, ?B/s]

### Data Prep
<a name="Data"></a>

We're using OpenAI's famous GSM8K dataset!

In [None]:
from datasets import load_dataset

train_repo_id = "chiffonng/en-vocab-mnemonics-grpo"
test_repo_id = "chiffonng/en-vocab-mnemonics-test"

train_dataset = load_dataset(train_repo_id, split="train")
test_dataset = load_dataset(test_repo_id, split="test")
train_dataset

README.md:   0%|          | 0.00/7.94k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

Dataset({
    features: ['question', 'answer'],
    num_rows: 7473
})

Let's look at the first row:

In [2]:
import textwrap


def pretty_format(text_string):
    """
    Format a text string with proper indentation and line wrapping.

    Args:
        text_string: A string that may contain newlines and quotes

    Returns:
        Formatted string with consistent indentation
    """
    # Split the text into lines
    lines = text_string.split("\n")

    # Format each line using textwrap for consistent width
    formatted_lines = []
    for line in lines:
        # Skip empty lines but preserve them in output
        if not line.strip():
            formatted_lines.append("")
            continue

        # Wrap text to 80 characters while preserving indentation
        wrapped = textwrap.fill(
            line,
            width=80,
            initial_indent="",
        )
        formatted_lines.append(wrapped)

    # Join the formatted lines back together
    return "\n".join(formatted_lines)


# Example usage
sample_text = """This is a long line of text that will be wrapped to make it more readable.
Here's another line with a quoted string like \"this is quoted\" which should be preserved.

This paragraph comes after an empty line above."""

print(pretty_format(sample_text))

This is a long line of text that will be wrapped to make it more readable.
Here's another line with a quoted string like "this is quoted" which should be
preserved.

This paragraph comes after an empty line above.


In [None]:
print(pretty_format(train_dataset[0]["reasoning"]))
print(pretty_format(train_dataset[0]["solution"]))

In [None]:
train_dataset[0]["prompts"]

'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?'

In [None]:
train_dataset[0]["completions"]

'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'

## Reward functions

In [None]:
import re

reasoning_start = "<think>"
reasoning_end = "</think>"
solution_start = "<solution>"
solution_end = "</solution>"


def match_format_exactly(completions, **kwargs):
    match_format = re.compile(
        rf"^[\s]{{0,}}"
        rf"{reasoning_start}.+?{reasoning_end}.*?"
        rf"{solution_start}(.+?){solution_end}"
        rf"[\s]{{0,}}$",
        flags=re.MULTILINE | re.DOTALL,
    )
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Match if format is seen exactly!
        if match_format.search(response) is not None:
            score += 3.0
        scores.append(score)
    return scores

If it fails, we want to reward the model if it at least follows the format partially, by counting each symbol:

In [None]:
def match_format_approximately(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Count how many keywords are seen - we penalize if too many!
        # If we see 1, then plus some points!
        score += 0.5 if response.count(reasoning_start) == 1 else -0.5
        score += 0.5 if response.count(reasoning_end) == 1 else -0.5
        score += 0.5 if response.count(solution_start) == 1 else -0.5
        score += 0.5 if response.count(solution_end) == 1 else -0.5
        scores.append(score)
    return scores

## Feature exaction

In [None]:
import re

LINGUISTIC_FEATURES = [
    "phonetics",
    "orthography",
    "etymology",
    "morphology",
    "semantics",
    "custom",
]


# Helper functions to extract specific parts from the model's response
def extract_linguistic_feature(text):
    """Extract the linguistic feature mentioned in the response"""
    text = text.lower()
    if "linguistic_feature:" in text:
        feature_line = text.split("linguistic_feature:")[1].split("\n")[0].strip()
        for feature in LINGUISTIC_FEATURES:
            if feature in feature_line:
                return feature


def extract_mnemonic(text):
    """Extract the mnemonic from the response"""
    text = text.lower()
    if "mnemonic:" in text:
        return text.split("mnemonic:")[1].split("\n")[0].strip()


def extract_example(text):
    """Extract the example from the response"""
    text = text.lower()
    if "example:" in text:
        return text.split("example:")[1].split("\n")[0].strip()
    return ""


# extract reasoning from text between reasoning_start and reasoning_end
def extract_reasoning(text):
    pass


# TODO: extract reasoning from text between solution_start and solution_end
def extract_solution(text):
    """Extract the solution from the response"""
    pass

## Reward functions for GRPO training

In [None]:
# 1. Format reward: Check if response follows the required format
def format_reward_func(completions, **kwargs):
    """Reward function that checks if the completion follows the expected format"""
    responses = [completion[0]["content"] for completion in completions]

    rewards = []
    for response in responses:
        score = 0.0
        # Check for linguistic_feature section
        if "linguistic:" in response.lower():
            score += 0.3
        else:
            score -= 0.2
        # Check for mnemonic section
        if "mnemonic:" in response.lower():
            score += 0.4
        else:
            score -= 0.3
        # Check for example section
        if "example:" in response.lower():
            score += 0.1
        rewards.append(score)

    return rewards


# 2. Linguistic feature reward: Check which linguistic features are mentioned
def contains_linguistic_feature(completions, **kwargs):
    """Reward function that scores based on linguistic features mentioned"""
    responses = [completion[0]["content"] for completion in completions]

    rewards = []
    for response in responses:
        score = 0.0
        response_lower = response.lower()

        # Check for feature-specific content
        if "etymology" in response_lower and "greek" not in response_lower:
            score += 0.5
        elif "morphology" in response_lower:
            score += 0.5
        elif "phonetics" in response_lower:
            score += 0.4
        elif "orthography" in response_lower:
            score += 0.4
        elif "semantics" in response_lower:
            score += 0.4
        elif "custom" in response_lower:
            score += 0.3

        # Add smaller reward for mentioning more linguistic features
        for feature in LINGUISTIC_FEATURES:
            if feature in response_lower:
                score += 0.2

        rewards.append(min(score, 1.5))  # Cap reward at 1.0

    return rewards


# 3.Check if the mnemonic contains the term from "term" column
def contains_term(completions, term, **kwargs):
    """Reward function that checks if the term appears in the mnemonic"""
    responses = [completion[0]["content"] for completion in completions]

    rewards = []
    for i, response in enumerate(responses):
        mnemonic_part = extract_mnemonic(response)
        if contains_term(mnemonic_part, term[i]):
            rewards.append(1.0)
        else:
            rewards.append(0.0)

    return rewards


# 4. Reasoning quality reward: Check for reasoning about the term
def contains_reasoning_indicators(completions, **kwargs):
    """Reward function that evaluates reasoning quality"""
    responses = [completion[0]["content"] for completion in completions]

    rewards = []
    for response in responses:
        score = 0.0
        response_lower = response.lower()

        # Check for reasoning indicators
        reasoning_indicators = [
            "because",
            "since",
            "as",
            "therefore",
            "thus",
            "hence",
            "thinkdue to",
            "reason",
            "analyze",
            "connect",
            "relation",
            "so",
        ]

        for indicator in reasoning_indicators:
            if indicator in response_lower:
                score += 0.1

        rewards.append(min(score, 1.0))  # Cap reward at 1.0

    return rewards


# 5. Check for vivid imagery and associations
def contains_associations(completions, **kwargs):
    """Reward function that evaluates word/subword associations"""

    mnemonics = [
        extract_mnemonic(completion[0]["content"]) for completion in completions
    ]

    association_regex = [
        "=",
        "+",
        "related",
        "similar",
        "imagine",
        "picture"
        r"think(\w+)of",
        r"break(\w+)down",
        "associate",
    ]

    rewards = []
    for mnemonic in mnemonics:
        # if mnemonic match one of association_regex
        # the more the better reward
        score = 0.0
        for regex in association_regex:
            if re.search(regex, mnemonic):
                score += 0.1
        rewards.append(score)

    return rewards


reward_funcs = [
    format_reward_func,
]
reward_weights = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

<a name="Train"></a>
### Train the model

Now set up GRPO Trainer and all configurations!

In [None]:
max_prompt_length = 1500
max_completion_length = 1500  # reasoning + solution

from trl import GRPOConfig, GRPOTrainer

training_args = GRPOConfig(
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_paged8bit",
    logging_steps=10,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_generations=4,
    max_prompt_length=max_prompt_length,
    max_completion_length=max_completion_length,
    num_train_epochs=3,
    max_steps=50,
    save_steps=50,
    max_grad_norm=0.1,
    report_to="wandb",  # Can use Weights & Biases
    output_dir="outputs",
)

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=reward_funcs,
    reward_weights=reward_weights,
    dataset=train_dataset,
    args=training_args,
)
trainer.train()

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 4


And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!

You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!

| Step | Training Loss | reward    | reward_std | completion_length | kl       |
|------|---------------|-----------|------------|-------------------|----------|
| 1    | 0.000000      | 0.125000  | 0.000000   | 200.000000        | 0.000000 |
| 2    | 0.000000      | 0.072375  | 0.248112   | 200.000000        | 0.000000 |
| 3    | 0.000000      | -0.079000 | 0.163776   | 182.500000        | 0.000005 |


<a name="Inference"></a>
### Inference
Now let's try the model we just trained!

In [None]:
from unsloth.chat_templates import get_chat_template
from transformers import TextStreamer

word = "ephemeral"

messages = [
    {"role": "system", "content": system_prompt},
    {
        "role": "user",
        "content": [
            {
                "type": "text",
                "text": f"Create a memory aid so that I could learn the word '{word}'",
            }
        ],
    },
]

tokenizer = get_chat_template(
    tokenizer,
    chat_template="gemma-3",
)

text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,  # Must add for generation
)

_ = model.generate(
    **tokenizer([text], return_tensors="pt").to("cuda"),
    max_new_tokens=1024,
    # Recommended Gemma-3 settings!
    temperature=1.0,
    top_p=0.95,
    top_k=64,
    streamer=TextStreamer(tokenizer, skip_prompt=True),
)

<start_working_out>
The square root of 101 is approximately 10.0498756.
<SOLUTION>
10.0498756<end_of_turn>


<a name="Save"></a>
### Saving, loading finetuned models


### Saving to float16 for VLLM

We also support saving to `float16` directly for deployment! We save it in the folder `gemma-3-finetune`. Set `if False` to `if True` to let it run!

In [None]:
model.save_pretrained_merged(
    "gemma-3-4b-it-mnemonics", tokenizer, save_method="merged_16bit"
)
model.push_to_hub_merged(
    "gemma-3-4b-it-mnemonics", tokenizer, save_method="merged_16bit", token=HF_TOKEN
)

### GGUF / llama.cpp Conversion
To save to `GGUF` / `llama.cpp`, we support it natively now for all models! For now, you can convert easily to `Q8_0, F16 or BF16` precision. `Q4_K_M` for 4bit will come later!

In [None]:
# https://docs.unsloth.ai/basics/running-and-saving-models/saving-to-gguf
model.push_to_hub_gguf(
    "chiffonng/gemma-3-4b-it-vmm",
    tokenizer,
    quantization_method=["q4_k_m", "f16"],
    token=HF_TOKEN,
)

# Ollama

In [None]:
model.save_pretrained_gguf("gemma-3-4b-it-vmm", tokenizer, quantization_method="f16")

In [None]:
!curl -fsSL https://ollama.com/install.sh | sh

In [None]:
import subprocess

subprocess.Popen(["ollama", "serve"])
import time

time.sleep(3)  # Wait for a few seconds for Ollama to load!

In [None]:
print(tokenizer._ollama_modelfile)

In [None]:
# !ollama create gemma3_4b_it_vmm -f•/model/Modelfile