# GSM8K Evaluation Notebook

This notebook evaluates Gemma JAX on the GSM8K mathematical reasoning dataset.

### Imports and other boilerplate


In [None]:
# Uncomment if using Colab

from google.colab import drive
drive.mount('/content/drive')

import os # Git clone the repository if it does not exist, and cd into it.
os.chdir('/content/drive/My Drive')

if not os.path.exists('gemma-jax'):
  !git clone https://github.com/baricev/gemma-jax

os.makedirs('gemma-jax', exist_ok=True)
os.chdir('gemma-jax')

print(f"Current working directory: {os.getcwd()}")

### Install

`!pip install jax[tpu] orbax datasets` -- quiet should also work

In [None]:
! pip install -e . --quiet

### Package Imports

Import core gemma-jax functions and datastructures.

In [None]:
import time
import argparse
from functools import partial
from pathlib import Path
import jax
from jax import Array
import jax.numpy as jnp
import re
import datasets
import json
from tqdm import tqdm

# Assuming gemma_jax is installed in editable mode (`pip install -e .`)
from gemma_jax.core.weights import (
    create_gemma3_config,
    create_device_mesh,
    load_model
)
from gemma_jax.core.model import (
    forward_fn,
    setup_scan_fn,
    scan_generate_step
)
from gemma_jax.core.cache import (
  KVCache,
  LayoutType,
  init_cache,
  aliases_map,
  layout_map,
  shard_dims,
  SEQUENCE_HEADS,
  HEADS_SEQUENCE,
)
from gemma_jax.core.rope import load_rope_cache
from gemma_jax.core.sp_tokenizer import SentencePieceTokenizer, process_and_pad_inputs, encode_text, decode_tokens, format_prompt
from gemma_jax.core.inference import greedy_sample

## Configuration Defaults

Set the default configuration values here. These replace the command-line arguments used in the script version.

**Important:** Update `CHECKPOINT_PATH` and `TOKENIZER_PATH` to your actual absolute paths.



In [None]:
root_dir = Path.cwd() # Assuming this notebook is in the `gemma_jax/examples` directory
checkpoint_path = root_dir / "4b"               # TODO: Replace with your ABSOLUTE path
tokenizer_path = root_dir / "tokenizer.model"   # TODO: Replace with your ABSOLUTE path

try:
  assert tokenizer_path.exists(), f"Tokenizer path {tokenizer_path} does not exist."
except AssertionError:
  # If the tokenizer path is not set, assume we are running a notebook in the examples directory
  root_dir = Path(__file__).parent.parent         # Adjust this if the notebook is in a different directory
  tokenizer_path = root_dir / "tokenizer.model"
  checkpoint_path= root_dir / "4b"              # TODO: Replace with your ABSOLUTE path

print(f"Using default tokenizer path: {tokenizer_path}")
print(f"Using default checkpoint path: {checkpoint_path}")

### Model Settings

In [None]:
model_size = 4              # Gemma model size (e.g., 4 for 4B). Choices: [1, 4, 12, 27]
cache_length = 1024 * 2     # KV cache length.
padded_input_size = 2048    # Padded input sequence length.
window_size = 1024          # Attention window size for sliding window attention.
batch_size = 4              # Batch size for inference.
generate_steps = 1024          # Number of tokens to generate after prefill.
dtype_str = "bfloat16"       # Data type for model parameters. Choices: ['bfloat16', 'float16', 'float32']

dtype_map = {
    "bfloat16": jnp.bfloat16,
    "float16": jnp.float16,
    "float32": jnp.float32,
}
model_dtype = dtype_map[dtype_str]

## Setup: Initialization

The cell below initializes the tokenizer, model configuration, device mesh, loads the model parameters, and initializes the KV and RoPE caches.



In [None]:
print("Starting setup...")
start_setup = time.time()

# 1. Model Config
config = create_gemma3_config(
    model_size=model_size,
    batch_size=batch_size,
    padded_input_size=padded_input_size,
    cache_length=cache_length,
    window_size=window_size,
    generate_steps = generate_steps,
)
print(f"Model Config created for Gemma-{model_size}b")

# 2. Device Mesh
num_devices = len(jax.devices())
# TODO: Configure mesh shape
mesh = create_device_mesh()
print(f"Device mesh created with shape: {mesh.shape}")

# 3. Load Model
assert checkpoint_path.exists(), f"Checkpoint path {checkpoint_path} does not exist."
assert checkpoint_path.is_absolute(), f"Checkpoint path {checkpoint_path} must be an absolute path."

print(f"Loading model from: {checkpoint_path} (dtype: {dtype_str})...")
load_start = time.time()
model = load_model(checkpoint_path, mesh, config, dtype=model_dtype)
print(f"Model loaded in {time.time() - load_start:.2f}s")

# 4. Initialize Caches
# rope_cache = load_rope_cache(mesh, config)  # RoPE cache dtype is float32 internally
rope_cache = None  # pass None to compute embeddings at runtime

# Configure memory layout, sharding or chache update functions in "cache.py"
# or use pre-configured settings (SEQUENCE_HEADS, HEADS_SEQUENCE)
cache_layout =  SEQUENCE_HEADS

cache = init_cache(
    mesh=mesh,
    config=config,
    dtype=jnp.bfloat16,
    kind=cache_layout,
    layout_map=layout_map,
)

print(f"Setup complete in {time.time() - start_setup:.2f}s")

# 4. Create Tokenizer
assert tokenizer_path.exists(), f"Tokenizer path {tokenizer_path} does not exist."
assert tokenizer_path.is_absolute(), f"Tokenizer path {tokenizer_path} must be an absolute path."
tokenizer = SentencePieceTokenizer(tokenizer_path)
print(f"Tokenizer loaded from: {tokenizer_path}")

print(f"Setup complete in {time.time() - start_setup:.2f}s")

## Evaluation Constants


In [None]:
# --- Role IDs ---
ROLE_USER, ROLE_MODEL, ROLE_SYSTEM = 0, 1, 2
ROLE_STR = {ROLE_USER: "user", ROLE_MODEL: "model", ROLE_SYSTEM: "system"}

# --- Tokenizer Constants ---
PAD_ID: int = 0
EOS_ID: int = 1
BOS_ID: int = 2


## Prompt Templates

GSM8K prompts taken from:

[google-deepmind/gemma/tree/main/colabs/old/gsm8k_eval.ipynb](https://github.com/google-deepmind/gemma/blob/2a162e21be390aa0ec635deb7176fb64fb1868b1/colabs/old/gsm8k_eval.ipynb)

See commit: 2a162e21be390aa0ec635deb7176fb64fb1868b1 for original work.


In [None]:
from gemma_jax.core.gsm8k_eval import PREAMBLE, EXTRA_3_SHOTS, FEWSHOT , PROMPT as EIGHT_SHOT_PROMPT

# --- System prompt ---
PREAMBLE = """As an expert problem solver solve step by step the following mathematical questions."""

FEWSHOT = """
Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6.

Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars. The answer is 5.
"""

## Evaluation Setup

Creat a text processing function and inference function. These are passed to the benchmarking function to run the evaluation.


In [None]:
process_partial = partial(
    process_and_pad_inputs,
    max_sequence_length=padded_input_size,
    cache_len=cache_length,
    tokenizer=tokenizer,
)

### Prefill and Generate Functions

The `prefill_partial` function is used to prefill the model with the input tokens. It takes the padded input IDs, positions, and attention mask as input and returns the logits and updated cache. Note: The cache object is updated in-place by prefill_partial

Setup the scan function with the model, cache, and other parameters using `setup_scan_fn`.



In [None]:
# Prefill inputs
prefill_partial = partial(
    forward_fn,
    write_index=0,
    model=model,
    cache=cache,
    rope_cache=rope_cache,
    config=config,
    layout=cache_layout,
)

# Auto-regressive generation
generate_partial = partial(
    scan_generate_step,
    model=model,
    rope_cache=rope_cache,
    config=config,
    layout=cache_layout,
)


## Helpers

In [None]:
def find_numbers(x: str) -> list[str]:
    return re.compile(r"-?[\d,]*\.?\d+").findall(x)

def find_number(x: str, answer_delimiter: str = "The answer is") -> str:
    if answer_delimiter in x:
        answer = x.split(answer_delimiter)[-1]
        numbers = find_numbers(answer)
        if numbers:
            return numbers[0]
    numbers = find_numbers(x)
    if numbers:
        return numbers[-1]
    return ""

def normalize_answer(ans: str) -> str:
    ans = ans.strip().replace(",", "")
    ans = re.sub(r"[^\d\.\-]+$", "", ans)
    return ans

# --- Utilities ----

_NUM_RE = r"-?[\d,]*\.?\d+"


def _strip(s: str) -> str:
    return s.strip(" \n\t")


def _visible_segment(raw: str) -> str:
    """
    Pull out the *model-visible* text from the raw dump emitted by
    `chat_once_batched`.  Supports both:
         “<start_of_turn>model … <end_of_turn>”
         “user\n … model\n …”
    """
    # 1) Try the explicit tag format first
    m = re.search(r"<start_of_turn>model(.*?)<end_of_turn>", raw, flags=re.S)
    if m:
        return _strip(m.group(1))

    # 2) Fallback: last “…\nmodel\n” chunk
    if "\nmodel" in raw:
        # Split by the LAST occurrence to avoid grabbing few-shot exemplars
        head, _, tail = raw.rpartition("\nmodel")
        return _strip(tail)

    # 3) Nothing matched -return full string (best effort)
    return _strip(raw)


def _last_answer_block(visible: str) -> str:
    """
    Given the visible text, grab the **final** answer paragraph:

        … Q: <few-shot #N>
           A: <few-shot #N answer>
        Q: <real question>
        A: <real answer>

    We split on *lines* that start with 'Q:' or 'A:' and keep the
    trailing answer.
    """
    # Normalise line endings & get rid of leading markdown bullets
    lines = [l.lstrip("* ").rstrip() for l in visible.replace("\r", "").split("\n")]

    # Walk backwards to find the last line that *begins* with 'A:'
    for i in range(len(lines) - 1, -1, -1):
        if lines[i].startswith(("A:", "Answer:", "**Answer:**")):
            # Everything FROM that line onwards is the answer block
            return "\n".join(lines[i:])

    # Fallback -no marker, keep entire visible
    return visible


def _extract_number(ans_block: str) -> str:
    """
    Pull the numeric answer out of the final answer block.
    Priority:
        1.  “The answer is <num>”
        2.  “Answer: <blah> <num>”
        3.  Last number in the block
    TODO:
    handle model outputs with "commas" or other variations, for example, 'model_answer': 99,076.92 ,   'predicted': '99076.92', is a valid answer, but missed by the parser.

    """
    # 1) “The answer is<num>”
    m = re.search(r"The answer is\s*(%s)" % _NUM_RE, ans_block, flags=re.I)
    if m:
        return _strip(m.group(1)).replace(",", "")

    # 2) “Answer:” / “**Answer:**”
    m = re.search(r"Answer[:\s\*]*.*?(%s)" % _NUM_RE, ans_block, flags=re.I | re.S)
    if m:
        return _strip(m.group(1)).replace(",", "")

    # 3) Anything else -last number in the block
    nums = re.findall(_NUM_RE, ans_block)
    return nums[-1].replace(",", "") if nums else ""


def _parse_reply(raw_reply: str) -> tuple[str, str]:
    """
    Returns (visible_str, numeric_prediction)
    """
    visible = _visible_segment(raw_reply)
    answer_block = _last_answer_block(visible)
    num = _extract_number(answer_block)
    return visible, num



## Evaluation

Key functions from `chat.py` are `create_empty_state_batched` (used to initialize stateful object and `chat`, which orchestrates the inference loop.


In [None]:
from gemma_jax.core.conversation_state import create_empty_state_batched, chat

QUESTION_TEMPLATE = "\nQ: {question}\nA:"

def evaluate_gsm8k_batched(
    batch_size: int = 4,
    save_path: str = "report.json",
    generate_steps: int = 256,
    verbose: bool = True,
    save_every: int = 50,
    print_every: int = 20,
    max_num_examples: int | None = None,
    max_turns: int = 10,
    tokenizer=tokenizer,
    config=config,
    benchmark=FEWSHOT,
):

    #  0.  Load data & assemble prompts
    gsm8k = datasets.load_dataset("gsm8k", "main")["test"]
    if max_num_examples is not None:
        gsm8k = gsm8k.select(range(max_num_examples))


    prompts = [
        format_prompt(
            f"{PREAMBLE}\n{FEWSHOT}{QUESTION_TEMPLATE.format(question=ex['question'])}"
        )
        for ex in gsm8k
    ]
    ground_truths = [ex["answer"] for ex in gsm8k]
    questions = [ex["question"] for ex in gsm8k]
    num_examples = len(prompts)
    print(f"Number of examples: {num_examples}")

    #  1.  Initialise chat state
    conv_state = create_empty_state_batched(
        batch_size=config.batch_size,
        cache_length= config.cache_length,
        max_turns=max_turns,
        with_trace=False,
        pad_id=PAD_ID,
    )

    # 2. Set up the initial conversation state
    chat_partial = partial(chat,
        prefill_partial,
        process_partial,
        setup_scan_fn,
        generate_partial,
        tokenizer,
        config,
    )


    # Book-keeping accumulators
    results: list[dict] = []
    correct = 0
    truncations = 0

    print(f"Evaluating{num_examples} examples " f"in batches of{batch_size} …")

    #  3.  Main loop -batched inference + parsing
    for i in tqdm(range(0, num_examples, batch_size), desc="GSM8K batched"):

        batch_prompts = prompts[i : i + batch_size]

        # Guard: last batch may be smaller than `batch_size`
        cur_bs = len(batch_prompts)
        if cur_bs < batch_size:
            # Skip if not multiple of batch size
            continue
        else:
            conv_state_batch = conv_state

        _, batch_replies = chat_partial(
            conv_state_batch,
            batch_prompts,
            generate_steps=generate_steps,
            role=ROLE_USER,
        )

        # Vectorised parsing (no JIT -but consistent shape for later)
        batch_vis_preds = list(map(_parse_reply, batch_replies))

        #  4.  Per-example metrics & logging
        for j, ((visible, pred), question, gt_raw) in enumerate(
            zip(
                batch_vis_preds,
                questions[i : i + cur_bs],
                ground_truths[i : i + cur_bs],
            )
        ):
            idx = i + j
            truth = _extract_number(gt_raw)
            is_correct = pred == truth

            # Simple truncation heuristic -no answer block found
            is_truncated = pred == ""

            if is_correct:
                correct += 1
            if is_truncated:
                truncations += 1

            results.append(
                {
                    "idx": idx,
                    "question": question,
                    "ground_truth": truth,
                    "model_answer": visible,
                    "predicted": pred,
                    "is_correct": is_correct,
                    "truncated": is_truncated,
                    "raw": batch_replies[j],
                }
            )

            # TODO: print all incorrect answers by addin: or not is_correct):
            if verbose and ((idx + 1) % print_every == 0):
                running_acc = correct / (idx + 1)
                running_trunc = truncations / (idx + 1)
                print(
                    f"\n--- Example{idx+1}/{num_examples}"
                    f"\nQ: {question}"
                    f"\nGT: {truth}\nPred: {pred} | Correct: {is_correct}"
                    f"\nRunning Acc: {running_acc:.2%} |Trunc rate: {running_trunc:.2%}"
                )

            # periodic checkpoint
            if (idx + 1) % save_every == 0:
                with open(save_path, "w") as fp:
                    json.dump(results, fp, indent=2)
                if verbose:
                    print(f"[checkpoint] saved{idx+1} examples → {save_path}")

    #  5.  Final stats + save
    accuracy = correct / len(results) if len(results) > 0 else 0
    print(
        f"\nGSM8K accuracy: {accuracy:.2%} "
        f"({correct}/{len(results)})  |truncations: {truncations}"
    )

    with open(save_path, "w") as fp:
        json.dump(results, fp, indent=2)
    print(f"[final] wrote results to {save_path}")

    return accuracy, results



## Results

### Execute the cell below to run the full evaluation

Note: This may take significant time depending on hardwar

In [None]:
### Execute this cell to run the full evaluation
# Note: This may take significant time depending on hardware

accuracy, results = evaluate_gsm8k_batched(
    batch_size=1,
    generate_steps=1024,
    save_every=100,
    print_every=40,
    verbose=True,
    max_num_examples=120
)

## Gemma 3 Technical Report


The Gemma 3 technical report  (https://arxiv.org/pdf/2503.19786v1) provides the following results for the GSM8K benchmark:


- for the instruction fine-tuned 4B model (the one used in this notebook) they report a score of 89.2% (8-shot, CoT).
- The score for the pre-trained model was 38.4%.

Our results: 
- 61.67% , on 120 examples, using the FEWSHOT prompt, and with no CoT.

See Table 10, and Table 18 in the paper for full results.