In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
from datasets import load_dataset, Dataset
import re
from tqdm.auto import tqdm
import os
import yaml
from trl import TextEnvironment, AutoModelForCausalLMWithValueHead
from transformers import load_tool, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

In [3]:
model_name = 'meta-llama/Meta-Llama-3-8B-Instruct'
prompt_path = 'prompts/python_cot.txt'
dataset_path = "data/MATH_DPO_COT"
checkpoint_path = "checkpoint_kto_dataset_python_cot.yaml"

prompt = open(prompt_path, 'r').read()

In [4]:
import numpy as np
def _exact_match_reward(responses, answers):
    """Reward if generated response contains correct answer."""
    rewards = []
    for response, answer in zip(responses, answers):
        reward = 0.0
        predicted_number = _get_answer(response)
        if predicted_number is not None:
            if np.abs(predicted_number - float(answer)) < 0.1:
                reward += 1.0
        else:
            reward = 0.0
        rewards.append(reward)
    return rewards

def _get_answer(response):
    try:
        pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*<submit>"
        match_pattern = re.findall(pattern, response)
        if match_pattern:
            return float(match_pattern[0])
        else:
            return None
    except Exception:
        return None

In [5]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype='bfloat16',
) 

model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

env = TextEnvironment(
    model,
    tokenizer,
    [load_tool("lvwerra/python-interpreter")],
    _exact_match_reward,
    prompt,
    generation_kwargs={
        "max_new_tokens": 512,
        "pad_token_id": tokenizer.eos_token_id
    }
)

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
You're loading a tool from the Hub from None. Please make sure this is a source that you trust as the code within that tool will be executed on your machine. Always verify the code of the tools that you load. We recommend specifying a `revision` to ensure you're loading the code that you have checked.


[2024-06-15 18:23:31,146] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


In [6]:
import numpy as np
import re
def _exact_match_reward(responses, answers):
    """Reward if generated response contains correct answer."""
    rewards = []
    for response, answer in zip(responses, answers):
        reward = 0.0
        predicted_number = _get_answer(response)
        if predicted_number is not None:
            if np.abs(predicted_number - float(answer)) < 0.1:
                reward += 1.0
        else:
            reward = 0.0
        rewards.append(reward)
    return rewards

def _get_answer(response):
    try:
        pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*<submit>"
        match_pattern = re.findall(pattern, response)
        if match_pattern:
            return float(match_pattern[0])
        else:
            return None
    except Exception:
        return None

In [7]:
def get_MATH_test_dataset():
    dataset = load_dataset("json", data_dir="data/MATH")
    def is_real_number(text):
        try:
            float(text)
            return True
        except Exception:
            return False
    def extract_answer(text):
        try:
            match = re.search(r"\\boxed{(.+?)}", text)
            return match.group(1)
        except Exception:
            return None

    dataset_with_answer = dataset.map(lambda x: {"problem": x["problem"], "answer": extract_answer(x["solution"])})
    dataset_with_answer = dataset_with_answer.filter(lambda x: is_real_number(x["answer"]))
    dataset_with_answer = dataset_with_answer.filter(lambda x: len(x['problem']) < 500)
    dataset_with_answer = dataset_with_answer.rename_column("problem", "query")
    return dataset_with_answer['test']

def get_aimo_test_dataset():
    test_dataset = Dataset.from_csv("data/val.csv")
    test_dataset = test_dataset.rename_column("problem", "query")
    test_dataset = test_dataset.remove_columns(["id"])
    return test_dataset

    

In [8]:
def run(queries, answers, run_times = 1):
    # add self consistency check
    final_rewards = []
    for _ in range(run_times):
        rewards = []
        _, _, _, rewards, histories = env.run(queries, answers=answers)
        if not final_rewards:
            final_rewards = rewards
        else:
            for i in range(len(final_rewards)):
                final_rewards[i] = max(final_rewards[i], rewards[i])
    
    return final_rewards, histories
        
    

In [9]:
def evaluate(test_dataset):
    batch_size = 8
    rewards = []
    all_histories = []

    for i in tqdm(range(0, len(test_dataset), batch_size)):
        batch_rows = test_dataset[i:i + batch_size]
        rewards, histories = run(batch_rows['query'], batch_rows['answer'], run_times=7)
        rewards.extend(rewards)
        all_histories.extend(histories)
        
    print(f"Exact match reward: {np.mean(rewards)}")
    return all_histories


In [10]:
dataset = get_MATH_test_dataset().select(range(32))
histories = evaluate(dataset)

Resolving data files:   0%|          | 0/7500 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/5000 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
