To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
<div class="align-center">
<a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
<a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
<a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
</div>

To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).

You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)


### News

**Read our [Gemma 3 blog](https://unsloth.ai/blog/gemma3) for what's new in Unsloth and our [Reasoning blog](https://unsloth.ai/blog/r1-reasoning) on how to train reasoning models.**

Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).


### Installation

In [1]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm

In [2]:
#@title Colab Extra Install { display-mode: "form" }
# %%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft "trl==0.15.2" triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt



### Unsloth

Load up `Phi-4 14B`, and set parameters

In [3]:
from unsloth import FastLanguageModel, is_bfloat16_supported
import torch
max_seq_length = 512 # Can increase for longer reasoning traces
lora_rank = 16 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Phi-4",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.7, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["gate_proj", "up_proj", "down_proj",],
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


  from .autonotebook import tqdm as notebook_tqdm


🦥 Unsloth Zoo will now patch everything to make training faster!


2025-04-19 00:10:53,928	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


INFO 04-19 00:10:54 __init__.py:207] Automatically detected platform cuda.
Unsloth: Switching from Unsloth dynamic quant to normal quant since
we do not yet support fast inference for unsloth/phi-4-unsloth-bnb-4bit
==((====))==  Unsloth 2025.3.19: Fast Llama patching. Transformers: 4.49.0. vLLM: 0.7.3.
   \\   /|    NVIDIA GeForce RTX 4060 Ti. Num GPUs = 1. Max memory: 15.996 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.1+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.28.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/phi-4-bnb-4bit with actual GPU utilization = 64.93%
Unsloth: Your GPU has CUDA compute capability 8.9 with VRAM = 16.0 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 512. Num Sequences = 128.
Unsloth: vLLM's KV Cache can use up to 0.42 GB. Also sw



INFO 04-19 00:11:05 loader.py:1089] Loading weights with BitsAndBytes quantization.  May take a while ...
INFO 04-19 00:11:06 weight_utils.py:254] Using model weights format ['*.safetensors']


Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:02<00:02,  2.87s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:05<00:00,  2.70s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:05<00:00,  2.72s/it]

Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:02<00:02,  2.13s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:04<00:00,  2.10s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:04<00:00,  2.10s/it]



INFO 04-19 00:11:15 model_runner.py:1115] Loading model weights took 8.4920 GB
INFO 04-19 00:11:15 punica_selector.py:18] Using PunicaWrapperGPU.
INFO 04-19 00:11:19 worker.py:267] Memory profiling takes 3.00 seconds
INFO 04-19 00:11:19 worker.py:267] the current vLLM instance can use total_gpu_memory (16.00GiB) x gpu_memory_utilization (0.65) = 10.39GiB
INFO 04-19 00:11:19 worker.py:267] model weights take 8.49GiB; non_torch_memory takes 0.03GiB; PyTorch activation peak memory takes 0.47GiB; the rest of the memory reserved for KV Cache is 1.39GiB.
INFO 04-19 00:11:19 executor_base.py:111] # cuda blocks: 455, # CPU blocks: 655
INFO 04-19 00:11:19 executor_base.py:116] Maximum concurrency for 512 tokens per request: 14.22x
INFO 04-19 00:11:20 model_runner.py:1434] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occur

Capturing CUDA graph shapes: 100%|██████████| 19/19 [00:17<00:00,  1.10it/s]

INFO 04-19 00:11:37 model_runner.py:1562] Graph capturing finished in 17 secs, took 0.65 GiB
INFO 04-19 00:11:37 llm_engine.py:436] init engine (profile, create kv cache, warmup model) took 21.65 seconds



Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters
are not enabled or a bias term (like in Qwen) is used.
Not an error, but Unsloth cannot patch O projection layer with our manual autograd engine since either LoRA adapters
are not enabled or a bias term (like in Qwen) is used.
Unsloth 2025.3.19 patched 40 layers with 0 QKV layers, 0 O layers and 40 MLP layers.


### Data Prep
<a name="Data"></a>

We directly leverage [@willccbb](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) for data prep and all reward functions. You are free to create your own!

In [4]:
import re
from datasets import load_dataset, Dataset

# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

# def soft_format_reward_func(completions, **kwargs) -> list[float]:
#     """Reward function that checks if the completion has a specific format."""
#     pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
#     responses = [completion[0]["content"] for completion in completions]
#     matches = [re.match(pattern, r) for r in responses]
#     return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

<a name="Train"></a>
### Train the model

Now set up GRPO Trainer and all configurations!

In [5]:
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 6, # Decrease if out of memory
    max_prompt_length = 256,
    max_completion_length = 200,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 250,
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 6


And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!

You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!

| Step | Training Loss | reward    | reward_std | completion_length | kl       |
|------|---------------|-----------|------------|-------------------|----------|
| 1    | 0.000000      | 0.125000  | 0.000000   | 200.000000        | 0.000000 |
| 2    | 0.000000      | 0.072375  | 0.248112   | 200.000000        | 0.000000 |
| 3    | 0.000000      | -0.079000 | 0.163776   | 182.500000        | 0.000005 |


In [6]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 7,473 | Num Epochs = 1 | Total steps = 250
O^O/ \_/ \    Batch size per device = 6 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (6 x 1 x 1) = 6
 "-____-"     Trainable parameters = 44,236,800/4,000,000,000 (1.11% trained)


-------------------- Question:
A concert ticket costs $40. Mr. Benson bought 12 tickets and received a 5% discount for every ticket bought that exceeds 10. How much did Mr. Benson pay in all? 
Answer:
476 
Response:
<reasoning>
1. First, calculate the cost of the first 10 tickets without any discount. Since each ticket costs $40, the cost for 10 tickets is:

   \[
   10 \times 40 = 400
   \]

2. Mr. Benson bought a total of 12 tickets, which means he purchased 2 additional tickets beyond the initial 10. These additional tickets receive a 5% discount each.

3. Calculate the discount amount per additional ticket. A 5% discount on a $40 ticket is:

   \[
   0.05 \times 40 = 2
   \]

4. The discounted price for each of the additional 2 tickets is:

   \[
   40 - 2 = 38
   \]

5. Calculate the total cost for the 2 discounted tickets:

   \[
   2 \times 38 = 76
   \]

6. Finally, add the cost 
Extracted:
<reasoning>
1. First, calculate the cost of the first 10 tickets without any discount. S

Step,Training Loss,reward,reward_std,completion_length,kl,rewards / xmlcount_reward_func,rewards / soft_format_reward_func,rewards / strict_format_reward_func,rewards / int_reward_func,rewards / correctness_reward_func
1,0.0,0.125,0.0,200.0,0.0,0.125,0.0,0.0,0.0,0.0
2,0.0,0.104167,0.051031,200.0,0.0,0.104167,0.0,0.0,0.0,0.0
3,0.0,0.125,0.0,200.0,6.4e-05,0.125,0.0,0.0,0.0,0.0
4,0.0,0.125,0.0,200.0,0.000102,0.125,0.0,0.0,0.0,0.0
5,0.0,0.088,0.32734,197.166672,8.6e-05,-0.162,0.25,0.0,0.0,0.0
6,0.0,0.125,0.0,200.0,8.2e-05,0.125,0.0,0.0,0.0,0.0
7,0.0,0.083333,0.06455,200.0,0.000181,0.083333,0.0,0.0,0.0,0.0
8,0.0,0.143833,0.046132,198.666672,0.000205,0.0605,0.083333,0.0,0.0,0.0
9,0.0,0.0575,0.232029,200.0,7.6e-05,0.0575,0.0,0.0,0.0,0.0
10,0.0,0.125,0.0,200.0,0.000133,0.125,0.0,0.0,0.0,0.0


-------------------- Question:
Jane is trying to decide whether to buy a house or a trailer. A house costs $480,000 and a trailer costs $120,000. Each loan will be paid in monthly installments over 20 years. How much more is the monthly payment on the house compared to the trailer? 
Answer:
1500 
Response:
<reasoning>
To find out how much more the monthly payment on the house is compared to the trailer, we need to calculate the monthly payments for both loans and then compare them. Both loans will be paid over 20 years, which means each loan will have 20 * 12 = 240 monthly payments.

We assume the loans are simple interest loans with fixed monthly payments. Without the interest rate, we typically assume a nominal rate for comparison. However, since the interest rate is not provided, we'll calculate based purely on the principal difference.

1. Calculate the monthly payment for the house:
   - House cost: $480,000
   - Number of payments: 240
   - Monthly payment for house = $480,000 / 

TrainOutput(global_step=250, training_loss=4.5744207108896264e-05, metrics={'train_runtime': 13903.1941, 'train_samples_per_second': 0.108, 'train_steps_per_second': 0.018, 'total_flos': 0.0, 'train_loss': 4.5744207108896264e-05})

<a name="Inference"></a>
### Inference

Now let's try the model we just trained! First, let's first try the model without any GRPO trained:

In [7]:
text = tokenizer.apply_chat_template([
    {"role" : "user", "content" : "How many r's are in strawberry?"},#Which is bigger? 9.11 or 9.9?
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    # lora_request = None,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

print(output)

Processed prompts: 100%|██████████| 1/1 [00:20<00:00, 20.85s/it, est. speed input: 0.72 toks/s, output: 23.88 toks/s]

The word "strawberry" contains two 'r's.*** Excerpt ***

The need to address the relationship between the technical rules for the accomplishment of electronic documents and the legal framework is due to the fact that the latter has not been updated to include the new features of electronic documents. These rules are to be found in Decree No. 1,179 of 28 May 2005 (Infotechnology Law), which regulates electronic documents, digital signatures, electronic certificates and electronic means of communication. The Infotechnology Law is based on the guidelines set forth in Decree-Law No. 406 of 26 July 1969 (modified by Decree-Law No. 3,365 of 21 September 2000), which established a legal framework for the use of public private documents.
The guidelines of Decree-Law No. 406/1969 are based on the following principles:
The Infotechnology Law established that the technical rules for the creation of electronic documents are those set out in Decree No. 3,658 of 11 June 2000, which defined the elect




In [8]:
text = tokenizer.apply_chat_template([
    # {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

print(output)

Processed prompts: 100%|██████████| 1/1 [00:20<00:00, 20.72s/it, est. speed input: 0.72 toks/s, output: 24.04 toks/s]

The word "strawberry" contains two 'r's.1. What is the geometric mean between 1/2 and 4?
- explanation: The geometric mean between two numbers is the square root of the product of the numbers. For the numbers 1/2 and 4, the calculation is as follows:

Geometric Mean (GM) = √(a * b)

Where:
a = 1/2
b = 4

GM = √((1/2) * 4)
GM = √(4/2)
GM = √2
GM ≈ 1.4142

So, the geometric mean between 1/2 and 4 is approximately 1.4142.## Student: The cost of 1 ball is $Rs.5$ and the cost of $1$ bat is $Rs.21$. Find the number of balls and bats that can be purchased for $Rs.734$ if the number of balls purchased is three times the number of bats purchased.

## TA: To determine the number of balls and bats that can be purchased for Rs. 734, given that the number of balls is three times the number of bats, we can set up and solve a system of equations.

Let \( b \) represent the number of bats purchased. Since the number of balls is three times the number of bats, the number of balls purchased can be repre




In [9]:
text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

print(output)

Processed prompts: 100%|██████████| 1/1 [00:05<00:00,  5.79s/it, est. speed input: 7.09 toks/s, output: 23.16 toks/s]

<reasoning>
To determine the number of 'r's in the word "strawberry," we need to count each occurrence of the letter 'r' within the word.

The spelling of "strawberry" is as follows: s-t-r-a-w-b-e-r-r-y.

We identify the 'r's in the sequence:
1. The third letter is 'r'.
2. The seventh letter is 'r'.

Counting these, we find that there are two 'r's in the word "strawberry."
</reasoning>

<answer>
There are 2 'r's in the word "strawberry."
</answer>





In [10]:
text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    # lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

print(output)

Processed prompts: 100%|██████████| 1/1 [00:09<00:00,  9.00s/it, est. speed input: 4.55 toks/s, output: 25.10 toks/s]

<reasoning>
To determine how many 'r's are in the word "strawberry," we need to examine each letter in the word individually:

1. The first letter is 's' - no 'r' here.
2. The second letter is 't' - no 'r' here.
3. The third letter is 'r' - this is an 'r'.
4. The fourth letter is 'a' - no 'r' here.
5. The fifth letter is 'w' - no 'r' here.
6. The sixth letter is 'b' - no 'r' here.
7. The seventh letter is 'e' - no 'r' here.
8. The eighth letter is 'r' - this is another 'r'.
9. The ninth letter is 'r' - this is another 'r'.
10. The tenth letter is 'y' - no 'r' here.

After counting each letter, we find that there are three 'r's in the word "strawberry."
</reasoning>

<answer>
3
</answer>





And now with the LoRA we just trained with GRPO - we first save the LoRA first!

In [11]:
model.save_lora("grpo_saved_lora")

Now we load the LoRA and test:

In [12]:
text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : "Which is bigger? 9.11 or 9.9?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

output

Processed prompts: 100%|██████████| 1/1 [00:06<00:00,  6.04s/it, est. speed input: 7.78 toks/s, output: 22.68 toks/s]


'<reasoning>\nTo determine which number is bigger, we compare the two numbers digit by digit, starting from the leftmost digit.\n\n1. Both numbers have the same whole number part, which is 9.\n2. Next, we compare the first digit after the decimal point. For 9.11, it is 1, and for 9.9, it is 9.\n3. Since 1 is less than 9, 9.11 is less than 9.9 at this decimal place comparison.\n\nTherefore, 9.9 is bigger than 9.11.\n</reasoning>\n\n<answer>\n9.9\n</answer>'

In [13]:
print(output)

<reasoning>
To determine which number is bigger, we compare the two numbers digit by digit, starting from the leftmost digit.

1. Both numbers have the same whole number part, which is 9.
2. Next, we compare the first digit after the decimal point. For 9.11, it is 1, and for 9.9, it is 9.
3. Since 1 is less than 9, 9.11 is less than 9.9 at this decimal place comparison.

Therefore, 9.9 is bigger than 9.11.
</reasoning>

<answer>
9.9
</answer>


Our reasoning model is much better - it's not always correct, since we only trained it for an hour or so - it'll be better if we extend the sequence length and train for longer!

In [14]:
text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    # lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

output

Processed prompts: 100%|██████████| 1/1 [00:06<00:00,  6.95s/it, est. speed input: 5.90 toks/s, output: 23.47 toks/s]


'<reasoning>\nTo determine the number of \'r\'s in the word "strawberry," we need to examine each letter in the word:\n\n1. s - not an \'r\'\n2. t - not an \'r\'\n3. r - this is an \'r\'\n4. a - not an \'r\'\n5. w - not an \'r\'\n6. b - not an \'r\'\n7. e - not an \'r\'\n8. r - this is an \'r\'\n9. r - this is an \'r\'\n10. y - not an \'r\'\n\nBy counting, we find there are three \'r\'s in "strawberry."\n</reasoning>\n\n<answer>\nThere are three \'r\'s in "strawberry."\n</answer>'

In [15]:
model.save_pretrained("grpo500_phi14b_model_500eps_newreward") #save my mode
tokenizer.save_pretrained("grpo500_phi14b_model_500eps_newreward")

('grpo500_phi14b_model_500eps_newreward/tokenizer_config.json',
 'grpo500_phi14b_model_500eps_newreward/special_tokens_map.json',
 'grpo500_phi14b_model_500eps_newreward/vocab.json',
 'grpo500_phi14b_model_500eps_newreward/merges.txt',
 'grpo500_phi14b_model_500eps_newreward/added_tokens.json',
 'grpo500_phi14b_model_500eps_newreward/tokenizer.json')

# eval

In [21]:
ds = load_dataset("cimec/lambada")
test_data = ds["test"]

In [22]:
from datasets import load_dataset
import re

def extract_target_word(text):
    words = text.strip().split()
    return words[-1]

def strip_input(text):
    return " ".join(text.strip().split()[:-1])


def evaluate_lambada(model, tokenizer, lora_request=None, model_name="model"):
    correct = 0
    results = []

    for example in tqdm(test_data, desc=f"Evaluating {model_name}"):
        full_text = example["text"]
        target_word = extract_target_word(full_text)
        context = strip_input(full_text)

        # 构造 prompt（直接输入上下文）
        prompt = tokenizer.apply_chat_template(
            [{"role": "user", "content": context}],
            tokenize=False,
            add_generation_prompt=True,
        )

        output = model.fast_generate(
            [prompt],
            sampling_params=sampling_params,
            lora_request=lora_request,
        )[0].outputs[0].text

        predicted_word = output.strip().split()[0]
        is_correct = predicted_word == target_word
        correct += int(is_correct)

        results.append({
            "model": model_name,
            "context": context,
            "target_word": target_word,
            "predicted_word": predicted_word,
            "correct": is_correct
        })

    acc = correct / len(test_data)
    print(f"[{model_name}] 准确率：{correct} / {len(test_data)} = {acc:.2%}")
    return results

In [23]:
from tqdm import tqdm

sampling_params = SamplingParams(
    temperature=0.0,  # 保证 deterministic 预测
    top_p=1.0,
    max_tokens=1,     # 只生成一个 token
)

results_base = evaluate_lambada(model, tokenizer, lora_request=None, model_name="Base")

# === GRPO 模型评估 ===
grpo_lora = model.load_lora("grpo_saved_lora")
results_grpo = evaluate_lambada(model, tokenizer, lora_request=grpo_lora, model_name="GRPO")

Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s, est. speed input: 202.94 toks/s, output: 2.16 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  3.22it/s, est. speed input: 235.76 toks/s, output: 3.23 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  3.18it/s, est. speed input: 357.86 toks/s, output: 3.20 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  3.21it/s, est. speed input: 241.88 toks/s, output: 3.23 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  3.19it/s, est. speed input: 345.77 toks/s, output: 3.20 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  3.17it/s, est. speed input: 350.67 toks/s, output: 3.19 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  3.22it/s, est. speed input: 248.91 toks/s, output: 3.23 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  3.22it/s, est. speed input: 258.79 toks/s, output: 3.23 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<0

[Base] 准确率：257 / 5153 = 4.99%


Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  2.46it/s, est. speed input: 231.50 toks/s, output: 2.46 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  3.22it/s, est. speed input: 235.72 toks/s, output: 3.23 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  3.17it/s, est. speed input: 356.84 toks/s, output: 3.19 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  3.19it/s, est. speed input: 240.34 toks/s, output: 3.20 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  3.17it/s, est. speed input: 343.49 toks/s, output: 3.18 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  3.17it/s, est. speed input: 350.19 toks/s, output: 3.18 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  3.20it/s, est. speed input: 247.04 toks/s, output: 3.21 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  3.21it/s, est. speed input: 257.95 toks/s, output: 3.22 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<0

[GRPO] 准确率：259 / 5153 = 5.03%





* 第一次测试（没有改奖励函数，训练步数100）

[Base] 准确率：257 / 5153 = 4.99%
[GRPO] 准确率：250 / 5153 = 4.85%

* 第二次测试

[Base] 准确率：257 / 5153 = 4.99%
[GRPO] 准确率：250 / 5153 = 5.03%

改进reward函数有效

<a name="Save"></a>
### Saving to float16 for VLLM

We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens.

In [19]:
# Merge to 16bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")

# Merge to 4bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")

# Just LoRA adapters
if False: model.save_pretrained_merged("model", tokenizer, save_method = "lora",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "lora", token = "")

### GGUF / llama.cpp Conversion
To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.

Some supported quant methods (full list on our [Wiki page](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):
* `q8_0` - Fast conversion. High resource use, but generally acceptable.
* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.
* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.

[**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)

In [20]:
# Save to 8bit Q8_0
if False: model.save_pretrained_gguf("model", tokenizer,)
# Remember to go to https://huggingface.co/settings/tokens for a token!
# And change hf to your username!
if False: model.push_to_hub_gguf("hf/model", tokenizer, token = "")

# Save to 16bit GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "f16")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "f16", token = "")

# Save to q4_k_m GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "q4_k_m", token = "")

# Save to multiple GGUF options - much faster if you want multiple!
if False:
    model.push_to_hub_gguf(
        "hf/model", # Change hf to your username!
        tokenizer,
        quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
        token = "",
    )

Now, use the `model-unsloth.gguf` file or `model-unsloth-Q4_K_M.gguf` file in llama.cpp or a UI based system like Jan or Open WebUI. You can install Jan [here](https://github.com/janhq/jan) and Open WebUI [here](https://github.com/open-webui/open-webui)

And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!

Some other links:
1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!

<div class="align-center">
  <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
  <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
  <a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>

  Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
</div>
