## Imports

In [None]:
!pip install -q jsonlines
!pip install -q kagglehub

!pip install -q tensorboardX
!pip install -q grain
!pip install -q git+https://github.com/google/qwix

In [None]:
import os
import jsonlines
import functools
import humanize
import re
import urllib.request

import jax
import jax.numpy as jnp
import grain
import optax
import kagglehub

from pprint import pprint

from orbax import checkpoint as ocp
from qwix import lora
from flax import nnx
from tunix.examples.gemma_libs import data as data_lib
from tunix.examples.gemma_libs import gemma as gemma_lib
from tunix.examples.gemma_libs import params as params_lib
from tunix.examples.gemma_libs import sampler as sampler_lib
from tunix.sft import metrics_logger

## Hyperparameters

In [None]:
# Data
DATA_SRC_URL = (
    "https://raw.githubusercontent.com/openai/grade-school-math/refs/heads/"
    "master/grade_school_math/data/"
)
DATA_DIR = "./data/"
BATCH_SIZE = 4
# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.
NUM_BATCHES = 800

# Reproducibility
SEED = 42

# Model
MESH = [(1, 4), ("fsdp", "tp")]
# LoRA
RANK = 16
ALPHA = 2.0

# Train
LEARNING_RATE = 5e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
NUM_EPOCHS = 3

# GRPO
MAX_PROMPT_LENGTH = 256
TOTAL_GENERATION_STEPS = 768
NUM_GENERATIONS = 4
NUM_ITERATIONS = 4
BETA = 0.04
EPSILON = 0.2
TEMPERATURE = 0.9
TOP_P = 0.92
EVAL_EVERY_N_STEPS = 1
MAX_STEPS = 3200 * NUM_EPOCHS

# Checkpoint saving
CKPT_DIR = "./ckpts/"
SAVE_INTERVAL_STEPS = 1000
MAX_TO_KEEP = 1

## Utility functions

In [None]:
def load_jsonl(path):
    with jsonlines.open(path) as reader:
        data = list(reader)
    return data

In [None]:
def show_hbm_usage():
    fmt_size = functools.partial(humanize.naturalsize, binary=True)

    for d in jax.local_devices():
        stats = d.memory_stats()
        used = stats["bytes_in_use"]
        limit = stats["bytes_limit"]
        print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")

In [None]:
def unbatched_generate(sampler, question, total_generation_steps=768):
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=question,
        ),
    ]

    out_data = sampler(
        input_strings=input_batch,
        total_generation_steps=total_generation_steps,
        echo=False,
    )
    return out_data.text[0]

## Data preprocessing

First, let's define some special tokens. We instruct the model to first reason
between the `<start_working_out>` and `<end_working_out>` tokens. After
reasoning, we expect it to provide the answer between the `<SOLUTION>` and
`</SOLUTION>` tokens.

In [None]:
reasoning_start = "<start_working_out>"
reasoning_end = "<end_working_out>"
solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"

SYSTEM_PROMPT = f"""You are given a problem.
Think about the problem and provide your working out.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start} and {solution_end}"""

TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model
"""

We use OpenAI's GSM8K dataset. GSM8K comprises grade school math word problems.

In [None]:
# Download data

if not os.path.exists(DATA_DIR):
    os.makedirs(DATA_DIR)

urllib.request.urlretrieve(
    os.path.join(DATA_SRC_URL, "train.jsonl"),
    os.path.join(DATA_DIR, "data.jsonl"),
)

In [None]:
def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()


def get_dataset(path: str) -> grain.MapDataset:

    data = load_jsonl(path)

    dataset = (
        grain.MapDataset.source(data)
        .shuffle(seed=SEED)
        .map(
            lambda x: {
                # passed to model forward pass
                "prompts": TEMPLATE.format(
                    system_prompt=SYSTEM_PROMPT, question=x["question"]
                ),
                # passed to reward functions
                "question": x["question"],
                # passed to reward functions
                "answer": extract_hash_answer(x["answer"]),
            }
        )
    )
    return dataset

In [None]:
dataset = (
    get_dataset(os.path.join(DATA_DIR, "data.jsonl"))
    .batch(BATCH_SIZE)[:NUM_BATCHES]
    .repeat(NUM_EPOCHS)
)

Let's see how one batch of the dataset looks like!


In [None]:
for element in dataset:
    pprint(element)
    break

## Load policy model and reference model

In [None]:
ckpt_path = kagglehub.model_download(f"abheesht75/gemma-tunix/jax/2b-it")

In [None]:
def get_ref_model(shard=False):

    mesh = jax.make_mesh(*MESH)
    abs_gemma: nnx.Module = nnx.eval_shape(
        lambda: gemma_lib.Transformer(
            gemma_lib.TransformerConfig.gemma_2b(), rngs=nnx.Rngs(params=0)
        )
    )
    abs_state = nnx.state(abs_gemma)
    abs_state = jax.tree.map(
        lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.float32, sharding=s),
        abs_state,
        nnx.get_named_sharding(abs_state, mesh),
    )
    checkpointer = ocp.StandardCheckpointer()
    restored_params = checkpointer.restore(ckpt_path, target=abs_state)

    graph_def, _ = nnx.split(abs_gemma)
    gemma = nnx.merge(graph_def, restored_params)
    return gemma, mesh


def get_lora_model(base_model, mesh):
    lora_provider = lora.LoraProvider(
        module_path=(
            ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
            ".*attn_vec_einsum"
        ),
        rank=RANK,
        alpha=ALPHA,
    )

    model_input = base_model.get_model_input()
    lora_model = lora.apply_lora_to_model(base_model, lora_provider, **model_input)

    with mesh:
        state = nnx.state(lora_model)
        pspecs = nnx.get_partition_spec(state)
        sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
        nnx.update(lora_model, sharded_state)

    return lora_model

In [None]:
# Reference model
gemma, mesh = get_ref_model(ckpt_path)
nnx.display(gemma)

In [None]:
# Policy model
lora_gemma = get_lora_model(gemma, mesh=mesh)
nnx.display(lora_gemma)

## Define reward functions

We define four reward functions:

- reward if the format of the output exactly matches the instruction given in
`TEMPLATE`;
- reward if the format of the output approximately matches the instruction given
in `TEMPLATE`;
- reward if the answer is correct/partially correct;
- Sometimes, the text between `<SOLUTION>`, `</SOLUTION>` might not be one number.
So, extract the number, and reward the model if the answer is correct.

In [None]:
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

match_format.search(
    "<start_working_out>Let me think!<end_working_out><SOLUTION>2</SOLUTION>",
)

In [None]:
def match_format_exactly(prompts, completions, **kargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion
        # Match if format is seen exactly!
        if match_format.search(response) is not None:
            score += 3.0
        scores.append(score)
    return scores

In [None]:
def match_format_approximately(prompts, completions, **kargs):
    scores = []

    for completion in completions:
        score = 0
        response = completion
        # Count how many keywords are seen - we penalize if too many!
        # If we see 1, then plus some points!
        score += 0.5 if response.count(reasoning_start) == 1 else -0.5
        score += 0.5 if response.count(reasoning_end) == 1 else -0.5
        score += 0.5 if response.count(solution_start) == 1 else -0.5
        score += 0.5 if response.count(solution_end) == 1 else -0.5
        scores.append(score)
    return scores

In [None]:
def check_answer(prompts, completions, answer, **kargs):
    responses = completions

    extracted_responses = [
        guess.group(1) if (guess := match_format.search(r)) is not None else None
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(0)
            continue
        # Correct answer gets 3 points!
        if guess == true_answer:
            score += 3.0
        # Match if spaces are seen
        elif guess.strip() == true_answer.strip():
            score += 1.5
        else:
            # We also reward it if the answer is close via ratios!
            # Ie if the answer is within some range, reward it!
            try:
                ratio = float(guess) / float(true_answer)
                if ratio >= 0.9 and ratio <= 1.1:
                    score += 0.5
                elif ratio >= 0.8 and ratio <= 1.2:
                    score += 0.25
                else:
                    score -= 1.0  # Penalize wrong answers
            except:
                score -= 0.5  # Penalize
        scores.append(score)
    return scores

In [None]:
match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL
)
match_numbers.findall("<SOLUTION>  0.34  </SOLUTION>")

In [None]:
def check_numbers(prompts, completions, answer, **kargs):
    question = kargs["question"]
    # question = prompts[0][-1]["content"]
    responses = completions

    extracted_responses = [
        guess.group(1) if (guess := match_numbers.search(r)) is not None else None
        for r in responses
    ]

    scores = []
    print("START ============================")
    print(f"Question: {question[0]}")
    print(f"Answer: {answer[0]}")
    print(f"Response: {responses[0]}")
    print(f"Extracted: {extracted_responses[0]}")
    print("END ==============================")
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(0)
            continue
        # Convert to numbers
        try:
            true_answer = float(true_answer.strip())
            guess = float(guess.strip())
            scores.append(1.5 if guess == true_answer else 0.0)
        except:
            scores.append(0)
            continue
    return scores

## Generate

Before we train the model, let's see the model outputs so that we can compare
them later.

In [None]:
gemma_tokenizer = data_lib.GemmaTokenizer()
sampler = sampler_lib.Sampler(transformer=lora_gemma, vocab=gemma_tokenizer.vocab)

question = (
    "Trevor and two of his neighborhood friends go to the toy shop every year "
    "to buy toys. Trevor always spends $20 more than his friend Reed on toys, "
    "and Reed spends 2 times as much money as their friend Quinn on the toys. "
    "If Trevor spends $80 every year to buy his toys, calculate how much money "
    "in total the three spend in 4 years."
)
print(unbatched_generate(sampler, question))

## Train!

In [None]:
# Ckpt saving
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
)

# Metrics logger
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/tensorboard/grpo", flush_every_n_steps=20
)

In [None]:
# Training config
training_config = GrpoTrainingConfig(
    max_prompt_length=MAX_PROMPT_LENGTH,
    total_generation_steps=TOTAL_GENERATION_STEPS,
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
    temperature=TEMPERATURE,
    top_p=TOP_P,
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=MAX_STEPS,
    # max_grad_norm=0.1,
    # metrics logging
    metrics_logging_options=metrics_logging_options,
    # checkpoint saving
    checkpoint_root_directory=CKPT_DIR,
    checkpointing_options=checkpointing_options,
)

In [None]:
gemma_tokenizer = data_lib.GemmaTokenizer()
sampler = sampler_lib.Sampler(
    transformer=lora_gemma,
    vocab=gemma_tokenizer.vocab,
)

grpo_trainer = GrpoTrainer(
    model=lora_gemma,
    ref_model=gemma,  # use the base model as reference
    reward_fns=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    sampler=sampler,
    optimizer=optax.adamw(
        learning_rate=LEARNING_RATE,
        b1=B1,
        b2=B2,
        weight_decay=WEIGHT_DECAY,
    ),
    training_config=training_config,
)

In [None]:
if SHARD:
    with mesh:
        if DO_MEM_PROFILING:
            with profile_and_capture_log("gemma_benchmark"):
                grpo_trainer.train(dataset)
        else:
            grpo_trainer.train(dataset)
else:
    if DO_MEM_PROFILING:
        with profile_and_capture_log("gemma_benchmark"):
            grpo_trainer.train(dataset)
    else:
        grpo_trainer.train(dataset)