<a href="https://colab.research.google.com/github/iamfaham/gemma3-rzero/blob/main/gemma3_rzero.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Unsloth official install (fastest path)
!pip install -q unsloth
!pip install -q accelerate transformers bitsandbytes trl

# Optional for API access / logging
!pip install -q requests tqdm  datasets

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.2/53.2 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.2/317.2 kB[0m [31m27.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m503.6/503.6 kB[0m [31m41.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m564.7/564.7 kB[0m [31m46.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m257.7/257.7 kB[0m [31m25.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.2/117.2 MB[0m [31m21.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.1/60.1 MB[0m [31m40.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.5/132.5 kB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import os, json, re, random, torch
from tqdm import tqdm
from datasets import Dataset


# ---- Unsloth + TRL (GRPO)
from unsloth import FastLanguageModel
from trl import GRPOConfig, GRPOTrainer

# ---- Repro
random.seed(42)
torch.manual_seed(42)

MAX_TOKENS_CHALLENGER = 512

NUM_TASKS = 200          # how many tasks to pre-generate for the curriculum
# Challenger token budgets
TASK_MAX_NEW_TOKENS  = 96
GRADE_MAX_NEW_TOKENS = 64

# RL knobs (safe for T4; scale later)
NUM_GENERATIONS = 2
BATCH_SIZE      = 2          # MUST be a multiple of NUM_GENERATIONS
MAX_STEPS       = 50

# Judge batch size (smaller = less VRAM, fewer spikes)
BATCH_GRADE_SIZE = 3



LR = 1e-5                # conservative LR for LoRA RL
SAVE_DIR = "/content/solver_checkpoints"
os.makedirs(SAVE_DIR, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
os.environ["WANDB_DISABLED"] = "true"
print("Device:", device)


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
Device: cuda


In [None]:
# --- Challenger: Llama-3.1-8B-Instruct (4-bit NF4 + offload) ---
import os, gc, torch, json, re, contextlib
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128"

# Primary small challenger (great JSON discipline)
CHALLENGER_ID = "microsoft/Phi-3.5-mini-instruct"

# If a previous model exists, free it (important when switching from GPT-OSS/Qwen)
for _name in ["gptoss_model", "gptoss_tok"]:
    if _name in globals():
        try: del globals()[_name]
        except: pass
torch.cuda.empty_cache(); gc.collect()

has_gpu = torch.cuda.is_available()
OFFLOAD_DIR = "/content/offload"
os.makedirs(OFFLOAD_DIR, exist_ok=True)

# 4-bit quant via BitsAndBytes (works well for Llama 8B)
bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# accelerate expects integer GPU keys
max_mem = ({0: "14GiB", "cpu": "48GiB"} if has_gpu else {"cpu": "48GiB"})

gptoss_tok = AutoTokenizer.from_pretrained(
    CHALLENGER_ID, use_fast=True, trust_remote_code=True
)

gptoss_model = AutoModelForCausalLM.from_pretrained(
    CHALLENGER_ID,
    quantization_config=(bnb_cfg if has_gpu else None),  # CPU path loads in fp32
    dtype=(torch.bfloat16 if has_gpu else torch.float32),
    device_map=("auto" if has_gpu else "cpu"),
    max_memory=max_mem,
    low_cpu_mem_usage=True,
    offload_folder=OFFLOAD_DIR,
    trust_remote_code=True,
)

if gptoss_tok.pad_token_id is None:
    gptoss_tok.pad_token_id = gptoss_tok.eos_token_id

def challenger_chat(prompt: str, max_tokens: int, deterministic: bool) -> str:
    # deterministic=True for strict JSON task gen; False (low temp) for grading
    inputs = gptoss_tok(prompt, return_tensors="pt")
    if has_gpu:
        inputs = {k: v.to(gptoss_model.device) for k, v in inputs.items()}

    amp_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) if has_gpu else contextlib.nullcontext()
    with torch.no_grad():
        with amp_ctx:
            out = gptoss_model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                do_sample=(not deterministic),
                temperature=(0.0 if deterministic else 0.15),
                pad_token_id=gptoss_tok.eos_token_id,
                eos_token_id=gptoss_tok.eos_token_id,
                use_cache=False, # Disable cache to avoid AttributeError
            )
    text = gptoss_tok.decode(out[0], skip_special_tokens=True)
    if text.startswith(prompt):
        text = text[len(prompt):].lstrip()
    return text

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
TASK_PROMPT = (
    "Generate a SHORT (2–3 sentences) passage and a question that requires cause–effect reasoning.\n"
    "Respond ONLY with a single-line JSON object exactly like this:\n"
    '{"passage":"...", "question":"..."}\n'
    "Rules: no prose, no code fences, no explanations, no trailing text. Keys must be 'passage' and 'question'."
)

def _extract_json_block(s: str):
    import re, json
    cands = re.findall(r"\{[^{}]*\}|\{(?:[^{}]|\{[^{}]*\})*\}", s, flags=re.S)
    for block in sorted(cands, key=len, reverse=True):
        try:
            obj = json.loads(block)
            if isinstance(obj, dict) and "passage" in obj and "question" in obj:
                return obj
        except:
            pass
    return None

def generate_task(max_retries: int = 2):
    raw = challenger_chat(TASK_PROMPT, max_tokens=TASK_MAX_NEW_TOKENS, deterministic=True)
    obj = _extract_json_block(raw)
    retries = 0
    while (not obj or not obj.get("passage") or not obj.get("question")) and retries < max_retries:
        corrective = (
            "Respond ONLY with a single-line JSON object with exactly keys 'passage' and 'question'. "
            'Example: {"passage":"Two sentences here.", "question":"One question here?"}'
        )
        raw = challenger_chat(corrective, max_tokens=TASK_MAX_NEW_TOKENS, deterministic=True)
        obj = _extract_json_block(raw)
        retries += 1
    if obj and obj.get("passage") and obj.get("question"):
        return {"passage": obj["passage"].strip(), "question": obj["question"].strip()}
    return {"passage":"", "question":""}

# --- Add after generate_task(): batched grading + single-answer grading ---
EVAL_PROMPT_TMPL = """You are a strict evaluator. For each item below, score EACH candidate answer in [0,1]
for correctness, completeness, and clarity. Respond ONLY with valid JSON: a list where each element
is a list of floats for that item's candidate answers, in order.

{batch_text}

Return format example (for 2 items, 3 answers each):
[[0.5, 0.8, 0.2],[1.0,0.9,0.1]]
"""

def evaluate_batch(items):
    """
    items: List[{'passage': str, 'question': str, 'answers': [str, ...]}]
    return: List[List[float]] aligned with answers per item.
    """
    parts = []
    for idx, it in enumerate(items, 1):
        answers_str = "\n".join([f"  {j+1}. {a}" for j, a in enumerate(it["answers"])])
        parts.append(
f"""Item {idx}:
Passage: {it['passage']}
Question: {it['question']}
Candidate answers:
{answers_str}"""
        )
    batch_text = "\n\n".join(parts)
    prompt = EVAL_PROMPT_TMPL.format(batch_text=batch_text)

    raw = challenger_chat(prompt, max_tokens=GRADE_MAX_NEW_TOKENS, deterministic=False)
    import json, re
    try:
        arr = json.loads(re.search(r"\[.*\]", raw, flags=re.S).group(0))
        return arr
    except Exception:
        return [[0.0]*len(it["answers"]) for it in items]

def evaluate_answer(passage: str, question: str, answer: str) -> float:
    """Convenience single-sample grader used at the end-of-notebook sanity check."""
    out = evaluate_batch([{"passage": passage, "question": question, "answers": [answer]}])
    try:
        return float(out[0][0])
    except Exception:
        return 0.0



In [None]:
# --- Utilities: on-demand ensure Challenger is loaded (MxFP4-aware) ---
def _ensure_challenger_loaded():
    global gptoss_model, gptoss_tok
    if "gptoss_model" in globals():
        return
    has_gpu = torch.cuda.is_available()
    max_mem = ({0:"14GiB","cpu":"48GiB"} if has_gpu else {"cpu":"48GiB"})
    gptoss_tok = AutoTokenizer.from_pretrained(CHALLENGER_ID, use_fast=True, trust_remote_code=True)
    gptoss_model = AutoModelForCausalLM.from_pretrained(
        CHALLENGER_ID,
        device_map=("auto" if has_gpu else "cpu"),
        dtype=(torch.float16 if has_gpu else torch.float32),
        low_cpu_mem_usage=True,
        max_memory=max_mem,
        offload_folder=OFFLOAD_DIR,
        trust_remote_code=True,
    )

torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False

In [None]:
# --- NEW CELL (after Cell 4): Task cache & reuse ---
TASK_CACHE_PATH = "/content/tasks_cache.jsonl"

def load_task_cache():
    pool = []
    try:
        with open(TASK_CACHE_PATH, "r", encoding="utf-8") as f:
            for line in f:
                pool.append(json.loads(line))
    except FileNotFoundError:
        pass
    return pool

def save_task_cache(tasks):
    with open(TASK_CACHE_PATH, "w", encoding="utf-8") as f:
        for t in tasks:
            f.write(json.dumps(t, ensure_ascii=False) + "\n")

# try load
_cached_tasks = load_task_cache()


In [None]:
# --- Cell 5: Pre-generate curriculum with cache (REPLACE ENTIRE CELL) ---
def is_valid_task(t):
    return isinstance(t, dict) and bool(t.get("passage","").strip()) and bool(t.get("question","").strip())

# Reuse cached tasks if exist; only top-up to NUM_TASKS if needed.
need = max(0, NUM_TASKS - len(_cached_tasks))
if need > 0:
    new_tasks = []
    for _ in tqdm(range(need), desc="Generating tasks"):
        t = generate_task()
        if is_valid_task(t):
            new_tasks.append(t)
    _cached_tasks.extend(new_tasks)
    save_task_cache(_cached_tasks)

# Keep only valid tasks and trim
tasks = [t for t in _cached_tasks if is_valid_task(t)][:NUM_TASKS]
print(f"Valid tasks found: {len(tasks)}")

# --- Fallback: if Qwen returned nothing usable, synthesize simple tasks locally ---
def synth_task(i: int):
    p = f"Heavy rainfall saturated the soil on a hillside near village {i}, and the slope became unstable."
    q = "What is the most likely effect of saturated soil on a hillside?"
    return {"passage": p, "question": q}

if len(tasks) == 0:
    print("No valid tasks from Challenger; using local fallback tasks.")
    tasks = [synth_task(i) for i in range(max(16, NUM_GENERATIONS * 8))]  # at least a few batches

# Store prompts as JSON strings so reward_func can parse easily
train_prompts = [json.dumps(t, ensure_ascii=False) for t in tasks]
from datasets import Dataset
ds = Dataset.from_dict({"prompt": train_prompts})

print("Dataset size:", len(ds))
print("Sample prompt:", train_prompts[0][:200], "...")


Generating tasks: 100%|██████████| 200/200 [28:28<00:00,  8.54s/it]

Valid tasks found: 200
Dataset size: 200
Sample prompt: {"passage": "A sudden drop in temperature caused the lake to freeze over. The local wildlife adapted by migrating to warmer areas.", "question": "What was the effect of the sudden drop in temperature  ...





In [None]:
# --- FREE CHALLENGER GPU MEMORY BEFORE LOADING SOLVER ---
import gc, torch
try:
    del gptoss_model
    del gptoss_tok
except NameError:
    pass
torch.cuda.empty_cache()
gc.collect()
# --------------------------------------------------------


model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "google/gemma-3-4b-it",
    load_in_4bit = True,   # quantized base
)

# Attach LoRA adapters (trainable heads). Adjust target_modules per Unsloth recommendations.
model = FastLanguageModel.get_peft_model(
    model,
    r=8,
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = FastLanguageModel.for_training(model)
model.to(device)


==((====))==  Unsloth 2025.10.1: Fast Gemma3 patching. Transformers: 4.56.2.
   \\   /|    NVIDIA A100-SXM4-80GB. Num GPUs = 1. Max memory: 79.318 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu126. CUDA: 8.0. CUDA Toolkit: 12.6. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3 does not support SDPA - switching to fast eager.


model.safetensors:   0%|          | 0.00/4.56G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/210 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.json: 0.00B [00:00, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/670 [00:00<?, ?B/s]

Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.05.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.


Unsloth: Making `base_model.model.model.vision_tower.vision_model` require gradients


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Gemma3ForConditionalGeneration(
      (model): Gemma3Model(
        (vision_tower): SiglipVisionModel(
          (vision_model): SiglipVisionTransformer(
            (embeddings): SiglipVisionEmbeddings(
              (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
              (position_embedding): Embedding(4096, 1152)
            )
            (encoder): SiglipEncoder(
              (layers): ModuleList(
                (0-26): 27 x SiglipEncoderLayer(
                  (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
                  (self_attn): SiglipAttention(
                    (k_proj): lora.Linear(
                      (base_layer): Linear(in_features=1152, out_features=1152, bias=True)
                      (lora_dropout): ModuleDict(
                        (default): Dropout(p=0.05, inplace=False)
                      )
                      (lor

In [None]:
# --- Cell 7: Per-generation reward functions (each returns [batch_size] vector) ---
EXPECTED_GENERATIONS = NUM_GENERATIONS  # from Cell 2


def _grade_many(passage_list, question_list, answer_list, batch_size=BATCH_GRADE_SIZE):
    """
    Grade many (passage, question, answer) triples via Qwen in batches.
    Robust to evaluate_batch returning [[score], ...] OR [score, ...] OR a single float.
    Returns: list[float] of length == len(answer_list)
    """
    items = []
    for p, q, a in zip(passage_list, question_list, answer_list):
        items.append({"passage": p, "question": q, "answers": [a]})  # single answer per item

    scores_all = []
    for i in range(0, len(items), batch_size):
        chunk = items[i:i+batch_size]
        scores_chunk = evaluate_batch(chunk)  # could be [[s],[s],...] OR [s,...] OR scalar

        # --- Normalize to one float per item in 'chunk' ---
        normalized = []

        # Case A: scalar → replicate for each item
        if isinstance(scores_chunk, (int, float)):
            normalized = [float(scores_chunk)] * len(chunk)

        # Case B: list
        elif isinstance(scores_chunk, list):
            # B1: list of lists (expected shape)
            if all(isinstance(x, list) for x in scores_chunk):
                for row in scores_chunk[:len(chunk)]:
                    if isinstance(row, list) and row:
                        val = row[0]
                        try:
                            normalized.append(float(val))
                        except:
                            normalized.append(0.0)
                    else:
                        normalized.append(0.0)
                # pad if shorter
                while len(normalized) < len(chunk):
                    normalized.append(0.0)

            # B2: flat list of numbers
            elif all(isinstance(x, (int, float)) for x in scores_chunk):
                for val in scores_chunk[:len(chunk)]:
                    normalized.append(float(val))
                # pad if shorter
                while len(normalized) < len(chunk):
                    normalized.append(0.0)

            # B3: mixed/unknown → fall back to zeros
            else:
                normalized = [0.0] * len(chunk)
        else:
            # Unknown type → zeros
            normalized = [0.0] * len(chunk)

        scores_all.extend(normalized)

    # Ensure final length matches input size
    if len(scores_all) < len(answer_list):
        scores_all += [0.0] * (len(answer_list) - len(scores_all))
    elif len(scores_all) > len(answer_list):
        scores_all = scores_all[:len(answer_list)]

    return scores_all


def _extract_passage_question(prompt_json_str):
    try:
        obj = json.loads(prompt_json_str)
        return obj.get("passage",""), obj.get("question","")
    except Exception:
        return "", ""

def make_reward_func(gen_idx: int):
    """Create a reward function that scores ONLY the gen_idx-th completion per prompt."""
    def reward_func_i(prompts, completions, **kwargs):
        # prompts: List[str]  (JSON strings)
        # completions: List[List[Union[str, dict]]]  (candidates per prompt)
        _ensure_challenger_loaded()
        passages, questions, answers = [], [], []
        for p_str, cand_group in zip(prompts, completions):
            passage, question = _extract_passage_question(p_str)
            # get the i-th completion (pad if missing)
            txt = ""
            if gen_idx < len(cand_group):
                sample = cand_group[gen_idx]
                if isinstance(sample, dict):
                    txt = sample.get("content") or sample.get("text") or sample.get("generated_text") or ""
                else:
                    txt = str(sample)
            passages.append(passage)
            questions.append(question)
            answers.append(txt)

        # grade this generation across the whole batch → returns [batch_size] floats
        return _grade_many(passages, questions, answers)
    return reward_func_i

# Build the list of reward functions, one per generation
reward_funcs = [make_reward_func(i) for i in range(EXPECTED_GENERATIONS)]


In [None]:
# --- Cell 8: GRPO config & trainer (ensure gens/batch match) ---
grpo_cfg = GRPOConfig(
    learning_rate = LR,
    per_device_train_batch_size = BATCH_SIZE,   # e.g., 8
    num_generations = NUM_GENERATIONS,          # e.g., 8
    max_steps = MAX_STEPS,
    logging_steps = 10,
    bf16 = False, fp16 = False,
)

assert BATCH_SIZE % NUM_GENERATIONS == 0, "BATCH_SIZE must be a multiple of NUM_GENERATIONS"
print(f"Using num_generations={NUM_GENERATIONS}, per_device_train_batch_size={BATCH_SIZE}")

# --- Add these asserts before GRPOTrainer(...) ---
assert len(ds) > 0, "Training dataset is empty. Re-run Cell 5 to build tasks."
assert BATCH_SIZE % NUM_GENERATIONS == 0, "BATCH_SIZE must be a multiple of NUM_GENERATIONS"
print(f"Ready to train with {len(ds)} prompts | gens={NUM_GENERATIONS}, batch_size={BATCH_SIZE}, steps={MAX_STEPS}")


trainer = GRPOTrainer(
    model = model,
    tokenizer = tokenizer,
    args = grpo_cfg,
    reward_funcs = reward_funcs,      # <-- pass the list created in Cell 7
    train_dataset = ds,
    dataset_text_field = "prompt",
)


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Using num_generations=2, per_device_train_batch_size=2
Ready to train with 200 prompts | gens=2, batch_size=2, steps=50


In [None]:
# --- Cell 9: Train (REPLACE ENTIRE CELL) ---
# NOTE: The dataset 'ds' holds all prompts; TRL will call reward_func, which batches grading.
trainer.train()


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 200 | Num Epochs = 1 | Total steps = 50
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 2
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 2 x 1) = 4
 "-____-"     Trainable parameters = 16,394,240 of 4,316,473,712 (0.38% trained)
`generation_config` default values have been modified to match model-specific defaults: {'top_p': 0.95}. If this is not desired, please set these values explicitly.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,sampling / sampling_logp_difference / mean,sampling / sampling_logp_difference / max,sampling / importance_sampling_ratio / min,sampling / importance_sampling_ratio / mean,sampling / importance_sampling_ratio / max,kl,rewards / reward_func_i / mean,rewards / reward_func_i / std,rewards / reward_func_i / mean.1,rewards / reward_func_i / std.1
10,0.0,1.45375,0.309359,256.0,256.0,256.0,1.0,0.0,0.0,0.0,0,0,0,0,0,0.000432,0.726875,0.213592,0.726875,0.213592
20,0.0,1.4725,0.251023,256.0,256.0,256.0,1.0,0.0,0.0,0.0,No Log,No Log,No Log,No Log,No Log,0.000821,0.73625,0.190515,0.73625,0.190515
30,0.0,1.45,0.240416,256.0,256.0,256.0,1.0,0.0,0.0,0.0,No Log,No Log,No Log,No Log,No Log,0.000827,0.725,0.166818,0.725,0.166818
40,0.0,1.5925,0.116673,256.0,256.0,256.0,1.0,0.0,0.0,0.0,No Log,No Log,No Log,No Log,No Log,0.000792,0.79625,0.097373,0.79625,0.097373
50,0.0,1.5425,0.194454,256.0,256.0,256.0,1.0,0.0,0.0,0.0,No Log,No Log,No Log,No Log,No Log,0.00084,0.77125,0.179511,0.77125,0.179511


TrainOutput(global_step=50, training_loss=1.6384216360165737e-06, metrics={'train_runtime': 3493.7756, 'train_samples_per_second': 0.057, 'train_steps_per_second': 0.014, 'total_flos': 0.0, 'train_loss': 1.6384216360165737e-06})

In [None]:
save_path = f"{SAVE_DIR}/gemma3_4b_grpo_lora_rzero"
os.makedirs(save_path, exist_ok=True)
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print("Saved to:", save_path)


Saved to: /content/solver_checkpoints/gemma3_4b_grpo_lora_rzero


In [None]:
test = generate_task()
print(test)

prompt = json.dumps(test, ensure_ascii=False)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
    out = model.generate(**inputs, max_new_tokens=80)
ans = tokenizer.decode(out[0], skip_special_tokens=True)
print("\nSolver answer:\n", ans)

score = evaluate_answer(test["passage"], test["question"], ans)
print("\nChallenger score:", score)


{'passage': 'Drought leads to crop failure. Crop failure results in food shortages. Food shortages cause price hikes. Price hikes lead to economic instability.', 'question': 'What is the effect of drought on economic stability?'}

Solver answer:
 {"passage": "Drought leads to crop failure. Crop failure results in food shortages. Food shortages cause price hikes. Price hikes lead to economic instability.", "question": "What is the effect of drought on economic stability?"}


Challenger score: 1.0
