## **Interpretation in LLMs**

In [23]:
%%capture
!pip install captum

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings

warnings.filterwarnings('ignore')

import torch.nn.functional as F
from captum.attr import IntegratedGradients


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


### **Attention Based Attribution**

In [15]:
# Load model
model_id = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, output_attentions=True)
model.eval()

prompt = (
    "On a sunny afternoon in the countryside, villagers set up colorful tents and wooden stalls for the annual harvest celebration. "
    "Children chased butterflies near the fields, while parents prepared traditional dishes over open fires. "
    "Music from flutes and drums filled the air as stories were shared under the shade of tall oak trees. "
    "Among the crowd stood a kind-hearted little"
)
target_token = "girl"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

input_ids = tokenizer(prompt, return_tensors="pt").input_ids
target_id = tokenizer.convert_tokens_to_ids(target_token)

In [6]:
with torch.no_grad():
    outputs = model(input_ids)
    logits = outputs.logits
    attentions = outputs.attentions  # List of attention maps from each layer

In [7]:
# Get the prediction logits
pred_logits = logits[0, -1]  # Last position (prediction after prompt)

# Get top predictions
probs = torch.softmax(pred_logits, dim=-1)
topk = torch.topk(probs, k=10)
top_tokens = tokenizer.convert_ids_to_tokens(topk.indices.tolist())
print("Top predictions:", top_tokens)

Top predictions: ['Ġgirl', 'Ġboy', 'Ġman', 'Ġvillage', 'Ġwoman', 'Ġchild', 'Ġbrother', 'Ġlady', 'Ġdog', 'Ġtown']


In [8]:
num_layers = len(attentions)
batch_size, num_heads, seq_len, _ = attentions[0].shape

# linear weighting -> higher layers get more weight (according to BERTology some layers speciailise in local dependencies)
layer_weights = torch.linspace(1.0, 2.0, steps=num_layers)
layer_weights /= layer_weights.sum() # normalized to 1

weighted_attn = torch.zeros(seq_len)

for i, layer_attn in enumerate(attentions):
    last_token_attn = layer_attn[0, :, -1, :] # (heads, src_len)
    mean_attn = last_token_attn.mean(dim=0)
    weighted_attn += layer_weights[i] * mean_attn

weighted_attn /= weighted_attn.sum()

In [10]:
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
token_attn_pairs = list(zip(tokens, weighted_attn.tolist()))
token_attn_sorted = sorted(token_attn_pairs, key=lambda x: x[1], reverse=True)

# Print top influential tokens based on attention
print("\n📊 Top tokens based on weighted attention to predicted token:")
for token, score in token_attn_sorted:
    token = token.replace('Ġ', '')
    print(f"{token:20} → {score:.4f}")


📊 Top tokens based on weighted attention to predicted token:
In                   → 0.5264
hearted              → 0.0710
A                    → 0.0584
little               → 0.0546
.                    → 0.0431
kind                 → 0.0396
-                    → 0.0170
Children             → 0.0166
.                    → 0.0151
village              → 0.0119
families             → 0.0114
laughter             → 0.0092
festival             → 0.0091
the                  → 0.0087
and                  → 0.0086
faces                → 0.0079
ran                  → 0.0072
filled               → 0.0072
painted              → 0.0069
,                    → 0.0066
with                 → 0.0064
air                  → 0.0063
around               → 0.0055
gathered             → 0.0050
annual               → 0.0049
for                  → 0.0044
the                  → 0.0044
spring               → 0.0041
a                    → 0.0039
quiet                → 0.0038
nest                 → 0.0036
hills   

### **Integrated Gradients**

In [21]:
# Load GPT-2
model_id = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

prompt = (
    "On a sunny afternoon in the countryside, villagers set up colorful tents and wooden stalls for the annual harvest celebration. "
    "Children chased butterflies near the fields, while parents prepared traditional dishes over open fires. "
    "Music from flutes and drums filled the air as stories were shared under the shade of tall oak trees. "
    "Among the crowd stood a kind-hearted little"
)

target_token = "girl"
target_id = tokenizer.convert_tokens_to_ids(target_token)

In [22]:
# Tokenize
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)  # (1, seq_len)
attention_mask = inputs["attention_mask"].to(device)
seq_len = input_ids.shape[1]

# Get embedding layer
embedding_layer = model.transformer.wte

# Forward function to return logit of target token
def forward_func(input_embeds):
    fake_attention_mask = torch.ones(input_embeds.shape[:2], dtype=torch.long).to(device)
    outputs = model(inputs_embeds=input_embeds, attention_mask=fake_attention_mask)
    logits = outputs.logits  # shape: (1, seq_len, vocab_size)
    return logits[:, -1, target_id]  # return only logit for target

# Create baseline (zero embedding)
baseline = torch.zeros_like(embedding_layer(input_ids))

# Get embeddings from input tokens
input_embed = embedding_layer(input_ids).detach()
input_embed.requires_grad = True

# Run Integrated Gradients
ig = IntegratedGradients(forward_func)
attributions, delta = ig.attribute(
    inputs=input_embed,
    baselines=baseline,
    return_convergence_delta=True
)

# Aggregate attribution scores across embedding dimensions
attribution_scores = attributions.sum(dim=-1).squeeze(0)  # (seq_len,)

# Decode input tokens
tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze(0))

# Pair tokens and attribution scores
token_attr_pairs = list(zip(tokens, attribution_scores.tolist()))
token_attr_pairs_sorted = sorted(token_attr_pairs, key=lambda x: abs(x[1]), reverse=True)

print("\n📊 Token Influence on Predicting:", target_token)
for token, score in token_attr_pairs_sorted:
    token = token.replace("Ġ", "")
    print(f"{token:20} → {score:.4f}")


📊 Token Influence on Predicting: girl
sunny                → 24.9788
a                    → 21.0395
little               → -9.8640
oak                  → 8.1216
On                   → -6.1214
in                   → 6.0677
afternoon            → 6.0600
Among                → 5.6198
Music                → -2.9468
over                 → -2.6909
tents                → -2.2920
drums                → 2.2845
hearted              → -2.2663
colorful             → -2.2488
set                  → -2.2285
the                  → 2.1400
a                    → -1.9853
villagers            → -1.8188
,                    → 1.7516
open                 → -1.6642
shade                → 1.6149
from                 → -1.6038
trees                → 1.5972
up                   → -1.5608
crowd                → -1.5132
near                 → -1.4539
harvest              → 1.4393
celebration          → 1.3505
fires                → -1.3296
dishes               → -1.1240
wooden               → 1.0828
shared      