In [None]:
!pip install datasets
!pip install textstat
!pip install bitsandbytes

Collecting datasets
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.1-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.4/491.4 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (

In [None]:
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from datasets import load_dataset
import re
import numpy as np
from tqdm import tqdm
import torch
import nltk
import textstat
import random
import pickle
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import spacy
from collections import defaultdict

In [None]:
def get_dataset():
    train_dataset = load_dataset("openai/gsm8k", "main", split='train')
    test_dataset = load_dataset("openai/gsm8k", "main", split='test')
    return train_dataset, test_dataset

In [None]:
def get_model(model_name):
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True
    )
    if model_name == "wizardmath":
        wizardmath_tokenizer = AutoTokenizer.from_pretrained("WizardLM/WizardMath-7B-V1.1")
        wizardmath_model = AutoModelForCausalLM.from_pretrained(
            "WizardLM/WizardMath-7B-V1.1",
            quantization_config=quantization_config,
            device_map={"": 0},
            torch_dtype=torch.float16
        )
        return {
            'model': wizardmath_model,
            'model_name': "wizardmath",
            'tokenizer': wizardmath_tokenizer,
            'cost_per_token': 0.7
        }
    elif model_name == "phi2":
        phi2_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
        phi2_model = AutoModelForCausalLM.from_pretrained(
            "microsoft/phi-2",
            quantization_config=quantization_config,
            device_map={"": 0},
            torch_dtype=torch.float16,
            trust_remote_code=True
        )
        return {
            'model': phi2_model,
            'model_name': "phi2",
            'tokenizer': phi2_tokenizer,
            'cost_per_token': 0.13
        }

In [None]:
def extract_answer(answer_text):
    # The final answer in GSM8K follows the '####' pattern
    match = re.search(r'####\s*(-?\d+)', answer_text)
    if match:
        return match.group(1).strip()
    return None

In [None]:
def process_problem(problem, model_index, models):
    prompt = f"""

Follow these instructions:
1. Work through the problem step by step
2. Calculate the numerical answer
3. On the last line, write ONLY: #### <numerical answer>. Do not add any units like "kg" or "m", or any currency symbols like "$".
4. Do not write anything after the final answer

-------------------
EXAMPLE FORMAT:
Step 1: [explanation]
Step 2: [explanation]
Final calculation: [calculation]
#### [numerical answer]
-------------------

NOW SOLVE THE PROBLEM CORRECTLY: {problem['question']}
"""
    # print("Entered global process problem")
    model_obj = models[model_index]['model']
    tokenizer = models[model_index].get('tokenizer', None)
    if tokenizer:
        tokenizer = models[model_index]['tokenizer']

    # if models[model_index]['model_name'] == "wizardmath":
    inputs = tokenizer(prompt, return_tensors="pt").to(model_obj.device)
    outputs = model_obj.generate(
        inputs.input_ids,
        max_new_tokens=1024,
        temperature=0.1,
        do_sample=True,
        attention_mask=inputs.attention_mask,
        # pad_token_id=tokenizer.eos_token_id,
    )
    full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)

    prompt_end = full_output.find(f"NOW SOLVE THE PROBLEM CORRECTLY: {problem['question']}")
    if prompt_end != -1:
        prompt_end = prompt_end + len(f"NOW SOLVE THE PROBLEM CORRECTLY: {problem['question']}")
        model_response = full_output[prompt_end:].strip()
    else:
        model_response = full_output

    hash_match = re.search(r'####\s*([\$]?\s*\d+(?:\.\d+)?)', model_response)
    if hash_match:
        answer_text = hash_match.group(1)
        numeric_match = re.search(r'(\d+(?:\.\d+)?)', answer_text)
        if numeric_match:
            numeric_answer = numeric_match.group(1)
            # return f"{prompt}\n\n{model_response.split('####')[0].strip()}\n#### {numeric_answer}"
            return {
                'prompt': prompt,
                'response': model_response,
                'answer': numeric_answer
            }

    answer_match = re.search(r'(?:final answer|the answer is)[^0-9]*?([\$]?\s*\d+(?:\.\d+)?)',
                            model_response.lower())
    if answer_match:
        answer_text = answer_match.group(1)
        numeric_match = re.search(r'(\d+(?:\.\d+)?)', answer_text)
        if numeric_match:
            numeric_answer = numeric_match.group(1)
            answer_position = model_response.lower().find(answer_match.group(0))
            if answer_position != -1:
                # return f"{prompt}\n\n{model_response[:answer_position].strip()}\n#### {numeric_answer}"
                return {
                    'prompt': prompt,
                    'response': model_response,
                    'answer': numeric_answer
                }

    lines = model_response.split('\n')
    for i in range(len(lines)-1, max(0, len(lines)-5), -1):
        line = lines[i]
        if len(line.strip()) < 1 or any(word in line.lower() for word in ["step", "explanation"]):
            continue

        numeric_match = re.search(r'(\d+(?:\.\d+)?)', line)
        if numeric_match:
            numeric_answer = numeric_match.group(1)
            # return f"{prompt}\n\n{model_response.split(line)[0].strip()}\n#### {numeric_answer}"
            return {
                'prompt': prompt,
                'response': model_response,
                'answer': numeric_answer
            }

    # return full_output
    return {
        'prompt': prompt,
        'response': full_output,
        'answer': None
    }


In [None]:
temp_set = get_dataset()
gsm8k_dataset = {
    'train': temp_set[0],
    'test': temp_set[1]
}
# models = [get_model('phi2')]
models = [get_model('wizardmath')]

tokenizer_config.json:   0%|          | 0.00/948 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/167 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

pytorch_model.bin.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

pytorch_model-00002-of-00002.bin:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

pytorch_model-00001-of-00002.bin:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

In [None]:
def calculate_cost(prediction_data, model):
    """Calculate cost using structured prediction data"""
    tokenizer = model['tokenizer']
    cost_per_token = model['cost_per_token']

    input_tokens = tokenizer.encode(prediction_data['prompt'], return_tensors='pt').shape[1]
    output_tokens = tokenizer.encode(prediction_data['response'], return_tensors='pt').shape[1]

    return (input_tokens + output_tokens) * cost_per_token

In [None]:
import time

total_correct = 0
total_cost = 0

start_idx = 2001
num_problems = 500
cur_problem_idx = 0
subset = gsm8k_dataset['train'].select(range(start_idx, start_idx+num_problems))
wizardmath_preds = []

start_time = time.time()
for problem in tqdm(subset, desc="Processing problems"):
    cur_problem_idx += 1
    correct_answer = extract_answer(problem['answer'])

    prediction = process_problem(problem, 0, models)

    predicted_ans = prediction['answer']
    # print(f"Predicted ans: {predicted_ans}")
    # print(f"Correct Answer: {correct_answer}")
    if predicted_ans is not None and float(predicted_ans) == float(correct_answer):
        total_correct += 1
        wizardmath_preds.append({'question': problem['question'], 'answer': problem['answer'], 'is_correct': True})
    else:
        wizardmath_preds.append({'question': problem['question'], 'answer': problem['answer'], 'is_correct': False})

end_time = time.time()

Processing problems:   0%|          | 0/500 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Processing problems:   0%|          | 1/500 [00:14<1:57:39, 14.15s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Processing problems:   0%|          | 2/500 [00:29<2:05:13, 15.09s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Processing problems:   1%|          | 3/500 [00:39<1:44:35, 12.63s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Processing problems:   1%|          | 4/500 [00:58<2:04:35, 15.07s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Processing problems:   1%|          | 5/500 [01:08<1:49:55, 13.32s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Processing problems:   1%|          | 6/500 [01:19<1:43:18, 12.55s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Processing problems:   1%|▏         | 7/500 

In [None]:
import pandas as pd

problem_data = pd.DataFrame(wizardmath_preds)
problem_data.to_csv('wizardmath_preds_solo.csv', index=False)

accuracy = total_correct / num_problems
print(f"Accuracy: {accuracy * 100:.2f}%")
print(f"Time taken: {end_time - start_time:.2f} seconds")

Accuracy: 88.80%
Time taken: 6654.15 seconds
