In [1]:
from auto_circuit.tasks import IOI_TOKEN_CIRCUIT_TASK, DOCSTRING_TOKEN_CIRCUIT_TASK, DOCSTRING_COMPONENT_CIRCUIT_TASK

from elk_experiments.auto_circuit.edge_graph import SeqGraph, sample_path_uniform, get_edge_path_counts, sample_paths, edge_in_path, valid_node, NodeType, node_name_to_type

In [2]:
task = DOCSTRING_COMPONENT_CIRCUIT_TASK
task.init_task()

Loaded pretrained model attn-only-4l into HookedTransformer
seq_len before divergence None
seq_len after divergence None


# Full Sequence Graph 

In [3]:
graph = SeqGraph(task.model.edges, task.token_circuit, task.model.cfg.attn_only)

In [4]:
# check reachable
len([(v.layer, v.node_type, v.seq_idx) for v in graph.seq_nodes if v.reachable]) / len(graph.seq_nodes)

1.0

In [5]:
# check pathcounts 
len([node for node in graph.seq_nodes if node.path_count > 0]) / len(graph.seq_nodes)

1.0

In [6]:
# check edge in path 
len([edge for edge in task.true_edges if edge_in_path(edge, graph)]) / task.true_edge_count

1.0

In [7]:
# check edge in path
all([not graph.valid_node(edge.dest.layer, node_name_to_type(edge.dest.name), edge.seq_idx) for edge in task.true_edges if not edge_in_path(edge, graph)])

True

In [8]:
# check edge patch counts
complement_edges = set(task.model.edges) - set(task.true_edges)
edge_path_counts = get_edge_path_counts(complement_edges, graph)
sum([v for v in edge_path_counts.values()])

VBox(children=(          | 0/? [00:00<?, ?it/s],))

656250

In [9]:
# check sample path uniform
path = sample_path_uniform(graph, edge_path_counts, complement_edges)
assert any(edge in complement_edges for edge in path)
[(edge, edge.seq_idx) for edge in path]


[(Resid Start->A0.2.K, None),
 (A0.2->A1.4.V, None),
 (A1.4->A2.3.K, None),
 (A2.3->A3.1.V, None),
 (A3.1->Resid End, None)]

In [10]:
sample_paths(graph, 256, complement_edges)

VBox(children=(          | 0/? [00:00<?, ?it/s],))

VBox(children=(          | 0/256 [00:00<?, ?it/s],))

[[Resid Start->A0.1.Q,
  A0.1->A1.5.V,
  A1.5->A2.5.Q,
  A2.5->A3.5.K,
  A3.5->Resid End],
 [Resid Start->A0.4.Q,
  A0.4->A1.3.K,
  A1.3->A2.5.K,
  A2.5->A3.4.V,
  A3.4->Resid End],
 [Resid Start->A0.7.K,
  A0.7->A1.5.V,
  A1.5->A2.2.K,
  A2.2->A3.1.V,
  A3.1->Resid End],
 [Resid Start->A0.2.K,
  A0.2->A1.5.K,
  A1.5->A2.1.V,
  A2.1->A3.1.V,
  A3.1->Resid End],
 [Resid Start->A0.4.K,
  A0.4->A1.4.K,
  A1.4->A2.4.V,
  A2.4->A3.4.Q,
  A3.4->Resid End],
 [Resid Start->A0.1.V,
  A0.1->A1.2.Q,
  A1.2->A2.7.V,
  A2.7->A3.2.Q,
  A3.2->Resid End],
 [Resid Start->A0.7.Q,
  A0.7->A1.2.K,
  A1.2->A2.1.V,
  A2.1->A3.4.Q,
  A3.4->Resid End],
 [Resid Start->A0.4.Q, A0.4->A2.5.Q, A2.5->A3.7.K, A3.7->Resid End],
 [Resid Start->A0.0.V,
  A0.0->A1.6.K,
  A1.6->A2.6.V,
  A2.6->A3.1.V,
  A3.1->Resid End],
 [Resid Start->A0.3.K,
  A0.3->A1.4.K,
  A1.4->A2.7.K,
  A2.7->A3.7.Q,
  A3.7->Resid End],
 [Resid Start->A1.6.Q, A1.6->A2.2.V, A2.2->A3.3.Q, A3.3->Resid End],
 [Resid Start->A0.0.Q,
  A0.0->A1.7.Q,
  A1

# Circuit Sequence Graph

In [11]:
circ_graph = SeqGraph(task.true_edges, task.token_circuit, task.model.cfg.attn_only)

In [12]:
# check reachable
len([(v.layer, v.node_type, v.seq_idx) for v in circ_graph.seq_nodes if v.reachable]) / len(circ_graph.seq_nodes)

1.0