In [1]:
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
text_env_tools = []
prompt_path = "prompts/costar_cot_1shot.txt"
dataset_path = "data/MATH_DPO_COT"
checkpoint_path = "checkpoint.yaml"

prompt = open(prompt_path, "r").read()

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
from datasets import load_dataset, Dataset
import re
from tqdm.auto import tqdm
import os
import yaml

In [3]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_storage=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    low_cpu_mem_usage=True,
)


tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
def generate(input_texts: list[str]):
    inputs = tokenizer(
        input_texts, return_tensors="pt", padding=True, truncation=False
    ).to(model.device)
    outputs = model.generate(
        **inputs, max_new_tokens=512, pad_token_id=tokenizer.eos_token_id
    )
    output_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return [output[len(input) :] for input, output in zip(input_texts, output_texts)]

In [5]:
dataset = load_dataset("json", data_dir="data/MATH")


def is_real_number(text):
    try:
        float(text)
        return True
    except Exception:
        return False


def extract_answer(text):
    try:
        match = re.search(r"\\boxed{(.+?)}", text)
        return match.group(1)
    except Exception:
        return None


dataset_with_answer = dataset.map(
    lambda x: {"problem": x["problem"], "answer": extract_answer(x["solution"])}
)
dataset_with_answer = dataset_with_answer.filter(lambda x: is_real_number(x["answer"]))
dataset_with_answer = dataset_with_answer.filter(lambda x: len(x["problem"]) < 500)
dataset_with_answer = dataset_with_answer.rename_column("problem", "query")

Resolving data files:   0%|          | 0/7500 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/5000 [00:00<?, ?it/s]

In [6]:
dataset_with_answer

DatasetDict({
    train: Dataset({
        features: ['query', 'level', 'type', 'solution', 'answer'],
        num_rows: 4567
    })
    test: Dataset({
        features: ['query', 'level', 'type', 'solution', 'answer'],
        num_rows: 3068
    })
})

In [7]:
import numpy as np


def _exact_match_reward(responses, answers):
    """Reward if generated response contains correct answer."""
    rewards = []
    for response, answer in zip(responses, answers):
        reward = 0.0
        predicted_number = _get_answer(response)
        if predicted_number is not None:
            if np.abs(predicted_number - float(answer)) < 0.1:
                reward += 1.0
        else:
            reward = 0.0
        rewards.append(reward)
    return rewards


def _get_answer(response):
    try:
        pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*<submit>"
        match_pattern = re.findall(pattern, response)
        if match_pattern:
            return float(match_pattern[0])
        else:
            return None
    except Exception:
        return None

In [8]:
prompts = []
chosen = []
rejected = []

batch_size = 12
try_count = 10


start_index = 0
if os.path.exists(checkpoint_path):
    with open(checkpoint_path, "r") as f:
        checkpoint = yaml.safe_load(f)
        start_index = checkpoint.get("start_index", 0)
        prompts = checkpoint.get("prompts", [])
        chosen = checkpoint.get("chosen", [])
        rejected = checkpoint.get("rejected", [])

for i in tqdm(range(start_index, len(dataset_with_answer["train"]), batch_size)):
    correct_answer = [None] * batch_size
    wrong_answer = [None] * batch_size
    batch_rows = dataset_with_answer["train"][i : i + batch_size]
    batch_inputs = [prompt + query for query in batch_rows["query"]]

    for j in range(try_count):
        responses = generate(batch_inputs)
        batch_reward = _exact_match_reward(responses, batch_rows["answer"])
        for k, reward in enumerate(batch_reward):
            if reward > 0:
                correct_answer[k] = responses[k]
            else:
                wrong_answer[k] = responses[k]

        if all(x is not None for x in correct_answer) and all(
            x is not None for x in wrong_answer
        ):
            break

    for k, row in enumerate(batch_rows):
        if correct_answer[k] is not None and wrong_answer[k] is not None:
            prompts.append(batch_inputs[k])
            chosen.append(correct_answer[k])
            rejected.append(wrong_answer[k])

    checkpoint = {
        "start_index": i + batch_size,
        "prompts": prompts,
        "chosen": chosen,
        "rejected": rejected,
    }
    with open(checkpoint_path, "w") as f:
        yaml.safe_dump(checkpoint, f)

new_dataset = Dataset.from_dict(
    {"prompt": prompts, "chosen": chosen, "rejected": rejected}
)
os.makedirs(dataset_path, exist_ok=True)

new_dataset.save_to_disk(dataset_path)

if os.path.exists(checkpoint_path):
    os.remove(checkpoint_path)

  0%|          | 0/201 [00:00<?, ?it/s]