This tutorial demonstrates how to fine-tune the Gemma 3 model on the GSM8K math
reasoning dataset using Proximal Policy Optimization (PPO).

Proximal Policy Optimization is a reinforcement learning (RL) algorithm that has
become a standard for aligning Large Language Models (LLMs) with human
preferences. PPO operates on an actor-critic architecture, and its key
innovation is a "clipped surrogate objective" that constrains policy updates to
prevent large, destabilizing changes during training. This ensures a more stable
and reliable alignment process. Implementing PPO for LLMs can be complex,
typically requiring four models to be active in memory during training: the
policy model being trained, a frozen reference model, a reward model to score
outputs, and a value model to estimate future rewards. In this example, we use a
reward function instead of a reward model, and do not use a reference model at
all.

This notebook can be run on Colab's `v6e-1` TPU.

## Install necessary libraries

In [None]:
!pip install -q kagglehub

!pip install -q ipywidgets

!pip install -q datasets
!pip install -q tensorflow
!pip install -q tensorflow_datasets
!pip install -q tensorboardX
!pip install -q transformers
!pip install -q grain
!pip install -q git+https://github.com/google/tunix@test_811362001
!pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
!pip install -q git+https://github.com/google/flax.git

!pip install -q datasets

## Imports

In [None]:
import functools
import gc
import os
import re

from datasets import load_dataset
from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
from tunix.examples.data import translation_dataset as data_lib
from tunix.generate import sampler as sampler_lib
from tunix.models.gemma import gemma as gemma_lib
from tunix.models.gemma3 import params as gemma3_params_lib
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.ppo.ppo_learner import PPOConfig, PPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger
from tunix.sft.peft_main import obtain_model_config

## Hyperparameters

Let's define the configuration we are going to use.

In [None]:
# ====== Data ======
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0

# ====== Reproducibility ======
SEED = 42

# ====== LoRA (for Actor) ======
RANK = 64
ALPHA = 64.0

# ====== Sharding ======
MESH = [(1, 1), ("fsdp", "tp")]

# ====== PPO ======
# ====== Generation during PPO training ======
MAX_PROMPT_LENGTH = 512
TOTAL_GENERATION_STEPS = 512
# Important to keep a high-ish temperature for varied, diverse responses during
# training.
TEMPERATURE = 0.9
TOP_P = 1.0  # implies we don't do nucleus sampling
TOP_K = 50

# ====== Other PPO configs ======
# Number of internal PPO loops.
NUM_PPO_EPOCHS = 2
# No KL divergence used.
BETA = 0.0
# Epsilon value for clipping policy loss, for stable updates.
EPSILON = 0.2
# Discount factor for future rewards in GAE.
GAMMA = 1.0
# Lambda parameter for GAE.
GAE_LAMBDA = 1.0
# Range for clipping the value function loss.
CLIP_RANGE_VALUE = 0.5

# ====== Training ======
ROLLOUT_BATCH_SIZE = 16
MINI_BATCH_SIZE = 8
TRAINING_MICRO_BATCH_SIZE = 2
ROLLOUT_MICRO_BATCH_SIZE = 4
COMPUTE_LOGPS_MICRO_BATCH_SIZE = 4

NUM_BATCHES = 10
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
# increased to a max. of 330 (if batch size is 4).
NUM_TEST_BATCHES = 10
EVAL_EVERY_N_STEPS = 10  # this doesn't matter if `TRAIN_FRACTION = 1.0`.
# Not to be confused with `num_ppo_epochs`.
NUM_EPOCHS = 1

# Number of training steps.
MAX_STEPS = int(
    NUM_BATCHES
    * TRAIN_FRACTION
    * NUM_EPOCHS
    * ROLLOUT_BATCH_SIZE
    // MINI_BATCH_SIZE
)

# ====== Optimizers & Schedulers ======
ACTOR_LEARNING_RATE = 1e-6
CRITIC_LEARNING_RATE = 1e-5
B1, B2 = 0.9, 0.999
WEIGHT_DECAY = 0.01
WARMUP_STEPS = 0.1 * MAX_STEPS
MAX_GRAD_NORM = 1.0

# ====== Checkpoint Saving ======
INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
CKPT_DIR = "/tmp/content/ckpts/"
SAVE_INTERVAL_STEPS = 5
MAX_TO_KEEP = 4
DO_MEM_PROFILING = False

# ====== Inference ======
GENERATION_CONFIGS = {
    # greedy search
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    # some randomness
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    # liberal
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}
CACHE_SIZE = 1024

## Utility functions

In [None]:
def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  for d in jax.local_devices():
    stats = d.memory_stats()
    used = stats["bytes_in_use"]
    limit = stats["bytes_limit"]
    print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")

## Data preprocessing

Fist, let's define a system prompt for the model. We instruct the model to think
step-by-step, and then produce the final answer.

In [None]:
SYSTEM_PROMPT = (
    'Let\'s think step by step and output the final answer after "####".'
)

TEMPLATE = """<start_of_turn>user
{question} {system_prompt}<end_of_turn>
<start_of_turn>model"""

We use OpenAI's [GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k).
GSM8K comprises grade school math word problems.

In [None]:
def extract_solution(solution_str):
  solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
  assert solution is not None
  final_solution = solution.group(0)
  final_solution = final_solution.split("#### ")[1].replace(",", "")
  return final_solution


def get_dataset(data_dir, split="train") -> grain.MapDataset:
  # Download data
  if not os.path.exists(data_dir):
    os.makedirs(data_dir)

  data = tfds.data_source(
      "gsm8k",
      split=split,
      data_dir=data_dir,
      builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
      download=True,
  )

  dataset = (
      grain.MapDataset.source(data)
      .shuffle(seed=42)
      .map(
          lambda x: {
              # passed to model forward pass
              "prompts": TEMPLATE.format(
                  system_prompt=SYSTEM_PROMPT,
                  question=x["question"].decode("utf-8"),
              ),
              # passed to reward functions
              "question": x["question"].decode("utf-8"),
              # passed to reward functions
              "answer": extract_solution(x["answer"].decode("utf-8")),
          }
      )
  )
  return dataset

We split the dataset into train and validation datasets, and also load the test dataset.

In [None]:
dataset = get_dataset(TRAIN_DATA_DIR, "train").batch(ROLLOUT_BATCH_SIZE)[
    :NUM_BATCHES
]

if TRAIN_FRACTION == 1.0:
  train_dataset = dataset.repeat(NUM_EPOCHS)
  val_dataset = None
else:
  train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]
  train_dataset = train_dataset.repeat(NUM_EPOCHS)

  val_dataset = dataset[int(len(dataset) * TRAIN_FRACTION) :].repeat(NUM_EPOCHS)

test_dataset = get_dataset(TEST_DATA_DIR, "test").batch(ROLLOUT_BATCH_SIZE)[
    :NUM_TEST_BATCHES
]

len(train_dataset), len(val_dataset) if val_dataset is not None else 0, len(
    test_dataset
)

Let's see how one batch of the dataset looks like!

In [None]:
for ele in train_dataset[:1]:
  print(ele)

## Load Policy, Critic Models

In PPO, we use four models:

- The policy model is the live language model being trained.
  It's initialized from a supervised fine-tuned (SFT) model, and its weights are
  updated to generate responses that maximize a reward.
- The reference model is a frozen, initial copy of the SFT model. It's used to
  calculate a KL divergence penalty that is subtracted from the reward score
  (`reward - KL_penalty`). This acts as a regularizer, preventing the policy
  from deviating too far from the SFT model's coherent style.
- The critic model, or value model, is also trained alongside the policy. Its
job is to predict the expected future reward (the "value") from a given
sequence. This value is essential for calculating advantages — a signal that
tells the policy if its actions were better or worse than average, helping it
learn more efficiently.
- The reward model to compute the reward.

In this example, we skip the reference model, and use a reward function instead
of a reward model.

Note: We perform full precision (fp32) training. You can, however, leverage Qwix
for QAT.

To load the model, you need to be on [Kaggle](https://www.kaggle.com/) and
need to have agreed to the Gemma license
[here](https://www.kaggle.commodels/google/gemma/flax/). Instead of logging in,
we recommend using Colab Secrets. This way, you don't have to manually enter
your username and password every time you run the notebook.

In [None]:
# Log in (no need to fill this in if you've set up Colab Secrets)
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
  kagglehub.login()

In [None]:
model_path = {
    "gemma3": "google/gemma-3/flax/",
}
model_family = "gemma3"
model_params = "gemma3-1b"
model_version = "gemma3-1b-it"

print(f"{model_path[model_family]}{model_version}")

kaggle_ckpt_path = kagglehub.model_download(
    f"{model_path[model_family]}{model_version}"
)

In [None]:
def get_base_model(kaggle_ckpt_path):
  mesh = None
  if MESH is not None:
    mesh = jax.make_mesh(*MESH)

  model_config = obtain_model_config(model_params)
  with mesh:
    ref_model = gemma3_params_lib.create_model_from_checkpoint(
        os.path.join(kaggle_ckpt_path, model_version), model_config, mesh
    )
  return ref_model, mesh, model_config


def get_lora_model(base_model, mesh):
  lora_provider = qwix.LoraProvider(
      module_path=(
          ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
          ".*attn_vec_einsum"
      ),
      rank=RANK,
      alpha=ALPHA,
  )

  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  if mesh is not None:
    with mesh:
      state = nnx.state(lora_model)
      pspecs = nnx.get_partition_spec(state)
      sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
      nnx.update(lora_model, sharded_state)

  return lora_model

In [None]:
# Policy
actor_model, mesh, model_config = get_base_model(kaggle_ckpt_path)
actor_model = get_lora_model(actor_model, mesh=mesh)

In [None]:
# Value
critic_model, _, _ = get_base_model(kaggle_ckpt_path)
critic_model = gemma_lib.TransformerWithScoreHead(
    critic_model,
    rngs=nnx.Rngs(params=jax.random.key(SEED)),
    shd_cfg=('fsdp', None),
)

In [None]:
show_hbm_usage()

## Define reward function

The reward function is designed such that the reward = 1 if the answer is correct, 0.1 if the answer is wrong but the format is correct, and 0 in all other cases.

This function has been picked from [here](https://github.com/volcengine/verl/tree/main/examples/ppo_trainer).

In [None]:
_SOLUTION_CLIP_CHARS = 300
FORMAT_SCORE = 0.1
CORR_ANS_SCORE = 1.0


def extract_solution(solution_str):
  if len(solution_str) > _SOLUTION_CLIP_CHARS:
    solution_str = solution_str[-_SOLUTION_CLIP_CHARS:]

  # this also tests the formatting of the model
  solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str)
  if len(solutions) == 0:
    final_answer = None
  else:
    # take the last solution
    final_answer = solutions[-1].replace(",", "").replace("$", "").strip()
  return final_answer


def compute_score(solution_str, ground_truth):
  answer = extract_solution(solution_str=solution_str)
  if answer is None:
    return 0
  else:
    if answer == str(ground_truth):
      return CORR_ANS_SCORE
    else:
      return FORMAT_SCORE


def reward_fn(prompts, completions, answer, **kargs):
  rewards = []
  for completion, true_answer in zip(completions, answer):
    reward = compute_score(completion, true_answer)
    rewards.append(reward)
  return rewards

Let's test our reward function

In [None]:
# Expected scores: 1.0, 0.1, 0.
reward_fn(
    ["prompt 1", "prompt 2", "prompt 3"],
    ["some thinking here #### 10", "no thinking #### 10", "wrong format"],
    [10, 100, 2],
)

## Evaluate

Before we train the model, let's evaluate the model on the test set so we can
see the improvement post training.

We evaluate it in two ways:

**Quantitative**

* **Answer Accuracy**: percentage of samples for which the model predicts the
correct final numerical answer   
* **Format Accuracy**: percentage of samples for which the model outputs the
correct format.

**Qualitative**

We'll also print outputs for a few given questions so that we can compare the
generated output later.

We define a helper function to generate an answer, given a prompt.

In [None]:
def generate(
    question,
    sampler,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    seed=None,
):
  """Given prompt, generates text."""

  if isinstance(question, str):
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=question,
        ),
    ]
  else:
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=q,
        )
        for q in question
    ]

  out_data = sampler(
      input_strings=input_batch,
      max_generation_steps=TOTAL_GENERATION_STEPS,
      temperature=temperature,
      top_k=top_k,
      top_p=top_p,
      echo=False,
      seed=seed if seed is not None else None,
  )

  output = out_data.text
  if isinstance(question, str):
    return output[0]
  return output

In [None]:
def evaluate(
    dataset,
    sampler,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    num_passes=1,
    corr_lst=False,
    make_lst=False,
):
  """Computes accuracy and percentage of outputs matching the format."""

  response_lst = []
  corr = 0
  corr_format = 0
  total = 0

  for batch in tqdm(dataset):
    answers = batch["answer"]
    questions = batch["question"]

    multiple_call_responses = [[] for _ in range(len(questions))]
    for p in range(num_passes):
      responses = generate(
          questions, sampler, temperature, top_k, top_p, seed=p
      )
      for idx, response in enumerate(responses):
        multiple_call_responses[idx].append(response)

    for question, multiple_call_response, answer in zip(
        questions, multiple_call_responses, answers
    ):
      # check answer
      corr_ctr_per_question = 0
      corr_format_per_question = 0
      for response in multiple_call_response:
        score = compute_score(response, answer)

        if score == 1.0:
          corr_ctr_per_question += 1
          # If the answer is correct, the format is correct too.
          corr_format_per_question += 1
        # check if answer is wrong, but format is correct.
        elif score == 0.1:
          corr_format_per_question += 1

        if corr_ctr_per_question > 0 and corr_format_per_question > 0:
          break

      if corr_ctr_per_question > 0:
        corr += 1
        if corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      else:
        if not corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      if corr_format_per_question > 0:
        corr_format += 1

      total += 1
      if total % 10 == 0:
        print(
            f"===> {corr=}, {total=}, {corr / total * 100=}, "
            f"{corr_format / total * 100=}"
        )

  to_return = (
      corr,
      total,
      corr / total * 100,
      corr_format / total * 100,
  )
  if make_lst:
    return to_return, response_lst
  return to_return

In [None]:
gemma_tokenizer = data_lib.GemmaTokenizer()
sampler = sampler_lib.Sampler(
    transformer=actor_model,
    tokenizer=gemma_tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=CACHE_SIZE,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [None]:
(corr, total, accuracy, format_accuracy) = evaluate(
    test_dataset,
    sampler,
    **GENERATION_CONFIGS["greedy"],
)
print(f"{corr=}, {total=}, {accuracy=}%, {format_accuracy=}%")

## Train

Let's set up all the configs first - checkpointing, metric logging and training. We then train the model.

In [None]:
# Ckpt saving
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
)

# Metrics logger
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/tensorboard/ppo", flush_every_n_steps=20
)

In [None]:
# Log
%load_ext tensorboard
%tensorboard --logdir tmp/tensorboard/ppo --port=0

In [None]:
def create_optimizer(lr):
  opt = optax.adamw(
      learning_rate=optax.schedules.warmup_cosine_decay_schedule(
          init_value=0.0,
          peak_value=lr,
          warmup_steps=WARMUP_STEPS,
          decay_steps=MAX_STEPS,
          end_value=0.0,
      ),
      b1=B1,
      b2=B2,
      weight_decay=WEIGHT_DECAY,
  )
  if MAX_GRAD_NORM:
    opt = optax.chain(optax.clip_by_global_norm(MAX_GRAD_NORM), opt)
  return opt


actor_optimizer = create_optimizer(ACTOR_LEARNING_RATE)
critic_optimizer = create_optimizer(CRITIC_LEARNING_RATE)

In [None]:
# Training config
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,  # Won't be used
        rl_cluster_lib.Role.CRITIC: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=actor_optimizer,
        critic_optimizer=critic_optimizer,
        max_steps=MAX_STEPS,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        metrics_logging_options=metrics_logging_options,
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=checkpointing_options,
        mini_batch_size=MINI_BATCH_SIZE,
        training_micro_batch_size=TRAINING_MICRO_BATCH_SIZE,
        rollout_micro_batch_size=ROLLOUT_MICRO_BATCH_SIZE,
        compute_logps_micro_batch_size=COMPUTE_LOGPS_MICRO_BATCH_SIZE,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
    ),
)

ppo_config = PPOConfig(
    num_ppo_epochs=NUM_PPO_EPOCHS,
    gamma=GAMMA,
    gae_lambda=GAE_LAMBDA,
    beta=BETA,
    epsilon=EPSILON,
    clip_range_value=CLIP_RANGE_VALUE,
)

In [None]:
# RL cluster
rl_cluster = rl_cluster_lib.RLCluster(
    actor=actor_model,
    reference=actor_model,
    critic=critic_model,
    tokenizer=gemma_tokenizer,
    cluster_config=cluster_config,
)

# PPO trainer
ppo_learner = PPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=reward_fn,
    ppo_config=ppo_config,
    data_shuffle_seed=SEED,
)

In [None]:
ppo_learner.train(train_dataset, eval_ds=None)

## Evaluate!

In [None]:
# Load checkpoint first.

trained_ckpt_path = os.path.join(
    CKPT_DIR, "actor", str(MAX_STEPS), "model_params"
)

abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(actor_model, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_params)

nnx.update(
    actor_model,
    jax.tree.map(
        lambda a, b: b,
        nnx.state(actor_model, nnx.LoRAParam),
        trained_lora_params,
    ),
)

In [None]:
sampler = sampler_lib.Sampler(
    transformer=actor_model,
    tokenizer=gemma_tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=CACHE_SIZE,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [None]:
(corr, total, accuracy, format_accuracy) = evaluate(
    test_dataset,
    sampler,
    **GENERATION_CONFIGS["greedy"],
)
print(f"{corr=}, {total=}, {accuracy=}%, {format_accuracy=}%")