In [1]:
!pip install -q kagglehub

!pip install -q tensorflow
!pip install -q tensorboardX
!pip install -q grain
!pip install -q git+https://github.com/google/tunix
!pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
!pip install -q git+https://github.com/google/flax.git

!pip install -q huggingface_hub
!pip install -q datasets

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m620.7/620.7 MB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.5/57.5 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.5/24.5 MB[0m [31m131.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m123.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.6/6.6 MB[0m [31m147.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m24.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.5/72.5 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import functools
import gc
import os
from pprint import pprint
import re
import time

from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm

from tunix.examples.data import translation_dataset as data_lib
from tunix.generate import sampler as sampler_lib
from tunix.models.gemma3 import params as params_lib
from tunix.models.gemma3 import params_safetensors as params_safetensors_lib

from tunix.sft import metrics_logger
from datasets import load_dataset
from tunix.sft.dpo.dpo_trainer import DpoTrainingConfig
from tunix.sft.dpo.dpo_trainer import DpoTrainer
from tunix.sft.dpo.dpo_trainer import TrainingInput
from huggingface_hub import snapshot_download
from tunix.sft.dpo.dpo_trainer import _generate_ids_and_masks
from tunix.models.gemma3 import model as gemma3_model_lib




In [3]:
# ====== Data ======
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0

INTERMEDIATE_CKPT_DIR = "/content/intermediate_ckpt/"
# ====== LoRA ======
RANK = 8
ALPHA = 16.0

# ====== Sharding ======
MESH = [(1, 1), ("fsdp", "tp")]

# ====== GRPO ======
# === Generation during GRPO training ===
MAX_PROMPT_LENGTH = 192
TOTAL_GENERATION_STEPS = 192
# Important to keep a high-ish temperature for varied, diverse responses during
# training.
TEMPERATURE = 0.9
TOP_P = 1.0
TOP_K = 50
# The number of times the policy generates multiple responses for a given prompt
# within a single training step. This corresponds to `G` in Algorithm 1 in the
# paper. The "group" in GRPO comes from here.
NUM_GENERATIONS = 2

# === other GRPO configs ===
# The number of iterations per batch (𝜇 in GRPO algo 1).
NUM_ITERATIONS = 1
# The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function.
# Important to keep a high enough value for this, otherwise, the KL divergence
# can increase unchecked.
BETA = 0.1
EVAL_EVERY_N_STEPS = 100
# Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for
# stable updates.
EPSILON = 0.2

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
# == Cosine decay with warmup scheduler ==
# Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# steps, and then gradually decrease the learning rate to 0 using cosine
# scheduler.
BATCH_SIZE = 1
NUM_BATCHES = 3738
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
# increased to a max. of 330 (if batch size is 4).
NUM_TEST_BATCHES = 100

EVAL_EVERY_N_STEPS = 10  # this doesn't matter if `TRAIN_FRACTION = 1.0`.
NUM_EPOCHS = 1  # can potentially train for more epochs
TRAIN_FRACTION = 1.0
MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)

WARMUP_STEPS = 0.1 * MAX_STEPS
# == Grad clipping ==
# Grad clipping to prevent large gradients. Found this
# important to keep KL divergence in check.
MAX_GRAD_NORM = 0.1

# ====== Inference ======
GENERATION_CONFIGS = {
    # greedy search
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    # some randomness
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    # liberal
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}

In [4]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) y
Token is valid (permission: fineGrained).
The token `amz` has been saved to /root/.cache/huggingface/stored_tokens
[1m[31mCannot authenticate through git-credential as no helper is defined on your machine.
You might have to re-authenti

In [5]:
model_id = "google/gemma-3-1b-it"
ignore_patterns = [
    "*.pth",  # Ignore PyTorch .pth weight files
]
print(f"Downloading {model_id} from Hugging Face...")
local_model_path = snapshot_download(
    repo_id=model_id, ignore_patterns=ignore_patterns
)
print(f"Model successfully downloaded to: {local_model_path}")

Downloading google/gemma-3-1b-it from Hugging Face...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

README.md:   0%|          | 0.00/24.3k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.00G [00:00<?, ?B/s]

config.json:   0%|          | 0.00/899 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

Model successfully downloaded to: /root/.cache/huggingface/hub/models--google--gemma-3-1b-it/snapshots/dcc83ea841ab6100d6b47a070329e1ba4cf78752


In [6]:
def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  print("\n--- TPU HBM Usage ---")
  for i, d in enumerate(jax.local_devices()):
    stats = d.memory_stats()
    used = stats.get("bytes_in_use", 0)
    limit = stats.get("bytes_limit", 0)

    hbm_used = stats.get("device:0:HBM0:bytes_in_use", used)
    hbm_limit = stats.get("device:0:HBM0:bytes_limit", limit)

    # Fallback if specific HBM stats not available
    if hbm_limit == 0:
      hbm_used = used
      hbm_limit = limit

    percentage = (hbm_used / hbm_limit * 100) if hbm_limit > 0 else 0

    print(
        f"Device {i} ({d.device_kind}): Using {fmt_size(hbm_used)} /"
        f" {fmt_size(hbm_limit)} ({percentage:.2f}%)"
    )

  print("--- End HBM Usage ---")

In [7]:
print("\n--- HBM Usage BEFORE Model Load ---")
show_hbm_usage()


--- HBM Usage BEFORE Model Load ---

--- TPU HBM Usage ---
Device 0 (TPU v6 lite): Using 32.1 KiB / 31.2 GiB (0.00%)
--- End HBM Usage ---


In [8]:
MODEL_CP_PATH = local_model_path

model_config = (
    gemma3_model_lib.Gemma3Config.gemma3_1b()
)  # pick correponding config based on model version
MESH = [(1, 1), ("fsdp", "tp")]
mesh = jax.make_mesh(*MESH)
with mesh:
  gemma3 = params_safetensors_lib.create_model_from_safe_tensors(
      MODEL_CP_PATH, model_config, mesh
  )
  nnx.display(gemma3)

In [9]:
print("\n--- HBM Usage AFTER Model Load ---")
show_hbm_usage()


--- HBM Usage AFTER Model Load ---

--- TPU HBM Usage ---
Device 0 (TPU v6 lite): Using 1.9 GiB / 31.2 GiB (5.97%)
--- End HBM Usage ---


In [10]:
# Log in
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
  kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Kaggle credentials set.
Kaggle credentials successfully validated.


In [12]:
def get_lora_model(base_model, mesh):
  lora_provider = qwix.LoraProvider(
      module_path=(
          ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
          ".*attn_vec_einsum"
      ),
      rank=RANK,
      alpha=ALPHA,
      #weight_qtype="nf4",
      #tile_size=4,
  )

  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  with mesh:
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)

  return lora_model

In [13]:
# Policy model
lora_gemma = get_lora_model(gemma3, mesh=mesh)
nnx.display(lora_gemma)

In [14]:
eval_dataset = load_dataset("gsm8k", "main", split="test").select(range(100))

README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

In [15]:
reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"
solution_start = "<answer>"
solution_end = "</answer>"


SYSTEM_PROMPT = f"""You are given a problem. Think about the problem and \
provide your reasoning. Place it between {reasoning_start} and \
{reasoning_end}. Then, provide the final answer (i.e., just one numerical \
value) between {solution_start} and {solution_end}."""

TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model"""

def generate(
    question, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None
):
  """Given prompt, generates text."""

  if isinstance(question, str):
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=question,
        ),
    ]
  else:
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=q,
        )
        for q in question
    ]

  out_data = sampler(
      input_strings=input_batch,
      max_generation_steps=TOTAL_GENERATION_STEPS, # Was 768
      temperature=temperature,
      top_k=top_k,
      top_p=top_p,
      echo=False,
      seed=seed if seed is not None else None,
  )

  output = out_data.text
  if isinstance(question, str):
    return output[0]
  return output

In [16]:
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

match_format.search(
    f"{reasoning_start}Let me"
    f" think!{reasoning_end}{solution_start}2{solution_end}",
)

match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL
)
match_numbers.findall(f"{solution_start}  0.34  {solution_end}")

['0.34']

In [17]:
def evaluate(
    dataset,
    sampler,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    num_passes=1,
    corr_lst=False,
    make_lst=False,
):
  """Computes accuracy and percentage of outputs matching the format."""

  response_lst = []
  corr = 0
  partially_corr = 0
  corr_format = 0
  total = 0

  for batch in tqdm(dataset):
    answers = batch["answer"]
    questions = batch["question"]

    multiple_call_responses = [[] for _ in range(len(questions))]
    for p in range(num_passes):
      responses = generate(
          questions, sampler, temperature, top_k, top_p, seed=p
      )
      for idx, response in enumerate(responses):
        multiple_call_responses[idx].append(response)

    for question, multiple_call_response, answer in zip(
        questions, multiple_call_responses, answers
    ):
      # check answer
      corr_ctr_per_question = 0
      partially_corr_per_question = 0
      corr_format_per_question = 0
      for response in multiple_call_response:
        extracted_response = (
            guess.group(1)
            if (guess := match_numbers.search(response)) is not None
            else "-1000000"
        )
        try:
          if float(extracted_response.strip()) == float(answer.strip()):
            corr_ctr_per_question += 1

          ratio = float(extracted_response.strip()) / float(answer.strip())
          if ratio >= 0.9 and ratio <= 1.1:
            partially_corr_per_question += 1
        except:
          print("SKIPPED")

        # check format
        if match_format.search(response) is not None:
          corr_format_per_question += 1

        if (
            corr_ctr_per_question > 0
            and partially_corr_per_question > 0
            and corr_format_per_question > 0
        ):
          break

      if corr_ctr_per_question > 0:
        corr += 1
        if corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      else:
        if not corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      if partially_corr_per_question > 0:
        partially_corr += 1
      if corr_format_per_question > 0:
        corr_format += 1

      total += 1
      if total % 10 == 0:
        print(
            f"===> {corr=}, {total=}, {corr / total * 100=}, "
            f"{partially_corr / total * 100=}, {corr_format / total * 100=}"
        )

  to_return = (
      corr,
      total,
      corr / total * 100,
      partially_corr / total * 100,
      corr_format / total * 100,
  )
  if make_lst:
    return to_return, response_lst
  return to_return

In [18]:
import sentencepiece as spm
from etils import epath
import numpy as np
class Gemma3Tokenizer(spm.SentencePieceProcessor):
  """Tokenizing and encoding/decoding text using the Sentencepiece tokenizer."""

  _GEMMA3_TOKENIZER_PATH: epath.PathLike = (
      'gs://gemma-data/tokenizers/tokenizer_gemma3.model'
  )

  def __init__(self, model_path: str = _GEMMA3_TOKENIZER_PATH):
    model_proto = epath.Path(model_path).read_bytes()
    super().__init__()
    self.LoadFromSerializedProto(model_proto)

  def tokenize(
      self,
      example: str,
      prefix: str = "",
      suffix: str = "",
      add_eos: bool = True,
  ) -> np.ndarray:
    """The tokenization function.

    Args:
      example: Input string to tokenize.
      prefix:  Prefix to add to the input string.
      suffix:  Suffix to add to the input string.
      add_eos: If True, add an "end of sentence" token at the end of the output
        sequence.

    Returns:
      Tokens corresponding to the input string.
    """
    int_list = [self.bos_id()]
    int_list.extend(self.EncodeAsIds(prefix + example + suffix))
    if add_eos:
      int_list.append(self.eos_id())

    return np.array(int_list, dtype=np.int32)

In [19]:
gemma_tokenizer = Gemma3Tokenizer()
#from transformers import AutoTokenizer

#gemma_tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
sampler = sampler_lib.Sampler(
    transformer=gemma3,
    tokenizer=gemma_tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [20]:
def extract_hash_answer(text: str) -> str | None:
  if "####" not in text:
    return None
  return text.split("####")[1].strip()

def get_dataset(data_dir, split="train") -> grain.MapDataset:
  # Download data
  if not os.path.exists(data_dir):
    os.makedirs(data_dir)

  data = tfds.data_source(
      "gsm8k",
      split=split,
      data_dir=data_dir,
      builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
      download=True,
  )

  dataset = (
      grain.MapDataset.source(data)
      .shuffle(seed=42)
      .map(
          lambda x: {
              # passed to model forward pass
              "prompts": TEMPLATE.format(
                  system_prompt=SYSTEM_PROMPT,
                  question=x["question"].decode("utf-8"),
              ),
              # passed to reward functions
              "question": x["question"].decode("utf-8"),
              # passed to reward functions
              "answer": extract_hash_answer(x["answer"].decode("utf-8")),
          }
      )
  )
  return dataset

test_dataset = get_dataset(TEST_DATA_DIR, "test").batch(BATCH_SIZE)[
    :NUM_TEST_BATCHES
]

len(test_dataset)



Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to data/test/gsm8k/1.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/4 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling data/test/gsm8k/incomplete.3YNLS8_1.0.0/gsm8k-train.array_record*...:   0%|          | 0/7473 [00:00…

Generating test examples...: 0 examples [00:00, ? examples/s]

Shuffling data/test/gsm8k/incomplete.3YNLS8_1.0.0/gsm8k-test.array_record*...:   0%|          | 0/1319 [00:00<…

Generating train_socratic examples...: 0 examples [00:00, ? examples/s]

Shuffling data/test/gsm8k/incomplete.3YNLS8_1.0.0/gsm8k-train_socratic.array_record*...:   0%|          | 0/74…

Generating test_socratic examples...: 0 examples [00:00, ? examples/s]

Shuffling data/test/gsm8k/incomplete.3YNLS8_1.0.0/gsm8k-test_socratic.array_record*...:   0%|          | 0/131…

Dataset gsm8k downloaded and prepared to data/test/gsm8k/1.0.0. Subsequent calls will reuse this data.


100

In [None]:
(corr, total, accuracy, partial_accuracy, format_accuracy), responses = evaluate(
    test_dataset,
    sampler,
    **GENERATION_CONFIGS["standard"],
    make_lst=True
)
print(
    f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
    f" {format_accuracy=}%"
)

  0%|          | 0/100 [00:00<?, ?it/s]

SKIPPED
===> corr=3, total=10, corr / total * 100=30.0, partially_corr / total * 100=30.0, corr_format / total * 100=0.0
===> corr=7, total=20, corr / total * 100=35.0, partially_corr / total * 100=40.0, corr_format / total * 100=0.0
===> corr=11, total=30, corr / total * 100=36.666666666666664, partially_corr / total * 100=40.0, corr_format / total * 100=0.0
===> corr=11, total=40, corr / total * 100=27.500000000000004, partially_corr / total * 100=30.0, corr_format / total * 100=0.0
===> corr=12, total=50, corr / total * 100=24.0, partially_corr / total * 100=30.0, corr_format / total * 100=0.0
===> corr=15, total=60, corr / total * 100=25.0, partially_corr / total * 100=31.666666666666664, corr_format / total * 100=0.0
===> corr=23, total=70, corr / total * 100=32.857142857142854, partially_corr / total * 100=40.0, corr_format / total * 100=0.0
===> corr=25, total=80, corr / total * 100=31.25, partially_corr / total * 100=37.5, corr_format / total * 100=0.0
===> corr=27, total=90, c

In [21]:
dpo_dataset = load_dataset("argilla/distilabel-intel-orca-dpo-pairs", split="train")
dpo_dataset[0]

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/79.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/12859 [00:00<?, ? examples/s]

{'system': '',
 'input': "You will be given a definition of a task first, then some input of the task.\nThis task is about using the specified sentence and converting the sentence to Resource Description Framework (RDF) triplets of the form (subject, predicate object). The RDF triplets generated must be such that the triplets accurately capture the structure and semantics of the input sentence. The input is a sentence and the output is a list of triplets of the form [subject, predicate, object] that capture the relationships present in the sentence. When a sentence has more than 1 RDF triplet possible, the output must contain all of them.\n\nAFC Ajax (amateurs)'s ground is Sportpark De Toekomst where Ajax Youth Academy also play.\nOutput:",
 'chosen': '[\n  ["AFC Ajax (amateurs)", "has ground", "Sportpark De Toekomst"],\n  ["Ajax Youth Academy", "plays at", "Sportpark De Toekomst"]\n]',
 'rejected': " Sure, I'd be happy to help! Here are the RDF triplets for the input sentence:\n\n[AFC

In [22]:
# Optimizer, learning rate scheduler, gradient clipping
optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)
if MAX_GRAD_NORM is not None:
  optimizer = optax.chain(
      optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
      optimizer,
  )

In [23]:
# Import DpoTrainer from the provided file
import orbax.checkpoint as ocp

# Configure DPO Training (using previously defined config variables)
dpo_config = DpoTrainingConfig(
    beta=BETA,
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=MAX_STEPS,
)

In [24]:
dpo_config

DpoTrainingConfig(eval_every_n_steps=10, max_steps=3738, gradient_accumulation_steps=None, checkpoint_root_directory=None, checkpointing_options=None, metrics_logging_options=None, profiler_options=None, data_sharding_axis=('fsdp',), max_inflight_computations=2, beta=0.1, label_smoothing=0.0, padding_value=0)

In [25]:
dpo_trainer = DpoTrainer(
    model=lora_gemma,
    ref_model=gemma3,
    optimizer=optimizer,
    training_config=dpo_config,
)

In [29]:
def process_dpo_dataset(dataset, tokenizer, max_prompt_length, total_generation_steps, batch_size):
    processed_batches = []
    for i in tqdm(range(0, len(dataset), batch_size)):
        batch_examples = dataset.select(range(i, min(i + batch_size, len(dataset))))
        processed_examples = []
        for example in batch_examples:
            prompt_ids, prompt_mask = _generate_ids_and_masks(
                [example["input"]], tokenizer, max_prompt_length, left_pad=True
            )
            chosen_ids, chosen_mask = _generate_ids_and_masks(
                [example["chosen"]], tokenizer, max_prompt_length + total_generation_steps, left_pad=False
            )
            rejected_ids, rejected_mask = _generate_ids_and_masks(
                [example["rejected"]], tokenizer, max_prompt_length + total_generation_steps, left_pad=False
            )
            processed_examples.append({
                "prompt_ids": prompt_ids,
                "prompt_mask": prompt_mask,
                "chosen_ids": chosen_ids,
                "chosen_mask": chosen_mask,
                "rejected_ids": rejected_ids,
                "rejected_mask": rejected_mask,
            })

        # Convert the list of dictionaries to a single TrainingInput object with NumPy arrays for the batch
        data_dict = {key: np.array([example[key] for example in processed_examples]) for key in processed_examples[0].keys()}

        training_input = TrainingInput( # Access TrainingInput as an attribute of the class
            prompt_ids=data_dict["prompt_ids"],
            prompt_mask=data_dict["prompt_mask"],
            chosen_ids=data_dict["chosen_ids"],
            chosen_mask=data_dict["chosen_mask"],
            rejected_ids=data_dict["rejected_ids"],
            rejected_mask=data_dict["rejected_mask"],
        )
        processed_batches.append(training_input)

    return processed_batches

processed_dpo_dataset = process_dpo_dataset(dpo_dataset.select(range(100)), gemma_tokenizer, MAX_PROMPT_LENGTH, TOTAL_GENERATION_STEPS, BATCH_SIZE)

# Display the first example in the processed dataset (optional)
# processed_dpo_dataset[0]

  0%|          | 0/100 [00:00<?, ?it/s]

In [30]:
from tunix.sft.dpo import dpo_trainer as dpo_lib
import grain
class MySource():

  def __init__(self, data):
    self._data = data

  def __getitem__(self, idx):
    return self._data[idx]

  def __len__(self):
    return len(self._data)
def _dummy_dataset(
    source: MySource,
    prompt_ids: np.ndarray,
    prompt_mask: np.ndarray,
    chosen_ids: np.ndarray,
    chosen_mask: np.ndarray,
    rejected_ids: np.ndarray,
    rejected_mask: np.ndarray,
):
  return grain.MapDataset.source(source).map(
      lambda x: dpo_lib.TrainingInput(
          prompt_ids=prompt_ids[x],
          prompt_mask=prompt_mask[x],
          chosen_ids=chosen_ids[x],
          chosen_mask=chosen_mask[x],
          rejected_ids=rejected_ids[x],
          rejected_mask=rejected_mask[x],
      )
  )
train_ds_dpo = _dummy_dataset(
        range(len(processed_dpo_dataset)),
        [processed_dpo_dataset[x].prompt_ids for x in range(len(processed_dpo_dataset))],
        [processed_dpo_dataset[x].prompt_mask for x in range(len(processed_dpo_dataset))],
        [processed_dpo_dataset[x].chosen_ids for x in range(len(processed_dpo_dataset))],
        [processed_dpo_dataset[x].chosen_mask for x in range(len(processed_dpo_dataset))],
        [processed_dpo_dataset[x].rejected_ids for x in range(len(processed_dpo_dataset))],
        [processed_dpo_dataset[x].rejected_mask for x in range(len(processed_dpo_dataset))],
    )

In [31]:
train_ds_dpo.__getitem__(2).prompt_ids.shape

(1, 192)

In [32]:
show_hbm_usage()


--- TPU HBM Usage ---
Device 0 (TPU v6 lite): Using 1.9 GiB / 31.2 GiB (6.12%)
--- End HBM Usage ---


In [33]:
# Start training
print("Starting DPO training...")

dpo_trainer.train(train_ds=train_ds_dpo)
print("DPO training finished.")

Starting DPO training...


Training:   0%|          | 0/3738 [00:00<?, ?step/s]

TypeError: expected number, got jaxlib._jax.ArrayImpl