# 00 - SFT Warmup for GRPO Training

This notebook trains the model on GSM8K with XML format before GRPO training.

**Why SFT warmup?**
- GRPO needs variance in rewards to learn
- Gemma 3 1B outputs `reasoning:` text format, not `<reasoning>` XML tags
- Without SFT, all generations get the same -2.0 reward = no gradient
- DeepSeek R1 used "cold-start data" (SFT) before RL for the same reason

**Target format:**
```
<reasoning>step-by-step solution</reasoning><answer>final answer</answer>
```

## Cell 1: Environment Setup & Imports

In [1]:
import os

import functools
import gc
import time
from pathlib import Path

from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import numpy as np
import optax
from orbax import checkpoint as ocp
import qwix
from tqdm.auto import tqdm
from huggingface_hub import snapshot_download
from datasets import load_dataset

# Tunix imports
from tunix.generate import sampler as sampler_lib
from tunix.models.gemma3 import model as gemma3_model
from tunix.models.gemma3 import params_safetensors as gemma3_safetensors
from tunix.sft import peft_trainer
from tunix.sft import metrics_logger

# Our library
from tunix_hack.models import load_tokenizer
from tunix_hack.inference import create_sampler

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")

JAX version: 0.8.1
Devices: [CudaDevice(id=0)]
Default backend: gpu


W1203 22:11:56.865842  510955 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1203 22:11:56.867944  510896 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.


## Cell 2: Configuration Constants

In [2]:
# Run configuration
RUN_NAME = input("Enter run name (e.g., 'warmup', 'sft_v1'): ").strip()
if not RUN_NAME:
    RUN_NAME = f"sft_{int(time.time())}"
print(f"Run name: {RUN_NAME}")

# Model configuration
MODEL_FAMILY = "gemma3"
MODEL_VERSION = "gemma3-1b-it"

# Mesh configuration for single GPU
MESH = ((1, 1), ("fsdp", "tp"))

# SFT Training configuration
BATCH_SIZE = 1              # Reduced for RTX 3090 memory
MAX_SEQ_LENGTH = 768        # Max sequence length (prompt + response)
MAX_STEPS = 500             # Brief warmup - just teach format
EVAL_EVERY_N_STEPS = 50     # Evaluation frequency
NUM_EPOCHS = 1              # Single epoch for warmup
TRAIN_FRACTION = 0.9        # Train/val split

# Optimizer hyperparameters (higher LR than GRPO)
LEARNING_RATE = 1e-4        # Higher than GRPO's 3e-6 for faster format learning
WEIGHT_DECAY = 0.01
B1 = 0.9
B2 = 0.99
WARMUP_STEPS = int(0.1 * MAX_STEPS)  # 10% warmup

# LoRA configuration - same as GRPO for compatibility
RANK = 16
ALPHA = 32

# Inference configuration (for testing)
TOTAL_GENERATION_STEPS = 512
GENERATION_CONFIGS = {
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
}

# Paths
PROJECT_ROOT = Path("/home/jimnix/gitrepos/tunix-hack")
SFT_CKPT_DIR = str(PROJECT_ROOT / "outputs" / "checkpoints" / "sft" / RUN_NAME)
TENSORBOARD_DIR = str(PROJECT_ROOT / "tmp" / "tensorboard" / "sft" / RUN_NAME)

# Create directories
os.makedirs(SFT_CKPT_DIR, exist_ok=True)
os.makedirs(TENSORBOARD_DIR, exist_ok=True)

print(f"\nConfiguration loaded.")
print(f"  LoRA rank: {RANK}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Max steps: {MAX_STEPS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"Checkpoint directory: {SFT_CKPT_DIR}")

Run name: warmup

Configuration loaded.
  LoRA rank: 16
  Batch size: 1
  Max steps: 500
  Learning rate: 0.0001
Checkpoint directory: /home/jimnix/gitrepos/tunix-hack/outputs/checkpoints/sft/warmup


## Cell 3: Download Model from HuggingFace

In [3]:
MODEL_ID = "google/gemma-3-1b-it"

print(f"Downloading {MODEL_ID} from HuggingFace...")
model_path = snapshot_download(MODEL_ID)

print(f"Model downloaded to: {model_path}")
for f in Path(model_path).iterdir():
    print(f"  {f.name}")

Downloading google/gemma-3-1b-it from HuggingFace...


Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

Model downloaded to: /home/jimnix/.cache/huggingface/hub/models--google--gemma-3-1b-it/snapshots/dcc83ea841ab6100d6b47a070329e1ba4cf78752
  .gitattributes
  special_tokens_map.json
  tokenizer.model
  model.safetensors
  added_tokens.json
  README.md
  config.json
  generation_config.json
  tokenizer_config.json
  tokenizer.json


## Cell 4: JAX Mesh Setup

In [4]:
mesh = jax.make_mesh(*MESH)
print(f"Mesh created: {mesh}")
print(f"Mesh devices: {mesh.devices}")

Mesh created: Mesh('fsdp': 1, 'tp': 1, axis_types=(Auto, Auto))
Mesh devices: [[CudaDevice(id=0)]]


  mesh = jax.make_mesh(*MESH)


## Cell 5: Load Tokenizer

In [5]:
tokenizer = load_tokenizer(model_path)
print("Tokenizer loaded.")

# Get special token IDs
PAD_ID = tokenizer.pad_id() if hasattr(tokenizer, 'pad_id') else 0
EOS_ID = 1  # Gemma EOS token

print(f"PAD_ID: {PAD_ID}")
print(f"EOS_ID: {EOS_ID}")

Tokenizer loaded.
PAD_ID: 0
EOS_ID: 1


## Cell 6: Helper Functions for Model Loading

In [6]:
def get_model_config():
    """Get Gemma 3 1B model configuration."""
    return gemma3_model.ModelConfig.gemma3_1b()


def get_base_model(model_path: str):
    """Load base Gemma 3 model from HuggingFace safetensors."""
    mesh = jax.make_mesh(*MESH)
    model_config = get_model_config()
    
    model = gemma3_safetensors.create_model_from_safe_tensors(
        model_path,
        model_config,
        mesh,
    )
    return model, mesh, model_config


def get_lora_model(base_model, mesh):
    """Apply LoRA to base model."""
    lora_provider = qwix.LoraProvider(
        module_path=(
            ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
            ".*attn_vec_einsum"
        ),
        rank=RANK,
        alpha=ALPHA,
    )
    
    model_input = base_model.get_model_input()
    lora_model = qwix.apply_lora_to_model(
        base_model, lora_provider, **model_input
    )
    
    return lora_model


print("Helper functions defined.")

Helper functions defined.


## Cell 7: Load Base Model

In [7]:
print("Loading base model...")
base_model, mesh, model_config = get_base_model(model_path)
print("Base model loaded.")

Loading base model...


  mesh = jax.make_mesh(*MESH)


Base model loaded.


## Cell 8: Create LoRA Model

In [8]:
print("Creating LoRA model...")
lora_model = get_lora_model(base_model, mesh=mesh)
print(f"LoRA model created. Rank: {RANK}, Alpha: {ALPHA}")

Creating LoRA model...
LoRA model created. Rank: 16, Alpha: 32


## Cell 9: Load GSM8K and Format for SFT

GSM8K has step-by-step solutions. We format them as:
- Input: Full prompt with question
- Target: `<reasoning>solution</reasoning><answer>N</answer>`

In [9]:
# System prompt (same as GRPO)
SYSTEM_PROMPT = """You are a math tutor. Solve the problem step by step.
Format your response EXACTLY as:
<reasoning>your step-by-step solution</reasoning><answer>final numerical answer only</answer>"""

TEMPLATE = "<start_of_turn>user\n{system_prompt}\n\n{question}<end_of_turn>\n<start_of_turn>model\n"


def extract_hash_answer(answer_text: str) -> str:
    """Extract final answer from GSM8K format (#### followed by number)."""
    if "####" in answer_text:
        return answer_text.split("####")[-1].strip()
    return answer_text.strip()


def extract_reasoning(answer_text: str) -> str:
    """Extract reasoning (everything before ####)."""
    if "####" in answer_text:
        return answer_text.split("####")[0].strip()
    return answer_text.strip()


def format_for_sft(example):
    """Format a GSM8K example for SFT training.
    
    Returns dict with 'input' (prompt) and 'target' (XML formatted response).
    """
    question = example["question"]
    full_answer = example["answer"]
    
    # Extract reasoning and answer
    reasoning = extract_reasoning(full_answer)
    answer = extract_hash_answer(full_answer)
    
    # Create input prompt
    input_text = TEMPLATE.format(
        system_prompt=SYSTEM_PROMPT,
        question=question
    )
    
    # Create XML formatted target
    target_text = f"<reasoning>{reasoning}</reasoning><answer>{answer}</answer>"
    
    return {
        "input": input_text,
        "target": target_text,
    }


# Load GSM8K dataset
print("Loading GSM8K dataset...")
gsm8k_train = load_dataset("gsm8k", "main", split="train")
gsm8k_test = load_dataset("gsm8k", "main", split="test")

print(f"Train examples: {len(gsm8k_train)}")
print(f"Test examples: {len(gsm8k_test)}")

# Format for SFT
train_data = [
    format_for_sft({"question": ex["question"], "answer": ex["answer"]})
    for ex in tqdm(gsm8k_train, desc="Formatting train data")
]

test_data = [
    format_for_sft({"question": ex["question"], "answer": ex["answer"]})
    for ex in tqdm(gsm8k_test, desc="Formatting test data")
]

# Show example
print("\nExample formatted data:")
print("INPUT:")
print(train_data[0]["input"])
print("\nTARGET:")
print(train_data[0]["target"][:500] + "..." if len(train_data[0]["target"]) > 500 else train_data[0]["target"])

Loading GSM8K dataset...
Train examples: 7473
Test examples: 1319


Formatting train data:   0%|          | 0/7473 [00:00<?, ?it/s]

Formatting test data:   0%|          | 0/1319 [00:00<?, ?it/s]


Example formatted data:
INPUT:
<start_of_turn>user
You are a math tutor. Solve the problem step by step.
Format your response EXACTLY as:
<reasoning>your step-by-step solution</reasoning><answer>final numerical answer only</answer>

Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?<end_of_turn>
<start_of_turn>model


TARGET:
<reasoning>Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.</reasoning><answer>72</answer>


## Cell 10: Tokenize and Create Datasets

Tunix SFT needs tokenized data with:
- `input_tokens`: Full sequence (prompt + target)
- `input_mask`: 1 for target tokens (compute loss), 0 for prompt tokens

In [10]:
def tokenize_for_sft(example):
    """Tokenize example for SFT training.
    
    Creates:
    - input_tokens: Full tokenized sequence (prompt + target + EOS)
    - input_mask: 1 for target tokens (loss), 0 for prompt tokens
    """
    # Tokenize prompt and target separately to know where prompt ends
    prompt_tokens = tokenizer.encode(example["input"])
    target_tokens = tokenizer.encode(example["target"])
    
    # Combine: prompt + target + EOS
    # Note: tokenizer.encode adds BOS, so we skip it for target
    full_tokens = prompt_tokens + target_tokens[1:] + [EOS_ID]  # Skip BOS from target
    
    # Truncate if too long
    if len(full_tokens) > MAX_SEQ_LENGTH:
        full_tokens = full_tokens[:MAX_SEQ_LENGTH]
    
    # Create input_mask: 0 for prompt, 1 for target (where loss is computed)
    prompt_len = len(prompt_tokens)
    input_mask = [0] * prompt_len + [1] * (len(full_tokens) - prompt_len)
    
    return {
        "input_tokens": np.array(full_tokens, dtype=np.int32),
        "input_mask": np.array(input_mask, dtype=np.int32),
    }


def pad_batch(batch, pad_id=PAD_ID):
    """Pad a batch of examples to the same length."""
    tokens_list = batch["input_tokens"]
    masks_list = batch["input_mask"]

    max_len = max(len(t) for t in tokens_list)

    padded_tokens = []
    padded_masks = []

    for tokens, mask in zip(tokens_list, masks_list):
        pad_len = max_len - len(tokens)
        padded_tokens.append(np.pad(tokens, (0, pad_len), constant_values=pad_id))
        padded_masks.append(np.pad(mask, (0, pad_len), constant_values=0))
        
    return {
        "input_tokens": np.stack(padded_tokens),
        "input_mask": np.stack(padded_masks),
    }


# Tokenize all examples
print("Tokenizing training data...")
train_tokenized = [
    tokenize_for_sft(ex) for ex in tqdm(train_data, desc="Tokenizing")
]

print("Tokenizing test data...")
test_tokenized = [
    tokenize_for_sft(ex) for ex in tqdm(test_data, desc="Tokenizing")
]

# Create grain datasets
num_train = int(len(train_tokenized) * TRAIN_FRACTION)
train_examples = train_tokenized[:num_train]
val_examples = train_tokenized[num_train:]

# Limit to MAX_STEPS batches for warmup
num_batches = min(MAX_STEPS, len(train_examples) // BATCH_SIZE)

train_ds = (
    grain.MapDataset.source(train_examples)
    .shuffle(seed=42)
    .batch(BATCH_SIZE)
    .map(pad_batch)
    [:num_batches]
)

val_ds = (
    grain.MapDataset.source(val_examples)
    .batch(BATCH_SIZE)
    .map(pad_batch)
    [:50]  # Limit validation batches
)

print(f"\nDatasets created:")
print(f"  Train batches: {len(train_ds)}")
print(f"  Val batches: {len(val_ds)}")

# Show first batch shape
first_batch = train_ds[0]
print(f"\nFirst batch shapes:")
print(f"  input_tokens: {first_batch['input_tokens'].shape}")
print(f"  input_mask: {first_batch['input_mask'].shape}")

Tokenizing training data...


Tokenizing:   0%|          | 0/7473 [00:00<?, ?it/s]

Tokenizing test data...


Tokenizing:   0%|          | 0/1319 [00:00<?, ?it/s]


Datasets created:
  Train batches: 500
  Val batches: 50

First batch shapes:
  input_tokens: (1, 285)
  input_mask: (1, 285)


## Cell 11: Define gen_model_input_fn

Required by PeftTrainer to format batch for Gemma model forward pass.

In [11]:
# Use Tunix's SFT utilities for correct attention mask format
from tunix.sft import utils as sft_utils


def gen_model_input_fn(x):
    """Transform training input for Gemma model forward pass.
    
    Uses tunix.sft.utils for correct position/attention mask formats.
    
    Args:
        x: TrainingInput with input_tokens and input_mask
        
    Returns:
        Dict with keys expected by Gemma model.
    """
    # Handle both TrainingInput namedtuple and dict
    if hasattr(x, 'input_tokens'):
        input_tokens = x.input_tokens
        input_mask = x.input_mask
    else:
        input_tokens = x['input_tokens']
        input_mask = x['input_mask']
    
    # Convert to JAX arrays if needed
    input_tokens = jnp.asarray(input_tokens)
    input_mask = jnp.asarray(input_mask)
    
    # Build padding mask (True where not padding)
    pad_mask = input_tokens != PAD_ID
    
    # Use Tunix's utilities (handles Gemma3 attention format correctly)
    positions = sft_utils.build_positions_from_mask(pad_mask)
    attention_mask = sft_utils.make_causal_attn_mask(pad_mask)
    
    return {
        'input_tokens': input_tokens,
        'input_mask': input_mask,
        'positions': positions,
        'attention_mask': attention_mask,
    }


print("gen_model_input_fn defined (using tunix.sft.utils).")

# Test it
test_input = gen_model_input_fn(train_ds[0])
print(f"Test output keys: {test_input.keys()}")
print(f"  positions shape: {test_input['positions'].shape}")
print(f"  attention_mask shape: {test_input['attention_mask'].shape}")

gen_model_input_fn defined (using tunix.sft.utils).
Test output keys: dict_keys(['input_tokens', 'input_mask', 'positions', 'attention_mask'])
  positions shape: (1, 285)
  attention_mask shape: (1, 285, 285)


## Cell 12: Configure Optimizer

In [12]:
# AdamW optimizer with warmup cosine decay
optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)

print("Optimizer configured.")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Warmup steps: {WARMUP_STEPS}")
print(f"  Weight decay: {WEIGHT_DECAY}")

Optimizer configured.
  Learning rate: 0.0001
  Warmup steps: 50
  Weight decay: 0.01


## Cell 13: Create PeftTrainer and Train

In [13]:
# Metrics logging
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir=TENSORBOARD_DIR,
    flush_every_n_steps=20,
)

# Training configuration
training_config = peft_trainer.TrainingConfig(
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=MAX_STEPS,
    checkpoint_root_directory=SFT_CKPT_DIR,
    metrics_logging_options=metrics_logging_options,
)

# Create trainer
print("Creating PeftTrainer...")
trainer = peft_trainer.PeftTrainer(
    lora_model,
    optimizer,
    training_config,
).with_gen_model_input_fn(gen_model_input_fn)

print("PeftTrainer created.")
print(f"  Max steps: {MAX_STEPS}")
print(f"  Eval every: {EVAL_EVERY_N_STEPS} steps")
print(f"  Checkpoint dir: {SFT_CKPT_DIR}")

Creating PeftTrainer...


[34m[1mwandb[0m: Currently logged in as: [33mimmpanda[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


PeftTrainer created.
  Max steps: 500
  Eval every: 50 steps
  Checkpoint dir: /home/jimnix/gitrepos/tunix-hack/outputs/checkpoints/sft/warmup


In [None]:
# Run training!
print("Starting SFT training...")
print(f"  Training batches: {len(train_ds)}")
print(f"  Validation batches: {len(val_ds)}")
print()

with mesh:
    trainer.train(train_ds, val_ds)

print("\nSFT Training complete!")
print(f"Checkpoints saved to: {SFT_CKPT_DIR}")

Starting SFT training...
  Training batches: 500
  Validation batches: 50



Training:   0%|          | 0/500 [00:00<?, ?step/s]



## Cell 14: Test Trained Model

Verify the model now generates XML format.

In [None]:
# Create sampler for inference
sampler = create_sampler(
    lora_model,
    tokenizer,
    model_config,
    max_cache_size=256 + TOTAL_GENERATION_STEPS + 256,
)
print("Sampler created.")

In [None]:
# Test questions
test_questions = [
    "What is 15 + 27?",
    "If Mary has 5 apples and gives 2 to John, how many apples does Mary have?",
    "A store sells 3 books for $12. How much does one book cost?",
]

print("Testing trained model...\n")
print("Checking if model now generates XML format:\n")

for question in test_questions:
    prompt = TEMPLATE.format(
        system_prompt=SYSTEM_PROMPT,
        question=question
    )
    
    result = sampler(
        input_strings=[prompt],
        max_generation_steps=TOTAL_GENERATION_STEPS,
        **GENERATION_CONFIGS["greedy"],
        echo=False,
        eos_tokens=[1, 106],
    )
    
    response = result.text[0]
    
    # Check for XML tags
    has_reasoning = "<reasoning>" in response and "</reasoning>" in response
    has_answer = "<answer>" in response and "</answer>" in response
    
    print(f"Question: {question}")
    print(f"Response: {response[:500]}" + ("..." if len(response) > 500 else ""))
    print(f"Has <reasoning> tags: {has_reasoning}")
    print(f"Has <answer> tags: {has_answer}")
    print("-" * 60)

## Cell 15: Save Checkpoint Info

In [None]:
# Find latest checkpoint
import os

ckpt_dir = Path(SFT_CKPT_DIR)
if ckpt_dir.exists():
    checkpoints = [d for d in ckpt_dir.iterdir() if d.is_dir() and d.name.isdigit()]
    if checkpoints:
        latest = max(checkpoints, key=lambda x: int(x.name))
        print(f"Latest checkpoint: {latest}")
        print(f"\nTo use in GRPO notebook, set:")
        print(f'SFT_CKPT_PATH = "{latest}/model_params"')
    else:
        print("No checkpoints found!")
else:
    print(f"Checkpoint directory not found: {SFT_CKPT_DIR}")

In [None]:
# Free memory
del trainer
del sampler
gc.collect()
print("Memory freed.")