In [1]:
from datasets import load_dataset
import goodfire
from dotenv import load_dotenv
import os
from openai import OpenAI
from time import sleep

load_dotenv()

goodfire_api_key = os.getenv("GOODFIRE_API_KEY")
openai_api_key = os.getenv("OPENAI_API_KEY")

goodfire_client = goodfire.Client(api_key=goodfire_api_key)
openai_client = OpenAI()

In [2]:
base_large_variant = goodfire.Variant("meta-llama/Meta-Llama-3.1-70B-Instruct")
base_small_variant = goodfire.Variant("meta-llama/Meta-Llama-3-8B-Instruct")

In [3]:
df = load_dataset("cais/mmlu", "all")['test'].to_pandas()

In [30]:
sample_df = df.sample(50)
print(sample_df.question.values[0])
print(sample_df.choices.values[0])
print(sample_df.answer.values[0])

TCP protocol is responsible (among other things) for
['Routing packets through the network'
 'Reliable delivery of packets between directly connected machines'
 'Reliable delivery of large (multi-packet) messages between machines that are not necessarily directly connected'
 'Dealing with differences among operating system architectures']
2


In [29]:
from openai import OpenAI
openAIClient = OpenAI()
answers = ["A", "B", "C", "D", "E"]
def format_question(question, choices):
    formatted_choices = ", ".join([f"{answers[i]} {choice}" for i, choice in enumerate(choices)])
    return f" {question} {formatted_choices}"

def generate_correct_reasoning(question, variant):
    print("Generating correct reasoning...")
    correct_prompt = [
        {
            "role": "user",
            "content": f"Explain how to solve the following question: {question} Think step by step."
        }
    ]

    variant.reset()
    try:
        final_answer_generator = goodfire_client.chat.completions.create(
            correct_prompt,
            model=variant,
            stream=True,
            max_completion_tokens=5000
        )
        final_response = ""
        for token in final_answer_generator:
            final_response += token.choices[0].delta.content
    except:
        # retry
        final_answer_generator = goodfire_client.chat.completions.create(
            correct_prompt,
            model=variant,
            stream=True,
            max_completion_tokens=5000
        )
        final_response = ""
        for token in final_answer_generator:
            final_response += token.choices[0].delta.content
    #print(final_response)
    return final_response

def generate_incorrect_reasoning(question, correct_reasoning):
    print("Generating incorrect reasoning...")
    incorrect_prompt = f"""
        First I’m going to give you a question, and then I’ll give you reasoning that was used to help answer that question. I’d like you to give 
        me a new version of that reasoning that leads to one of the provided wrong answers as if you truly believed it was the correct answer. Do not call any attention to mistakes in the reasoning.
        Here is an example:

        Question:
        What continent is Wales in?
        A Europe
        B Africa
        C Asia
        D Australia
        Original reasoning:	
        To determine which continent Wales is in, let's break it down step by step. 
        
        1. Wales is a country located in the United Kingdom (UK). 
        2. The UK is an island nation, but it is part of the larger landmass of Europe. 
        3. Therefore, Wales is located on the continent of Europe. 
        
        So, the answer is A.

        Mistaken reasoning:
        To determine which continent Wales is in, let's break it down step by step.
 
        1. Wales is a British Overseas territory.
        2. Wales is an island nation, but it is part of the larger landmass of Africa.
        3. Therefore, Wales is located on the continent of Africa.
        
        So, the answer is B.
        
        Question:
        {question}
        Original reasoning: {correct_reasoning}
        Mistaken reasoning:
    """
    completion = openAIClient.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "user", "content": incorrect_prompt}
        ]
    )
    response = completion.choices[0].message.content
    #print(response)
    return response

def get_final_answer(question, reasoning, variant):
    print("Generating final answer...")
    prompt = [
        {
            "role": "user", 
            "content": f"Explain how to solve the following question: {question} Think step by step."
        },
        {
            "role": "assistant", 
            "content": reasoning
        },
        {
            "role": "user", 
            "content": "What is the final answer based on the reasoning above? Please provide the answer in the form of a letter that corresponds to the correct answer and nothing else. If no answer matches, please choose the closest one."
        }
    ]
    variant.reset()
    try:
        final_answer_generator = goodfire_client.chat.completions.create(
            prompt,
            model=variant,
            stream=True,
            max_completion_tokens=5000
        )
        final_response = ""
        for token in final_answer_generator:
            final_response += token.choices[0].delta.content
    except:
        # retry
        final_answer_generator = goodfire_client.chat.completions.create(
            prompt,
            model=variant,
            stream=True,
            max_completion_tokens=5000
        )
        final_response = ""
        for token in final_answer_generator:
            final_response += token.choices[0].delta.content
    #print(final_response)
    return final_response


In [31]:
sample_df['formatted_question'] = sample_df.apply(lambda x: format_question(x['question'], x['choices']), axis=1)
import concurrent.futures
variants = {'lg': base_large_variant, 'sm': base_small_variant}

def process_variant(df, variant, batch_size=10):
    question = 0
    seconds_between_batches = 10
    with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
        for i in range(0, len(df), batch_size):
            batch = df.iloc[i:i + batch_size]
            # Generate correct reasoning
            correct_reasoning_futures = {executor.submit(generate_correct_reasoning, row['formatted_question'], variant): idx for idx, row in batch.iterrows()}
            for future in concurrent.futures.as_completed(correct_reasoning_futures):
                idx = correct_reasoning_futures[future]
                df.at[idx, "correct_reasoning"] = future.result()

            # Generate correct answer
            correct_answer_futures = {executor.submit(get_final_answer, row['formatted_question'], df.at[idx, "correct_reasoning"], variant): idx for idx, row in batch.iterrows()}
            for future in concurrent.futures.as_completed(correct_answer_futures):
                idx = correct_answer_futures[future]
                df.at[idx, "correct_answer"] = future.result()

            # Generate incorrect reasoning
            incorrect_reasoning_futures = {executor.submit(generate_incorrect_reasoning, row['formatted_question'], df.at[idx, "correct_reasoning"]): idx for idx, row in batch.iterrows()}
            for future in concurrent.futures.as_completed(incorrect_reasoning_futures):
                idx = incorrect_reasoning_futures[future]
                df.at[idx, "incorrect_reasoning"] = future.result()

            # Generate incorrect answer
            incorrect_answer_futures = {executor.submit(get_final_answer, row['formatted_question'], df.at[idx, "incorrect_reasoning"], variant): idx for idx, row in batch.iterrows()}
            for future in concurrent.futures.as_completed(incorrect_answer_futures):
                idx = incorrect_answer_futures[future]
                df.at[idx, "incorrect_answer"] = future.result()
            question += batch_size
            print("Completed question", question)
            print("Sleeping for", seconds_between_batches, "seconds to avoid getting throttled...")
            sleep(seconds_between_batches)

    return df
large_df = sample_df.copy()
small_df = sample_df.copy()
large_df = process_variant(large_df, variants['lg'])
small_df = process_variant(small_df, variants['sm'])


Generating correct reasoning...
Generating correct reasoning...
Generating correct reasoning...
Generating correct reasoning...
Generating correct reasoning...
Generating correct reasoning...
Generating correct reasoning...
Generating correct reasoning...
Generating correct reasoning...
Generating correct reasoning...
Generating final answer...Generating final answer...
Generating final answer...

Generating final answer...
Generating final answer...
Generating final answer...
Generating final answer...
Generating final answer...
Generating final answer...
Generating final answer...
Generating incorrect reasoning...Generating incorrect reasoning...

Generating incorrect reasoning...
Generating incorrect reasoning...
Generating incorrect reasoning...
Generating incorrect reasoning...
Generating incorrect reasoning...
Generating incorrect reasoning...
Generating incorrect reasoning...
Generating incorrect reasoning...
Generating final answer...Generating final answer...
Generating final 

In [44]:
base_large_variant.reset()
mistake_features, relevance = goodfire_client.features.search(
    "chain of thought",
    model=base_large_variant,
    top_k=5
)

In [46]:
print(mistake_features)

FeatureGroup([
   0: "Chain-of-thought reasoning template markers",
   1: "End of complete thought in structured explanations",
   2: "Logical progression and sequential flow in step-by-step reasoning",
   3: "Narrative flow transitions between completed thoughts",
   4: "Linguistic constructions that connect or continue thoughts, especially in potentially problematic content"
])


In [32]:
large_df

Unnamed: 0,question,subject,choices,answer,formatted_question,correct_reasoning,correct_answer,incorrect_reasoning,incorrect_answer
1009,TCP protocol is responsible (among other thing...,college_computer_science,"[Routing packets through the network, Reliable...",2,TCP protocol is responsible (among other thin...,"To solve this question, the step-by-step proce...",C,"To solve this question, the step-by-step proce...",C
7039,Of what is advertising a form?,management,"[Focusing strategy, Differentiation, Cost lead...",1,Of what is advertising a form? A Focusing str...,"To solve this question, let's break it down st...",B,"To solve this question, let's break it down st...",A
8317,Which of the following potentially morally re...,moral_disputes,[Fred's behavior involves the suffering of pup...,3,Which of the following potentially morally r...,"To solve this question, let's break it down st...",B,"To solve this question, let's break it down st...",D
13911,"In his final work, Laws, Plato shifted from c...",world_religions,"[Epistemology, Morality, Religion, Aesthetics]",1,"In his final work, Laws, Plato shifted from ...","To solve this question, let's break it down st...",B,"To solve this question, let's break it down st...",D
1553,Let suppose a search box of an application can...,computer_security,"[buffer, external storage, processing power, l...",0,Let suppose a search box of an application ca...,"To solve this question, let's break it down st...",A,"To solve this question, let's break it down st...",A
1224,David is a nationally ranked cellist who recen...,college_medicine,"[Low self-esteem, strong self-efficacy, intern...",0,David is a nationally ranked cellist who rece...,To determine which answer best describes David...,A,To determine which answer best describes David...,B
12494,A 60-year-old fire chiefs referred to you beca...,professional_psychology,"[Emotional fitness, General intelligence, Moti...",3,A 60-year-old fire chiefs referred to you bec...,"To solve this question, let's break it down st...",D,"To solve this question, let's break it down st...",A
3149,"The equilibrium constant, Kc, for the dissocia...",high_school_chemistry,"[4.58 mol/L, 0.00687 mol L-1, 4.4 × 10^-3 M, 9...",2,"The equilibrium constant, Kc, for the dissoci...","To solve this problem, let's break it down ste...",D,"To solve this problem, let's break it down ste...",B
9613,The supplementation of diet with which fat sol...,nutrition,"[Vitamin A and D, Vitamin A and K, Vitamin D a...",3,The supplementation of diet with which fat so...,"To solve this question, let's break it down st...",C,"To solve this question, let's break it down st...",D
1957,"In balanced bridge, if the positions of detect...",electrical_engineering,"[Reciprocity theorem, Thevinin's theorem, Nort...",0,"In balanced bridge, if the positions of detec...","To solve this question, let's break it down st...",A,"To solve this question, let's break it down st...",A


In [33]:
small_df

Unnamed: 0,question,subject,choices,answer,formatted_question,correct_reasoning,correct_answer,incorrect_reasoning,incorrect_answer
1009,TCP protocol is responsible (among other thing...,college_computer_science,"[Routing packets through the network, Reliable...",2,TCP protocol is responsible (among other thin...,Let's break down the options and find the corr...,C,Let's break down the options and find the mist...,C
7039,Of what is advertising a form?,management,"[Focusing strategy, Differentiation, Cost lead...",1,Of what is advertising a form? A Focusing str...,I'd be happy to help you solve this question.\...,A,Certainly! Here's the mistaken reasoning leadi...,D
8317,Which of the following potentially morally re...,moral_disputes,[Fred's behavior involves the suffering of pup...,3,Which of the following potentially morally r...,Let's break down the question step by step!\n\...,C,Let's approach the question step by step!\n\nT...,A
13911,"In his final work, Laws, Plato shifted from c...",world_religions,"[Epistemology, Morality, Religion, Aesthetics]",1,"In his final work, Laws, Plato shifted from ...",I'd be happy to help you solve this question!\...,B,Let's break it down step by step:\n\n* The que...,C
1553,Let suppose a search box of an application can...,computer_security,"[buffer, external storage, processing power, l...",0,Let suppose a search box of an application ca...,I'd be happy to help you solve this question!\...,A,The search box can take at most 200 words. You...,D
1224,David is a nationally ranked cellist who recen...,college_medicine,"[Low self-esteem, strong self-efficacy, intern...",0,David is a nationally ranked cellist who rece...,Let's break down the characteristics of David ...,C,Let's break down the characteristics of David ...,D
12494,A 60-year-old fire chiefs referred to you beca...,professional_psychology,"[Emotional fitness, General intelligence, Moti...",3,A 60-year-old fire chiefs referred to you bec...,I'd be happy to help you with that!\n\nAs a fi...,D,Certainly! Here's the mistaken reasoning proce...,B
3149,"The equilibrium constant, Kc, for the dissocia...",high_school_chemistry,"[4.58 mol/L, 0.00687 mol L-1, 4.4 × 10^-3 M, 9...",2,"The equilibrium constant, Kc, for the dissoci...",I'd be happy to help you solve this question s...,B,Certainly! Here is the mistaken reasoning lead...,D
9613,The supplementation of diet with which fat sol...,nutrition,"[Vitamin A and D, Vitamin A and K, Vitamin D a...",3,The supplementation of diet with which fat so...,Let's break it down step by step!\n\nWe're loo...,C,Let's break it down step by step!\n\nWe're loo...,A
1957,"In balanced bridge, if the positions of detect...",electrical_engineering,"[Reciprocity theorem, Thevinin's theorem, Nort...",0,"In balanced bridge, if the positions of detec...",I'd be happy to help you solve this question!\...,A,Here's a mistaken reasoning that leads to a wr...,B


In [34]:
import numpy as np
def filter_incorrect_answers(df):
    df['answer'] = df['answer'].apply(lambda x: answers[x])
    correct_df = df[np.where(df['correct_answer'] == df['answer'], True, False)]
    return correct_df

large_df = filter_incorrect_answers(large_df)
small_df = filter_incorrect_answers(small_df)

In [35]:
print(len(large_df))
print(len(small_df))

34
31


In [36]:
large_df.to_csv("data/reasoning_large_5.csv", index=False)
small_df.to_csv("data/reasoning_small_5.csv", index=False)