This tutorial demonstrates how to preference-tune a Gemma model using Direct
Preference Optimization (DPO). We will use the UltraFeedback dataset, a
large-scale collection of high-quality AI feedback on user-assistant
conversations.

DPO is a preference tuning method for aligning large language models with
human or AI preferences. It is a more efficient, performant alternative
to RLHF. DPO works by directly training the model on paired examples of "chosen"
(preferred) and "rejected" responses. This process fine-tunes the model to
increase the probability of generating desirable outputs and decrease the
likelihood of undesirable ones, simplifying the alignment process by eliminating
the need for complex sampling or hyperparameter tuning.

This notebook has been tested on Colab's `v6e-1` TPU instance, with 30 GB
memory.

For reference:

- Dataset: UltraFeedback ([Paper](https://arxiv.org/pdf/2310.01377), [HuggingFace](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences-cleaned))

- Algorithm: [Direct Preference Optimization](https://arxiv.org/pdf/2305.1829)

## Install necessary libraries

In [None]:
!pip install -q kagglehub

!pip install -q ipywidgets

!pip install -q datasets
!pip install -q tensorflow
!pip install -q tensorboardX
!pip install -q transformers
!pip install -q grain
!pip install -q git+https://github.com/google/tunix
!pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
!pip install -q git+https://github.com/google/flax.git

!pip install -q datasets

## Imports

In [None]:
import functools
import gc
import os

from datasets import load_dataset
from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
import qwix
from tunix.examples.data import translation_dataset as data_lib
from tunix.generate import sampler as sampler_lib
from tunix.models.gemma3 import params as gemma3_params_lib
from tunix.sft import metrics_logger
from tunix.sft.dpo.dpo_trainer import DPOTrainer, DPOTrainingConfig
from tunix.sft.peft_main import obtain_model_config

## Hyperparameters

Let's define the configuration we are going to use. Note that this is by no
means a "perfect" set of hyperparameters. To get good results, you might have
to train the model for longer.

In [None]:
# ====== Data ======
TRAIN_FRACTION = 0.75

# ====== LoRA ======
RANK = 16
ALPHA = 2.0

# ====== Sharding ======
MESH = None

# ====== DPO ======
BETA = 0.04
MAX_PROMPT_LENGTH = 512
MAX_RESPONSE_LENGTH = 1024

# ====== Training ======
BATCH_SIZE = 1
NUM_BATCHES = 100
NUM_EPOCHS = 2  # can potentially train for more epochs
# Number of training steps.
MAX_STEPS = int(NUM_BATCHES * TRAIN_FRACTION * NUM_EPOCHS)
EVAL_EVERY_N_STEPS = 10

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1

# Checkpoint saving
INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
CKPT_DIR = "/tmp/content/ckpts/"
SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4

# ====== Generation/Inference ======
TOTAL_GENERATION_STEPS = 1024

TOP_P = 1.0
TOP_K = 50
TEMPERATURE = 0.9
CACHE_SIZE = TOTAL_GENERATION_STEPS + MAX_PROMPT_LENGTH

## Utility functions

In [None]:
def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  for d in jax.local_devices():
    stats = d.memory_stats()
    used = stats["bytes_in_use"]
    limit = stats["bytes_limit"]
    print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")


show_hbm_usage()

## Data preprocessing

The data preprocessing is taken care of by Tunix's `DPOTrainer`. All we need is
to make sure we feed in our prompts in the correct format (for example, adding
correct special tokens, such as `<start_of_turn>`, `<end_of_turn>`), etc.

In [None]:
PROMPT_TEMPLATE = """<start_of_turn>user
{prompt}<end_of_turn>
<start_of_turn>model"""

In [None]:
def get_dataset() -> grain.MapDataset:
  hf_dataset = load_dataset(
      "argilla/ultrafeedback-binarized-preferences-cleaned"
  )["train"]

  def _get_response(x):
    for element in x:
      if element["role"] == "assistant":
        return element["content"]

  dataset = grain.MapDataset.source(hf_dataset).map(
      lambda x: {
          "prompts": PROMPT_TEMPLATE.format(prompt=x["prompt"]),
          "chosen_responses": _get_response(x["chosen"]),
          "rejected_responses": _get_response(x["rejected"]),
      }
  )
  return dataset

In [None]:
dataset = get_dataset().batch(BATCH_SIZE)[:NUM_BATCHES]

if TRAIN_FRACTION == 1.0:
  train_dataset = dataset.repeat(NUM_EPOCHS)
  test_dataset = None
else:
  train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]
  train_dataset = train_dataset.repeat(NUM_EPOCHS)

  test_dataset = dataset[int(len(dataset) * TRAIN_FRACTION) :].repeat(
      NUM_EPOCHS
  )

len(train_dataset), len(test_dataset) if test_dataset is not None else 0

Let's see how one batch of data looks like!

In [None]:
for ele in train_dataset:
  print(ele)
  break

## Load policy model and reference model

The policy model is the model which is actually trained and whose weights are
updated. The reference model is the model which stays fixed during
training, and with which we compare the policy model with.

Typically, the reference model is the base model, and the policy model is the
same base model, but with LoRA parameters. Only the LoRA parameters are updated.

Note: We perform full precision (fp32) training. You can, however, leverage Qwix for QAT.

To load the model, you need to be on [Kaggle](https://www.kaggle.com/) and need
to have agreed to the Gemma license
[here](https://www.kaggle.com/models/google/gemma/flax/). Instead of logging in, we recommend using Colab Secrets. This way, you don't have to manually enter your username and password every time you run the notebook.

In [None]:
# Log in (no need to fill this in if you've set up Colab Secrets)
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
  kagglehub.login()

In [None]:
model_path = {
    "gemma3": "google/gemma-3/flax/",
}
model_family = "gemma3"
model_params = "gemma3-1b"
model_version = "gemma3-1b-it"

print(f"{model_path[model_family]}{model_version}")

kaggle_ckpt_path = kagglehub.model_download(
    f"{model_path[model_family]}{model_version}"
)

### Model Loading and LoRA Application

These two functions work together to load a base model from a checkpoint and apply a LoRA (Low-Rank Adaptation) layer to it.

* `get_ref_model`: Loads the complete Gemma model from a specified checkpoint path. It uses **JAX sharding** to distribute the model parameters across multiple devices.
* `get_lora_model`: Takes the base model and applies LoRA layers to it. It uses a `LoraProvider` to select specific layers (like attention and MLP layers) to be adapted. The resulting LoRA-infused model is then sharded and updated to ensure it's ready for distributed training.

In [None]:
def get_ref_model(kaggle_ckpt_path):
  mesh = None
  if MESH is not None:
    mesh = jax.make_mesh(*MESH)

  model_config = obtain_model_config(model_params)
  ref_model = gemma3_params_lib.create_model_from_checkpoint(
      os.path.join(kaggle_ckpt_path, model_version), model_config, MESH
  )
  return ref_model, mesh, model_config


def get_lora_model(base_model, mesh):
  lora_provider = qwix.LoraProvider(
      module_path=(
          ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
          ".*attn_vec_einsum"
      ),
      rank=RANK,
      alpha=ALPHA,
  )

  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  if mesh is not None:
    with mesh:
      state = nnx.state(lora_model)
      pspecs = nnx.get_partition_spec(state)
      sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
      nnx.update(lora_model, sharded_state)

  return lora_model

In [None]:
# Reference model
ref_model, mesh, model_config = get_ref_model(kaggle_ckpt_path)

In [None]:
# Policy model
lora_model = get_lora_model(ref_model, mesh=mesh)
nnx.display(lora_model)

In [None]:
# Tokenizer
tokenizer = data_lib.GemmaTokenizer(
    os.path.join(kaggle_ckpt_path, "tokenizer.model")
)

## Train

Let's set up all the configs first - checkpointing, metric logging and training.
We then train the model.

Note: To get good results, it is advised to train the model for longer.

In [None]:
# Ckpt saving
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
)

# Metrics logger
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/content/tmp/tensorboard/dpo", flush_every_n_steps=20
)

In [None]:
# Logs
%load_ext tensorboard
%tensorboard --logdir /tmp/content/tmp/tensorboard/dpo --port=0

In [None]:
# Training config
training_config = DPOTrainingConfig(
    beta=BETA,
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=MAX_STEPS,
    metrics_logging_options=metrics_logging_options,
    checkpoint_root_directory=CKPT_DIR,
    checkpointing_options=checkpointing_options,
    max_prompt_length=MAX_PROMPT_LENGTH,
    max_response_length=MAX_RESPONSE_LENGTH,
)

# Trainer
dpo_trainer = DPOTrainer(
    model=lora_model,
    ref_model=ref_model,
    optimizer=optax.adamw(
        learning_rate=LEARNING_RATE,
        b1=B1,
        b2=B2,
        weight_decay=WEIGHT_DECAY,
    ),
    tokenizer=tokenizer,
    training_config=training_config,
)

In [None]:
if mesh is None:
  dpo_trainer.train(train_dataset)
else:
  with mesh:
    dpo_trainer.train(train_dataset)

## Evaluate

We evaluate the model's performance on the test dataset. For this initial analysis, we perform a qualitative comparison using randomly selected examples to get a sense of the output quality.

In [None]:
def generate(prompt, sampler, temperature=1.0, top_k=64, top_p=0.95):
  """Given prompt, generates text."""

  input_batch = [PROMPT_TEMPLATE.format(prompt=prompt)]

  out_data = sampler(
      input_strings=input_batch,
      max_generation_steps=TOTAL_GENERATION_STEPS,
      max_prompt_length=MAX_PROMPT_LENGTH,
      temperature=temperature,
      top_k=top_k,
      top_p=top_p,
      echo=False,
  )

  output = out_data.text

  if isinstance(prompt, str):
    return output[0]
  return output

In [None]:
# Load the checkpoint first.
trained_ckpt_path = os.path.join(CKPT_DIR, str(MAX_STEPS), "model_params")

abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_model, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_params)

nnx.update(
    lora_model,
    jax.tree.map(
        lambda a, b: b,
        nnx.state(lora_model, nnx.LoRAParam),
        trained_lora_params,
    ),
)

In [None]:
sampler = sampler_lib.Sampler(
    transformer=lora_model,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=CACHE_SIZE + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [None]:
# Randomly select an example from test dataset and eyeball compare the model
# output vs. chosen and reject responses.

test_index = 20
print("prompt: \n\n", test_dataset[test_index]["prompt"])
print("==" * 10)
print("chosen: \n\n", test_dataset[test_index]["chosen"])
print("==" * 10)
print("rejected: \n\n", test_dataset[test_index]["rejected"])
print("==" * 10)
print("DPO tuned model output")
text = generate(
    prompt=test_dataset[test_index]["prompt"],
    sampler=sampler,
    temperature=TEMPERATURE,
    top_k=TOP_K,
    top_p=TOP_P,
)
print(text)