# Google Tunix Hackathon - Multi-Domain Reasoning Training

**Strategy**: OpenThoughts + GSM8K with GRPO
**Model**: Gemma 3 1B IT + LoRA

## Key Improvements
1. Multi-reward system (format + accuracy)
2. Optimized hyperparameters for 9-hour TPU session
3. Enhanced evaluation metrics

In [None]:
import os
os.environ["HF_HUB_DISABLE_XET"] = "1"

!pip install -q kagglehub ipywidgets
!pip install -q tensorflow tensorflow_datasets tensorboardX
!pip install -q transformers grain datasets
!pip install "google-tunix[prod]==0.1.3"
!pip uninstall -q -y flax
!pip install -U flax

In [None]:
import functools
import gc
import os
import re
from pprint import pprint

from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
from pathlib import Path
import qwix
from tqdm.auto import tqdm
from tunix.generate import sampler as sampler_lib
from tunix.models.gemma3 import params, model
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger
from datasets import load_dataset

In [None]:
# ====== Model & LoRA ======
LORA_RANK = 32
LORA_ALPHA = 32.0

# ====== Sharding ======
MESH = [(1, 4), ("fsdp", "tp")]

# ====== GRPO ======
MAX_PROMPT_LENGTH = 256
TOTAL_GENERATION_STEPS = 512
TEMPERATURE = 0.9
TOP_P = 1.0
TOP_K = 50
NUM_GENERATIONS = 4
NUM_ITERATIONS = 1
BETA = 0.08
EPSILON = 0.2

# ====== Training ======
TRAIN_MICRO_BATCH_SIZE = 2
NUM_BATCHES = 3738
NUM_TEST_BATCHES = 100
EVAL_EVERY_N_STEPS = 10
NUM_EPOCHS = 1
MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * NUM_EPOCHS)

# ====== Optimizer ======
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
WARMUP_STEPS = int(0.1 * MAX_STEPS)
MAX_GRAD_NORM = 0.1

# ====== Checkpointing ======
INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
CKPT_DIR = "/tmp/content/ckpts/"
SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4

In [None]:
from datasets import load_dataset, concatenate_datasets
import random

# Reasoning format tags
reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"
solution_start = "<answer>"
solution_end = "</answer>"

SYSTEM_PROMPT = f"""You are given a problem. Think about the problem and \
provide your reasoning. Place it between {reasoning_start} and \
{reasoning_end}. Then, provide the final answer (i.e., just one numerical \
value) between {solution_start} and {solution_end}."""

TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model"""

def extract_hash_answer(text):
    """Extract answer from GSM8K format (#### answer)."""
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

def format_gsm8k_for_grpo(example):
    """Format GSM8K data for GRPO phase."""
    return {
        'prompts': TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=example['question']
        ),
        'question': example['question'],
        'answer': extract_hash_answer(example['answer'])
    }

print("Loading GSM8K for GRPO...")
grpo_dataset = load_dataset(
    "gsm8k",
    "main",
    split="train"
)
grpo_dataset = grpo_dataset.map(format_gsm8k_for_grpo, remove_columns=grpo_dataset.column_names)
train_dataset = grpo_dataset.select(range(min(len(grpo_dataset), NUM_BATCHES * TRAIN_MICRO_BATCH_SIZE)))

print(f"GRPO dataset size: {len(train_dataset)}")

In [None]:
import re

# RegEx for format matching
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL
)

def match_format_exactly(prompts, completions, **kwargs):
    """Reward if format matches exactly."""
    return [
        0 if match_format.search(response) is None else 3.0
        for response in completions
    ]

def match_format_approximately(prompts, completions, **kwargs):
    """Reward if format matches partially."""
    scores = []
    for completion in completions:
        score = 0
        response = completion
        score += 0.5 if response.count(reasoning_start) == 1 else -0.5
        score += 0.5 if response.count(reasoning_end) == 1 else -0.5
        score += 0.5 if response.count(solution_start) == 1 else -0.5
        score += 0.5 if response.count(solution_end) == 1 else -0.5
        scores.append(score)
    return scores

def check_answer(prompts, completions, answer, **kwargs):
    """Reward if the answer is correct."""
    responses = completions
    extracted_responses = [
        guess.group(1) if (guess := match_format.search(r)) is not None else None
        for r in responses
    ]
    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(0)
            continue
        if guess == true_answer:
            score += 3.0
        elif guess.strip() == true_answer.strip():
            score += 1.5
        else:
            try:
                ratio = float(guess) / float(true_answer)
                if ratio >= 0.9 and ratio <= 1.1:
                    score += 0.5
                elif ratio >= 0.8 and ratio <= 1.2:
                    score += 0.25
                else:
                    score -= 1.0
            except:
                score -= 0.5
        scores.append(score)
    return scores

def check_numbers(prompts, completions, answer, **kwargs):
    """Extract numbers and check if correct."""
    responses = completions
    extracted_responses = [
        guess.group(1) if (guess := match_numbers.search(r)) is not None else None
        for r in responses
    ]
    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(0)
            continue
        try:
            true_answer_num = float(true_answer.strip())
            guess_num = float(guess.strip())
            scores.append(1.5 if guess_num == true_answer_num else 0.0)
        except:
            scores.append(0)
            continue
    return scores

In [None]:
# Load Model
!rm -rf {INTERMEDIATE_CKPT_DIR}/*
!rm -rf {CKPT_DIR}/*

MODEL_CP_PATH = params.GEMMA3_1B_IT
mesh = jax.make_mesh(*MESH)
config = model.ModelConfig.gemma3_1b()
gemma = params.create_model_from_checkpoint(MODEL_CP_PATH, config)
tokenizer = params.create_tokenizer()

checkpointer = ocp.StandardCheckpointer()
_, state = nnx.split(gemma)
checkpointer.save(os.path.join(INTERMEDIATE_CKPT_DIR, "state"), state)
checkpointer.wait_until_finished()
del gemma, state
gc.collect()

In [None]:
def get_gemma_ref_model(ckpt_path):
    mesh = jax.make_mesh(*MESH)
    model_config = model.ModelConfig.gemma3_1b()
    abs_gemma: nnx.Module = nnx.eval_shape(
        lambda: params.create_model_from_checkpoint(MODEL_CP_PATH, config)
    )
    abs_state = nnx.state(abs_gemma)
    abs_state = jax.tree.map(
        lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
        abs_state,
        nnx.get_named_sharding(abs_state, mesh),
    )
    checkpointer = ocp.StandardCheckpointer()
    restored_params = checkpointer.restore(ckpt_path, target=abs_state)
    graph_def, _ = nnx.split(abs_gemma)
    gemma = nnx.merge(graph_def, restored_params)
    return gemma, mesh, model_config

def get_lora_model(base_model, mesh):
    lora_provider = qwix.LoraProvider(
        module_path=(
            ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
            ".*attn_vec_einsum"
        ),
        rank=LORA_RANK,
        alpha=LORA_ALPHA,
    )
    model_input = base_model.get_model_input()
    lora_model = qwix.apply_lora_to_model(
        base_model, lora_provider, **model_input
    )
    with mesh:
        state = nnx.state(lora_model)
        pspecs = nnx.get_partition_spec(state)
        sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
        nnx.update(lora_model, sharded_state)
    return lora_model

ref_model, mesh, model_config = get_gemma_ref_model(
    ckpt_path=os.path.join(INTERMEDIATE_CKPT_DIR, "state")
)
lora_policy = get_lora_model(ref_model, mesh=mesh)

In [None]:
# Optimizer
optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)
optimizer = optax.chain(
    optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
    optimizer,
)

# GRPO Config
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=MAX_STEPS,
        mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        metrics_logging_options=None,
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=ocp.CheckpointManagerOptions(
            save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
        ),
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
        eos_tokens=[1,106],
    ),
)

grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)

# Trainer
rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,
    reference=ref_model,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    grpo_config=grpo_config,
)

print("Starting training...")
with mesh:
    grpo_trainer.train(train_dataset)