In [None]:
# %%

%load_ext autoreload
%autoreload 2


# %%
import torch
import time
import plotly.express as px
import matplotlib.pyplot as plt

from task_evaluation import TaskEvaluation
from data.ioi_dataset import gen_templated_prompts
from data.greater_than_dataset import generate_greater_than_dataset
from circuit_discovery import CircuitDiscovery, only_feature
from circuit_lens import CircuitComponent
from plotly_utils import *
from data.ioi_dataset import IOI_GROUND_TRUTH_HEADS
from data.greater_than_dataset import GT_GROUND_TRUTH_HEADS
from memory import get_gpu_memory
from sklearn import metrics
from tqdm import trange

from utils import get_attn_head_roc


# %%
torch.set_grad_enabled(False)
get_gpu_memory()
# %%
dataset_prompts = gen_templated_prompts(template_idex=1, N=500)
dataset_gt = generate_greater_than_dataset(N=500)

In [None]:



# dataset_prompts = generate_greater_than_dataset(N=100)


# %%

def component_filter(component: str):
    return component in [
        CircuitComponent.Z_FEATURE,
        CircuitComponent.MLP_FEATURE,
        CircuitComponent.ATTN_HEAD,
        CircuitComponent.UNEMBED,
        # CircuitComponent.UNEMBED_AT_TOKEN,
        CircuitComponent.EMBED,
        CircuitComponent.POS_EMBED,
        # CircuitComponent.BIAS_O,
        CircuitComponent.Z_SAE_ERROR,
        # CircuitComponent.Z_SAE_BIAS,
        # CircuitComponent.TRANSCODER_ERROR,
        # CircuitComponent.TRANSCODER_BIAS,
    ]


pass_based = True

passes = 5
node_contributors = 1
first_pass_minimal = True

sub_passes = 3
do_sub_pass = True #False
layer_thres = 9
minimal = True


num_greedy_passes = 20
k = 1
N = 30

thres = 4


# # Danny and Charlie... Charlie gave shit to Danny
# # Danny and Charlie... Charlie gave shit to Charlie
# # Danny and Charlie... Danny gave shit to Danny
# #

def strategy(cd: CircuitDiscovery):
    if pass_based:
        for _ in range(passes):
            cd.add_greedy_pass(contributors_per_node=node_contributors, minimal=first_pass_minimal)

            if do_sub_pass:
                for _ in range(sub_passes):
                    cd.add_greedy_pass_against_all_existing_nodes(contributors_per_node=node_contributors, skip_z_features=True, layer_threshold=layer_thres, minimal=minimal)
    else:
        for _ in range(num_greedy_passes):
            cd.greedily_add_top_contributors(k=k, reciever_threshold=thres)



task_eval = TaskEvaluation(prompts=dataset_prompts, circuit_discovery_strategy=strategy, allowed_components_filter=component_filter)

In [None]:
# a = task_eval.get_attn_head_freqs_over_dataset(N=N, return_freqs=True)

# %%
ground = task_eval.get_faithfulness_curve_over_data(N=20, attn_head_freq_n=10, faithfulness_intervals=30, rand=False, ioi_ground=True, task='ioi')
base = task_eval.get_faithfulness_curve_over_data(N=20, attn_head_freq_n=10, faithfulness_intervals=30, rand=False, ioi_ground=False, task='ioi')

radd = []
for _ in trange(20):
    radd.append(task_eval.get_faithfulness_curve_over_data(N=20, attn_head_freq_n=10, faithfulness_intervals=30, rand=True, ioi_ground=False, visualize=False))


# %%
big_rad = {}
for rad in radd:
    for k, v in rad.items():
        if k not in big_rad:
            big_rad[k] = 0
        big_rad[k] += v

for k in big_rad:
    big_rad[k] /= 20

rad = big_rad

# %%
plt.plot([float(k) for k in ground.keys()], ground.values(), label="Ground Truth")
plt.plot([float(k) for k in rad.keys()], base.values(), label="Circuit Discovery")
plt.plot([float(k) for k in base.keys()], rad.values(), label="Random")
plt.legend()
plt.grid(True)
plt.grid(color='gray', linestyle='--', linewidth=0.2)
plt.title("IOI Faithfulness (Mean Ablation)")
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.margins(0)
plt.ylabel("Normalized KL")
plt.xlabel("# Heads")
# ax.spines['left'].set_visible(False)
# ax.spines['bottom'].set_visible(False)
plt.show()

In [None]:
import plotly.graph_objects as go

fig = go.Figure()

fig.add_trace(go.Scatter(x=list(ground.keys()), y=list(ground.values()),
                         mode='lines+markers',
                         name='Ground Truth',
                         line=dict(color='blue', width=2),
                         marker=dict(size=8, symbol='circle', color='blue')))

fig.add_trace(go.Scatter(x=list(rad.keys()), y=list(base.values()),
                         mode='lines+markers',
                         name='Circuit Discovery',
                         line=dict(color='red', width=2),
                         marker=dict(size=8, symbol='square', color='red')))

fig.add_trace(go.Scatter(x=list(base.keys()), y=list(rad.values()),
                         mode='lines+markers',
                         name='Random',
                         line=dict(color='green', width=2),
                         marker=dict(size=8, symbol='triangle-up', color='green')))

fig.update_layout(
    title='IOI Faithfulness (Mean Ablation)',
    xaxis=dict(title='# Heads', showgrid=False, zeroline=False, showline=True, linewidth=1, linecolor='black', mirror=True),
    yaxis=dict(title='Normalized KL', showgrid=False, zeroline=False, showline=True, linewidth=1, linecolor='black', mirror=True),
    font=dict(size=14),
    template='plotly_white',
    width=800,
    height=600,
    legend=dict(x=0.7, y=0.9, borderwidth=1, bordercolor='black', bgcolor='rgba(255, 255, 255, 0.8)'),
    plot_bgcolor='white',
    hovermode='x'
)

# fig.update_xaxes(tickvals=list(ground.keys()))

fig.show()

In [None]:
# %%

%load_ext autoreload
%autoreload 2


# %%
import torch
import time
import plotly.express as px
import matplotlib.pyplot as plt

from task_evaluation import TaskEvaluation
from data.ioi_dataset import gen_templated_prompts
from data.greater_than_dataset import generate_greater_than_dataset
from circuit_discovery import CircuitDiscovery, only_feature
from circuit_lens import CircuitComponent
from plotly_utils import *
from data.ioi_dataset import IOI_GROUND_TRUTH_HEADS
from data.greater_than_dataset import GT_GROUND_TRUTH_HEADS
from memory import get_gpu_memory
from sklearn import metrics
from tqdm import trange

from utils import get_attn_head_roc


# %%
torch.set_grad_enabled(False)
# %%


#dataset_prompts = gen_templated_prompts(template_idex=1, N=500)
dataset_prompts = generate_greater_than_dataset(N=100)


# %%

def component_filter(component: str):
    return component in [
        CircuitComponent.Z_FEATURE,
        CircuitComponent.MLP_FEATURE,
        CircuitComponent.ATTN_HEAD,
        CircuitComponent.UNEMBED,
        # CircuitComponent.UNEMBED_AT_TOKEN,
        CircuitComponent.EMBED,
        CircuitComponent.POS_EMBED,
        # CircuitComponent.BIAS_O,
        CircuitComponent.Z_SAE_ERROR,
        # CircuitComponent.Z_SAE_BIAS,
        # CircuitComponent.TRANSCODER_ERROR,
        # CircuitComponent.TRANSCODER_BIAS,
    ]


pass_based = True

passes = 5
node_contributors = 1
first_pass_minimal = True

sub_passes = 3
do_sub_pass = False
layer_thres = 9
minimal = True


num_greedy_passes = 20
k = 1
N = 30

thres = 4

def strategy(cd: CircuitDiscovery):
    if pass_based:
        for _ in range(passes):
            cd.add_greedy_pass(contributors_per_node=node_contributors, minimal=first_pass_minimal)

            if do_sub_pass:
                for _ in range(sub_passes):
                    cd.add_greedy_pass_against_all_existing_nodes(contributors_per_node=node_contributors, skip_z_features=True, layer_threshold=layer_thres, minimal=minimal)
    else:
        for _ in range(num_greedy_passes):
            cd.greedily_add_top_contributors(k=k, reciever_threshold=thres)



task_eval = TaskEvaluation(prompts=dataset_prompts, circuit_discovery_strategy=strategy, allowed_components_filter=component_filter)

cd = task_eval.get_circuit_discovery_for_prompt(20)
# f = task_eval.get_features_at_heads_over_dataset(N=30)
N = 100

attn_freqs = task_eval.get_attn_head_freqs_over_dataset(N=N, subtract_counter_factuals=False, return_freqs=True)

In [None]:
import plotly.graph_objects as go

def get_attn_head_roc(ground_truth, data, task_name, visualize=True, additional_title=""):
    fp, tp, thresh = metrics.roc_curve(ground_truth.flatten(), data.flatten())
    score = metrics.roc_auc_score(ground_truth.flatten(), data.flatten())

    if visualize:
        print("Score:", score)

        # Create the ROC curve with flat lines and vertical lines
        x_coords = []
        y_coords = []

        for i in range(len(fp)):
            x_coords.append(fp[i])
            y_coords.append(tp[i])

            if i < len(fp) - 1:
                x_coords.append(fp[i])
                y_coords.append(tp[i+1])
                x_coords.append(fp[i+1])
                y_coords.append(tp[i+1])

        fig = go.Figure()

        fig.add_trace(go.Scatter(x=x_coords, y=y_coords,
                                 mode='lines',
                                 name='ROC Curve',
                                 line=dict(color='blue', width=2)))

        fig.add_shape(type='line',
                      x0=0, y0=0, x1=1, y1=1,
                      line=dict(color='red', width=2, dash='dash'),
                      name='Random Guess')

        fig.update_layout(
            title=f"ROC Curve for {task_name} " + additional_title,
            xaxis=dict(title='False Positive Rate', showgrid=False, zeroline=False),
            yaxis=dict(title='True Positive Rate', showgrid=False, zeroline=False),
            font=dict(size=14),
            template='plotly_white',
            width=600,
            height=600,
            legend=dict(x=0.7, y=0.2, borderwidth=1, bordercolor='black', bgcolor='rgba(255, 255, 255, 0.8)'),
            plot_bgcolor='white',
            hovermode='closest'
        )

        fig.update_xaxes(range=[0, 1.01])
        fig.update_yaxes(range=[0, 1.01])

        fig.show()

    return score, fp, tp, thresh

In [None]:
# IOI_GROUND_TRUTH_DATA = torch.load("data/ioi_ground_truth.pt")

# IOI_GROUND_TRUTH_HEADS = torch.zeros(12, 12)

# for layer, head in IOI_GROUND_TRUTH_DATA:
#     IOI_GROUND_TRUTH_HEADS[layer, head] = 1

# ground_truth = IOI_GROUND_TRUTH_HEADS.flatten()

GT_GROUND_TRUTH_DATA = torch.load("data/gt_ground_truth.pt")

GT_GROUND_TRUTH_HEADS = torch.zeros(12, 12)

for layer, head in GT_GROUND_TRUTH_DATA:
    GT_GROUND_TRUTH_HEADS[layer, head] = 1

ground_truth = GT_GROUND_TRUTH_HEADS.flatten()

# fp, tp, thresh = get_attn_head_roc(ground_truth, a.flatten().softmax(dim=-1), "IOI", visualize=True, additional_title="(No Counterfactuals)")
score, _, _, _ = get_attn_head_roc(ground_truth, attn_freqs.flatten().softmax(dim=-1), "GT", visualize=True, additional_title="(No Counterfactuals)")

# Labelling IOI

How are we going to do this?
* We will work from the bottom up (starting at bottom nodes of our graph and moving onto connected components).
* Get feature families for each component. 
* Get max-activating examples for each feature family.
* Get token contributions for these max-activating examples - combine this into one prompt.
* Get the max-activating examples/tokens on IOI specific inputs, as well as token contributions.
* Provide a description of the IOI task in the prompt. 
* Provide the feature interpretation of every component feeding in to our current component.
* Get IOI-specific interpretation of component.

Baseline:
* One circuit at a time, same process. 


Extensions (not for now):
* Feature co-occurrences - not only what our component is doing, but how it passes information between nodes.

In [None]:
# Get feature families for each component

from autointerpretability import *

cp = get_circuit_prediction(task='ioi', N=20)

from collections import Counter, defaultdict

def get_top_k_feature_tuples_for_component(co_occurrence_dict, component_str, k=5):
    # Parse the component string to get the appropriate tuple key
    if component_str.startswith("MLP"):
        layer = int(component_str[3:])
        component = ('mlp_feature', layer)
    elif component_str.startswith("L") and "H" in component_str:
        layer, head = map(int, component_str[1:].split("H"))
        component = ('attn_head', layer, head)
    else:
        raise ValueError(f"Invalid component format: {component_str}")

    # Use a Counter to count the occurrences of each tuple
    global_counter = Counter()

    # Iterate through the co-occurrence dictionary
    for comp_pair, co_occurrences in co_occurrence_dict.items():
        comp1, comp2 = comp_pair

        if comp1 == component or comp2 == component:
            for feature_tuple in co_occurrences:
                global_counter[(comp_pair, feature_tuple)] += 1

    # Get the top-k tuples by count
    top_k_tuples = global_counter.most_common(k)

    # Create a dictionary to store the results
    top_k_dict = defaultdict(dict)
    
    for (comp_pair, feature_tuple), count in top_k_tuples:
        top_k_dict[comp_pair][feature_tuple] = count

    return top_k_dict

In [None]:
features = list(set(cp.circuit_hypergraph['L2_H2']['features']))
features

In [None]:
model, z_saes, transcoders = get_model_encoders('cpu')

In [None]:
# Autoreload
%load_ext autoreload
%autoreload 2

from data.ioi_dataset import gen_templated_prompts
from aug_interp_prompts import main_aug_interp_prompt, main_aug_interp_prompt_v2
from openai_utils import gen_openai_completion, get_response
from autointerpretability import *
from discovery_strategies import (
    create_filter,
    create_simple_greedy_strategy,
    create_top_contributor_strategy,
)
from max_act_analysis import MaxActAnalysis

features = list(set(cp.circuit_hypergraph['L2_H2']['features']))

#feature = 19042
layer = 2
num_examples = 5000

strategy = create_simple_greedy_strategy(
    passes=1,
    node_contributors=1,
    minimal=True,
)


dataset_prompts = gen_templated_prompts(template_idex=1, N=500)
prompts = [x['text'] + x['correct'] for x in dataset_prompts]
tokens = model.to_tokens(prompts)  # Assuming `model` is already defined
dataset_prompt_tokens = torch.tensor(tokens)

mini_examples_owt_overall = []
mini_examples_ioi_overall = []

for feature in features:

    analyze_owt = MaxActAnalysis(
        "attn", 
        layer, 
        feature, 
        num_sequences=num_examples, 
        batch_size=128, 
        strategy=strategy
    )
    mini_examples_owt = analyze_owt.get_context_referenced_prompts_for_range(0, 5)
    mini_examples_owt_overall.append(mini_examples_owt)

    # For Dataset Prompt Tokens
    analyze_prompts = MaxActAnalysis(
        "attn", 
        layer, 
        feature, 
        num_sequences=num_examples, 
        batch_size=128, 
        strategy=strategy, 
        token_dataset=dataset_prompt_tokens
    )
    mini_examples_ioi = analyze_prompts.get_context_referenced_prompts_for_range(0, 5)
    mini_examples_ioi_overall.append(mini_examples_ioi)

In [None]:
from jinja2 import Template
from typing import List, Tuple

def main_aug_interp_prompt_ioi(
    examples: List[str], examples_ioi: List[str], token_lr=("<<", ">>"), context_lr=("[[", "]]")
):
    tl, tr = token_lr
    cl, cr = context_lr

    template = Template(
        """
{# You are a meticulous AI researcher conducting an important investigation into a certain neuron in a language model. Your task is to analyze the neuron and provide an explanation that thoroughly encapsulates its behavior in the context of a specific task: Indirect Object Identification (IOI). Here's how you will complete this task: #}

You are a meticulous AI researcher conducting an important investigation into a certain neuron in a language model. This language model is trained to predict the text that will follow a given input. Your task is to figure out what sort of behavior this neuron is responsible for -- namely, when this neuron fires, what kind of predictions does this neuron promote in the context of the specific task of Indirect Object Identification (IOI)? Here's how you'll complete the task:

INPUT_DESCRIPTION: 
You will be given several examples of text that activate the neuron. First we'll provide the example text without any annotations, and then we'll provide the same text with annotations that show the specific tokens that caused the neuron to activate and context about why the neuron fired.

The specific token that the neuron activates on will be the last token in the sequence, and will appear between {{tl}} and {{tr}} (like {{tl}}this{{tr}}).  

Additionally, each sequence will have tokens enclosed between {{cl}} and {{cr}} (like {{cl}}this{{cr}}). From previous analysis, we know that these tokens form the context for why our neuron fires on the token enclosed in {{tl}} and {{tr}} (in addition to the value of the actual token itself). Note that we treat the group of tokens enclosed between {{cl}} and {{cr}} as the "context" for why the neuron fired.

We will provide both general examples and specific examples related to the task of Indirect Object Identification (IOI).

Task Description: A sentence containing indirect object identification (IOI) has an initial dependent clause, e.g. "When Mary and John went to the store", and a main clause, e.g. "John gave a bottle of milk to Mary". The initial clause introduces the indirect object (IO) "Mary" and the subject (S) "John". The main clause refers to the subject a second time, and in all our examples of IOI, the subject gives an object to the IO. The IOI task is to predict the final token in the sentence to be the IO. We use 'S1' and 'S2' to refer to the first and second occurrences of the subject, when we want to specify position.

Given these examples, complete the following steps.

OUTPUT_DESCRIPTION:

Step 1: Based on the general examples provided, write down observed patterns between the tokens that caused the neuron to activate (just the tokens enclosed in {{tl}} and {{tr}}).
Step 2: Based on the general examples provided, write down patterns you see in the context for why the neuron fired. (Remember, the "context" for an example is the group of tokens enclosed in {{cl}} and {{cr}}). Include any patterns in the relationships between different tokens in the context, and any patterns in the relationship between the context and the rest of the text.
Step 3: Write down several general shared features of the general text examples.
Step 4: Based on the IOI examples provided, write down observed patterns between the tokens that caused the neuron to activate (just the tokens enclosed in {{tl}} and {{tr}}).
Step 5: Based on the IOI examples provided, write down patterns you see in the context for why the neuron fired. (Remember, the "context" for an example is the group of tokens enclosed in {{cl}} and {{cr}}). Include any patterns in the relationships between different tokens in the context, and any patterns in the relationship between the context and the rest of the text.
Step 6: Write down several general shared features of the IOI text examples.
Step 7: Based on the patterns you found between the activating token and the relevant context in both general and IOI examples, write down your best explanation for what this neuron is responsible for. Propose your explanation in the following form: 
[EXPLANATION]: <your explanation>

Guidelines:
- Try to produce a final explanation that's both concise and general to the examples provided.
- Your explanation should be short: 1-2 sentences.
- Specifically address the neuron's role in the context of the IOI task, explaining its specific function in relation to predicting the indirect object.

INPUT:

General Examples:
{% for example in examples %}                         
EXAMPLE {{loop.index + 1}}:
- Base Text -
=================================================
{{example[0]}}
=================================================

- Annotated Text -
=================================================
{{example[1]}}
=================================================

{% endfor %}

IOI Task Examples:
{% for example in examples_ioi %}                         
EXAMPLE {{loop.index + 1}}:
- Base Text -
=================================================
{{example[0]}}
=================================================

- Annotated Text -
=================================================
{{example[1]}}
=================================================

{% endfor %}

OUTPUT:
                         
Step 1:
"""
    )

    return template.render(
        {"tl": tl, "tr": tr, "cl": cl, "cr": cr, "examples": examples, "examples_ioi": examples_ioi}
    )

In [None]:
p = main_aug_interp_prompt_ioi(mini_examples_owt_overall, mini_examples_ioi_overall)
print(p)

In [None]:
interp = get_response(p)
print(interp)

In [None]:
l2h2_interp = "Predicting conjunctions following the subject's name."

In [None]:
features = [x for x in list(set(cp.circuit_hypergraph['L0_H1']['features'])) if x!=-1]

#feature = 19042
layer = 0
num_examples = 5000

strategy = create_simple_greedy_strategy(
    passes=1,
    node_contributors=1,
    minimal=True,
)


dataset_prompts = gen_templated_prompts(template_idex=1, N=500)
prompts = [x['text'] + x['correct'] for x in dataset_prompts]
tokens = model.to_tokens(prompts)  # Assuming `model` is already defined
dataset_prompt_tokens = torch.tensor(tokens)

mini_examples_owt_overall = []
mini_examples_ioi_overall = []

for feature in features:

    analyze_owt = MaxActAnalysis(
        "attn", 
        layer, 
        feature, 
        num_sequences=num_examples, 
        batch_size=128, 
        strategy=strategy
    )
    mini_examples_owt = analyze_owt.get_context_referenced_prompts_for_range(0, 5)
    mini_examples_owt_overall.append(mini_examples_owt)

    # For Dataset Prompt Tokens
    analyze_prompts = MaxActAnalysis(
        "attn", 
        layer, 
        feature, 
        num_sequences=num_examples, 
        batch_size=128, 
        strategy=strategy, 
        token_dataset=dataset_prompt_tokens
    )
    mini_examples_ioi = analyze_prompts.get_context_referenced_prompts_for_range(0, 5)
    mini_examples_ioi_overall.append(mini_examples_ioi)

In [None]:
# Convert list of lists into a single list
mini_examples_owt_overall = [item for sublist in mini_examples_owt_overall for item in sublist]
mini_examples_ioi_overall = [item for sublist in mini_examples_ioi_overall for item in sublist]

In [None]:
p = main_aug_interp_prompt_ioi(mini_examples_owt_overall, mini_examples_ioi_overall)
interp = get_response(p)
print(interp)

In [None]:
l0h1_interp = "Identifying the later appearance of the indirect object in a sentence structure where the indirect object is being given something."

In [None]:
features = [x for x in list(set(cp.circuit_hypergraph['L3_H0']['features'])) if x!=-1]

# #feature = 19042
layer = 3
num_examples = 5000

strategy = create_simple_greedy_strategy(
    passes=1,
    node_contributors=1,
    minimal=True,
)


dataset_prompts = gen_templated_prompts(template_idex=1, N=500)
prompts = [x['text'] + x['correct'] for x in dataset_prompts]
tokens = model.to_tokens(prompts)  # Assuming `model` is already defined
dataset_prompt_tokens = torch.tensor(tokens)

mini_examples_owt_overall = []
mini_examples_ioi_overall = []

for feature in features:

    analyze_owt = MaxActAnalysis(
        "attn", 
        layer, 
        feature, 
        num_sequences=num_examples, 
        batch_size=128, 
        strategy=strategy
    )
    mini_examples_owt = analyze_owt.get_context_referenced_prompts_for_range(0, 5)
    mini_examples_owt_overall.extend(mini_examples_owt)

    # For Dataset Prompt Tokens
    analyze_prompts = MaxActAnalysis(
        "attn", 
        layer, 
        feature, 
        num_sequences=num_examples, 
        batch_size=128, 
        strategy=strategy, 
        token_dataset=dataset_prompt_tokens
    )
    mini_examples_ioi = analyze_prompts.get_context_referenced_prompts_for_range(0, 5)
    mini_examples_ioi_overall.extend(mini_examples_ioi)

p = main_aug_interp_prompt_ioi(mini_examples_owt_overall, mini_examples_ioi_overall)
interp = get_response(p)
print(interp)

In [None]:
l3h0_interp = "Aids in flagging when and where the subject or indirect object from a prior clause reappears in the text."

In [None]:
features = [x for x in list(set(cp.circuit_hypergraph['L4_H11']['features'])) if x!=-1]
print(features)

# #feature = 19042
layer = 4
num_examples = 5000

strategy = create_simple_greedy_strategy(
    passes=1,
    node_contributors=1,
    minimal=True,
)


dataset_prompts = gen_templated_prompts(template_idex=1, N=500)
prompts = [x['text'] + x['correct'] for x in dataset_prompts]
tokens = model.to_tokens(prompts)  # Assuming `model` is already defined
dataset_prompt_tokens = torch.tensor(tokens)

mini_examples_owt_overall = []
mini_examples_ioi_overall = []

for feature in features:

    analyze_owt = MaxActAnalysis(
        "attn", 
        layer, 
        feature, 
        num_sequences=num_examples, 
        batch_size=128, 
        strategy=strategy
    )
    mini_examples_owt = analyze_owt.get_context_referenced_prompts_for_range(0, 5)
    mini_examples_owt_overall.extend(mini_examples_owt)

    # For Dataset Prompt Tokens
    analyze_prompts = MaxActAnalysis(
        "attn", 
        layer, 
        feature, 
        num_sequences=num_examples, 
        batch_size=128, 
        strategy=strategy, 
        token_dataset=dataset_prompt_tokens
    )
    mini_examples_ioi = analyze_prompts.get_context_referenced_prompts_for_range(0, 5)
    mini_examples_ioi_overall.extend(mini_examples_ioi)

p = main_aug_interp_prompt_ioi(mini_examples_owt_overall, mini_examples_ioi_overall)
interp = get_response(p)
print(interp)

In [None]:
l4h11_interp = """ 
Predicts a determiner that is following a structure of "<Name_1> and <Name_2>", signalling the association between two entities or characters and promoting predictions in the context of an action or state involving them. In the context of the IOI task, this neuron helps identify and predict an interaction between the subject and the indirect object.
"""

In [None]:
from jinja2 import Template
from typing import List, Tuple, Optional

def main_aug_interp_prompt_ioi(
    examples: List[str], examples_ioi: List[str], 
    token_lr=("<<", ">>"), context_lr=("[[", "]]")
):
    tl, tr = token_lr
    cl, cr = context_lr

    template = Template(
        """
{# You are a meticulous AI researcher conducting an important investigation into a certain neuron in a language model. Your task is to analyze the neuron and provide an explanation that thoroughly encapsulates its behavior in the context of a specific task: Indirect Object Identification (IOI). Here's how you will complete this task: #}

You are a meticulous AI researcher conducting an important investigation into a certain neuron in a language model. This language model is trained to predict the text that will follow a given input. Your task is to figure out what sort of behavior this neuron is responsible for -- namely, when this neuron fires, what kind of predictions does this neuron promote in the context of the specific task of Indirect Object Identification (IOI)? Here's how you'll complete the task:

INPUT_DESCRIPTION: 
You will be given several examples of text that activate the neuron. First we'll provide the example text without any annotations, and then we'll provide the same text with annotations that show the specific tokens that caused the neuron to activate and context about why the neuron fired.

The specific token that the neuron activates on will be the last token in the sequence, and will appear between {{tl}} and {{tr}} (like {{tl}}this{{tr}}).  

Additionally, each sequence will have tokens enclosed between {{cl}} and {{cr}} (like {{cl}}this{{cr}}). From previous analysis, we know that these tokens form the context for why our neuron fires on the token enclosed in {{tl}} and {{tr}} (in addition to the value of the actual token itself). Note that we treat the group of tokens enclosed between {{cl}} and {{cr}} as the "context" for why the neuron fired.

We will provide both general examples and specific examples related to the task of Indirect Object Identification (IOI).

Task Description: A sentence containing indirect object identification (IOI) has an initial dependent clause, e.g. "When Mary and John went to the store", and a main clause, e.g. "John gave a bottle of milk to Mary". The initial clause introduces the indirect object (IO) "Mary" and the subject (S) "John". The main clause refers to the subject a second time, and in all our examples of IOI, the subject gives an object to the IO. The IOI task is to predict the final token in the sentence to be the IO. We use 'S1' and 'S2' to refer to the first and second occurrences of the subject, when we want to specify position.

Given these examples, complete the following steps.

OUTPUT_DESCRIPTION:

Step 1: Based on the general examples provided, write down observed patterns between the tokens that caused the neuron to activate (just the tokens enclosed in {{tl}} and {{tr}}).
Step 2: Based on the general examples provided, write down patterns you see in the context for why the neuron fired. (Remember, the "context" for an example is the group of tokens enclosed in {{cl}} and {{cr}}). Include any patterns in the relationships between different tokens in the context, and any patterns in the relationship between the context and the rest of the text.
Step 3: Write down several general shared features of the general text examples.
Step 4: Based on the IOI examples provided, write down observed patterns between the tokens that caused the neuron to activate (just the tokens enclosed in {{tl}} and {{tr}}).
Step 5: Based on the IOI examples provided, write down patterns you see in the context for why the neuron fired. (Remember, the "context" for an example is the group of tokens enclosed in {{cl}} and {{cr}}). Include any patterns in the relationships between different tokens in the context, and any patterns in the relationship between the context and the rest of the text.
Step 6: Write down several general shared features of the IOI text examples.
Step 7: Based on the patterns you found between the activating token and the relevant context in both general and IOI examples, write down your best explanation for what this neuron is responsible for. Propose your explanation in the following form: 
[EXPLANATION]: <your explanation>

Guidelines:
- Try to produce a final explanation that's both concise and general to the examples provided.
- Your explanation should be short: 1-2 sentences.
- Specifically address the neuron's role in the context of the IOI task, explaining its specific function in relation to predicting the indirect object.
- If provided, incorporate the interpretation of the previous neurons into your explanation, considering how the current neuron processes and uses the information from these previous neurons.

INPUT:

General Examples:
{% for example in examples %}                         
EXAMPLE {{loop.index + 1}}:
- Base Text -
=================================================
{{example[0]}}
=================================================

- Annotated Text -
=================================================
{{example[1]}}
=================================================

{% endfor %}

IOI Task Examples:
{% for example in examples_ioi %}                         
EXAMPLE {{loop.index + 1}}:
- Base Text -
=================================================
{{example[0]}}
=================================================

- Annotated Text -
=================================================
{{example[1]}}
=================================================

{% endfor %}

OUTPUT:
                         
Step 1:
"""
    )

    return template.render(
        {"tl": tl, "tr": tr, "cl": cl, "cr": cr, "examples": examples, "examples_ioi": examples_ioi, "incoming_information": incoming_information}
    )

def main_aug_interp_prompt_ioi_incoming(
    examples: List[str], examples_ioi: List[str], 
    incoming_information: Optional[List[Tuple[str, str]]] = None, 
    token_lr=("<<", ">>"), context_lr=("[[", "]]")
):
    tl, tr = token_lr
    cl, cr = context_lr

    template = Template(
        """
{# You are a meticulous AI researcher conducting an important investigation into a certain neuron in a language model. Your task is to analyze the neuron and provide an explanation that thoroughly encapsulates its behavior in the context of a specific task: Indirect Object Identification (IOI). Here's how you will complete this task: #}

You are a meticulous AI researcher conducting an important investigation into a certain neuron in a language model. This language model is trained to predict the text that will follow a given input. Your task is to figure out what sort of behavior this neuron is responsible for -- namely, when this neuron fires, what kind of predictions does this neuron promote in the context of the specific task of Indirect Object Identification (IOI)? Here's how you'll complete the task:

INPUT_DESCRIPTION: 
You will be given several examples of text that activate the neuron. First we'll provide the example text without any annotations, and then we'll provide the same text with annotations that show the specific tokens that caused the neuron to activate and context about why the neuron fired.

The specific token that the neuron activates on will be the last token in the sequence, and will appear between {{tl}} and {{tr}} (like {{tl}}this{{tr}}).  

Additionally, each sequence will have tokens enclosed between {{cl}} and {{cr}} (like {{cl}}this{{cr}}). From previous analysis, we know that these tokens form the context for why our neuron fires on the token enclosed in {{tl}} and {{tr}} (in addition to the value of the actual token itself). Note that we treat the group of tokens enclosed between {{cl}} and {{cr}} as the "context" for why the neuron fired.

We will provide both general examples and specific examples related to the task of Indirect Object Identification (IOI).

Task Description: A sentence containing indirect object identification (IOI) has an initial dependent clause, e.g. "When Mary and John went to the store", and a main clause, e.g. "John gave a bottle of milk to Mary". The initial clause introduces the indirect object (IO) "Mary" and the subject (S) "John". The main clause refers to the subject a second time, and in all our examples of IOI, the subject gives an object to the IO. The IOI task is to predict the final token in the sentence to be the IO. We use 'S1' and 'S2' to refer to the first and second occurrences of the subject, when we want to specify position.

Previous Neuron Information:
You will also be provided with information about important previous neurons that feed into the current neuron. These neurons play a significant role in the IOI task and move information into the current neuron. The incoming information will be presented as a list of tuples, where each tuple contains the neuron's name and its interpretation in the context of the IOI task.

{% for neuron in incoming_information %}
Neuron {{neuron[0]}}:
- Interpretation in IOI context: {{neuron[1]}}

{% endfor %}

Use this incoming information to help interpret the current neuron's role, considering how it processes and uses the information from these previous neurons.

Given these examples, complete the following steps.

OUTPUT_DESCRIPTION:

Step 1: Based on the general examples provided, write down observed patterns between the tokens that caused the neuron to activate (just the tokens enclosed in {{tl}} and {{tr}}).
Step 2: Based on the general examples provided, write down patterns you see in the context for why the neuron fired. (Remember, the "context" for an example is the group of tokens enclosed in {{cl}} and {{cr}}). Include any patterns in the relationships between different tokens in the context, and any patterns in the relationship between the context and the rest of the text.
Step 3: Write down several general shared features of the general text examples.
Step 4: Based on the IOI examples provided, write down observed patterns between the tokens that caused the neuron to activate (just the tokens enclosed in {{tl}} and {{tr}}).
Step 5: Based on the IOI examples provided, write down patterns you see in the context for why the neuron fired. (Remember, the "context" for an example is the group of tokens enclosed in {{cl}} and {{cr}}). Include any patterns in the relationships between different tokens in the context, and any patterns in the relationship between the context and the rest of the text.
Step 6: Write down several general shared features of the IOI text examples.
Step 7: Based on the patterns you found between the activating token and the relevant context in both general and IOI examples, write down your best explanation for what this neuron is responsible for. Propose your explanation in the following form: 
[EXPLANATION]: <your explanation>

Guidelines:
- Try to produce a final explanation that's both concise and general to the examples provided.
- Your explanation should be short: 1-2 sentences.
- Specifically address the neuron's role in the context of the IOI task, explaining its specific function in relation to predicting the indirect object.
- If provided, incorporate the interpretation of the previous neurons into your explanation, considering how the current neuron processes and uses the information from these previous neurons.

INPUT:

General Examples:
{% for example in examples %}                         
EXAMPLE {{loop.index + 1}}:
- Base Text -
=================================================
{{example[0]}}
=================================================

- Annotated Text -
=================================================
{{example[1]}}
=================================================

{% endfor %}

IOI Task Examples:
{% for example in examples_ioi %}                         
EXAMPLE {{loop.index + 1}}:
- Base Text -
=================================================
{{example[0]}}
=================================================

- Annotated Text -
=================================================
{{example[1]}}
=================================================

{% endfor %}

OUTPUT:
                         
Step 1:
"""
    )

    return template.render(
        {"tl": tl, "tr": tr, "cl": cl, "cr": cr, "examples": examples, "examples_ioi": examples_ioi, "incoming_information": incoming_information}
    )

In [None]:
features = [x for x in list(set(cp.circuit_hypergraph['L5_H5']['features'])) if x!=-1]
print(features)

# #feature = 19042
layer = 5
num_examples = 5000

strategy = create_simple_greedy_strategy(
    passes=1,
    node_contributors=1,
    minimal=True,
)


dataset_prompts = gen_templated_prompts(template_idex=1, N=500)
prompts = [x['text'] + x['correct'] for x in dataset_prompts]
tokens = model.to_tokens(prompts)  # Assuming `model` is already defined
dataset_prompt_tokens = torch.tensor(tokens)

mini_examples_owt_overall = []
mini_examples_ioi_overall = []

for feature in features:

    analyze_owt = MaxActAnalysis(
        "attn", 
        layer, 
        feature, 
        num_sequences=num_examples, 
        batch_size=128, 
        strategy=strategy
    )
    mini_examples_owt = analyze_owt.get_context_referenced_prompts_for_range(0, 5)
    mini_examples_owt_overall.extend(mini_examples_owt)

    # For Dataset Prompt Tokens
    analyze_prompts = MaxActAnalysis(
        "attn", 
        layer, 
        feature, 
        num_sequences=num_examples, 
        batch_size=128, 
        strategy=strategy, 
        token_dataset=dataset_prompt_tokens
    )
    mini_examples_ioi = analyze_prompts.get_context_referenced_prompts_for_range(0, 5)
    mini_examples_ioi_overall.extend(mini_examples_ioi)


incoming_information = [
    ("L2H2", l2h2_interp),
    ("L0H1", l0h1_interp),
    ("L3H0", l3h0_interp),
    ("L4H11", l4h11_interp),
]

p = main_aug_interp_prompt_ioi_incoming(mini_examples_owt_overall, mini_examples_ioi_overall, incoming_information)
interp = get_response(p)
print(interp)

In [None]:
l5h5_interp = "This neuron predicts the appearance of named entities, specifically indirect objects, influenced by a previous introduction of the same entity or entities in the text (especially when entities were linked with a conjunction), crucial in identifying these entities when they reappear in a later context."

In [None]:
features = [x for x in list(set(cp.circuit_hypergraph['L8_H6']['features'])) if x!=-1]
print(features)

# #feature = 19042
layer = 8
num_examples = 5000

strategy = create_simple_greedy_strategy(
    passes=1,
    node_contributors=1,
    minimal=True,
)


dataset_prompts = gen_templated_prompts(template_idex=1, N=500)
prompts = [x['text'] + x['correct'] for x in dataset_prompts]
tokens = model.to_tokens(prompts)  # Assuming `model` is already defined
dataset_prompt_tokens = torch.tensor(tokens)

mini_examples_owt_overall = []
mini_examples_ioi_overall = []

for feature in features:

    analyze_owt = MaxActAnalysis(
        "attn", 
        layer, 
        feature, 
        num_sequences=num_examples, 
        batch_size=128, 
        strategy=strategy
    )
    mini_examples_owt = analyze_owt.get_context_referenced_prompts_for_range(0, 5)
    mini_examples_owt_overall.extend(mini_examples_owt)

    # For Dataset Prompt Tokens
    analyze_prompts = MaxActAnalysis(
        "attn", 
        layer, 
        feature, 
        num_sequences=num_examples, 
        batch_size=128, 
        strategy=strategy, 
        token_dataset=dataset_prompt_tokens
    )
    mini_examples_ioi = analyze_prompts.get_context_referenced_prompts_for_range(0, 5)
    mini_examples_ioi_overall.extend(mini_examples_ioi)


incoming_information = [
    # ("L2H2", l2h2_interp),
    # ("L0H1", l0h1_interp),
    # ("L3H0", l3h0_interp),
    # ("L4H11", l4h11_interp),
    ("L5H5", l5h5_interp),
]

p = main_aug_interp_prompt_ioi_incoming(mini_examples_owt_overall, mini_examples_ioi_overall, incoming_information)
interp = get_response(p)
print(interp)

In [None]:
l8h6_interp = """ 
This neuron is responsible for detecting a linking or coordinating structure ("and", "to") between multiple named entities when those entities have already been introduced, as determined by L5H5.
"""

In [None]:
cp.circuit_hypergraph['L9_H9']['features']

In [None]:
features = [x for x in list(set(cp.circuit_hypergraph['L9_H9']['features'])) if x!=-1]
print(features)

# #feature = 19042
layer = 9
num_examples = 2500

strategy = create_simple_greedy_strategy(
    passes=1,
    node_contributors=1,
    minimal=True,
)


dataset_prompts = gen_templated_prompts(template_idex=1, N=500)
prompts = [x['text'] + x['correct'] for x in dataset_prompts]
tokens = model.to_tokens(prompts)  # Assuming `model` is already defined
dataset_prompt_tokens = torch.tensor(tokens)

mini_examples_owt_overall = []
mini_examples_ioi_overall = []

for feature in features:
    try:

        analyze_owt = MaxActAnalysis(
            "attn", 
            layer, 
            feature, 
            num_sequences=num_examples, 
            batch_size=128, 
            strategy=strategy
        )
        mini_examples_owt = analyze_owt.get_context_referenced_prompts_for_range(0, 5)
        mini_examples_owt_overall.extend(mini_examples_owt)

        # For Dataset Prompt Tokens
        analyze_prompts = MaxActAnalysis(
            "attn", 
            layer, 
            feature, 
            num_sequences=num_examples, 
            batch_size=128, 
            strategy=strategy, 
            token_dataset=dataset_prompt_tokens
        )
        mini_examples_ioi = analyze_prompts.get_context_referenced_prompts_for_range(0, 5)
        mini_examples_ioi_overall.extend(mini_examples_ioi)
    
    except:
        print(f"Error with feature {feature}")

In [None]:
incoming_information = [
    # ("L2H2", l2h2_interp),
    # ("L0H1", l0h1_interp),
    # ("L3H0", l3h0_interp),
    # ("L4H11", l4h11_interp),
    #("L5H5", l5h5_interp),
    ("L8H6", l8h6_interp),
]

p = main_aug_interp_prompt_ioi_incoming(mini_examples_owt_overall, mini_examples_ioi_overall, incoming_information)
interp = get_response(p)
print(interp)

In [None]:
print(p)

In [None]:
features = [x for x in list(set(cp.circuit_hypergraph['L10_H7']['features'])) if x!=-1]
print(features)

# #feature = 19042
layer = 10
num_examples = 2500

strategy = create_simple_greedy_strategy(
    passes=1,
    node_contributors=1,
    minimal=True,
)


dataset_prompts = gen_templated_prompts(template_idex=1, N=500)
prompts = [x['text'] + x['correct'] for x in dataset_prompts]
tokens = model.to_tokens(prompts)  # Assuming `model` is already defined
dataset_prompt_tokens = torch.tensor(tokens)

mini_examples_owt_overall = []
mini_examples_ioi_overall = []

for feature in features:
    try:

        analyze_owt = MaxActAnalysis(
            "attn", 
            layer, 
            feature, 
            num_sequences=num_examples, 
            batch_size=128, 
            strategy=strategy
        )
        mini_examples_owt = analyze_owt.get_context_referenced_prompts_for_range(0, 5)
        mini_examples_owt_overall.extend(mini_examples_owt)

        # For Dataset Prompt Tokens
        analyze_prompts = MaxActAnalysis(
            "attn", 
            layer, 
            feature, 
            num_sequences=num_examples, 
            batch_size=128, 
            strategy=strategy, 
            token_dataset=dataset_prompt_tokens
        )
        mini_examples_ioi = analyze_prompts.get_context_referenced_prompts_for_range(0, 5)
        mini_examples_ioi_overall.extend(mini_examples_ioi)
    
    except:
        print(f"Error with feature {feature}")

In [None]:
incoming_information = [
    # ("L2H2", l2h2_interp),
    # ("L0H1", l0h1_interp),
    # ("L3H0", l3h0_interp),
    # ("L4H11", l4h11_interp),
    #("L5H5", l5h5_interp),
    ("L8H6", l8h6_interp),
]

p = main_aug_interp_prompt_ioi_incoming(mini_examples_owt_overall, mini_examples_ioi_overall, incoming_information)
interp = get_response(p)
print(interp)

# New and improved

In [None]:

from data.ioi_dataset import gen_templated_prompts
from aug_interp_prompts import main_aug_interp_prompt, main_aug_interp_prompt_v2
from openai_utils import gen_openai_completion, get_response
from autointerpretability import *
from discovery_strategies import (
    create_filter,
    create_simple_greedy_strategy,
    create_top_contributor_strategy,
)
from max_act_analysis import MaxActAnalysis

# Get feature families for each component

from autointerpretability import *

cp = get_circuit_prediction(task='ioi', N=20)

from collections import Counter, defaultdict

def get_top_k_feature_tuples_for_component(co_occurrence_dict, component_str, k=5):
    # Parse the component string to get the appropriate tuple key
    if component_str.startswith("MLP"):
        layer = int(component_str[3:])
        component = ('mlp_feature', layer)
    elif component_str.startswith("L") and "H" in component_str:
        layer, head = map(int, component_str[1:].split("H"))
        component = ('attn_head', layer, head)
    else:
        raise ValueError(f"Invalid component format: {component_str}")

    # Use a Counter to count the occurrences of each tuple
    global_counter = Counter()

    # Iterate through the co-occurrence dictionary
    for comp_pair, co_occurrences in co_occurrence_dict.items():
        comp1, comp2 = comp_pair

        if comp1 == component or comp2 == component:
            for feature_tuple in co_occurrences:
                global_counter[(comp_pair, feature_tuple)] += 1

    # Get the top-k tuples by count
    top_k_tuples = global_counter.most_common(k)

    # Create a dictionary to store the results
    top_k_dict = defaultdict(dict)
    
    for (comp_pair, feature_tuple), count in top_k_tuples:
        top_k_dict[comp_pair][feature_tuple] = count

    return top_k_dict

torch.set_grad_enabled(False)
model, z_saes, transcoders = get_model_encoders('cpu')

In [None]:
torch.set_grad_enabled(False)
model, z_saes, transcoders = get_model_encoders('cpu')

In [None]:
from functools import partial
import transformer_lens.utils as utils
import numpy as np
import einops
from tabulate import tabulate
from termcolor import colored

def pretty_print_results(top_k_increases_indices, top_k_increases, top_k_decreases_indices, top_k_decreases, model, k, visualize):
    increase_data = []
    decrease_data = []
    
    for i in range(k):
        increase_token = model.to_string([top_k_increases_indices[i]])
        increase_data.append([increase_token, f"{top_k_increases[i]:.2f}"])
        
        decrease_token = model.to_string([top_k_decreases_indices[i]])
        decrease_data.append([decrease_token, f"{top_k_decreases[i]:.2f}"])
    
    increases_table = tabulate(increase_data, headers=["Token", "Increase"], tablefmt="pretty")
    decreases_table = tabulate(decrease_data, headers=["Token", "Decrease"], tablefmt="pretty")
    
    if visualize:
        print(colored("\nTop k Increases:", 'green'))
        print(increases_table)
        
        print(colored("\nTop k Decreases:", 'red'))
        print(decreases_table)
    
    return f"Top k Increases:\n{increases_table}\n\nTop k Decreases:\n{decreases_table}"

def find_top_changes(prompt, model, layer, features, sae, k=10, visualize=True, max_z=True):
    # Tokenize the prompt
    tokens = model.to_tokens(prompt)
    
    # Get the activations from the cache
    _, cache = model.run_with_cache(tokens)
    z = cache["z", layer, "attn"]
    clean_z = einops.rearrange(z, "b s n d -> (b s) (n d)")
    z_hidden = sae.encode(clean_z)
    
    # Keep only the features in the second dimension
    z_hidden = z_hidden[:, features]
    
    # Sum the feature activations together
    if max_z:
        z_hidden, _ = torch.max(z_hidden, dim=1) # z_hidden.sum(dim=1)  
    else:
        z_hidden = z_hidden.sum(dim=1)

    # Make the first index 0
    z_hidden[0] = 0.0
    
    # Get the index of the max activation
    max_act_idx = z_hidden.argmax()

    # If max act index is 0, return an error message
    if max_act_idx == 0:
        return None, ""
    
    # Cut the tokens at the max activation index
    tokens = tokens[:, 1:max_act_idx + 1]
    cut_prompt = model.to_string(tokens)[0]

    # Get clean logits
    clean_logits, clean_cache = model.run_with_cache(tokens)

    # Define hook function to patch activations
    def patch_z_vector(z, hook, layer, features):
        clean_z = einops.rearrange(z, "b s n d -> (b s) (n d)")
        z_hidden = sae.encode(clean_z)
        for feature in features:
            z_hidden[:, feature] = 0.0
        z_out = sae.decode(z_hidden)
        z_out = einops.rearrange(z_out, "(b s) (n d) -> b s n d", b=z.shape[0], s=z.shape[1], n=z.shape[2], d=z.shape[3])
        return z_out

    # Apply the hook function
    hook_fn = partial(patch_z_vector, layer=layer, features=features)
    patched_logits = model.run_with_hooks(
        tokens,
        fwd_hooks=[(utils.get_act_name("z", layer, "attn"), hook_fn)],
        return_type="logits"
    )

    clean_logits = clean_logits.squeeze()[-1, :]
    patched_logits = patched_logits.squeeze()[-1, :]

    difference = clean_logits - patched_logits

    # Find indices of the top k increases
    top_k_increases_indices = np.argpartition(difference, -k)[-k:]
    top_k_increases = difference[top_k_increases_indices]

    # Find indices of the top k decreases
    top_k_decreases_indices = np.argpartition(difference, k)[:k]
    top_k_decreases = difference[top_k_decreases_indices]

    # Sort the top k increases and decreases
    sorted_top_k_increases_indices = top_k_increases_indices[np.argsort(-top_k_increases)]
    sorted_top_k_increases = difference[sorted_top_k_increases_indices]

    sorted_top_k_decreases_indices = top_k_decreases_indices[np.argsort(top_k_decreases)]
    sorted_top_k_decreases = difference[sorted_top_k_decreases_indices]

    # Pretty print the results and return the tables as a string
    result_string = pretty_print_results(sorted_top_k_increases_indices, sorted_top_k_increases, sorted_top_k_decreases_indices, sorted_top_k_decreases, model, k, visualize=visualize)
    return result_string, cut_prompt

# Example usage
prompt = "Then, Joseph and Brandon had a lot of fun at the restaurant. Joseph gave a ring to"
layer = 9
features = [3520]
sae = z_saes[layer]
result_string, cut_prompt = find_top_changes(prompt, model, layer, features, sae)
print(cut_prompt)
print(result_string)

In [None]:
# Now let's try it over a bunch of prompts
dataset_prompts = gen_templated_prompts(template_idex=1, N=5)
prompts = [x['text'] + x['correct'] for x in dataset_prompts]

layer = 9
features = [4729, 12471, 3520, 10391, 21753, 22975, 13173, 2581, 14056, 3481] #[16513, 2623]
sae = z_saes[layer]

all_results = []

for prompt in prompts:
    try:
        result_string, cut_prompt = find_top_changes(prompt, model, layer, features, sae, k=5, visualize=False)
        all_results.append(f"Prompt = '{cut_prompt}'\n\n"+result_string)
    except:
        continue

# Combine all results into one string with \n\n\n between each result
all_results_string = "\n\n\n".join(all_results)

In [None]:
print(all_results_string)

In [None]:
from jinja2 import Template
from typing import List, Tuple, Optional

def main_aug_interp_prompt_ioi(
    examples: List[str], examples_ioi: List[str], 
    logit_increase_decrease: Optional[str] = None,
    token_lr=("<<", ">>"), context_lr=("[[", "]]")
):
    tl, tr = token_lr
    cl, cr = context_lr

    template = Template(
        """
{# You are a meticulous AI researcher conducting an important investigation into a certain neuron in a language model. Your task is to analyze the neuron and provide an explanation that thoroughly encapsulates its behavior in the context of a specific task: Indirect Object Identification (IOI). Here's how you will complete this task: #}

You are a meticulous AI researcher conducting an important investigation into a certain neuron in a language model. This language model is trained to predict the text that will follow a given input. Your task is to figure out what sort of behavior this neuron is responsible for -- namely, when this neuron fires, what kind of predictions does this neuron promote in the context of the specific task of Indirect Object Identification (IOI)? Here's how you'll complete the task:

INPUT_DESCRIPTION:
You will be given several examples of text that activate the neuron. First we'll provide the example text without any annotations, and then we'll provide the same text with annotations that show the specific tokens that caused the neuron to activate and context about why the neuron fired.

The specific token that the neuron activates on will be the last token in the sequence, and will appear between {{tl}} and {{tr}} (like {{tl}}this{{tr}}).

Additionally, each sequence will have tokens enclosed between {{cl}} and {{cr}} (like {{cl}}this{{cr}}). From previous analysis, we know that these tokens form the context for why our neuron fires on the token enclosed in {{tl}} and {{tr}} (in addition to the value of the actual token itself). Note that we treat the group of tokens enclosed between {{cl}} and {{cr}} as the "context" for why the neuron fired.

We will provide both general examples and specific examples related to the task of Indirect Object Identification (IOI).

Task Description: A sentence containing indirect object identification (IOI) has an initial dependent clause, e.g. "When Mary and John went to the store", and a main clause, e.g. "John gave a bottle of milk to Mary". The initial clause introduces the indirect object (IO) "Mary" and the subject (S) "John". The main clause refers to the subject a second time, and in all our examples of IOI, the subject gives an object to the IO. The IOI task is to predict the final token in the sentence to be the IO. We use 'S1' and 'S2' to refer to the first and second occurrences of the subject, when we want to specify position.

Given these examples, complete the following steps.

OUTPUT_DESCRIPTION:

Step 1: Based on the general examples provided, write down observed patterns between the tokens that caused the neuron to activate (just the tokens enclosed in {{tl}} and {{tr}}).
Step 2: Based on the general examples provided, write down patterns you see in the context for why the neuron fired. (Remember, the "context" for an example is the group of tokens enclosed in {{cl}} and {{cr}}). Include any patterns in the relationships between different tokens in the context, and any patterns in the relationship between the context and the rest of the text.
Step 3: Write down several general shared features of the general text examples.
Step 4: Based on the IOI examples provided, write down observed patterns between the tokens that caused the neuron to activate (just the tokens enclosed in {{tl}} and {{tr}}).
Step 5: Based on the IOI examples provided, write down patterns you see in the context for why the neuron fired. (Remember, the "context" for an example is the group of tokens enclosed in {{cl}} and {{cr}}). Include any patterns in the relationships between different tokens in the context, and any patterns in the relationship between the context and the rest of the text.
Step 6: Write down several general shared features of the IOI text examples.
Step 7: Based on the logit increase/decrease information provided, analyze how the neuron's activation affects the probability of certain tokens being predicted next.
Step 8: Based on the patterns you found between the activating token and the relevant context in both general and IOI examples, and the logit increase/decrease information, write down your best explanation for (1) which exact tokens the neuron fires on, and (2) how this firing influence the prediction of the NEXT token. Propose your explanation in the following form:
[EXPLANATION]: <your explanation>

Guidelines:
- Try to produce a final explanation that's both concise and general to the examples provided.
- Your explanation should be short: 1-2 sentences.
- Specifically address the neuron's role in the context of the IOI task, explaining its specific function in relation to predicting the indirect object.
- Use the logit increase/decrease information to explain the neuron's effect on the next token prediction.
- When looking at the context tokens that contributed to the neuron activating on the activation token, think about how they relate to the activation token. Has the neuron seen these tokens before? How are they related to activation, and how are they related to the token that might come next?
- Not all of the information will be useful! Sometimes, the neuron may just be looking at a specific previous token, or alternatively may not care about previous tokens and is just doing something to the prediction.
- Your final explanation should be a combination of (1) what tokens the neuron specifically activates on, and (2) how this neuron then affects the next token.

INPUT:

General Examples for tokens the neuron activates on (<< >>), and the other tokens that provide context ([[ ]]):
{% for example in examples %}                         
EXAMPLE {{loop.index + 1}}:
- Base Text -
=================================================
{{example[0]}}
=================================================

- Annotated Text -
=================================================
{{example[1]}}
=================================================

{% endfor %}

IOI Task Examples:
{% for example in examples_ioi %}                         
EXAMPLE {{loop.index + 1}}:
- Base Text -
=================================================
{{example[0]}}
=================================================

- Annotated Text -
=================================================
{{example[1]}}
=================================================

{% endfor %}

Use these examples to determine what tokens the neuron fires/activates on, and the surrounding context (tokens) that also contribute to this.

Logit Increase/Decrease Information:
Below is information on how this particular neuron increases the probability of certain tokens (i.e. their logits) and decreases the probability of others, when predicting the next token in the prompt i.e. after the final word in the prompt provided. This information will help you discern the neuron's effect on the NEXT token to predict after the token it activates on, which it doesn't actually see in the examples.

{% for x in logit_increase_decrease %}
{{x}}
{% endfor %}

Use this information to help interpret the effect of the neuron's activation on the model's predictions. Be specific about whether it boosts the logits of the subject, the indirect object, or a different type of neuron, and similarly whether it decreases the logits of the subject, the indirect object, or a different type of neuron.

OUTPUT:
                         
Step 1:
"""
    )

    return template.render(
        {"tl": tl, "tr": tr, "cl": cl, "cr": cr, "examples": examples, "examples_ioi": examples_ioi, 
         "logit_increase_decrease": logit_increase_decrease}
    )

# {# You are a meticulous AI researcher conducting an important investigation into a certain neuron in a language model.
#     # Your task is to analyze the neuron and provide an explanation that thoroughly encapsulates its behavior in the context of a specific task: Indirect Object Identification (IOI).
#     # Here's how you will complete this task:}


def main_aug_interp_prompt_ioi_incoming(
    examples: List[str], examples_ioi: List[str], 
    incoming_information: Optional[List[Tuple[str, str]]] = None, 
    logit_increase_decrease: Optional[str] = None,
    token_lr=("<<", ">>"), context_lr=("[[", "]]")
):
    tl, tr = token_lr
    cl, cr = context_lr

    template = Template(
        """
{# You are a meticulous AI researcher conducting an important investigation into a certain neuron in a language model. Your task is to analyze the neuron and provide an explanation that thoroughly encapsulates its behavior in the context of a specific task: Indirect Object Identification (IOI). Here's how you will complete this task: #}

You are a meticulous AI researcher conducting an important investigation into a certain neuron in a language model. This language model is trained to predict the text that will follow a given input. Your task is to figure out what sort of behavior this neuron is responsible for -- namely, when this neuron fires, what kind of predictions does this neuron promote in the context of the specific task of Indirect Object Identification (IOI)? Here's how you'll complete the task:

INPUT_DESCRIPTION:
You will be given several examples of text that activate the neuron. First we'll provide the example text without any annotations, and then we'll provide the same text with annotations that show the specific tokens that caused the neuron to activate and context about why the neuron fired.

The specific token that the neuron activates on will be the last token in the sequence, and will appear between {{tl}} and {{tr}} (like {{tl}}this{{tr}}).

Additionally, each sequence will have tokens enclosed between {{cl}} and {{cr}} (like {{cl}}this{{cr}}). From previous analysis, we know that these tokens form the context for why our neuron fires on the token enclosed in {{tl}} and {{tr}} (in addition to the value of the actual token itself). Note that we treat the group of tokens enclosed between {{cl}} and {{cr}} as the "context" for why the neuron fired.

We will provide both general examples and specific examples related to the task of Indirect Object Identification (IOI).

Task Description: A sentence containing indirect object identification (IOI) has an initial dependent clause, e.g. "When Mary and John went to the store", and a main clause, e.g. "John gave a bottle of milk to Mary". The initial clause introduces the indirect object (IO) "Mary" and the subject (S) "John". The main clause refers to the subject a second time, and in all our examples of IOI, the subject gives an object to the IO. The IOI task is to predict the final token in the sentence to be the IO. We use 'S1' and 'S2' to refer to the first and second occurrences of the subject, when we want to specify position.

Previous Neuron Information:
You will also be provided with information about important previous neurons that feed into the current neuron. These neurons play a significant role in the IOI task and move information into the current neuron. The incoming information will be presented as a list of tuples, where each tuple contains the neuron's name and its interpretation in the context of the IOI task.

{% for neuron in incoming_information %}
Neuron {{neuron[0]}}:
- Interpretation in IOI context: {{neuron[1]}}

{% endfor %}

Use this incoming information to help interpret the current neuron's role, considering how it processes and uses the information from these previous neurons.

Given these examples, complete the following steps.

OUTPUT_DESCRIPTION:

Step 1: Based on the general examples provided, write down observed patterns between the tokens that caused the neuron to activate (just the tokens enclosed in {{tl}} and {{tr}}).
Step 2: Based on the general examples provided, write down patterns you see in the context for why the neuron fired. (Remember, the "context" for an example is the group of tokens enclosed in {{cl}} and {{cr}}). Include any patterns in the relationships between different tokens in the context, and any patterns in the relationship between the context and the rest of the text.
Step 3: Write down several general shared features of the general text examples.
Step 4: Based on the IOI examples provided, write down observed patterns between the tokens that caused the neuron to activate (just the tokens enclosed in {{tl}} and {{tr}}).
Step 5: Based on the IOI examples provided, write down patterns you see in the context for why the neuron fired. (Remember, the "context" for an example is the group of tokens enclosed in {{cl}} and {{cr}}). Include any patterns in the relationships between different tokens in the context, and any patterns in the relationship between the context and the rest of the text.
Step 6: Write down several general shared features of the IOI text examples.
Step 7: Based on the logit increase/decrease information provided, analyze how the neuron's activation affects the probability of certain tokens being predicted next.
Step 8: Based on the patterns you found between the activating token and the relevant context in both general and IOI examples, and the logit increase/decrease information, write down your best explanation for (1) which exact tokens the neuron fires on, and (2) how this firing influence the prediction of the NEXT token. Propose your explanation in the following form:
[EXPLANATION]: <your explanation>

Guidelines:
- Try to produce a final explanation that's both concise and general to the examples provided.
- Your explanation should be short: 1-2 sentences.
- Specifically address the neuron's role in the context of the IOI task, explaining its specific function in relation to predicting the indirect object.
- If provided, incorporate the interpretation of the previous neurons into your explanation, considering how the current neuron processes and uses the information from these previous neurons.
- Use the logit increase/decrease information to explain the neuron's effect on the next token prediction.
- When looking at the context tokens that contributed to the neuron activating on the activation token, think about how they relate to the activation token. Has the neuron seen these tokens before? How are they related to activation, and how are they related to the token that might come next?
- Not all of the information will be useful! Sometimes, the neuron may just be looking at a specific previous token, or alternatively may not care about previous tokens and is just doing something to the prediction.
- Your final explanation should be a combination of (1) what tokens the neuron specifically activates on, and (2) how this neuron then affects the next token.

INPUT:

General Examples for tokens the neuron activates on (<< >>), and the other tokens that provide context ([[ ]]):
{% for example in examples %}                         
EXAMPLE {{loop.index + 1}}:
- Base Text -
=================================================
{{example[0]}}
=================================================

- Annotated Text -
=================================================
{{example[1]}}
=================================================

{% endfor %}

IOI Task Examples:
{% for example in examples_ioi %}                         
EXAMPLE {{loop.index + 1}}:
- Base Text -
=================================================
{{example[0]}}
=================================================

- Annotated Text -
=================================================
{{example[1]}}
=================================================

{% endfor %}

Use these examples to determine what tokens the neuron fires/activates on, and the surrounding context (tokens) that also contribute to this.

Logit Increase/Decrease Information:
Below is information on how this particular neuron increases the probability of certain tokens (i.e. their logits) and decreases the probability of others, when predicting the next token in the prompt i.e. after the final word in the prompt provided. This information will help you discern the neuron's effect on the NEXT token to predict after the token it activates on, which it doesn't actually see in the examples.

{% for x in logit_increase_decrease %}
{{x}}
{% endfor %}

Use this information to help interpret the effect of the neuron's activation on the model's predictions. Be specific about whether it boosts the logits of the subject, the indirect object, or a different type of neuron, and similarly whether it decreases the logits of the subject, the indirect object, or a different type of neuron.

OUTPUT:
                         
Step 1:
"""
    )

    return template.render(
        {"tl": tl, "tr": tr, "cl": cl, "cr": cr, "examples": examples, "examples_ioi": examples_ioi, 
         "incoming_information": incoming_information, "logit_increase_decrease": logit_increase_decrease}
    )

In [None]:
# Now let's try it over a bunch of prompts
layer = 8
features = [16513, 10461]
sae = z_saes[layer]
dataset_prompts = gen_templated_prompts(template_idex=1, N=50)
prompts = [x['text'] + x['correct'] for x in dataset_prompts]

all_results = []

for prompt in prompts:
    try:
        result_string, cut_prompt = find_top_changes(prompt, model, layer, features, sae, k=8, visualize=False, max_z=True)
        all_results.append(f"Prompt = '{cut_prompt}'\n\n"+result_string)
    except:
        continue

# Combine all results into one string with \n\n\n between each result
all_results_string = "\n\n\n".join(all_results)
print("Successfully generated logit increase/decrease info.")

print(all_results_string)

In [None]:
features = [x for x in list(set(cp.circuit_hypergraph['L8_H6']['features'])) if x!=-1]
print(features)

num_examples = 2500
layer = 8
# features = [16513, 2623]
sae = z_saes[layer]

strategy = create_simple_greedy_strategy(
    passes=1,
    node_contributors=1,
    minimal=True,
)


# Now let's try it over a bunch of prompts
dataset_prompts = gen_templated_prompts(template_idex=1, N=8)
prompts = [x['text'] + x['correct'] for x in dataset_prompts]

all_results = []

for prompt in prompts:
    result_string, cut_prompt = find_top_changes(prompt, model, layer, features, sae, k=8, visualize=False)
    all_results.append(f"Prompt = '{cut_prompt}'\n\n"+result_string)

# Combine all results into one string with \n\n\n between each result
all_results_string = "\n\n\n".join(all_results)
print("Successfully generated logit increase/decrease info.")

print(all_results_string)

dataset_prompts = gen_templated_prompts(template_idex=1, N=500)
prompts = [x['text'] + x['correct'] for x in dataset_prompts]
tokens = model.to_tokens(prompts)  # Assuming `model` is already defined
dataset_prompt_tokens = torch.tensor(tokens)

mini_examples_owt_overall = []
mini_examples_ioi_overall = []

for feature in features:

    analyze_owt = MaxActAnalysis(
        "attn", 
        layer, 
        feature, 
        num_sequences=num_examples, 
        batch_size=128, 
        strategy=strategy
    )
    mini_examples_owt = analyze_owt.get_context_referenced_prompts_for_range(0, 5)
    mini_examples_owt_overall.extend(mini_examples_owt)

    # For Dataset Prompt Tokens
    analyze_prompts = MaxActAnalysis(
        "attn", 
        layer, 
        feature, 
        num_sequences=num_examples, 
        batch_size=128, 
        strategy=strategy, 
        token_dataset=dataset_prompt_tokens
    )
    mini_examples_ioi = analyze_prompts.get_context_referenced_prompts_for_range(0, 5)
    mini_examples_ioi_overall.extend(mini_examples_ioi)

incoming_information = [
    # ("L2H2", l2h2_interp),
    # ("L0H1", l0h1_interp),
    # ("L3H0", l3h0_interp),
    # ("L4H11", l4h11_interp),
    ("L5H5", l5h5_interp),
]

p = main_aug_interp_prompt_ioi_incoming(mini_examples_owt_overall, mini_examples_ioi_overall, incoming_information, [all_results_string])
# print(p)
interp = get_response(p)
print(interp)

In [None]:
l5h5_interp = """ 
Activates on the appearance of named entities, specifically indirect objects, influenced by a previous introduction of the same entity or entities in the text, and identifies the latter of these entities when they reappear together in a later context.
"""

In [None]:
incoming_information = [
    # ("L2H2", l2h2_interp),
    # ("L0H1", l0h1_interp),
    # ("L3H0", l3h0_interp),
    # ("L4H11", l4h11_interp),
    ("L5H5", l5h5_interp),
]

p = main_aug_interp_prompt_ioi_incoming(mini_examples_owt_overall, mini_examples_ioi_overall, incoming_information, [all_results_string])
# print(p)
interp = get_response(p)
print(interp)

In [None]:
from typing import List, Tuple, Dict, Optional
import torch
from collections import Counter

# Function to generate logit increase/decrease information
def generate_logit_info(prompts: List[str], model, layer: int, features: List[int], sae) -> str:
    all_results = []
    for prompt in prompts:
        try:
            result_string, cut_prompt = find_top_changes(prompt, model, layer, features, sae, k=8, visualize=False)
            # Surround the last word in cut prompt with << and >> for the logit info
            cut_prompt = cut_prompt.split(" ")
            cut_prompt[-1] = f"<<{cut_prompt[-1]}>>"
            cut_prompt = " ".join(cut_prompt)
            if result_string is not None:
                all_results.append(f"Prompt = '{cut_prompt}'\n\n" + result_string)

        except:
            print(f"Error for prompt {prompt}")
            continue
            
    all_results = all_results[:8]
    return "\n\n\n".join(all_results)

# Function to generate mini examples for both general and IOI tasks
def generate_mini_examples(features: List[int], layer: int, num_examples: int, dataset_prompt_tokens, strategy) -> Tuple[List[str], List[str]]:
    mini_examples_owt_overall = []
    mini_examples_ioi_overall = []

    for feature in features:
        try:
            analyze_owt = MaxActAnalysis("attn", layer, feature, num_sequences=num_examples, batch_size=128, strategy=strategy)
            mini_examples_owt = analyze_owt.get_context_referenced_prompts_for_range(0, 5)
            mini_examples_owt_overall.extend(mini_examples_owt)
        except:
            print(f"Error for OWT tokens with feature {feature}")

        try: 
            analyze_prompts = MaxActAnalysis("attn", layer, feature, num_sequences=num_examples, batch_size=128, strategy=strategy, token_dataset=dataset_prompt_tokens)
            mini_examples_ioi = analyze_prompts.get_context_referenced_prompts_for_range(0, 5)
            mini_examples_ioi_overall.extend(mini_examples_ioi)
        except:
            print(f"Error for IOI tokens with feature {feature}")

    return mini_examples_owt_overall, mini_examples_ioi_overall

# Function to interpret the neuron using the provided prompt
def interpret_neuron(mini_examples_owt: List[str], mini_examples_ioi: List[str], incoming_information: Optional[List[Tuple[str, str]]], logit_info: str) -> str:
    if len(incoming_information) > 0:
        print("Using incoming information.")
        prompt = main_aug_interp_prompt_ioi_incoming(mini_examples_owt, mini_examples_ioi, incoming_information, [logit_info])
    else:
        print("Not using incoming information.")
        prompt = main_aug_interp_prompt_ioi(mini_examples_owt, mini_examples_ioi, [logit_info])
    return get_response(prompt), prompt

# Function to get the top 5 (or less) occurring features
def get_top_features(component: str, cp) -> List[int]:
    features = [x for x in list(cp.circuit_hypergraph[component]['features']) if x != -1]
    feature_counts = Counter(features)
    top_features = [feature for feature, _ in feature_counts.most_common(10)]
    return top_features

# Higher-level function to process components based on the graph notation
def process_components(graph: Dict[str, List[str]], model, z_saes, strategy, num_examples: int = 2500):
    # Initialize data structure to manage incoming information
    component_interpretations = {}

    num_examples = 2500

    # Generate prompts
    dataset_prompts = gen_templated_prompts(template_idex=1, N=50)
    prompts = [x['text'] + x['correct'] for x in dataset_prompts]

    # Process each component
    for component, next_components in graph.items():
        layer, head = component.split("_")
        layer = int(layer[1:])
        head = int(head[1:])
        features = get_top_features(component, cp)
        print(f"Doing {component} with features {features} for layer {layer} head {head}")

        sae = z_saes[layer]

        # Generate logit increase/decrease information
        logit_info = generate_logit_info(prompts, model, layer, features, sae)
        print(logit_info)
        print(f"Successfully generated logit increase/decrease info for {component}.")

        # Generate mini examples
        dataset_prompts_full = gen_templated_prompts(template_idex=1, N=500)
        full_prompts = [x['text'] + x['correct'] for x in dataset_prompts_full]
        tokens = model.to_tokens(full_prompts)
        dataset_prompt_tokens = torch.tensor(tokens)

        mini_examples_owt, mini_examples_ioi = generate_mini_examples(features, layer, num_examples, dataset_prompt_tokens, strategy)

        # Prepare incoming information
        incoming_information = [] #[(prev_component, component_interpretations[prev_component]) for prev_component in graph if component in graph[prev_component]]
        print(f"Incoming information for {component}:\n{incoming_information}\n")

        # Interpret the neuron
        interpretation, p = interpret_neuron(mini_examples_owt, mini_examples_ioi, incoming_information, logit_info)
        print(f"Interpretation for {component}:\n{interpretation}\n")
        print(f"Prompt for {component}:\n{p}\n")
        component_interpretations[component] = {"interp": interpretation, "prompt": p}

        # Update incoming information for future components
        for next_component in next_components:
            if next_component not in graph:
                graph[next_component] = []
            graph[next_component].append(component)

    return component_interpretations

# Example graph notation
graph_notation = {
    # "L2_H2": ["L5_H5"],
    # "L4_H11": ["L5_H5"],
    # "L0_H1": ["L5_H5"],
    # "L3_H0": ["L5_H5"],
    # "L5_H5": ["L8_H6"],
    # "L8_H6": ["L10_H7", "L9_H9"],
    "L9_H9": [],
    # "L10_H7": []
}

# # Example usage
# model = ...  # Your model here
# sae = z_saes[layer]
model, z_saes, _ = get_model_encoders('cpu')
strategy = create_simple_greedy_strategy(passes=1, node_contributors=1, minimal=True)
component_interpretations_l9h9 = process_components(graph_notation, model, z_saes, strategy)

In [None]:
print(component_interpretations_l9h9['L9_H9']['interp'])

In [None]:
component_interpretations['L10_H7'].keys()

In [None]:
component_interpretations['L3_H0']['interp'] = """ 
[EXPLANATION]: Activates when the subject or indirect object from a prior clause reappears in the text.
"""

In [None]:
for k, v in component_interpretations.items():
    print(f"Component: {k}")
    # print(f"Interpretation: {v['interp']}")
    if 'EXPLANATION' in v['interp']:
        print(f"Interpretation: {v['interp'].split('[EXPLANATION]: ')[-1].strip()}")
    elif 'Explanation' in v['interp']:
        print(f"Interpretation: {v['interp'].split('[Explanation]:')[-1].strip()}")
    #print(f"Interpretation: {v['interp'].split('[EXPLANATION]: ')[-1].strip()}")
    print("\n\n")

In [None]:
# Print one of the prompts
p = component_interpretations['L8_H6']['prompt']
print(p)

## Greater-than

In [None]:

from data.ioi_dataset import gen_templated_prompts
from aug_interp_prompts import main_aug_interp_prompt, main_aug_interp_prompt_v2
from openai_utils import gen_openai_completion, get_response
from autointerpretability import *
from discovery_strategies import (
    create_filter,
    create_simple_greedy_strategy,
    create_top_contributor_strategy,
)
from max_act_analysis import MaxActAnalysis

from autointerpretability import *

cp = get_circuit_prediction(task='gt', N=50)

from collections import Counter, defaultdict

def get_top_k_feature_tuples_for_component(co_occurrence_dict, component_str, k=5):
    # Parse the component string to get the appropriate tuple key
    if component_str.startswith("MLP"):
        layer = int(component_str[3:])
        component = ('mlp_feature', layer)
    elif component_str.startswith("L") and "H" in component_str:
        layer, head = map(int, component_str[1:].split("H"))
        component = ('attn_head', layer, head)
    else:
        raise ValueError(f"Invalid component format: {component_str}")

    # Use a Counter to count the occurrences of each tuple
    global_counter = Counter()

    # Iterate through the co-occurrence dictionary
    for comp_pair, co_occurrences in co_occurrence_dict.items():
        comp1, comp2 = comp_pair

        if comp1 == component or comp2 == component:
            for feature_tuple in co_occurrences:
                global_counter[(comp_pair, feature_tuple)] += 1

    # Get the top-k tuples by count
    top_k_tuples = global_counter.most_common(k)

    # Create a dictionary to store the results
    top_k_dict = defaultdict(dict)
    
    for (comp_pair, feature_tuple), count in top_k_tuples:
        top_k_dict[comp_pair][feature_tuple] = count

    return top_k_dict

torch.set_grad_enabled(False)
model, z_saes, transcoders = get_model_encoders('cpu')

In [None]:
# Go through circuit hypergraph and print out frequencies of all keys with non-zero freq
hypergraph = cp.circuit_hypergraph

components = []
threshold = 0.3

for key, value in hypergraph.items():
    if value['freq'] > 0.0:
        print(key, value['freq'], len([x for x in list(set(value['features'])) if x!=-1]))
    
    if value['freq'] > threshold and len([x for x in list(set(value['features'])) if x!=-1]) > 0:
        components.append(key)


In [None]:
components

In [None]:
from jinja2 import Template
from typing import List, Optional

def main_aug_interp_prompt_gt(
    examples: List[str], examples_gt: List[str], 
    logit_increase_decrease: Optional[str] = None,
    token_lr=("<<", ">>"), context_lr=("[[", "]]")
):
    tl, tr = token_lr
    cl, cr = context_lr

    template = Template(
        """
{# You are a meticulous AI researcher conducting an important investigation into a certain neuron in a language model. Your task is to analyze the neuron and provide an explanation that thoroughly encapsulates its behavior in the context of a specific task: Greater-Than Prediction (GT). Here's how you will complete this task: #}

You are a meticulous AI researcher conducting an important investigation into a certain neuron in a language model. This language model is trained to predict the text that will follow a given input. Your task is to figure out what sort of behavior this neuron is responsible for -- namely, when this neuron fires, what kind of predictions does this neuron promote in the context of the specific task of Greater-Than Prediction (GT)? Here's how you'll complete the task:

INPUT_DESCRIPTION:
You will be given several examples of text that activate the neuron. First we'll provide the example text without any annotations, and then we'll provide the same text with annotations that show the specific tokens that caused the neuron to activate and context about why the neuron fired.

The specific token that the neuron activates on will be the last token in the sequence, and will appear between {{tl}} and {{tr}} (like {{tl}}this{{tr}}).

Additionally, each sequence will have tokens enclosed between {{cl}} and {{cr}} (like {{cl}}this{{cr}}). From previous analysis, we know that these tokens form the context for why our neuron fires on the token enclosed in {{tl}} and {{tr}} (in addition to the value of the actual token itself). Note that we treat the group of tokens enclosed between {{cl}} and {{cr}} as the "context" for why the neuron fired.

We will provide both general examples and specific examples related to the task of Greater-Than Prediction (GT).

Task Description: The greater-than task involves predicting a year that is greater than a given year in the context of sentences framed like “The <noun> lasted from the year XXYY to the year XX”. The initial part of the sentence introduces a time span, and the model's task is to assign higher probabilities to years that are greater than YY. Each example is designed to have at least one correct and one incorrect validly tokenized answer.

Given these examples, complete the following steps.

OUTPUT_DESCRIPTION:

Step 1: Based on the general examples provided, write down observed patterns between the tokens that caused the neuron to activate (just the tokens enclosed in {{tl}} and {{tr}}).
Step 2: Based on the general examples provided, write down patterns you see in the context for why the neuron fired. (Remember, the "context" for an example is the group of tokens enclosed in {{cl}} and {{cr}}). Include any patterns in the relationships between different tokens in the context, and any patterns in the relationship between the context and the rest of the text.
Step 3: Write down several general shared features of the general text examples.
Step 4: Based on the GT examples provided, write down observed patterns between the tokens that caused the neuron to activate (just the tokens enclosed in {{tl}} and {{tr}}).
Step 5: Based on the GT examples provided, write down patterns you see in the context for why the neuron fired. (Remember, the "context" for an example is the group of tokens enclosed in {{cl}} and {{cr}}). Include any patterns in the relationships between different tokens in the context, and any patterns in the relationship between the context and the rest of the text.
Step 6: Write down several general shared features of the GT text examples.
Step 7: Based on the logit increase/decrease information provided, analyze how the neuron's activation affects the probability of certain tokens being predicted next.
Step 8: Based on the patterns you found between the activating token and the relevant context in both general and GT examples, and the logit increase/decrease information, write down your best explanation for (1) which exact tokens the neuron fires on, and (2) how this firing influences the prediction of the NEXT token. Propose your explanation in the following form:
[EXPLANATION]: <your explanation>

Guidelines:
- Try to produce a final explanation that's both concise and general to the examples provided.
- Your explanation should be short: 1-2 sentences.
- Specifically address the neuron's role in the context of the GT task, explaining its specific function in relation to predicting the greater-than year.
- Use the logit increase/decrease information to explain the neuron's effect on the next token prediction.
- When looking at the context tokens that contributed to the neuron activating on the activation token, think about how they relate to the activation token. Has the neuron seen these tokens before? How are they related to activation, and how are they related to the token that might come next?
- Not all of the information will be useful! Sometimes, the neuron may just be looking at a specific previous token, or alternatively may not care about previous tokens and is just doing something to the prediction.
- Your final explanation should be a combination of (1) what tokens the neuron specifically activates on, and (2) how this neuron then affects the next token.

INPUT:

General Examples for tokens the neuron activates on (<< >>), and the other tokens that provide context ([[ ]]):
{% for example in examples %}                         
EXAMPLE {{loop.index + 1}}:
- Base Text -
=================================================
{{example[0]}}
=================================================

- Annotated Text -
=================================================
{{example[1]}}
=================================================

{% endfor %}

GT Task Examples:
{% for example in examples_gt %}                         
EXAMPLE {{loop.index + 1}}:
- Base Text -
=================================================
{{example[0]}}
=================================================

- Annotated Text -
=================================================
{{example[1]}}
=================================================

{% endfor %}

Use these examples to determine what tokens the neuron fires/activates on, and the surrounding context (tokens) that also contribute to this.

Logit Increase/Decrease Information:
Below is information on how this particular neuron increases the probability of certain tokens (i.e. their logits) and decreases the probability of others, when predicting the next token in the prompt i.e. after the final word in the prompt provided. This information will help you discern the neuron's effect on the NEXT token to predict after the token it activates on, which it doesn't actually see in the examples.

{% for x in logit_increase_decrease %}
{{x}}
{% endfor %}

Use this information to help interpret the effect of the neuron's activation on the model's predictions. Be specific about whether it boosts the logits of the greater-than year, the context tokens, or a different type of token, and similarly whether it decreases the logits of the greater-than year, the context tokens, or a different type of token.

OUTPUT:
                         
Step 1:
"""
    )

    return template.render(
        {"tl": tl, "tr": tr, "cl": cl, "cr": cr, "examples": examples, "examples_gt": examples_gt, 
         "logit_increase_decrease": logit_increase_decrease}
    )


In [None]:
from functools import partial
import transformer_lens.utils as utils
import numpy as np
import einops
from tabulate import tabulate
from termcolor import colored

def pretty_print_results(top_k_increases_indices, top_k_increases, top_k_decreases_indices, top_k_decreases, model, k, visualize):
    increase_data = []
    decrease_data = []
    
    for i in range(k):
        increase_token = model.to_string([top_k_increases_indices[i]])
        increase_data.append([increase_token, f"{top_k_increases[i]:.2f}"])
        
        decrease_token = model.to_string([top_k_decreases_indices[i]])
        decrease_data.append([decrease_token, f"{top_k_decreases[i]:.2f}"])
    
    increases_table = tabulate(increase_data, headers=["Token", "Increase"], tablefmt="pretty")
    decreases_table = tabulate(decrease_data, headers=["Token", "Decrease"], tablefmt="pretty")
    
    if visualize:
        print(colored("\nTop k Increases:", 'green'))
        print(increases_table)
        
        print(colored("\nTop k Decreases:", 'red'))
        print(decreases_table)
    
    return f"Top k Increases:\n{increases_table}\n\nTop k Decreases:\n{decreases_table}"

def find_top_changes(prompt, model, layer, features, sae, k=10, visualize=True, max_z=True):
    # Tokenize the prompt
    tokens = model.to_tokens(prompt)
    
    # Get the activations from the cache
    _, cache = model.run_with_cache(tokens)
    z = cache["z", layer, "attn"]
    clean_z = einops.rearrange(z, "b s n d -> (b s) (n d)")
    z_hidden = sae.encode(clean_z)
    
    # Keep only the features in the second dimension
    z_hidden = z_hidden[:, features]
    
    # Sum the feature activations together
    if max_z:
        z_hidden, _ = torch.max(z_hidden, dim=1) # z_hidden.sum(dim=1)  
    else:
        z_hidden = z_hidden.sum(dim=1)

    # Make the first index 0
    #z_hidden[0] = 0.0
    
    # Get the index of the max activation
    max_act_idx = z_hidden.argmax()

    # If max act index is 0, return an error message
    if max_act_idx == 0:
        return None, ""
    
    # Cut the tokens at the max activation index
    tokens = tokens[:, 1:max_act_idx + 1]
    cut_prompt = model.to_string(tokens)[0]

    # Get clean logits
    clean_logits, _ = model.run_with_cache(tokens)

    # Define hook function to patch activations
    def patch_z_vector(z, hook, layer, features):
        clean_z = einops.rearrange(z, "b s n d -> (b s) (n d)")
        z_hidden = sae.encode(clean_z)
        for feature in features:
            z_hidden[:, feature] = 0.0
        z_out = sae.decode(z_hidden)
        z_out = einops.rearrange(z_out, "(b s) (n d) -> b s n d", b=z.shape[0], s=z.shape[1], n=z.shape[2], d=z.shape[3])
        return z_out

    # Apply the hook function
    hook_fn = partial(patch_z_vector, layer=layer, features=features)
    patched_logits = model.run_with_hooks(
        tokens,
        fwd_hooks=[(utils.get_act_name("z", layer, "attn"), hook_fn)],
        return_type="logits"
    )

    clean_logits = clean_logits.squeeze()[-1, :]
    patched_logits = patched_logits.squeeze()[-1, :]

    difference = clean_logits - patched_logits

    # Find indices of the top k increases
    top_k_increases_indices = np.argpartition(difference, -k)[-k:]
    top_k_increases = difference[top_k_increases_indices]

    # Find indices of the top k decreases
    top_k_decreases_indices = np.argpartition(difference, k)[:k]
    top_k_decreases = difference[top_k_decreases_indices]

    # Sort the top k increases and decreases
    sorted_top_k_increases_indices = top_k_increases_indices[np.argsort(-top_k_increases)]
    sorted_top_k_increases = difference[sorted_top_k_increases_indices]

    sorted_top_k_decreases_indices = top_k_decreases_indices[np.argsort(top_k_decreases)]
    sorted_top_k_decreases = difference[sorted_top_k_decreases_indices]

    # Pretty print the results and return the tables as a string
    result_string = pretty_print_results(sorted_top_k_increases_indices, sorted_top_k_increases, sorted_top_k_decreases_indices, sorted_top_k_decreases, model, k, visualize=visualize)
    return result_string, cut_prompt

# Example usage
prompt = "The war lasted from the year 1746 to the year 17"
layer = 9
features = [5463]
# sae = z_saes[layer]
transcoder = transcoders[layer]
result_string, cut_prompt = find_top_changes(prompt, model, layer, features, transcoder)
print(cut_prompt)
print(result_string)

In [None]:
[x for x in list(set(cp.circuit_hypergraph["MLP9"]['features'])) if x != -1]

In [None]:
dataset_prompts = generate_greater_than_dataset(N=50)
prompts = [x['text'].split('<|endoftext|>')[-1] + x['correct'] for x in dataset_prompts]
prompts[0]

In [None]:
from typing import List, Tuple, Dict, Optional
import torch
from collections import Counter

from data.greater_than_dataset import (
    generate_greater_than_dataset,
    GT_GROUND_TRUTH_HEADS,
)

# Function to generate logit increase/decrease information
def generate_logit_info(prompts: List[str], model, layer: int, features: List[int], sae) -> str:
    all_results = []
    for prompt in prompts:
        try:
            result_string, cut_prompt = find_top_changes(prompt, model, layer, features, sae, k=8, visualize=False)
            # Surround the last word in cut prompt with << and >> for the logit info
            cut_prompt = cut_prompt.split(" ")
            cut_prompt[-1] = f"<<{cut_prompt[-1]}>>"
            cut_prompt = " ".join(cut_prompt)
            if result_string is not None:
                all_results.append(f"Prompt = '{cut_prompt}'\n\n" + result_string)

        except:
            print(f"Error for prompt {prompt}")
            continue
            
    all_results = all_results[:10]
    return "\n\n\n".join(all_results)

# Function to generate mini examples for both general and IOI tasks
def generate_mini_examples(component_type: str, features: List[int], layer: int, num_examples: int, dataset_prompt_tokens, strategy) -> Tuple[List[str], List[str]]:
    mini_examples_owt_overall = []
    mini_examples_ioi_overall = []

    for feature in features:
        try:
            if component_type == "attn":
                analyze_owt = MaxActAnalysis("attn", layer, feature, num_sequences=num_examples, batch_size=128, strategy=strategy)
            else:
                analyze_owt = MaxActAnalysis("mlp", layer, feature, num_sequences=num_examples, batch_size=128, strategy=strategy)
            mini_examples_owt = analyze_owt.get_context_referenced_prompts_for_range(0, 5)
            mini_examples_owt_overall.extend(mini_examples_owt)
        except:
            print(f"Error for OWT tokens with feature {feature}")

        try: 
            if component_type == "attn":
                analyze_prompts = MaxActAnalysis("attn", layer, feature, num_sequences=num_examples, batch_size=128, strategy=strategy, token_dataset=dataset_prompt_tokens)
            else:
                analyze_prompts = MaxActAnalysis("mlp", layer, feature, num_sequences=num_examples, batch_size=128, strategy=strategy, token_dataset=dataset_prompt_tokens)
            
            mini_examples_ioi = analyze_prompts.get_context_referenced_prompts_for_range(0, 5)
            mini_examples_ioi_overall.extend(mini_examples_ioi)
        except:
            print(f"Error for IOI tokens with feature {feature}")

    return mini_examples_owt_overall, mini_examples_ioi_overall

# Function to interpret the neuron using the provided prompt
def interpret_neuron(mini_examples_owt: List[str], mini_examples_ioi: List[str], incoming_information: Optional[List[Tuple[str, str]]], logit_info: str) -> str:
    if len(incoming_information) > 0:
        print("Using incoming information.")
        prompt = main_aug_interp_prompt_ioi_incoming(mini_examples_owt, mini_examples_ioi, incoming_information, [logit_info])
    else:
        print("Not using incoming information.")
        prompt = main_aug_interp_prompt_gt(mini_examples_owt, mini_examples_ioi, [logit_info])
    return get_response(prompt), prompt

# Function to get the top 5 (or less) occurring features
def get_top_features(component: str, cp) -> List[int]:
    features = [x for x in list(cp.circuit_hypergraph[component]['features']) if x != -1]
    feature_counts = Counter(features)
    top_features = [feature for feature, _ in feature_counts.most_common(10)]
    return top_features

# Higher-level function to process components based on the graph notation
def process_components(graph: Dict[str, List[str]], model, z_saes, trancoders, strategy, num_examples: int = 2500, task='gt'):
    # Initialize data structure to manage incoming information
    component_interpretations = {}

    num_examples = 2500

    # Generate prompts
    if task == 'ioi':
        dataset_prompts = gen_templated_prompts(template_idex=1, N=250)
        prompts = [x['text'] + x['correct'] for x in dataset_prompts]
    elif task == 'gt':
        dataset_prompts = generate_greater_than_dataset(N=250)
        prompts = [x['text'].split('<|endoftext|>')[-1] for x in dataset_prompts]# + x['correct'] for x in dataset_prompts]
    else:
        raise ValueError(f"Invalid task: {task}")

    # Process each component
    for component, next_components in graph.items():
        component_type = "attn" if component[0]=='L' else "mlp"
        if component_type == "attn":
            layer, head = component.split("_")
            layer = int(layer[1:])
            head = int(head[1:])
            features = get_top_features(component, cp)
            print(f"Doing {component} with features {features} for layer {layer} head {head}")
        else:
            layer = int(component[3:])
            features = get_top_features(component, cp)
            print(f"Doing {component} with features {features} for layer {layer}")

        if component_type == "attn":
            print(f"Doing {component} with features {features} for layer {layer} head {head} with ZSAEs")
            sae = z_saes[layer]
        else:
            print(f"Doing {component} with features {features} for layer {layer} with transcoders")
            sae = transcoders[layer]

        # Generate logit increase/decrease information
        logit_info = generate_logit_info(prompts, model, layer, features, sae)
        print(logit_info)
        print(f"Successfully generated logit increase/decrease info for {component}.")

        # Generate mini examples
        if task == 'ioi':
            dataset_prompts_full = gen_templated_prompts(template_idex=1, N=250)
            full_prompts = [x['text'] + x['correct'] for x in dataset_prompts_full]
        elif task == 'gt':
            dataset_prompts_full = generate_greater_than_dataset(N=250)
            full_prompts = [x['text'] + x['correct'] for x in dataset_prompts_full]
        tokens = model.to_tokens(full_prompts)
        dataset_prompt_tokens = torch.tensor(tokens)

        mini_examples_owt, mini_examples_ioi = generate_mini_examples(component_type, features, layer, num_examples, dataset_prompt_tokens, strategy)

        # Prepare incoming information
        incoming_information = [] #[(prev_component, component_interpretations[prev_component]) for prev_component in graph if component in graph[prev_component]]
        print(f"Incoming information for {component}:\n{incoming_information}\n")

        # Interpret the neuron
        interpretation, p = interpret_neuron(mini_examples_owt, mini_examples_ioi, incoming_information, logit_info)
        print(f"Interpretation for {component}:\n{interpretation}\n")
        print(f"Prompt for {component}:\n{p}\n")
        component_interpretations[component] = {"interp": interpretation, "prompt": p}

        # Update incoming information for future components
        for next_component in next_components:
            if next_component not in graph:
                graph[next_component] = []
            graph[next_component].append(component)

    return component_interpretations

# Example graph notation
graph_notation = {
    # 'L0_H1',
    # 'L0_H3',
    # 'MLP0',
    # 'L1_H0',
    # 'L1_H10',
    # 'MLP1',
    # 'MLP2',
    # 'MLP3',
    # 'MLP4',
    # 'L5_H5',
    # 'L8_H7',
    'MLP4': [],
    # 'MLP9',
    # 'MLP10',
    # 'MLP11',
}

# # Example usage
# model = ...  # Your model here
# sae = z_saes[layer]
model, z_saes, transcoders = get_model_encoders('cpu')
strategy = create_simple_greedy_strategy(passes=1, node_contributors=1, minimal=True)
component_interpretations_mlp1 = process_components(graph_notation, model, z_saes, transcoders, strategy)

In [None]:
print(component_interpretations_mlp1['MLP4']['interp'])

In [None]:
print(component_interpretations_mlp1['MLP9']['prompt'])

In [None]:
from data.ioi_dataset import gen_templated_prompts
from aug_interp_prompts import main_aug_interp_prompt, main_aug_interp_prompt_v2
from openai_utils import gen_openai_completion, get_response
from autointerpretability import *
from discovery_strategies import (
    create_filter,
    create_simple_greedy_strategy,
    create_top_contributor_strategy,
)
from max_act_analysis import MaxActAnalysis

layer = 9
feature = 12072#, 5463, 15687
num_examples=1500
strategy = create_simple_greedy_strategy(passes=1, node_contributors=1, minimal=True)
analyze_owt = MaxActAnalysis("mlp", layer, feature, num_sequences=num_examples, batch_size=128, strategy=strategy)
mini_examples_owt = analyze_owt.get_context_referenced_prompts_for_range(0, 5)

In [None]:
for k, v in component_interpretations_l9h9.items():
    print(f"Component: {k}")
    print(f"Interpretation: {v['interp'].split('[EXPLANATION]: ')[-1].strip()}")  
    print("\n\n")