In [5]:
class CONFIG:
    model_name:str = 'meta-llama/Meta-Llama-3-8B-Instruct'
    dataset_path: str = "data"
    init_result_path: str = "results"

In [6]:
from datasets import load_dataset

dataset_gsm8k = load_dataset("openai/gsm8k", "main")

training_data = dataset_gsm8k['train']
test_data = dataset_gsm8k['test']

print(f"Size of training data: {len(training_data)}")
print(f"Size of test data: {len(test_data)}")

Size of training data: 7473
Size of test data: 1319


In [None]:
import json
from transformers import AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(CONFIG.model_name)
model = AutoModelForCausalLMWithValueHead.from_pretrained(CONFIG.model_name)

# Define 4-shot chain-of-thought examples
cot_examples = [
    {"problem": "If a train travels 60 miles in 1 hour, how far will it travel in 4 hours?", 
     "solution": "First, calculate the distance traveled in one hour, which is 60 miles. Since the train travels at a constant speed, in 4 hours, it will travel 4 times the distance of one hour. Therefore, the solution is 60 miles * 4 = 240 miles."},
    {"problem": "Sarah has 5 apples. She buys 7 more apples. How many apples does she have now?", 
     "solution": "First, note that Sarah initially has 5 apples. Then, she buys 7 more apples. To find the total number of apples Sarah has, add the apples she had initially to the apples she bought. Therefore, the solution is 5 + 7 = 12 apples."},
    {"problem": "Tom has 8 candies. He gives 3 candies to his friend. How many candies does Tom have left?", 
     "solution": "First, note that Tom starts with 8 candies. He gives away 3 candies. To find out how many candies Tom has left, subtract the number of candies he gives away from the number he started with. Therefore, the solution is 8 - 3 = 5 candies."},
    {"problem": "A rectangle has a length of 10 cm and a width of 5 cm. What is the area of the rectangle?", 
     "solution": "First, note that the area of a rectangle is calculated by multiplying its length by its width. Given the length is 10 cm and the width is 5 cm, the area is 10 cm * 5 cm. Therefore, the solution is 10 * 5 = 50 square cm."}
]

def llama3_8b_evaluate(problem):
    """
    Evaluate the solution using Llama3 8B model with 4-shot CoT examples.

    Parameters:
    problem (str): The problem statement.

    Returns:
    str: The generated solution.
    """
    # Construct the prompt with 4-shot CoT examples
    prompt = ""
    for example in cot_examples:
        prompt += f"Problem: {example['problem']}\nSolution: {example['solution']}\n\n"
    prompt += f"Problem: {problem}\nSolution:"

    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=150, num_return_sequences=1)  # Adjust max_length as needed
    solution = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract the generated solution after "Solution:"
    solution = solution.split("Solution:")[1].strip() if "Solution:" in solution else solution.strip()
    return solution

def run_eval(evaluate_function, test_data):
    """
    Run initial evaluation to find the accuracy of the result without alignment
    and return the incorrect problem-solution pairs in JSON format.

    Parameters:
    evaluate_function (callable): The function that evaluates the solution.
    test_data (list): A list of dictionaries where each dictionary contains 'problem' and 'expected_solution'.

    Returns:
    tuple: A tuple containing accuracy and JSON string of incorrect problem-solution pairs.
    """
    total = len(test_data)
    incorrect_pairs = []
    correct_count = 0

    for item in test_data:
        problem = item['problem']
        expected_solution = item['expected_solution']
        actual_solution = evaluate_function(problem)

        if actual_solution.strip() == expected_solution.strip():
            correct_count += 1
        else:
            incorrect_pairs.append({
                'problem': problem,
                'expected_solution': expected_solution,
                'actual_solution': actual_solution
            })

    accuracy = correct_count / total
    incorrect_pairs_json = json.dumps(incorrect_pairs, indent=4)

    return accuracy, incorrect_pairs_json

def save_incorrect_pairs(file_path="./", file_name="incorrect_pairs.json"):
    with open(file_path + file_name, "w") as file:
        file.write(incorrect_pairs_json)
    
    print(f"Incorrect problem-solution pairs saved to {file_path}")

accuracy, incorrect_pairs_json = run_eval(llama3_8b_evaluate, test_data)
print("Accuracy:", accuracy)
print("Incorrect Problem-Solution Pairs JSON:", incorrect_pairs_json)
save_incorrect_pairs("llama_3_8b_4shots_cot_gsm8k_incorrect_pairs.json")