# üß† Teach Gemma 3 to Reason: The "Plan-Evaluate-Execute" Loop (Final Research Version)

**Objective:** Train Gemma 3 1B to simulate System 2 thinking (deep reasoning) using GRPO.

**Fixed Issues in this Version:**
* ‚úÖ **Fixed Mode Collapse:** Solved the 0% accuracy issue where the model only generated empty XML tags.
* ‚úÖ **Fixed Cold Start:** Injected a "One-Shot" example into the training data so the model knows *what* to write.
* ‚úÖ **Fixed Infinite Loops:** Added `repetition_penalty` to inference to stop `</answer></answer>...` spam.
* ‚úÖ **Robust Config:** Patched `NoneType` errors in the Tunix configuration loader.

**Hardware:** TPU VM v3-8 (Recommended) or GPU T4 x2.

---

### **Step 1: Environment Setup**
*Installs the correct JAX/Tunix ecosystem. **Restart Kernel** after running this.*

In [None]:
# --- INSTALLATION ---
import os
import sys

print("Installing Research Environment...")

# 1. Clean Slate
!pip uninstall -y jax jaxlib flax tunix qwix libtpu-nightly

# 2. Install JAX (Detects TPU or GPU automatically)
try:
    if os.environ.get("TPU_NAME"):
        !pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    else:
        !pip install -U "jax[cuda12]"
except:
    pass

# 3. Install Tunix Ecosystem
!pip install git+https://github.com/google/tunix.git
!pip install git+https://github.com/google/qwix.git
!pip install git+https://github.com/google/flax.git

# 4. Dependencies
!pip install -q "numpy>2" tensorflow tensorflow_datasets tensorboardX transformers grain huggingface_hub datasets

print("\n" + "="*50)
print("‚ö†Ô∏è  ACTION REQUIRED: RESTART KERNEL NOW  ‚ö†Ô∏è")
print("Go to 'Run' > 'Restart Kernel' or 'Runtime' > 'Restart Session'")
print("="*50)

### **Step 2: Initialization & Config Injection**
*Loads the model and applies the critical configuration patches.*

In [None]:
import os
import jax
import jax.numpy as jnp
from flax import nnx
import qwix
from huggingface_hub import login, snapshot_download
import json
import inspect
from tunix.models.gemma3 import model as gemma_lib
from tunix.models.gemma3 import params_safetensors as params_safetensors_lib

# --- CHECK HARDWARE ---
print(f"‚úÖ JAX Version: {jax.__version__}")
try:
    devices = jax.devices()
    print(f"‚úÖ Detected Devices: {len(devices)} ({devices[0].platform})")
    IS_TPU = "tpu" in str(devices[0]).lower()
except Exception as e:
    print(f"‚ùå Hardware Error: {e}")
    IS_TPU = False

# --- AUTH ---
from kaggle_secrets import UserSecretsClient
try:
    user_secrets = UserSecretsClient()
    HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
    login(token=HF_TOKEN)
except:
    print("‚ö†Ô∏è Secrets not found. Attempting manual login...")
    login()

# --- MODEL CONFIGURATION ---
MODEL_ID = "google/gemma-3-1b-it"
local_path = snapshot_download(repo_id=MODEL_ID, ignore_patterns=["*.pth"], token=HF_TOKEN)

with open(os.path.join(local_path, "config.json"), "r") as f:
    hf_config = json.load(f)

# Robust Config Mapping
mappings = {
    "num_embed": "vocab_size", "embed_dim": "hidden_size", "hidden_dim": "intermediate_size",
    "num_heads": "num_attention_heads", "num_kv_heads": "num_key_value_heads",
    "num_layers": "num_hidden_layers", "head_dim": "head_dim", "sliding_window_size": "sliding_window",
    "rope_base_frequency": "rope_theta"
}

config_args = {}
try:
    valid_keys = set(inspect.signature(gemma_lib.ModelConfig).parameters.keys())
except:
    valid_keys = set(mappings.keys())

for tunix_k, hf_k in mappings.items():
    if tunix_k in valid_keys and hf_k in hf_config:
        config_args[tunix_k] = hf_config[hf_k]

model_config = gemma_lib.ModelConfig(**config_args)

# --- ATTRIBUTE INJECTION (Fixes NoneType Errors) ---
try:
    if not hasattr(model_config, 'query_pre_attn_scalar') or model_config.query_pre_attn_scalar is None:
        object.__setattr__(model_config, 'query_pre_attn_scalar', 256.0)
except AttributeError: pass

try:
    if not hasattr(model_config, 'sliding_window_size') or model_config.sliding_window_size is None:
        object.__setattr__(model_config, 'sliding_window_size', 4096)
except AttributeError: pass

# --- MESH SETUP ---
print("Initializing Mesh...")
if IS_TPU:
    # FSDP across all cores for TPU v3-8
    mesh = jax.make_mesh((len(jax.devices()), 1), axis_names=("fsdp", "tp"))
else:
    # GPU setup
    mesh = jax.make_mesh((1, len(jax.devices())), axis_names=("fsdp", "tp"))

with mesh:
    base_model = params_safetensors_lib.create_model_from_safe_tensors(local_path, model_config, mesh)
    
    lora_provider = qwix.LoraProvider(
        module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum",
        rank=32,
        alpha=64.0,
    )
    
    model_input = base_model.get_model_input()
    lora_policy = qwix.apply_lora_to_model(base_model, lora_provider, **model_input)
    
    state = nnx.state(lora_policy)
    sharded_state = jax.lax.with_sharding_constraint(state, nnx.get_partition_spec(state))
    nnx.update(lora_policy, sharded_state)

print("‚úÖ Policy Model Initialized (Attributes Patched).")

### **Step 3: Data Pipeline (The "Cold Start" Fix)**
**CRITICAL UPDATE:** We inject a "One-Shot" example into every prompt. This prevents the model from generating empty tags by showing it *exactly* what filled tags look like.

In [None]:
import tensorflow_datasets as tfds
import grain.python as grain
import re

# 1. THE GOLDEN EXAMPLE (Teaches the model HOW to reason)
ONE_SHOT_EXAMPLE = (
    "Problem: I have 2 apples and buy 2 more. How many total?\n"
    "<end_of_turn>\n<start_of_turn>model\n"
    "<brainstorm> 1. Simple addition: 2+2. 2. Counting on fingers. </brainstorm>\n"
    "<evaluate> Addition is faster and standard. </evaluate>\n"
    "<solve> 2 + 2 = 4. </solve>\n"
    "<answer> 4 </answer><end_of_turn>\n"
)

SYSTEM_PROMPT = (
    "You are an advanced reasoning engine. For every problem, you must follow this strict XML structure:\n"
    "1. <brainstorm> List 2-3 distinct approaches. </brainstorm>\n"
    "2. <evaluate> Critique the approaches. </evaluate>\n"
    "3. <solve> Solve step-by-step. </solve>\n"
    "4. <answer> Put the final number here. </answer>"
)

def get_structured_dataset(split="train"):
    ds = tfds.load("gsm8k", split=split, as_supervised=False)
    ds_list = list(ds.as_numpy_iterator())
    
    def format_fn(ex):
        q = ex['question'].decode('utf-8')
        a = ex['answer'].decode('utf-8').split("####")[-1].strip()
        
        # INJECT EXAMPLE HERE
        full_prompt = (
            f"<start_of_turn>user\n{SYSTEM_PROMPT}\n\n"
            f"Example:\n{ONE_SHOT_EXAMPLE}\n\n"
            f"Problem: {q}<end_of_turn>\n<start_of_turn>model\n"
        )
        return {"prompts": full_prompt, "question": q, "answer": a}
    
    # Batch size optimization (8 for TPU, 1 for GPU)
    batch_size = 8 if IS_TPU else 1
    return grain.MapDataset.source(ds_list).map(format_fn).shuffle(seed=42).batch(batch_size)

train_ds = get_structured_dataset("train").repeat(100)
print("‚úÖ Dataset Ready with One-Shot Injection.")

### **Step 4: Reward Engineering (The "Anti-Hacking" Fix)**
We introduce penalties to stop the model from gaming the system.
* **Content Penalty:** `-5.0` points if `<answer>` tags are empty.
* **Strict Correctness:** `+10.0` points for correct math (jackpot).

In [None]:
# 1. CONTENT PENALTY (Prevents Empty Tags)
def content_presence_reward(prompts, completions, **kwargs):
    rewards = []
    for c in completions:
        try:
            if "<answer>" in c and "</answer>" in c:
                content = c.split("<answer>")[1].split("</answer>")[0].strip()
                if len(content) > 0:
                    rewards.append(0.5)  # Reward for writing something
                else:
                    rewards.append(-5.0) # PENALTY for empty tags
            else:
                rewards.append(0.0)
        except:
            rewards.append(0.0)
    return rewards

# 2. STRICT CORRECTNESS (High Stakes)
def correctness_reward_strict(prompts, completions, answer, **kwargs):
    rewards = []
    for c, gt in zip(completions, answer):
        try:
            if "<answer>" in c:
                pred = c.split("<answer>")[1].split("</answer>")[0].strip()
                pred_clean = re.sub(r"[^0-9\.\-]", "", pred)
                gt_clean = re.sub(r"[^0-9\.\-]", "", gt)
                
                if pred_clean == gt_clean and len(pred_clean) > 0:
                    rewards.append(10.0) # JACKPOT
                else:
                    rewards.append(0.0)
            else:
                rewards.append(0.0)
        except:
            rewards.append(0.0)
    return rewards

# 3. FORMAT GATE (Weak Nudge)
def format_reward_weak(prompts, completions, **kwargs):
    rewards = []
    required_tags = ["<brainstorm>", "<evaluate>", "<solve>", "<answer>"]
    for c in completions:
        if all(tag in c for tag in required_tags):
            rewards.append(0.1)
        else:
            rewards.append(-1.0) # Penalty for breaking structure
    return rewards

print("‚úÖ Rewards Ready: Penalties Active.")

### **Step 5: GRPO Training**
Training with `beta=0.1` to prevent the model from drifting too far (Mode Collapse protection).

In [None]:
import optax
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.generate import tokenizer_adapter as tokenizer_lib

GEMMA_TOKENIZER_PATH = "gs://gemma-data/tokenizers/tokenizer_gemma3.model"
tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=GEMMA_TOKENIZER_PATH)
CHECKPOINT_DIR = os.path.abspath("checkpoints/grpo_final")

cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optax.adamw(learning_rate=2e-6),
        max_steps=300,
        mini_batch_size=8 if IS_TPU else 1,
        checkpoint_root_directory=CHECKPOINT_DIR,
        eval_every_n_steps=1000, 
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=512,
        temperature=0.85,
        eos_tokens=[tokenizer.eos_id()],
    ),
)

rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,
    reference=base_model,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[format_reward_weak, content_presence_reward, correctness_reward_strict],
    algo_config=GRPOConfig(
        num_generations=8 if IS_TPU else 4,
        beta=0.1, # Increased Beta to stabilize logic
    ),
)

print("üöÄ Starting Final GRPO Training Loop...")
with mesh:
    trainer.train(train_ds, None)
print("‚úÖ Training Complete.")

### **Step 6: Inference & Evaluation (Loop Fix)**
Using `repetition_penalty=1.2` to ensure the model stops generating tags and finishes the answer.

In [None]:
from tunix.generate import sampler as sampler_lib
import tunix.generate
import inspect

# --- ROBUST CACHE CONFIG --- #
CacheConfigClass = tunix.generate.sampler.CacheConfig
valid_keys = set(inspect.signature(CacheConfigClass).parameters.keys())
cache_args = {
    "cache_size": 2048, "num_layers": model_config.num_layers,
    "num_kv_heads": model_config.num_kv_heads, "head_dim": model_config.head_dim,
    "dtype": jnp.bfloat16
}
final_args = {k: v for k, v in cache_args.items() if k in valid_keys}
cache_cfg = CacheConfigClass(**final_args)

# --- ROBUST SAMPLER INIT --- #
try:
    sampler = sampler_lib.Sampler(module=lora_policy, tokenizer=tokenizer, cache_config=cache_cfg)
except TypeError:
    sampler = sampler_lib.Sampler(lora_policy, tokenizer, cache_cfg)

# --- TEST PROMPT --- #
prompt_text = (
    f"<start_of_turn>user\n{SYSTEM_PROMPT}\n\n"
    f"Example:\n{ONE_SHOT_EXAMPLE}\n\n"
    "Problem: A store sells apples for $2 and oranges for $3. Alice buys 5 fruits and spends $12. How many apples did she buy?"
    "<end_of_turn>\n<start_of_turn>model\n"
)

print("üß† Generating with Repetition Penalty (Anti-Loop)...")
with mesh:
    try:
        # The Repetition Penalty is key here
        outputs = sampler(input_strings=[prompt_text], max_generation_steps=1024, temperature=0.7, repetition_penalty=1.2)
    except TypeError:
        # Fallback for older Tunix versions without rep_penalty kwarg
        outputs = sampler([prompt_text], 1024, 0.7)

# --- PARSING --- #
output_text = outputs.text[0]
print("="*60)
print(f"üìù RAW OUTPUT:\n{output_text}")
print("="*60)

if "<answer>" in output_text:
    ans = output_text.split("<answer>")[1].split("</answer>")[0]
    print(f"üéØ FINAL ANSWER EXTRACTED: {ans}")
else:
    print("‚ùå No Answer Tag Found (Check logs).")