In [None]:
import torch
import einops

# autoreload
%load_ext autoreload
%autoreload 2

from circuit_lens import CircuitLens
from circuit_discovery import CircuitDiscovery

In [18]:
prompt = "When John and Mary went to the store, John gave the bag to Mary"
cd = CircuitDiscovery(prompt=prompt)

In [78]:
cd.build_directed_graph(contributors_per_node=4)

In [79]:
def get_directed_graph_edges_with_weights(cd):
    edges_with_weights = []
    for receiver, contributors in cd.transformer_model.edge_tracker._reciever_to_contributors.items():
        for contributor, weight in contributors.items():
            edges_with_weights.append((contributor, receiver, weight))
    return edges_with_weights

# Access the edges with weights
edges_with_weights = get_directed_graph_edges_with_weights(cd)
print(len(edges_with_weights))

126620


In [80]:
# Print unique names (first thing in tuple) across edges with weights
unique_names_src = set([x[0][0] for x in edges_with_weights])
# Add unique names for trgt
unique_names_trgt = set([x[1][0] for x in edges_with_weights])
unique_names = unique_names_src.union(unique_names_trgt)
print(unique_names)

{'transcoder_error', 'attn_head', 'mlp_feature', 'embed', 'z_sae_error', 'z_feature', 'unembed_at_token', 'b_O', 'pos_embed', 'transcoder_bias', 'z_sae_bias'}


In [81]:
def print_connectivity_stats(edges_with_weights):
    attn_to_attn = 0
    attn_to_ff = 0
    ff_to_attn = 0
    ff_to_ff = 0
    other = 0

    for edge in edges_with_weights:
        if edge[0][0] == "attn_head" and edge[1][0] == "attn_head":
            attn_to_attn += 1
        elif edge[0][0] == "attn_head" and edge[1][0] == "mlp_feature":
            attn_to_ff += 1
        elif edge[0][0] == "mlp_feature" and edge[1][0] == "attn_head":
            ff_to_attn += 1
        elif edge[0][0] == "mlp_feature" and edge[1][0] == "mlp_feature":
            ff_to_ff += 1
        else:
            other += 1

    print(f"attn_to_attn: {attn_to_attn}")
    print(f"attn_to_ff: {attn_to_ff}")
    print(f"ff_to_attn: {ff_to_attn}")
    print(f"ff_to_ff: {ff_to_ff}")
    print(f"other: {other}")

print_connectivity_stats(edges_with_weights)

attn_to_attn: 0
attn_to_ff: 0
ff_to_attn: 32038
ff_to_ff: 923
other: 93659


In [82]:
from collections import defaultdict
import networkx as nx

def filter_and_aggregate_edges(edges_with_weights):
    filtered_edges = []
    
    # Filter edges and prepare the graph
    graph = nx.DiGraph()
    for src, dst, weight in edges_with_weights:
        if (src[0] in ['attn_head', 'mlp_feature'] and dst[0] in ['attn_head', 'mlp_feature']):
            #print(f"Found a direct path between {src} and {dst} with weight {weight}")
            graph.add_edge(src, dst, weight=weight)
    
    # Aggregation dictionaries
    attn_head_aggregation = defaultdict(lambda: defaultdict(list))
    mlp_aggregation = defaultdict(list)
    
    # Aggregate contributions for attention heads and MLPs
    for node in graph.nodes:
        if node[0] == 'attn_head':
            layer, head = node[1], node[2]
            attn_head_aggregation[layer][head].append(node)
        elif node[0] == 'mlp_feature':
            layer = node[1]
            mlp_aggregation[layer].append(node)
    
    # New graph to store aggregated edges
    aggregated_graph = nx.DiGraph()
    
    # Aggregate the edges with at most one intermediate node
    for node in graph.nodes:
        for neighbor in graph.successors(node):
            if graph[node][neighbor]['weight'] is not None:
                aggregated_graph.add_edge(node, neighbor, weight=graph[node][neighbor]['weight'])
            for next_neighbor in graph.successors(neighbor):
                if graph[neighbor][next_neighbor]['weight'] is not None:
                    combined_weight = graph[node][neighbor]['weight'] * graph[neighbor][next_neighbor]['weight']
                    aggregated_graph.add_edge(node, next_neighbor, weight=combined_weight)
    
    return aggregated_graph.edges(data=True)

aggregated_edges = filter_and_aggregate_edges(edges_with_weights)
print(len(aggregated_edges))

for edge in aggregated_edges:
    print(edge)

24140
(('mlp_feature', 0, 4, 5545), ('attn_head', 10, 0, 4, 14, -1), {'weight': 0.04073655808147336})
(('mlp_feature', 0, 4, 5545), ('attn_head', 10, 11, 4, 14, -1), {'weight': 0.018364919481609787})
(('mlp_feature', 0, 4, 5545), ('attn_head', 10, 7, 4, 14, -1), {'weight': 0.14910197340068443})
(('mlp_feature', 0, 4, 5545), ('attn_head', 9, 6, 4, 14, 11368), {'weight': 0.017093155626921597})
(('mlp_feature', 0, 4, 5545), ('attn_head', 9, 9, 4, 14, 11368), {'weight': 0.018195938148835022})
(('mlp_feature', 0, 4, 5545), ('attn_head', 9, 8, 4, 14, 11368), {'weight': 0.019684991770769322})
(('mlp_feature', 0, 4, 5545), ('attn_head', 8, 11, 4, 14, -1), {'weight': 0.0051213979167685775})
(('mlp_feature', 0, 4, 5545), ('mlp_feature', 6, 4, 22620), {'weight': 0.011944838360218313})
(('mlp_feature', 0, 4, 5545), ('mlp_feature', 7, 4, 12771), {'weight': 0.35617712140083313})
(('mlp_feature', 0, 4, 5545), ('attn_head', 8, 2, 4, 14, 17232), {'weight': 0.16399270296096802})
(('mlp_feature', 0, 4, 5

In [83]:
# Go through aggregated edges and only include information in tuples we care about
cleaned_edges = []
for src, dst, data in aggregated_edges:
    # If src is an attention head
    if src[0] == 'attn_head':
        src = (src[0], src[1], src[2])
    # If src is an MLP
    elif src[0] == 'mlp_feature':
        src = (src[0], src[1])
    
    # If dst is an attention head
    if dst[0] == 'attn_head':
        dst = (dst[0], dst[1], dst[2])
    # If dst is an MLP
    elif dst[0] == 'mlp_feature':
        dst = (dst[0], dst[1])

    cleaned_edges.append((src, dst, data['weight']))

print(f"Length of cleaned edges: {len(cleaned_edges)}")
for edge in cleaned_edges:
    print(edge)

Length of cleaned edges: 24140
(('mlp_feature', 0), ('attn_head', 10, 0), 0.04073655808147336)
(('mlp_feature', 0), ('attn_head', 10, 11), 0.018364919481609787)
(('mlp_feature', 0), ('attn_head', 10, 7), 0.14910197340068443)
(('mlp_feature', 0), ('attn_head', 9, 6), 0.017093155626921597)
(('mlp_feature', 0), ('attn_head', 9, 9), 0.018195938148835022)
(('mlp_feature', 0), ('attn_head', 9, 8), 0.019684991770769322)
(('mlp_feature', 0), ('attn_head', 8, 11), 0.0051213979167685775)
(('mlp_feature', 0), ('mlp_feature', 6), 0.011944838360218313)
(('mlp_feature', 0), ('mlp_feature', 7), 0.35617712140083313)
(('mlp_feature', 0), ('attn_head', 8, 2), 0.16399270296096802)
(('mlp_feature', 0), ('attn_head', 8, 9), 1.506545066833496)
(('mlp_feature', 0), ('mlp_feature', 2), 0.050479200328219065)
(('mlp_feature', 0), ('mlp_feature', 5), 0.012681874636538026)
(('mlp_feature', 0), ('attn_head', 4, 6), 0.9344635829468189)
(('mlp_feature', 0), ('mlp_feature', 3), 0.01937143396971308)
(('mlp_feature', 0

In [84]:
# We need to sum the weights for duplicate src and trg nodes
cleaned_edges_dict = {}
for src, dst, weight in cleaned_edges:
    if (src, dst) in cleaned_edges_dict:
        cleaned_edges_dict[(src, dst)] += weight
    else:
        cleaned_edges_dict[(src, dst)] = weight

# Print length of cleaned edges
print(f"Length of cleaned edges: {len(cleaned_edges_dict)}")

cleaned_edges_dict

Length of cleaned edges: 567


{(('mlp_feature', 0), ('attn_head', 10, 0)): 0.04073655808147336,
 (('mlp_feature', 0), ('attn_head', 10, 11)): 0.018364919481609787,
 (('mlp_feature', 0), ('attn_head', 10, 7)): 0.18343051566134466,
 (('mlp_feature', 0), ('attn_head', 9, 6)): 0.017093155626921597,
 (('mlp_feature', 0), ('attn_head', 9, 9)): 0.027710512478806456,
 (('mlp_feature', 0), ('attn_head', 9, 8)): 0.14845890158113867,
 (('mlp_feature', 0), ('attn_head', 8, 11)): 25.29536393111058,
 (('mlp_feature', 0), ('mlp_feature', 6)): 1.7596578076469858,
 (('mlp_feature', 0), ('mlp_feature', 7)): 1.5740405574180172,
 (('mlp_feature', 0), ('attn_head', 8, 2)): 8.029552797750402,
 (('mlp_feature', 0), ('attn_head', 8, 9)): 2.8768720874009555,
 (('mlp_feature', 0), ('mlp_feature', 2)): 10.776186672936417,
 (('mlp_feature', 0), ('mlp_feature', 5)): 3.472386995028584,
 (('mlp_feature', 0), ('attn_head', 4, 6)): 113.70588114520864,
 (('mlp_feature', 0), ('mlp_feature', 3)): 11.335792336531433,
 (('mlp_feature', 0), ('attn_head'

In [85]:
# Print number of times we have attn to attn
attn_to_attn = 0
attn_to_mlp = 0
mlp_to_attn = 0
mlp_to_mlp = 0

for (src, dst), weight in cleaned_edges_dict.items():
    if src[0] == 'attn_head' and dst[0] == 'attn_head':
        attn_to_attn += 1
    elif src[0] == 'attn_head' and dst[0] == 'mlp_feature':
        attn_to_mlp += 1
    elif src[0] == 'mlp_feature' and dst[0] == 'attn_head':
        mlp_to_attn += 1
    elif src[0] == 'mlp_feature' and dst[0] == 'mlp_feature':
        mlp_to_mlp += 1

print(f"Attn to Attn: {attn_to_attn}")
print(f"Attn to MLP: {attn_to_mlp}")
print(f"MLP to Attn: {mlp_to_attn}")
print(f"MLP to MLP: {mlp_to_mlp}")

Attn to Attn: 0
Attn to MLP: 0
MLP to Attn: 524
MLP to MLP: 43


## Greedy paths to build adjacency matrix

In [38]:
from circuit_discovery import CircuitDiscovery, CircuitDiscoveryHeadNode, CircuitDiscoveryRegularNode

# Autoreload
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [39]:
prompt = "When John and Mary went to the store, John gave the bag to Mary"
cd = CircuitDiscovery(prompt=prompt)

In [40]:
# Assume cd is an instance of the CircuitDiscovery class
M = 100  # Number of times to call the greedy function
k = 5  # Top k contributors at each step

for _ in range(M):
    cd.greedily_add_top_contributors(k=k)

In [53]:
# Functions to collect nodes and edges
def collect_nodes_and_edges(cd):
    nodes = set()
    edges = []

    def visit_node(node):
        nodes.add(node.tuple_id)
        if isinstance(node, CircuitDiscoveryRegularNode):
            for contributor in node.contributors_in_graph:
                edges.append((contributor.tuple_id, node.tuple_id, cd.transformer_model.edge_tracker._reciever_to_contributors[node.tuple_id][contributor.tuple_id]))
        elif isinstance(node, CircuitDiscoveryHeadNode):
            for head_type in ["q", "k", "v"]:
                for contributor in node.contributors_in_graph(head_type):
                    edges.append((contributor.tuple_id, node.tuple_id, cd.transformer_model.edge_tracker._reciever_to_contributors[node.tuple_id_for_head_type(head_type)][contributor.tuple_id]))

    cd.traverse_graph(visit_node)
    
    return nodes, edges

# Collect nodes and edges
nodes, edges = collect_nodes_and_edges(cd)

# Print the nodes and edges
print("Nodes:")
for node in nodes:
    print(node)

print("\nEdges (with weights):")
for edge in edges:
    print(edge)

Nodes:
('attn_head', 2, 8, 11, 11, -1)
('b_O', 0, 1, -1)
('z_sae_bias', 2, 3, -1)
('attn_head', 0, 1, 5, 5, 22805)
('attn_head', 1, 11, 2, 10, 6506)
('attn_head', 0, 9, 0, 2, 16453)
('z_feature', 0, 2, 9303)
('transcoder_error', 0, 2, -1)
('attn_head', 4, 8, 0, 11, -1)
('transcoder_bias', 1, 4, -1)
('attn_head', 1, 1, 3, 6, 24389)
('z_sae_bias', 0, 7, -1)
('b_O', 0, 9, -1)
('attn_head', 5, 5, 3, 10, 27535)
('attn_head', 2, 6, 1, 11, -1)
('b_O', 7, 10, -1)
('pos_embed', 0, 7, -1)
('attn_head', 0, 5, 2, 10, 9715)
('mlp_feature', 0, 3, 21708)
('mlp_feature', 0, 1, 17726)
('attn_head', 1, 10, 3, 3, -1)
('z_feature', 8, 14, 2623)
('z_sae_bias', 1, 4, -1)
('transcoder_bias', 0, 4, -1)
('transcoder_error', 2, 10, -1)
('z_sae_bias', 0, 10, -1)
('mlp_feature', 2, 11, 15455)
('attn_head', 8, 6, 10, 14, 16513)
('attn_head', 1, 11, 11, 11, 12552)
('attn_head', 2, 0, 6, 10, -1)
('z_sae_bias', 4, 11, -1)
('transcoder_bias', 0, 7, -1)
('embed', 0, 8, -1)
('z_feature', 0, 1, 12028)
('transcoder_error'

In [54]:
cleaned_edges = []
for src, dst, weight in edges:
    if src[0] == 'attn_head':
        src = (src[0], src[1], src[2])
    elif src[0] == 'mlp_feature':
        src = (src[0], src[1])
    
    if dst[0] == 'attn_head':
        dst = (dst[0], dst[1], dst[2])
    elif dst[0] == 'mlp_feature':
        dst = (dst[0], dst[1])

    cleaned_edges.append((src, dst, weight))

print(f"Length of cleaned edges: {len(cleaned_edges)}")
for edge in cleaned_edges:
    print(edge)

Length of cleaned edges: 1012
(('z_sae_error', 10, 14, -1), ('unembed_at_token', 14), 0.4225853681564331)
(('z_feature', 9, 14, 11368), ('unembed_at_token', 14), 0.39217451214790344)
(('z_sae_error', 8, 14, -1), ('unembed_at_token', 14), 0.03144604712724686)
(('transcoder_error', 10, 14, -1), ('unembed_at_token', 14), 0.029998723417520523)
(('z_feature', 8, 14, 2623), ('unembed_at_token', 14), 0.025278732180595398)
(('attn_head', 10, 0), ('z_sae_error', 10, 14, -1), 0.7936676144599915)
(('attn_head', 9, 6), ('z_feature', 9, 14, 11368), 0.4577443599700928)
(('attn_head', 9, 9), ('z_feature', 9, 14, 11368), 0.34887272119522095)
(('attn_head', 8, 6), ('z_sae_error', 8, 14, -1), 1.3374806642532349)
(('attn_head', 8, 4), ('z_sae_error', 8, 14, -1), 0.31561723351478577)
(('attn_head', 8, 6), ('z_feature', 8, 14, 2623), 0.48295125365257263)
(('z_sae_bias', 0, 14, -1), ('attn_head', 10, 0), 0.16439153254032135)
(('b_O', 1, 4, -1), ('attn_head', 10, 0), 0.29095473885536194)
(('mlp_feature', 0),

In [57]:
def is_keep_type(node):
    """Check if a node is a keep type (attention head or MLP feature)."""
    return node[0] in {'attn_head', 'mlp_feature'}

def get_layer(node):
    """Extract the layer number from a node."""
    return node[1]

def find_paths(graph, start, path, paths, weight):
    """Recursively find all valid paths starting from the given node."""
    current_node = path[-1]
    current_layer = get_layer(current_node)
    current_type_is_mlp = current_node[0] == 'mlp_feature'

    for edge in graph.get(current_node, []):
        next_node, edge_weight = edge
        next_layer = get_layer(next_node)
        
        # Skip if not strictly increasing layers
        if next_layer <= current_layer:
            continue
        
        # Ensure the correct alternating pattern
        if (current_type_is_mlp and next_node[0] != 'attn_head') or (not current_type_is_mlp and next_node[0] != 'mlp_feature'):
            continue
        
        new_weight = weight + edge_weight
        
        if is_keep_type(next_node):
            # If the next node is a keep type, record the path and start a new path
            paths.append((path + [next_node], start, next_node, new_weight))
        else:
            # Continue to search deeper
            find_paths(graph, start, path + [next_node], paths, new_weight)

def build_graph(edges):
    """Build a graph representation from the list of edges."""
    graph = {}
    for src, dst, weight in edges:
        if src not in graph:
            graph[src] = []
        graph[src].append((dst, weight))
    return graph

def aggregate_paths(paths):
    """Aggregate paths into edges with combined weights."""
    aggregated_edges = {}
    for path, src, dst, weight in paths:
        if (src, dst) not in aggregated_edges:
            aggregated_edges[(src, dst)] = 0.0
        aggregated_edges[(src, dst)] += weight
    return aggregated_edges

def filter_and_aggregate_edges(edges):
    """Main function to filter edges and aggregate paths."""
    # Step 1: Build the graph
    graph = build_graph(edges)
    
    # Step 2: Find all valid paths
    paths = []
    for node in graph:
        if is_keep_type(node):
            find_paths(graph, node, [node], paths, 0.0)
    
    # Step 3: Aggregate paths into edges
    aggregated_edges = aggregate_paths(paths)
    
    # Convert aggregated_edges to list format
    final_edges = [(src, dst, weight) for (src, dst), weight in aggregated_edges.items()]
    
    return final_edges

# Filter and aggregate edges
final_edges = filter_and_aggregate_edges(cleaned_edges)

In [58]:
final_edges

[(('mlp_feature', 0), ('attn_head', 10, 0), 0.2532690167427063),
 (('mlp_feature', 0), ('attn_head', 9, 6), 0.30155935883522034),
 (('mlp_feature', 0), ('attn_head', 9, 9), 0.5751477628946304),
 (('mlp_feature', 0), ('attn_head', 8, 4), 1.4676207900047302),
 (('mlp_feature', 0), ('attn_head', 8, 6), 0.28148677945137024),
 (('mlp_feature', 0), ('attn_head', 6, 9), 0.4516158401966095),
 (('mlp_feature', 0), ('attn_head', 5, 6), 1.5188753604888916),
 (('mlp_feature', 0), ('attn_head', 5, 11), 13.372763633728027),
 (('mlp_feature', 0), ('attn_head', 5, 0), 2.041739344596863),
 (('mlp_feature', 0), ('attn_head', 5, 5), 1.4683692157268524),
 (('mlp_feature', 0), ('attn_head', 3, 0), 1.0596462786197662),
 (('mlp_feature', 0), ('attn_head', 2, 2), 5.412833392620087),
 (('mlp_feature', 0), ('attn_head', 2, 3), 2.358521282672882),
 (('mlp_feature', 0), ('attn_head', 1, 0), 2.714501678943634),
 (('mlp_feature', 0), ('attn_head', 3, 6), 4.417906999588013),
 (('mlp_feature', 0), ('attn_head', 1, 10

In [59]:
# print max weight
max_weight = 0
max_edge = None
for edge in final_edges:
    if edge[2] > max_weight:
        max_weight = edge[2]
        max_edge = edge

print(max_edge)
print(max_weight)

(('mlp_feature', 0), ('attn_head', 3, 1), 30.762391567230225)
30.762391567230225


In [60]:
import numpy as np

def get_node_index(node):
    """Get the index of the node in the adjacency matrix."""
    if node[0] == 'attn_head':
        return node[1] * 12 + node[2]
    elif node[0] == 'mlp_feature':
        return 144 + node[1]

def create_labels(num_layers, num_heads_per_layer):
    """Create labels for the nodes in the adjacency matrix."""
    labels = []
    for layer in range(num_layers):
        for head in range(num_heads_per_layer):
            labels.append(f'attn_head_{layer}_{head}')
        labels.append(f'mlp_feature_{layer}')
    return labels

def build_adjacency_matrix(final_edges, num_layers, num_heads_per_layer):
    """Build the adjacency matrix from the final edges."""
    num_nodes = num_layers * (num_heads_per_layer + 1)
    adj_matrix = np.zeros((num_nodes, num_nodes))

    label_to_index = {}
    labels = create_labels(num_layers, num_heads_per_layer)
    
    for idx, label in enumerate(labels):
        label_to_index[label] = idx

    for src, dst, weight in final_edges:
        src_label = f'{src[0]}_{src[1]}' if src[0] == 'mlp_feature' else f'{src[0]}_{src[1]}_{src[2]}'
        dst_label = f'{dst[0]}_{dst[1]}' if dst[0] == 'mlp_feature' else f'{dst[0]}_{dst[1]}_{dst[2]}'
        
        src_idx = label_to_index[src_label]
        dst_idx = label_to_index[dst_label]
        
        adj_matrix[src_idx, dst_idx] = weight

    return adj_matrix, labels

In [61]:
adjacency_matrix, labels = build_adjacency_matrix(final_edges, 12, 12)

# Print the adjacency matrix
print(adjacency_matrix)

[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]


In [62]:
# plotly imshow for adjacency matrix
import plotly.express as px

fig = px.imshow(adjacency_matrix, zmax=1, color_continuous_scale='blues', width=600)
fig.show()

In [63]:
def adj_matrix_to_pred(adjacency_matrix):
    # Go through the adjacency matrix and add up the total weights for each source
    total_weights = np.sum(adjacency_matrix, axis=0)
    # If zero set to -inf
    total_weights[total_weights == 0] = -np.inf
    # Normalise with softmax
    total_weights = np.exp(total_weights) / np.sum(np.exp(total_weights))

    # Go through total weights and remove every 13th element
    y_pred = np.zeros(144)
    for i in range(144):
        if i % 13 != 0:
            y_pred[i] = total_weights[i]

    return y_pred
    

In [70]:
# Go through the adjacency matrix and add up the total weights for each source
total_weights = np.sum(adjacency_matrix, axis=1)
# total_weights += np.sum(adjacency_matrix, axis=0)
# If zero set to -inf
total_weights[total_weights == 0] = -np.inf
# Normalise with softmax
total_weights = np.exp(total_weights) / np.sum(np.exp(total_weights))

total_weights

array([0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       1.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 4.87286764e-35, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 7.91132108e-43, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 2.00867541e-45,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
      

In [71]:
# Go through total weights and remove every 13th element
y_pred = np.zeros(144)
for i in range(144):
    if i % 13 != 0:
        y_pred[i] = total_weights[i]

In [72]:
from data.ioi_dataset import IOI_GROUND_TRUTH_HEADS

IOI_GROUND_TRUTH_HEADS = IOI_GROUND_TRUTH_HEADS.flatten()

IOI_GROUND_TRUTH_HEADS.shape

torch.Size([144])

In [73]:
IOI_GROUND_TRUTH_HEADS

tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
        0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 1., 1., 1., 0., 0., 0.,
        1., 1., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 1., 0.])

In [74]:
y_pred

array([0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       1.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 4.87286764e-35, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 7.91132108e-43, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 2.00867541e-45,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
      

In [75]:
from sklearn.metrics import roc_auc_score

roc_auc_score(IOI_GROUND_TRUTH_HEADS, y_pred)

0.4661016949152542

## Get cumulative adjacency matrix for a bunch of prompts

In [19]:
from data.ioi_dataset import NAMES, SINGLE_TOKEN_NAMES, ABBA_TEMPLATES, gen_prompt_uniform, BABA_TEMPLATES, NOUNS_DICT, gen_templated_prompts
from transformer_lens import HookedTransformer, utils, ActivationCache
import torch


p = gen_prompt_uniform(
    [BABA_TEMPLATES[0], ABBA_TEMPLATES[0]], SINGLE_TOKEN_NAMES, NOUNS_DICT, 10, True
)

In [20]:
# Get list of prompts
prompts = []
for prompt in p:
    prompts.append(prompt['text'] + " " + prompt["IO"])

In [21]:
prompts

['Then, Ryan and Charles went to the restaurant. Charles gave a snack to Ryan',
 'Then, Charles and Ryan went to the restaurant. Ryan gave a snack to Charles',
 'Then, Eric and Joseph went to the house. Joseph gave a bone to Eric',
 'Then, Joseph and Eric went to the house. Eric gave a bone to Joseph',
 'Then, Paul and Steven went to the office. Steven gave a bone to Paul',
 'Then, Steven and Paul went to the office. Paul gave a bone to Steven',
 'Then, James and Jessica went to the station. Jessica gave a bone to James',
 'Then, Jessica and James went to the station. James gave a bone to Jessica',
 'Then, Rachel and Kyle went to the garden. Rachel gave a ring to Kyle',
 'Then, Kyle and Rachel went to the garden. Kyle gave a ring to Rachel']

In [22]:
from tqdm import tqdm
import numpy as np

agg_adj_matrix = np.zeros((156, 156))

for prompt in tqdm(prompts):

    cd = CircuitDiscovery(prompt=prompt)

    # Assume cd is an instance of the CircuitDiscovery class
    M = 5  # Number of times to call the greedy function
    k = 3  # Top k contributors at each step

    for _ in range(M):
        cd.greedily_add_top_contributors(k=k)

    nodes, edges = collect_nodes_and_edges(cd)

    cleaned_edges = []
    for src, dst, weight in edges:
        if src[0] == 'attn_head':
            src = (src[0], src[1], src[2])
        elif src[0] == 'mlp_feature':
            src = (src[0], src[1])
        
        if dst[0] == 'attn_head':
            dst = (dst[0], dst[1], dst[2])
        elif dst[0] == 'mlp_feature':
            dst = (dst[0], dst[1])

        cleaned_edges.append((src, dst, weight))

    
    # Filter and aggregate edges
    final_edges = filter_and_aggregate_edges(cleaned_edges)

    adjacency_matrix, labels = build_adjacency_matrix(final_edges, 12, 12)

    # add to agg_adj_matrix
    agg_adj_matrix += adjacency_matrix

    del adjacency_matrix
    del labels
    del cd

100%|██████████| 10/10 [00:38<00:00,  3.83s/it]


In [23]:
# imshow agg matrix
fig = px.imshow(agg_adj_matrix, zmax=1, color_continuous_scale='blues', 
                labels={'x': 'Destination', 'y': 'Source'}, width=600)
fig.show()

In [28]:
def adj_matrix_to_pred(adjacency_matrix):
    # Go through the adjacency matrix and add up the total weights for each source
    total_weights = np.sum(adjacency_matrix, axis=1)
    total_weights += np.sum(adjacency_matrix, axis=0)
    # If zero set to -inf
    total_weights[total_weights == 0] = -np.inf
    # Normalise with softmax
    total_weights = np.exp(total_weights) / np.sum(np.exp(total_weights))

    # Go through total weights and remove every 13th element
    y_pred = np.zeros(144)
    for i in range(144):
        if i % 13 != 0:
            y_pred[i] = total_weights[i]

    return y_pred

y_pred = adj_matrix_to_pred(agg_adj_matrix)

In [29]:
roc = roc_auc_score(IOI_GROUND_TRUTH_HEADS, y_pred)
print(f"ROC AUC: {roc}")

ROC AUC: 0.4827249022164276


## Task Evaluator

In [52]:
# %%

%load_ext autoreload
%autoreload 2


# %%
import torch
import time
import plotly.express as px
import matplotlib.pyplot as plt

from task_evaluation import TaskEvaluation
from data.ioi_dataset import gen_templated_prompts
from data.greater_than_dataset import generate_greater_than_dataset
from circuit_discovery import CircuitDiscovery, only_feature
from circuit_lens import CircuitComponent
from plotly_utils import *
from data.ioi_dataset import IOI_GROUND_TRUTH_HEADS
from data.greater_than_dataset import GT_GROUND_TRUTH_HEADS
from memory import get_gpu_memory
from sklearn import metrics
from tqdm import trange

from utils import get_attn_head_roc


# %%
torch.set_grad_enabled(False)
# %%
#dataset_prompts = gen_templated_prompts(template_idex=1, N=500)


dataset_prompts = generate_greater_than_dataset(N=100)


# %%

def component_filter(component: str):
    return component in [
        CircuitComponent.Z_FEATURE,
        CircuitComponent.MLP_FEATURE,
        CircuitComponent.ATTN_HEAD,
        CircuitComponent.UNEMBED,
        # CircuitComponent.UNEMBED_AT_TOKEN,
        CircuitComponent.EMBED,
        CircuitComponent.POS_EMBED,
        # CircuitComponent.BIAS_O,
        CircuitComponent.Z_SAE_ERROR,
        # CircuitComponent.Z_SAE_BIAS,
        # CircuitComponent.TRANSCODER_ERROR,
        # CircuitComponent.TRANSCODER_BIAS,
    ]


pass_based = True

passes = 5
node_contributors = 1
first_pass_minimal = True

sub_passes = 3
do_sub_pass = False
layer_thres = 9
minimal = True


num_greedy_passes = 20
k = 1
N = 30

thres = 4

def strategy(cd: CircuitDiscovery):
    if pass_based:
        for _ in range(passes):
            cd.add_greedy_pass(contributors_per_node=node_contributors, minimal=first_pass_minimal)

            if do_sub_pass:
                for _ in range(sub_passes):
                    cd.add_greedy_pass_against_all_existing_nodes(contributors_per_node=node_contributors, skip_z_features=True, layer_threshold=layer_thres, minimal=minimal)
    else:
        for _ in range(num_greedy_passes):
            cd.greedily_add_top_contributors(k=k, reciever_threshold=thres)



task_eval = TaskEvaluation(prompts=dataset_prompts, circuit_discovery_strategy=strategy, allowed_components_filter=component_filter)

cd = task_eval.get_circuit_discovery_for_prompt(20)
# f = task_eval.get_features_at_heads_over_dataset(N=30)
N = 100

attn_freqs = task_eval.get_attn_head_freqs_over_dataset(N=N, subtract_counter_factuals=False, return_freqs=True)


# %%
ground_truth = GT_GROUND_TRUTH_HEADS #IOI_GROUND_TRUTH_HEADS

# fp, tp, thresh = get_attn_head_roc(ground_truth, a.flatten().softmax(dim=-1), "IOI", visualize=True, additional_title="(No Counterfactuals)")
score, _, _, _ = get_attn_head_roc(ground_truth, attn_freqs.flatten().softmax(dim=-1), "GT", visualize=True, additional_title="(No Counterfactuals)")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


KeyboardInterrupt: 

In [42]:
from tqdm import tqdm

#dataset_prompts = gen_templated_prompts(template_idex=1, N=500)
dataset_prompts = generate_greater_than_dataset(N=100)
N_list = [1, 5, 10, 50, 100]
roc_scores_gt = []
for N in tqdm(N_list):

    task_eval = TaskEvaluation(prompts=dataset_prompts, circuit_discovery_strategy=strategy, allowed_components_filter=component_filter)
    cd = task_eval.get_circuit_discovery_for_prompt(20)
    attn_freqs = task_eval.get_attn_head_freqs_over_dataset(N=N, subtract_counter_factuals=False, return_freqs=True)
    ground_truth = GT_GROUND_TRUTH_HEADS #IOI_GROUND_TRUTH_HEADS
    score, _, _, _ = get_attn_head_roc(ground_truth, attn_freqs.flatten().softmax(dim=-1), "GT", visualize=False, additional_title="(No Counterfactuals)")
    roc_scores_gt.append(score)

100%|██████████| 1/1 [00:01<00:00,  1.80s/it]


100%|██████████| 5/5 [00:08<00:00,  1.69s/it]


100%|██████████| 10/10 [00:16<00:00,  1.65s/it]


100%|██████████| 50/50 [01:24<00:00,  1.69s/it]


100%|██████████| 100/100 [02:51<00:00,  1.71s/it]


 50%|█████     | 100/200 [02:59<02:59,  1.79s/it]
 83%|████████▎ | 5/6 [07:59<01:35, 95.87s/it]


IndexError: list index out of range

In [47]:
dataset_prompts = gen_templated_prompts(template_idex=1, N=100)
N_list = [1, 5, 10, 50, 100]
roc_scores_ioi = []
for N in tqdm(N_list):

    task_eval = TaskEvaluation(prompts=dataset_prompts, circuit_discovery_strategy=strategy, allowed_components_filter=component_filter)
    cd = task_eval.get_circuit_discovery_for_prompt(20)
    attn_freqs = task_eval.get_attn_head_freqs_over_dataset(N=N, subtract_counter_factuals=False, return_freqs=True)
    ground_truth = IOI_GROUND_TRUTH_HEADS
    score, _, _, _ = get_attn_head_roc(ground_truth, attn_freqs.flatten().softmax(dim=-1), "GT", visualize=False, additional_title="(No Counterfactuals)")
    roc_scores_ioi.append(score)

100%|██████████| 1/1 [00:02<00:00,  2.21s/it]


100%|██████████| 5/5 [00:10<00:00,  2.08s/it]


100%|██████████| 10/10 [00:21<00:00,  2.12s/it]


100%|██████████| 50/50 [01:59<00:00,  2.39s/it]


100%|██████████| 100/100 [03:53<00:00,  2.34s/it]


100%|██████████| 5/5 [06:43<00:00, 80.76s/it] 


In [51]:
import plotly.graph_objects as go

# Create traces
trace_ioi = go.Scatter(x=N_list, y=roc_scores_ioi, mode='lines', name='IOI')
trace_gt = go.Scatter(x=N_list, y=roc_scores, mode='lines', name='GT')

# Create layout
layout = go.Layout(xaxis_title='No. examples', yaxis_title='ROC', width=600)

# Create figure
fig = go.Figure(data=[trace_ioi, trace_gt], layout=layout)

# Show figure
fig.show()

In [46]:
# Plotly line plot for ROC scores
fig = px.line(x=N_list[:-1], y=roc_scores, labels={'x': 'No. examples', 'y': 'ROC AUC Score'}, width=600)
# 
fig.show()

In [34]:
# %%
import torch
import time
import plotly.express as px
import matplotlib.pyplot as plt

from task_evaluation import TaskEvaluation
from data.ioi_dataset import gen_templated_prompts
from data.greater_than_dataset import generate_greater_than_dataset
from circuit_discovery import CircuitDiscovery, only_feature
from circuit_lens import CircuitComponent
from plotly_utils import *
from data.ioi_dataset import IOI_GROUND_TRUTH_HEADS
# from data.ioi_dataset import GT_GROUND_TRUTH_HEADS
from memory import get_gpu_memory
from sklearn import metrics
from tqdm import trange

from utils import get_attn_head_roc

# Autoreload
%load_ext autoreload
%autoreload 2

# %%
torch.set_grad_enabled(False)

# %%
dataset_prompts = gen_templated_prompts(template_idex=1, N=500)

# dataset_prompts = generate_greater_than_dataset(N=100)

# %%
def component_filter(component: str):
    return component in [
        CircuitComponent.Z_FEATURE,
        CircuitComponent.MLP_FEATURE,
        CircuitComponent.ATTN_HEAD,
        CircuitComponent.UNEMBED,
        # CircuitComponent.UNEMBED_AT_TOKEN,
        CircuitComponent.EMBED,
        CircuitComponent.POS_EMBED,
        # CircuitComponent.BIAS_O,
        CircuitComponent.Z_SAE_ERROR,
        # CircuitComponent.Z_SAE_BIAS,
        # CircuitComponent.TRANSCODER_ERROR,
        # CircuitComponent.TRANSCODER_BIAS,
    ]

pass_based = True

passes = 5
node_contributors = 1
first_pass_minimal = True

sub_passes = 3
do_sub_pass = False
layer_thres = 9
minimal = True

num_greedy_passes = 20
k = 1
N = 30

thres = 4

def strategy(cd: CircuitDiscovery):
    if pass_based:
        for _ in range(passes):
            cd.add_greedy_pass(contributors_per_node=node_contributors, minimal=first_pass_minimal)

            if do_sub_pass:
                for _ in range(sub_passes):
                    cd.add_greedy_pass_against_all_existing_nodes(contributors_per_node=node_contributors, skip_z_features=True, layer_threshold=layer_thres, minimal=minimal)
    else:
        for _ in range(num_greedy_passes):
            cd.greedily_add_top_contributors(k=k, reciever_threshold=thres)

task_eval = TaskEvaluation(prompts=dataset_prompts, circuit_discovery_strategy=strategy, allowed_components_filter=component_filter)

cd = task_eval.get_circuit_discovery_for_prompt(20)
# f = task_eval.get_features_at_heads_over_dataset(N=30)
N = 5

attn_freqs = task_eval.get_weighted_attn_head_freqs_over_dataset(N=N, visualize=True, return_freqs=True)

print(attn_freqs.shape)

# Softmax across row (do not flatten, apply softmax across rows of matrix)
# attn_freqs = attn_freqs.softmax(dim=-1)

# %%
ground_truth = IOI_GROUND_TRUTH_HEADS

# fp, tp, thresh = get_attn_head_roc(ground_truth, a.flatten().softmax(dim=-1), "IOI", visualize=True, additional_title="(No Counterfactuals)")
score, _, _, _ = get_attn_head_roc(ground_truth, attn_freqs.flatten().softmax(dim=-1), "IOI", visualize=True, additional_title="(No Counterfactuals)")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


100%|██████████| 5/5 [00:12<00:00,  2.53s/it]


torch.Size([12, 12])
Score: 0.6756844850065189


In [27]:
px.imshow(attn_freqs, zmax=1, color_continuous_scale='blues', width=600).show()
px.imshow(ground_truth, zmax=1, color_continuous_scale='blues', width=600).show()