# Overview

This notebook evaluates phi-2 on @openati GSM8K dataset for mathematical reasoning

GSM8K consists of 8.5K high quality grade school math problems created by human problem writers.

GSM8K's main difficulty lies in both properly interpreting a question and reasoning through the steps to solve it.

Sampling strategies:

1. At test time, we judge performance by autoregressively sampling a single _low temperature_ solution and checking whether the final answer is correct.
1.

We use a low temperature (T = 0) to generate test@1 samples and we use a higher temperature (T = 0.7) to generate test@100 samples.

- [Paper](https://arxiv.org/abs/2110.14168)
- [OpenAI blog post](https://openai.com/research/solving-math-word-problems)


In [None]:
!git clone https://github.com/openai/grade-school-math --depth 1 ./grade-school-math

In [None]:
%pip install ipywidgets -q

In [None]:
%pip install -e ./grade-school-math

# Evaluation


In [None]:
# setup notebook
import os
import sys

local_path = !pwd
workspace_dir = os.path.abspath(os.path.join(local_path[0], "..", ".."))
print('workspace:', workspace_dir)
sys.path.append(os.path.join(workspace_dir, "models"))
model_path = os.path.join(workspace_dir, '.cache', 'models', 'microsoft', 'phi-2')

In [None]:
import mlx.core as mx

mx.set_default_device(mx.gpu)

In [None]:
import microsoft_phi2_model as phi

model, tokenizer = phi.load(model_path)

In [None]:
import mlx.core as mx
import textwrap
import re

ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"


def extract_answer(completion):
    match = ANS_RE.search(completion)
    if match:
        match_str = match.group(1).strip()
        match_str = match_str.replace(",", "")
        return match_str
    else:
        return INVALID_ANS


def is_correct(model_completion, gt_example):
    gt_answer = extract_answer(gt_example["answer"])
    assert gt_answer != INVALID_ANS
    return extract_answer(model_completion) == gt_answer


def evaluate(examples, verbose=False):
    temp = 0.0
    max_tokens = 512
    tokens = []
    skip = 0
    REPLACEMENT_CHAR = "\ufffd"

    prompt_tokens = [
        tokenizer.encode("Question: " + ex["question"] + "\nAnswer:") for ex in examples
    ][0]

    tokens = prompt_tokens[:]

    for (token, prob), n in zip(
        phi.generate_step(mx.array(prompt_tokens), model, temp), range(max_tokens)
    ):
        if token == tokenizer.eos_token_id:
            break
        tokens.append(token.item())

        # s = "\n".join(textwrap.wrap(s, width=120))
        if verbose:
            ss = tokenizer.decode(tokens)
            if REPLACEMENT_CHAR not in ss:
                sys.stdout.write(ss[skip:])
                sys.stdout.flush()
                skip = len(ss)
    output = tokenizer.decode(tokens)
    line = [l for l in output.split("\n") if l.strip()][-1]
    candidates = re.findall("\d+", line)
    if not candidates:
        return False, output, line
    answer_gold = extract_answer(ex)
    correct = is_correct(output, answer_gold)
    return (
        correct,
        candidates[-1],
        answer_gold,
        output,
        line,
    )

In [None]:
import random
import os
import json


def read_jsonl(path: str):
    with open(path) as fh:
        return [json.loads(line) for line in fh.readlines() if line]


def get_examples(split):
    path = os.path.join(
        "./grade-school-math", "grade_school_math", "data", f"{split}.jsonl"
    )
    examples = read_jsonl(path)

    for ex in examples:
        ex.update(question=ex["question"] + "\n")
        ex.update(answer=ex["answer"] + "<|endoftext|>")

    return examples


split = "train"
train_examples = get_examples(split)
print(f"{len(train_examples)} {split} examples")

In [None]:
ex = random.choice(train_examples)
evaluate([ex], verbose=True)

# Evaluate on the Test Dataset


In [None]:
import random
import numpy as np

split = "test"
test_examples = get_examples(split)
# ds = dataset.GSMDataset(tokenizer=tokenizer, examples=train_examples)

In [None]:
from tqdm.notebook import tqdm

correct = []
pbar = tqdm(total=len(test_examples))
for ex in test_examples:
    ok, a, b, *_ = evaluate(ex)
    correct.append(ok)
    pbar.update(1)
    pbar.set_description(f"accuracy: {np.sum(correct) / len(correct)}")