# Tunix Gemma Reasoning Training

This notebook fine-tunes Gemma2 2B or Gemma3 1B using Tunix GRPO to teach step-by-step reasoning.

## Setup


In [None]:
# Install dependencies
%pip install -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
%pip install -q flax optax transformers datasets
# Install Tunix from GitHub
%pip install -q git+https://github.com/google/tunix.git


In [None]:
import os
import json
import yaml
from pathlib import Path
import jax
import jax.numpy as jnp
from datetime import datetime

# Verify TPU setup
print(f"JAX devices: {jax.devices()}")
print(f"JAX platform: {jax.default_backend()}")


## Configuration


In [None]:
# Training configuration
config = {
    'model': {
        'name': 'gemma2-2b',  # or 'gemma3-1b'
        'max_length': 2048,
        'temperature': 0.7
    },
    'training': {
        'batch_size': 4,
        'learning_rate': 1e-5,
        'max_steps': 2000,
        'save_steps': 500,
        'eval_steps': 250,
        'group_size': 4,  # GRPO group size
        'kl_coefficient': 0.1
    },
    'data': {
        'train_file': '/path/to/train.jsonl',
        'eval_file': '/path/to/eval.jsonl'
    },
    'output': {
        'output_dir': './checkpoints/',
        'run_name': 'gemma-reasoning'
    }
}

print("Configuration:")
print(json.dumps(config, indent=2))
