In [1]:
from datasets import load_dataset
from dataset_generator import (
    generate_completion_dataset,
    generate_corrective_dataset,
    generate_kto_dataset,
    generate_copy_dataset,
)
from transformers import (
    pipeline,
    BitsAndBytesConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
)
import re

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype="bfloat16",
)
model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_path, quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

generate_kargs = {"max_new_tokens": 3000, "do_sample": True, "batch_size": 2}

prompt = "{problem} \nPlease reason step by step, and put your final answer within \\boxed{{}}"

dataset = load_dataset("openai/gsm8k", "main")
testing_set = dataset["test"]

def get_answer_from_output(text):
    try:
        result_output = re.findall(r"\\boxed\{(\d+)\}", text)
        return float(result_output[0])
    except Exception:
        return None

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:  44%|####4     | 2.21G/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

In [None]:
completion_dataset = generate_completion_dataset(
    pipe,
    testing_set,
    prompt,
    get_answer_from_output,
    generate_kwargs=generate_kargs,
    generate_count_per_problem=1,
)
completion_dataset[:]

# Calculate the accuracy of the model
correct = 0
total = 0
for problem in completion_dataset:
    total += 1
    if problem["correct"]:
        correct += 1
        
print(f"Accuracy: {correct/total}")