In [2]:
import types
import torch
from attention_utils import AttentionPatcher
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_id = 'meta-llama/Llama-3.2-1B'
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map='cuda:0',
    torch_dtype=torch.bfloat16,
    cache_dir="/share/u/can/models",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [4]:
prompt = "Hello world,"
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
layers = [11] # index of attention layer/block to cache attention patterns from

save_attn_for = [0] # save attention for head_idx
cut_edges = None
# cut_edges = [0, AttentionEdge(q_idx=5, k_idx=4)]    # [head_idx, [AttentionEdge(q_idx, k_idx)]] to cut off attention enge q_idx --> k_idx via a specific head

In [6]:
# Initialize storage for attention weights and contributions
attn_weights = {} # attention weights: layer_idx --> head_idx --> attn_matrix
attn_contributions = {} # attention contributions ie. multiplication of head with value matrix: layer_idx --> head_idx --> attn_matrix

# Replace forward method of attention blocks with the patched version
for layer in layers:
    attn_weights[layer] = {}
    attn_contributions[layer] = {}
    attn_block_name = f"layers.{layer}.self_attn"
    attn_block = model.model.layers[layer].self_attn
    attn_block.forward = types.MethodType(
        AttentionPatcher(
            block_name=attn_block_name,
            cut_attn_edges=cut_edges,
            save_attn_for=save_attn_for,
            attn_matrices=attn_weights[layer],
            attn_contributions=attn_contributions[layer],
        ),
        attn_block
    )

# # Run the model
# with model.trace(prompt):
#     out = model.lm_head.output.save()
out = model(input_ids)

In [7]:
attn_weights

{11: {0: tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
           [0.9844, 0.0167, 0.0000, 0.0000],
           [0.9727, 0.0053, 0.0227, 0.0000],
           [0.9609, 0.0119, 0.0178, 0.0086]]], device='cuda:0',
         dtype=torch.bfloat16, grad_fn=<CloneBackward0>)}}