<a href="https://colab.research.google.com/github/dingkwang/tpu_training/blob/master/12_26_14_56_MBPP_GRPO_Training_Notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MBPP+ GRPO Training - All in One Notebook
Complete pipeline for training Gemma-3-1B on MBPP+ code generation using GRPO.

## Step 1: Install Dependencies

In [1]:
# !pip install uv
# !uv init

In [2]:
%%time
print("Installing dependencies...")
!pip install datasets pyarrow google-tunix[prod]
print("✓ All dependencies installed")

Installing dependencies...
✓ All dependencies installed
CPU times: user 295 ms, sys: 60.1 ms, total: 355 ms
Wall time: 2.05 s


In [3]:
import sys
import datasets
print("python", sys.version)
print("executable", sys.executable)
print("datasets", datasets.__version__)
from datasets import load_dataset


python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
executable /usr/bin/python3
datasets 4.4.2


## Step 2: Configuration

In [4]:
# Model configuration
MODEL_ID = "google/gemma-3-1b-it"
# TOKENIZER_PATH = "gs://gemma-data/tokenizers/tokenizer_gemma3.model" # Commented out to use local file
MAX_SEQ_LEN = 1024


# Training configuration
LEARNING_RATE = 3e-6
NUM_GENERATIONS = 2
TEMPERATURE = 0.9
TOP_K = 50
TOP_P = 1.0
BETA = 0.08
EPSILON = 0.2
NUM_ITERATIONS = 1
TOTAL_GENERATION_STEPS = 768
MAX_PROMPT_LENGTH = 256

print("✓ Configuration loaded")

✓ Configuration loaded


## Step 3: Define Helper Functions

In [5]:
import re
import subprocess
from typing import List

# ============================================================================
# Reward Functions
# ============================================================================

def extract_code(completion: str) -> str | None:
    """Extract Python code from model completion."""
    # Try ```python ... ```
    python_block = re.search(r'```python\s*\n(.*?)\n```', completion, re.DOTALL)
    if python_block:
        return python_block.group(1).strip()

    # Try ``` ... ```
    generic_block = re.search(r'```\s*\n(.*?)\n```', completion, re.DOTALL)
    if generic_block:
        return generic_block.group(1).strip()

    # Look for function definition
    if 'def ' in completion:
        def_start = completion.find('def ')
        if def_start != -1:
            return completion[def_start:].strip()

    return None

def has_code_block(completion: str) -> bool:
    """Check if completion contains a code block."""
    return bool(re.search(r'```(?:python)?\s*\n.*?\n```', completion, re.DOTALL))

def execute_test(code: str, test: str, test_imports, timeout: float = 3.0) -> tuple[int, int]:
    """Execute MBPP+ tests on generated code."""
    try:
        script_parts = []

        # Add imports
        if test_imports:
            if isinstance(test_imports, str):
                if test_imports.strip():
                    script_parts.append(test_imports)
            else:
                for imp in test_imports:
                    script_parts.append(imp)

        # Add code and test
        script_parts.append(code)
        script_parts.append(test)
        full_script = '\n'.join(script_parts)

        # Execute
        result = subprocess.run(
            ["python3", "-c", full_script],
            timeout=timeout,
            capture_output=True,
            text=True
        )

        test_count = test.count('assert')
        if result.returncode == 0:
            return (test_count, test_count)
        else:
            return (0, test_count if test_count > 0 else 1)

    except subprocess.TimeoutExpired:
        test_count = test.count('assert')
        return (0, test_count if test_count > 0 else 1)
    except Exception:
        test_count = test.count('assert') if test else 1
        return (0, test_count if test_count > 0 else 1)

def mbppplus_verifier_reward(prompts: List[str], completions: List[str],
                              test, test_imports=None, **kwargs) -> List[float]:
    """Main MBPP+ reward: test pass rate."""
    rewards = []

    # Handle test and test_imports batching
    if isinstance(test, str):
        tests = [test] * len(completions)
    else:
        try:
            test_list = list(test)
            if len(test_list) == len(completions):
                tests = test_list
            else:
                tests = test_list * len(completions) if len(test_list) > 0 else [test_list[0]] * len(completions)
        except:
            tests = [test] * len(completions)

    if test_imports is None:
        imports_list = [None] * len(completions)
    elif isinstance(test_imports, str):
        imports_list = [test_imports] * len(completions)
    else:
        try:
            test_imports_list = list(test_imports)
            if len(test_imports_list) == len(completions):
                imports_list = test_imports_list
            else:
                imports_list = [test_imports_list] * len(completions)
        except:
            imports_list = [test_imports] * len(completions)

    # Evaluate each completion
    for i, completion in enumerate(completions):
        code = extract_code(completion)
        if code is None:
            rewards.append(0.0)
            continue

        try:
            passed, total = execute_test(code, tests[i], imports_list[i], timeout=3.0)
            reward = passed / total if total > 0 else 0.0
            rewards.append(reward)
        except Exception:
            rewards.append(0.0)

    return rewards

def code_format_reward(prompts: List[str], completions: List[str], **kwargs) -> List[float]:
    """Format reward: encourage code blocks."""
    return [0.5 if has_code_block(c) else -0.2 for c in completions]

DEFAULT_REWARD_FNS_MBPP = [
    mbppplus_verifier_reward,
    code_format_reward,
]

print("✓ Reward functions defined")

✓ Reward functions defined


## Step 4: Data Loader Function

In [6]:
import grain.python as grain
import pyarrow
import pyarrow.parquet as pq
import numpy as np

def get_mbpp_dataset(
    local_path="./data/mbppplus_hf",
    train_fraction=0.9,
    batch_size=1,
    num_train_batches=None,
    num_test_batches=64,
    num_epochs=1,
    shuffle=True,
    seed=42,
):
    """Load MBPP+ dataset using grain."""
    import glob

    # Load parquet
    parquet_files = glob.glob(f"{local_path}/*.parquet")
    if not parquet_files:
        raise FileNotFoundError(f"No parquet files found in {local_path}")

    table = pq.read_table(parquet_files[0])
    dataset = table.to_pylist()

    # Format prompts
    for item in dataset:
        item['prompts'] = f"""# Problem: {item['prompt']}
# Write a Python function to solve this problem.
# Return only Python code in a ```python ... ``` block.

"""
        # Convert test_imports list to string
        if 'test_imports' in item and item['test_imports']:
            if isinstance(item['test_imports'], list):
                item['test_imports'] = '\n'.join(item['test_imports'])

    # Split dataset
    total_samples = len(dataset)
    train_size = int(total_samples * train_fraction)

    if shuffle:
        np.random.seed(seed)
        indices = np.random.permutation(total_samples)
        dataset = [dataset[i] for i in indices]

    train_data = dataset[:train_size]
    test_data = dataset[train_size:] if train_size < total_samples else dataset[-2:]

    # Create grain datasets
    train_source = grain.MapDataset.source(train_data)
    test_source = grain.MapDataset.source(test_data)

    # Apply transformations
    train_ds = train_source.batch(batch_size=batch_size)
    test_ds = test_source.batch(batch_size=batch_size)

    if num_train_batches:
        train_ds = train_ds[:num_train_batches]
    if num_test_batches:
        test_ds = test_ds[:num_test_batches]

    train_ds = train_ds.repeat(num_epochs)

    dataset_lengths = (len(train_ds), 0, len(test_ds))

    return train_ds, None, test_ds, dataset_lengths

print("✓ Data loader defined")



✓ Data loader defined


## Step 5: Download MBPP+ Dataset

In [7]:
import datasets
print("pyarrow.__version__", pyarrow.__version__)
print("datasets.__version__", datasets.__version__)

pyarrow.__version__ 22.0.0
datasets.__version__ 4.4.2


In [8]:
%%time
import os
import numpy
print(numpy.__version__)
from datasets import load_dataset

print(pyarrow.__version__)
print("Downloading MBPP+ dataset...")
os.makedirs("./data/mbppplus_hf", exist_ok=True)
dataset = load_dataset("evalplus/mbppplus", split="test")
dataset.to_parquet("./data/mbppplus_hf/test.parquet")
print(f"✓ Downloaded {len(dataset)} samples")

2.3.5
22.0.0
Downloading MBPP+ dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/518 [00:00<?, ?B/s]

data/test-00000-of-00001-d5781c9c51e0279(…):   0%|          | 0.00/1.13M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/378 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

✓ Downloaded 378 samples
CPU times: user 284 ms, sys: 111 ms, total: 394 ms
Wall time: 2.59 s


## Step 6: Import Libraries and Setup

In [9]:
import jax
from flax import nnx
import optax
from huggingface_hub import snapshot_download

from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.models.gemma3 import model as gemma_lib
from tunix.models.gemma3 import params_safetensors as params_safetensors_lib
from tunix.rl.rollout import base_rollout
from tunix.generate import tokenizer_adapter as tokenizer_lib

print("✓ All libraries imported")

✓ All libraries imported


## Step 7: Setup TPU Mesh

In [10]:
NUM_TPUS = len(jax.devices())
print(NUM_TPUS)
MESH_COUNTS = (2, 4)
MESH = [MESH_COUNTS, ("fsdp", "tp")]

devices = jax.devices()
device_type = devices[0].platform
num_devices = len(devices)

print(f"Device type: {device_type}")
print(f"Number of devices: {num_devices}")

import numpy as np
if num_devices == 8:
    print("Using 2D mesh: (1, 8)")
    devices_2d = np.array(devices).reshape(1, 8)
    mesh = jax.make_mesh(
        *MESH,
        axis_types=(jax.sharding.AxisType.Auto,) * len(MESH_COUNTS),
    )
elif num_devices == 1:
    print("Using 2D mesh: (1, 1)")
    devices_2d = np.array(devices).reshape(1, 1)
    mesh = jax.sharding.Mesh(devices_2d, axis_names=('fsdp', 'tp'))
else:
    raise ValueError(f"Unsupported device count: {num_devices}")

print(f"✓ Mesh created: {mesh}")

1
Device type: tpu
Number of devices: 1
Using 2D mesh: (1, 1)
✓ Mesh created: Mesh('fsdp': 1, 'tp': 1, axis_types=(Auto, Auto))


## Step 8: Load Dataset (10 samples)

In [11]:
%%time
print("Loading 10 MBPP+ samples...")
train_dataset, val_dataset, test_dataset, dataset_lengths = get_mbpp_dataset(
    local_path="./data/mbppplus_hf",
    train_fraction=1.0,
    batch_size=1,
    num_train_batches=10,
    num_test_batches=2,
    num_epochs=1,
    shuffle=False,
)
print(f"✓ Loaded {dataset_lengths[0]} training batches")

Loading 10 MBPP+ samples...
✓ Loaded 10 training batches
CPU times: user 11.5 ms, sys: 5.37 ms, total: 16.9 ms
Wall time: 15.6 ms


## Step 9: Load Gemma-3-1B Model

In [15]:
from getpass import getpass
token = getpass("Enter your Hugging Face token (will not be shown): ")
os.environ["HF_TOKEN"] = token

Enter your Hugging Face token (will not be shown): ··········


In [16]:
%%time
import copy
import os

from huggingface_hub import snapshot_download


print("Loading Gemma-3-1B model...")
model_config = gemma_lib.ModelConfig.gemma3_1b()

# Download model
print("  Downloading from Hugging Face...")
local_model_path = snapshot_download(
    repo_id=MODEL_ID,
    ignore_patterns=["*.pth"],
    token=token
)
print(f"  Model at: {local_model_path}")

# Create model
print("  Creating model on mesh...")
with mesh:
    actor_model = params_safetensors_lib.create_model_from_safe_tensors(
        local_model_path, model_config, mesh
    )
    ref_model = copy.deepcopy(actor_model)

print("✓ Model loaded")

# Create tokenizer
# Use local tokenizer from downloaded model instead of GCS
tokenizer_path = os.path.join(local_model_path, "tokenizer.model")
print(f"  Loading tokenizer from: {tokenizer_path}")

tokenizer = tokenizer_lib.Tokenizer(
    tokenizer_path=tokenizer_path,
    tokenizer_type='sentencepiece'
)
print("✓ Tokenizer loaded")

Loading Gemma-3-1B model...
  Downloading from Hugging Face...


Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/899 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.00G [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

  Model at: /root/.cache/huggingface/hub/models--google--gemma-3-1b-it/snapshots/dcc83ea841ab6100d6b47a070329e1ba4cf78752
  Creating model on mesh...
✓ Model loaded
  Loading tokenizer from: /root/.cache/huggingface/hub/models--google--gemma-3-1b-it/snapshots/dcc83ea841ab6100d6b47a070329e1ba4cf78752/tokenizer.model
✓ Tokenizer loaded
CPU times: user 9.11 s, sys: 4.95 s, total: 14.1 s
Wall time: 14.5 s


## Step 10: Create RL Cluster

In [17]:
optimizer = optax.adamw(learning_rate=LEARNING_RATE)
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=10,
        max_steps=10,
        mini_batch_size=1,
        train_micro_batch_size=1,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_SEQ_LEN,
        temperature=TEMPERATURE,
        top_k=TOP_K,
        top_p=TOP_P,
    ),
)

rl_cluster = rl_cluster_lib.RLCluster(
    actor=actor_model,
    reference=ref_model,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)
print("✓ RL Cluster created")

✓ RL Cluster created


## Step 11: Run GRPO Training

In [18]:
%%time
print("Creating GRPO trainer...")
grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,
    beta=BETA,
    epsilon=EPSILON,
    num_iterations=NUM_ITERATIONS,
)

grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=DEFAULT_REWARD_FNS_MBPP,
    algo_config=grpo_config,
)
print(f"Config: {NUM_GENERATIONS} generations, beta={BETA}, epsilon={EPSILON}")
print(f"Reward functions: {len(DEFAULT_REWARD_FNS_MBPP)}")

print("\n" + "=" * 80)
print("Starting GRPO training on 10 samples...")
print("=" * 80)

grpo_trainer.train(
    train_ds=train_dataset,
    eval_ds=test_dataset,
)

print("\n" + "=" * 80)
print("✅ Training completed successfully!")
print("=" * 80)

Creating GRPO trainer...
Config: 2 generations, beta=0.08, epsilon=0.2
Reward functions: 2

Starting GRPO training on 10 samples...


Actor Training:   0%|          | 0/10 [00:00<?, ?step/s]


✅ Training completed successfully!
CPU times: user 4min 5s, sys: 2.65 s, total: 4min 8s
Wall time: 1min 37s
