# Testing on GSM8k Math reasoning

In this notebook, we show how Ally can learn and improve math reasoning skills based on [GSM8k dataset](https://github.com/openai/grade-school-math).

We use a handful of annotated examples from train set, and verify the answer accuracy on the full test set that consists of 1319 examples.

In [1]:
import sys
sys.path.append('../')

import pandas as pd
from datasets import load_dataset
gsm8k = load_dataset("gsm8k", "main")

df_train = pd.DataFrame({'question': gsm8k['train']['question'], 'answer': gsm8k['train']['answer']})
df_test = pd.DataFrame({'question': gsm8k['test']['question'], 'answer': gsm8k['test']['answer']})

df_train.head()

  from .autonotebook import tqdm as notebook_tqdm


Unnamed: 0,question,answer
0,Natalia sold clips to 48 of her friends in Apr...,Natalia sold 48/2 = <<48/2=24>>24 clips in May...
1,Weng earns $12 an hour for babysitting. Yester...,Weng earns 12/60 = $<<12/60=0.2>>0.2 per minut...
2,Betty is saving money for a new wallet which c...,"In the beginning, Betty has only 100 / 2 = $<<..."
3,"Julie is reading a 120-page book. Yesterday, s...",Maila read 12 x 2 = <<12*2=24>>24 pages today....
4,James writes a 3-page letter to 2 different fr...,He writes each friend 3*2=<<3*2=6>>6 pages a w...


The following code is used to evaluate the quality of answers by comparing numbers in the outputs:

In [2]:
import re

def extract_and_convert_numbers(text):
  pattern = "\d{1,3}(?:,\d{3})*\.?\d*"
  numbers = re.findall(pattern, text)
  return [int(num.replace(',', '').split('.')[0]) for num in numbers if num.replace(',', '').split('.')[0]]

def evaluate_answers(ground_truth, prediction):
  pred = extract_and_convert_numbers(prediction)
  gt = extract_and_convert_numbers(ground_truth)                             
  return any(p == gt[-1] for p in pred)

Now we can create and execute baseline agent. We start with the naive baseline that just expects answer given the question, by following the prompt template, without any additional instructions:
```
Q: {question}
A: {answer}
```

In [3]:
from ally.skills.skillset import LinearSkillSet
from ally.skills.base import TransformSkill

skills = LinearSkillSet(skills=[
	TransformSkill(
		name='math_solver',
		# we start with no instructions then explain how agent can learn more details
		instructions='',
		# instructions=prompt,
		input_template='Q: {question}',
		# here is the baseline established in Kojima et al., 2022 paper
		# output_template='A: Let’s think step by step. {rationale}\nFinal numerical answer:{answer}',
		output_template=[{
      "name": "answer",
      "description": "The final numerical answer",
		}]
	)
])

In [4]:
from ally.runtimes.openai import OpenAIRuntime
from app.core.settings import settings

openai_runtime = OpenAIRuntime(
  verbose=True,
  api_key=settings.openai_api_key,
  gpt_model_name="gpt-3.5-turbo",
)
openai_teacher_runtime = OpenAIRuntime(
  verbose=True,
  api_key=settings.openai_api_key,
  gpt_model_name="gpt-4",
)

In [5]:
from ally.agents.base import Agent
from ally.environments.base import StaticEnvironment


agent = Agent(
	skills=skills,
  runtimes={
    'openai': openai_runtime
	},
  teacher_runtimes={
    'openai': openai_teacher_runtime
	},
	environment=StaticEnvironment(
		df=df_train,
	)
)

In [6]:
# run baseline agent on a test dataset
result_baseline = agent.run(df_test.drop(columns='answer')[:5])

  0%|          | 0/5 [00:00<?, ?it/s]

 40%|████      | 2/5 [00:00<00:00,  3.00it/s]

 60%|██████    | 3/5 [00:01<00:01,  1.74it/s]

 80%|████████  | 4/5 [00:02<00:00,  1.56it/s]

100%|██████████| 5/5 [00:03<00:00,  1.35it/s]

100%|██████████| 5/5 [00:03<00:00,  1.26it/s]


In [7]:
# evaluate baseline results
accuracy = StaticEnvironment(df=df_test, matching_function=evaluate_answers).get_feedback(skills, result_baseline).get_accuracy()

print(f'Baseline accuracy: {accuracy["answer"]}')

Baseline accuracy: 0.4


# Improve the baseline

The agent can improve its initial skill by learning new instructions from provided ground truth signals:

In [8]:
agent.learn(batch_size=5, learning_iterations=5)

  0%|          | 0/5 [00:00<?, ?it/s]

 40%|████      | 2/5 [00:06<00:09,  3.06s/it]

 60%|██████    | 3/5 [00:08<00:05,  2.66s/it]

 80%|████████  | 4/5 [00:09<00:02,  2.06s/it]

100%|██████████| 5/5 [00:10<00:00,  1.77s/it]

100%|██████████| 5/5 [00:15<00:00,  3.13s/it]

Predictions and feedback:





  0%|          | 0/5 [00:00<?, ?it/s]

 40%|████      | 2/5 [00:04<00:06,  2.29s/it]

 60%|██████    | 3/5 [00:10<00:07,  3.78s/it]

 80%|████████  | 4/5 [00:12<00:03,  3.31s/it]

100%|██████████| 5/5 [00:16<00:00,  3.25s/it]

100%|██████████| 5/5 [00:22<00:00,  4.41s/it]

Predictions and feedback:





  0%|          | 0/5 [00:00<?, ?it/s]

 40%|████      | 2/5 [00:04<00:07,  2.33s/it]

 60%|██████    | 3/5 [00:10<00:07,  3.85s/it]

 80%|████████  | 4/5 [00:15<00:04,  4.28s/it]

100%|██████████| 5/5 [00:22<00:00,  5.06s/it]

100%|██████████| 5/5 [00:26<00:00,  5.20s/it]

Predictions and feedback:





  0%|          | 0/5 [00:00<?, ?it/s]

 40%|████      | 2/5 [00:03<00:04,  1.66s/it]

 60%|██████    | 3/5 [00:06<00:04,  2.38s/it]

 80%|████████  | 4/5 [00:11<00:03,  3.17s/it]

100%|██████████| 5/5 [00:16<00:00,  3.99s/it]

100%|██████████| 5/5 [00:21<00:00,  4.31s/it]

Predictions and feedback:





  0%|          | 0/5 [00:00<?, ?it/s]

 40%|████      | 2/5 [00:06<00:09,  3.11s/it]

 60%|██████    | 3/5 [00:11<00:07,  3.90s/it]

 80%|████████  | 4/5 [00:12<00:02,  2.80s/it]

100%|██████████| 5/5 [00:18<00:00,  3.98s/it]

100%|██████████| 5/5 [00:22<00:00,  4.42s/it]

Predictions and feedback:





In [12]:
from ally.utils.logs import print_text


print_text(agent.skills['math_solver'].instruction_template)

In [14]:
result_new = agent.run(df_test.drop(columns='answer')[:10])
accuracy = StaticEnvironment(df=df_test, matching_function=evaluate_answers).get_feedback(skills, result_new).get_accuracy()
print(f'New accuracy: {accuracy["answer"]}')

  0%|          | 0/10 [00:00<?, ?it/s]

 20%|██        | 2/10 [00:01<00:05,  1.51it/s]

 30%|███       | 3/10 [00:02<00:05,  1.25it/s]

 40%|████      | 4/10 [00:08<00:15,  2.65s/it]

 50%|█████     | 5/10 [00:11<00:14,  2.94s/it]

 60%|██████    | 6/10 [00:15<00:12,  3.18s/it]

 70%|███████   | 7/10 [00:21<00:12,  4.08s/it]

 80%|████████  | 8/10 [00:22<00:06,  3.27s/it]

 90%|█████████ | 9/10 [00:27<00:03,  3.90s/it]

100%|██████████| 10/10 [00:34<00:00,  4.60s/it]

100%|██████████| 10/10 [00:40<00:00,  4.04s/it]

New accuracy: 0.2857142857142857



