# Tunix Gemma Setup
This notebook prepares a Gemma checkpoint for Tunix-based supervised fine-tuning and GRPO experiments on Kaggle TPUs.

## Goals
- Install Tunix and required libraries on the TPU runtime.
- Authenticate with Hugging Face to download Gemma checkpoints.
- Load tokenizer and model in Flax/JAX with bf16 precision.
- Scaffold Tunix supervised and GRPO trainers ready for custom datasets and rewards.

In [None]:
# Cell 1 — Environment bootstrap
!pip install --quiet "google-tunix[prod]" flax transformers accelerate datasets sentencepiece

In [None]:
# Cell 2 — Imports and runtime checks
import os
import random
import jax
import jax.numpy as jnp
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
import tunix


SEED = 42
random.seed(SEED)
jax.random.key(SEED)

devices = jax.devices()
print(f"Detected {len(devices)} JAX device(s).")
print(devices[0])

In [None]:
# Cell 3 — Hugging Face authentication helpers
import getpass

HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
    HF_TOKEN = getpass.getpass("Enter your Hugging Face token (read access only): ")

os.environ["HF_TOKEN"] = HF_TOKEN
os.environ["HF_HOME"] = os.environ.get("HF_HOME", "/kaggle/temp/hf-cache")

MODEL_NAME = "google/gemma-2-2b-it"  # swap to base Gemma variant if desired
TOKENIZER_NAME = MODEL_NAME
print(f"Using checkpoint: {MODEL_NAME}")

In [None]:
# Cell 4 — Load tokenizer and base model
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, token=HF_TOKEN)
tokenizer.padding_side = "right"
tokenizer.truncation_side = "left"

model = FlaxAutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    token=HF_TOKEN,
    dtype=jnp.bfloat16,
    _do_init=False,
)

param_count = sum(p.size for p in jax.tree_util.tree_leaves(model.params))
print(f"Loaded model with ~{param_count / 1e6:.1f}M parameters in bf16.")

In [None]:
# Cell 5 — Tunix supervised fine-tuning scaffold
from datasets import load_dataset
from tunix.training.supervised import SupervisedConfig, SupervisedTrainer

DATASET_NAME = "gsm8k"
dataset = load_dataset(DATASET_NAME, "main", split="train[:128]")

def format_example(example):
    prompt = f"Question: {example['question']}\nPlease think step by step before answering."
    answer = example["answer"].split("####")[-1].strip()
    target = f"<reasoning>{example['answer']}</reasoning><answer>{answer}</answer>"
    return {"prompt": prompt, "response": target}

processed_dataset = dataset.map(format_example, remove_columns=dataset.column_names)

sft_config = SupervisedConfig(
    learning_rate=1e-5,
    max_steps=10,
    per_device_batch_size=1,
    gradient_accumulation_steps=4,
    max_seq_length=768,
    output_dir="/kaggle/working/tunix-sft-checkpoints",
    logging_steps=1,
 )

sft_trainer = SupervisedTrainer(
    model=model,
    tokenizer=tokenizer,
    config=sft_config,
    train_dataset=processed_dataset,
 )

print("Supervised trainer initialized — ready for warmup runs.")

In [None]:
# Cell 6 — GRPO configuration preview
from tunix.training.rl import GRPOConfig, GRPOTrainer
from tunix.rewards import basic_trace_reward

grpo_config = GRPOConfig(
    learning_rate=5e-6,
    kl_weight=0.1,
    num_generations=4,
    max_prompt_length=512,
    max_response_length=512,
    total_training_steps=50,
    logging_steps=5,
 )

grpo_trainer = GRPOTrainer(
    model=model,
    tokenizer=tokenizer,
    reward_fn=basic_trace_reward,
    config=grpo_config,
    prompt_dataset=processed_dataset,
 )

print("GRPO trainer scaffold ready — plug in custom rewards before training.")

## Next Steps
- Replace the placeholder GSM8k slice with the custom curated dataset.
- Swap in the process-aware reward function before launching GRPO training.
- Promote this notebook to Kaggle once verified end-to-end on TPU.