# GSM8K Reasoning Enhancement: SFT + GRPO Pipeline

## Training Strategy

**Stage 1 - SFT (1 Epoch)**  
Teach format using 5K samples (3K B + 1K C + 1K A). Single epoch to avoid overfitting.

**Stage 2 - GRPO (2 Epochs over 7.4K samples)**  
Use entire GSM8K dataset with curriculum learning:
- Epoch 1: Easy problems first (SFT data, model succeeds)
- Epoch 2: Hard problems (base model failures)

## Data Strategy: 3K/1K/1K Split

| Type | Count | Description |
|------|-------|-------------|
| **B** | 3,000 | Rich chain-of-thought |
| **C** | 1,000 | Alternative reasoning |
| **A** | 1,000 | Concise baseline |

---
**Hardware**: TPU v5e-8

In [None]:
!pip install -q "numpy<2" "flax==0.12.0" "google-tunix[prod]==0.1.3"
!pip install -q kagglehub grain datasets pandas huggingface_hub

import jax
print(f"JAX {jax.__version__} | {jax.device_count()} devices")
assert jax.device_count() == 8, "Requires TPU v5e-8"

In [None]:
import os, gc, re, json
import pandas as pd
import numpy as np
from pathlib import Path
from flax import nnx
import grain
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
import qwix

from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma import model as gemma_lib
from tunix.models.gemma import params as params_lib
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft.peft_trainer import TrainingInput

MESH = [(2, 4), ("fsdp", "tp")]
LORA_RANK, LORA_ALPHA = 64, 64.0

# SFT: 1 epoch only
SFT_EPOCHS = 1
SFT_BATCH = 4
SFT_LR = 1e-5
MAX_SEQ = 1024

# GRPO: 1 epochs over full 7.4K
GRPO_EPOCHS = 1
GRPO_BATCH = 2
GRPO_LR = 3e-6
NUM_GENS = 4
KL_COEF = 0.08

WORK_DIR = Path("/kaggle/working")
CKPT_DIR = WORK_DIR / "ckpts"
SFT_CKPT = CKPT_DIR / "sft"
GRPO_CKPT = CKPT_DIR / "grpo"
for p in [CKPT_DIR, SFT_CKPT, GRPO_CKPT]: p.mkdir(parents=True, exist_ok=True)

In [None]:
from kaggle_secrets import UserSecretsClient
secrets = UserSecretsClient()
os.environ["KAGGLE_USERNAME"] = secrets.get_secret("kaggle_username")
os.environ["KAGGLE_KEY"] = secrets.get_secret("kaggle_key")

In [None]:
# Data: 3K/1K/1K for SFT, Full 7.4K for GRPO
try:
    ds_path = kagglehub.dataset_download("bazingawaggle/gsm8k-merged")
    df = pd.read_csv(f"{ds_path}/gsm8k_merged.csv")
except:
    from datasets import load_dataset
    raw = load_dataset("gsm8k", "main", split="train")
    df = pd.DataFrame([{"question": x["question"], "gold": x["answer"].split("####")[-1].strip(), "has_B": True, "has_C": False} for x in raw])

df = df.sample(frac=1, random_state=42).reset_index(drop=True)

# SFT: 3K B + 1K C + 1K A
sft_samples, used_ids = [], set()

for idx in df[df.get('has_B', True) == True].index[:3000]:
    used_ids.add(idx); row = df.loc[idx]
    sft_samples.append({"q": row["question"], "a": row.get('formatted_B', f"<reasoning>Solving step by step.</reasoning><answer>{row['gold']}</answer>"), "type": "B"})

for idx in df[(df.get('has_C', False) == True) & (~df.index.isin(used_ids))].index[:1000]:
    used_ids.add(idx); row = df.loc[idx]
    sft_samples.append({"q": row["question"], "a": row.get('formatted_C', f"<reasoning>Alternative approach.</reasoning><answer>{row['gold']}</answer>"), "type": "C"})

for idx in df[~df.index.isin(used_ids)].index[:1000]:
    used_ids.add(idx); row = df.loc[idx]
    a = row.get('formatted_A', row.get('reasoning_A', f"<reasoning>Direct calculation.</reasoning><answer>{row['gold']}</answer>"))
    sft_samples.append({"q": row["question"], "a": a if pd.notna(a) else f"<reasoning>Direct.</reasoning><answer>{row['gold']}</answer>", "type": "A"})

print(f"SFT: {len(sft_samples)} (B:{len([x for x in sft_samples if x['type']=='B'])}, C:{len([x for x in sft_samples if x['type']=='C'])}, A:{len([x for x in sft_samples if x['type']=='A'])})")

# GRPO: Full dataset with curriculum (Easy first, Hard later)
easy = df[df.get('has_B', True) == True].copy()
hard = df[df.get('has_C', False) == True].copy()
rest = df[~df.index.isin(easy.index) & ~df.index.isin(hard.index)].copy()
grpo_curriculum = pd.concat([easy, rest, hard]).drop_duplicates('question').reset_index(drop=True)

grpo_samples = [{"q": r["question"], "gold": str(r["gold"])} for _, r in grpo_curriculum.iterrows()]
print(f"GRPO: {len(grpo_samples)} samples (Easy→Medium→Hard curriculum)")

In [None]:
# Load model
model_path = kagglehub.model_download("google/gemma-2/flax/gemma2-2b-it")
params = params_lib.load_and_format_params(f"{model_path}/gemma2-2b-it")
mesh = jax.make_mesh(*MESH)

base_model = gemma_lib.Transformer.from_params(params, version="2-2b-it")
with mesh:
    state = nnx.state(base_model)
    nnx.update(base_model, jax.lax.with_sharding_constraint(state, nnx.get_partition_spec(state)))

# ✅ Fixed: Added tokenizer_type argument
tokenizer = tokenizer_lib.Tokenizer("sentencepiece", f"{model_path}/tokenizer.model")
del params; gc.collect()

lora_cfg = qwix.LoraProvider(module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum", rank=LORA_RANK, alpha=LORA_ALPHA)
model = qwix.apply_lora_to_model(base_model, lora_cfg, **base_model.get_model_input())
with mesh:
    state = nnx.state(model)
    nnx.update(model, jax.lax.with_sharding_constraint(state, nnx.get_partition_spec(state)))
print("Model ready")

## Stage 1: SFT (1 Epoch)

In [None]:
CHAT_TEMPLATE = "<start_of_turn>user\n{q}<end_of_turn>\n<start_of_turn>model\n{a}<end_of_turn>"

def make_sft_input(ex):
    toks = tokenizer.encode(CHAT_TEMPLATE.format(q=ex['q'], a=ex['a']))
    seq_len = len(toks)
    toks = toks[:MAX_SEQ] if seq_len > MAX_SEQ else toks + [0]*(MAX_SEQ - seq_len)
    arr = np.array(toks, dtype=np.int32)
    
    # Positions array
    positions = np.arange(MAX_SEQ, dtype=np.int32)
    
    # Input mask (1 for real tokens, 0 for padding)
    input_mask = (arr != 0).astype(np.float32)
    
    # Causal attention mask for autoregressive training
    # Shape: (seq_len, seq_len) - lower triangular matrix
    causal_mask = np.tril(np.ones((MAX_SEQ, MAX_SEQ), dtype=np.float32))
    
    return {
        "input_tokens": arr,
        "positions": positions,
        "attention_mask": causal_mask,
        "input_mask": input_mask
    }

sft_ds = grain.MapDataset.source(sft_samples).map(make_sft_input).batch(SFT_BATCH)

from tunix.sft.peft_trainer import PeftTrainer, TrainingConfig
print(f"--- SFT: {len(sft_samples)} samples, {SFT_EPOCHS} epoch ---")

# Calculate total steps
total_steps = (len(sft_samples) // SFT_BATCH) * SFT_EPOCHS

# Create optimizer
optimizer = optax.adamw(learning_rate=SFT_LR)

# Create TrainingConfig
training_config = TrainingConfig(
    eval_every_n_steps=100,
    max_steps=total_steps,
    checkpoint_root_directory=str(SFT_CKPT),
    data_sharding_axis=("fsdp",),
    pbar_description="SFT Training"
)

# Create PeftTrainer instance
trainer = PeftTrainer(
    model=model,
    optimizer=optimizer,
    training_config=training_config
)

# Run training
trainer.train(sft_ds)

# Save checkpoint
ckptr = ocp.StandardCheckpointer()
ckptr.save(str(SFT_CKPT / "final"), nnx.state(model, nnx.LoRAParam))
ckptr.wait_until_finished()
print("SFT complete")

## Stage 2: GRPO (1 Epochs, Full 7.4K, Curriculum)

In [None]:
# THE COMPLETE REWARD FUNCTION
#
# Components:
#   1. Correctness: +3.0 if exact match, else penalty
#   2. Format (Strict): +2.0 if valid XML structure
#   3. Format (Soft): +0.1 per valid tag (partial credit)
#   4. Reasoning Quality: +0.5 if reasoning contains math operations

def extract_answer_num(text):
    """Extract numeric answer from <answer> tags or **Answer:** format"""
    # Try <answer> tags first
    m = re.search(r'<answer>(.*?)</answer>', text, re.DOTALL)
    if m:
        nums = re.findall(r"[-+]?\d*\.?\d+", m.group(1))
        if nums: return float(nums[-1])
    # Fallback: **Answer:** markdown format
    m = re.search(r'\*\*Answer[:\*]*\s*(.+?)(?:\n|$)', text, re.IGNORECASE)
    if m:
        nums = re.findall(r"[-+]?\d*\.?\d+", m.group(1))
        if nums: return float(nums[-1])
    # Last resort: final number in text
    nums = re.findall(r"[-+]?\d*\.?\d+", text)
    return float(nums[-1]) if nums else None

def check_ascii_math(text):
    """Check if reasoning contains mathematical operations"""
    math_ops = ['+', '-', '*', '/', '=', '%']
    math_words = ['calculate', 'compute', 'multiply', 'divide', 'add', 'subtract', 'total', 'sum', 'equals']
    has_ops = any(op in text for op in math_ops)
    has_words = any(w in text.lower() for w in math_words)
    return has_ops or has_words

_ITER = [0]  # mutable counter

def gsm8k_reward_final(prompts, completions, answer, **kwargs):
    """Complete reward function with all components"""
    _ITER[0] += 1
    rewards = []
    
    for i, (prompt, completion, gold) in enumerate(zip(prompts, completions, answer)):
        # Parse gold answer
        try:
            gold_num = float(re.findall(r"[-+]?\d*\.?\d+", str(gold))[-1])
        except:
            gold_num = None
        
        pred_num = extract_answer_num(completion)
        score = 0.0
        
        # 1. CORRECTNESS (main signal)
        if gold_num is not None and pred_num is not None:
            if abs(pred_num - gold_num) < 1e-3:
                score += 3.0  # Correct answer
                if i == 0 and _ITER[0] % 50 == 0:
                    print(f"[Step {_ITER[0]}] CORRECT: pred={pred_num}, gold={gold_num}")
            else:
                score -= 0.5  # Wrong answer penalty
        else:
            score -= 0.3  # No answer extracted
        
        # 2. FORMAT (strict)
        has_reasoning = '<reasoning>' in completion and '</reasoning>' in completion
        has_answer = '<answer>' in completion and '</answer>' in completion
        if has_reasoning and has_answer:
            score += 2.0  # Full format compliance
        
        # 3. FORMAT (soft - partial credit)
        score += 0.1 if '<reasoning>' in completion else 0
        score += 0.1 if '</reasoning>' in completion else 0
        score += 0.1 if '<answer>' in completion else 0
        score += 0.1 if '</answer>' in completion else 0
        
        # 4. REASONING QUALITY
        reasoning_match = re.search(r'<reasoning>(.*?)</reasoning>', completion, re.DOTALL)
        if reasoning_match:
            reasoning_text = reasoning_match.group(1)
            if len(reasoning_text) > 50 and check_ascii_math(reasoning_text):
                score += 0.5  # Good reasoning
        
        rewards.append(score)
    
    return rewards

print("Reward function ready")

In [None]:
# GRPO setup
GRPO_PROMPT = "<start_of_turn>user\nSolve step-by-step. Use <reasoning> for work, <answer> for final answer.\n\n{q}<end_of_turn>\n<start_of_turn>model"

grpo_formatted = [{"prompts": GRPO_PROMPT.format(q=x["q"]), "question": x["q"], "answer": x["gold"]} for x in grpo_samples]

# 2 epochs = repeat dataset
grpo_data_2epochs = grpo_formatted * GRPO_EPOCHS
grpo_ds = grain.MapDataset.source(grpo_data_2epochs).batch(GRPO_BATCH)
GRPO_STEPS = len(grpo_ds)

print(f"GRPO: {len(grpo_samples)} x {GRPO_EPOCHS} epochs = {len(grpo_data_2epochs)} samples, {GRPO_STEPS} steps")

In [None]:
# GRPO training - With Reward Logging
del base_model
gc.collect()

# === Reward Logging Callback ===
class RewardLogger:
    def __init__(self):
        self.rewards = []
        self.steps = []
    
    def log(self, step, reward):
        self.steps.append(step)
        self.rewards.append(reward)
        print(f"Step {step}: Reward = {reward:.4f}")

reward_logger = RewardLogger()

# Wrap reward function to log values
def gsm8k_reward_with_logging(prompts, completions, **kwargs):
    rewards = gsm8k_reward_final(prompts, completions, **kwargs)
    avg_reward = float(jnp.mean(rewards)) if hasattr(rewards, 'mean') else sum(rewards)/len(rewards)
    print(f"  → Batch Reward: {avg_reward:.4f}")
    return rewards

opt = optax.chain(
    optax.clip_by_global_norm(0.1),
    optax.adamw(optax.schedules.warmup_cosine_decay_schedule(0, GRPO_LR, int(0.1*GRPO_STEPS), GRPO_STEPS, 0), b1=0.9, b2=0.99, weight_decay=0.1)
)

cluster_cfg = rl_cluster_lib.ClusterConfig(
    role_to_mesh={r: mesh for r in [rl_cluster_lib.Role.ACTOR, rl_cluster_lib.Role.REFERENCE, rl_cluster_lib.Role.ROLLOUT]},
    rollout_engine='vanilla',
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=opt,
        max_steps=GRPO_STEPS,
        mini_batch_size=GRPO_BATCH,
        checkpoint_root_directory=str(GRPO_CKPT),
        eval_every_n_steps=100  
    ),
    rollout_config=base_rollout.RolloutConfig(max_tokens_to_generate=512, max_prompt_length=256, kv_cache_size=1024, temperature=0.9, top_k=50)
)

cluster = rl_cluster_lib.RLCluster(
    actor=model,
    reference=model,
    tokenizer=tokenizer,
    cluster_config=cluster_cfg
)

# ✅ Use wrapped reward function with logging
trainer = GRPOLearner(
    rl_cluster=cluster,
    reward_fns=[gsm8k_reward_with_logging],  # Changed!
    grpo_config=GRPOConfig(num_generations=NUM_GENS, beta=KL_COEF)
)

print(f"--- GRPO: {GRPO_STEPS} steps, {GRPO_EPOCHS} epochs ---")
with mesh:
    trainer.train(grpo_ds)
print("GRPO complete")

In [None]:
import zipfile
final = GRPO_CKPT / "final"
ckptr.save(str(final), nnx.state(model)); ckptr.wait_until_finished()

zip_out = WORK_DIR / "submission.zip"
with zipfile.ZipFile(zip_out, "w", zipfile.ZIP_DEFLATED) as zf:
    for root, _, files in os.walk(final):
        for f in files: zf.write(os.path.join(root, f), os.path.relpath(os.path.join(root, f), final))
print(f"Saved: {zip_out}")