# Gemma-3-1B fine-tuning with SFT (Supervised Fine-Tuning).

In this notebook, we will see how to fine-tune a Gemma3-1B model using a synthetic reasoning dataset, finetune the model with LoRA adaptors.

TRL (which contains SFT), stands for Transformer Reinforcement Learning, is a Python library from Hugging Face designed to facilitate the training of transformer language models and diffusion models using reinforcement learning (RL)

Prerequisite: Create HuggingFace token with permission access to `google/gemma-3-1b`.

### Download Gemma-3-1B from HuggingFace and set up tokenizer.

In [1]:
import os
from huggingface_hub import login
import torch
from transformers import AutoTokenizer, BitsAndBytesConfig, GemmaTokenizer
from transformers.models.gemma3 import Gemma3ForCausalLM

my_token = "hf_xxxxxxxxxx"

login(token=my_token)

model_id = 'google/gemma-3-1b-it'
tokenizer = AutoTokenizer.from_pretrained(model_id, token=my_token)
model = Gemma3ForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto", token=my_token, attn_implementation='eager')
# Set up the chat format
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"


2025-07-21 22:40:53.137312: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-21 22:40:53.153938: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753137653.170429   58004 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753137653.175758   58004 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1753137653.184996   58004 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

Now, with a simple prompt ("What is the primary function of mitochondria within a cell?"), from the sample output we can see that the base model is repeating user questions (which is expected before the fine-tuning step).

In [2]:
from transformers import pipeline
torch.set_float32_matmul_precision('high')

# Let's test the base model before training
prompt = "What is the primary function of mitochondria within a cell?"
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
pipe(prompt, max_new_tokens=100)

Device set to use cuda:0
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_worker/__main__.py", line 8, in <module>
    from typing import TypeVar
  File "/opt/conda/lib/python3.11/site-packages/cv2/typing/__init__.py", line 61, in <module>
    import cv2.mat_wrapper
ImportError: libGL.so.1: cannot open shared object file: No such file or directory
W0721 22:41:29.608000 58004 site-packages/torch/_inductor/utils.py:1250] [0/0] Not enough SMs to use max_autotune_gemm mode


[{'generated_text': 'What is the primary function of mitochondria within a cell?\n\n**Answer:** The primary function of mitochondria is to generate energy in the form of ATP (adenosine triphosphate) through cellular respiration.\n\nHere\'s why this is the best answer:\n\n* **Mitochondria and Energy Production:** Mitochondria are often called the "powerhouses" of the cell because they are responsible for producing the energy needed to fuel cellular activities.\n* **ATP Synthesis:** This process of energy production is crucial for all cellular functions, including muscle contraction, nerve impulse transmission'}]

Set up LoRA configurations, datasets and SFT training procedure.

In [3]:
os.environ["WANDB_DISABLED"] = "true"

from peft import LoraConfig, PeftModel

lora_config = LoraConfig(
    r=16,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

Download the SFT reasoning dataset from HuggingFace(argilla/synthetic-concise-reasoning-sft-filtered).

In [4]:
from datasets import load_dataset

ds = load_dataset("argilla/synthetic-concise-reasoning-sft-filtered")
def tokenize_function(examples):
    # Process all examples in the batch
    prompts = examples["prompt"]
    completions = examples["completion"]
    texts = []
    for prompt, completion in zip(prompts, completions):
        text = tokenizer.apply_chat_template([{"role": "user", "content": prompt.strip()}, {"role": "assistant", "content": completion.strip()}], tokenize=False)
        texts.append(text)
    return { "text" : texts }  # Return a list of texts

ds = ds.map(tokenize_function, batched = True)

Start the fine-tuning with 150 training steps (which will take ~3 minutes with single A100. (2m 23.5s on my GTX 4060)). Alternatively you can set num_train_epochs=1 if you want to train with the entire SFT dataset, that will lead to even longer training times

In [5]:
import transformers
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset = ds['train'],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=150,
        #num_train_epochs=1,
        # Copied from other hugging face tuning blog posts
        learning_rate=2e-4,
        #fp16=True,
        bf16=True,
        # It makes training faster
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit",
        report_to = "none",
    ),
    peft_config=lora_config,
)
trainer.train()

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss
1,2.0102
2,1.9447
3,1.9797
4,1.8379
5,1.8184
6,1.6584
7,1.7049
8,1.5349
9,1.6183
10,1.8128


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


TrainOutput(global_step=150, training_loss=1.4332614644368489, metrics={'train_runtime': 142.2775, 'train_samples_per_second': 4.217, 'train_steps_per_second': 1.054, 'total_flos': 628927575183360.0, 'train_loss': 1.4332614644368489})

Now, let's save the trainer weights, and run a few inference steps on the fine-tuned model to make sure it can perform question answering. Weights will be saved in a folder named "gemma3-1b-sft".

In [6]:
trainer.save_model("gemma3-1b-sft")

Next, we can merge the LoRA weights to the base model, and the saved checkpoint will be imported with ai-edge-torch to create a LiteRT model for on-device inference. Merged weights will be saved in a folder named "merged_model".

In [7]:
from peft import AutoPeftModelForCausalLM
import torch

# Load PEFT model on CPU
model = AutoPeftModelForCausalLM.from_pretrained("gemma3-1b-sft")
# Merge LoRA and base model and save
merged_model = model.merge_and_unload()
# Resize vocab size to match with base model vocabulary table.
merged_model.resize_token_embeddings(262144)
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Let's run inference again on the merged model to ensure if works as expected.

In [8]:
from transformers import pipeline

prompt = "What is the primary function of mitochondria within a cell?"
pipe = pipeline("text-generation", model=merged_model, tokenizer=tokenizer)
pipe(prompt, max_new_tokens=100)

Device set to use cuda:0


[{'generated_text': "What is the primary function of mitochondria within a cell?Mitochondria primarily function as the cell's powerhouses, producing ATP (adenosine triphosphate) through cellular respiration. They convert nutrients into energy in the form of ATP, which is then used by the cell for various metabolic processes.\n\nThe primary function of mitochondria within a cell is to generate energy through cellular respiration. This process converts nutrients into ATP, which is used by the cell for various metabolic processes.\nThe main function of mitochondria is to produce ATP (adenosine triphosphate), which is used"}]