<a target="_blank" href="https://colab.research.google.com/github/google-ai-edge/mediapipe-samples/blob/main/codelabs/litert_inference/Gemma3_1b_fine_tune.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

#Gemma-3-1B fine-tuning with SFT and on-device deployment with AI edge torch and MediaPipe.

In this colab, we will show you how to fine-tune a Gemma3-1B model using a synthetic reasoning dataset, finetune the model with LoRA adaptors and then convert the model to LiteRT format. Lastly we will load the LiteRT model and perform some inferences in colab environment.

(Note: to run this colab smoothly you will need a Colab Pro subscription which gives you GPU and high RAM access).

#Prerequisite

- Create HuggingFace token with permission access to
  - google/gemma-3-1b

  This is needed to download the tflite model and tokenizer.

- Open Colab Secrets: In your Google Colab notebook, locate the Secrets icon in the left-hand sidebar and click on it.
- Add a new secret: Click the "Add Secret" button.
- Name your secret: Enter "HF_TOKEN" for your token in the "Name" field.
- Paste your token: In the "Value" field, paste the actual token you want to store.

Note: When running notebooks in this repository with Google Colab, some users may see
the following warning message:

![Colab warning](https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/data/colab_warning.jpg?raw=true)

Please click `Restart Session` and run again.

#Install dependencies

In [None]:
!pip3 install --upgrade -q -U bitsandbytes==0.46.0
!pip3 install --upgrade -q -U peft==0.15.2
!pip3 install --upgrade -q -U trl==0.18.1
!pip3 install --upgrade -q -U accelerate==1.7.0
!pip3 install --upgrade -q -U datasets==3.6.0
!pip3 install --upgrade -q -U numpy==2.2.6
!pip3 install --force-reinstall transformers==4.52.3

In [None]:
! pip3 install ai-edge-torch-nightly==0.6.0.dev20250605
! pip3 install ai-edge-litert==1.3.0
! pip3 install mediapipe==0.10.21

In [None]:
import os
from google.colab import userdata
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')

#Download Gemma-3-1B from HuggingFace and set up tokenizer.

In [None]:
import os

import torch
from transformers import AutoTokenizer, BitsAndBytesConfig, GemmaTokenizer
from transformers.models.gemma3 import Gemma3ForCausalLM

model_id = 'google/gemma-3-1b-pt'
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
model = Gemma3ForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto", token=os.environ['HF_TOKEN'], attn_implementation='eager')
# Set up the chat format
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"


Now, with a simple prompt ("What is the primary function of mitochondria within a cell?"), from the sample output we can see that the base model is repeating user questions (which is expected before the fine-tuning step).

In [None]:
import torch

from transformers import pipeline

# Let's test the base model before training
prompt = "What is the primary function of mitochondria within a cell?"
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
pipe(prompt, max_new_tokens=100)

#Set up LoRA configurations, datasets and SFT training procedure.



In [None]:
os.environ["WANDB_DISABLED"] = "true"

from peft import LoraConfig, PeftModel

lora_config = LoraConfig(
    r=16,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

Download the SFT reasoning dataset from HuggingFace(argilla/synthetic-concise-reasoning-sft-filtered).

In [None]:
from datasets import load_dataset

ds = load_dataset("argilla/synthetic-concise-reasoning-sft-filtered")
def tokenize_function(examples):
    # Process all examples in the batch
    prompts = examples["prompt"]
    completions = examples["completion"]
    texts = []
    for prompt, completion in zip(prompts, completions):
        text = tokenizer.apply_chat_template([{"role": "user", "content": prompt.strip()}, {"role": "assistant", "content": completion.strip()}], tokenize=False)
        texts.append(text)
    return { "text" : texts }  # Return a list of texts

ds = ds.map(tokenize_function, batched = True)

Start the fine-tuning with 150 training steps (which will take ~3 minutes with single A100). Alternatively you can set `num_train_epochs=1` if you want to train with the entire SFT dataset, that will lead to even longer training times

In [None]:
import transformers
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset = ds['train'],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=150,
        #num_train_epochs=1,
        # Copied from other hugging face tuning blog posts
        learning_rate=2e-4,
        #fp16=True,
        bf16=True,
        # It makes training faster
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit",
        report_to = "none",
    ),
    peft_config=lora_config,
)
trainer.train()

Now, let's save the trainer weights, and run a few inference steps on the fine-tuned model to make sure it can perform question answering. Weights will be saved in a folder named "gemma3-1b-sft".

In [None]:
trainer.save_model("gemma3-1b-sft")

In [None]:
from transformers import pipeline
# Let's test the base model before training
prompt = "What is the primary function of mitochondria within a cell?"
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
pipe(prompt, max_new_tokens=100)

Next, we can merge the LoRA weights to the base model, and the saved checkpoint will be imported with ai-edge-torch to create a LiteRT model for on-device inference. Merged weights will be saved in a folder named "merged_model".

In [None]:
from peft import AutoPeftModelForCausalLM
import torch

# Load PEFT model on CPU
model = AutoPeftModelForCausalLM.from_pretrained("gemma3-1b-sft")
# Merge LoRA and base model and save
merged_model = model.merge_and_unload()
# Resize vocab size to match with base model vocabulary table.
merged_model.resize_token_embeddings(262144)
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")

Let's run inference again on the merged model to ensure if works as expected.

In [None]:
from transformers import pipeline

prompt = "What is the primary function of mitochondria within a cell?"
pipe = pipeline("text-generation", model=merged_model, tokenizer=tokenizer)
pipe(prompt, max_new_tokens=100)

#Load up the checkpoint in AI edge torch and convert to LiteRT.

Now let's convert our model(including 8-bit quantization) to LiteRT format, this will take roughly 10+ minutes to finish. The output tflite will be saved in the "/content" subfolder, with the name "gemma3_1b_finetune_q8_ekv1024.tflite"

In [None]:
import torch

from ai_edge_torch.generative.examples.gemma3 import gemma3
from ai_edge_torch.generative.layers import kv_cache
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities.export_config import ExportConfig


PREFILL_SEQ_LENS = [128]
KV_CACHE_MAX_LEN = 1024

def _create_mask(mask_len, kv_cache_max_len):
  mask = torch.full(
      (mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
  )
  mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
  return mask


def _create_export_config(
    prefill_seq_lens: list[int], kv_cache_max_len: int
) -> ExportConfig:
  """Creates the export config for the model."""
  export_config = ExportConfig()
  if isinstance(prefill_seq_lens, list):
    prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
  else:
    prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)

  export_config.prefill_mask = prefill_mask

  decode_mask = torch.full(
      (1, kv_cache_max_len), float('-inf'), dtype=torch.float32
  )
  decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
  export_config.decode_mask = decode_mask
  export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
  export_config.mask_as_input = True
  return export_config


def convert_to_litert():
  with torch.inference_mode(True):
    pytorch_model = gemma3.build_model_1b(
      "/content/merged_model", mask_cache_size=KV_CACHE_MAX_LEN,
    )
    converter.convert_to_tflite(
        pytorch_model,
        output_path="/content/",
        output_name_prefix="gemma3_1b_finetune",
        prefill_seq_len=PREFILL_SEQ_LENS,
        kv_cache_max_len=KV_CACHE_MAX_LEN,
        quantize=converter.QuantizationName.DYNAMIC_INT4_BLOCK32,
        lora_ranks=None,
        export_config=_create_export_config(
            prefill_seq_lens=PREFILL_SEQ_LENS,
            kv_cache_max_len=KV_CACHE_MAX_LEN,
        ),
    )

# Run model conversion.
convert_to_litert()

In [None]:
from ai_edge_litert import interpreter as interpreter_lib
from transformers import AutoTokenizer
import numpy as np
from collections.abc import Sequence
import sys

In [None]:
from transformers import AutoTokenizer

model_id = 'google/gemma-3-1b-pt'
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.chat_template = "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"

In [None]:
interpreter = interpreter_lib.InterpreterWithCustomOps(
    custom_op_registerers=["pywrap_genai_ops.GenAIOpsRegisterer"],
    model_path="/content/gemma3_1b_finetune_q4_block32_ekv1024.tflite",
    num_threads=2,
    experimental_default_delegate_latest_features=True)

# Create pipeline with LiteRT models

In [None]:
def _get_mask(shape: Sequence[int], k: int):
  """Gets the mask for the input to the model.

  Args:
    shape: The shape of the mask input to the model.
    k: all elements below the k-th diagonal are set to 0.

  Returns:
    The mask for the input to the model. All the elements in the mask are set
    to -inf except that all the elements below the k-th diagonal are set to 0.
  """
  mask = np.ones(shape, dtype=np.float32) * float("-inf")
  mask = np.triu(mask, k=k)
  return mask

class LiteRTLlmPipeline:

  def __init__(self, interpreter, tokenizer):
    """Initializes the pipeline."""
    self._interpreter = interpreter
    self._tokenizer = tokenizer

    self._prefill_runner = None
    self._decode_runner = self._interpreter.get_signature_runner("decode")


  def _init_prefill_runner(self, num_input_tokens: int):
    """Initializes all the variables related to the prefill runner.

    This method initializes the following variables:
      - self._prefill_runner: The prefill runner based on the input size.
      - self._max_seq_len: The maximum sequence length supported by the model.

    Args:
      num_input_tokens: The number of input tokens.
    """
    if not self._interpreter:
      raise ValueError("Interpreter is not initialized.")

    # Prefill runner related variables will be initialized in `predict_text` and
    # `compute_log_likelihood`.
    self._prefill_runner = self._get_prefill_runner(num_input_tokens)
    # input_token_shape has shape (batch, max_seq_len)
    input_token_shape = self._prefill_runner.get_input_details()["tokens"][
        "shape"
    ]
    if len(input_token_shape) == 1:
      self._max_seq_len = input_token_shape[0]
    else:
      self._max_seq_len = input_token_shape[1]

    # kv cache input has shape [batch=1, num_kv_heads, cache_size, head_dim].
    kv_cache_shape = self._prefill_runner.get_input_details()["kv_cache_k_0"][
        "shape"
    ]
    self._max_kv_cache_seq_len = kv_cache_shape[2]

  def _init_kv_cache(self) -> dict[str, np.ndarray]:
    if self._prefill_runner is None:
      raise ValueError("Prefill runner is not initialized.")
    kv_cache = {}
    for input_key in self._prefill_runner.get_input_details().keys():
      if "kv_cache" in input_key:
        kv_cache[input_key] = np.zeros(
            self._prefill_runner.get_input_details()[input_key]["shape"],
            dtype=np.float32,
        )
        kv_cache[input_key] = np.zeros(
            self._prefill_runner.get_input_details()[input_key]["shape"],
            dtype=np.float32,
        )
    return kv_cache

  def _get_prefill_runner(self, num_input_tokens: int) :
    """Gets the prefill runner with the best suitable input size.

    Args:
      num_input_tokens: The number of input tokens.

    Returns:
      The prefill runner with the smallest input size.
    """
    best_signature = None
    delta = sys.maxsize
    max_prefill_len = -1
    for key in self._interpreter.get_signature_list().keys():
      if "prefill" not in key:
        continue
      input_pos = self._interpreter.get_signature_runner(key).get_input_details()[
          "input_pos"
      ]
      # input_pos["shape"] has shape (max_seq_len, )
      seq_size = input_pos["shape"][0]
      max_prefill_len = max(max_prefill_len, seq_size)
      if num_input_tokens <= seq_size and seq_size - num_input_tokens < delta:
        delta = seq_size - num_input_tokens
        best_signature = key
    if best_signature is None:
      raise ValueError(
          "The largest prefill length supported is %d, but we have %d number of input tokens"
          %(max_prefill_len, num_input_tokens)
      )
    return self._interpreter.get_signature_runner(best_signature)

  def _run_prefill(
      self, prefill_token_ids: Sequence[int],
  ) -> dict[str, np.ndarray]:
    """Runs prefill and returns the kv cache.

    Args:
      prefill_token_ids: The token ids of the prefill input.

    Returns:
      The updated kv cache.
    """
    if not self._prefill_runner:
      raise ValueError("Prefill runner is not initialized.")
    prefill_token_length = len(prefill_token_ids)
    if prefill_token_length == 0:
      return self._init_kv_cache()

    # Prepare the input to be [1, max_seq_len].
    input_token_ids = [0] * self._max_seq_len
    input_token_ids[:prefill_token_length] = prefill_token_ids
    input_token_ids = np.asarray(input_token_ids, dtype=np.int32)
    input_token_ids = np.expand_dims(input_token_ids, axis=0)

    # Prepare the input position to be [max_seq_len].
    input_pos = [0] * self._max_seq_len
    input_pos[:prefill_token_length] = range(prefill_token_length)
    input_pos = np.asarray(input_pos, dtype=np.int32)

    # Initialize kv cache.
    prefill_inputs = self._init_kv_cache()
    # Prepare the tokens and input position inputs.
    prefill_inputs.update({
        "tokens": input_token_ids,
        "input_pos": input_pos,
    })
    if "mask" in self._prefill_runner.get_input_details().keys():
      # For prefill, mask has shape [batch=1, 1, seq_len, kv_cache_size].
      # We want mask[0, 0, i, j] = 0 for j<=i and -inf otherwise.
      prefill_inputs["mask"] = _get_mask(
          shape=self._prefill_runner.get_input_details()["mask"]["shape"],
          k=1,
      )
    prefill_outputs = self._prefill_runner(**prefill_inputs)
    if "logits" in prefill_outputs:
      # Prefill outputs includes logits and kv cache. We only output kv cache.
      prefill_outputs.pop("logits")

    return prefill_outputs

  def _greedy_sampler(self, logits: np.ndarray) -> int:
    return int(np.argmax(logits))


  def _run_decode(
      self,
      start_pos: int,
      start_token_id: int,
      kv_cache: dict[str, np.ndarray],
      max_decode_steps: int,
  ) -> str:
    """Runs decode and outputs the token ids from greedy sampler.

    Args:
      start_pos: The position of the first token of the decode input.
      start_token_id: The token id of the first token of the decode input.
      kv_cache: The kv cache from the prefill.
      max_decode_steps: The max decode steps.

    Returns:
      The token ids from the greedy sampler.
    """
    next_pos = start_pos
    next_token = start_token_id
    decode_text = []
    decode_inputs = kv_cache

    for _ in range(max_decode_steps):
      decode_inputs.update({
          "tokens": np.array([[next_token]], dtype=np.int32),
          "input_pos": np.array([next_pos], dtype=np.int32),
      })
      if "mask" in self._decode_runner.get_input_details().keys():
        # For decode, mask has shape [batch=1, 1, 1, kv_cache_size].
        # We want mask[0, 0, 0, j] = 0 for j<=next_pos and -inf otherwise.
        decode_inputs["mask"] = _get_mask(
            shape=self._decode_runner.get_input_details()["mask"]["shape"],
            k=next_pos + 1,
        )
      decode_outputs = self._decode_runner(**decode_inputs)
      # Output logits has shape (batch=1, 1, vocab_size). We only take the first
      # element.
      logits = decode_outputs.pop("logits")[0][0]
      next_token = self._greedy_sampler(logits)
      if next_token == self._tokenizer.eos_token_id:
        break
      decode_text.append(self._tokenizer.decode(next_token, skip_special_tokens=True))
      if len(decode_text[-1]) == 0:
        # Break out the loop if we hit the special token.
        break

      print(decode_text[-1], end='', flush=True)
      # Decode outputs includes logits and kv cache. We already poped out
      # logits, so the rest is kv cache. We pass the updated kv cache as input
      # to the next decode step.
      decode_inputs = decode_outputs
      next_pos += 1

    print() # print a new line at the end.
    return ''.join(decode_text)

  def generate(self, prompt: str, max_decode_steps: int | None = None) -> str:
    messages=[{ 'role': 'user', 'content': prompt}]
    token_ids = self._tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)
    # Initialize the prefill runner with the suitable input size.
    self._init_prefill_runner(len(token_ids))

    # Run prefill.
    # Prefill up to the seond to the last token of the prompt, because the last
    # token of the prompt will be used to bootstrap decode.
    prefill_token_length = len(token_ids) - 1

    print('Running prefill')
    kv_cache = self._run_prefill(token_ids[:prefill_token_length])
    # Run decode.
    print('Running decode')
    actual_max_decode_steps = self._max_kv_cache_seq_len - prefill_token_length - 1
    if max_decode_steps is not None:
      actual_max_decode_steps = min(actual_max_decode_steps, max_decode_steps)
    decode_text = self._run_decode(
        prefill_token_length,
        token_ids[prefill_token_length],
        kv_cache,
        actual_max_decode_steps,
    )
    return decode_text

# Generate text from model

In [None]:
# Disclaimer: Model performance demonstrated with the Python API in this notebook is not representative of performance on a local device.
pipeline = LiteRTLlmPipeline(interpreter, tokenizer)

In [None]:
prompt = "What is the primary function of mitochondria within a cell"
output = pipeline.generate(prompt, max_decode_steps = 100)

# Prepare task bundle for MediaPipe deployment

The task file will be named as "gemma3_1b_it_q8_ekv1280.task", and placed under the "/content" directory. Please refer to https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference about how to deploy the `task` file with MediaPipe LLM inference example.

In [None]:
import os
from google.colab import userdata
from huggingface_hub import hf_hub_download
import joblib

REPO_ID = "google/gemma-3-1b-it"
FILENAME = "tokenizer.model"
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')

tokenizer_model = (
    hf_hub_download(repo_id=REPO_ID, filename=FILENAME, local_dir="/content", token=os.environ['HF_TOKEN'])
)

In [None]:
from mediapipe.tasks.python.genai.bundler import llm_bundler

def build_gemma3_1b_it_block_q4():
  output_file = "/content/gemma3_1b_finetune_q4_block32_ekv1024.task"
  tflite_model = "/content/gemma3_1b_finetune_q4_block32_ekv1024.tflite"
  tokenizer_model = (
      "/content/tokenizer.model"
  )
  config = llm_bundler.BundleConfig(
      tflite_model=tflite_model,
      tokenizer_model=tokenizer_model,
      start_token="<bos>",
      stop_tokens=["<eos>"],
      output_filename=output_file,
      enable_bytes_to_unicode_mapping=False,
      prompt_prefix="<start_of_turn>user\n",
      prompt_suffix="<end_of_turn>\n<start_of_turn>model\n",
  )
  llm_bundler.create_bundle(config)

# Build the MediaPipe task bundle.
build_gemma3_1b_it_block_q4()