# References

https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma3N_(4B)-Conversational.ipynb

In [1]:
import torch
from unsloth import FastLanguageModel, FastModel
import os
from datasets import load_dataset, Dataset, IterableDataset
from trl import GRPOConfig, GRPOTrainer
from pprint import pprint
import re
import wandb
from vllm import SamplingParams
from transformers import TextStreamer
from datetime import datetime
import sys
import logging

# Check GPU
if torch.cuda.is_available():
    print(f"GPU detected: {torch.cuda.get_device_name(0)}")
else:
    print("No GPU detected.")

# Environment variables for torch
os.environ["TORCH_LOGS"] = "recompiles"
os.environ['TORCHDYNAMO_CACHE_SIZE_LIMIT'] = '999999999'
import torch._dynamo 
torch._dynamo.config.cache_size_limit = 64

TRAIN = False

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


  from .autonotebook import tqdm as notebook_tqdm


🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 07-24 21:19:31 [__init__.py:244] Automatically detected platform cuda.
GPU detected: NVIDIA GeForce RTX 3060


In [None]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_dir = "outputs"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, "evaluation.log")
file_handler = logging.FileHandler(log_file, mode='w')  # override the log file
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
# Remove all existing file handlers to ensure override
for h in logger.handlers[:]:
    if isinstance(h, logging.FileHandler):
        logger.removeHandler(h)
logger.addHandler(file_handler)

logger.info("Hello world")

In [3]:

reasoning_start = "<start_working_out>"
reasoning_end = "<end_working_out>"
solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"

system_prompt = \
    f"""You are given a problem, think about the problem and provide your workout. 
    Place it between {reasoning_start} and {reasoning_end}. Then provide your solution
    between {solution_start}{solution_end}"""

print(system_prompt)


You are given a problem, think about the problem and provide your workout. 
    Place it between <start_working_out> and <end_working_out>. Then provide your solution
    between <SOLUTION></SOLUTION>


# Prepare dataset

In [4]:
from datasets import load_dataset

dataset = load_dataset('openai/gsm8k', 'main', split='train')
print(dataset)
dataset[0]


Dataset({
    features: ['question', 'answer'],
    num_rows: 7473
})


{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
 'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'}

In [5]:
def extract_hash_answer(text):
    if "####" not in text: return None
    return text.split("####")[1].strip()

dataset = dataset.map(lambda x: {
    "prompt" : [
        {"role": "system", "content": system_prompt},
        {"role": "user",   "content": x["question"]},
    ],
    "answer": extract_hash_answer(x["answer"]),
})

print(dataset)
pprint(dataset[0])
assert int(dataset[0]['answer']), "answer not a number format"



Dataset({
    features: ['question', 'answer', 'prompt'],
    num_rows: 7473
})
{'answer': '72',
 'prompt': [{'content': 'You are given a problem, think about the problem and '
                        'provide your workout. \n'
                        '    Place it between <start_working_out> and '
                        '<end_working_out>. Then provide your solution\n'
                        '    between <SOLUTION></SOLUTION>',
             'role': 'system'},
            {'content': 'Natalia sold clips to 48 of her friends in April, and '
                        'then she sold half as many clips in May. How many '
                        'clips did Natalia sell altogether in April and May?',
             'role': 'user'}],
 'question': 'Natalia sold clips to 48 of her friends in April, and then she '
             'sold half as many clips in May. How many clips did Natalia sell '
             'altogether in April and May?'}


# Format match function 

In [6]:
# This regular expression is used to match a specific format in a string, typically for extracting
# the solution part from a text that contains both reasoning and solution sections.
# - It expects the string to start with optional whitespace.
# - Then it looks for the reasoning section, which starts with the value of `reasoning_start`,
#   contains any characters (non-greedy), and ends with `reasoning_end`.
# - After that, it expects the solution section, which starts with `solution_start`,
#   captures everything up to `solution_end` (the solution itself is captured in a group).
# - Finally, it expects optional whitespace at the end of the string.
# - The flags `re.MULTILINE` and `re.DOTALL` allow the regex to match across multiple lines
#   and let the dot (`.`) match newline characters as well.

match_format = re.compile(
    rf"^[\s]{{0,}}"\
    rf"{reasoning_start}.+?{reasoning_end}.*?"\
    rf"{solution_start}(.+?){solution_end}"\
    rf"[\s]{{0,}}$",
    flags = re.MULTILINE | re.DOTALL
)

#test
res = match_format.search(
    "<start_working_out>Let me think!<end_working_out>"\
    "<SOLUTION>2</SOLUTION>",
)

print(res)
print(res.group(1))

res = match_format.search(
    "<SOLUTION>2</SOLUTION>",
)

print(res)
# print(res.group(1))

<re.Match object; span=(0, 71), match='<start_working_out>Let me think!<end_working_out>>
2
None


In [7]:
def match_format_exactly(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        res = completion[0]['content']
        if match_format.search(res) is not None: score += 3.0
        scores.append(score)
    return scores

# Unit tests for match_format_exactly
actual = match_format_exactly([[{'content': "<start_working_out>Reason<end_working_out><SOLUTION>42</SOLUTION>"}]])
assert actual == [3.0], f"Test 1 Failed: Expected [3.0] for valid reasoning and solution, got {actual}"

actual = match_format_exactly([[{'content': "<start_working_out>R<end_working_out><SOLUTION>ans</SOLUTION>"}]])
assert actual == [3.0], f"Test 2 Failed: Expected [3.0] for valid short reasoning and solution, got {actual}"

actual = match_format_exactly([[{'content': "<start_working_out>R<end_working_out><SOLUTION></SOLUTION>"}]])
assert actual == [0.0], f"Test 3 Failed: Expected [0.0] for empty solution but valid format, got {actual}"

actual = match_format_exactly([[{'content': "<start_working_out>R<end_working_out>"}]])
assert actual == [0], f"Test 4 Failed: Expected [0] for missing solution section, got {actual}"

actual = match_format_exactly([[{'content': "<SOLUTION>42</SOLUTION>"}]])
assert actual == [0], f"Test 5 Failed: Expected [0] for missing reasoning section, got {actual}"

actual = match_format_exactly([[{'content': ""}]])
assert actual == [0], f"Test 6 Failed: Expected [0] for empty string, got {actual}"

actual = match_format_exactly([[{'content': "<start_working_out>R<end_working_out><SOLUTION>ans</SOLUTION> extra"}]])
assert actual == [0], f"Test 7 Failed: Expected [0] for extra text after valid format, got {actual}"


In [8]:
def match_format_approx(completions, **kwargs):
    scores = []
    for completion in completions:
        response = completion[0]['content']
        scores.append(
            sum(0.5 if response.count(tag) == 1 else -0.5 
                for tag in [reasoning_start, reasoning_end, solution_start, solution_end])
        )
    return scores


# Unit tests for match_format_approx
actual = match_format_approx([[{'content': "<start_working_out>R<end_working_out><SOLUTION>42</SOLUTION>"}]])
assert actual == [2.0], f"Test 1 Failed: Expected [2.0] for all tags present once, got {actual}"

actual = match_format_approx([[{'content': "<start_working_out>R<end_working_out>"}]])
assert actual == [0.0], f"Test 2 Failed: Expected [0.0] for only reasoning tags, got {actual}"

actual = match_format_approx([[{'content': "<SOLUTION>42</SOLUTION>"}]])
assert actual == [0.0], f"Test 3 Failed: Expected [0.0] for only solution tags, got {actual}"

actual = match_format_approx([[{'content': "<start_working_out>R<end_working_out><SOLUTION>42</SOLUTION> extra"}]])
assert actual == [2.0], f"Test 4 Failed: Expected [2.0] for all tags present with extra text, got {actual}"

actual = match_format_approx([[{'content': ""}]])
assert actual == [-2.0], f"Test 5 Failed: Expected [-2.0] for missing all tags, got {actual}"

actual = match_format_approx([[{'content': "<start_working_out>R<end_working_out><SOLUTION>42</SOLUTION><SOLUTION>43</SOLUTION>"}]])
assert actual == [0.0], f"Test 6 Failed: Expected [0.0] for duplicate solution tags, got {actual}"


In [9]:
def check_answer(prompts, completions, answer, **kwargs):
    responses = [completion[0]['content'] for completion in completions]

    extracted_responses = [
        guess.group(1) # the answer after hash #### 
        if(guess:= match_format.search(res)) is not None else None
        for res in responses
    ]

    scores = []

    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(0)
            continue
        if guess == true_answer:
            score += 3.0
        elif guess.strip() == true_answer.strip(): # correct answer but there are spaces in between won't get full points
            score += 1.5
        else:
            try:
                ratio = float(guess) / float(true_answer)
                if 0.9 <= ratio <= 1.1:
                    score += 0.5
                elif 0.8 <= ratio <= 1.2:
                    score += 0.25
                else:
                    score -= 1.0 #wrong answer, penalize
            except:
                score -= 0.5 #unknown format 
        scores.append(score)

    return scores

# Compact unit tests for check_answer

# Test 1: Exact match

actual = check_answer([["Q"]], [[{'content': f"{reasoning_start}Some reasoning here{reasoning_end}<SOLUTION>42</SOLUTION>"}]], ["42"])
assert actual == [3.0], f"Test 1 Failed: {actual}"

# Test 2: Whitespace match
actual = check_answer([["Q"]], [[{'content': f"{reasoning_start}Reasoning{reasoning_end}<SOLUTION>   42  </SOLUTION>"}]], ["42"])
assert actual == [1.5], f"Test 2 Failed: {actual}"

# Test 3: Ratio within 10%
actual = check_answer([["Q"]], [[{'content': f"{reasoning_start}Math steps{reasoning_end}<SOLUTION>95</SOLUTION>"}]], ["100"])
assert actual == [0.5], f"Test 3 Failed: {actual}"

# Test 4: Ratio within 20%
actual = check_answer([["Q"]], [[{'content': f"{reasoning_start}Estimate{reasoning_end}<SOLUTION>85</SOLUTION>"}]], ["100"])
assert actual == [0.25], f"Test 4 Failed: {actual}"

# Test 5: Wrong numeric answer
actual = check_answer([["Q"]], [[{'content': f"{reasoning_start}Wrong math{reasoning_end}<SOLUTION>50</SOLUTION>"}]], ["100"])
assert actual == [-1.0], f"Test 5 Failed: {actual}"

# Test 6: Unknown format
actual = check_answer([["Q"]], [[{'content': f"{reasoning_start}Nonsense{reasoning_end}<SOLUTION>foo</SOLUTION>"}]], ["100"])
assert actual == [-0.5], f"Test 6 Failed: {actual}"

# Test 7: No match
actual = check_answer([["Q"]], [[{'content': f"{reasoning_start}No answer here{reasoning_end}"}]], ["100"])
assert actual == [0], f"Test 7 Failed: {actual}"

# Test 8: Multiple answers
actual = check_answer(
    [["Q1"], ["Q2"]],
    [
        [{'content': f"{reasoning_start}Reasoning1{reasoning_end}<SOLUTION>4</SOLUTION>"}],
        [{'content': f"{reasoning_start}Reasoning2{reasoning_end}<SOLUTION>6</SOLUTION>"}]
    ],
    ["4", "6"]
)
assert actual == [3.0, 3.0], f"Test 8 Failed: {actual}"


In [10]:
match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})",
    flags = re.MULTILINE | re.DOTALL
)
match_numbers.findall("<SOLUTION>  Answer is 0.34, Another answer is 0.45 </SOLUTION>")

['0.34']

In [11]:
def extract_response(response):
    guess = match_numbers.search(response)
    return guess.group(1) if guess != None else None

def check_numbers(prompts, completions, answer, **kwargs):
    question = prompts[0][-1]['content']
    responses = [completion[0]['content'] for completion in completions]

    extracted_responses = [
        # guess.group(1)
        result
        if (result := extract_response(response)) is not None else None
        for response in responses
    ]

    scores = []
    print('*'*20, f"Question:\n{question}", f"Answer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")

    for guess, true_answer in zip(extracted_responses, answer):
        
        if guess is None:
            scores.append(0.0)
            continue
        try:
            true_answer = float(true_answer.strip())
            guess = float(guess.strip())
            scores.append(1.5 if guess == true_answer else 0.0)
        
        except:
            scores.append(0.0)
            continue
    
    return scores

# Test 1: Correct extraction and matching
actual = check_numbers(
    [[{'content': ''}]],
    [[{'content': '<SOLUTION> 0.34 </SOLUTION>'}]],
    ["0.34"]
)
assert actual == [1.5], f"Test 1 Failed: {actual}"

# Test 2: Extraction with extra text
actual = check_numbers(
    [[{'content': ''}]],
    [[{'content': '<SOLUTION> Answer is 0.34, Another answer is 0.45 </SOLUTION>'}]],
    ["0.34"]
)
assert actual == [1.5], f"Test 2 Failed: {actual}"

# Test 3: Incorrect number
actual = check_numbers(
    [[{'content': ''}]],
    [[{'content': '<SOLUTION> 0.45 </SOLUTION>'}]],
    ["0.34"]
)
assert actual == [0.0], f"Test 3 Failed: {actual}"

# Test 4: No number found
actual = check_numbers(
    [[{'content': ''}]],
    [[{'content': '<SOLUTION> no number here </SOLUTION>'}]],
    ["0.34"]
)
assert actual == [0.0], f"Test 4 Failed: {actual}"

# Test 5: Multiple completions
actual = check_numbers(
    [[{'content': ''}, {'content': ''}]],
    [[{'content': '<SOLUTION> 1.23 </SOLUTION>'}], [{'content': '<SOLUTION> 4.56 </SOLUTION>'}]],
    ["1.23", "4.56"]
)
assert actual == [1.5, 1.5], f"Test 5 Failed: {actual}"


******************** Question:
 Answer:
0.34 
Response:
<SOLUTION> 0.34 </SOLUTION> 
Extracted:
0.34
******************** Question:
 Answer:
0.34 
Response:
<SOLUTION> Answer is 0.34, Another answer is 0.45 </SOLUTION> 
Extracted:
0.34
******************** Question:
 Answer:
0.34 
Response:
<SOLUTION> 0.45 </SOLUTION> 
Extracted:
0.45
******************** Question:
 Answer:
0.34 
Response:
<SOLUTION> no number here </SOLUTION> 
Extracted:
None
******************** Question:
 Answer:
1.23 
Response:
<SOLUTION> 1.23 </SOLUTION> 
Extracted:
1.23


# GRPOConfig and GRPOTrainer

In [12]:
if TRAIN:
    max_seq_length = 1024
    model, tokenizer = FastModel.from_pretrained(
        model_name = "unsloth/gemma-3-1b-it",
        max_seq_length = max_seq_length, # Choose any for long context!
        load_in_4bit = False,  # 4 bit quantization to reduce memory
        load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
        full_finetuning = False, # [NEW!] We have full finetuning now!
    )

    model = FastModel.get_peft_model(
        model,
        finetune_vision_layers     = False, # Turn off for just text!
        finetune_language_layers   = True,  # Should leave on!
        finetune_attention_modules = True,  # Attention good for GRPO
        finetune_mlp_modules       = True,  # SHould leave on always!
        r = 8,           # Larger = higher accuracy, but might overfit
        lora_alpha = 8,  # Recommended alpha == r at least
        lora_dropout = 0,
        bias = "none",
        random_state = 3407,
    )
    print(type(model))

In [13]:
if TRAIN:
    max_prompt_length = 256
    max_seq_length = 1024

    grpo_config = GRPOConfig(
        learning_rate = 5e-6,                # The initial learning rate for the optimizer
        adam_beta1 = 0.9,                    # Beta1 parameter for Adam optimizer (exponential decay rate for first moment estimates)
        adam_beta2 = 0.99,                   # Beta2 parameter for Adam optimizer (exponential decay rate for second moment estimates)
        weight_decay = 0.1,                  # Weight decay (L2 penalty) to prevent overfitting
        warmup_ratio = 0.1,                  # Fraction of total steps used for learning rate warmup
        lr_scheduler_type = "cosine",        # Type of learning rate scheduler ("cosine" annealing)
        optim = "adamw_torch_fused",         # Optimizer type (fused AdamW for efficiency)
        logging_steps = 1,                   # Log training metrics every N steps
        per_device_train_batch_size = 1,     # Batch size per device (GPU/CPU) during training
        gradient_accumulation_steps = 1,     # Number of steps to accumulate gradients before updating weights (increase for larger effective batch size)
        num_generations = 4,                 # Number of generations per prompt (reduce if out of memory)
        max_prompt_length = max_prompt_length,                   # Maximum length of the input prompt
        max_completion_length = max_seq_length - max_prompt_length, # Maximum length of the generated completion
        # num_train_epochs = 1,              # Number of training epochs (uncomment and set for full training run)
        max_steps = 50,                      # Total number of training steps
        save_steps = 50,                     # Save checkpoint every N steps
        max_grad_norm = 0.1,                 # Maximum gradient norm for gradient clipping
        report_to = "none",                  # Reporting backend ("none" disables reporting, can use "wandb" for Weights & Biases)
        output_dir = "outputs"               # Directory to save model checkpoints and outputs
        
    )

    trainer = GRPOTrainer(
        model = model,
        processing_class = tokenizer,
        reward_funcs = [
            match_format_exactly,
            match_format_approx,
            check_answer,
            check_numbers,
        ],
        args = grpo_config,
        train_dataset = dataset,
    )

    trainer.train()

    model.save_pretrained('outputs/gemma-3-tune1')
    tokenizer.save_pretrained('outputs/gemma-3-tune1')

# Evaluate pipeline 

In [14]:
# load the model weights
# model.save_pretrained('outputs/gemma-3-tune1')
model, tokenizer = FastModel.from_pretrained('outputs/gemma-3-tune1')
print(type(model))
print(type(tokenizer))
print(model.device)

==((====))==  Unsloth 2025.7.5: Fast Gemma3 patching. Transformers: 4.53.2. vLLM: 0.9.2.
   \\   /|    NVIDIA GeForce RTX 3060. Num GPUs = 1. Max memory: 11.622 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
<class 'peft.peft_model.PeftModelForCausalLM'>
<class 'transformers.models.gemma.tokenization_gemma_fast.GemmaTokenizerFast'>
cuda:0


In [15]:
# Sample inference
messages = [
    {'role': 'system', "content": system_prompt},
    {'role': 'user', "content": "What is the square root of 1010?"},
]

token_ids = tokenizer.apply_chat_template(
    messages, 
    add_generation_prompt = True,
    return_tensors = "pt",
    tokenize=True
).to('cuda')

print(token_ids)
print(type(token_ids))

# from transformers import TextStreamer

output = model.generate(
    token_ids,
    max_new_tokens = 64, # Increase for longer outputs!
    # Recommended Gemma-3 settings!
    temperature = 1.0, top_p = 0.95, top_k = 64,
    # streamer = TextStreamer(tokenizer, skip_prompt = True),
)

print(tokenizer.decode(output[0]))


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


tensor([[     2,    105,   2364,    107,   3048,    659,   2238,    496,   2608,
         236764,   1751,   1003,    506,   2608,    532,   2847,    822,  26149,
         236761, 236743,    107,    140,  24698,    625,   1534,    655,   3041,
         236779,  25421, 236779,    725, 236813,    532,    655,    643, 236779,
          25421, 236779,    725,  24449,   4298,   2847,    822,   3465,    107,
            140,  19195,    655, 213910,   2588, 213910, 236813,    108,   3689,
            563,    506,   6281,   5989,    529, 236743, 236770, 236771, 236770,
         236771, 236881,    106,    107,    105,   4368,    107]],
       device='cuda:0')
<class 'torch.Tensor'>
<bos><start_of_turn>user
You are given a problem, think about the problem and provide your workout. 
    Place it between <start_working_out> and <end_working_out>. Then provide your solution
    between <SOLUTION></SOLUTION>

What is the square root of 1010?<end_of_turn>
<start_of_turn>model
<START_WORKING_OUT>
1010 

In [16]:
test_ds = load_dataset('openai/gsm8k', 'main', split="test")
print(test_ds)

Dataset({
    features: ['question', 'answer'],
    num_rows: 1319
})


In [17]:
a = x if (x:=89) is not None else None
print(a)

89


In [19]:
def generate(model, question, **kwargs):
    MAX_NEW_TOKENS = kwargs.get("max_new_tokens", 64)

    # Sample inference
    messages = [
        {'role': 'system', "content": system_prompt},
        {'role': 'user', "content": question},
    ]

    token_ids = tokenizer.apply_chat_template(
        messages, 
        add_generation_prompt = True,
        return_tensors = "pt",
        tokenize=True
    ).to('cuda')

    # from transformers import TextStreamer

    output = model.generate(
        token_ids,
        max_new_tokens = MAX_NEW_TOKENS, # Increase for longer outputs!
        # Recommended Gemma-3 settings!
        temperature = 1.0, top_p = 0.95, top_k = 64,
    )

    return output

def get_answer_from_completion(output: torch.Tensor):
    assert type(output) == torch.Tensor, 'Wrong type: output type must be torch.tensor'

    completion = tokenizer.decode(output[0])
    return x if (x:=extract_response(completion)) is not None else None, completion


print(test_ds)
logging.info("##### EVALUATION #####")
from tqdm import tqdm

correct_answer = 0
total = 0 
for X in tqdm(test_ds, desc="Evaluating"):
    total += 1
    try:
        question = X["question"]
        answer = extract_hash_answer(X['answer'])
        y = generate(model, question, max_new_tokens = 128)
        pred_answer, completion = get_answer_from_completion(y)

        # Log for later analysis

        logger.info('question\t' + question)
        logger.info('target_answer\t'+  answer)
        logger.info('pred_answer\t'+ pred_answer)
        logger.info('completion\t'+ completion)

        correct_answer += 1 if (pred_answer != None and pred_answer == answer) else 0

    except Exception as e:
        print("error occured, check log")
        logger.error("Error in this question ", question)
    
    accuracy = (correct_answer / total) * 100
    print(f'Accuracy = {accuracy:.2f}')
    logger.info(f'Accuracy = {accuracy:.2f}')
    logger.info("################################################")
    break

# generate(model, test_ds[0]['question'])
    


Dataset({
    features: ['question', 'answer'],
    num_rows: 1319
})


Evaluating:   0%|          | 0/1319 [00:06<?, ?it/s]

Accuracy = 0.00





In [None]:
# from sklearn.metrics import accuracy_score


# x = torch.tensor([1,2,3,4]).numpy()
# y = torch.tensor([1,2,4,4]).numpy()

# res = accuracy_score(x,y)
# print(res)
# print(type(res))

