<a href="https://colab.research.google.com/github/google/tunix/blob/main/examples/grpo_demo.ipynb" ><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This tutorial demonstrates training the Gemma 2 2B-IT model on the GSM8K math
reasoning benchmark using Group Relative Policy Optimization (GRPO). GRPO can
enhance your model's problem-solving skills on mathematical word problems,
coding problems, etc.

GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It
is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by
eliminating the need for a separate value function model. GRPO works by
generating multiple responses for a given prompt, evaluating these responses
using a reward model, and then calculating a relative advantage based on the
group's performance to update the policy.

In this tutorial we use Colab's `v2-8` TPU. Let's get started!

## Install necessary libraries

In [1]:
!pip install -q kagglehub


# !pip install -q tensorflow
# !pip install -q tensorboardX
# !pip install -q grain
# !pip install --force-reinstall "jax==0.6.2" "jaxlib==0.6.2" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install "jax[tpu]==0.6.2" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# !pip install -q git+https://github.com/google/tunix
! pip install -e ~/tunix_base/tunix/
!pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
# !pip install -q git+https://github.com/google/flax.git
!pip install -q git+https://github.com/google/flax.git@7a429f33fca2179079f163934a11658f6ddcb039
!pip install -q tensorflow-datasets

!pip install -q git+https://github.com/AI-Hypercomputer/pathways-utils.git

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html
Obtaining file:///home/linchai_google_com/tunix_base/tunix
  Installing build dependencies ... [?25ldone
[?25h  Checking if build backend supports build_editable ... [?25ldone
[?25h  Getting requirements to build editable ... [?25ldone
[?25h  Preparing editable metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: tunix
  Building editable for tunix (pyproject.toml) ... [?25ldone
[?25h  Created wheel for tunix: filename=tunix-0.0.0-0.editable-py3-none-any.whl size=8890 sha256=086620ddd7bf33017a615840534309e536b7bd1a8b5b894e2ccc6d1d2796fce8
  Stored in directory: /tmp/pip-ephem-wheel-cache-anub5q3q/wheels/1c/cf/5e/22da5c698298ab3a05917078e971943e9f05896c3467637ab4
Successfully built tunix
Installing collected packages: tunix
  Attempting uninstall: tunix
    Found existing installation: tunix 0.0.0
    Uninstalling tunix-0.0.0:
      Successfully uninstalled tunix-0.0.0


In [2]:
!pip install ipywidgets



## Imports

In [3]:
import functools
import gc
import os
from pprint import pprint
import re
import time

from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
# import kagglehub
import optax
from orbax import checkpoint as ocp
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
from tunix.generate import sampler as sampler_lib
from tunix.models.gemma import data as data_lib
from tunix.models.gemma import gemma as gemma_lib
from tunix.models.gemma import params as params_lib
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.rollout import base_rollout
from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner
from tunix.sft import metrics_logger

from tunix.models.llama3 import params as llama3_params
from tunix.models.llama3 import model as llama3_model

os.environ['TPU_LIBRARY_PATH'] = '/home/linchai_google_com/miniconda3/envs/vllm/lib/python3.12/site-packages/libtpu/libtpu.so'

## Hyperparameters

Let's define the configuration we are going to use. Note that this is by no
means a "perfect" set of hyperparameters. To get good results, you might have
to train the model for longer.

In [4]:
# ====== Data ======
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0

# ====== LoRA ======
RANK = 64
ALPHA = 64.0

# ====== Sharding ======
MESH = [(1, 8), ("fsdp", "tp")]

# ====== GRPO ======
# === Generation during GRPO training ===
MAX_PROMPT_LENGTH = 256
TOTAL_GENERATION_STEPS = 768
# Important to keep a high-ish temperature for varied, diverse responses during
# training.
TEMPERATURE = 0.9
TOP_P = 1.0
TOP_K = 50
# The number of times the policy generates multiple responses for a given prompt
# within a single training step. This corresponds to `G` in Algorithm 1 in the
# paper. The "group" in GRPO comes from here.
NUM_GENERATIONS = 2

# === other GRPO configs ===
# The number of iterations per batch (𝜇 in GRPO algo 1).
NUM_ITERATIONS = 1
# The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function.
# Important to keep a high enough value for this, otherwise, the KL divergence
# can increase unchecked.
BETA = 0.08
# Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for
# stable updates.
EPSILON = 0.2

# ====== Training ======
BATCH_SIZE = 1
# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.
# NUM_BATCHES = 3738
NUM_BATCHES = 4
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
# increased to a max. of 330 (if batch size is 4).
NUM_TEST_BATCHES = 100

EVAL_EVERY_N_STEPS = 10  # this doesn't matter if `TRAIN_FRACTION = 1.0`.
NUM_EPOCHS = 1  # can potentially train for more epochs

# Number of training steps.
MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
# == Cosine decay with warmup scheduler ==
# Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# steps, and then gradually decrease the learning rate to 0 using cosine
# scheduler.
WARMUP_STEPS = 0.1 * MAX_STEPS
# == Grad clipping ==
# Grad clipping to prevent large gradients. Found this
# important to keep KL divergence in check.
MAX_GRAD_NORM = 0.1

# Checkpoint saving
INTERMEDIATE_CKPT_DIR = "/home/linchai_google_com/content/intermediate_ckpt/"
CKPT_DIR = "/home/linchai_google_com/content/ckpts/"
SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4

# ====== Inference ======
GENERATION_CONFIGS = {
    # greedy search
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    # some randomness
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    # liberal
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}

## Utility functions

In [5]:
def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  for d in jax.local_devices():
    stats = d.memory_stats()
    used = stats["bytes_in_use"]
    limit = stats["bytes_limit"]
    print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")

## Data preprocessing

First, let's define some special tokens. We instruct the model to first reason
between the `<reasoning>` and `</reasoning>` tokens. After
reasoning, we expect it to provide the answer between the `<answer>` and
`</answer>` tokens.

In [6]:
reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"
solution_start = "<answer>"
solution_end = "</answer>"


SYSTEM_PROMPT = f"""You are given a problem. Think about the problem and \
provide your reasoning. Place it between {reasoning_start} and \
{reasoning_end}. Then, provide the final answer (i.e., just one numerical \
value) between {solution_start} and {solution_end}."""

TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model"""

We use OpenAI's GSM8K dataset. GSM8K comprises grade school math word problems.

In [7]:
def extract_hash_answer(text: str) -> str | None:
  if "####" not in text:
    return None
  return text.split("####")[1].strip()


def get_dataset(data_dir, split="train") -> grain.MapDataset:
  # Download data
  if not os.path.exists(data_dir):
    os.makedirs(data_dir)

  data = tfds.data_source(
      "gsm8k",
      split=split,
      data_dir=data_dir,
      builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
      download=True,
  )

  dataset = (
      grain.MapDataset.source(data)
      .shuffle(seed=42)
      .map(
          lambda x: {
              # passed to model forward pass
              "prompts": TEMPLATE.format(
                  system_prompt=SYSTEM_PROMPT,
                  question=x["question"].decode("utf-8"),
              ),
              # passed to reward functions
              "question": x["question"].decode("utf-8"),
              # passed to reward functions
              "answer": extract_hash_answer(x["answer"].decode("utf-8")),
          }
      )
  )
  return dataset

In [8]:
dataset = get_dataset(TRAIN_DATA_DIR, "train").batch(BATCH_SIZE)[:NUM_BATCHES]

if TRAIN_FRACTION == 1.0:
  train_dataset = dataset.repeat(NUM_EPOCHS)
  val_dataset = None
else:
  train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]
  train_dataset = train_dataset.repeat(NUM_EPOCHS)

  val_dataset = dataset[int(len(dataset) * TRAIN_FRACTION) :].repeat(NUM_EPOCHS)

test_dataset = get_dataset(TEST_DATA_DIR, "test").batch(BATCH_SIZE)[
    :NUM_TEST_BATCHES
]

len(train_dataset), len(val_dataset) if val_dataset is not None else 0, len(
    test_dataset
)

(4, 0, 100)

Let's see how one batch of the dataset looks like!


In [9]:
for ele in train_dataset[:1]:
  pprint(ele)

{'answer': array(['13'], dtype='<U2'),
 'prompts': array(['<start_of_turn>user\nYou are given a problem. Think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer (i.e., just one numerical value) between <answer> and </answer>.\n\nJane is painting her fingernails. She applies a base coat that takes 2 minutes to dry, two color coats that take 3 minutes each to dry, and a clear top coat that takes 5 minutes to dry. How many minutes total does Jane spend waiting for her nail polish to dry?<end_of_turn>\n<start_of_turn>model'],
      dtype='<U535'),
 'question': array(['Jane is painting her fingernails. She applies a base coat that takes 2 minutes to dry, two color coats that take 3 minutes each to dry, and a clear top coat that takes 5 minutes to dry. How many minutes total does Jane spend waiting for her nail polish to dry?'],
      dtype='<U260')}


## Load the policy model and the reference model

The policy model is the model which is actually trained and whose weights are
updated. The reference model is the model with which we compute KL divergence.
This is to ensure that the policy updates are not huge and that it does not
deviate too much from the reference model.

Typically, the reference model is the base model, and the policy model is the
same base model, but with LoRA parameters. Only the LoRA parameters are updated.

Note: We perform full precision (fp32) training. You can, however, leverage
Qwix for QAT.

To load the model, you need to be on [Kaggle](https://www.kaggle.com/) and need
to have agreed to the Gemma license
[here](https://www.kaggle.com/models/google/gemma/flax/).

In [10]:
MODEL_CP_PATH = "/home/linchai_google_com/llama3_1_8b/meta-llama/Llama-3.1-8B"

In [11]:
def get_ref_model():
  mesh = jax.make_mesh(*MESH)
  model_config = llama3_model.ModelConfig.llama3_1_8b()
  
  llama3 = llama3_params.create_model_from_safe_tensors(MODEL_CP_PATH, model_config, mesh)
  return llama3, mesh, model_config


def get_lora_model(base_model, mesh):
  lora_provider = qwix.LoraProvider(
      module_path=(
          ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
          ".*attn_vec_einsum"
      ),
      rank=RANK,
      alpha=ALPHA,
  )

  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  with mesh:
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)

  return lora_model

In [12]:
# Reference model
llama3_ref, mesh, model_config = get_ref_model()
nnx.display(llama3_ref)

In [13]:
# Policy model
# lora_gemma = get_lora_model(gemma, mesh=mesh)
llama3_policy, mesh, model_config = get_ref_model()
nnx.display(llama3_policy)

## Define reward functions

We define four reward functions:

- reward if the format of the output exactly matches the instruction given in
`TEMPLATE`;
- reward if the format of the output approximately matches the instruction given
in `TEMPLATE`;
- reward if the answer is correct/partially correct;
- Sometimes, the text between `<answer>`, `</answer>` might not be one
  number. So, extract the number, and reward the model if the answer is correct.

The reward functions are inspired from
[here](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb).

First off, let's define a RegEx for checking whether the format matches.

In [14]:
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

match_format.search(
    f"{reasoning_start}Let me"
    f" think!{reasoning_end}{solution_start}2{solution_end}",
)

<re.Match object; span=(0, 54), match='<reasoning>Let me think!</reasoning><answer>2</an>

Give the model a reward of 3 points if the format matches exactly.

In [15]:
def match_format_exactly(prompts, completions, **kargs):
  scores = []
  for completion in completions:
    score = 0
    response = completion
    # Match if format is seen exactly!
    if match_format.search(response) is not None:
      score += 3.0
    scores.append(score)
  return scores

We also reward the model if the format of the output matches partially.

In [16]:
def match_format_approximately(prompts, completions, **kargs):
  scores = []

  for completion in completions:
    score = 0
    response = completion
    # Count how many keywords are seen - we penalize if too many!
    # If we see 1, then plus some points!
    score += 0.5 if response.count(reasoning_start) == 1 else -0.5
    score += 0.5 if response.count(reasoning_end) == 1 else -0.5
    score += 0.5 if response.count(solution_start) == 1 else -0.5
    score += 0.5 if response.count(solution_end) == 1 else -0.5
    scores.append(score)
  return scores

Reward the model if the answer is correct. A reward is also given if the answer
does not match exactly, i.e., based on how close the answer is to the correct
value.

In [17]:
def check_answer(prompts, completions, answer, **kargs):
  responses = completions

  extracted_responses = [
      guess.group(1) if (guess := match_format.search(r)) is not None else None
      for r in responses
  ]

  scores = []
  for guess, true_answer in zip(extracted_responses, answer):
    score = 0
    if guess is None:
      scores.append(0)
      continue
    # Correct answer gets 3 points!
    if guess == true_answer:
      score += 3.0
    # Match if spaces are seen
    elif guess.strip() == true_answer.strip():
      score += 1.5
    else:
      # We also reward it if the answer is close via ratios!
      # Ie if the answer is within some range, reward it!
      try:
        ratio = float(guess) / float(true_answer)
        if ratio >= 0.9 and ratio <= 1.1:
          score += 0.5
        elif ratio >= 0.8 and ratio <= 1.2:
          score += 0.25
        else:
          score -= 1.0  # Penalize wrong answers
      except:
        score -= 0.5  # Penalize
    scores.append(score)
  return scores

Sometimes, the text between `<answer>` and `</answer>` might not be one
number; it can be a sentence. So, we extract the number and compare the answer.

In [18]:
match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL
)
match_numbers.findall(f"{solution_start}  0.34  {solution_end}")

['0.34']

In [19]:
def check_numbers(prompts, completions, answer, **kargs):
  question = kargs["question"]
  responses = completions

  extracted_responses = [
      guess.group(1) if (guess := match_numbers.search(r)) is not None else None
      for r in responses
  ]

  scores = []
  print("START ============================")
  print(f"Question: {question[0]}")
  print(f"Answer: {answer[0]}")
  print(f"Response: {responses[0]}")
  print(f"Extracted: {extracted_responses[0]}")
  print("END ==============================")
  for guess, true_answer in zip(extracted_responses, answer):
    if guess is None:
      scores.append(0)
      continue
    # Convert to numbers
    try:
      true_answer = float(true_answer.strip())
      guess = float(guess.strip())
      scores.append(1.5 if guess == true_answer else 0.0)
    except:
      scores.append(0)
      continue
  return scores

## Evaluate


Before we train the model, let's evaluate the model on the test set so we can
see the improvement post training.

We evaluate it in two ways:

**Quantitative**

* **Answer Accuracy**: percentage of samples for which the model predicts the
correct final numerical answer  
* **Answer (Partial) Accuracy**: percentage of samples for which the model
predicts a final numerical answer such that the \`model answer / answer\`
ratio lies between 0.9 and 1.1.  
* **Format Accuracy**: percentage of samples for which the model outputs the
correct format, i.e., reasoning between the reasoning special tokens, and the
final answer between the \`\<start\_answer\>\`, \`\<end\_answer\>\` tokens.

**Qualitative**

We'll also print outputs for a few given questions so that we can compare the generated output later.


In [20]:
# def generate(
#     question, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None
# ):
#   """Given prompt, generates text."""

#   if isinstance(question, str):
#     input_batch = [
#         TEMPLATE.format(
#             system_prompt=SYSTEM_PROMPT,
#             question=question,
#         ),
#     ]
#   else:
#     input_batch = [
#         TEMPLATE.format(
#             system_prompt=SYSTEM_PROMPT,
#             question=q,
#         )
#         for q in question
#     ]

#   out_data = sampler(
#       input_strings=input_batch,
#       total_generation_steps=768,
#       temperature=temperature,
#       top_k=top_k,
#       top_p=top_p,
#       echo=False,
#       seed=seed if seed is not None else None,
#   )

#   output = out_data.text
#   if isinstance(question, str):
#     return output[0]
#   return output

In [21]:
# def evaluate(
#     dataset,
#     sampler,
#     temperature=0.7,
#     top_k=50,
#     top_p=0.95,
#     num_passes=1,
#     corr_lst=False,
#     make_lst=False,
# ):
#   """Computes accuracy and percentage of outputs matching the format."""

#   response_lst = []
#   corr = 0
#   partially_corr = 0
#   corr_format = 0
#   total = 0

#   for batch in tqdm(dataset):
#     answers = batch["answer"]
#     questions = batch["question"]

#     multiple_call_responses = [[] for _ in range(len(questions))]
#     for p in range(num_passes):
#       responses = generate(
#           questions, sampler, temperature, top_k, top_p, seed=p
#       )
#       for idx, response in enumerate(responses):
#         multiple_call_responses[idx].append(response)

#     for question, multiple_call_response, answer in zip(
#         questions, multiple_call_responses, answers
#     ):
#       # check answer
#       corr_ctr_per_question = 0
#       partially_corr_per_question = 0
#       corr_format_per_question = 0
#       for response in multiple_call_response:
#         extracted_response = (
#             guess.group(1)
#             if (guess := match_numbers.search(response)) is not None
#             else "-1000000"
#         )
#         try:
#           if float(extracted_response.strip()) == float(answer.strip()):
#             corr_ctr_per_question += 1

#           ratio = float(extracted_response.strip()) / float(answer.strip())
#           if ratio >= 0.9 and ratio <= 1.1:
#             partially_corr_per_question += 1
#         except:
#           print("SKIPPED")

#         # check format
#         if match_format.search(response) is not None:
#           corr_format_per_question += 1

#         if (
#             corr_ctr_per_question > 0
#             and partially_corr_per_question > 0
#             and corr_format_per_question > 0
#         ):
#           break

#       if corr_ctr_per_question > 0:
#         corr += 1
#         if corr_lst and make_lst:
#           response_lst.append((question, answer, multiple_call_response))
#       else:
#         if not corr_lst and make_lst:
#           response_lst.append((question, answer, multiple_call_response))
#       if partially_corr_per_question > 0:
#         partially_corr += 1
#       if corr_format_per_question > 0:
#         corr_format += 1

#       total += 1
#       if total % 10 == 0:
#         print(
#             f"===> {corr=}, {total=}, {corr / total * 100=}, "
#             f"{partially_corr / total * 100=}, {corr_format / total * 100=}"
#         )

#   to_return = (
#       corr,
#       total,
#       corr / total * 100,
#       partially_corr / total * 100,
#       corr_format / total * 100,
#   )
#   if make_lst:
#     return to_return, response_lst
#   return to_return

In [22]:
# gemma_tokenizer = data_lib.GemmaTokenizer()
# sampler = sampler_lib.Sampler(
#     transformer=lora_gemma,
#     tokenizer=gemma_tokenizer,
#     cache_config=sampler_lib.CacheConfig(
#         cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
#         num_layers=model_config.num_layers,
#         num_kv_heads=model_config.num_kv_heads,
#         head_dim=model_config.head_dim,
#     ),
# )

In [23]:
# (corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
#     test_dataset,
#     sampler,
#     **GENERATION_CONFIGS["greedy"],
# )
# print(
#     f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
#     f" {format_accuracy=}%"
# )

In [24]:
# for eval_example in QUALITATIVE_EVAL_EXAMPLES:
#   question = eval_example["question"]
#   answer = eval_example["answer"]
#   response = generate(
#       question,
#       sampler,
#       temperature=INFERENCE_TEMPERATURE,
#       top_k=INFERENCE_TOP_K,
#       top_p=INFERENCE_TOP_P,
#   )

#   print(f"Question:\n{question}")
#   print(f"Answer:\n{answer}")
#   print(f"Response:\n{response}")
#   print("===============")

## Train

Let's set up all the configs first - checkpointing, metric logging and training.
We then train the model.

In [25]:
# Ckpt saving
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
)

# Metrics logger
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/home/linchai_google_com/content/tmp/tensorboard/grpo", flush_every_n_steps=20
)

In [26]:
# Logs
# %load_ext tensorboard
# %tensorboard --logdir /home/linchai_google_com/content/tmp/tensorboard/grpo --port=0

In [27]:
# Optimizer, learning rate scheduler, gradient clipping
optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)
if MAX_GRAD_NORM is not None:
  optimizer = optax.chain(
      optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
      optimizer,
  )

In [28]:
# Training config
rollout_mesh = jax.make_mesh(*[(1, 8), ("fsdp", "tp")])
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: rollout_mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=True,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=MAX_STEPS,
        gradient_accumulation_steps=1,
        # metrics logging
        metrics_logging_options=metrics_logging_options,
        # checkpoint saving
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
    ),
)

grpo_config = GrpoConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)

In [None]:
# test trainer logic only to see HBM usage
show_hbm_usage()
actor_trainer = rl_cluster_lib.rl_trainer.Trainer(
    model=llama3_policy,
    optimizer=cluster_config.training_config.actor_optimizer,
    training_config=cluster_config.training_config)

import numpy as np
from tunix.sft import peft_trainer
def dummy_datasets(batch_size: int, repeat: int = 1):
  # (num_batch, batch_size, seq_len)
  dummy_input = np.arange(256).reshape((batch_size, 256))
  return [
      peft_trainer.TrainingInput(
          input_tokens=x, input_mask=jnp.ones(x.shape, dtype=jnp.int32)
      )
      for x in dummy_input
  ] * repeat

def dummy_gen_model_input_fn(x: peft_trainer.TrainingInput):
  return {
      'input_tokens': x.input_tokens,
      'input_mask': x.input_mask,
      'positions': jnp.arange(x.input_tokens.shape[1]),
      'attention_mask': jnp.ones_like(x.input_tokens),
  }

actor_trainer.with_gen_model_input_fn(dummy_gen_model_input_fn)
import jax
jax.profiler.start_trace("gs://linchai-bucket/tunix_native_xprof/jax_traces_tunix_vanilla/grpo")
with mesh:
    actor_trainer.train(dummy_datasets(1, 4))
jax.profiler.stop_trace()

2025-08-22 21:51:00.707891: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755899460.716563 2602489 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755899460.719114 2602489 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1755899460.728584 2602489 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755899460.728595 2602489 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755899460.728597 2602489 computation_placer.cc:177] computation placer alr

Training:   0%|          | 0/4 [00:00<?, ?step/s]

In [30]:
# # RL cluster
# from transformers import AutoTokenizer
# rl_cluster = rl_cluster_lib.RLCluster(
#     actor=llama3_policy,
#     reference=llama3_ref,
#     tokenizer=AutoTokenizer.from_pretrained(MODEL_CP_PATH),
#     cluster_config=cluster_config,
# )

# # GRPO Trainer
# grpo_trainer = GrpoLearner(
#     rl_cluster=rl_cluster,
#     reward_fns=[
#         match_format_exactly,
#         match_format_approximately,
#         check_answer,
#         check_numbers,
#     ],
#     grpo_config=grpo_config,
# )

In [31]:
show_hbm_usage()

Using 9.6 GiB / 31.2 GiB (30.823993%) on TPU_0(process=0,(0,0,0,0))
Using 9.6 GiB / 31.2 GiB (30.816944%) on TPU_1(process=0,(1,0,0,0))
Using 9.6 GiB / 31.2 GiB (30.816944%) on TPU_2(process=0,(0,1,0,0))
Using 9.6 GiB / 31.2 GiB (30.816944%) on TPU_3(process=0,(1,1,0,0))
Using 9.6 GiB / 31.2 GiB (30.816944%) on TPU_4(process=0,(0,2,0,0))
Using 9.6 GiB / 31.2 GiB (30.816944%) on TPU_5(process=0,(1,2,0,0))
Using 9.6 GiB / 31.2 GiB (30.816944%) on TPU_6(process=0,(0,3,0,0))
Using 9.6 GiB / 31.2 GiB (30.816944%) on TPU_7(process=0,(1,3,0,0))


In [32]:
# import time

# with mesh:
#   grpo_trainer.train(dataset)

# import jax
# jax.profiler.start_trace("gs://linchai-bucket/tunix_native_xprof/jax_traces_tunix_vanilla/grpo")
# with mesh:
#   grpo_trainer.train(dataset)
# # time.sleep(5)  # Give some time for the profiler to collect data
# jax.profiler.stop_trace()

In [33]:
show_hbm_usage()

Using 9.6 GiB / 31.2 GiB (30.823993%) on TPU_0(process=0,(0,0,0,0))
Using 9.6 GiB / 31.2 GiB (30.816944%) on TPU_1(process=0,(1,0,0,0))
Using 9.6 GiB / 31.2 GiB (30.816944%) on TPU_2(process=0,(0,1,0,0))
Using 9.6 GiB / 31.2 GiB (30.816944%) on TPU_3(process=0,(1,1,0,0))
Using 9.6 GiB / 31.2 GiB (30.816944%) on TPU_4(process=0,(0,2,0,0))
Using 9.6 GiB / 31.2 GiB (30.816944%) on TPU_5(process=0,(1,2,0,0))
Using 9.6 GiB / 31.2 GiB (30.816944%) on TPU_6(process=0,(0,3,0,0))
Using 9.6 GiB / 31.2 GiB (30.816944%) on TPU_7(process=0,(1,3,0,0))


## Evaluate

Let's evaluate our model!

In [34]:
# # Load checkpoint first.

# trained_ckpt_path = os.path.join(CKPT_DIR, str(MAX_STEPS), "model_params")

# abs_params = jax.tree.map(
#     lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
#     nnx.state(lora_gemma, nnx.LoRAParam),
# )
# checkpointer = ocp.StandardCheckpointer()
# trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_params)

# nnx.update(
#     lora_gemma,
#     jax.tree.map(
#         lambda a, b: b,
#         nnx.state(lora_gemma, nnx.LoRAParam),
#         trained_lora_params,
#     ),
# )

In [35]:
# gemma_tokenizer = data_lib.GemmaTokenizer()
# sampler = sampler_lib.Sampler(
#     transformer=lora_gemma,
#     tokenizer=gemma_tokenizer,
#     cache_config=sampler_lib.CacheConfig(
#         cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
#         num_layers=model_config.num_layers,
#         num_kv_heads=model_config.num_kv_heads,
#         head_dim=model_config.head_dim,
#     ),
# )

In [36]:
# (corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
#     test_dataset,
#     sampler,
#     **GENERATION_CONFIGS["greedy"],
# )
# print(
#     f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
#     f" {format_accuracy=}%"
# )

In [37]:
# for eval_example in QUALITATIVE_EVAL_EXAMPLES:
#   question = eval_example["question"]
#   answer = eval_example["answer"]
#   response = generate(
#       question,
#       sampler,
#       temperature=INFERENCE_TEMPERATURE,
#       top_k=INFERENCE_TOP_K,
#       top_p=INFERENCE_TOP_P,
#   )

#   print(f"Question:\n{question}")
#   print(f"Answer:\n{answer}")
#   print(f"Response:\n{response}")
#   print("===============")