In [1]:
import transformer_lens as tl
import torch as t
from acdc.TLACDCExperiment import TLACDCCorrespondence
import dataclasses as dc
from typing import Optional

device = t.device("cuda" if t.cuda.is_available() else "cpu")
model = tl.HookedTransformer.from_pretrained("attn-only-2l")

# NOTE. The following lines are present in the code of the official repo,
# but I don't know what it really does.

model.set_use_attn_result(True)     
model.set_use_split_qkv_input(True)

Loaded pretrained model attn-only-2l into HookedTransformer


In [2]:
for name, parameters in model.named_parameters():
    print(f"{name}: {parameters.shape}")

embed.W_E: torch.Size([48262, 512])
pos_embed.W_pos: torch.Size([1024, 512])
blocks.0.attn.W_Q: torch.Size([8, 512, 64])
blocks.0.attn.W_O: torch.Size([8, 64, 512])
blocks.0.attn.b_Q: torch.Size([8, 64])
blocks.0.attn.b_O: torch.Size([512])
blocks.0.attn.W_K: torch.Size([8, 512, 64])
blocks.0.attn.W_V: torch.Size([8, 512, 64])
blocks.0.attn.b_K: torch.Size([8, 64])
blocks.0.attn.b_V: torch.Size([8, 64])
blocks.1.attn.W_Q: torch.Size([8, 512, 64])
blocks.1.attn.W_O: torch.Size([8, 64, 512])
blocks.1.attn.b_Q: torch.Size([8, 64])
blocks.1.attn.b_O: torch.Size([512])
blocks.1.attn.W_K: torch.Size([8, 512, 64])
blocks.1.attn.W_V: torch.Size([8, 512, 64])
blocks.1.attn.b_K: torch.Size([8, 64])
blocks.1.attn.b_V: torch.Size([8, 64])
unembed.W_U: torch.Size([512, 48262])
unembed.b_U: torch.Size([48262])


In [3]:
# NOTE. Edges from the ACDC package, maybe we should create a method to compare our graph to theirs.

correspondence = TLACDCCorrespondence.setup_from_model(model)
correspondence.edges


OrderedDefaultdict(<function acdc.acdc_utils.make_nd_dict.<locals>.<lambda>()>,
                   {'blocks.1.hook_resid_post': defaultdict(<function acdc.acdc_utils.make_nd_dict.<locals>.<lambda>.<locals>.<lambda>()>,
                                {[:]: defaultdict(<function acdc.acdc_utils.make_nd_dict.<locals>.<lambda>.<locals>.<lambda>.<locals>.<lambda>()>,
                                             {'blocks.1.attn.hook_result': defaultdict(None,
                                                          {[:, :, 7]: Edge(EdgeType.ADDITION, True),
                                                           [:, :, 6]: Edge(EdgeType.ADDITION, True),
                                                           [:, :, 5]: Edge(EdgeType.ADDITION, True),
                                                           [:, :, 4]: Edge(EdgeType.ADDITION, True),
                                                           [:, :, 3]: Edge(EdgeType.ADDITION, True),
                                    

In [4]:
# NOTE. I got this code from ARENA.

def generate_repeated_tokens(
    model: tl.HookedTransformer, seq_len: int, batch: int = 1
) -> t.Tensor:
    prefix = (t.ones(batch, 1) * model.tokenizer.bos_token_id).long()
    rep_tokens_half = t.randint(0, model.cfg.d_vocab, (batch, seq_len), dtype=t.int64)
    rep_tokens = t.cat([prefix, rep_tokens_half, rep_tokens_half], dim=-1).to(device)
    return rep_tokens

repeated_tokens = generate_repeated_tokens(model=model, seq_len=50, batch=100)


In [5]:
logits, cache = model.run_with_cache(repeated_tokens)
for name, values in cache.items():
    print(f"{name}: {values.shape}")

hook_embed: torch.Size([100, 101, 512])
hook_pos_embed: torch.Size([100, 101, 512])
blocks.0.hook_resid_pre: torch.Size([100, 101, 512])
blocks.0.hook_q_input: torch.Size([100, 101, 8, 512])
blocks.0.hook_k_input: torch.Size([100, 101, 8, 512])
blocks.0.hook_v_input: torch.Size([100, 101, 8, 512])
blocks.0.ln1.hook_scale: torch.Size([100, 101, 8, 1])
blocks.0.ln1.hook_normalized: torch.Size([100, 101, 8, 512])
blocks.0.attn.hook_q: torch.Size([100, 101, 8, 64])
blocks.0.attn.hook_k: torch.Size([100, 101, 8, 64])
blocks.0.attn.hook_v: torch.Size([100, 101, 8, 64])
blocks.0.attn.hook_attn_scores: torch.Size([100, 8, 101, 101])
blocks.0.attn.hook_pattern: torch.Size([100, 8, 101, 101])
blocks.0.attn.hook_z: torch.Size([100, 101, 8, 64])
blocks.0.attn.hook_result: torch.Size([100, 101, 8, 512])
blocks.0.hook_attn_out: torch.Size([100, 101, 512])
blocks.0.hook_resid_post: torch.Size([100, 101, 512])
blocks.1.hook_resid_pre: torch.Size([100, 101, 512])
blocks.1.hook_q_input: torch.Size([100,

In [13]:
@dc.dataclass
class TransformerNode:
    hook_name: str
    head_idx: Optional[int] = None

@dc.dataclass
class TransformerEdge:
    parent_node: TransformerNode
    child_node: TransformerNode

# NOTE. I am translating the code from https://github.com/ArthurConmy/Automatic-Circuit-Discovery/blob/main/acdc/TLACDCCorrespondence.py
# I commented out some parts that I think will be important in the future for consulting.

@dc.dataclass
class TransformerGraph:
    nodes: list[TransformerNode] = dc.field(default_factory=list)
    edges: list[TransformerEdge] = dc.field(default_factory=list)

    def from_model(model: tl.HookedTransformer):
        graph = TransformerGraph()
        
        n_layers = model.cfg.n_layers
        residual_stream_nodes: list[TransformerNode] = []

        logits_node = TransformerNode(
            hook_name=f"blocks.{n_layers-1}.hook_resid_post",
            # incoming_edge_type = EdgeType.ADDITION,
        )
        
        graph.nodes.append(logits_node)
        residual_stream_nodes.append(logits_node)

        for layer_idx in range(model.cfg.n_layers - 1, -1, -1):
            new_residual_stream_nodes: list[TransformerNode] = []

            for head_idx in range(model.cfg.n_heads - 1, -1, -1):
                cur_head_hook_name = f"blocks.{layer_idx}.attn.hook_result"
        
                cur_head = TransformerNode(
                    hook_name=cur_head_hook_name,
                    head_idx=head_idx,
                    # incoming_edge_type=EdgeType.PLACEHOLDER,
                )
    
                graph.nodes.append(cur_head)
        
                for residual_stream_node in residual_stream_nodes:
                    # correspondence.add_edge(
                    #     parent_node=cur_head,
                    #     child_node=residual_stream_node,
                    #     edge=Edge(edge_type=EdgeType.ADDITION),
                    #     safe=False,
                    # )

                    graph.edges.append(TransformerEdge(cur_head, residual_stream_node))

                for letter in "qkv":
                    letter_hook_name = f"blocks.{layer_idx}.attn.hook_{letter}"

                    letter_hook_node = TransformerNode(
                        hook_name=letter_hook_name,
                        head_idx=head_idx,
                        # incoming_edge_type=EdgeType.DIRECT_COMPUTATION
                    )

                    graph.nodes.append(letter_hook_node)

                    letter_input_hook_name = f"blocks.{layer_idx}.hook_{letter}_input"
                    letter_input_hook_node = TransformerNode(
                        hook_name=letter_input_hook_name,
                        head_idx=head_idx,
                        # incoming_edge_type=EdgeType.ADDITION,
                    )
            
                    graph.nodes.append(letter_input_hook_node)

                    # correspondence.add_edge(
                    #     parent_node = hook_letter_node,
                    #     child_node = cur_head,
                    #     edge = Edge(edge_type=EdgeType.PLACEHOLDER),
                    #     safe = False,
                    # )

                    # correspondence.add_edge(
                    #     parent_node=hook_letter_input_node,
                    #     child_node=hook_letter_node,
                    #     edge=Edge(edge_type=EdgeType.DIRECT_COMPUTATION),
                    #     safe=False,
                    # )

                    graph.edges.append(TransformerEdge(letter_hook_node, cur_head))
                    graph.edges.append(TransformerEdge(letter_input_hook_node, letter_hook_node))

                    new_residual_stream_nodes.append(letter_input_hook_name)
    
            residual_stream_nodes.extend(new_residual_stream_nodes)

        token_embed_node = TransformerNode(
            hook_name="hook_embed.W_E",
            # incoming_edge_type=EdgeType.PLACEHOLDER,
        )
    
        pos_embed_node = TransformerNode(
            hook_name="hook_pos_embed.W_pos",
            # incoming_edge_type=EdgeType.PLACEHOLDER,
        )

        # NOTE. For now, we will always include the positional embedding node. 

        embed_nodes = [token_embed_node, pos_embed_node]

        for embed_node in embed_nodes:
            graph.nodes.append(embed_node)

            for residual_stream_node in residual_stream_nodes:
                # correspondence.add_edge(
                #     parent_node=embed_node,
                #     child_node=node,
                #     edge=Edge(edge_type=EdgeType.ADDITION),
                #     safe=False,
                # )

                graph.edges.append(TransformerEdge(embed_node, residual_stream_node))
    
        return graph


In [14]:
graph = TransformerGraph.from_model(model)
graph.edges

[TransformerEdge(parent_node=TransformerNode(hook_name='blocks.1.attn.hook_result', head_idx=7), child_node=TransformerNode(hook_name='blocks.1.hook_resid_post', head_idx=None)),
 TransformerEdge(parent_node=TransformerNode(hook_name='blocks.1.attn.hook_q', head_idx=7), child_node=TransformerNode(hook_name='blocks.1.attn.hook_result', head_idx=7)),
 TransformerEdge(parent_node=TransformerNode(hook_name='blocks.1.hook_q_input', head_idx=7), child_node=TransformerNode(hook_name='blocks.1.attn.hook_q', head_idx=7)),
 TransformerEdge(parent_node=TransformerNode(hook_name='blocks.1.attn.hook_k', head_idx=7), child_node=TransformerNode(hook_name='blocks.1.attn.hook_result', head_idx=7)),
 TransformerEdge(parent_node=TransformerNode(hook_name='blocks.1.hook_k_input', head_idx=7), child_node=TransformerNode(hook_name='blocks.1.attn.hook_k', head_idx=7)),
 TransformerEdge(parent_node=TransformerNode(hook_name='blocks.1.attn.hook_v', head_idx=7), child_node=TransformerNode(hook_name='blocks.1.at