### Installation

In [1]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm==0.8.5.post1

In [2]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm==0.8.5.post1
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

### Unsloth

`FastModel` supports loading nearly any model now! This includes Vision and Text models!

In [3]:
from unsloth import FastModel
import torch

fourbit_models = [
    # 4bit dynamic quants for superior accuracy and low memory use
    "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
    "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit",
    # Pretrained models
    "unsloth/gemma-3n-E4B-unsloth-bnb-4bit",
    "unsloth/gemma-3n-E2B-unsloth-bnb-4bit",

    # Other Gemma 3 quants
    "unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-27b-it-unsloth-bnb-4bit",
] # More models at https://huggingface.co/unsloth
max_seq_length = 2048
model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3n-E2B-it",
    dtype = None, # None for auto detection by unsloth
    max_seq_length = max_seq_length, # Choose any for long context!
    load_in_4bit = True,  # 4 bit dynamic quantization for superior accuracy and lower memory use
    full_finetuning = False,
    # token = "hf_...", # use one if using gated models
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 08-06 09:46:51 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 08-06 09:46:51 [__init__.py:239] Automatically detected platform cuda.
==((====))==  Unsloth 2025.8.1: Fast Gemma3N patching. Transformers: 4.54.0. vLLM: 0.8.5.post1.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3N does not support SDPA - switching to eager!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/2.65G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/469M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/210 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/98.0 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.70M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/777 [00:00<?, ?B/s]

In [4]:
print(f"Allocated: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB")
print(f"Cached: {torch.cuda.memory_reserved() / 1024 ** 2:.2f} MB")

# Check model quantization status
print(f"4-bit loaded: {model.is_loaded_in_4bit}")
print(f"Quantized: {model.is_quantized}")
print(f"Method: {model.quantization_method}")

Allocated: 7770.68 MB
Cached: 7790.00 MB
4-bit loaded: True
Quantized: True
Method: QuantizationMethod.BITS_AND_BYTES


# Gemma 3N can process Text, Vision and Audio!

Let's first experience how Gemma 3N can handle multimodal inputs. We use Gemma 3N's recommended settings of `temperature = 1.0, top_p = 0.95, top_k = 64`

In [5]:
from transformers import TextStreamer
# Helper function for inference
def do_gemma_3n_inference(messages, max_new_tokens = 128):
    _ = model.generate(
        **tokenizer.apply_chat_template(
            messages,
            add_generation_prompt = True, # Must add for generation
            tokenize = True,
            return_dict = True,
            return_tensors = "pt",
        ).to("cuda"),
        max_new_tokens = max_new_tokens,
        temperature = 1.0, top_p = 0.95, top_k = 64,
        streamer = TextStreamer(tokenizer, skip_prompt = True),
    )

We now add LoRA adapters so we only need to update a small amount of parameters!

In [6]:
model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = False, # Turn off for just text!
    finetune_language_layers   = True,  # Should leave on!
    finetune_attention_modules = True,  # Attention good for GRPO
    finetune_mlp_modules       = True,  # SHould leave on always!
    r = 16,           # Larger = higher accuracy, but no significant improvements. checkout [LoRA](https://arxiv.org/abs/2106.09685) paper
    lora_alpha = 32,  # Recommended alpha >= r
    lora_dropout = 0, # 0.1 provides moderate regularization without being too aggressive. Common range is 0.05-0.2 for LoRA fine-tuning. Since we are doing text only, 0.1 is a good default.
    bias = "none", # Bias terms are simple additive constants that shift neuron outputs. They're less critical for task adaptation because:
    # What bias does: If a neuron computes Wx + b, the bias b just shifts the entire output up/down by a constant.
    # Why freezing works: The main "intelligence" comes from the weight matrix W learning new patterns. The bias shifts are usually already well-calibrated from pretraining.
    random_state = 3407,
    use_rslora = True,
    use_gradient_checkpointing = "unsloth",
    loftq_config = {}
)
model.print_trainable_parameters()

Unsloth: Making `model.base_model.model.model.language_model` require gradients
trainable params: 21,135,360 || all params: 5,460,573,632 || trainable%: 0.3871


<a name="Data"></a>
### Data Prep
We now use the `Gemma-3` format for conversation style finetunes.

```
<bos><start_of_turn>user
Hello!<end_of_turn>
<start_of_turn>model
Hey there!<end_of_turn>
```

We use our `get_chat_template` function to get the correct chat template. We support `zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, phi3, llama3, phi4, qwen2.5, gemma3` and more.

In [7]:
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma-3",
)

In [8]:
from datasets import load_dataset

train_ds = load_dataset("muzzz/coedit-cot-reasoning", split="train")
val_ds = load_dataset("muzzz/coedit-cot-reasoning", split="validation")

README.md: 0.00B [00:00, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/82.6M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/3.40M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/69071 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1712 [00:00<?, ? examples/s]

In [9]:
print("Train Dataset Info:")
print(train_ds)
print("features: ", train_ds.features)
print("example: ", train_ds[0])
print("--------------------------------")
print("\nValidation Dataset Info:")
print(val_ds)
print("features: ", val_ds.features)
print("example: ", val_ds[0])

Train Dataset Info:
Dataset({
    features: ['_id', 'task', 'src', 'tgt', 'reasoning'],
    num_rows: 69071
})
features:  {'_id': Value(dtype='string', id=None), 'task': Value(dtype='string', id=None), 'src': Value(dtype='string', id=None), 'tgt': Value(dtype='string', id=None), 'reasoning': Value(dtype='string', id=None)}
example:  {'_id': '1', 'task': 'gec', 'src': 'Remove all grammatical errors from this text: For example, countries with a lot of deserts can terraform their desert to increase their habitable land and using irrigation to provide clean water to the desert.', 'tgt': 'For example, countries with a lot of deserts can transform their desert to increase their habitable land and use irrigation to provide clean water to the desert.', 'reasoning': '<think>\n1.  **Instruction Analysis:** The goal is to "Remove all grammatical errors" from the provided text. I need to read the text carefully and identify any points that violate standard English grammar rules.\n2.  **Source Text

In [10]:
# Check how many rows have empty reasoning column
empty_reasoning_count_train = sum(1 for example in train_ds if not example["reasoning"] or example["reasoning"].strip() == "")
empty_reasoning_count_val = sum(1 for example in val_ds if not example["reasoning"] or example["reasoning"].strip() == "")

print(f"Train dataset - Empty reasoning rows: {empty_reasoning_count_train} out of {len(train_ds)} ({empty_reasoning_count_train/len(train_ds)*100:.2f}%)")
print(f"Validation dataset - Empty reasoning rows: {empty_reasoning_count_val} out of {len(val_ds)} ({empty_reasoning_count_val/len(val_ds)*100:.2f}%)")

print("\nExamples with empty reasoning from train dataset:")
empty_examples_train = [example for example in train_ds if not example["reasoning"] or example["reasoning"].strip() == ""]
for i, example in enumerate(empty_examples_train[:3]):
    print(f"Example {i+1}:")
    print(f"  src: {example['src'][:100]}...")
    print(f"  reasoning: '{example['reasoning']}'")
    print(f"  tgt: {example['tgt'][:100]}...")


Train dataset - Empty reasoning rows: 12335 out of 69071 (17.86%)
Validation dataset - Empty reasoning rows: 0 out of 1712 (0.00%)

Examples with empty reasoning from train dataset:
Example 1:
  src: Remove unsourced opinions: they were at war with mecca, and saw no wrong in raiding meccan caravans....
  reasoning: 'None'
  tgt: they considered themselves to be at war with mecca, and saw no wrong in raiding meccan caravans....
Example 2:
  src: Make this sentence more neutral: crowd of curious people day after the tragedy....
  reasoning: 'None'
  tgt: crowd of curious people day after the fire....
Example 3:
  src: Neutralize the text: "the city that fun forgot", often used sarcastically by residents of ottawa...
  reasoning: 'None'
  tgt: "the city that fun forgot", used sarcastically by residents of ottawa...


### the dataset being used [here](https://huggingface.co/datasets/muzzz/coedit-cot-reasoning)

In [11]:
reasoning_start_token = "<think>"
reasoning_end_token = "</think>"

# The detailed system prompt for SFT phases (1 and 3)
system_prompt_sft = f"""You are an expert text editor. First, think step-by-step about the user's instruction and the source text. Place your reasoning inside {reasoning_start_token} and {reasoning_end_token} tags. Then, provide the final, edited text immediately after the closing tag. Your reasoning must follow a logical structure: Instruction analysis, Sentence analysis, Identify error(s), Apply correction(s), and Synthesized correction."""

# The simpler system prompt for the GRPO phase (2)
system_prompt_grpo = f"""You are an expert text editor. First, think step-by-step about the user's instruction and the source text. Place your reasoning inside {reasoning_start_token} and {reasoning_end_token} tags. Then, provide the final, edited text immediately after the closing tag."""

In [12]:
def format_sft_dataset(example, tokenizer, system_prompt):
    # For gemma_chatml, include system prompt as part of the user message
    user_content = system_prompt + "\n\n" + example["src"]
    messages = [
        {"role": "user", "content": user_content},
        {"role": "assistant", "content": (example["reasoning"] or "") + "\n" + example["tgt"]},
    ]
    formatted_text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=False
    )
    if formatted_text.startswith(tokenizer.bos_token):
        formatted_text = formatted_text[len(tokenizer.bos_token):]
    return {"text": formatted_text}


def format_grpo_dataset(example, system_prompt):
    """
    Formats data for GRPO. The 'prompt' is a list of messages for generation,
    and 'answer' is the ground truth for the reward function.
    """
    # For gemma_chatml, include system prompt as part of the user message
    user_content = system_prompt + "\n\n" + example["src"]
    return {
        "prompt": [
            {"role": "user", "content": user_content},
        ],
        "answer": example["tgt"],
    }

In [13]:
# --- Testing Chat Template Formatting ---

# --- [Phase 1 & 3] SFT Template Test ---
print("--- [Phase 1 & 3] SFT Template Test ---")
print("This template should include the detailed system prompt and the full model response.")

# Reset tokenizer to use the base gemma-3 template
tokenizer = get_chat_template(tokenizer, chat_template="gemma-3")

# Format the messages just like the SFTTrainer will (system prompt included in user message)
user_content_sft = system_prompt_sft + "\n\nWhat is the capital of France?"
sft_messages = [
    {"role": "user", "content": user_content_sft},
    {"role": "assistant", "content": "The capital of France is Paris."},
]
sft_formatted = tokenizer.apply_chat_template(sft_messages, tokenize=False, add_generation_prompt=False)
print(sft_formatted)

# --- [Phase 2] GRPO Template Test ---
print("\n--- [Phase 2] GRPO Template Test ---")
print("This template should include the simpler system prompt and add the generation prompt at the end.")

# Format the messages just like the GRPOTrainer will (system prompt included in user message)
user_content_grpo = system_prompt_grpo + "\n\nWhat is the capital of France?"
grpo_messages = [
    {"role": "user", "content": user_content_grpo},
]
grpo_formatted = tokenizer.apply_chat_template(grpo_messages, tokenize=False, add_generation_prompt=True)
print(grpo_formatted)

--- [Phase 1 & 3] SFT Template Test ---
This template should include the detailed system prompt and the full model response.
<bos><start_of_turn>user
You are an expert text editor. First, think step-by-step about the user's instruction and the source text. Place your reasoning inside <think> and </think> tags. Then, provide the final, edited text immediately after the closing tag. Your reasoning must follow a logical structure: Instruction analysis, Sentence analysis, Identify error(s), Apply correction(s), and Synthesized correction.

What is the capital of France?<end_of_turn>
<start_of_turn>model
The capital of France is Paris.<end_of_turn>


--- [Phase 2] GRPO Template Test ---
This template should include the simpler system prompt and add the generation prompt at the end.
<bos><start_of_turn>user
You are an expert text editor. First, think step-by-step about the user's instruction and the source text. Place your reasoning inside <think> and </think> tags. Then, provide the final, 

In [14]:
# GRPO REWARD FUNCS
import re

extraction_pattern = re.compile(rf"{re.escape(reasoning_end_token)}(.*)", flags=re.DOTALL)

def reward_reasoning_structure(completions, **kwargs):
    """
    Rewards completions for having the correct <think> tags and the expected
    step-by-step reasoning structure based on your data generation prompt.
    """
    structural_keywords = [
        re.compile(r"instruction.*analysis", re.IGNORECASE),
        re.compile(r"sentence.*analysis", re.IGNORECASE),
        re.compile(r"identify.*error", re.IGNORECASE),
        re.compile(r"apply.*correction", re.IGNORECASE),
        re.compile(r"synthesized.*correction", re.IGNORECASE),
    ]
    scores = []
    for completion in completions:
        score = 0
        response_text = completion[0]["content"]

        # Base check for the enclosing <think> tags. This is fundamental.
        if reasoning_start_token in response_text and reasoning_end_token in response_text:
            score += 1.0  # Base reward for correct tag usage
        else:
            scores.append(-4.0) # Penalize heavily if tags are missing
            continue

        # Additive reward for each structural keyword found.
        num_keywords_found = sum(1 for keyword_regex in structural_keywords if keyword_regex.search(response_text))

        # Scale the reward. Max of +3 points for a perfectly structured response.
        if num_keywords_found > 0:
            score += (num_keywords_found / len(structural_keywords)) * 3.0

        scores.append(score)
    return scores

def reward_target_match(completions, answer, **kwargs):
    """
    Heavily rewards completions where the final extracted text exactly matches the target.
    """
    scores = []
    ground_truth_tgts = answer # The 'answer' here is `example["tgt"]` from format_grpo_dataset

    for completion, true_tgt in zip(completions, ground_truth_tgts):
        score = 0
        response_text = completion[0]["content"]

        # Extract the model's generated final answer
        extracted_match = extraction_pattern.search(response_text)

        if extracted_match:
            generated_text = extracted_match.group(1).strip()
            if generated_text == true_tgt.strip():
                score += 5.0  # High reward for an exact match
            else:
                score -= 2.0  # Penalize if it generates something, but it's wrong
        else:
            score -= 4.0  # Penalize heavily if it fails to produce any final answer

        scores.append(score)
    return scores

reward_funcs = [
    reward_reasoning_structure,
    reward_target_match,
]

In [15]:
# PHASE 1
print("preparing phase 1")
# Reset tokenizer to use the base gemma-3 template
tokenizer = get_chat_template(tokenizer, chat_template="gemma-3")

sft_dataset_full = train_ds.map(lambda x: format_sft_dataset(x, tokenizer, system_prompt_sft), batched=False)
sft_val_dataset = val_ds.map(lambda x: format_sft_dataset(x, tokenizer, system_prompt_sft), batched=False)

train_ds_subset = sft_dataset_full.shuffle(seed=42).select(range(len(sft_dataset_full) // 10))
val_ds_subset = sft_val_dataset.shuffle(seed=42).select(range(len(sft_val_dataset) // 10))

print(f"Full SFT dataset size: {len(sft_dataset_full)}")
print(sft_dataset_full[0]['text'][:500])
print(f"SFT subset size: {len(train_ds_subset)}")

preparing phase 1


Map:   0%|          | 0/69071 [00:00<?, ? examples/s]

Map:   0%|          | 0/1712 [00:00<?, ? examples/s]

Full SFT dataset size: 69071
<start_of_turn>user
You are an expert text editor. First, think step-by-step about the user's instruction and the source text. Place your reasoning inside <think> and </think> tags. Then, provide the final, edited text immediately after the closing tag. Your reasoning must follow a logical structure: Instruction analysis, Sentence analysis, Identify error(s), Apply correction(s), and Synthesized correction.

Remove all grammatical errors from this text: For example, countries with a lot of deser
SFT subset size: 6907


In [16]:
from trl import SFTTrainer, SFTConfig

trainer_phase1 = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_ds_subset,
    eval_dataset=sft_val_dataset,
    args=SFTConfig(
        output_dir="./phase1",
        dataset_text_field="text",
        per_device_train_batch_size=16,
        gradient_accumulation_steps=2,
        eval_accumulation_steps=2,
        save_total_limit=12,
        load_best_model_at_end=True,
        greater_is_better=False,
        metric_for_best_model="eval_loss",
        warmup_steps=100,
        num_train_epochs=1,
        learning_rate=4e-5,
        logging_steps=25,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        report_to="none",
        eval_strategy="steps",
        eval_steps=25,
        save_strategy="steps",
        save_steps=25,
    )
)

Unsloth: Tokenizing ["text"] (num_proc=2):   0%|          | 0/6907 [00:00<?, ? examples/s]

Unsloth: Tokenizing ["text"] (num_proc=2):   0%|          | 0/1712 [00:00<?, ? examples/s]

We also use Unsloth's `train_on_completions` method to only train on the assistant outputs and ignore the loss on the user's inputs. This helps increase accuracy of finetunes!

In [17]:
from unsloth.chat_templates import train_on_responses_only
trainer_phase1 = train_on_responses_only(
    trainer_phase1,
    instruction_part = "<start_of_turn>user\n",
    response_part = "<start_of_turn>model\n",
)

Map (num_proc=12):   0%|          | 0/6907 [00:00<?, ? examples/s]

Map (num_proc=12):   0%|          | 0/1712 [00:00<?, ? examples/s]

In [18]:
print(trainer_phase1.train_dataset)

Dataset({
    features: ['_id', 'task', 'src', 'tgt', 'reasoning', 'text', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 6907
})


In [19]:
print(trainer_phase1.train_dataset[100]["text"])
print()
print(trainer_phase1.train_dataset[100]["input_ids"])
print()
print(trainer_phase1.train_dataset[100]["attention_mask"])
print()
print(trainer_phase1.train_dataset[100]["labels"])

<start_of_turn>user
You are an expert text editor. First, think step-by-step about the user's instruction and the source text. Place your reasoning inside <think> and </think> tags. Then, provide the final, edited text immediately after the closing tag. Your reasoning must follow a logical structure: Instruction analysis, Sentence analysis, Identify error(s), Apply correction(s), and Synthesized correction.

Rewrite this with simpler wording: The question made me feel as stupid as Ralph Roberts's observation that I needed a vacation.<end_of_turn>
<start_of_turn>model
<think>
1.  **Instruction Analysis:** The goal is to rewrite the source sentence using "simpler wording." This means identifying complex words, phrases, or grammatical structures and replacing them with more common, direct, or less formal alternatives while preserving the original meaning.
2.  **Source Sentence Analysis:** "The question made me feel as stupid as Ralph Roberts's observation that I needed a vacation."
    * 

In [20]:
total_training_tokens = sum(len(x) for x in trainer_phase1.train_dataset['input_ids'])
print("Total training tokens in CoEdIT dataset:", total_training_tokens)

Total training tokens in CoEdIT dataset: 5401275


In [21]:
import numpy as np

lengths = [len(x) for x in trainer_phase1.train_dataset['input_ids']]
print(f"Token lengths stats:")
print(f"Min: {np.min(lengths)}")
print(f"Max: {np.max(lengths)}")
print(f"Mean: {np.mean(lengths)}")
print(f"Median: {np.median(lengths)}")
print(f"95th percentile: {np.percentile(lengths, 95)}")
print(f"99th percentile: {np.percentile(lengths, 99)}")

Token lengths stats:
Min: 282
Max: 2048
Mean: 782.0001447806573
Median: 728.0
95th percentile: 1313.0
99th percentile: 2048.0


Let's verify masking the instruction part is done! Let's print the 100th row again.  Notice how the sample only has a single `<bos>` as expected!

In [22]:
tokenizer.decode(trainer_phase1.train_dataset[100]["input_ids"])

'<bos><start_of_turn>user\nYou are an expert text editor. First, think step-by-step about the user\'s instruction and the source text. Place your reasoning inside <think> and </think> tags. Then, provide the final, edited text immediately after the closing tag. Your reasoning must follow a logical structure: Instruction analysis, Sentence analysis, Identify error(s), Apply correction(s), and Synthesized correction.\n\nRewrite this with simpler wording: The question made me feel as stupid as Ralph Roberts\'s observation that I needed a vacation.<end_of_turn>\n<start_of_turn>model\n<think>\n1.  **Instruction Analysis:** The goal is to rewrite the source sentence using "simpler wording." This means identifying complex words, phrases, or grammatical structures and replacing them with more common, direct, or less formal alternatives while preserving the original meaning.\n2.  **Source Sentence Analysis:** "The question made me feel as stupid as Ralph Roberts\'s observation that I needed a v

Now let's print the masked out example - you should see only the answer is present:

In [23]:
tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in trainer_phase1.train_dataset[100]["labels"]]).replace(tokenizer.pad_token, " ")

'                                                                                                                     <think>\n1.  **Instruction Analysis:** The goal is to rewrite the source sentence using "simpler wording." This means identifying complex words, phrases, or grammatical structures and replacing them with more common, direct, or less formal alternatives while preserving the original meaning.\n2.  **Source Sentence Analysis:** "The question made me feel as stupid as Ralph Roberts\'s observation that I needed a vacation."\n    *   The core meaning is a comparison: feeling stupid about a question, using a previous feeling of stupidity related to Ralph Roberts\'s comment about needing a vacation as the point of comparison.\n    *   Phrases to evaluate for simplification:\n        *   "made me feel as stupid as": This structure is okay, but could potentially be slightly more direct like "I felt as stupid as...". Adding *what* the feeling was about ("about that question") coul

In [24]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA A100-SXM4-40GB. Max memory = 39.557 GB.
7.688 GB of memory reserved.


# Let's train the model!

To resume a training run, set `trainer.train(resume_from_checkpoint = True)`

In [25]:
print("--- starting phase 1 (primer): initial SFT ---")
trainer_stats = trainer_phase1.train()
#resume_from_checkpoint=True

--- starting phase 1 (primer): initial SFT ---


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 6,907 | Num Epochs = 1 | Total steps = 216
O^O/ \_/ \    Batch size per device = 16 | Gradient accumulation steps = 2
\        /    Data Parallel GPUs = 1 | Total batch size (16 x 2 x 1) = 32
 "-____-"     Trainable parameters = 21,135,360 of 5,460,573,632 (0.39% trained)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 6,907 | Num Epochs = 1 | Total steps = 247
O^O/ \_/ \    Batch size per device = 14 | Gradient accumulation steps = 2
\        /    Data Parallel GPUs = 1 | Total batch size (14 x 2 x 1) = 28
 "-____-"     Trainable parameters = 21,135,360 of 5,460,573,632 (0.39% trained)
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 6,907 | Num Epochs = 1 | Total steps = 288
O^O/ \_/ \    Batch size 

Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,Validation Loss
25,2.7187,1.458365
50,1.6772,1.338516
75,1.3851,1.193472
100,1.2635,1.151698
125,1.2013,1.133579
150,1.1862,1.117951
175,1.1333,1.110629
200,1.1327,1.097295


Unsloth: Not an error, but Gemma3nForConditionalGeneration does not accept `num_items_in_batch`.
Using gradient accumulation will be very slightly less accurate.
Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient


Step,Training Loss,Validation Loss
25,2.7187,1.458365
50,1.6772,1.338516
75,1.3851,1.193472
100,1.2635,1.151698
125,1.2013,1.133579
150,1.1862,1.117951
175,1.1333,1.110629
200,1.1327,1.097295
225,1.1176,1.101213
250,1.0898,1.096799


In [54]:

print('phase one training complete')

q = "fix grammar: she go to store yesterday"
message = [{
    "role": "user",
    "content": [{"type": "text", "text": q}]
}]
print(message)

phase one training complete
[{'role': 'user', 'content': [{'type': 'text', 'text': 'fix grammar: she go to store yesterday'}]}]


In [55]:
inputs = tokenizer.apply_chat_template(
    message,
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
).to("cuda")

from transformers import TextStreamer
_ = model.generate(
    **inputs,
    temperature=1.0,
    max_new_tokens=528,
    streamer=TextStreamer(tokenizer, skip_prompt=False),
)


<bos><start_of_turn>user
fix grammar: she go to store yesterday<end_of_turn>
<start_of_turn>model
The correct grammar for the sentence "she go to store yesterday" is **"She went to the store yesterday."**

Here's a breakdown of the corrections:

*   **"she go"** is missing the correct past tense form of the verb "go". The past tense is "went".
*   **"store"** is a singular countable noun. It requires the definite article "the" before it when referring to a specific store. The sentence should be "she went to *the* store yesterday."
*   **"yesterday"** is an adverb of time. It correctly modifies the verb "went".

**Therefore, the complete and grammatically correct sentence is: She went to the store yesterday.**<end_of_turn>


In [56]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA A100-SXM4-40GB. Max memory = 39.557 GB.
38.982 GB of memory reserved.


In [57]:
# PHASE 2
print("preparing phase 2")
# Reset tokenizer to use the base gemma-3 template
tokenizer = get_chat_template(tokenizer, chat_template="gemma-3")

grpo_dataset_full = train_ds.map(lambda x: format_grpo_dataset(x, system_prompt_grpo), remove_columns=list(train_ds.features))

subset_size = len(train_ds) // 10
grpo_train_subset = grpo_dataset_full.shuffle(seed=42).select(range(subset_size))
grpo_val_subset = val_ds.map(lambda x: format_grpo_dataset(x, system_prompt_grpo), remove_columns=list(val_ds.features)).shuffle(seed=42).select(range(len(val_ds) // 10))

print(f"Full GRPO dataset size: {len(grpo_dataset_full)}")
print(grpo_dataset_full[0])
print(f"GRPO subset size: {len(grpo_train_subset)}")

preparing phase 2


Map:   0%|          | 0/69071 [00:00<?, ? examples/s]

Map:   0%|          | 0/1712 [00:00<?, ? examples/s]

Full GRPO dataset size: 69071
{'prompt': [{'content': "You are an expert text editor. First, think step-by-step about the user's instruction and the source text. Place your reasoning inside <think> and </think> tags. Then, provide the final, edited text immediately after the closing tag.\n\nRemove all grammatical errors from this text: For example, countries with a lot of deserts can terraform their desert to increase their habitable land and using irrigation to provide clean water to the desert.", 'role': 'user'}], 'answer': 'For example, countries with a lot of deserts can transform their desert to increase their habitable land and use irrigation to provide clean water to the desert.'}
GRPO subset size: 6907


In [58]:
grpo_train_subset

Dataset({
    features: ['prompt', 'answer'],
    num_rows: 6907
})

In [59]:
from vllm import SamplingParams
vllm_sampling_params = SamplingParams(
    min_p=0.1,
    top_p=1.0,
    top_k=-1,
    temperature=1.0,
    stop=[tokenizer.eos_token],
    max_tokens=max_seq_length,
)


In [60]:

# # Reset generation config to avoid HybridCache issues
# model.generation_config.cache_implementation = None
# if hasattr(model.generation_config, 'use_cache'):
#     model.generation_config.use_cache = True

# # Also clear any cached states in the model config
# if hasattr(model.config, 'cache_implementation'):
#     model.config.cache_implementation = None



In [61]:
from trl import GRPOTrainer, GRPOConfig

trainer_phase2_grpo = GRPOTrainer(
    model=model,
    tokenizer=tokenizer,
    reward_funcs=reward_funcs,
    train_dataset=grpo_train_subset,
    args=GRPOConfig(
        output_dir='./phase2',
        vllm_sampling_params=vllm_sampling_params,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,
        num_generations=4,
        max_prompt_length=max_seq_length // 2,
        max_completion_length=max_seq_length // 2,
        max_steps=200,
        save_steps=50,
        logging_steps=5,
        learning_rate=4e-6,
        report_to="none",
    ),
    eval_dataset=grpo_val_subset,
)


Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 2 to the `num_generations` of 4


In [62]:
print("starting phase 2")
trainer_stats = trainer_phase2_grpo.train()

starting phase 2


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 6,907 | Num Epochs = 1 | Total steps = 200
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 8 x 1) = 32
 "-____-"     Trainable parameters = 21,135,360 of 5,460,573,632 (0.39% trained)
`generation_config` default values have been modified to match model-specific defaults: {'cache_implementation': 'hybrid', 'top_p': 0.95}. If this is not desired, please set these values explicitly.


ValueError: Max cache length is not consistent across layers: [512, 512, 512, 512, 1130, 512, 512, 512, 512, 1130, 512, 512, 512, 512, 1130, 512, 512, 512, 512, 1130, 512, 512, 512, 512, 1130, 512, 512, 512, 512, 1130]

In [None]:
# PHASE 3
from trl import SFTTrainer, SFTConfig

trainer_phase3 = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=sft_dataset_full,
    eval_dataset=sft_val_dataset,
    args=SFTConfig(
        output_dir="./phase3",
        dataset_text_field="text",
        per_device_train_batch_size=16,
        gradient_accumulation_steps=2,
        eval_accumulation_steps=2,
        save_total_limit=12,
        load_best_model_at_end=True,
        greater_is_better=False,
        metric_for_best_model="eval_loss",
        warmup_steps=100,
        num_train_epochs=1,
        learning_rate=2e-5,
        logging_steps=25,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        report_to="none",
        eval_strategy="steps",
        eval_steps=25,
        save_strategy="steps",
        save_steps=25,
    )
)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# @title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

In [None]:
from transformers import EarlyStoppingCallback
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience = 5,     # How many steps we will wait if the eval loss doesn't decrease
                                     # For example the loss might increase, but decrease after 3 steps
    early_stopping_threshold = 0.01,  # Can set higher - sets how much loss should decrease by until
                                     # we consider early stopping. For eg 0.01 means if loss was
                                     # 0.02 then 0.01, we consider to early stop the run.
)
trainer_phase3.add_callback(early_stopping_callback)

In [None]:
print('starting phase 3')
phase3_stats = trainer_phase3.train()


<a name="Inference"></a>
### Inference

In [None]:
print(trainer_phase3.state.best_model_checkpoint)
print(trainer_phase3.state.best_metric)

In [None]:
messages = [{
    "role": "user",
    "content": [{"type" : "text", "text" : "Fix grammar in this sentence: hello there the angle from my nightmare the shadow in teh background of the morgue",}]
}]
inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True, # Must add for generation
    tokenize = True,
    return_tensors = "pt",
    return_dict = True,
).to("cuda")


from transformers import TextStreamer
_ = model.generate(
    **inputs,
    max_new_tokens = 3072,
    # Recommended Gemma-3 settings!
    temperature = 0.3, top_p = 0.95, top_k = 64,
    streamer = TextStreamer(tokenizer, skip_prompt = True),
)

In [None]:
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma-3",
)
messages = [{
    "role": "user",
    "content": [{
        "type" : "text",
        "text" : "Fix grammar in this sentence: If engineers do not come up with new ideas, they cannot find best solution for the problems.",
    }]
}]
inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True, # Must add for generation
    return_tensors = "pt",
    tokenize = True,
    return_dict = True,
    enable_thinking=True,
).to("cuda")
outputs = model.generate(
    **inputs,
    max_new_tokens = 3072, # Increase for longer outputs!
    # Recommended Gemma-3 settings!
    temperature = 1.0, top_p = 0.95, top_k = 64,
)
tokenizer.batch_decode(outputs)

 You can also use a `TextStreamer` for continuous inference - so you can see the generation token by token, instead of waiting the whole time!

In [None]:
messages = [{
    "role": "user",
    "content": [{"type" : "text", "text" : "Why is the sky blue?",}]
}]
inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True, # Must add for generation
    return_tensors = "pt",
    tokenize = True,
    return_dict = True,
).to("cuda")

from transformers import TextStreamer
_ = model.generate(
    **inputs,
    max_new_tokens = 64, # Increase for longer outputs!
    # Recommended Gemma-3 settings!
    temperature = 1.0, top_p = 0.95, top_k = 64,
    streamer = TextStreamer(tokenizer, skip_prompt = True),
)

<a name="Save"></a>
### Saving, loading finetuned models
To save the final model as LoRA adapters, either use Huggingface's `push_to_hub` for an online save or `save_pretrained` for a local save.

**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!

In [None]:
model.save_pretrained("gemma-3n")  # Local saving
tokenizer.save_pretrained("gemma-3n")
# model.push_to_hub("HF_ACCOUNT/gemma-3", token = "...") # Online saving
# tokenizer.push_to_hub("HF_ACCOUNT/gemma-3", token = "...") # Online saving

Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:

In [None]:
if False:
    from unsloth import FastModel
    model, tokenizer = FastModel.from_pretrained(
        model_name = "lora_model", # YOUR MODEL YOU USED FOR TRAINING
        max_seq_length = 2048,
        load_in_4bit = True,
    )

messages = [{
    "role": "user",
    "content": [{"type" : "text", "text" : "What is Gemma-3N?",}]
}]
inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True, # Must add for generation
    return_tensors = "pt",
    tokenize = True,
    return_dict = True,
).to("cuda")

from transformers import TextStreamer
_ = model.generate(
    **inputs,
    max_new_tokens = 128, # Increase for longer outputs!
    # Recommended Gemma-3 settings!
    temperature = 1.0, top_p = 0.95, top_k = 64,
    streamer = TextStreamer(tokenizer, skip_prompt = True),
)

### Saving to float16 for VLLM

We also support saving to `float16` directly for deployment! We save it in the folder `gemma-3N-finetune`. Set `if False` to `if True` to let it run!

In [None]:
if False: # Change to True to save finetune!
    model.save_pretrained_merged("gemma-3N-finetune", tokenizer)

If you want to upload / push to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!

In [None]:
if False: # Change to True to upload finetune
    model.push_to_hub_merged(
        "HF_ACCOUNT/gemma-3N-finetune", tokenizer,
        token = "hf_..."
    )

### GGUF / llama.cpp Conversion
To save to `GGUF` / `llama.cpp`, we support it natively now for all models! For now, you can convert easily to `Q8_0, F16 or BF16` precision. `Q4_K_M` for 4bit will come later!

In [None]:
if False: # Change to True to save to GGUF
    model.save_pretrained_gguf(
        "gemma-3N-finetune",
        quantization_type = "Q8_0", # For now only Q8_0, BF16, F16 supported
    )

Likewise, if you want to instead push to GGUF to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!

In [None]:
if False: # Change to True to upload GGUF
    model.push_to_hub_gguf(
        "gemma-3N-finetune",
        quantization_type = "Q8_0", # Only Q8_0, BF16, F16 supported
        repo_id = "HF_ACCOUNT/gemma-3N-finetune-gguf",
        token = "hf_...",
    )