In [17]:
# %% Import modules
import torch
torch.set_grad_enabled(False)
device = torch.device('cpu') #'cuda' if torch.cuda.is_available() else 

import plotly.express as px
import pandas as pd
import numpy as np
import einops
import importlib
import sys
from torch import Tensor
from jaxtyping import Float, Int, Bool
from typing import Callable, Optional
from functools import partial

from transformer_lens import HookedTransformer, ActivationCache
from transformer_lens.utils import get_act_name, test_prompt
import plotly.graph_objects as go

import seaborn as sns
from matplotlib import pyplot as plt
sns.set() # Setting seaborn as default style even if use only matplotlib

from plotting import (
    single_head_full_resid_projection,
    ntensor_to_long,
    line_with_river
)
from load_data import (
    get_prompts_t,
    get_token_counts,
)
from utils import (
    projection,
    cos_similarity,
    reinforcement_ratio
)



In [2]:
#%% Setup model & load data
model = HookedTransformer.from_pretrained('gelu-4l')
model.cfg.use_attn_result = True
model.to(device)

prompts_t = get_prompts_t()

Loaded pretrained model gelu-4l into HookedTransformer
Moving model to device:  cpu
Loading 80 prompts from c4-tokenized-2b...


100%|██████████| 80/80 [00:02<00:00, 32.15it/s]


Loading 20 prompts from code-tokenized...


100%|██████████| 20/20 [00:02<00:00,  9.25it/s]


In [3]:
logits, cache = model.run_with_cache(prompts_t[0])


In [5]:
get_act_name('result', 0)

'blocks.0.attn.hook_result'

In [26]:
def resample_ablation(activation, hook, head):
    """activation: (batch, pos, head, dmodel)"""
    tmp = activation[0, :, head, :].clone()
    activation[:-1, :, head, :] = activation[1:, :, head, :]
    activation[-1, :, head, :] = tmp

    # activation[:, :, head, :] = torch.zeros_like(activation[:, :, head, :])

    return activation



In [27]:
original_loss = model(prompts_t[:2, :], return_type="loss")

ablated_loss_diff_matrix = torch.zeros((model.cfg.n_layers, model.cfg.n_heads))

for layer in range(model.cfg.n_layers):
    for head in range(model.cfg.n_heads):
        ablated_loss = model.run_with_hooks(
            prompts_t[:2], 
            return_type="loss", 
            fwd_hooks=[(get_act_name('result', layer), partial(resample_ablation, head=head))]
        )
        ablated_loss_diff_matrix[layer, head] = ablated_loss - original_loss


In [25]:
px.imshow(ablated_loss_diff_matrix, color_continuous_midpoint=0, color_continuous_scale='RdBu')

In [28]:
px.imshow(ablated_loss_diff_matrix, color_continuous_midpoint=0, color_continuous_scale='RdBu')