In [1]:
class CONFIG:
    model_name:str = 'meta-llama/Meta-Llama-3-8B-Instruct'
    dataset_path: str = "data"
    init_result_path: str = "results"

In [2]:
from datasets import load_dataset

dataset_gsm8k = load_dataset("openai/gsm8k", "main")

training_data = dataset_gsm8k['train']
test_data = dataset_gsm8k['test']

print(f"Size of training data: {len(training_data)}")
print(f"Size of test data: {len(test_data)}")

Size of training data: 7473
Size of test data: 1319




In [3]:
import json
import numpy as np
import re
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM
from tqdm.notebook import tqdm

# Define paths
prompt_path = '/home/ubuntu/lam-a10-cal/AIMO/prompts/costar_cot_1shot.txt'
checkpoint_path = 'meta-llama/Meta-Llama-3-8B-Instruct'

# Load prompt
prompt = open(prompt_path, 'r').read()
prompt = 'Submit your answer with the format: "Result = 72 <submit>"'

# Load model and tokenizer with quantization
quantization_config = BitsAndBytesConfig(
    load_in_2bit=True,
    load_out_2bit=True,
    load_quantized=True,
    quantize_inference=True,
)
model = AutoModelForCausalLM.from_pretrained(checkpoint_path, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
tokenizer.pad_token = tokenizer.eos_token

def generate(model, input_texts: list[str]):
    inputs = tokenizer(input_texts, return_tensors='pt', padding=True, truncation=False).to(model.device)
    outputs = model.generate(**inputs, max_new_tokens=512, pad_token_id=tokenizer.eos_token_id)
    output_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return [output[len(input):] for input, output in zip(input_texts, output_texts)]

def _exact_match_reward(responses, answers):
    """Reward if generated response contains correct answer."""
    rewards = []
    for response, answer in zip(responses, answers):
        reward = 0.0
        predicted_number = _get_answer(response)
        if predicted_number is not None:
            if np.abs(predicted_number - float(answer)) < 0.1:
                reward += 1.0
        else:
            reward = 0.0
        rewards.append(reward)
    return rewards

def _get_answer(response):
    try:
        pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*<submit>"
        match_pattern = re.findall(pattern, response)
        if match_pattern:
            return float(match_pattern[0])
        else:
            return None
    except Exception:
        return None

def evaluate(test_dataset):
    batch_size = len(test_dataset)
    responses = []
    incorrect_pairs = []

    for i in tqdm(range(0, len(test_dataset), batch_size)):
        batch = test_dataset[i:i + batch_size]
        batch_queries = [prompt + row for row in batch['question']]
        batch_responses = generate(model, batch_queries)
        responses.extend(batch_responses)
        
        # Identify incorrect pairs
        for query, response, answer in zip(batch['question'], batch_responses, batch['answer']):
            predicted_number = _get_answer(response)
            if predicted_number is None or np.abs(predicted_number - float(answer)) >= 0.1:
                incorrect_pairs.append({
                    "question": query,
                    "expected_solution": answer,
                    "actual_solution": response
                })

    answers = test_dataset['answer']
    rewards = _exact_match_reward(responses, answers)
    print(f"Exact match reward: {np.mean(rewards)}")
    return responses, incorrect_pairs

def llama3_8b_corrective_prompt(problem, previous_solution, correction_hint):
    """
    Perform corrective prompting using Llama3 8B model with LoRA.

    Parameters:
    problem (str): The problem statement.
    previous_solution (str): The previous incorrect solution.
    correction_hint (str): The hint or correction to guide the model.

    Returns:
    str: The generated solution.
    """
    prompt = f"Problem: {problem}\nPrevious Solution: {previous_solution}\nCorrection Hint: {correction_hint}\nNew Solution:"

    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=150, num_return_sequences=1)  # Adjust max_length as needed
    solution = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract the generated solution after "New Solution:"
    solution = solution.split("New Solution:")[1].strip() if "New Solution:" in solution else solution.strip()
    return solution

def categorize_responses(incorrect_pairs, evaluate_function):
    """
    Categorize responses into accepted and rejected after corrective prompting.

    Parameters:
    incorrect_pairs (list): List of dictionaries containing incorrect problem-solution pairs.
    evaluate_function (callable): The function that evaluates the solution.

    Returns:
    tuple: A tuple containing JSON strings of accepted and rejected datasets.
    """
    accepted = []
    rejected = []

    for item in incorrect_pairs:
        problem = item['question']
        expected_solution = item['expected_solution']
        previous_solution = item['actual_solution']
        correction_hint = f"The correct solution should be: {expected_solution}"

        new_solution = llama3_8b_corrective_prompt(problem, previous_solution, correction_hint)

        if new_solution.strip() == expected_solution.strip():
            accepted.append({
                'question': problem,
                'expected_solution': expected_solution,
                'previous_solution': previous_solution,
                'new_solution': new_solution
            })
        else:
            rejected.append({
                'question': problem,
                'expected_solution': expected_solution,
                'previous_solution': previous_solution,
                'new_solution': new_solution
            })

    accepted_json = json.dumps(accepted, indent=4)
    rejected_json = json.dumps(rejected, indent=4)

    return accepted_json, rejected_json

def merge_for_dpo(accepted_json, rejected_json):
    """
    Merge accepted and rejected datasets for DPO finetuning.

    Parameters:
    accepted_json (str): JSON string of accepted responses.
    rejected_json (str): JSON string of rejected responses.

    Returns:
    str: Merged JSON string suitable for DPO finetuning.
    """
    accepted = json.loads(accepted_json)
    rejected = json.loads(rejected_json)

    dpo_data = []

    for acc in accepted:
        problem = acc['question']
        expected_solution = acc['expected_solution']
        previous_solution = acc['previous_solution']
        new_solution = acc['new_solution']

        # Format for DPO: each entry contains a pair of responses (accepted and rejected)
        dpo_data.append({
            "prompt": problem,
            "accept": new_solution,
            "reject": previous_solution,
            "expected_solution": expected_solution
        })

    for rej in rejected:
        problem = rej['question']
        expected_solution = rej['expected_solution']
        previous_solution = rej['previous_solution']
        new_solution = rej['new_solution']

        # Format for DPO: each entry contains a pair of responses (accepted and rejected)
        dpo_data.append({
            "prompt": problem,
            "accept": expected_solution,
            "reject": new_solution,
            "expected_solution": expected_solution
        })

    dpo_json = json.dumps(dpo_data, indent=4)
    return dpo_json

# Example usage:
# Note: Replace the following test_data and incorrect_pairs with actual data
responses, incorrect_pairs = evaluate(test_data)
print("Initial Evaluation Responses and Incorrect Pairs Captured.")

# Use the incorrect pairs from the initial evaluation for corrective prompting
accepted_json, rejected_json = categorize_responses(incorrect_pairs, llama3_8b_corrective_prompt)
dpo_json = merge_for_dpo(accepted_json, rejected_json)

with open('accepted.json', 'w') as f:
    f.write(accepted_json)
    
with open('rejected.json', 'w') as f:
    f.write(rejected_json)
    
with open('dpo.json', 'w') as f:
    f.write(dpo_json)

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


  0%|          | 0/1 [00:00<?, ?it/s]

A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.02 GiB (GPU 0; 21.99 GiB total capacity; 17.14 GiB already allocated; 1.63 GiB free; 20.08 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF