# Run Inference on GRPO Checkpoint

Load a trained GRPO checkpoint and run inference on math problems.

## Cell 1: Imports

In [1]:
from pathlib import Path

import jax

from tunix_hack.models import load_model, load_tokenizer, list_checkpoints, find_checkpoint, restore_checkpoint
from tunix_hack.inference import create_sampler, generate

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

JAX version: 0.8.1
Devices: [CudaDevice(id=0)]


W1202 20:46:45.915603  165749 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1202 20:46:45.918463  165572 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.


## Cell 2: Configuration

In [2]:
# Paths
PROJECT_ROOT = Path("/home/jimnix/gitrepos/tunix-hack")
CKPT_ROOT = PROJECT_ROOT / "outputs" / "checkpoints" / "grpo"

# Model
MODEL_ID = "google/gemma-3-1b-it"
MESH = ((1, 1), ("fsdp", "tp"))

# LoRA config (set USE_LORA=False for pure base model)
USE_LORA = True
LORA_RANK = 16
LORA_ALPHA = 16

# Checkpoint selection (only used if USE_LORA=True)
RUN_NAME = "demo"
CHECKPOINT_STEP = 45

# Generation config
MAX_TOKENS = 256

print(f"Model: {MODEL_ID}")
print(f"USE_LORA: {USE_LORA}")
if USE_LORA:
    print(f"Checkpoint: {RUN_NAME}/step {CHECKPOINT_STEP}")

Model: google/gemma-3-1b-it
USE_LORA: True
Checkpoint: demo/step 45


## Cell 3: Interactive Checkpoint Picker

In [3]:
# List available checkpoints
if USE_LORA:
    checkpoints = list_checkpoints(CKPT_ROOT)
    
    if not checkpoints:
        raise ValueError(f"No checkpoints found in {CKPT_ROOT}")
    
    print("Available checkpoints:")
    print("-" * 40)
    for ckpt in checkpoints:
        marker = " <--" if ckpt["run"] == RUN_NAME and ckpt["step"] == str(CHECKPOINT_STEP) else ""
        print(f"  {ckpt['run']} / step {ckpt['step']}{marker}")
    print("-" * 40)
    
    # Find the selected checkpoint
    ckpt_path = find_checkpoint(CKPT_ROOT, RUN_NAME, CHECKPOINT_STEP)
    if ckpt_path is None:
        raise ValueError(f"Checkpoint not found: {RUN_NAME}/step {CHECKPOINT_STEP}")
    print(f"\nWill load: {ckpt_path}")
else:
    ckpt_path = None
    print("USE_LORA=False: Using pure base model (no checkpoint)")

Available checkpoints:
----------------------------------------
  demo / step 30
  demo / step 40
  demo / step 45 <--
----------------------------------------

Will load: /home/jimnix/gitrepos/tunix-hack/outputs/checkpoints/grpo/demo/actor/45/model_params


## Load Model

In [4]:
# Create mesh and load model
mesh = jax.make_mesh(*MESH)
print(f"Mesh: {mesh}")

print(f"\nLoading {MODEL_ID}...")
model, model_config, model_path = load_model(
    MODEL_ID,
    mesh,
    use_lora=USE_LORA,
    lora_rank=LORA_RANK,
    lora_alpha=LORA_ALPHA,
)
print("Model loaded.")

Mesh: Mesh('fsdp': 1, 'tp': 1, axis_types=(Auto, Auto))

Loading google/gemma-3-1b-it...


  mesh = jax.make_mesh(*MESH)


Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

Model loaded.


## Load Tokenizer

In [5]:
# Load tokenizer
tokenizer = load_tokenizer(model_path)
print("Tokenizer loaded.")

Tokenizer loaded.


## Restore Checkpoint

In [6]:
# Restore checkpoint (if using LoRA)
if USE_LORA and ckpt_path:
    print(f"Restoring checkpoint from: {ckpt_path}")
    restore_checkpoint(model, ckpt_path)
    print("Checkpoint restored!")
else:
    print("Using base model weights (no checkpoint to restore)")

Restoring checkpoint from: /home/jimnix/gitrepos/tunix-hack/outputs/checkpoints/grpo/demo/actor/45/model_params




Checkpoint restored!


## Create Sampler

In [7]:
# Create sampler
print("Creating sampler...")
sampler = create_sampler(model, tokenizer, model_config)
print("Sampler ready!")

Creating sampler...
Sampler ready!


## Generate Helper

In [8]:
# Test generation
def ask(question: str, **kwargs) -> str:
    """Convenience wrapper for generate()."""
    return generate(sampler, mesh, question, max_tokens=MAX_TOKENS, system_prompt='', **kwargs)

print("Ready to generate! Use ask('your question') or generate(sampler, mesh, question)")

Ready to generate! Use ask('your question') or generate(sampler, mesh, question)


## Quick Test

In [9]:
# Quick test
print(ask("What is 2 + 2?"))


2 + 2 = 4




## Test Examples

In [10]:
# Test on a few examples
test_questions = [
    "What is 15 + 27?",
    "If Mary has 5 apples and gives 2 to John, how many apples does Mary have?",
    "A store sells 3 books for $12. How much does one book cost?",
]

print("Testing model...")
print("=" * 60)

for q in test_questions:
    print(f"\nQ: {q}")
    print("-" * 40)
    print(ask(q))
    print("=" * 60)

Testing model...

Q: What is 15 + 27?
----------------------------------------

15 + 27 = 42

So the answer is 42.

Q: If Mary has 5 apples and gives 2 to John, how many apples does Mary have?
----------------------------------------
Mary has 3 apples.

Here's the solution:

* Start with: 5 apples
* Subtract: 2 apples
* Result: 5 - 2 = 3 apples



Q: A store sells 3 books for $12. How much does one book cost?
----------------------------------------
Let the cost of one book be $x$.
The store sells 3 books for $12.
So, the total cost of 3 books is $12.
We can write the equation:
3x = 12
To find the cost of one book, we divide both sides of the equation by 3:
x = \frac{12}{3}
x = 4
Therefore, one book costs $4.

Final Answer: The final answer is $\boxed{4}$


## Cell 11: Interactive Loop

In [11]:
# print("Interactive mode - enter math questions (type 'quit' to exit)")
# print("=" * 60)

# while True:
#     question = input("\nYour question: ").strip()
    
#     if question.lower() in ('quit', 'exit', 'q'):
#         print("Goodbye!")
#         break
    
#     if not question:
#         continue
    
#     print("-" * 40)
#     print(ask(question))
#     print("=" * 60)