In [4]:
import warnings
warnings.filterwarnings("ignore")

import argparse 
import torch

from datasets import load_dataset
from transformers import AutoImageProcessor
from lens import HookedVisionTransformer


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load a model (e.g. ViT Base)
model_name_or_path = "google/vit-base-patch16-224"
model = HookedVisionTransformer.from_pretrained(model_name_or_path, device)
image_processor = AutoImageProcessor.from_pretrained(model_name_or_path)
dataset = load_dataset(
    "cifar10",
    cache_dir="./../cache",
    task="image-classification"
)

# run the model and get outputs and activations
with torch.no_grad():
    features = image_processor(dataset['test']['image'][0], return_tensors="pt")
    features.to(device)
    outputs, activations = model.run_with_cache(**features)

In [5]:
activations.keys()

dict_keys(['embeddings', 'encoder.0.ln1', 'encoder.0.attn.q', 'encoder.0.attn.k', 'encoder.0.attn.v', 'encoder.0.ln2', 'encoder.0.intermediate', 'encoder.1.ln1', 'encoder.1.attn.q', 'encoder.1.attn.k', 'encoder.1.attn.v', 'encoder.1.ln2', 'encoder.1.intermediate', 'encoder.2.ln1', 'encoder.2.attn.q', 'encoder.2.attn.k', 'encoder.2.attn.v', 'encoder.2.ln2', 'encoder.2.intermediate', 'encoder.3.ln1', 'encoder.3.attn.q', 'encoder.3.attn.k', 'encoder.3.attn.v', 'encoder.3.ln2', 'encoder.3.intermediate', 'encoder.4.ln1', 'encoder.4.attn.q', 'encoder.4.attn.k', 'encoder.4.attn.v', 'encoder.4.ln2', 'encoder.4.intermediate', 'encoder.5.ln1', 'encoder.5.attn.q', 'encoder.5.attn.k', 'encoder.5.attn.v', 'encoder.5.ln2', 'encoder.5.intermediate', 'encoder.6.ln1', 'encoder.6.attn.q', 'encoder.6.attn.k', 'encoder.6.attn.v', 'encoder.6.ln2', 'encoder.6.intermediate', 'encoder.7.ln1', 'encoder.7.attn.q', 'encoder.7.attn.k', 'encoder.7.attn.v', 'encoder.7.ln2', 'encoder.7.intermediate', 'encoder.8.ln1'

In [19]:
import math
import plotly.express as px
import torch.nn as nn


def compute_attention_patterns(activations: dict, layer: int):

    def transpose_for_scores(x: torch.Tensor) -> torch.Tensor:
        attention_head_size = int(model.config.hidden_size / model.config.num_attention_heads)
        new_x_shape = x.size()[:-1] + (model.config.num_attention_heads, attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    key_layer = transpose_for_scores(activations[f"encoder.{layer}.attn.k"])
    query_layer = transpose_for_scores(activations[f"encoder.{layer}.attn.q"])

    # take the dot product between "query" and "key" to get the raw attention scores
    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
    attention_scores = attention_scores / math.sqrt(int(model.config.hidden_size / model.config.num_attention_heads))

    # normalize the attention scores to probabilities
    attention_probs = nn.functional.softmax(attention_scores, dim=-1)
    return attention_probs


def display_attention_head(attention_patterns: torch.Tensor, head: int):
    attention_pattern = attention_patterns[head, :, :]
    fig = px.imshow(
        attention_pattern,
        title=f"Attention Pattern of Head {head}",
        color_continuous_scale='RdBu_r',
        zmin=-1, zmax=1,
        height=600, width=600
    )
    fig.show()


attention_patterns = compute_attention_patterns(activations, 0)
display_attention_head(attention_patterns[0], 0)

In [26]:
import transformer_lens.utils as utils

from jaxtyping import Float
from transformer_lens.hook_points import HookPoint


# use head ablation hooks
layer_to_ablate = 0
head_index_to_ablate = 8

def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    print(f"Shape of the value tensor: {value.shape}")
    value[:, :, head_index_to_ablate, :] = 0.
    return value

original_loss = model(**features, return_dict=True)["loss"]
print(original_loss)
ablated_loss = model.run_with_hooks(
    features, 
    return_type="loss", 
    fwd_hooks=[(
        utils.get_act_name("v", layer_to_ablate), 
        head_ablation_hook
        )]
    )
print(f"Original Loss: {original_loss.item():.3f}")
print(f"Ablated Loss: {ablated_loss.item():.3f}")

KeyError: 'loss'

# todo: change to VisionTransformerForMaskedImageModelling