Copyright 2024 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

---

# GSM8K evaluation using RecurrentGemma

The [GSM8K dataset](https://arxiv.org/pdf/2110.14168.pdf) presents a good evaluation challenge for small models for several reasons:

1. **Conceptual Simplicity:** While the problems in GSM8K require multi-step reasoning, they primarily involve elementary mathematical concepts and basic arithmetic operations. This makes the dataset accessible to smaller models that may not have the capacity to handle complex mathematical reasoning.

2. **Linguistic Diversity:** GSM8K emphasizes linguistic diversity, ensuring that problems are not simply variations of the same template. This forces models to generalize their understanding of language and mathematical concepts, rather than relying on superficial pattern matching.

3. **Moderate Difficulty:** The problems in GSM8K are challenging enough to test the limits of small models without being completely intractable. This allows for meaningful evaluation and comparison of different models and methods within a reasonable difficulty range.

4. **Natural Language Solutions:** GSM8K provides solutions in natural language, encouraging models to develop verbal analytical skills and produce human-interpretable reasoning steps. This is particularly relevant for smaller models that may struggle with purely symbolic or equation-based solutions.

By focusing on grade-school math concepts and emphasizing linguistic diversity, GSM8K provides a valuable benchmark for evaluating the informal reasoning abilities of smaller language models and identifying areas for improvement.


## Installation

In [None]:
! pip install git+https://github.com/google-deepmind/recurrentgemma.git#egg=recurrentgemma[jax]
! pip install --user kaggle
! pip install datasets  # Required for the task

## Downloading the checkpoint

"To use RecurrentGemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:

1. Visit https://www.kaggle.com/ and create an account.
2. Go to your account settings, then the 'API' section.
3. Click 'Create new token' to download your key.
4. You can either login using the UI interface or by setting your Kaggle username and key via the Colab secrets.

Then run the cell below.

In [None]:
import os
from google.colab import userdata
import kagglehub

try:
  os.environ["KAGGLE_KEY"] = userdata.get("KAGGLE_KEY")
  os.environ["KAGGLE_USERNAME"] = userdata.get("KAGGLE_USERNAME")
except userdata.SecretNotFoundError:
  kagglehub.login()

If everything went well, you should see:
```
Kaggle credentials set.
Kaggle credentials successfully validated.
```

Now select and download the checkpoint you want to try.

In [None]:
# @title Python imports

import pathlib
import re
import datasets

import sentencepiece as spm
from recurrentgemma import jax as recurrentgemma

In [None]:
VARIANT = '2b' # @param ['2b', '2b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/recurrentgemma/flax/{VARIANT}')
weights_dir = pathlib.Path(weights_dir)
ckpt_path = weights_dir / VARIANT
vocab_path = weights_dir / 'tokenizer.model'

## Load GSM8K dataset

In [None]:
gsm8k = datasets.load_dataset("gsm8k", "main", cache_dir='/tmp')
gsm8k_train, gsm8k_test = gsm8k['train'], gsm8k['test']

In [None]:
# @title Testing library

def find_numbers(x: str) -> list[str]:
  """Finds all numbers in a string."""
  # Search for number, possibly negative (hyphen), with thousand separators
  # (comma), and with a decimal point (period inbetween digits).
  numbers = re.compile(
      r'-?[\d,]*\.?\d+',
      re.MULTILINE | re.DOTALL | re.IGNORECASE,
  ).findall(x)
  return numbers


def find_number(x: str,
                answer_delimiter: str = 'The answer is') -> str:
  """Finds the most relevant number in a string."""
  # If model uses the answer delimiter, then select the first number following
  # that format.
  if answer_delimiter in x:
    answer = x.split(answer_delimiter)[-1]
    numbers = find_numbers(answer)
    if numbers:
      return numbers[0]

  # In general, select the last number in the string.
  numbers = find_numbers(x)
  if numbers:
    return numbers[-1]
  return ''


def maybe_remove_comma(x: str) -> str:
  # Example: 5,600 -> 5600
  return x.replace(',', '')

In [None]:
# @title GSM8K Prompts

PREAMBLE = """As an expert problem solver solve step by step the following mathematical questions."""

# The default gsm8k prompt from the CoT paper
# https://arxiv.org/pdf/2201.11903.pdf page 35.

PROMPT = """Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6.

Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars. The answer is 5.

Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?
A: Leah had 32 chocolates and Leah's sister had 42. That means there were originally 32 + 42 = 74 chocolates. 35 have been eaten. So in total they still have 74 - 35 = 39 chocolates. The answer is 39.

Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?
A: Jason had 20 lollipops. Since he only has 12 now, he must have given the rest to Denny. The number of lollipops he has given to Denny must have been 20 - 12 = 8 lollipops. The answer is 8.

Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?
A: He has 5 toys. He got 2 from mom, so after that he has 5 + 2 = 7 toys. Then he got 2 more from dad, so in total he has 7 + 2 = 9 toys. The answer is 9.

Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?
A: There are 4 days from monday to thursday. 5 computers were added each day. That means in total 4 * 5 = 20 computers were added. There were 9 computers in the beginning, so now there are 9 + 20 = 29 computers. The answer is 29.

Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?
A: Michael initially had 58 balls. He lost 23 on Tuesday, so after that he has 58 - 23 = 35 balls. On Wednesday he lost 2 more so now he has 35 - 2 = 33 balls. The answer is 33.

Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
A: She bought 5 bagels for $3 each. This means she spent 5 * $3 = $15 on the bagels. She had $23 in beginning, so now she has $23 - $15 = $8. The answer is 8."""

## Load and prepare your LLM's checkpoint for use with Flax.

Start by loading the weights of your model.

In [None]:
# Load parameters
params = recurrentgemma.load_parameters(ckpt_path, 'single_device')

Then load the tokenizer.

In [None]:
vocab = spm.SentencePieceProcessor()
vocab.Load(str(vocab_path))

Finally, build a sampler from the model configuration deduced from the checkpoint.

In [None]:
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(params, preset=recurrentgemma.Preset.RECURRENT_GEMMA_2B_V1)
model = recurrentgemma.Griffin(model_config)

# Create a sampler with the right param shapes for the GSM8K prompt below
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)

## Main Evaluation loop

You should expect a score of 19.33% with the 2B model, on TPUv2. The evals take some time to run.

In [None]:
%%time
all_responses = {}
short_responses = {}
idx = 0
correct = 0

TEMPLATE = """
Q: {question}
A:"""

for task_id, problem in enumerate(gsm8k_test):

  if task_id in all_responses: continue

  # Print Task ID
  print(f"task_id {task_id}")

  # Formulate and print the full prompt
  full_prompt = (PREAMBLE +'\n\n' + PROMPT + '\n' +
                 TEMPLATE.format(question=problem['question']))

  input_batch = [full_prompt]
  response = sampler(input_strings=input_batch, total_generation_steps=1024)
  print(response.text)

  all_responses[task_id] = response.text[0].split('\nQ:')[0]
  short_responses[task_id] = maybe_remove_comma(find_number(all_responses[task_id]))
  print(f"Short answer: {short_responses[task_id]}")
  try:
    correct += float(maybe_remove_comma(
        find_number(problem['answer']))) == float(short_responses[task_id])
  except:
    correct += maybe_remove_comma(
        find_number(problem['answer'])) == maybe_remove_comma(
            find_number(short_responses[task_id]))
  print('-'*40)
  print(f"Ground truth answer {problem['answer']}")
  print(f"Short ground truth answer {find_number(problem['answer'])}")
  print(f"Correct: {correct} out of {idx+1}")
  print("="*40)
  idx += 1
