In [29]:
# %% Import modules
import torch
torch.set_grad_enabled(False)
device = torch.device('cpu') #'cuda' if torch.cuda.is_available() else 

import plotly.express as px
import pandas as pd
import numpy as np
import einops
import importlib
import sys
from torch import Tensor
from jaxtyping import Float, Int, Bool
from typing import Callable, Optional
from functools import partial

from transformer_lens import HookedTransformer, ActivationCache
from transformer_lens.utils import get_act_name, test_prompt
import plotly.graph_objects as go

import seaborn as sns
from matplotlib import pyplot as plt
sns.set() # Setting seaborn as default style even if use only matplotlib

from plotting import (
    single_head_full_resid_projection,
    ntensor_to_long,
    line_with_river
)
from load_data import (
    get_prompts_t,
    get_token_counts,
)
from utils import (
    projection,
    cos_similarity,
    reinforcement_ratio
)



In [82]:
#%% Setup model & load data
model = HookedTransformer.from_pretrained('gelu-4l')
model.cfg.use_attn_result = True
model.to(device)

n_total_prompts = 1000
n_c4_total_prompts = int(0.8 * n_total_prompts)
n_code_total_prompts = n_total_prompts - n_c4_total_prompts

prompts_t = get_prompts_t(n_text_prompts=n_c4_total_prompts, n_code_prompts=n_code_total_prompts)

Loaded pretrained model gelu-4l into HookedTransformer
Moving model to device:  cpu
Loading 800 prompts from c4-tokenized-2b...


100%|██████████| 800/800 [00:03<00:00, 228.23it/s]


Loading 200 prompts from code-tokenized...


100%|██████████| 200/200 [00:02<00:00, 83.61it/s] 


In [5]:
get_act_name('result', 0)

'blocks.0.attn.hook_result'

In [26]:
def resample_ablation(activation, hook, head):
    """activation: (batch, pos, head, dmodel)"""
    tmp = activation[0, :, head, :].clone()
    activation[:-1, :, head, :] = activation[1:, :, head, :]
    activation[-1, :, head, :] = tmp


    # activation[:, :, head, :] = torch.zeros_like(activation[:, :, head, :])

    return activation



In [27]:
original_loss = model(prompts_t[:2, :], return_type="loss")

ablated_loss_diff_matrix = torch.zeros((model.cfg.n_layers, model.cfg.n_heads))

for layer in range(model.cfg.n_layers):
    for head in range(model.cfg.n_heads):
        ablated_loss = model.run_with_hooks(
            prompts_t[:2], 
            return_type="loss", 
            fwd_hooks=[(get_act_name('result', layer), partial(resample_ablation, head=head))]
        )
        ablated_loss_diff_matrix[layer, head] = ablated_loss - original_loss


In [25]:
px.imshow(ablated_loss_diff_matrix, color_continuous_midpoint=0, color_continuous_scale='RdBu')

In [28]:
px.imshow(ablated_loss_diff_matrix, color_continuous_midpoint=0, color_continuous_scale='RdBu')

### Showing DLA is misleading by resampling ablation

In [None]:
ablated_results = []

text = "It's in the shelf, either on the top or the"
ori_tokens = model.to_tokens(text)

n_prompts = 1000
for random_prompt_idx in range(n_prompts):
    # random_prompt_idx = 0
    corrupted_tokens = prompts_t[random_prompt_idx:random_prompt_idx+1, :ori_tokens.shape[-1]]

    # print(ori_tokens)
    # print(corrupted_tokens)

    # print(model.to_string(ori_tokens))
    # print(model.to_string(corrupted_tokens))
    # print()


    # show original probability
    original_logits = model(ori_tokens, return_type="logits")
    ori_prob, ori_token = torch.max(torch.softmax(original_logits[0, -1, :], dim=-1), dim=-1)
    ori_prob, ori_token = ori_prob.item(), ori_token.item()
    # print(f"Original: max_prob: {max_prob}, max_token: {max_token}, token: '{model.to_string(max_token)}'")

    _, corrupted_activation = model.run_with_cache(corrupted_tokens)

    def resample_ablation(activation, hook, corrupted_activation, head):
        """activation: (batch, pos, head, dmodel)"""
        activation[:, -1, head, :] = corrupted_activation[hook.name][:, -1, head, :]
        return activation

    layer, head = 0, 2
    ablated_logits = model.run_with_hooks(
        ori_tokens,
        return_type="logits",
        fwd_hooks=[(get_act_name('result', layer), partial(resample_ablation, corrupted_activation=corrupted_activation, head=head))]
    )

    # get the token with highest probability, get the probability as well
    ablated_prob, ablated_token = torch.max(torch.softmax(ablated_logits[0, -1, :], dim=-1), dim=-1)
    ablated_prob, ablated_token = ablated_prob.item(), ablated_token.item()
    # print(model.to_string(ablated_token))
    # print(f"Ablated: max_prob: {max_prob}, max_token: {max_token}, token: '{model.to_string(max_token)}'")


    ablated_results.append(ablated_prob - ori_prob)

    

In [99]:
px.histogram(
    ablated_results,
    title=f'The probability difference of the correct token between original run and resample ablation for {n_prompts} prompts<br>Original probability: {ori_prob:.3f}, token: "{model.to_string(ori_token)}"<br>Prompt: "It\'s in the shelf, either on the top or the"',
    labels={'value': 'Pr(" bottom" | original) - Pr(" bottom" | ablated)'},
)