# 0. **Setup**

In [2]:
from utils import Graph, get_dataloader
from utils.interp import *
from utils.dataset.build import generate_planar, generate_misclass 
from models import get_model

from transformer_lens import FactoredMatrix
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import networkx as nx
import torch
from torch import Tensor
import einops
import wandb
from jaxtyping import Float
import numpy as np
import rich
from rich.table import Table

from tqdm.auto import tqdm
from functools import partial

**Load Model**

In [None]:
path = 'trained_models/transformer_2024-09-05_13-28-07'

wandb.init(
    config=path+'.yaml',
    mode='disabled'
)

args = wandb.config

model = get_model(args)
model.eval()
device = model.cfg.device
state_dict = torch.load(path+'.pt',map_location=device)
model.load_state_dict(state_dict)


**Test model on val set and random graphs**

In [None]:
loader = get_dataloader('adj','data/n10_15-21_tn_vn/val.npz',1)
n_planar = n_correct = 0
for x,y in loader:
    # plot_graph(adj_to_graph(x.squeeze()))
    logits = model(x)
    correct = logits.argmax()==y
    n_planar += y.item()*correct
    n_correct += correct

print(f"Val Set Accuracy: {(100*n_correct/len(loader)).item():.2f}%")
print(f"{(100*n_planar/n_correct).item():.2f}% of correctly classified graphs were planar.")
print()

for m in range(22,30):
    n_planar = n_correct = 0
    n_samples = 1000
    for _ in range(n_samples):
        g = Graph(nx.gnm_random_graph(args.n_vertices,m))
        y = nx.is_planar(g.G)
        logits = model(g.A)
        correct = logits.argmax()==y
        n_planar += y*correct
        n_correct += correct

        # print(f"y: {y}")
        # print(f"logits.argmax(): {logits.argmax()}")
        # print(f"correct: {correct}")
        # plt.figure(figsize=(8, 6))
        # nx.draw(g, with_labels=True, node_color='lightblue', edge_color='gray', node_size=500, font_size=10)
        # plt.title("Random Graph")
        # plt.show()

    print(f"Accuracy on {n_samples} random graphs with {m} edges: {(100*n_correct/n_samples).item():.2f}%")
    print(f"{(100*n_planar/n_correct).item():.2f}% of correctly classified graphs were planar.")

**Cache activations from example graphs**

In [129]:
m = 16

random_graph = Graph(nx.gnm_random_graph(model.cfg.n_vertices,m))
petersen_graph = Graph(nx.petersen_graph())
cycle_graph = Graph(nx.cycle_graph(model.cfg.n_vertices))
complete_graph = Graph(nx.complete_graph(model.cfg.n_vertices))
empty_graph = Graph(nx.empty_graph(model.cfg.n_vertices))
star_graph = Graph(nx.star_graph(model.cfg.n_vertices-1))

loader = get_dataloader('adj','data/n10_15-21_tn_vn/val.npz',5000)
batch,batch_labels = next(iter(loader))

model.set_use_attn_result(True)
_, random_cache = model.run_with_cache(random_graph.A)
_, petersen_cache = model.run_with_cache(petersen_graph.A)
_, cycle_cache = model.run_with_cache(cycle_graph.A)
_, complete_cache = model.run_with_cache(complete_graph.A)
_, empty_cache = model.run_with_cache(empty_graph.A)
_, star_cache = model.run_with_cache(star_graph.A)
_, batch_cache = model.run_with_cache(batch)
model.set_use_attn_result(False)

In [None]:
n_planar = 0
for g in batch:
    graph = Graph(g)
    n_planar += nx.is_planar(graph.G)
print(f"{100*n_planar/len(batch):.2f}% of the batch are planar graphs.")

# 1. **Identifying a Circuit**

## 1.1 **Plot attention patterns**

We begin by looking at the attention patterns for specific inputs to the model. In the case of our model, attention patterns $A^h$ are 10 by 10 matrices that represent how each attention head $h$ moves information between the vertex positions of the input graph. For instance, if $A_{ij}^{1.3}$ is large, that means that lots of information is moved from vertex position $j$ to vertex position $i$. We say that vertex position $i$ *attends strongly* to vertex postion $j$.

I took the usual [CircuitsVis](https://github.com/TransformerLensOrg/CircuitsVis) way of visualising language model attention patterns (in which you visualise which tokens an input sentence attend to each other on the tokens themselves) and applied it to graph inputs. This means you can select an attention pattern and hover over a vertex to see which other vertices it attends to. 

It is important to remember that transformers are designed so that the attention patterns in each attention head *depend on the input*. Indeed, if we rename all the vertices of a graph but keep all the edges the same, we still want the vertices to attend to each other in the same way and so the model must adjust its attention patterns accordingly. It is therefore important to study the attnetion patterns for a whole range of input graphs. This way you can get a much better idea of what each head could be doing.

Have a go yourself at looking at the attention patterns for different input graphs. The example below is for the attention patterns in layer 0 for a randomly generated graph, but you can swap in any of the graphs that we cached during setup.

In [None]:
random_graph.plot_attention(random_cache["pattern",0])

**Observation**: In Head 1.6 vertices of low degree attend stongly to vertex position 2. Let's try and work out how the model does this.

My first guess as to how this attention pattern in head 1.6 comes about is:
- Each vertex of low degree has the concept of "I am a vertex of low degree" encoded in its residual stream position. 
- The query matrix for head 1.6 reads this information from the residual stream and produces queries that encode the concept "I am looking for the special vertex position" in low degree vertex positions.
- The key matrix for head 1.6 produces a key that says "I am in the special vertex position" in position 2. 

## 1.2 **Patching**

Let's test this theory by looking at how earlier components contribute to the attention score in head 1.6.

We are going to take a "corrupted" graph with (in some sense) no low degree vertices as model input and then "patch over" the paths leading from a collection of earlier components to the query and/or the key of head 1.6 using the cached head 0.2 activations for a "clean" graph with some low degree vertices. This will allow us to see to what extent we can reproduce the "clean" behaviour of head 1.6 using a subset of the components earlier in the model. We can use this to test my theory that degree information is passed into head 1.6 via the query and the positional information pinpointing vertex position 2 is passed in via the key. We should be able to go one further and actually identify which earlier components are responsible for passing this information on to head 1.6. 

By "components" I mean attention heads and the embedding. 

By "path" I mean the sequence of latent variables that lead from a component directly (without involving another component) to another component later in the model (in this case head 1.6).
"Patching over" a path from component $A$ to component $B$ means effectively replacing all these latent varaibles with what they would be in the hypothetical scenario that component $A$'s output was replaced but all other components outputs stayed the same (even for components that come after $A$ in the model). By doing this we can get a better idea of the extent to which $A$ helps $B$ achieve its function. 

Notice that I choose not to include the mlps as components. Instead we say that mlps count as part of the direct path between other components. This is because (in the case of this specific circuit) it turns out that all the important paths between other components go through the mlp on layer 0. Insisting on studying paths that skip the mlp obscures important relationships that depend on mlp computation.  

Because you can't in general simultaneously patch two different paths during the same forward pass, we implement the above as follows:
1. Choose two subsets of the components that come before head 1.6 in the model (i.e. Embed, Layer 1 heads)
2. Input a "corrupted" graph into the model
3. Replace only the output of the components within our first chosen subset with the equivalent output obtained when the "clean" graph is the input.
4. Continue the model forward pass and cache the query of head 1.6.
5. Begin another forward pass on the same corrupted output and this time replace only the output of the components within our *second* chosen subset.
6. Continue the forward pass until layer 1 and replace the query in head 1.6 with the cached query from the previous run.
7. Observe the resulting attention pattern in head 1.6 and hopefully pinpoint which components are needed to reproduce the "clean" attention pattern and in what way they compose with head 1.6.

This is a specific implementation of a broader technique known as *path patching*. In the cells below we carry out path patching with the following configuration details:
- `q_heads_to_patch` specifies heads $h$ for which the path from $h$ to the query of head 1.6 will be patched.
- `q_patch_embed` specifies whether or not to patch the path from the Embed to the query of head 1.6.
- `k_heads_to_patch` specifies heads $h$ for which the path from $h$ to the key of head 1.6 will be patched.
- `k_patch_embed` specifies whether or not to patch the path from the Embed to the key of head 1.6.

Have a play aroung with this configuration below and see if you can gain some insight that could help with step 7 above.

<details>
  <summary>Two things to note</summary>
  
PosEmbed is the same for all inputs so patching it is pointless.

The cell below runs path patching for several example configurations and, as such, the four variables described above are actually lists with the ith entry of each forming the configuration for the ith path patching example.
  
</details>

In [8]:
corrupted_input = petersen_graph.A
corrupted_cache = petersen_cache
clean_input = random_graph.A
clean_cache = random_cache

layer_1_head = 6

In [None]:
# Components to patch:
q_heads_to_patch = [[2,1],[3,4],[1,2,3]]
q_patch_embed = [True,True,False]
k_heads_to_patch = [[2,3],[2],[1,2,3]]
k_patch_embed = [True,False,False]
 
component_lists = [q_heads_to_patch,q_patch_embed,k_heads_to_patch,k_patch_embed]

n_examples = len(q_heads_to_patch)

assert all(len(lst) == n_examples for lst in component_lists)

px.imshow(random_cache["pattern",1][0,6],color_continuous_scale='Blues',title="Clean Pattern",width=1000).show()
px.imshow(petersen_cache["pattern",1][0,6],color_continuous_scale='Blues',title="Corrupted Pattern",zmax=1,width=1000).show()

for i in range(n_examples):
    pattern = get_path_patched_attn_pattern(
        model,
        corrupted_input,
        corrupted_cache,
        clean_input,
        clean_cache, 
        q_heads_to_patch = q_heads_to_patch[i],
        q_patch_embed = q_patch_embed[i],
        k_heads_to_patch = k_heads_to_patch[i],
        k_patch_embed = k_patch_embed[i],
        layer_1_head = layer_1_head
    )

    q_embed_string = " Embed and" if q_patch_embed[i] else ""
    k_embed_string = " Embed and" if k_patch_embed[i] else ""
    title = f"Pattern when the paths from{q_embed_string} heads {q_heads_to_patch[i]} to the query of head 1.{layer_1_head}<br>and the paths from{k_embed_string} heads {k_heads_to_patch[i]} to the key of head 1.{layer_1_head} are patched."

    px.imshow(
        pattern,
        color_continuous_scale='Blues',
        title=title,
        zmax=1,
        width=1000
    ).show()

Having played around with the setup above, I made the following observations:
1. I cannot reproduce any of the clean pattern purely by patching paths into the key of head 1.6. This suggests there is little K-composition occuring. This does not in any way rule out my theory that the information pinpointing vertex position 2 enters head 1.6 via the key (but it doesn't support it either).
2. Configurations involving patching the path from head 0.2 to the query of head 1.6 tend to be the ones that reproduce the clean pattern most closely. This suggests that head 0.2 is Q-composing with head 1.6.
3. In configurations that involve head 0.2 but still miss some key parts of the pattern, adding in heads 0.1, 0.3 and/or Embed can help performance. This suggests more than one head is involved in the circuit responsible for the function of head 1.6. It also suggests that Embed has direct effect on head 1.6, not just via layer 0 heads.

To be more confident in observations 2 and 3, we can quantify how well the pattern has been reproduced and then repeat the patching process above for all possible configurations and across a whole batch of graphs, recording how well the pattern has been reproduced each time.
When I say "all possible configurations" I am actually ignoring configurations involving patching paths from layer 0 heads to the key of head 1.6. This to reduce the number of possible configurations to a manageable amount. Based on observation 1, I am confident that this does not missing a configuration that offers a notably good reconstruction of the clean attention pattern.

I quantify the pattern reconstuction performance using the `patching_metric` below (scaled so that 1 represents no improvement over the corrupted input and 0 represents perfect reconstruction).

In [None]:
def patching_metric(
    reconstructed_pattern: Float[Tensor,"batch n_vertices n_vertices"],
    clean_pattern: Float[Tensor,"batch n_vertices n_vertices"],
    corrupted_pattern: Float[Tensor,"batch n_vertices n_vertices"] # batch can be 1 in some cases but broadcasting handles this
):
    return (torch.linalg.matrix_norm(reconstructed_pattern-clean_pattern)/torch.linalg.matrix_norm(clean_pattern-corrupted_pattern)).mean().item()

def get_component_string(results_key):
    """Given key from results dict, return dict expressing which components that key corresponds to."""
    q_components = 'Embed '*results_key[0][0] + ' '.join(f'0.{head}' for head in results_key[0][1:])
    k_components = 'Embed' if results_key[1][0] else ''
    return q_components, k_components

def get_n_components_patched(results_key):
    """Given key from results dict, find total number of components that have been patched"""
    return sum([
        results_key[0][0],
        results_key[1][0],
        len(results_key[0][1:]),
        len(results_key[1][1:])
    ])



results = get_path_patching_metric_results(
    model,
    corrupted_input,
    corrupted_cache,
    batch,
    batch_cache,
    patching_metric,
    layer_1_head
)
   
table = Table('Number of components patched','Best Score Achieved','Q Components Patched','K Components Patched',title='Best paths to patch for each number of components patched:')
one_component_table = Table("Q component", "K Component","Pattern Reconstruction Score", title='Results from patching one component')

for n_components_patched in range(1,model.cfg.n_heads+2):
    filtered_results = {k:v for k,v in results.items() if get_n_components_patched(k)==n_components_patched}
    if n_components_patched==1:
        for k,v in filtered_results.items():
            q_component, k_component = get_component_string(k)
            one_component_table.add_row(q_component, k_component,f"{v:.2f}")
    best_combination = min(filtered_results, key=filtered_results.get)
    best_q_components, best_k_components = get_component_string(best_combination)
    best_metric_score = filtered_results[best_combination]
    table.add_row(str(n_components_patched),f"{best_metric_score:.2f}",best_q_components,best_k_components)

rich.print(one_component_table)
rich.print(table)

The results above make me pretty confident that head 0.2 Q-composing with head 1.6 is the main interaction driving the circuit behind head 1.6's function.

Note that still haven't obtained any evidence supporting my theory that the information pinpointing vertex position 2 enters head 1.6 via the key. While it makes sense the key could have learned to produce a special key vector in position 2 that matches each vector that the query puts in a low degree vertex position, for all we know this information could be passed in via some other less intuitive mechanism. We will come back to address this theory later. 

## 1.3 **Head 0.2**

For now, let's follow this trail back through the model by looking more closely at head 0.2. We will again start by studying the attention pattern.

In [None]:
random_graph.plot_attention(random_cache['pattern',0])

**Observation**: Vertices of high degree attend more to vertices 3 and 6, whereas vertices of low degree seem to attend evenly to all vertices.

Although this pattern seems conceptually similar to head 1.6, there are three main differences that I observe:
1. Instead of vertex position 2, vertices attend to vertex positions 3 and 6.
2. Instead of low degree vertices, it is the high degree vertices that attend strongly to the special vertex positions.
3. Instead of seemingly all "low degree" vertices attending maximally to the special vetex position (like in head 1.6), here it seems like attention scales with degree: the higher a vertex's degree, the more strongly it attends to vertex positions 3 and 6.

Based on these observations, I have a similar theory for this heads mechansim as I did for head 1.6: the query receives the degree information and the key receives the positional information required for pinpointing positions 3 and 6.

As before, we will now test this theory. However, now that we are studying a layer 0 head we have two simplifications that we can take advantage of:
1. There are only two inputs to this head: the Embed and PosEmbed.
2. There are no non-linear mlps between us and the model input. 

This means we can use some more straight forward techniques to study how the Embed and PosEmbed interact with head 0.2.

Let $X$ be an adjacnecy matrix model input, let $W_Q^h$ and $W_K^h$ be the query and key matrices for head $h$ and let $W_{pos}$ and $W_E$ be the positional embedding and embedding matrices. Then the query $q$ and key $k$ of head 0.2 are given by:

$$q = XW_EW_Q^{0.2}+ W_{pos}W_Q^{0.2}$$

$$k = XW_EW_K^{0.2}+ W_{pos}W_K^{0.2}$$

This means we can nicely decompose $q$ and $k$ into their two terms and quantify how much each one contributes to the result. We will do this by taking the norm of each component of the sum and averaging out across a batch of inputs: 

In [None]:
decomposed_qk_input_layer_0 = decompose_qk_input(model, batch_cache, layer=0)
decomposed_q_layer_0 = decompose_q(model,decomposed_qk_input_layer_0,head_index=2,layer=0)
decomposed_k_layer_0 = decompose_k(model,decomposed_qk_input_layer_0,head_index=2,layer=0)

component_labels_layer_0 = ["Embed", "PosEmbed"]
for decomposed_input, name in [(decomposed_q_layer_0, "query"), (decomposed_k_layer_0, "key")]:
    px.imshow(
        decomposed_input.pow(2).sum(-1).sqrt().mean(0).detach(),
        labels={"x": "Position", "y": "Component"},
        title=f"Norms of components of {name}",
        y=component_labels_layer_0,
        color_continuous_scale='Blues',
        width=1000, height=400,
        zmin=0
    ).show()


As expected, these plots suggest that all the degree information enters head 0.2 via the query. Indeed, degree information can only come from the Embed and when averaged out across a whole batch, we expect this information to be even distributed across vertex positions. This matches what we observe in the plots above: specifically a solid dark blue stripe across the Embed row in the first plot as well as a very light blue strip in the same place in the second plot, suggesting that the key cannot be using much degree information. 

These plots also support my theory that the positional information pinpointing vertex positions 3 and 6 enters via the key. Indeed the second plot suggests that the information in these postions in the PosEmbed is consistently of particular importance to the key but the query seems to give similar importance to all positions in PosEmbed.

One additional observation of interest is that, as well as the degree information from the Embed, it seems that there is quite a lot of information from the PosEmbed entering head 0.2 via the query (suggested by the blue stripe across the PosEmbed row of the first plot). While this does not contradict any of my theories, it does suggest that there could be something more happening in head 0.2. Similarly, the are positions other than 3 and 6 in the key with reasonably sized PosEmbed component norms, also suggesting that there could be a more subtle purpose to head 0.2 as well.



One problem with the visualisations above is that, while it believable that it is correlated with the importance of a component, the norm obscures a lot of the detail when it comes to how the PosEmbed and Embed effect head 0.2. We can get a better idea of their effect by recalling that the attention scores for head 0.2 are given by $qk^{T}/\sqrt{d_{head}}$. This means we can use our decompositions of $q$ and $k$ from before to decompose the attention scores into four terms corrseponding to the contribution of each pairing of the two query components and the two key components to the overall attention score. This gives us a clearer idea of the impact each pairing of components has.

For this visualisation we have to specifiy a single graph input with `batch_sample_idx` as opposed to avergaing over a batch. Indeed, averaging over the norms makes some sense but attention scores are very input dependent so averaging over them would just blur a lot of detail.

In [None]:
decomposed_scores_layer_0 = decompose_attn_scores(decomposed_q_layer_0, decomposed_k_layer_0)

plot_decomposed_attn_scores(
    decomposed_scores=decomposed_scores_layer_0,
    cache=batch_cache,
    component_labels=component_labels_layer_0,
    layer=0,
    head_idx=2,
    batch_sample_idx=100,
    zmax=3.2
)

These plots offer much more support to my theory. The most important interactions for creating the attention pattern in head 0.2 are clearly the degree information passing from the Embed to the head via the query and the positional information passing from the PosEmbed to the head via the key. 

However, as we suspected, this is not the whole story. In particular, it seems that the query uses information from the PosEmbed to uniformly subtract from the attention scores in positions 3 and 6 (shown in the bottom right of the grid of four plots above). If the head's purpose is simply to allow vertices to attend to postions 3 and 6 based on their degree (in line with our observations), then this uniform subtraction seems unecessary. Is this just a quirk of the model or does it represent some key computation with an as yet unkown puropse? We will explore this more in the next section.    

# 2. **Zooming in**: following the circuit step by step

Now that we have a broad overview of which components are composing to form the circuit responsible for head 1.6's function, it is time to attempt to follow this circuit through the network from input to head 1.6 in more detail.

Hopefully by doing this we can answer questions like:
- How exactly does $W_Q^{0.2}$ give the query vectors a concept of degree based on $XW_E$?
- What does the mlp layer do to the information moving between heads 0.2 and 1.6?

We begin by trying to answer the first of these questions.

For this section we focus on one single input which we set below. I encourage you to swap this out for another graph that we cached during the setup in order to see how the circuit behaves for different inputs.

In [39]:
graph = random_graph
cache = random_cache

## 2.1 **Input -> Head 0.2**

Recall that the query vectors are given by:

$q = XW_EW_Q^{0.2}+ W_{pos}W_Q^{0.2}$,

and that multiplying by $X$ is just summing the rows of $W_EW_Q^{0.2}$ corresponding to the vertices adjacent to each vertex. 

This means that studying the rows of $W_EW_Q^{0.2}$ and $W_{pos}W_Q^{0.2}$ will help us decompose what $W_Q^{0.2}$ does into smaller parts instead of looking at what it does to the entire residual stream.

In [None]:
layer = 0
head = 2

W_E = model.W_E.detach()
W_pos = model.W_pos.detach()
W_Q = model.W_Q[layer,head].detach()
W_K = model.W_K[layer,head].detach()
resid = cache['resid_pre',0][0]

graph.plot()
px.imshow(W_E,color_continuous_scale='RdBu',title='W_E',width=1500).show()
px.imshow(W_pos,color_continuous_scale='RdBu',title='W_pos',width=1500).show()
px.imshow(W_Q,color_continuous_scale='RdBu',title='W_Q',width=1500).show()
px.imshow(W_E@W_Q,color_continuous_scale='RdBu',title='W_E @ W_Q',width=1500,zmax=0.5,zmin=-0.5).show()
px.imshow(W_pos@W_Q,color_continuous_scale='RdBu',title='W_pos @ W_Q',width=1500,zmax=0.5,zmin=-0.5).show()
px.imshow(resid@W_Q,color_continuous_scale='RdBu',title='q (= resid @ W_Q)',width=1500).show()


From these plots we can observe that, given a vertex in position $i$ with degree $d_i$, $W_Q^{0.2}$ maps the embedding of all $d_i$ adjacent vertices (found in the rows of $W_E$) to pretty much the same vector $\mathbf{v}$ (represented by the verticle stripes in the plot of $W_EW_Q^{0.2}$ above). It also maps the positional embedding of $i$ (found in the rows of $W_{pos}$) to $-2\mathbf{v}$. As described earlier, multiplying by $X$ sums the results in the rows of $W_EW_Q^{0.2}$ meaning that row $i$ of $XW_EW_Q^{0.2}$ contains the vector $d_i\mathbf{v}$. By then summing the results with $W_{pos}W_Q^{0.2}$ we are subtracting $2\mathbf{v}$ meaning the query of head 0.2 has the information $d_i-2$ stored in each vertex position $i$. 

By identifying this direction in the latent space of head 0.2's QK circuit corresponding to the concept of degree, we can read off the degree of each vertex directly from the query $q$ plotted above.  

Given this insight, we now expect the key vectors in positions 3 and 6 of head 0.2 to also point in the direction $\mathbf{v}$, therefore causing vertices with higher degree to attend more strongly to those positions. This means that the $W_{pos}W_K^{0.2}$ should make up the majority of the key $k$ (in line with the earlier decomposition analysis). Sure enough, when we plot $W_EW_K^{0.2}$ and $W_{pos}W_K^{0.2}$, our expectations are met: 

In [None]:
px.imshow(W_K,color_continuous_scale='RdBu',title='W_K',width=1500).show()
px.imshow(W_E@W_K,color_continuous_scale='RdBu',title='W_E @ W_K',width=1500,color_continuous_midpoint=0,zmax=1.25,zmin=-1.25).show()
px.imshow(W_pos@W_K,color_continuous_scale='RdBu',title='W_pos @ W_K',width=1500,color_continuous_midpoint=0,zmax=1.25).show()
px.imshow(resid@W_K,color_continuous_scale='RdBu',title='k (= resid @ W_K)',width=1500).show()

## 2.2 **Head 0.2 -> Head 1.6**

Having reverse engineered the computation behind head 0.2's attention pattern, we now look at how that pattern is used to write information out to the residual stream. If $W_V^h$ is the value matrix for head $h$, then $W_V^{0.2}$ essentially appears to be doing that same thing as $W_K^{0.2}$, except with a different direction corresponding to degree:

In [None]:
W_V = model.W_V[layer,head].detach()
z = cache['z',layer][0,:,head]
result = cache['result',layer][0,:,head]

px.imshow(W_V,color_continuous_scale='RdBu',title='W_V',color_continuous_midpoint=0).show()
px.imshow(W_E@W_V,color_continuous_scale='RdBu',title='W_E @ W_V',color_continuous_midpoint=0).show()
px.imshow(W_pos@W_V,color_continuous_scale='RdBu',title='W_pos @ W_V',color_continuous_midpoint=0).show()
px.imshow(resid@W_V,color_continuous_scale='RdBu',title='v (= resid @ W_V)',color_continuous_midpoint=0).show()
px.imshow(z,color_continuous_scale='RdBu',title='z (attention pattern applied to v)',color_continuous_midpoint=0).show()
px.imshow(result,color_continuous_scale='RdBu',title=f'result (output of head 0.{head})',color_continuous_midpoint=0).show()

From these plots we can see how the value has the attention pattern applied to it (see z plot) and then is written to the residual stream (see result plot). In every plot we can see the patterns caused by every row encoding degree in the same direction in the corresponding latent space. We now have a good understanding of how head 0.2 writes degree information to the residual stream for head 1.6 to read and use for its task.

Between heads 0.2 and 1.6 we have an mlp layer, which makes it hard to track what happens to the output of a specific head. Indeed, because the mlp is non-linear we can't decompose the post mlp residual stream into nice components like we did for the residual stream input into the layer 0 attention heads.
In the general case, just because a layer 0 head outputs a vector in the residual stream pointing in one direction does not mean this vector will be available to be read by a layer 1 head, the mlp can easily mess it up. 

This is a problem for head 1.6. As we have seen above, head 0.2 (among others) has written a result to the residual stream that contains the degree of each vertex as a scaled version of the same vector in each vertex position. In other words, we have an approximate direction in the residual stream that corresponds to the degree of a vertex. For head 1.6's purposes, it would be nice if the mlp did not tamper with this direction at all. Unfotunately, the mlp has lots of other jobs to do and so this nice interpretable direction appears to be lost. To visualise this, we plot two things below:
1. The output obtained when you pass the output of head 0.2 ('result' plotted above) through the mlp. You can think of this as the "theortecial output of the mlp if the output of head 0.2 was the only information the model wanted to pass to the next layer".
2. The full residual stream after the mlp on layer 0 below.

Compare these plots with the 'result' plot above to see the effect the mlp's non-linearity has on one head's output vs when it has to handle a full residual streams worth of information.

In [None]:
mlp_0 = model.blocks[0].mlp

post_mlp = (mlp_0(result)).detach()


px.imshow(post_mlp,color_continuous_scale='RdBu',title=" Theoretical/fake 'Component' of post mlp 0 residual stream orignating from head 0.2").show()
px.imshow(cache['resid_post',0][0],color_continuous_scale='RdBu',title='Post mlp 0 residual stream').show()

Studying those plots does illustrate how even if the mlp is "theoretically" capable of passing a (different, but still clearly interpretable) degree direction to the next layer (shown by the pattern in the first of the two plots above), in general we probably wont have one residual direction from which to "read off" information like a vertex's degree. 

However, in this specific case, I can't ignore the similarity between the entries in the 68th basis direction of the residual stream across all of the last three plots. It might just be a coincidence, but the pattern in that column matches the degree information we want to pass on (low values corresponding to high degree vertices and vice versa). Have a go at switching the input graph at the start of this section and seeing if you can replicate this observation in other cases. 

While this could well be an interesting observation, it only has any relevance to head 1.6 if it actually takes adavntage of this apparent preservation of the degree direction by reading information from the 68th column of the residual stream. Let's plot $W_Q^{1.6}$ to get a better idea of where the query information is read in from (we will plot its transpose for easier visualisation, but remember it has shape (128, 32)): 

In [None]:
W_Q_1_6 = model.W_Q[1,6].detach()

px.imshow(W_Q_1_6.T,color_continuous_scale='RdBu',title='W_Q').show()

This plot makes me pretty confident that what we observed before is not a coincidence and that the non-linear mlp is preserving the degree pattern in a linear way by rotating and scaling it to roughly line up with the 68th basis direction of the residual stream.

We can gather more evidence for this by taking a batch of input graphs and, for each graph, doing the following:
1. Consider the output of head 0.2 and find the residual direction corresponding to degree that this head writes to.
2. For each vertex find the coefficient of this direction vector that scales the vector, encoding the degree of the vertex.
3. Form the 10 dimensional vector of these coefficients. This can be thought of as the "degree signature" of that specific graph and can be clearly visualised by looking at the columns of the output of head 0.2 ('result' plot above).
4. Record the cosine similarity between the degree signature and each column of the post mlp residual stream.

Once completed for each graph, average these results across the whole batch.

These cosine similarities us the extent to which the relative sizes of the vectors in the "degree direction" pre mlp match the relative sizes of the vectors in the "$i^{th}$ basis direction" post mlp (for each $i$ in $[0,...,127]$).

In [130]:
def get_degree_signature_resid_basis_cossim(pre_mlp,post_mlp):
    return einops.einsum(
        pre_mlp/pre_mlp.norm(dim=1).unsqueeze(1),post_mlp/post_mlp.norm(dim=1).unsqueeze(1),
        "batch n_vertices d_model, batch n_vertices d_model -> d_model "
    ).detach()/pre_mlp.shape[0]

post_mlp = batch_cache['resid_pre',1]
head_0_2_output_pre_mlp = batch_cache['result',0][:,:,2]

cossims = get_degree_signature_resid_basis_cossim(head_0_2_output_pre_mlp,post_mlp)

px.line(
    y=cossims,
    labels={"x": "Resid stream basis direction", "y": "Degree signature cosine similarity"},
).show()


In [None]:
px.imshow(post_mlp@W_Q_1_6,color_continuous_scale='RdBu',title='component @ W_Q').show()
px.imshow(cache['resid_pre',1][0]@W_Q_1_6,color_continuous_scale='RdBu',title='q (= full residual stream @ W_Q)').show()


We also observe that $W_Q^{1.6}$ exctracts the degree information to give a query, $q$, exhibiting a simlar structure to the query from head 0.2 (except with a different latent direction for degree). The realtive sizes of the query vectors in each row of $q$ are slightly less correlated with degree than they were in head 0.2, but there is still more than enough information in these vectors to identify vertices of low degree.    

With that, I think we have done enough to say that we have reverse engineered the circuit behind the function of head 1.6, we even followed the input step by step through each layer and identified the directions in each latent space that corresponded to degree. There are of course small contributions to the input to head 1.6 that come from heads other than 0.2 but&mdash;as we saw in our activation patching results&mdash;I don't think they are contributing anything vitally improtant for the behviour of head 1.6 that head 0.2 isn't providing by itself.

The obvious next question is: "what does any of this have to do with the planarity of a graph?". I am not sure at all. My best guess so far is that the model could be leveraging the fact that *vertices of degree less than 3 are not relevant to planarity*. Indeed for vertices of degree 0 or 1, removing them from a graph is never going to change a graph's planarity. For vertices of degree 2, replacing them with a single edge between their adjacent vertices is never going to change a graph’s planarity. Perhaps the purpose of head 1.6 is to help the model ignore this redundant information when making its decision? If true, it could help explain the significance of $W_{pos}W_Q^{0.2}$ representing “-2 in the degree direction”: it could be creating a “cutoff” below which all vertices should be in some sense ignored.

# Unused Stuff

## 1.2 **Decompose input**

Let's test this theory by looking at how earlier components contribute to the query and key in head 1.6.

To do this we take the input into head 1.6 (i.e. the residual stream post layer 0, $x_1$) and think of it as a sum of the outputs of all the previous components 10 components $y_i$ (i.e. the 2 embeddings and 8 layer 0 heads) which is then fed through the mlp on layer 0 $f_0$:
$$
x_1 = f_0\left(\sum_{i = 0}^9 y_i\right) + \sum_{i = 0}^9 y_i = f_0\left(gW_E + W_{pos} + \sum_{h}A^h x_0 W_{OV}^h\right) + gW_E + W_{pos} + \sum_{h}A^h x_0 W_{OV}^h
$$
Decomposing $x_1$ into each of its terms allows us to study the impact each component has on each position in the query $q$ and key $k$ of head 1.6:
$$
q = \left(\sum_{i = 0}^9 y_i\right)W_Q^{1.6} = \sum_{i = 0}^9 y_iW_Q^{1.6} = gW_EW_Q^{1.6} + W_{pos}W_Q^{1.6} + \sum_{h}A^h x_0 W_{OV}^hW_Q^{1.6}
$$


In [None]:
head_index = 6

decomposed_qk_input = decompose_qk_input(model,batch_cache)

actual_resid = batch_cache['resid_post',0][0]
naive_resid = decomposed_qk_input.sum(1)[0].detach()

px.imshow(actual_resid,color_continuous_scale='RdBu',zmax=3.5,zmin=-3.5).show()
px.imshow(naive_resid,color_continuous_scale='RdBu',zmax=4.5,zmin=-4.5).show()
px.imshow(actual_resid-naive_resid,color_continuous_scale='RdBu',zmax=4.5,zmin=-4.5).show()
px.imshow(model.W_Q[1,6].detach(),color_continuous_scale='RdBu').show()

decomposed_q = decompose_q(model, decomposed_qk_input, head_index)
decomposed_k = decompose_k(model, decomposed_qk_input, head_index)

component_labels = ["Embed", "PosEmbed"] + [f"0.{h}" for h in range(model.cfg.n_heads)]
for decomposed_input, name in [(decomposed_q, "query"), (decomposed_k, "key")]:
    px.imshow(
        decomposed_input.pow(2).sum(-1).sqrt().mean(0).detach(),
        labels={"x": "Position", "y": "Component"},
        title=f"Norms of components of {name}",
        y=component_labels,
        color_continuous_scale='Blues',
        width=1000, height=400
    ).show()

If we believe that the Norms of the components of the query and the key of head 1.6 are correlated with the importance of that component in the function of head1.6, then these plots support my theory that any degree information is passed into head 1.6 via Q-composition from components earlier in the model.
Indeed if it was passed via K-composition then we would expect to see components of the key which have high norms across all positions, as low degree vertices can occur at any position.

**Further Observations**: 
- We also see that when creating keys, head 1.6 pretty much only cares about the positional information coming from position 2. This suggests that the concept of "I want to gather low degree vertices"&mdash;that is encoded in the position 2 key&mdash;comes entirely from head 1.6 and is not passed on from an earlier component.
- The query plot suggests that, out of the previous components, Embed, PosEmbed and head 0.2, are primarily responsible for passing on degree infomation to the vertex positions in head 1.6, with heads 0.1 and 0.3 playing a less important role.

These plots are not fully convincing though, as the norm is not obviously a principled indicator of causality from earlier components to later ones. We can get a better idea of which components are causaly important for the function of head 1.6 by taking a single example graph and pairing up individual key and query components and seeing what attention scores they produce:

In [None]:
decomposed_scores = decompose_attn_scores(decomposed_q, decomposed_k)

plot_decomposed_attn_scores(
    decomposed_scores=decomposed_scores,
    cache=batch_cache,
    component_labels=component_labels,
    head_idx=6,
    batch_sample_idx=100
)

This strongly supports my theory: it would be a pretty big coincidence if the output space of head 0.2 happend to be the only layer 0 head to have a strong overlap with the query input space of head 1.6 and head 0.2 wasn't integral to head 1.6's functionality. Similarly, we see that the output space of PosEmbed has a strong overlap with the key input space of head 1.6.

## 1.3 **Composition scores**

In [None]:
head_idx = 2
out = (mlp_0(W_OV[0]) + W_OV[0]).detach()
px.imshow(out[head_idx],color_continuous_scale='RdBu').show()
px.imshow(mlp_0(W_OV[0])[head_idx].detach(),color_continuous_scale='RdBu').show()
px.imshow(W_OV[0,head_idx].detach(),color_continuous_scale='RdBu').show()

In [None]:
# Get all QK and OV matrices
W_QK = model.W_Q @ model.W_K.transpose(-2, -1)
W_OV = model.W_V @ model.W_O

mlp_0 = model.blocks[0].mlp

# note: adding embed and pos_embed to output spaces may not be very principled (it was my idea and I don't know if anyone else does it)
output_space = [
    (mlp_0(model.W_E) + model.W_E).unsqueeze(dim=0),
    (mlp_0(model.W_pos) + model.W_pos).unsqueeze(dim=0),
    *list(mlp_0(W_OV[0]) + W_OV[0])
]

input_space = {
    'Q': W_QK[1],
    'K': W_QK[1].transpose(-2,-1),
    'V': W_OV[1]
}

# Define tensors to hold the composition scores
composition_scores = {
    "Q": torch.zeros(model.cfg.n_heads+2, model.cfg.n_heads).to(device),
    "K": torch.zeros(model.cfg.n_heads+2, model.cfg.n_heads).to(device),
    "V": torch.zeros(model.cfg.n_heads+2, model.cfg.n_heads).to(device),
}

for i,out_sp in enumerate(output_space):
    for j in range(model.cfg.n_heads):
        for comp_type in "QKV":
            composition_scores[comp_type][i, j] = get_comp_score(out_sp, input_space[comp_type][j])

# baseline
n_samples = 300
comp_scores_baseline = np.zeros(n_samples)
for i in tqdm(range(n_samples)):
    comp_scores_baseline[i] = generate_single_random_comp_score(model)

# px.histogram(
#     comp_scores_baseline,
#     nbins=50,
#     width=800,
#     labels={"x": "Composition score"},
#     title="Random composition scores"
# )

for comp_type in "QKV":
    plot_comp_scores(
        model,
        composition_scores[comp_type], 
        title=f"{comp_type} Composition Scores",
        component_labels=component_labels,
        baseline=comp_scores_baseline.mean()
    )

## **OV circuit**

In [59]:
def get_OV_circuit(model,layer,head):
    W_O = model.W_O[layer, head]
    W_V = model.W_V[layer, head]
    W_E = model.W_E
    W_U = model.W_U

    OV_circuit = FactoredMatrix(W_V, W_O)
    full_OV_circuit = W_E @ OV_circuit @ W_U
    return full_OV_circuit, OV_circuit

In [None]:
zmax = 0.4

layers = list(range(model.cfg.n_layers))
heads = list(range(model.cfg.n_heads))
fig = make_subplots(
    rows=2, 
    row_titles=[str(i) for i in layers],
    cols=8,
    column_titles=[str(i) for i in heads]
)

for layer in layers:
    for head_index in heads:
        full_OV_circuit, OV_circuit = get_OV_circuit(model,layer,head_index)
        heatmap = go.Heatmap(z=OV_circuit.AB.detach(),zmin=-zmax,zmax=zmax,colorscale='RdBu')
        fig.add_trace(heatmap, row=layer+1, col=head_index+1)
        fig.update_yaxes(autorange='reversed')

# Update the layout for better visualization
fig.update_layout(height=600, width=2000, showlegend=False)

# Show the plot
fig.show()


In [None]:
px.imshow(
    get_OV_circuit(model,0,2)[1].AB.detach(),
    color_continuous_scale="RdBu"
).show()

## **Visualise neurons**

In [24]:
def plot_acts(model,cache):
    """
    Plots mean neuron activations for all mlp layers in cache.
    """ 
    assert len(cache['embed'].shape) > 2, "No batch dimension found."
    for i in range(model.cfg.n_layers):
        px.imshow(
            cache["mlp_post",i].mean(0), # mean over batch dimension
            y=[f"0.{h}" for h in range(model.cfg.n_vertices)],
            labels={"x": "Neurons", "y": "Vertices"},
            title=f"Layer {i} Neuron Activations",
            color_continuous_scale="Blues"
        ).show()

In [None]:
plot_acts(model,batch_cache)

## **Correlation between neuron activation and vertex degree**

In [123]:
def get_degree_act_correlation(
    input: Float[Tensor,"batch n_vertices n_vertices"],
    acts: Float[Tensor,"batch n_vertices d_mlp"]
):
    degrees = input.sum(dim=-1) # batch n_vertices
    degrees_flattened = einops.rearrange(
        degrees,
        "batch n_vertices -> 1 (batch n_vertices)"
    )

    acts_flattened = einops.rearrange(
        acts,
        "batch n_vertices d_mlp-> d_mlp (batch n_vertices)"
    )

    degree_acts_stack = torch.cat([degrees_flattened,acts_flattened]) # (batch n_vertices) d_mlp+1
    print(degree_acts_stack.shape)
    correlations = torch.corrcoef(degree_acts_stack)[0,1:] # d_mlp

    return correlations

In [124]:
loader = get_dataloader('adj','data/n10_15-21_tn_vn/val.npz',128)
input, _ = next(iter(loader))
_, cache = model.run_with_cache(input)
degree_act_correlations = get_degree_act_correlation(input, cache['resid_pre',1])
px.line(
    y=degree_act_correlations.detach(),
    labels={"x": "Neuron", "y": "Correlation with degree"},
    title=f"Correlation between neuron activation and degree"
).show()

torch.Size([129, 1280])


## **Logit attribution**

In [None]:
component_results, component_captions = batch_cache.get_full_resid_decomposition(return_labels=True, expand_neurons=False)
layer_results = torch.stack([v for k,v in batch_cache.items() if 'out' in k])
layer_captions = [k for k in cache if 'out' in k]

logit_diff_directions = get_logit_diff_directions(model,batch_labels)

component_logit_attribution = get_logit_attribution(component_results,logit_diff_directions)
layer_logit_attribution = get_logit_attribution(layer_results,logit_diff_directions)

plot_logit_attribution(component_logit_attribution,component_captions)
plot_logit_attribution(layer_logit_attribution,layer_captions)

**Investigate 1_mlp_out further**

In [None]:
neuron_results = einops.rearrange(
    batch_cache.get_neuron_results(layer=1),
    "batch n_vertices n_neurons d_model -> n_neurons batch n_vertices d_model"
)

neuron_logit_attribution = get_logit_attribution(neuron_results,logit_diff_directions)

plot_logit_attribution(neuron_logit_attribution,[str(i) for i in range(model.cfg.d_mlp)])

## **Make datasets of specific examples**

In [None]:
dataset_kwargs = {
    'n': model.cfg.n_vertices,
    'size': 1099,
    'start': 15,
    'end': 22
}

planar_graphs = generate_planar(**dataset_kwargs)

non_planar_graphs = generate_planar(**dataset_kwargs, non_planar=True)

In [None]:
misclassified_graphs, correct_labels = generate_misclass(
    model=model,
    start=15,
    end=17,
    size=4
)

In [None]:
for adj in misclassified_graphs:
    g = Graph(adj)
    is_planar,cert = nx.check_planarity(g.G, counterexample=True)
    print(is_planar)
    Graph(cert).plot()

**Study average activations across sets of specific examples**

In [None]:
model.set_use_attn_result(True)
_, planar_cache = model.run_with_cache(planar_graphs)
_, non_planar_cache = model.run_with_cache(non_planar_graphs)
_, misclassified_cache = model.run_with_cache(misclassified_graphs)
model.set_use_attn_result(False)

plot_acts(model,planar_cache)
plot_acts(model,non_planar_cache)
plot_acts(model,misclassified_cache)

**Logit attributions across sets of specific examples**

In [None]:
planar_neuron_results = einops.rearrange(
    planar_cache.get_neuron_results(layer=1),
    "batch n_vertices n_neurons d_model -> n_neurons batch n_vertices d_model"
)

non_planar_neuron_results = einops.rearrange(
    non_planar_cache.get_neuron_results(layer=1),
    "batch n_vertices n_neurons d_model -> n_neurons batch n_vertices d_model"
)

misclassified_neuron_results = einops.rearrange(
    misclassified_cache.get_neuron_results(layer=1),
    "batch n_vertices n_neurons d_model -> n_neurons batch n_vertices d_model"
)

planar_logit_diff_directions = get_logit_diff_directions(model,torch.ones(dataset_kwargs['size'],dtype=torch.int64))
non_planar_logit_diff_directions = get_logit_diff_directions(model,torch.zeros(dataset_kwargs['size'],dtype=torch.int64))
misclassified_logit_diff_directions = get_logit_diff_directions(model,torch.zeros(len(misclassified_graphs),dtype=torch.int64))

planar_neuron_logit_attribution = get_logit_attribution(planar_neuron_results,planar_logit_diff_directions)
non_planar_neuron_logit_attribution = get_logit_attribution(non_planar_neuron_results,non_planar_logit_diff_directions)
misclassified_neuron_logit_attribution = get_logit_attribution(misclassified_neuron_results,misclassified_logit_diff_directions)

plot_logit_attribution(planar_neuron_logit_attribution,[str(i) for i in range(model.cfg.d_mlp)])
plot_logit_attribution(non_planar_neuron_logit_attribution,[str(i) for i in range(model.cfg.d_mlp)])
plot_logit_attribution(misclassified_neuron_logit_attribution,[str(i) for i in range(model.cfg.d_mlp)])