# **Setup**

In [1]:
from utils import get_dataloader
from utils.graph import plot_attention, graph_to_adj, adj_to_graph, plot_graph
from utils.dataset.build import generate_planar, generate_misclass 
from models import get_model, Transformer


from transformer_lens import ActivationCache, utils, FactoredMatrix
from transformer_lens.hook_points import HookPoint 
import circuitsvis as cv
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import networkx as nx
import torch
from torch import Tensor
import einops
import wandb
from jaxtyping import Float, Int
import numpy as np

import yaml
from tqdm.auto import tqdm
from typing import Optional, Callable, Union, Sequence, List
from functools import partial

# **Load Model**

In [None]:
path = 'trained_models/transformer_2024-09-05_13-28-07'

wandb.init(
    config=path+'.yaml',
    mode='disabled'
)

args = wandb.config

model = get_model(args)
model.eval()
device = model.cfg.device
state_dict = torch.load(path+'.pt',map_location=device)
model.load_state_dict(state_dict)


# **Test model on val set and random graphs**

In [None]:
loader = get_dataloader('adj','data/n10_15-21_tn_vn/val.npz',1)
n_planar = n_correct = 0
for x,y in loader:
    # plot_graph(adj_to_graph(x.squeeze()))
    logits = model(x)
    correct = logits.argmax()==y
    n_planar += y.item()*correct
    n_correct += correct

print(f"Val Set Accuracy: {(100*n_correct/len(loader)).item():.2f}%")
print(f"{(100*n_planar/n_correct).item():.2f}% of correctly classified graphs were planar.")
print()

for m in range(11,23):
    n_planar = n_correct = 0
    n_samples = 1000
    for _ in range(n_samples):
        g = nx.gnm_random_graph(args.n_vertices,m)
        y = nx.is_planar(g)
        x = torch.tensor(nx.adjacency_matrix(g).toarray(),dtype=torch.float)
        logits = model(x)
        correct = logits.argmax()==y
        n_planar += y*correct
        n_correct += correct

        # print(f"y: {y}")
        # print(f"logits.argmax(): {logits.argmax()}")
        # print(f"correct: {correct}")
        # plt.figure(figsize=(8, 6))
        # nx.draw(g, with_labels=True, node_color='lightblue', edge_color='gray', node_size=500, font_size=10)
        # plt.title("Random Graph")
        # plt.show()

    print(f"Accuracy on {n_samples} random graphs with {m} edges: {(100*n_correct/n_samples).item():.2f}%")
    print(f"{(100*n_planar/n_correct).item():.2f}% of correctly classified graphs were planar.")

# **Cache activations from example graphs**

In [3]:
m = 19

petersen_graph = graph_to_adj(nx.petersen_graph())
cycle_graph = graph_to_adj(nx.cycle_graph(model.cfg.n_vertices))
complete_graph = graph_to_adj(nx.complete_graph(model.cfg.n_vertices))
empty_graph = graph_to_adj(nx.empty_graph(model.cfg.n_vertices))
star_graph = graph_to_adj(nx.star_graph(model.cfg.n_vertices-1))
random_graph = graph_to_adj(nx.gnm_random_graph(model.cfg.n_vertices,m))

loader = get_dataloader('adj','data/n10_15-21_tn_vn/val.npz',128)
batch,batch_labels = next(iter(loader))

model.set_use_attn_result(True)
_, petersen_cache = model.run_with_cache(petersen_graph)
_, cycle_cache = model.run_with_cache(cycle_graph)
_, complete_cache = model.run_with_cache(complete_graph)
_, empty_cache = model.run_with_cache(empty_graph)
_, star_cache = model.run_with_cache(star_graph)
_, random_cache = model.run_with_cache(random_graph)
_, batch_cache = model.run_with_cache(batch)
model.set_use_attn_result(False)

# **Plot attention patterns**

In [None]:
plot_attention(adj_to_graph(random_graph),random_cache["pattern",1])

**Observation**: In Head 1.6 vertices of low degree attend stongly to vertex 2. Let's try and work out how each vertex knows whether or not it has low degree. This can easily be read off from the input but we don't know that's what the model is actually doing. 

My first guess as to how the attention pattern in head 1.6 comes about is this:
- Each vertex of low degree has the concept of "I am a vertex of low degree" encoded in its residual stream position. 
- The query matrix for head 1.6 reads this information from the residual stream and produces queries that encode the concept "I am looking for the special vertex position" in low degree vertex positions.
- The key matrix for head 1.6 produces a key that says "I am in the special vertex position" in position 2. I have no idea why, i guess position 2 might contain useful information that the model wants low degree vertices to have?

If my guess is true, then any degree information would be passed into head 1.6 via Q-composition from components earlier in the model. Let't take a look at how earlier components contribute to the query and key in head 1.6:

In [51]:
def mlp(input: Float[Tensor,"batch n_vertices d_model"],model,layer):
    return model.blocks[layer].mlp(input)

mlp_0 = partial(mlp,model=model,layer=0)

def decompose_qk_input(cache: ActivationCache) -> Float[Tensor, "batch n_components n_vertices d_model"]:

    y_embed = cache["embed"].unsqueeze(dim=1) # (batch 1 n_vertices d_model)
    y_pos = cache["pos_embed"].unsqueeze(dim=1) # (batch 1 n_vertices d_model)
    y_heads = cache["result", 0].transpose(1, 2) # (batch n_heads n_vertices d_model)
    y_stack = torch.cat([y_embed, y_pos, y_heads], dim=1)

    return mlp_0(y_stack) + y_stack


def decompose_q(
    decomposed_qk_input: Float[Tensor, "batch n_components n_vertices d_head"],
    head_index: int,
    model: Transformer,
) -> Float[Tensor, "batch n_components n_vertices d_head"]:

    W_Q = model.W_Q[1, head_index]

    return einops.einsum(
        decomposed_qk_input, W_Q,
        "batch n_components n_vertices d_model, d_model d_head -> batch n_components n_vertices d_head"
    )


def decompose_k(
    decomposed_qk_input: Float[Tensor, "batch n_components n_vertices d_head"],
    head_index: int,
    model: Transformer,
) -> Float[Tensor, "batch n_components n_vertices d_head"]:

    W_K = model.W_K[1, head_index]

    return einops.einsum(
        decomposed_qk_input, W_K,
        "batch n_components n_vertices d_model, d_model d_head -> batch n_components n_vertices d_head"
    )

def decompose_attn_scores(
    decomposed_q: Float[Tensor, "batch n_components n_vertices d_head"],
    decomposed_k: Float[Tensor, "batch n_components n_vertices d_head"]
) -> Float[Tensor, "batch query_component key_component query_pos key_pos"]:
    
    return einops.einsum(
        decomposed_q, decomposed_k,
        "batch q_comp q_pos d_model, batch k_comp k_pos d_model -> batch q_comp k_comp q_pos k_pos",
    )

In [None]:
head_index = 6

# First we get decomposed q and k input, and check they're what we expect
decomposed_qk_input = decompose_qk_input(batch_cache)

decomposed_q = decompose_q(decomposed_qk_input, head_index, model)
decomposed_k = decompose_k(decomposed_qk_input, head_index, model)


# Second, we plot our results
component_labels = ["Embed", "PosEmbed"] + [f"0.{h}" for h in range(model.cfg.n_heads)]
for decomposed_input, name in [(decomposed_q, "query"), (decomposed_k, "key")]:
    px.imshow(
        decomposed_input.pow(2).sum(-1).sqrt().mean(0).detach(),
        labels={"x": "Position", "y": "Component"},
        title=f"Norms of components of {name}",
        y=component_labels,
        color_continuous_scale='Blues',
        width=1000, height=400
    ).show()

If we believe that the Norms of the components of the query and the key of head 1.6 are correlated with the importance of that component in the function of head1.6, then these plots support my theory that any degree information is passed into head 1.6 via Q-composition from components earlier in the model.
Indeed if it was passed via K-composition then we would expect to see components of the key which have high norms across all positions, as low degree vertices can occur at any position.

**Further Observations**: 
- We also see that when creating keys, head 1.6 pretty much only cares about the positional information coming from position 2. This suggests that the concept of "I want to gather low degree vertices"&mdash;that is encoded in the position 2 key&mdash;comes entirely from head 1.6 and is not passed on from an earlier component.
- The query plot suggests that, out of the previous components, Embed, PosEmbed and head 0.2, are primarily responsible for passing on degree infomation to the vertex positions in head 1.6, with heads 0.1 and 0.3 playing a less important role.

These plots are not fully convincing though, as the norm is not obviously a principled indicator of causality from earlier components to later ones. We can get a better idea of which components are causaly important for the function of head 1.6 by taking a single example graph and pairing up individual key and query components and seeing what attention scores they produce:

In [None]:
decomposed_scores = decompose_attn_scores(decomposed_q, decomposed_k)
decomposed_stds = einops.reduce(
    decomposed_scores,
    "batch query_decomp key_decomp query_pos key_pos -> batch query_decomp key_decomp",
    torch.std
)

sample = 1 # Indexing a sample from `batch`
zmax = 80 # Colorscale max

px.imshow(
    batch_cache['pattern',1][sample,6].detach(),
    color_continuous_scale='Blues',
    title="Original Sample Attention Pattern"
).show()

fig = make_subplots(
    rows=len(component_labels), 
    row_titles=component_labels,
    y_title='Query Components',
    cols=len(component_labels),
    column_titles=component_labels,
    x_title='Key Components'
)

for i,q_component in enumerate(component_labels):
    for j,k_component in enumerate(component_labels):

        heatmap = go.Heatmap(
            z=decomposed_scores[sample, i, j].detach(),
            colorscale='RdBu',
            zmax=zmax,
            zmin=-zmax
        )
        fig.add_trace(heatmap, row=i+1, col=j+1)
        fig.update_yaxes(autorange='reversed')

fig.update_layout(height=2000, width=2000)

fig.show()

# std dev over query and key positions, shown by component. Mean over whole batch
px.imshow(
    decomposed_stds.mean(0).detach(),
    labels={"x": "Key Component", "y": "Query Component"},
    title="Standard deviations of attention score contributions (by key and query component)",
    x=component_labels,
    y=component_labels,
    color_continuous_scale='Blues',
    width=800
).show()

These plots make me pretty confident that the QK circuit from head 0.2 to PosEmbed via head 1.6 is responsible for moving information from position 2 to low degree vertices. Specifically, head 0.2 Q-composes with head 1.6 and PosEmbed K-composes with head 1.6.

While I find this all convincing, one problem with the above plots is that they all rely on activations from a single graph input or a small batch input into the model. It would be great if we could show this relationship between components just by looking at the model weights. We do this by studying the *composition scores* of each pair of components:

# **Composition Scores**

In [65]:
def get_comp_score(
    W_A: Float[Tensor, "in_A out_A"],
    W_B: Float[Tensor, "out_A out_B"]
) -> float:

    W_A_norm = W_A.pow(2).sum().sqrt()
    W_B_norm = W_B.pow(2).sum().sqrt()
    W_AB_norm = (W_A @ W_B).pow(2).sum().sqrt()

    return (W_AB_norm / (W_A_norm * W_B_norm)).item()

def plot_comp_scores(
    model, 
    comp_scores, 
    component_labels=component_labels, 
    title: str = "", 
    baseline: Optional[Tensor] = None
) -> go.Figure:
    
    px.imshow(
        comp_scores,
        y=component_labels,
        x=[f"1.{h}" for h in range(model.cfg.n_heads)],
        labels={"x": "Layer 1", "y": "Layer 0"},
        title=title,
        color_continuous_scale="RdBu" if baseline is not None else "Blues",
        color_continuous_midpoint=baseline if baseline is not None else None,
        zmin=None if baseline is not None else 0.0,
    ).show()

def generate_single_random_comp_score() -> float:
    '''
    Generates a single composition score for random matrices
    '''
    W_A_left = torch.empty(model.cfg.d_model, model.cfg.d_head)
    W_B_left = torch.empty(model.cfg.d_model, model.cfg.d_head)
    W_A_right = torch.empty(model.cfg.d_model, model.cfg.d_head)
    W_B_right = torch.empty(model.cfg.d_model, model.cfg.d_head)

    for W in [W_A_left, W_B_left, W_A_right, W_B_right]:
        torch.nn.init.kaiming_uniform_(W, a=np.sqrt(5))

    W_A = W_A_left @ W_A_right.T
    W_B = W_B_left @ W_B_right.T

    return get_comp_score(W_A, W_B)

In [None]:
# Get all QK and OV matrices
W_QK = model.W_Q @ model.W_K.transpose(-2, -1)
W_OV = model.W_V @ model.W_O

# note: adding embed and pos_embed to output spaces may not be very principled (it was my idea and I don't know if anyone else does it)
output_space = [
    (mlp_0(model.W_E) + model.W_E).unsqueeze(dim=0),
    (mlp_0(model.W_pos) + model.W_pos).unsqueeze(dim=0),
    *list(mlp_0(W_OV[0]) + W_OV[0])
]

input_space = {
    'Q': W_QK[1],
    'K': W_QK[1].transpose(-2,-1),
    'V': W_OV[1]
}

# Define tensors to hold the composition scores
composition_scores = {
    "Q": torch.zeros(model.cfg.n_heads+2, model.cfg.n_heads).to(device),
    "K": torch.zeros(model.cfg.n_heads+2, model.cfg.n_heads).to(device),
    "V": torch.zeros(model.cfg.n_heads+2, model.cfg.n_heads).to(device),
}

for i,out_sp in enumerate(output_space):
    for j in range(model.cfg.n_heads):
        for comp_type in "QKV":
            composition_scores[comp_type][i, j] = get_comp_score(out_sp, input_space[comp_type][j])

# baseline
n_samples = 300
comp_scores_baseline = np.zeros(n_samples)
for i in tqdm(range(n_samples)):
    comp_scores_baseline[i] = generate_single_random_comp_score()

# px.histogram(
#     comp_scores_baseline,
#     nbins=50,
#     width=800,
#     labels={"x": "Composition score"},
#     title="Random composition scores"
# )

for comp_type in "QKV":
    plot_comp_scores(
        model,
        composition_scores[comp_type], 
        title=f"{comp_type} Composition Scores",
        baseline=comp_scores_baseline.mean()
    )

This strongly supports my theory: it would be a pretty big coincidence if the output space of head 0.2 happend to be the only layer 0 head to have a strong overlap with the query input space of head 1.6 and head 0.2 wasn't integral to head 1.6's functionality. Similarly, we see that the output space of PosEmbed has a strong overlap with the key input space of head 1.6.

If you still aren't convinced then let's take a "corrupted" graph with (in some sense) no low degree vertices as model input and then patch in the cached head 0.2 activations for a "clean" graph with some low degree vertices and see to what extent we can reproduce the "clean" behaviour of head 1.6:

# **Activation Patching**

In [160]:
def pattern_metric( 
    patched_cache: ActivationCache,
    clean_cache: ActivationCache,
) -> Float[Tensor, ""]:
    patched_pattern = patched_cache['blocks.1.attn.hook_pattern']
    clean_pattern = clean_cache['blocks.1.attn.hook_pattern']
    return torch.norm(patched_pattern[0,6]-clean_pattern[0,6])

def patch_head_outputs(
    corrupted_head_outputs: Float[Tensor, "batch n_vertices n_heads d_head"],
    hook: HookPoint,
    head_ids: List[int],
    clean_cache: ActivationCache
) -> Float[Tensor, "batch n_vertices n_heads d_head"]:
    '''
    Patches the output of a given head (before it's added to the residual stream) at
    every vertex position, using the value from the clean cache.
    '''
    corrupted_head_outputs[:, :, head_ids] = clean_cache[hook.name][:,:, head_ids]
    
    return corrupted_head_outputs

def run_with_patched_heads(
    model: Transformer,
    corrupted_input: Float[Tensor, "batch n_vertices n_vertices"],
    clean_cache: ActivationCache,
    head_ids: List[int]
):
    hook_fn = partial(patch_head_outputs, head_ids=head_ids, clean_cache=clean_cache)
    model.run_with_hooks(
        corrupted_input,
        fwd_hooks=[('blocks.0.attn.hook_z',hook_fn)]
    )

def rank_heads_by_patching_metric(
    model: Transformer,
    corrupted_input: Float[Tensor, "batch n_vertices n_vertices"],
    clean_cache: ActivationCache,
    patching_metric: Callable,
    patching_names: Optional[Union[Callable[[str], bool], Sequence[str], str]] 
) -> Float[Tensor, "batch n_vertices n_vertices"]:
    '''
    Patches an increasing set of heads, each time adding the head that has the best effect
    on the `patching_metric`. Displays heads in the order they were added along with the
    metric score the achieved along with all previously added heads.

    The `patching_metric` function should be called on the model's patched and clean caches. 
    The patched cache is created by caching activations specified by `patching_names`.
    '''
    model.reset_hooks()

    patched_cache = model.add_caching_hooks(patching_names)

    importance_ranking = []
    metric_list = []

    heads = list(range(model.cfg.n_heads))

    while heads:

        min_metric = 100
        min_head = -1

        for head in heads:
            ids = importance_ranking + [head]
            
            run_with_patched_heads(model,corrupted_input,clean_cache,ids)

            metric = patching_metric(patched_cache,clean_cache)
            if metric < min_metric:
                min_metric = metric
                min_head = head

        metric_list.append(min_metric)
        importance_ranking.append(heads.pop(heads.index(min_head)))
    
    px.line(x=[str(i) for i in importance_ranking], y=metric_list, labels={'x': 'Head', 'y': 'Pattern Reconstruction Metric'}).show()

In [None]:
rank_heads_by_patching_metric(model,petersen_graph,random_cache,pattern_metric,'blocks.1.attn.hook_pattern')

In [None]:
px.imshow(random_cache["pattern",1][0,6],color_continuous_scale='Blues',title="Clean Pattern").show()
px.imshow(petersen_cache["pattern",1][0,6],color_continuous_scale='Blues',title="Corrupted Pattern",zmax=1).show()

model.reset_hooks()
patched_cache = model.add_caching_hooks('blocks.1.attn.hook_pattern')

heads_to_patch = [[2],[1,2],[1,2,3],[0],[4],[5],[6],[7],[0,4,5,6,7]]

for heads in heads_to_patch:
    run_with_patched_heads(model,petersen_graph,random_cache,heads)

    head_string = ' '.join(f'0.{head}' for head in heads)

    px.imshow(
        patched_cache['blocks.1.attn.hook_pattern'][0,6],
        color_continuous_scale='Blues',
        title=f"Pattern with {head_string} patched",
        zmax=1
    ).show()

# **OV circuit**

In [133]:
def get_OV_circuit(model,layer,head):
    W_O = model.W_O[layer, head]
    W_V = model.W_V[layer, head]
    W_E = model.W_E
    W_U = model.W_U

    OV_circuit = FactoredMatrix(W_V, W_O)
    full_OV_circuit = W_E @ OV_circuit @ W_U
    return full_OV_circuit, OV_circuit

In [None]:
zmax = 0.4

layers = list(range(model.cfg.n_layers))
heads = list(range(model.cfg.n_heads))
fig = make_subplots(
    rows=2, 
    row_titles=[str(i) for i in layers],
    cols=8,
    column_titles=[str(i) for i in heads]
)

for layer in layers:
    for head_index in heads:
        full_OV_circuit, OV_circuit = get_OV_circuit(model,layer,head_index)
        heatmap = go.Heatmap(z=OV_circuit.AB.detach(),zmin=-zmax,zmax=zmax,colorscale='RdBu')
        fig.add_trace(heatmap, row=layer+1, col=head_index+1)
        fig.update_yaxes(autorange='reversed')

# Update the layout for better visualization
fig.update_layout(height=600, width=2000, showlegend=False)

# Show the plot
fig.show()


In [None]:
px.imshow(
    get_OV_circuit(model,0,2)[1].AB.detach(),
    color_continuous_scale="RdBu"
).show()

# **Visualise neurons**

In [5]:
def plot_acts(model,cache):
    """
    Plots mean neuron activations for all mlp layers in cache.
    """ 
    assert len(cache['embed'].shape) > 2, "No batch dimension found."
    for i in range(model.cfg.n_layers):
        px.imshow(
            cache["mlp_post",i].mean(0), # mean over batch dimension
            y=[f"0.{h}" for h in range(model.cfg.n_vertices)],
            labels={"x": "Neurons", "y": "Vertices"},
            title=f"Layer {i} Neuron Activations",
            color_continuous_scale="Blues"
        ).show()

In [None]:
plot_acts(model,random_cache)

# **Correlation between neuron activation and vertex degree**

In [101]:
def get_degree_act_correlation(
        input: Float[Tensor,"batch n_vertices n_vertices"],
        acts: Float[Tensor,"batch n_vertices d_mlp"]
):
    degrees = input.sum(dim=-1) # batch n_vertices
    degrees_flattened = einops.rearrange(
        degrees,
        "batch n_vertices -> 1 (batch n_vertices)"
    )

    acts_flattened = einops.rearrange(
        acts,
        "batch n_vertices d_mlp-> d_mlp (batch n_vertices)"
    )

    degree_acts_stack = torch.cat([degrees_flattened,acts_flattened]) # (batch n_vertices) d_mlp+1
    print(degree_acts_stack.shape)
    correlations = torch.corrcoef(degree_acts_stack)[0,1:] # d_mlp

    return correlations

In [None]:
loader = get_dataloader('adj','data/n10_15-21_tn_vn/val.npz',128)
input, _ = next(iter(loader))
_, cache = model.run_with_cache(input)
degree_act_correlations = get_degree_act_correlation(input, cache['mlp_post',0])
px.line(
    y=degree_act_correlations.detach(),
    labels={"x": "Neuron", "y": "Correlation with degree"},
    title=f"Correlation between neuron activation and degree"
).show()

# **Logit attribution**

In [13]:
def get_logit_diff_directions(
    model,
    labels: Int[Tensor,"batch"]
) -> Float[Tensor,"d_model batch"]:
    return model.W_U[:, labels] - model.W_U[:, ~labels]

def get_logit_attribution(
    component_results: Float[Tensor,"... batch n_vertices d_model"],
    logit_diff_directions: Float[Tensor,"d_model batch"] 
) -> Float[Tensor,"..."]:

    batch_size = logit_diff_directions.size(-1)
    pooled_results: Float[Tensor,"... batch d_model"] = component_results.mean(dim=-2)
    mean_logit_attribution = einops.einsum(
        pooled_results,
        logit_diff_directions,
        "... batch d_model, d_model batch -> ..."
    )/batch_size
    return mean_logit_attribution

def plot_logit_attribution(logit_attribution: Float[Tensor,"n_components"],component_captions: list):
    px.line(
        x=component_captions,
        y=logit_attribution.detach(),
        labels={"x": "Component", "y": "Logit Attribution"},
        title=f"Component Logit Attributions"
    ).show()

In [None]:
component_results, component_captions = batch_cache.get_full_resid_decomposition(return_labels=True, expand_neurons=False)
layer_results = torch.stack([v for k,v in batch_cache.items() if 'out' in k])
layer_captions = [k for k in cache if 'out' in k]

logit_diff_directions = get_logit_diff_directions(model,batch_labels)

component_logit_attribution = get_logit_attribution(component_results,logit_diff_directions)
layer_logit_attribution = get_logit_attribution(layer_results,logit_diff_directions)

plot_logit_attribution(component_logit_attribution,component_captions)
plot_logit_attribution(layer_logit_attribution,layer_captions)

**Investigate 1_mlp_out further**

In [None]:
neuron_results = einops.rearrange(
    batch_cache.get_neuron_results(layer=1),
    "batch n_vertices n_neurons d_model -> n_neurons batch n_vertices d_model"
)

neuron_logit_attribution = get_logit_attribution(neuron_results,logit_diff_directions)

plot_logit_attribution(neuron_logit_attribution,[str(i) for i in range(model.cfg.d_mlp)])

# **Make datasets of specific examples**

In [None]:
dataset_kwargs = {
    'n': model.cfg.n_vertices,
    'size': 1099,
    'start': 15,
    'end': 22
}

planar_graphs = generate_planar(**dataset_kwargs)

non_planar_graphs = generate_planar(**dataset_kwargs, non_planar=True)

In [None]:
misclassified_graphs, correct_labels = generate_misclass(
    model=model,
    start=15,
    end=17,
    size=4
)

In [None]:
for adj in misclassified_graphs:
    g = adj_to_graph(adj)
    is_planar,cert = nx.check_planarity(g, counterexample=True)
    print(is_planar)
    plot_graph(cert)

**Study average activations across sets of specific examples**

In [None]:
model.set_use_attn_result(True)
_, planar_cache = model.run_with_cache(planar_graphs)
_, non_planar_cache = model.run_with_cache(non_planar_graphs)
_, misclassified_cache = model.run_with_cache(misclassified_graphs)
model.set_use_attn_result(False)

plot_acts(model,planar_cache)
plot_acts(model,non_planar_cache)
plot_acts(model,misclassified_cache)

**Logit attributions across sets of specific examples**

In [None]:
planar_neuron_results = einops.rearrange(
    planar_cache.get_neuron_results(layer=1),
    "batch n_vertices n_neurons d_model -> n_neurons batch n_vertices d_model"
)

non_planar_neuron_results = einops.rearrange(
    non_planar_cache.get_neuron_results(layer=1),
    "batch n_vertices n_neurons d_model -> n_neurons batch n_vertices d_model"
)

misclassified_neuron_results = einops.rearrange(
    misclassified_cache.get_neuron_results(layer=1),
    "batch n_vertices n_neurons d_model -> n_neurons batch n_vertices d_model"
)

planar_logit_diff_directions = get_logit_diff_directions(model,torch.ones(dataset_kwargs['size'],dtype=torch.int64))
non_planar_logit_diff_directions = get_logit_diff_directions(model,torch.zeros(dataset_kwargs['size'],dtype=torch.int64))
misclassified_logit_diff_directions = get_logit_diff_directions(model,torch.zeros(len(misclassified_graphs),dtype=torch.int64))

planar_neuron_logit_attribution = get_logit_attribution(planar_neuron_results,planar_logit_diff_directions)
non_planar_neuron_logit_attribution = get_logit_attribution(non_planar_neuron_results,non_planar_logit_diff_directions)
misclassified_neuron_logit_attribution = get_logit_attribution(misclassified_neuron_results,misclassified_logit_diff_directions)

plot_logit_attribution(planar_neuron_logit_attribution,[str(i) for i in range(model.cfg.d_mlp)])
plot_logit_attribution(non_planar_neuron_logit_attribution,[str(i) for i in range(model.cfg.d_mlp)])
plot_logit_attribution(misclassified_neuron_logit_attribution,[str(i) for i in range(model.cfg.d_mlp)])

# test theory with patching etc.