In [23]:
# %% [code]
"""
MBPP+ GRPO Training on Kaggle TPU
Train Gemma-3-1B on MBPP+ code generation with 10 samples.
"""

import os
import sys

print("=" * 80)
print("MBPP+ GRPO Training - 10 Samples")
print("=" * 80)

# Step 1: Download MBPP+ dataset
print("\n[1/6] Downloading MBPP+ dataset...")
import subprocess
result = subprocess.run(
    ["pip", "install", "-q", "datasets"],
    capture_output=True,
    text=True,
    timeout=60
)
print("  ✓ datasets library installed")

# Download dataset
from datasets import load_dataset
print("  Downloading evalplus/mbppplus...")
os.makedirs("./data/mbppplus_hf", exist_ok=True)
dataset = load_dataset("evalplus/mbppplus", split="test")
dataset.to_parquet("./data/mbppplus_hf/test.parquet")
print(f"  ✓ Downloaded {len(dataset)} samples")

# Step 2: Import libraries
print("\n[2/6] Importing libraries...")
import jax
from flax import nnx
import optax

from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.models.gemma3 import model as gemma_lib
from tunix.models.gemma3 import params_safetensors as params_safetensors_lib
from tunix.rl.rollout import base_rollout
from huggingface_hub import snapshot_download

from mbpp_data_loader import get_mbpp_dataset
from reward_functions_mbpp import DEFAULT_REWARD_FNS_MBPP
from config import *

print("  ✓ All libraries imported")

# Step 3: Detect TPU and create mesh
print("\n[3/6] Setting up TPU mesh...")
devices = jax.devices()
device_type = devices[0].platform
num_devices = len(devices)
print(f"  Device type: {device_type}")
print(f"  Number of devices: {num_devices}")

if num_devices == 8:
    # TPU v3-8: reshape to 2D (1, 8)
    print(f"  Using 2D mesh: (1, 8)")
    import numpy as np
    devices_2d = np.array(devices).reshape(1, 8)
    mesh = jax.sharding.Mesh(
        devices_2d,
        axis_names=('fsdp', 'tp')
    )
elif num_devices == 1:
    # Single device: use 2D mesh (1, 1)
    print(f"  Using 2D mesh (1, 1) for single device")
    import numpy as np
    devices_2d = np.array(devices).reshape(1, 1)
    mesh = jax.sharding.Mesh(
        devices_2d,
        axis_names=('fsdp', 'tp')
    )
else:
    # Use config.MESH
    mesh = jax.make_mesh(
        *MESH,
        axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0])
    )

print(f"  ✓ Mesh created: {mesh}")

# Step 4: Load dataset (10 samples)
print("\n[4/6] Loading 10 MBPP+ samples...")
train_dataset, val_dataset, test_dataset, dataset_lengths = get_mbpp_dataset(
    local_path="./data/mbppplus_hf",
    train_fraction=1.0,
    batch_size=1,
    num_train_batches=10,
    num_test_batches=2,
    num_epochs=1,
    shuffle=False,
)
print(f"  ✓ Loaded {dataset_lengths[0]} training batches")

# Step 5: Load model
print("\n[5/6] Loading Gemma-3-1B model...")
model_config = gemma_lib.ModelConfig.gemma3_1b_it()

# Download model
print("  Downloading from Hugging Face...")
local_model_path = snapshot_download(
    repo_id=MODEL_ID,
    ignore_patterns=["*.pth"]
)
print(f"  Model at: {local_model_path}")

# Create model from safetensors
print("  Creating model on TPU mesh...")
with mesh:
    actor_model = params_safetensors_lib.create_model_from_safe_tensors(
        local_model_path, model_config, mesh
    )
print("  ✓ Model loaded")

# Create tokenizer
from tunix.generate import tokenizer_adapter as tokenizer_lib
tokenizer = tokenizer_lib.Tokenizer(
    tokenizer_path=TOKENIZER_PATH,
    tokenizer_type='sentencepiece'
)
print("  ✓ Tokenizer loaded")

# Create optimizer and cluster config
optimizer = optax.adamw(learning_rate=LEARNING_RATE)
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=10,
        max_steps=10,
        mini_batch_size=1,
        train_micro_batch_size=1,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_SEQ_LEN,
        temperature=TEMPERATURE,
        top_k=TOP_K,
        top_p=TOP_P,
    ),
)

# Create RL cluster
rl_cluster = rl_cluster_lib.RLCluster(
    actor=actor_model,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)
print("  ✓ RL Cluster created")

# Step 6: Create GRPO trainer and run
print("\n[6/6] Running GRPO training on 10 samples...")
grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,
    beta=BETA,
    epsilon=EPSILON,
    num_iterations=NUM_ITERATIONS,
)

grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=DEFAULT_REWARD_FNS_MBPP,
    algo_config=grpo_config,
)
print(f"  Config: {NUM_GENERATIONS} generations, beta={BETA}, epsilon={EPSILON}")
print(f"  Reward functions: {len(DEFAULT_REWARD_FNS_MBPP)}")

print("\n" + "=" * 80)
print("Starting training...")
print("=" * 80)

try:
    grpo_trainer.train(
        train_ds=train_dataset,
        eval_ds=test_dataset,
    )

    print("\n" + "=" * 80)
    print("✅ Training completed successfully!")
    print("=" * 80)

except Exception as e:
    print("\n" + "=" * 80)
    print(f"❌ Training failed: {e}")
    print("=" * 80)
    import traceback
    traceback.print_exc()
    sys.exit(1)


MBPP+ GRPO Training - 10 Samples

[1/6] Downloading MBPP+ dataset...
  ✓ datasets library installed


AttributeError: module 'pyarrow' has no attribute 'PyExtensionType'