In [None]:
import torch
import einops

# autoreload
%load_ext autoreload
%autoreload 2

from circuit_lens import CircuitLens
from circuit_discovery import CircuitDiscovery

In [None]:
prompt = "When John and Mary went to the store, John gave the bag to Mary"
cd = CircuitDiscovery(prompt=prompt)

In [None]:
cd.build_directed_graph(contributors_per_node=4)

In [None]:
def get_directed_graph_edges_with_weights(cd):
    edges_with_weights = []
    for receiver, contributors in cd.transformer_model.edge_tracker._reciever_to_contributors.items():
        for contributor, weight in contributors.items():
            edges_with_weights.append((contributor, receiver, weight))
    return edges_with_weights

# Access the edges with weights
edges_with_weights = get_directed_graph_edges_with_weights(cd)
print(len(edges_with_weights))

In [None]:
# Print unique names (first thing in tuple) across edges with weights
unique_names_src = set([x[0][0] for x in edges_with_weights])
# Add unique names for trgt
unique_names_trgt = set([x[1][0] for x in edges_with_weights])
unique_names = unique_names_src.union(unique_names_trgt)
print(unique_names)

In [None]:
def print_connectivity_stats(edges_with_weights):
    attn_to_attn = 0
    attn_to_ff = 0
    ff_to_attn = 0
    ff_to_ff = 0
    other = 0

    for edge in edges_with_weights:
        if edge[0][0] == "attn_head" and edge[1][0] == "attn_head":
            attn_to_attn += 1
        elif edge[0][0] == "attn_head" and edge[1][0] == "mlp_feature":
            attn_to_ff += 1
        elif edge[0][0] == "mlp_feature" and edge[1][0] == "attn_head":
            ff_to_attn += 1
        elif edge[0][0] == "mlp_feature" and edge[1][0] == "mlp_feature":
            ff_to_ff += 1
        else:
            other += 1

    print(f"attn_to_attn: {attn_to_attn}")
    print(f"attn_to_ff: {attn_to_ff}")
    print(f"ff_to_attn: {ff_to_attn}")
    print(f"ff_to_ff: {ff_to_ff}")
    print(f"other: {other}")

print_connectivity_stats(edges_with_weights)

In [None]:
from collections import defaultdict
import networkx as nx

def filter_and_aggregate_edges(edges_with_weights):
    filtered_edges = []
    
    # Filter edges and prepare the graph
    graph = nx.DiGraph()
    for src, dst, weight in edges_with_weights:
        if (src[0] in ['attn_head', 'mlp_feature'] and dst[0] in ['attn_head', 'mlp_feature']):
            #print(f"Found a direct path between {src} and {dst} with weight {weight}")
            graph.add_edge(src, dst, weight=weight)
    
    # Aggregation dictionaries
    attn_head_aggregation = defaultdict(lambda: defaultdict(list))
    mlp_aggregation = defaultdict(list)
    
    # Aggregate contributions for attention heads and MLPs
    for node in graph.nodes:
        if node[0] == 'attn_head':
            layer, head = node[1], node[2]
            attn_head_aggregation[layer][head].append(node)
        elif node[0] == 'mlp_feature':
            layer = node[1]
            mlp_aggregation[layer].append(node)
    
    # New graph to store aggregated edges
    aggregated_graph = nx.DiGraph()
    
    # Aggregate the edges with at most one intermediate node
    for node in graph.nodes:
        for neighbor in graph.successors(node):
            if graph[node][neighbor]['weight'] is not None:
                aggregated_graph.add_edge(node, neighbor, weight=graph[node][neighbor]['weight'])
            for next_neighbor in graph.successors(neighbor):
                if graph[neighbor][next_neighbor]['weight'] is not None:
                    combined_weight = graph[node][neighbor]['weight'] * graph[neighbor][next_neighbor]['weight']
                    aggregated_graph.add_edge(node, next_neighbor, weight=combined_weight)
    
    return aggregated_graph.edges(data=True)

aggregated_edges = filter_and_aggregate_edges(edges_with_weights)
print(len(aggregated_edges))

for edge in aggregated_edges:
    print(edge)

In [None]:
# Go through aggregated edges and only include information in tuples we care about
cleaned_edges = []
for src, dst, data in aggregated_edges:
    # If src is an attention head
    if src[0] == 'attn_head':
        src = (src[0], src[1], src[2])
    # If src is an MLP
    elif src[0] == 'mlp_feature':
        src = (src[0], src[1])
    
    # If dst is an attention head
    if dst[0] == 'attn_head':
        dst = (dst[0], dst[1], dst[2])
    # If dst is an MLP
    elif dst[0] == 'mlp_feature':
        dst = (dst[0], dst[1])

    cleaned_edges.append((src, dst, data['weight']))

print(f"Length of cleaned edges: {len(cleaned_edges)}")
for edge in cleaned_edges:
    print(edge)

In [None]:
# We need to sum the weights for duplicate src and trg nodes
cleaned_edges_dict = {}
for src, dst, weight in cleaned_edges:
    if (src, dst) in cleaned_edges_dict:
        cleaned_edges_dict[(src, dst)] += weight
    else:
        cleaned_edges_dict[(src, dst)] = weight

# Print length of cleaned edges
print(f"Length of cleaned edges: {len(cleaned_edges_dict)}")

cleaned_edges_dict

In [None]:
# Print number of times we have attn to attn
attn_to_attn = 0
attn_to_mlp = 0
mlp_to_attn = 0
mlp_to_mlp = 0

for (src, dst), weight in cleaned_edges_dict.items():
    if src[0] == 'attn_head' and dst[0] == 'attn_head':
        attn_to_attn += 1
    elif src[0] == 'attn_head' and dst[0] == 'mlp_feature':
        attn_to_mlp += 1
    elif src[0] == 'mlp_feature' and dst[0] == 'attn_head':
        mlp_to_attn += 1
    elif src[0] == 'mlp_feature' and dst[0] == 'mlp_feature':
        mlp_to_mlp += 1

print(f"Attn to Attn: {attn_to_attn}")
print(f"Attn to MLP: {attn_to_mlp}")
print(f"MLP to Attn: {mlp_to_attn}")
print(f"MLP to MLP: {mlp_to_mlp}")

## Greedy paths to build adjacency matrix

In [None]:
from circuit_discovery import CircuitDiscovery, CircuitDiscoveryHeadNode, CircuitDiscoveryRegularNode

# Autoreload
%load_ext autoreload
%autoreload 2

In [None]:
prompt = "When John and Mary went to the store, John gave the bag to Mary"
cd = CircuitDiscovery(prompt=prompt)

In [None]:
# Assume cd is an instance of the CircuitDiscovery class
M = 100  # Number of times to call the greedy function
k = 5  # Top k contributors at each step

for _ in range(M):
    cd.greedily_add_top_contributors(k=k)

In [None]:
# Functions to collect nodes and edges
def collect_nodes_and_edges(cd):
    nodes = set()
    edges = []

    def visit_node(node):
        nodes.add(node.tuple_id)
        if isinstance(node, CircuitDiscoveryRegularNode):
            for contributor in node.contributors_in_graph:
                edges.append((contributor.tuple_id, node.tuple_id, cd.transformer_model.edge_tracker._reciever_to_contributors[node.tuple_id][contributor.tuple_id]))
        elif isinstance(node, CircuitDiscoveryHeadNode):
            for head_type in ["q", "k", "v"]:
                for contributor in node.contributors_in_graph(head_type):
                    edges.append((contributor.tuple_id, node.tuple_id, cd.transformer_model.edge_tracker._reciever_to_contributors[node.tuple_id_for_head_type(head_type)][contributor.tuple_id]))

    cd.traverse_graph(visit_node)
    
    return nodes, edges

# Collect nodes and edges
nodes, edges = collect_nodes_and_edges(cd)

# Print the nodes and edges
print("Nodes:")
for node in nodes:
    print(node)

print("\nEdges (with weights):")
for edge in edges:
    print(edge)

In [None]:
cleaned_edges = []
for src, dst, weight in edges:
    if src[0] == 'attn_head':
        src = (src[0], src[1], src[2])
    elif src[0] == 'mlp_feature':
        src = (src[0], src[1])
    
    if dst[0] == 'attn_head':
        dst = (dst[0], dst[1], dst[2])
    elif dst[0] == 'mlp_feature':
        dst = (dst[0], dst[1])

    cleaned_edges.append((src, dst, weight))

print(f"Length of cleaned edges: {len(cleaned_edges)}")
for edge in cleaned_edges:
    print(edge)

In [None]:
def is_keep_type(node):
    """Check if a node is a keep type (attention head or MLP feature)."""
    return node[0] in {'attn_head', 'mlp_feature'}

def get_layer(node):
    """Extract the layer number from a node."""
    return node[1]

def find_paths(graph, start, path, paths, weight):
    """Recursively find all valid paths starting from the given node."""
    current_node = path[-1]
    current_layer = get_layer(current_node)
    current_type_is_mlp = current_node[0] == 'mlp_feature'

    for edge in graph.get(current_node, []):
        next_node, edge_weight = edge
        next_layer = get_layer(next_node)
        
        # Skip if not strictly increasing layers
        if next_layer <= current_layer:
            continue
        
        # Ensure the correct alternating pattern
        if (current_type_is_mlp and next_node[0] != 'attn_head') or (not current_type_is_mlp and next_node[0] != 'mlp_feature'):
            continue
        
        new_weight = weight + edge_weight
        
        if is_keep_type(next_node):
            # If the next node is a keep type, record the path and start a new path
            paths.append((path + [next_node], start, next_node, new_weight))
        else:
            # Continue to search deeper
            find_paths(graph, start, path + [next_node], paths, new_weight)

def build_graph(edges):
    """Build a graph representation from the list of edges."""
    graph = {}
    for src, dst, weight in edges:
        if src not in graph:
            graph[src] = []
        graph[src].append((dst, weight))
    return graph

def aggregate_paths(paths):
    """Aggregate paths into edges with combined weights."""
    aggregated_edges = {}
    for path, src, dst, weight in paths:
        if (src, dst) not in aggregated_edges:
            aggregated_edges[(src, dst)] = 0.0
        aggregated_edges[(src, dst)] += weight
    return aggregated_edges

def filter_and_aggregate_edges(edges):
    """Main function to filter edges and aggregate paths."""
    # Step 1: Build the graph
    graph = build_graph(edges)
    
    # Step 2: Find all valid paths
    paths = []
    for node in graph:
        if is_keep_type(node):
            find_paths(graph, node, [node], paths, 0.0)
    
    # Step 3: Aggregate paths into edges
    aggregated_edges = aggregate_paths(paths)
    
    # Convert aggregated_edges to list format
    final_edges = [(src, dst, weight) for (src, dst), weight in aggregated_edges.items()]
    
    return final_edges

# Filter and aggregate edges
final_edges = filter_and_aggregate_edges(cleaned_edges)

In [None]:
final_edges

In [None]:
# print max weight
max_weight = 0
max_edge = None
for edge in final_edges:
    if edge[2] > max_weight:
        max_weight = edge[2]
        max_edge = edge

print(max_edge)
print(max_weight)

In [None]:
import numpy as np

def get_node_index(node):
    """Get the index of the node in the adjacency matrix."""
    if node[0] == 'attn_head':
        return node[1] * 12 + node[2]
    elif node[0] == 'mlp_feature':
        return 144 + node[1]

def create_labels(num_layers, num_heads_per_layer):
    """Create labels for the nodes in the adjacency matrix."""
    labels = []
    for layer in range(num_layers):
        for head in range(num_heads_per_layer):
            labels.append(f'attn_head_{layer}_{head}')
        labels.append(f'mlp_feature_{layer}')
    return labels

def build_adjacency_matrix(final_edges, num_layers, num_heads_per_layer):
    """Build the adjacency matrix from the final edges."""
    num_nodes = num_layers * (num_heads_per_layer + 1)
    adj_matrix = np.zeros((num_nodes, num_nodes))

    label_to_index = {}
    labels = create_labels(num_layers, num_heads_per_layer)
    
    for idx, label in enumerate(labels):
        label_to_index[label] = idx

    for src, dst, weight in final_edges:
        src_label = f'{src[0]}_{src[1]}' if src[0] == 'mlp_feature' else f'{src[0]}_{src[1]}_{src[2]}'
        dst_label = f'{dst[0]}_{dst[1]}' if dst[0] == 'mlp_feature' else f'{dst[0]}_{dst[1]}_{dst[2]}'
        
        src_idx = label_to_index[src_label]
        dst_idx = label_to_index[dst_label]
        
        adj_matrix[src_idx, dst_idx] = weight

    return adj_matrix, labels

In [None]:
adjacency_matrix, labels = build_adjacency_matrix(final_edges, 12, 12)

# Print the adjacency matrix
print(adjacency_matrix)

In [None]:
# plotly imshow for adjacency matrix
import plotly.express as px

fig = px.imshow(adjacency_matrix, zmax=1, color_continuous_scale='blues', width=600)
fig.show()

In [None]:
def adj_matrix_to_pred(adjacency_matrix):
    # Go through the adjacency matrix and add up the total weights for each source
    total_weights = np.sum(adjacency_matrix, axis=0)
    # If zero set to -inf
    total_weights[total_weights == 0] = -np.inf
    # Normalise with softmax
    total_weights = np.exp(total_weights) / np.sum(np.exp(total_weights))

    # Go through total weights and remove every 13th element
    y_pred = np.zeros(144)
    for i in range(144):
        if i % 13 != 0:
            y_pred[i] = total_weights[i]

    return y_pred
    

In [None]:
# Go through the adjacency matrix and add up the total weights for each source
total_weights = np.sum(adjacency_matrix, axis=1)
# total_weights += np.sum(adjacency_matrix, axis=0)
# If zero set to -inf
total_weights[total_weights == 0] = -np.inf
# Normalise with softmax
total_weights = np.exp(total_weights) / np.sum(np.exp(total_weights))

total_weights

In [None]:
# Go through total weights and remove every 13th element
y_pred = np.zeros(144)
for i in range(144):
    if i % 13 != 0:
        y_pred[i] = total_weights[i]

In [None]:
from data.ioi_dataset import IOI_GROUND_TRUTH_HEADS

IOI_GROUND_TRUTH_HEADS = IOI_GROUND_TRUTH_HEADS.flatten()

IOI_GROUND_TRUTH_HEADS.shape

In [None]:
IOI_GROUND_TRUTH_HEADS

In [None]:
y_pred

In [None]:
from sklearn.metrics import roc_auc_score

roc_auc_score(IOI_GROUND_TRUTH_HEADS, y_pred)

## Get cumulative adjacency matrix for a bunch of prompts

In [None]:
from data.ioi_dataset import NAMES, SINGLE_TOKEN_NAMES, ABBA_TEMPLATES, gen_prompt_uniform, BABA_TEMPLATES, NOUNS_DICT, gen_templated_prompts
from transformer_lens import HookedTransformer, utils, ActivationCache
import torch


p = gen_prompt_uniform(
    [BABA_TEMPLATES[0], ABBA_TEMPLATES[0]], SINGLE_TOKEN_NAMES, NOUNS_DICT, 10, True
)

In [None]:
# Get list of prompts
prompts = []
for prompt in p:
    prompts.append(prompt['text'] + " " + prompt["IO"])

In [None]:
prompts

In [None]:
from tqdm import tqdm
import numpy as np

agg_adj_matrix = np.zeros((156, 156))

for prompt in tqdm(prompts):

    cd = CircuitDiscovery(prompt=prompt)

    # Assume cd is an instance of the CircuitDiscovery class
    M = 5  # Number of times to call the greedy function
    k = 3  # Top k contributors at each step

    for _ in range(M):
        cd.greedily_add_top_contributors(k=k)

    nodes, edges = collect_nodes_and_edges(cd)

    cleaned_edges = []
    for src, dst, weight in edges:
        if src[0] == 'attn_head':
            src = (src[0], src[1], src[2])
        elif src[0] == 'mlp_feature':
            src = (src[0], src[1])
        
        if dst[0] == 'attn_head':
            dst = (dst[0], dst[1], dst[2])
        elif dst[0] == 'mlp_feature':
            dst = (dst[0], dst[1])

        cleaned_edges.append((src, dst, weight))

    
    # Filter and aggregate edges
    final_edges = filter_and_aggregate_edges(cleaned_edges)

    adjacency_matrix, labels = build_adjacency_matrix(final_edges, 12, 12)

    # add to agg_adj_matrix
    agg_adj_matrix += adjacency_matrix

    del adjacency_matrix
    del labels
    del cd

In [None]:
# imshow agg matrix
fig = px.imshow(agg_adj_matrix, zmax=1, color_continuous_scale='blues', 
                labels={'x': 'Destination', 'y': 'Source'}, width=600)
fig.show()

In [None]:
def adj_matrix_to_pred(adjacency_matrix):
    # Go through the adjacency matrix and add up the total weights for each source
    total_weights = np.sum(adjacency_matrix, axis=1)
    total_weights += np.sum(adjacency_matrix, axis=0)
    # If zero set to -inf
    total_weights[total_weights == 0] = -np.inf
    # Normalise with softmax
    total_weights = np.exp(total_weights) / np.sum(np.exp(total_weights))

    # Go through total weights and remove every 13th element
    y_pred = np.zeros(144)
    for i in range(144):
        if i % 13 != 0:
            y_pred[i] = total_weights[i]

    return y_pred

y_pred = adj_matrix_to_pred(agg_adj_matrix)

In [None]:
roc = roc_auc_score(IOI_GROUND_TRUTH_HEADS, y_pred)
print(f"ROC AUC: {roc}")

## Task Evaluator

In [None]:
# %%

%load_ext autoreload
%autoreload 2


# %%
import torch
import time
import plotly.express as px
import matplotlib.pyplot as plt

from task_evaluation import TaskEvaluation
from data.ioi_dataset import gen_templated_prompts
from data.greater_than_dataset import generate_greater_than_dataset
from circuit_discovery import CircuitDiscovery, only_feature
from circuit_lens import CircuitComponent
from plotly_utils import *
from data.ioi_dataset import IOI_GROUND_TRUTH_HEADS
from data.greater_than_dataset import GT_GROUND_TRUTH_HEADS
from memory import get_gpu_memory
from sklearn import metrics
from tqdm import trange

from utils import get_attn_head_roc


# %%
torch.set_grad_enabled(False)
# %%
#dataset_prompts = gen_templated_prompts(template_idex=1, N=500)


dataset_prompts = generate_greater_than_dataset(N=100)


# %%

def component_filter(component: str):
    return component in [
        CircuitComponent.Z_FEATURE,
        CircuitComponent.MLP_FEATURE,
        CircuitComponent.ATTN_HEAD,
        CircuitComponent.UNEMBED,
        # CircuitComponent.UNEMBED_AT_TOKEN,
        CircuitComponent.EMBED,
        CircuitComponent.POS_EMBED,
        # CircuitComponent.BIAS_O,
        CircuitComponent.Z_SAE_ERROR,
        # CircuitComponent.Z_SAE_BIAS,
        # CircuitComponent.TRANSCODER_ERROR,
        # CircuitComponent.TRANSCODER_BIAS,
    ]


pass_based = True

passes = 5
node_contributors = 1
first_pass_minimal = True

sub_passes = 3
do_sub_pass = False
layer_thres = 9
minimal = True


num_greedy_passes = 20
k = 1
N = 30

thres = 4

def strategy(cd: CircuitDiscovery):
    if pass_based:
        for _ in range(passes):
            cd.add_greedy_pass(contributors_per_node=node_contributors, minimal=first_pass_minimal)

            if do_sub_pass:
                for _ in range(sub_passes):
                    cd.add_greedy_pass_against_all_existing_nodes(contributors_per_node=node_contributors, skip_z_features=True, layer_threshold=layer_thres, minimal=minimal)
    else:
        for _ in range(num_greedy_passes):
            cd.greedily_add_top_contributors(k=k, reciever_threshold=thres)



task_eval = TaskEvaluation(prompts=dataset_prompts, circuit_discovery_strategy=strategy, allowed_components_filter=component_filter)

cd = task_eval.get_circuit_discovery_for_prompt(20)
# f = task_eval.get_features_at_heads_over_dataset(N=30)
N = 100

attn_freqs = task_eval.get_attn_head_freqs_over_dataset(N=N, subtract_counter_factuals=False, return_freqs=True)


# %%
ground_truth = GT_GROUND_TRUTH_HEADS #IOI_GROUND_TRUTH_HEADS

# fp, tp, thresh = get_attn_head_roc(ground_truth, a.flatten().softmax(dim=-1), "IOI", visualize=True, additional_title="(No Counterfactuals)")
score, _, _, _ = get_attn_head_roc(ground_truth, attn_freqs.flatten().softmax(dim=-1), "GT", visualize=True, additional_title="(No Counterfactuals)")

In [None]:
from tqdm import tqdm

#dataset_prompts = gen_templated_prompts(template_idex=1, N=500)
dataset_prompts = generate_greater_than_dataset(N=100)
N_list = [1, 5, 10, 50, 100]
roc_scores_gt = []
for N in tqdm(N_list):

    task_eval = TaskEvaluation(prompts=dataset_prompts, circuit_discovery_strategy=strategy, allowed_components_filter=component_filter)
    cd = task_eval.get_circuit_discovery_for_prompt(20)
    attn_freqs = task_eval.get_attn_head_freqs_over_dataset(N=N, subtract_counter_factuals=False, return_freqs=True)
    ground_truth = GT_GROUND_TRUTH_HEADS #IOI_GROUND_TRUTH_HEADS
    score, _, _, _ = get_attn_head_roc(ground_truth, attn_freqs.flatten().softmax(dim=-1), "GT", visualize=False, additional_title="(No Counterfactuals)")
    roc_scores_gt.append(score)

In [None]:
dataset_prompts = gen_templated_prompts(template_idex=1, N=100)
N_list = [1, 5, 10, 50, 100]
roc_scores_ioi = []
for N in tqdm(N_list):

    task_eval = TaskEvaluation(prompts=dataset_prompts, circuit_discovery_strategy=strategy, allowed_components_filter=component_filter)
    cd = task_eval.get_circuit_discovery_for_prompt(20)
    attn_freqs = task_eval.get_attn_head_freqs_over_dataset(N=N, subtract_counter_factuals=False, return_freqs=True)
    ground_truth = IOI_GROUND_TRUTH_HEADS
    score, _, _, _ = get_attn_head_roc(ground_truth, attn_freqs.flatten().softmax(dim=-1), "GT", visualize=False, additional_title="(No Counterfactuals)")
    roc_scores_ioi.append(score)

In [None]:
import plotly.graph_objects as go

# Create traces
trace_ioi = go.Scatter(x=N_list, y=roc_scores_ioi, mode='lines', name='IOI')
trace_gt = go.Scatter(x=N_list, y=roc_scores, mode='lines', name='GT')

# Create layout
layout = go.Layout(xaxis_title='No. examples', yaxis_title='ROC', width=600)

# Create figure
fig = go.Figure(data=[trace_ioi, trace_gt], layout=layout)

# Show figure
fig.show()

In [None]:
# Plotly line plot for ROC scores
fig = px.line(x=N_list[:-1], y=roc_scores, labels={'x': 'No. examples', 'y': 'ROC AUC Score'}, width=600)
# 
fig.show()

In [None]:
# %%
import torch
import time
import plotly.express as px
import matplotlib.pyplot as plt

from task_evaluation import TaskEvaluation
from data.ioi_dataset import gen_templated_prompts
from data.greater_than_dataset import generate_greater_than_dataset
from circuit_discovery import CircuitDiscovery, only_feature
from circuit_lens import CircuitComponent
from plotly_utils import *
from data.ioi_dataset import IOI_GROUND_TRUTH_HEADS
# from data.ioi_dataset import GT_GROUND_TRUTH_HEADS
from memory import get_gpu_memory
from sklearn import metrics
from tqdm import trange

from utils import get_attn_head_roc

# Autoreload
%load_ext autoreload
%autoreload 2

# %%
torch.set_grad_enabled(False)

# %%
dataset_prompts = gen_templated_prompts(template_idex=1, N=500)

# dataset_prompts = generate_greater_than_dataset(N=100)

# %%
def component_filter(component: str):
    return component in [
        CircuitComponent.Z_FEATURE,
        CircuitComponent.MLP_FEATURE,
        CircuitComponent.ATTN_HEAD,
        CircuitComponent.UNEMBED,
        # CircuitComponent.UNEMBED_AT_TOKEN,
        CircuitComponent.EMBED,
        CircuitComponent.POS_EMBED,
        # CircuitComponent.BIAS_O,
        CircuitComponent.Z_SAE_ERROR,
        # CircuitComponent.Z_SAE_BIAS,
        # CircuitComponent.TRANSCODER_ERROR,
        # CircuitComponent.TRANSCODER_BIAS,
    ]

pass_based = True

passes = 5
node_contributors = 1
first_pass_minimal = True

sub_passes = 3
do_sub_pass = False
layer_thres = 9
minimal = True

num_greedy_passes = 20
k = 1
N = 30

thres = 4

def strategy(cd: CircuitDiscovery):
    if pass_based:
        for _ in range(passes):
            cd.add_greedy_pass(contributors_per_node=node_contributors, minimal=first_pass_minimal)

            if do_sub_pass:
                for _ in range(sub_passes):
                    cd.add_greedy_pass_against_all_existing_nodes(contributors_per_node=node_contributors, skip_z_features=True, layer_threshold=layer_thres, minimal=minimal)
    else:
        for _ in range(num_greedy_passes):
            cd.greedily_add_top_contributors(k=k, reciever_threshold=thres)

task_eval = TaskEvaluation(prompts=dataset_prompts, circuit_discovery_strategy=strategy, allowed_components_filter=component_filter)

cd = task_eval.get_circuit_discovery_for_prompt(20)
# f = task_eval.get_features_at_heads_over_dataset(N=30)
N = 5

attn_freqs = task_eval.get_weighted_attn_head_freqs_over_dataset(N=N, visualize=True, return_freqs=True)

print(attn_freqs.shape)

# Softmax across row (do not flatten, apply softmax across rows of matrix)
# attn_freqs = attn_freqs.softmax(dim=-1)

# %%
ground_truth = IOI_GROUND_TRUTH_HEADS

# fp, tp, thresh = get_attn_head_roc(ground_truth, a.flatten().softmax(dim=-1), "IOI", visualize=True, additional_title="(No Counterfactuals)")
score, _, _, _ = get_attn_head_roc(ground_truth, attn_freqs.flatten().softmax(dim=-1), "IOI", visualize=True, additional_title="(No Counterfactuals)")

In [None]:
px.imshow(attn_freqs, zmax=1, color_continuous_scale='blues', width=600).show()
px.imshow(ground_truth, zmax=1, color_continuous_scale='blues', width=600).show()