# Tunix Gemma Reasoning - Inference

Load and test the trained reasoning model.


In [None]:
import jax
from tunix import models
from pathlib import Path
import re

print(f"JAX devices: {jax.devices()}")


In [None]:
# Load model
checkpoint_path = "./checkpoints/final_model"
model, params = models.load_from_checkpoint(checkpoint_path)

print("Model loaded successfully")


In [None]:
def format_prompt(question):
    return f"""Answer the following question. Show your reasoning step by step, then provide your final answer.

Question: {question}

Format your response as:
<reasoning>
Your step-by-step reasoning here
</reasoning>
<answer>
Your final answer here
</answer>"""

def extract_reasoning_and_answer(text):
    reasoning_match = re.search(r'<reasoning>(.*?)</reasoning>', text, re.DOTALL)
    answer_match = re.search(r'<answer>(.*?)</answer>', text, re.DOTALL)
    
    reasoning = reasoning_match.group(1).strip() if reasoning_match else ""
    answer = answer_match.group(1).strip() if answer_match else ""
    
    return reasoning, answer


In [None]:
# Test questions
test_questions = [
    "If a train travels 120 miles in 2 hours, what is its average speed?",
    "Solve for x: 3x + 7 = 22",
    "Explain how photosynthesis works.",
    "What is the difference between correlation and causation?"
]

for question in test_questions:
    prompt = format_prompt(question)
    response = model.generate(params, prompt, max_tokens=512, temperature=0.7)
    
    reasoning, answer = extract_reasoning_and_answer(response)
    
    print("=" * 60)
    print(f"Question: {question}")
    print("-" * 60)
    print(f"Reasoning:\n{reasoning}")
    print("-" * 60)
    print(f"Answer: {answer}")
    print("=" * 60)
    print()
