# Partial Replication of "Take a Step Back: Evoking Reasoning via Abstraction in Large Language Models"

This notebook presents a partial replication of the work described in "Take a Step Back: Evoking Reasoning via Abstraction in Large Language Models" by Zheng et al. (2023). I focus specifically on the MMLU high-school physics portion of their experiments.


## Overview of the Process

The replication follows these main steps:
1. Load the MMLU high-school physics dataset
2. Create an exemplar prompt to extract physics principles (following Table 7 in section C.1 in the original paper)
3. Use the exemplar to request principles for each question in the dataset (following Table 7 in section C.1 in the original paper)
4. Generate solutions based on the extracted principles (Table 8 in section D.1.)
5. Evaluate the generated solutions against the official answers
6. Analyze the results

We will perform this process first using the llama3 model, and then repeat it with another model for comparison.

## Step 1: Data Loading

This code loads both the training and test splits of the MMLU high-school physics dataset. The training split will be used to create our exemplar, while the test split will be used for the main evaluation.

In [3]:
# We'll be using these libraries throughout the replication
import random
import os
import json
import re
import time
from datasets import load_dataset

def load_physics_dataset():
    # Load the MMLU high-school physics dataset.
    physics_dataset_examples = load_dataset("lukaemon/mmlu", "high_school_physics", split="train") # We'll use 'train' to create an exemplar
    physics_dataset_evaluation = load_dataset("lukaemon/mmlu", "high_school_physics", split="test") # We'll use 'test' to evaluate
    return physics_dataset_examples, physics_dataset_evaluation

physics_dataset_examples, physics_dataset_evaluation = load_physics_dataset()

# The datasets are quite small
print(f"Number of examples in training set: {len(physics_dataset_examples)}")
print(f"Number of examples in test set: {len(physics_dataset_evaluation)}")

Number of examples in training set: 5
Number of examples in test set: 151


Create utility function to save to JSON.

In [4]:
# Function to save to a JSON file
def save_to_json(data, filename):
    with open(filename, 'w') as f:
        json.dump(data, f, indent=4)

## Step 2: Model Setup

Next, we set up the functions to interact with our language models.

In [82]:
import ollama
import anthropic
from openai import OpenAI
import google.generativeai as genai
	
anthropic_client = anthropic.Anthropic() # Defaults to os.environ.get("ANTHROPIC_API_KEY")
openai_client = OpenAI() # Defaults to os.environ.get("OPENAI_API_KEY"),
genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))

def call_claude(prompt):
	# Call Claude 3.5 Sonnet model with given prompt.
	message = anthropic_client.messages.create(
        model="claude-3-5-sonnet-20240620",
        max_tokens=1024,
        messages=[
            {"role": "user", "content": prompt}
        ]
	)
	return message.content[0].text

def call_llama3(prompt):
	# Call Llama 3 model via Ollama with given prompt.
	model = 'llama3'
	messages = [
        {
            'role': 'user',
            'content': prompt
        }
	]

	# Make the request to the Ollama API
	response = ollama.chat(model=model, messages=messages, stream=False)
	return response["message"]["content"]

def call_gemini(prompt):
    # Call Gemini 1.5 Flash model with given prompt.
    model = genai.GenerativeModel("gemini-1.5-flash")
    attempts = 1
    delay = 5

    # Gemini API doesn't seem to be that reliable yet and requires some additional logic.
    while attempts <= 3:
        try:
            response = model.generate_content(
                contents=prompt, 
                request_options={"timeout": 2000}
            )
            return response.text
        except ValueError as e:
            attempts += 1
            print(f"Encountered error {e}. Waiting {delay} seconds and retrying. Attempt{attempts}...")
            time.sleep(delay)

def call_gpt(prompt):
	completion = openai_client.chat.completions.create(
		model="gpt-3.5-turbo",
		messages=[
			{"role": "user", "content": prompt}
		]
	)
	return completion.choices[0].message.content

def call_model(model_name, prompt):
    # Route prompt to specified model and return response.
    match model_name:
        case "claude":
            return call_claude(prompt)
        case "llama3":
            return call_llama3(prompt)
        case "gemini":
            return call_gemini(prompt)
        case "gpt":
            return call_gpt(prompt)
        case _:
            return "Model not found."


test_prompt = "Hello my friend"
print(call_model("claude", test_prompt))
print(call_model("llama3", test_prompt))
print(call_model("gemini", test_prompt))
print(call_model("gpt", test_prompt))

Hello! It's nice to meet you. How can I assist you today? Feel free to ask me any questions or let me know if there's anything you'd like to discuss.
Hello there! It's great to chat with you. Is there something on your mind that you'd like to talk about, or are you just looking for some friendly conversation? I'm all ears!
Hello! It's nice to hear from you.  What can I do for you today? 

Hello! How can I assist you today?


## Step 3: Creating the Exemplar

Now, we'll create an exemplar following Table 7 in Section C.1 in the original paper. This exemplar will be used to guide the model in extracting principles for other questions.

### Step 3a: Get the principles

In [12]:
import random

random.seed(293)
exemplar_question = random.choice(physics_dataset_examples)["input"]
step_back_prompt = "What are the physics principles behind this question? Describe only the principles and relevant equations without answering the question."

# Follows format in Table 7 in Section C.1.
principles_prompt = f"""
{step_back_prompt}
{exemplar_question}
Principles:
"""

exemplar_principles = call_model("claude", principles_prompt)

### Step 3b: Get the solution using the principles

In [13]:
solution_prompt = f"""
Question: {exemplar_question}
Principles: {exemplar_principles}
Solution: 
"""
exemplar_solution = call_model("claude", solution_prompt)

# Follows format in Table 8 in Section D.1.
exemplar = {
    "question": exemplar_question,
    "principles": exemplar_principles,
    "solution": exemplar_solution
}

# Print exemplar for examination
for _, value in exemplar.items():
    print(value)

A microwave oven is connected to an outlet, 120 V, and draws a current of 2 amps. At what rate is energy being used by the microwave oven?
The physics principles behind this question relate to electrical power and energy consumption. Here are the relevant principles and equations:

1. Ohm's Law: This fundamental principle relates voltage (V), current (I), and resistance (R) in an electrical circuit.
   V = I * R

2. Electrical Power: The rate at which electrical energy is transferred or converted in a circuit is given by the power equation.
   P = V * I
   Where:
   P is power in watts (W)
   V is voltage in volts (V)
   I is current in amperes (A)

3. Energy and Power Relationship: Energy is the capacity to do work, while power is the rate at which energy is transferred or work is done.
   E = P * t
   Where:
   E is energy in joules (J)
   P is power in watts (W)
   t is time in seconds (s)

4. Unit Conversions: Understanding the relationships between different units of power and ene

## 4. Generating Principles for Test Dataset

Next, we'll generate principles for each problem in the test dataset.

In [83]:
# Follows format in Table 7 in Section C.1.
get_principles_prompt_template = """
You are an expert at Physics.
You are given a Physics problem.
Your task is to extract the Physics concepts and principles involved in solving the problem.
Here is an example:

--- Example ----
Question: {exemplar_question}
Principles:{exemplar_principles}
--- End of Example ----

Question: {current_question}
Do not solve the problem. Only detail the principles and equations involved.
Principles Involved:
"""

# A function to step through each problem and get the associated principles from a model.
def generate_principles(model_name, prompt_template, problem_set, exemplar):
    num_problems = len(problem_set)

    # Initalize results dictionary
    principles_dict = {
        "model_name": model_name,
        "principles": {}
    }

    # Loop through problem set and get principles for each problem.
    for idx, problem in enumerate(problem_set):
        print(f"Using model {model_name}. Processing problem {idx + 1} of {num_problems}.")
        current_question = problem["input"]

        # Assume format of problem_set is physics_dataset_evaluation
        prompt = prompt_template.format(
            exemplar_question=exemplar["question"],
            exemplar_principles=exemplar["principles"],
            current_question=current_question
        )
       
        extracted_principles = call_model(model_name, prompt) # Call associated model to get principles for the given problem
        
        # Save returned principles
        principles_dict["principles"][f"problem_{idx}"] = {
            "question": current_question,
            "principles": extracted_principles
        }

    return principles_dict

# Create principles using both Llama3 (8B) and Gemini Flash
llama3_principles = generate_principles("llama3", get_principles_prompt_template, physics_dataset_evaluation, exemplar)
gemini_principles = generate_principles("gemini", get_principles_prompt_template, physics_dataset_evaluation, exemplar)
gpt_principles = generate_principles("gpt", get_principles_prompt_template, physics_dataset_evaluation, exemplar)

# Save files
save_to_json(llama3_principles, "llama3_principles.json")
save_to_json(gemini_principles, "gemini_principles.json")
save_to_json(gpt_principles, "gpt_principles.json")

Using model gemini. Processing problem 1 of 151.
Using model gemini. Processing problem 2 of 151.
Using model gemini. Processing problem 3 of 151.
Using model gemini. Processing problem 4 of 151.
Using model gemini. Processing problem 5 of 151.
Using model gemini. Processing problem 6 of 151.
Using model gemini. Processing problem 7 of 151.
Using model gemini. Processing problem 8 of 151.
Using model gemini. Processing problem 9 of 151.
Using model gemini. Processing problem 10 of 151.
Using model gemini. Processing problem 11 of 151.
Using model gemini. Processing problem 12 of 151.
Using model gemini. Processing problem 13 of 151.
Using model gemini. Processing problem 14 of 151.
Using model gemini. Processing problem 15 of 151.
Using model gemini. Processing problem 16 of 151.
Using model gemini. Processing problem 17 of 151.
Using model gemini. Processing problem 18 of 151.
Using model gemini. Processing problem 19 of 151.
Using model gemini. Processing problem 20 of 151.
Using mod

## 5. Generating Solutions
Now we'll generate solutions for each problem using the extracted principles. This follows Table 8 in section D.1.

In [90]:
# Follows format in Table 8 in Section D.1.
get_step_back_solution_prompt_template = """
You are an expert at Physics.
You are given a Physics problem and a set of principles involved in solving the problem.
Solve the problem step by step by following the principles.
Here is an example:

--- Example ----
Question: {exemplar_question}
Principles:{exemplar_principles}
Solution: {exemplar_solution}
--- End of Example ----

Question: {current_question}
Principles Involved: {current_principles}
Solution: 
"""

get_standard_solution_prompt_template = """
You are an expert at Physics.
You are given a Physics problem.
Solve the problem step by step.

Question: {current_question}
Solution: 
"""

def generate_step_back_solutions(model_name, prompt_template, principles_dict, exemplar):
    num_problems = len(principles_dict["principles"])

    # Initalize results dictionary
    solutions_dict = {
        "model_name": model_name,
        "prompt_type": "step back",
        "solutions": {}
    }

    for idx, problem in enumerate(principles_dict["principles"].values()):
        print(f"Using model {model_name} with a step-back prompt. Processing problem {idx + 1} of {num_problems}.")
        current_question = problem["question"]
        current_principles = problem["principles"]

        prompt = prompt_template.format(
            exemplar_question=exemplar["question"],
            exemplar_principles=exemplar["principles"],
            exemplar_solution=exemplar["solution"],
            current_question=current_question,
            current_principles=current_principles
        )

        current_solution = call_model(model_name, prompt)
        
        solutions_dict["solutions"][f"problem_{idx}"] = {
            "question": current_question,
            "principles": current_principles,
            "solution": current_solution
        }

    return solutions_dict

def generate_standard_solutions(model_name, prompt_template, problem_set):
    num_problems = len(problem_set)

    # Initalize results dictionary
    solutions_dict = {
        "model_name": model_name,
        "prompt_type": "standard",
        "solutions": {}
    }

    # Loop through problem set and call model to get solution with just the original question and no principles or example solution
    for idx, problem in enumerate(problem_set):
        print(f"Using model {model_name} with a standard chain-of-thought prompt. Processing problem {idx + 1} of {num_problems}.")
        current_question = problem["input"]
        prompt = prompt_template.format(current_question=current_question)
        current_solution = call_model(model_name, prompt)
        
        # Save results
        solutions_dict["solutions"][f"problem_{idx}"] = {
            "question": current_question,
            "solution": current_solution
        }

    return solutions_dict

# Get step back solutions
llama3_step_back_solutions = generate_step_back_solutions("llama3", get_step_back_solution_prompt_template, llama3_principles, exemplar)
gemini_step_back_solutions = generate_step_back_solutions("gemini", get_step_back_solution_prompt_template, llama3_principles, exemplar)
gpt_step_back_solutions = generate_step_back_solutions("gpt", get_step_back_solution_prompt_template, gpt_principles, exemplar)

# Save step back solutions
save_to_json(llama3_step_back_solutions, "llama3_step_back_solutions.json")
save_to_json(gemini_step_back_solutions, "gemini_step_back_solutions.json")
save_to_json(gpt_step_back_solutions, "gpt_step_back_solutions.json")

# Get standard solutions
llama3_standard_solutions = generate_standard_solutions("llama3", get_standard_solution_prompt_template, physics_dataset_evaluation)
gemini_standard_solutions = generate_standard_solutions("gemini", get_standard_solution_prompt_template, physics_dataset_evaluation)
gpt_standard_solutions = generate_standard_solutions("gpt", get_standard_solution_prompt_template, physics_dataset_evaluation)

# Save standard solutions
save_to_json(llama3_standard_solutions, "llama3_standard_solutions.json")
save_to_json(gemini_standard_solutions, "gemini_standard_solutions.json")
save_to_json(gpt_standard_solutions, "gpt_standard_solutions.json")

Using model llama3 with a step-back prompt. Processing problem 1 of 151.
Using model llama3 with a step-back prompt. Processing problem 2 of 151.
Using model llama3 with a step-back prompt. Processing problem 3 of 151.
Using model llama3 with a step-back prompt. Processing problem 4 of 151.
Using model llama3 with a step-back prompt. Processing problem 5 of 151.
Using model llama3 with a step-back prompt. Processing problem 6 of 151.
Using model llama3 with a step-back prompt. Processing problem 7 of 151.
Using model llama3 with a step-back prompt. Processing problem 8 of 151.
Using model llama3 with a step-back prompt. Processing problem 9 of 151.
Using model llama3 with a step-back prompt. Processing problem 10 of 151.
Using model llama3 with a step-back prompt. Processing problem 11 of 151.
Using model llama3 with a step-back prompt. Processing problem 12 of 151.
Using model llama3 with a step-back prompt. Processing problem 13 of 151.
Using model llama3 with a step-back prompt. Pro

## 6. Evaluate solutions
We'll use Claude Sonnet 3.5 to evaluate the solution of each model. The solution could be in different units than the original, but still valid. Likewise, the solution may be difficult to extract using standard text analysis functions so using an LLM as a judge is an appropriate technique.

In [97]:
solution_evaluation_prompt_template = """
Does the result from the AI-generated solution match the official solution? Read over the entire solution to determine the AI's answer.

Note: The solutions are invariant to units if they represent the same value. For example, 1 Amp = 1000 milliamps is considered an equivalent and valid solution. 

Provide your evaluation in JSON format with the following information:
1. true if the solutions are equivalent, false otherwise (remember, booleans in JSON are in lower case)
2. A short explanation of the reasoning behind your evaluation.

--- Example ---
{{
    "ai_solution_matches": true,
    "evaluation_explanation": "The AI said, 'Therefore, the charge on the positive plate is approximately 10 microcoulombs (μC). The official solution is 0.00001 C. The two answers are the same.'"
}}
--- End of Example ---

Official solution: {official_solution}
AI-generated solution: {ai_solution}

Your assessment:
{{
    "ai_solution_matches": 
    "evaluation_explanation": ""
}}
"""

def convert_response_to_json(response):
    start_index = response.find('{')
    end_index = response.find('}') + 1
    cleaned_response = response[start_index:end_index]
    try:
        return json.loads(cleaned_response)
    except json.JSONDecodeError as e:
        print(f"Error while decoding JSON: {e}")
        return None

def get_solution_evaluation(model_solutions, official_solutions, prompt_template, evaluation_model="claude"):
    num_problems = len(model_solutions["solutions"])

    model_name = model_solutions["model_name"]
    prompt_type = model_solutions["prompt_type"]

    # Initalize results dictionary
    solution_evaluation = {
        "model_name": model_name,
        "prompt_type": prompt_type,
        "solution_evaluations": {}
    }

    for idx, model_solution in enumerate(model_solutions["solutions"].values()):
        print(f"Using model {evaluation_model} to evaluate solutions generated by {model_name} using a {prompt_type} prompt strategy. Processing problem {idx + 1} of {num_problems}.")
        
        # Assuming format "official_solutions" are in the format of physics_dataset_evaluation and "model_solutions" are in the format of step_back_solutions_dict.
        # Use key to get offical solution
        official_solution_entry = official_solutions[idx]
        key = official_solution_entry["target"]
        official_solution = official_solution_entry[key]
        
        # Prompt Claude Sonnet 3.5 with model solution and official solution for evaluation
        prompt = prompt_template.format(ai_solution=model_solution["solution"], official_solution=official_solution)
        evaluation_response = call_model(evaluation_model, prompt)
        solution_evaluation["solution_evaluations"][f"problem_{idx}"] = convert_response_to_json(evaluation_response)

    return solution_evaluation

# # Step back solutions
llama3_step_back_solution_evaluation = get_solution_evaluation(llama3_step_back_solutions, physics_dataset_evaluation, solution_evaluation_prompt_template)
gemini_step_back_solution_evaluation = get_solution_evaluation(gemini_step_back_solutions, physics_dataset_evaluation, solution_evaluation_prompt_template)
gpt_step_back_solution_evaluation = get_solution_evaluation(gpt_step_back_solutions, physics_dataset_evaluation, solution_evaluation_prompt_template)

# # Save step back solutions evaluations
save_to_json(llama3_step_back_solution_evaluation, 'llama3_step_back_solution_evaluation.json')
save_to_json(gemini_step_back_solution_evaluation, 'gemini_step_back_solution_evaluation.json')
save_to_json(gpt_step_back_solution_evaluation, 'gpt_step_back_solution_evaluation.json')

# Standard solutions
llama3_standard_solution_evaluation = get_solution_evaluation(llama3_standard_solutions, physics_dataset_evaluation, solution_evaluation_prompt_template)
gemini_standard_solution_evaluation = get_solution_evaluation(gemini_standard_solutions, physics_dataset_evaluation, solution_evaluation_prompt_template)
gpt_standard_solution_evaluation = get_solution_evaluation(gpt_standard_solutions, physics_dataset_evaluation, solution_evaluation_prompt_template)

# Save standard solutions evaluations
save_to_json(llama3_standard_solution_evaluation, 'llama3_standard_solution_evaluation.json')
save_to_json(gemini_standard_solution_evaluation, 'gemini_standard_solution_evaluation.json')
save_to_json(gpt_standard_solution_evaluation, 'gpt_standard_solution_evaluation.json')

Using model claude to evaluate solutions generated by gemini using a step back prompt strategy. Processing problem 1 of 151.
Using model claude to evaluate solutions generated by gemini using a step back prompt strategy. Processing problem 2 of 151.
Using model claude to evaluate solutions generated by gemini using a step back prompt strategy. Processing problem 3 of 151.
Using model claude to evaluate solutions generated by gemini using a step back prompt strategy. Processing problem 4 of 151.
Using model claude to evaluate solutions generated by gemini using a step back prompt strategy. Processing problem 5 of 151.
Using model claude to evaluate solutions generated by gemini using a step back prompt strategy. Processing problem 6 of 151.
Using model claude to evaluate solutions generated by gemini using a step back prompt strategy. Processing problem 7 of 151.
Using model claude to evaluate solutions generated by gemini using a step back prompt strategy. Processing problem 8 of 151.


## Step 7. Calculate accuracy

In [114]:
def calculate_accuracy(solutions, comparison_key):
    return sum(1 for value in solutions.values() if value[comparison_key]) / len(solutions)

def get_evaluation_results(evaluation_results_list):
    for evaluation_results in evaluation_results_list:
        proportion_correct = calculate_accuracy(evaluation_results["solution_evaluations"], 'ai_solution_matches')
        evaluation_results["solution_evaluations"]
        print(f"For model {evaluation_results['model_name']} using the {evaluation_results['prompt_type']} prompt method the accuracy was {proportion_correct:.1%}")

evaluation_results_list = [llama3_step_back_solution_evaluation, gemini_step_back_solution_evaluation, gpt_step_back_solution_evaluation, llama3_standard_solution_evaluation, gemini_standard_solution_evaluation, gpt_standard_solution_evaluation]

get_evaluation_results(evaluation_results_list)          

For model llama3 using the step back prompt method the accuracy was 22.5%
For model gemini using the step back prompt method the accuracy was 58.9%
For model gpt using the step back prompt method the accuracy was 33.1%
For model llama3 using the standard prompt method the accuracy was 20.5%
For model gemini using the standard prompt method the accuracy was 61.6%
For model gpt using the standard prompt method the accuracy was 43.0%


In [116]:
physics_dataset_evaluation[0]

{'input': 'The plates of a capacitor are charged to a potential difference of 5 V. If the capacitance is 2 mF, what is the charge on the positive plate?',
 'A': '0.005 C',
 'B': '0.01 C',
 'C': '0.02 C',
 'D': '0.5 C',
 'target': 'B'}