# Partial Replication of "Take a Step Back: Evoking Reasoning via Abstraction in Large Language Models"

This notebook presents a partial replication of the work described in "Take a Step Back: Evoking Reasoning via Abstraction in Large Language Models" by Zheng et al. (2023). I focus specifically on the MMLU high-school physics portion of their experiments.


## Overview of the Process

The replication follows these main steps:
1. Load the MMLU high-school physics dataset
2. Create an exemplar prompt to extract physics principles (following Table 7 in section C.1 in the original paper)
3. Use the exemplar to request principles for each question in the dataset (following Table 7 in section C.1 in the original paper)
4. Generate solutions based on the extracted principles (Table 8 in section D.1.)
5. Evaluate the generated solutions against the official answers
6. Analyze the results

We will perform this process first using the llama3 model, and then repeat it with another model for comparison.

## Step 1: Data Loading

This code loads both the training and test splits of the MMLU high-school physics dataset. The training split will be used to create our exemplar, while the test split will be used for the main evaluation.

In [1]:
# We'll be using these libraries throughout the replication
import random
import os
import json
from datasets import load_dataset

def load_physics_dataset():
    # Load the MMLU high-school physics dataset.
    physics_dataset_examples = load_dataset("lukaemon/mmlu", "high_school_physics", split="train") # We'll use 'train' to create an exemplar
    physics_dataset_evaluation = load_dataset("lukaemon/mmlu", "high_school_physics", split="test") # We'll use 'test' to evaluate
    return physics_dataset_examples, physics_dataset_evaluation

physics_dataset_examples, physics_dataset_evaluation = load_physics_dataset()

# The datasets are quite small
print(f"Number of examples in training set: {len(physics_dataset_examples)}")
print(f"Number of examples in test set: {len(physics_dataset_evaluation)}")

Downloading builder script:   0%|          | 0.00/5.01k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/28.7k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/166M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/151 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/17 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/5 [00:00<?, ? examples/s]

## Step 2: Model Setup

Next, we set up the functions to interact with our language models.

In [1]:
import ollama
import anthropic
import google.generativeai as genai
import os

anthropic_client = anthropic.Anthropic() # Defaults to os.environ.get("ANTHROPIC_API_KEY")
genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))

def call_claude(prompt):
    # Call Claude 3.5 Sonnet model with given prompt.
    message = anthropic_client.messages.create(
        model="claude-3-5-sonnet-20240620",
        max_tokens=1024,
        messages=[
            {"role": "user", "content": prompt}
        ]
    )
    return message.content[0].text

def call_llama3(prompt):
    # Call Llama 3 model via Ollama with given prompt.
    model = 'llama3'
    messages = [
        {
            'role': 'user',
            'content': prompt
        }
    ]

    # Make the request to the Ollama API
    response = ollama.chat(model=model, messages=messages, stream=False)
    return response["message"]["content"]

def call_gemini_flash(prompt):
    # Call Gemini 1.5 Flash model with given prompt.
    model = genai.GenerativeModel('gemini-1.5-flash')
    response = model.generate_content(prompt)
    return response.text

def call_model(model_name, prompt):
   # Route prompt to specified model and return response.
   match model_name:
      case "claude":
        return call_claude(prompt)
      case "llama3":
        return call_llama3(prompt)
      case "gemini":
        return call_gemini_flash
      case _:
        return "Model not found."

test_prompt = "Hello my friend"
print(call_model("claude", test_prompt))
print(call_model("llama3", test_prompt))
print(call_model("gemini", test_prompt))

ModuleNotFoundError: No module named 'google.generativeai'

In [14]:
import random
import anthropic
import os

examplar = random.choice(physics_dataset_examples)
step_back_prompt = "What are the physics principles behind this question? Describe only the principles and relevant equations without answering the question."

client = anthropic.Anthropic(
    api_key=os.environ.get("ANTHROPIC_API_KEY")
) # Defaults to os.environ.get("ANTHROPIC_API_KEY")

# Function to call Claude Sonnet 3.5 and return the api response
def call_claude(prompt):
    message = client.messages.create(
        model="claude-3-5-sonnet-20240620",
        max_tokens=1024,
        messages=[
            {"role": "user", "content": prompt}
        ]
    )
    return message.content[0].text

principles_prompt = step_back_prompt + "\n\n" + examplar["input"] + "\n\n" + "Principles:"
principles = call_claude(principles_prompt)


Use Claude Sonnet 3.5 to get example principles and solution.

In [36]:
principles_and_question_prompt = examplar["input"] + "\n\n" + f"Principles: {principles}"
exemplar_solution = call_claude(principles_and_question_prompt)

print(f"Question: {examplar['input']}")
print(f"Principles: {principles}")
print(f"Solution: {exemplar_solution}")

Question: A microwave oven is connected to an outlet, 120 V, and draws a current of 2 amps. At what rate is energy being used by the microwave oven?
Principles: The physics principles behind this question relate to electrical power and energy consumption. Here are the key principles and relevant equations:

1. Electrical Power:
   Power is the rate at which energy is transferred or converted. In electrical systems, power is the product of voltage and current.

   Equation: P = V × I
   Where:
   P = Power (measured in watts, W)
   V = Voltage (measured in volts, V)
   I = Current (measured in amperes, A)

2. Ohm's Law:
   While not directly used in this problem, Ohm's Law relates voltage, current, and resistance in an electrical circuit.

   Equation: V = I × R
   Where:
   R = Resistance (measured in ohms, Ω)

3. Energy Consumption:
   Energy consumed over time is equal to power multiplied by time.

   Equation: E = P × t
   Where:
   E = Energy (measured in joules, J, or watt-hours, 

## Test a small model Llama3 (8B)

### Step 1. Get principles.
For each question in the physics test dataset, make an API call to Llama3 (8B) to get the principles for that problem and store them in a dictionary. This approach follows Table 7 in section C.1. I slightly modified the original principles and solution above from Claude Sonnet 3.5 for clarity.

In [10]:
import ollama

def call_llama3(prompt):
  model = 'llama3'
  messages = [
      {
          'role': 'user',
          'content': prompt
      }
  ]

  # Make the request to the Ollama API
  response = ollama.chat(model=model, messages=messages, stream=False)
  return response["message"]["content"]

# Use a very slightly modified version of the principles developed by Claude Sonnet 3.5.
get_principles_prompt_template = """You are an expert at Physics.
You are given a Physics problem.
Your task is to extract the Physics concepts and principles involved in solving the problem.
Here is an examples:

--- Example ----
Question: A microwave oven is connected to an outlet, 120 V, and draws a current of 2 amps. At what rate is energy being used by the microwave oven?'
Principles:
"The physics principles behind this question relate to electrical power and energy consumption. Here are the key principles and relevant equations:

1. Electrical Power:
Power is the rate at which energy is transferred or converted. In electrical systems, power is the product of voltage and current.
Equation: P = V × I
Where:
P = Power (measured in watts, W)
V = Voltage (measured in volts, V)
I = Current (measured in amperes, A)

2. Ohm's Law:
While not directly used in this problem, Ohm's Law relates voltage, current, and resistance in an electrical circuit.
Equation: V = I × R
Where:
R = Resistance (measured in ohms, Ω)

3. Energy Consumption:
Energy consumed over time is equal to power multiplied by time.
Equation: E = P × t
Where:
E = Energy (measured in joules, J, or watt-hours, Wh)
t = Time (measured in seconds, s, or hours, h)

4. Conservation of Energy:
The principle that energy cannot be created or destroyed, only converted from one form to another. In this case, electrical energy is being converted to other forms (primarily heat and electromagnetic radiation) by the microwave oven.

5. AC Power:
The voltage provided by standard electrical outlets is alternating current (AC). However, for simple power calculations with resistive loads, we can use the root mean square (RMS) values of voltage and current, which are the values typically given for household electricity.
--- End of Example ----

Question: {question}
Principles Involved:"""

num_problems = len(physics_dataset_evaluation)
principles_dict = {}

for idx, physics_problem in enumerate(physics_dataset_evaluation):
    print(f"Processing problem {idx + 1} of {num_problems}.")
    physics_question = physics_problem["input"]
    prompt = get_principles_prompt_template.format(question=physics_question)
    extracted_principles = call_llama3(prompt)
    principles_dict[idx] = {
        "question": physics_question,
        "principles": extracted_principles
    }


Processing problem 0 of 151.
Processing problem 1 of 151.
Processing problem 2 of 151.
Processing problem 3 of 151.
Processing problem 4 of 151.
Processing problem 5 of 151.
Processing problem 6 of 151.
Processing problem 7 of 151.
Processing problem 8 of 151.
Processing problem 9 of 151.
Processing problem 10 of 151.
Processing problem 11 of 151.
Processing problem 12 of 151.
Processing problem 13 of 151.
Processing problem 14 of 151.
Processing problem 15 of 151.
Processing problem 16 of 151.
Processing problem 17 of 151.
Processing problem 18 of 151.
Processing problem 19 of 151.
Processing problem 20 of 151.
Processing problem 21 of 151.
Processing problem 22 of 151.
Processing problem 23 of 151.
Processing problem 24 of 151.
Processing problem 25 of 151.
Processing problem 26 of 151.
Processing problem 27 of 151.
Processing problem 28 of 151.
Processing problem 29 of 151.
Processing problem 30 of 151.
Processing problem 31 of 151.
Processing problem 32 of 151.
Processing problem 3

Save the resulting principles JSON.

In [39]:
with open("llama3_principles.json", 'w') as json_file:
    json.dump(principles_dict, json_file, indent=4)

### Step 2: Get solution
Now that we have the principles we can prompt the model to produce a final answer. This follows Table 8 in section D.1.

In [15]:
get_solution_prompt_template = """
You are an expert at Physics.
You are given a Physics problem and a set of principles involved in solving the problem.
Solve the problem step by step by following the principles.

Here is an examples:

--- Example ----
Question: A microwave oven is connected to an outlet, 120 V, and draws a current of 2 amps. At what rate is energy being used by the microwave oven?'
Principles:
"The physics principles behind this question relate to electrical power and energy consumption. Here are the key principles and relevant equations:

1. Electrical Power:
Power is the rate at which energy is transferred or converted. In electrical systems, power is the product of voltage and current.
Equation: P = V × I
Where:
P = Power (measured in watts, W)
V = Voltage (measured in volts, V)
I = Current (measured in amperes, A)

2. Ohm's Law:
While not directly used in this problem, Ohm's Law relates voltage, current, and resistance in an electrical circuit.
Equation: V = I × R
Where:
R = Resistance (measured in ohms, Ω)

3. Energy Consumption:
Energy consumed over time is equal to power multiplied by time.
Equation: E = P × t
Where:
E = Energy (measured in joules, J, or watt-hours, Wh)
t = Time (measured in seconds, s, or hours, h)

4. Conservation of Energy:
The principle that energy cannot be created or destroyed, only converted from one form to another. In this case, electrical energy is being converted to other forms (primarily heat and electromagnetic radiation) by the microwave oven.

5. AC Power:
The voltage provided by standard electrical outlets is alternating current (AC). However, for simple power calculations with resistive loads, we can use the root mean square (RMS) values of voltage and current, which are the values typically given for household electricity.

Solution:
Let's solve this problem step by step using the principles we've discussed.

Given:
- Voltage (V) = 120 V
- Current (I) = 2 A

We need to find the rate at which energy is being used, which is equivalent to the power consumed by the microwave oven.

Step 1: Apply the electrical power equation:
P = V × I

Step 2: Substitute the known values:
P = 120 V × 2 A

Step 3: Calculate the power:
P = 240 W

Therefore, the microwave oven is using energy at a rate of 240 watts.
--- End of Example ----

Question: {question}
Principles: {principles}
Answer:
"""

step_back_solutions_dict = {}

for idx, entry in principles_dict.items():
    print(f"Processing problem {idx + 1} of {num_problems}.")
    question = entry["question"]
    principles = entry["principles"]
    prompt = get_solution_prompt_template.format(question=question, principles=principles)
    solution = call_llama3(prompt)
    step_back_solutions_dict[idx] = {
        "question": question,
        "principles": principles,
        "solution": solution
    }

Processing problem 1 of 151.
Processing problem 2 of 151.
Processing problem 3 of 151.
Processing problem 4 of 151.
Processing problem 5 of 151.
Processing problem 6 of 151.
Processing problem 7 of 151.
Processing problem 8 of 151.
Processing problem 9 of 151.
Processing problem 10 of 151.
Processing problem 11 of 151.
Processing problem 12 of 151.
Processing problem 13 of 151.
Processing problem 14 of 151.
Processing problem 15 of 151.
Processing problem 16 of 151.
Processing problem 17 of 151.
Processing problem 18 of 151.
Processing problem 19 of 151.
Processing problem 20 of 151.
Processing problem 21 of 151.
Processing problem 22 of 151.
Processing problem 23 of 151.
Processing problem 24 of 151.
Processing problem 25 of 151.
Processing problem 26 of 151.
Processing problem 27 of 151.
Processing problem 28 of 151.
Processing problem 29 of 151.
Processing problem 30 of 151.
Processing problem 31 of 151.
Processing problem 32 of 151.
Processing problem 33 of 151.
Processing problem 

Save resulting solutions

In [40]:
with open("llama3_solutions.json", 'w') as json_file:
    json.dump(step_back_solutions_dict, json_file, indent=4)

## Step 3: Use Claude Sonnet 3.5 as a judge

In [38]:
import anthropic
import os
import json

comparison_prompt_template = """
Does the result from the AI-generated solution match the official solution? Read over the entire solution to determine the AI's answer.

Note: The solutions are invariant to units if they represent the same value. For example, 1 Amp = 1000 milliamps is considered an equivalent and valid solution. 

Provide your evaluation in JSON format with the following information:
1. true if the solutions are equivalent, false otherwise (remember, booleans in JSON are in lower case)
2. A short explanation of the reasoning behind your evaluation.

--- Example ---
{{
    "ai_solution_matches": true,
    "evaluation_explanation": "The AI said, 'Therefore, the charge on the positive plate is approximately 10 microcoulombs (μC). The official solution is 0.00001 C. The two answers are the same.'"
}}
--- End of Example ---

Official solution: {official_solution}
AI-generated solution: {ai_solution}

Your assessment:
{{
    "ai_solution_matches": 
    "evaluation_explanation": ""
}}
"""

# Function to call Claude Sonnet 3.5 and return the api response
client = anthropic.Anthropic(
    api_key=os.environ.get("ANTHROPIC_API_KEY")
)

def call_claude(prompt):
    message = client.messages.create(
        model="claude-3-5-sonnet-20240620",
        max_tokens=1024,
        messages=[
            {"role": "user", "content": prompt}
        ]
    )
    return message.content[0].text

def convert_response_to_json(response):
    start_index = response.find('{')
    end_index = response.find('}') + 1
    cleaned_response = response[start_index:end_index]
    try:
        return json.loads(cleaned_response)
    except json.JSONDecodeError as e:
        print(f"Error while decoding JSON: {e}")
        return None

step_back_evaluation = {}

# Assuming format "official_solutions" are in the format of physics_dataset_evaluation and "ai_solutions" are in the format of step_back_solutions_dict.
def evaluate_solutions(official_solutions, ai_solutions, evaluation_store):
    for idx, solution in ai_solutions.items():
        print(f"Processing problem {idx + 1} of {num_problems}.")
        ai_solution = solution["solution"]
        official_solution_entry = official_solutions[idx]
        key = official_solution_entry["target"]
        official_solution = official_solution_entry[key]
        prompt = comparison_prompt_template.format(ai_solution=ai_solution, official_solution=official_solution)
        evaluation_response = call_claude(prompt)
        evaluation_store[idx] = convert_response_to_json(evaluation_response)

    return evaluation_store

# Assuming `physics_dataset_evaluation` and `step_back_solutions_dict` are already defined and populated
step_back_solutions_evaluated_dict = evaluate_solutions(physics_dataset_evaluation, step_back_solutions_dict, step_back_evaluation)


Processing problem 1 of 151.
Processing problem 2 of 151.
Processing problem 3 of 151.
Processing problem 4 of 151.
Processing problem 5 of 151.
Processing problem 6 of 151.
Processing problem 7 of 151.
Processing problem 8 of 151.


KeyboardInterrupt: 