# DPO Demo with math (gsm8k)

This notebook demonstrates how to fine-tune a Gemma3-1B model using Direct Preference Optimization (DPO). DPO is a method for training language models to align with human preferences without requiring a separate reward model.

## What this example covers:
- Loading and setting up a pre-trained Gemma3-1B instruction-tuned model
- Applying LoRA (Low-Rank Adaptation) for efficient fine-tuning
- Processing DPO training data with prompt/chosen/rejected response pairs
- Training the model using the DPO trainer from Tunix
- Evaluating model performance on GSM8K mathematical reasoning tasks

The training uses the Argilla DPO dataset containing preference pairs, focusing on GSM8K training examples to improve mathematical reasoning capabilities.

This notebook has been tested on a v6e-1 TPU instance, with 32 GB HBM.

In [None]:
# Install necessary libraries

!pip install -q tensorflow
!pip install -q tensorboardX
!pip install -q grain
!pip install -q git+https://github.com/google/tunix
!pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
!pip install -q git+https://github.com/google/flax.git

!pip install -q huggingface_hub
!pip install -q datasets

In [None]:
# Imports

import os

from datasets import concatenate_datasets
from datasets import load_dataset
from flax import nnx
import grain
from huggingface_hub import snapshot_download
import jax
import optax
from orbax import checkpoint as ocp
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma3 import model as gemma3_model_lib
from tunix.models.gemma3 import params_safetensors as params_safetensors_lib
from tunix.sft import metrics_logger
from tunix.sft.dpo.dpo_trainer import DPOTrainer
from tunix.sft.dpo.dpo_trainer import DPOTrainingConfig
from tunix.sft.utils import show_hbm_usage

In [None]:
# Hyperparamters/Config

model_id = "google/gemma-3-1b-it"
GEMMA_TOKENIZER_PATH = "gs://gemma-data/tokenizers/tokenizer_gemma3.model"

# ====== Data ======
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0

INTERMEDIATE_CKPT_DIR = "/content/intermediate_ckpt/"
# ====== LoRA ======
RANK = 32
ALPHA = 16.0

# ====== Sharding ======
MESH = [(1, 8), ("fsdp", "tp")]

MAX_PROMPT_LENGTH = 192
MAX_RESPONSE_LENGTH = 192
TEMPERATURE = 0.7
TOP_P = 1.0
TOP_K = 50
BETA = 0.1

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1

# == Cosine decay with warmup scheduler ==
# Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# steps, and then gradually decrease the learning rate to 0 using cosine
# scheduler.
BATCH_SIZE = 1
NUM_BATCHES = 512
NUM_TEST_BATCHES = 100
EVAL_EVERY_N_STEPS = 1024

NUM_EPOCHS = 1  # can potentially train for more epochs
TRAIN_FRACTION = 1.0
MAX_STEPS = int(NUM_BATCHES * TRAIN_FRACTION * NUM_EPOCHS)

WARMUP_STEPS = 0.1 * MAX_STEPS
# == Grad clipping ==
# Grad clipping to prevent large gradients. Found this
# important to keep KL divergence in check.
MAX_GRAD_NORM = 0.1

# Checkpoint saving
INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
CKPT_DIR = "/tmp/content/ckpts/"
SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4

# ====== Inference ======
GENERATION_CONFIGS = {
    # greedy search
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    # some randomness
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    # liberal
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}

## Load reference model and LoRA model

### Reference Model and LoRA Model

**Reference Model:** This is the original pre-trained Gemma3-1B instruction-tuned model that serves as the base for fine-tuning. It's loaded from the Hugging Face Hub.

**LoRA Model:** This is a Low-Rank Adaptation of the reference model. LoRA is a parameter-efficient fine-tuning technique that injects small, trainable matrices into specific layers of the pre-trained model, significantly reducing the number of parameters that need to be updated during training. This makes fine-tuning much faster and requires less memory compared to fine-tuning the entire model. The LoRA model is built on top of the reference model, inheriting its pre-trained weights and capabilities, while allowing for efficient adaptation to the DPO task.

In [None]:
!huggingface-cli login

In [None]:
ignore_patterns = [
    "*.pth",  # Ignore PyTorch .pth weight files
]
print(f"Downloading {model_id} from Hugging Face...")
local_model_path = snapshot_download(
    repo_id=model_id, ignore_patterns=ignore_patterns
)
print(f"Model successfully downloaded to: {local_model_path}")

In [None]:
print("\n--- HBM Usage BEFORE Model Load ---")
show_hbm_usage()

In [None]:
MODEL_CP_PATH = local_model_path

model_config = (
    gemma3_model_lib.ModelConfig.gemma3_1b()
)  # pick correponding config based on model version
mesh = jax.make_mesh(*MESH)
with mesh:
  gemma3 = params_safetensors_lib.create_model_from_safe_tensors(
      MODEL_CP_PATH, model_config, mesh
  )
  nnx.display(gemma3)

In [None]:
print("\n--- HBM Usage AFTER Model Load ---")
show_hbm_usage()

In [None]:
gemma_tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=GEMMA_TOKENIZER_PATH)
sampler = sampler_lib.Sampler(
    transformer=gemma3,
    tokenizer=gemma_tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [None]:
def get_lora_model(base_model, mesh):
  lora_provider = qwix.LoraProvider(
      module_path=(
          ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
          ".*attn_vec_einsum"
      ),
      rank=RANK,
      alpha=ALPHA,
  )

  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  with mesh:
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)

  return lora_model

In [None]:
# Policy model
lora_gemma = get_lora_model(gemma3, mesh=mesh)
nnx.display(lora_gemma)

Load evaluation data and evaluate the reference model

In [None]:
TEMPLATE = """<start_of_turn>user
{question}<end_of_turn>
<start_of_turn>model"""


def generate(
    question, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None
):
  """Given prompt, generates text."""

  if isinstance(question, str):
    input_batch = [
        TEMPLATE.format(
            question=question,
        ),
    ]
  else:
    input_batch = [
        TEMPLATE.format(
            question=q,
        )
        for q in question
    ]

  out_data = sampler(
      input_strings=input_batch,
      max_generation_steps=MAX_RESPONSE_LENGTH,
      temperature=temperature,
      top_k=top_k,
      top_p=top_p,
      echo=False,
      seed=seed if seed is not None else None,
  )

  output = out_data.text
  if isinstance(question, str):
    return output[0]
  return output

In [None]:
def evaluate(
    dataset,
    sampler,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    num_passes=1,
    corr_lst=False,
    make_lst=False,
):
  """Computes accuracy."""

  response_lst = []
  corr = 0
  total = 0

  for batch in tqdm(dataset):
    answers = batch["answer"]
    questions = batch["question"]

    multiple_call_responses = [[] for _ in range(len(questions))]
    for p in range(num_passes):
      responses = generate(
          questions, sampler, temperature, top_k, top_p, seed=p
      )
      for idx, response in enumerate(responses):
        multiple_call_responses[idx].append(response)

    for question, multiple_call_response, answer in zip(
        questions, multiple_call_responses, answers
    ):
      corr_ctr_per_question = 0

      for response in multiple_call_response:
        # Simple Accuracy: check for answer anywhere in the full response
        try:
          answer_no_comma = answer.replace(",", "")
          response_no_comma = response.replace(",", "")
          if (
              answer.strip() in response.strip()
              or answer_no_comma.strip() in response_no_comma.strip()
          ):
            corr_ctr_per_question += 1
        except:
          print("SKIPPED accuracy check")

        if corr_ctr_per_question > 0:
          break

      if corr_ctr_per_question > 0:
        corr += 1
        if corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      else:
        if not corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))

      total += 1
      if total % 10 == 0:
        print(f"===> {corr=}, {total=}, {corr / total * 100=}")

  to_return = (
      corr,
      total,
      corr / total * 100,
  )
  if make_lst:
    return to_return, response_lst
  return to_return

In [None]:
def extract_hash_answer(text: str) -> str | None:
  if "####" not in text:
    return None
  return text.split("####")[1].strip()


def get_dataset(data_dir, split="train") -> grain.MapDataset:
  # Download data
  if not os.path.exists(data_dir):
    os.makedirs(data_dir)

  data = tfds.data_source(
      "gsm8k",
      split=split,
      data_dir=data_dir,
      builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
      download=True,
  )

  dataset = (
      grain.MapDataset.source(data)
      .shuffle(seed=42)
      .map(
          lambda x: {
              # passed to model forward pass
              "prompts": TEMPLATE.format(
                  question=x["question"].decode("utf-8"),
              ),
              # passed to reward functions
              "question": x["question"].decode("utf-8"),
              # passed to reward functions
              "answer": extract_hash_answer(x["answer"].decode("utf-8")),
          }
      )
  )
  return dataset


test_dataset = get_dataset(TEST_DATA_DIR, "test").batch(BATCH_SIZE)[
    :NUM_TEST_BATCHES
]

len(test_dataset)

In [None]:
# Evaluate
# After evaluating the reference model on the GSM8K test dataset, we achieved an accuracy of around 65%.

(corr, total, accuracy), responses = evaluate(
    test_dataset,
    sampler,
    **GENERATION_CONFIGS["standard"],
    make_lst=True,
    num_passes=5,
)
print(f"{corr=}, {total=}, {accuracy=}%")

## DPO Dataset Preparation

The DPO training dataset is loaded from the "argilla/distilabel-intel-orca-dpo-pairs" dataset on the Hugging Face Hub. This dataset contains preference pairs (chosen and rejected responses) for various prompts.

To improve the model's performance on mathematical reasoning tasks, we prioritize samples from the GSM8K training set by filtering the dataset for records where `in_gsm8k_train` is True.

Since the number of GSM8K training samples might be less than the desired `NUM_BATCHES` for training, we add a sufficient number of random samples from the rest of the dataset to reach the target batch size. This ensures we have enough data for training while giving more weight to the GSM8K examples, and also helps improve the model's performance on general use cases.

In [None]:
def get_dataset() -> grain.MapDataset:
  dpo_dataset = load_dataset(
      "argilla/distilabel-intel-orca-dpo-pairs", split="train"
  )
  gsm8k_train_dpo_dataset = dpo_dataset.filter(lambda x: x["in_gsm8k_train"])

  # Get the number of samples in the filtered dataset
  num_gsm8k_train_samples = len(gsm8k_train_dpo_dataset)
  print(
      f"Number of samples with in_gsm8k_train=True: {num_gsm8k_train_samples}"
  )

  # Calculate how many more samples are needed
  total_samples_needed = NUM_BATCHES * BATCH_SIZE
  samples_to_add = total_samples_needed - num_gsm8k_train_samples
  print(f"Number of additional random samples needed: {samples_to_add}")

  if samples_to_add > 0:
    # Randomly select additional samples from the original dataset
    # Ensure we don't sample more than the total available in the original dataset
    random_samples = dpo_dataset.shuffle(seed=42).select(
        range(min(samples_to_add, len(dpo_dataset)))
    )
    print(f"Number of random samples selected: {len(random_samples)}")

    # Combine the filtered dataset and the random samples
    combined_dpo_dataset = concatenate_datasets(
        [gsm8k_train_dpo_dataset, random_samples]
    )
  else:
    combined_dpo_dataset = gsm8k_train_dpo_dataset

  print(f"Total samples in the combined dataset: {len(combined_dpo_dataset)}")

  def _get_response(x):
    for element in x:
      if element["role"] == "assistant":
        return element["content"]

  dataset = grain.MapDataset.source(combined_dpo_dataset).map(
      lambda x: {
          "prompts": TEMPLATE.format(question=x["input"]),
          "chosen_responses": x["chosen"],
          "rejected_responses": x["rejected"],
      }
  )
  return dataset


dataset = get_dataset().batch(BATCH_SIZE)[:NUM_BATCHES]

if TRAIN_FRACTION == 1.0:
  train_dataset = dataset.repeat(NUM_EPOCHS)
else:
  train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]
  train_dataset = train_dataset.repeat(NUM_EPOCHS)

len(train_dataset)

## Define optimizer and DPO Trainer

In [None]:
# Ckpt saving
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
)

# Metrics logger
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/content/tmp/tensorboard/dpo", flush_every_n_steps=20
)

In [None]:
# Logs
%load_ext tensorboard
%tensorboard --logdir /tmp/content/tmp/tensorboard/dpo --port=0

In [None]:
# Optimizer, learning rate scheduler, gradient clipping
optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)
if MAX_GRAD_NORM is not None:
  optimizer = optax.chain(
      optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
      optimizer,
  )

In [None]:
# Configure DPO Training
dpo_config = DPOTrainingConfig(
    beta=BETA,
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=MAX_STEPS,
    max_prompt_length=MAX_PROMPT_LENGTH,
    max_response_length=MAX_RESPONSE_LENGTH,
    metrics_logging_options=metrics_logging_options,
    checkpoint_root_directory=CKPT_DIR,
    checkpointing_options=checkpointing_options,
)

dpo_config

In [None]:
dpo_trainer = DPOTrainer(
    model=lora_gemma,
    ref_model=gemma3,
    optimizer=optimizer,
    training_config=dpo_config,
    tokenizer=gemma_tokenizer,
)

## Train and evaluate LoRA model

In [None]:
show_hbm_usage()

In [None]:
if mesh is None:
  dpo_trainer.train(train_dataset)
else:
  with mesh:
    dpo_trainer.train(train_dataset)

In [None]:
lora_sampler = sampler_lib.Sampler(
    transformer=lora_gemma,
    tokenizer=gemma_tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [None]:
# Evaluate
# After evaluating the finetuned model on the GSM8K test dataset, we achieved an accuracy of around 70%.

(corr, total, accuracy), responses = evaluate(
    test_dataset,
    lora_sampler,
    **GENERATION_CONFIGS["standard"],
    make_lst=True,
    num_passes=5,
)
print(f"{corr=}, {total=}, {accuracy=}%")

In [None]:
from google.colab import runtime

runtime.unassign()