# Llama-3.2-1B Token Attribution Analysis

This notebook demonstrates three different methods for analyzing which input tokens are most important in generating model predictions:

1. **Gradient-Based Attribution**: Compute gradients of output logits w.r.t. input embeddings
2. **Attention Weights Visualization**: Analyze attention patterns across layers
3. **Integrated Gradients**: Use Captum library for sophisticated attribution

We'll use the prompt: "What is the Capital City of Latvia?" and analyze which tokens the model uses to generate the answer.


## 1. Setup and Model Loading

First, let's install the required packages and import necessary libraries.


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoModelForCausalLM, AutoTokenizer
from captum.attr import IntegratedGradients
import warnings
warnings.filterwarnings('ignore')

# Set style for visualizations
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)


In [None]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Available memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


In [None]:
# Load the Llama-3.2-1B model and tokenizer
model_name = "meta-llama/Llama-3.2-1B"

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # Set padding token

print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto",
    low_cpu_mem_usage=True
)

model.eval()  # Set to evaluation mode
print("Model loaded successfully!")
print(f"Model parameters: {model.num_parameters() / 1e6:.2f}M")


## 2. Simple Generation Task

Let's generate an answer to the question: "What is the Capital City of Latvia?"


In [None]:
# Define the prompt
prompt = "What is the Capital City of Latvia?"

print(f"Prompt: {prompt}")
print("\nGenerating response...")


In [None]:
# Tokenize the input
inputs = tokenizer(prompt, return_tensors="pt").to(device)
input_ids = inputs["input_ids"]

print(f"Input tokens: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
print(f"Number of input tokens: {len(input_ids[0])}")


In [None]:
# Generate response
with torch.no_grad():
    outputs = model.generate(
        input_ids,
        max_new_tokens=20,
        do_sample=False,  # Use greedy decoding for reproducibility
        pad_token_id=tokenizer.eos_token_id
    )

# Decode the response
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"\nGenerated text: {generated_text}")

# Extract only the generated part (excluding the prompt)
generated_only = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
print(f"\nGenerated answer: {generated_only}")


## 3. Method A: Gradient-Based Attribution

This method computes the gradients of the output logits with respect to the input token embeddings. Tokens with higher gradient magnitudes have more influence on the prediction.


In [None]:
def gradient_based_attribution(model, input_ids, target_token_idx=-1):
    """
    Compute gradient-based attribution scores for input tokens.
    
    Args:
        model: The language model
        input_ids: Input token IDs
        target_token_idx: Index of the output token to compute gradients for (-1 for last token)
    
    Returns:
        attribution_scores: Attribution scores for each input token
    """
    # Ensure gradients are enabled
    with torch.enable_grad():
        # Get embeddings layer and its weight matrix
        embeddings_layer = model.get_input_embeddings()
        embedding_matrix = embeddings_layer.weight  # Shape: (vocab_size, embedding_dim)
        
        # Manually create embeddings by indexing the weight matrix
        # This ensures we build a proper computation graph
        token_embeddings = embedding_matrix[input_ids]  # Shape: (batch, seq_len, embedding_dim)
        
        # Detach and clone to break any existing graph, then enable gradients
        token_embeddings = token_embeddings.detach().clone()
        token_embeddings.requires_grad_(True)
        
        # Forward pass with custom embeddings
        outputs = model(inputs_embeds=token_embeddings)
        logits = outputs.logits
        
        # Get the logits for the target position
        target_logits = logits[0, target_token_idx]
        
        # Get the predicted token
        predicted_token_id = target_logits.argmax(dim=-1)
        
        # Compute gradients of the predicted token's logit w.r.t. input embeddings
        target_score = target_logits[predicted_token_id]
        target_score.backward()
        
        # Get gradients
        gradients = token_embeddings.grad
        
        # Check if gradients were computed
        if gradients is None:
            raise RuntimeError("Gradients were not computed. This may indicate an issue with the computation graph.")
        
        # Compute attribution scores (L2 norm of gradients)
        attribution_scores = gradients.norm(dim=-1).squeeze().cpu().detach().numpy()
    
    return attribution_scores, predicted_token_id.item()

print("Computing gradient-based attribution...")
grad_scores, predicted_id = gradient_based_attribution(model, input_ids)
predicted_token = tokenizer.decode([predicted_id])
print(f"Predicted next token: '{predicted_token}'")


In [None]:
# Visualize gradient-based attribution
input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

plt.figure(figsize=(14, 6))
plt.bar(range(len(input_tokens)), grad_scores, color='steelblue', alpha=0.7)
plt.xlabel('Token Position', fontsize=12)
plt.ylabel('Attribution Score (Gradient Norm)', fontsize=12)
plt.title(f'Gradient-Based Attribution: Predicting "{predicted_token}"', fontsize=14, fontweight='bold')
plt.xticks(range(len(input_tokens)), input_tokens, rotation=45, ha='right')
plt.tight_layout()
plt.grid(axis='y', alpha=0.3)
plt.show()

# Print top contributing tokens
top_k = 5
top_indices = np.argsort(grad_scores)[-top_k:][::-1]
print(f"\nTop {top_k} contributing tokens (Gradient-Based):")
for i, idx in enumerate(top_indices, 1):
    print(f"{i}. '{input_tokens[idx]}' (position {idx}): {grad_scores[idx]:.4f}")


## 4. Method B: Attention Weights Visualization

This method analyzes the attention patterns across all layers to see which input tokens the model focuses on when making predictions.


In [None]:
def extract_attention_weights(model, input_ids):
    """
    Extract attention weights from all layers.
    
    Args:
        model: The language model
        input_ids: Input token IDs
    
    Returns:
        attention_weights: Aggregated attention weights
    """
    with torch.no_grad():
        outputs = model(input_ids, output_attentions=True)
    
    # Get attention weights from all layers
    # attentions is a tuple of (num_layers, batch_size, num_heads, seq_len, seq_len)
    attentions = outputs.attentions
    
    # Stack all layers and average across layers and heads
    # We'll look at attention from the last token to all previous tokens
    all_attention = torch.stack(attentions)  # (num_layers, batch, num_heads, seq_len, seq_len)
    
    # Average across layers and heads
    avg_attention = all_attention.mean(dim=(0, 2))  # (batch, seq_len, seq_len)
    
    # Get attention from the last token to all tokens
    last_token_attention = avg_attention[0, -1, :].cpu().numpy()
    
    return last_token_attention, all_attention

print("Extracting attention weights...")
attention_scores, all_attention = extract_attention_weights(model, input_ids)
print(f"Shape of all attention weights: {all_attention.shape}")
print(f"Number of layers: {len(model.model.layers)}")


In [None]:
# Visualize attention weights
plt.figure(figsize=(14, 6))
plt.bar(range(len(input_tokens)), attention_scores, color='coral', alpha=0.7)
plt.xlabel('Token Position', fontsize=12)
plt.ylabel('Attention Score (Averaged)', fontsize=12)
plt.title('Attention Weights: Last Token Attending to Input Tokens', fontsize=14, fontweight='bold')
plt.xticks(range(len(input_tokens)), input_tokens, rotation=45, ha='right')
plt.tight_layout()
plt.grid(axis='y', alpha=0.3)
plt.show()

# Print top attending tokens
top_k = 5
top_indices = np.argsort(attention_scores)[-top_k:][::-1]
print(f"\nTop {top_k} attended tokens (Attention Weights):")
for i, idx in enumerate(top_indices, 1):
    print(f"{i}. '{input_tokens[idx]}' (position {idx}): {attention_scores[idx]:.4f}")


In [None]:
# Create a detailed attention heatmap across layers
# Average across heads for each layer, showing attention from last token
layer_attention = all_attention[:, 0, :, -1, :].mean(dim=1).cpu().numpy()  # (num_layers, seq_len)

plt.figure(figsize=(14, 8))
sns.heatmap(
    layer_attention,
    xticklabels=input_tokens,
    yticklabels=[f'Layer {i}' for i in range(layer_attention.shape[0])],
    cmap='YlOrRd',
    cbar_kws={'label': 'Attention Weight'},
    annot=False
)
plt.xlabel('Input Tokens', fontsize=12)
plt.ylabel('Transformer Layers', fontsize=12)
plt.title('Attention Heatmap Across All Layers (Last Token)', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()


## 5. Method C: Integrated Gradients

Integrated Gradients is a more sophisticated attribution method that computes the path integral of gradients from a baseline to the actual input. This provides more faithful attributions.


In [None]:
def model_forward(embeddings, attention_mask=None):
    """
    Forward function for Captum's IntegratedGradients.
    
    Args:
        embeddings: Input embeddings
        attention_mask: Attention mask (optional)
    
    Returns:
        logits for the predicted token at the last position
    """
    outputs = model(inputs_embeds=embeddings, attention_mask=attention_mask)
    # Return logits for the last position
    return outputs.logits[:, -1, :]

def compute_integrated_gradients(model, input_ids, baseline_type='zero'):
    """
    Compute integrated gradients attribution.
    
    Args:
        model: The language model
        input_ids: Input token IDs
        baseline_type: Type of baseline ('zero' or 'pad')
    
    Returns:
        attribution_scores: Attribution scores for each input token
        predicted_token_id: ID of the predicted token
    """
    embeddings_layer = model.get_input_embeddings()
    embedding_matrix = embeddings_layer.weight
    
    # Get input embeddings using manual indexing for proper gradient computation
    # This ensures the computation graph is built correctly
    input_embeddings = embedding_matrix[input_ids].detach().clone()
    
    # Create baseline
    if baseline_type == 'zero':
        baseline_embeddings = torch.zeros_like(input_embeddings)
    else:  # 'pad'
        pad_token_id = tokenizer.pad_token_id
        baseline_ids = torch.full_like(input_ids, pad_token_id)
        # Use manual indexing for baseline too
        baseline_embeddings = embedding_matrix[baseline_ids].detach().clone()
    
    # Get predicted token
    with torch.no_grad():
        outputs = model(input_ids)
        predicted_token_id = outputs.logits[0, -1].argmax().item()
    
    # Initialize IntegratedGradients
    # Captum will handle enabling gradients internally
    ig = IntegratedGradients(model_forward)
    
    # Compute attributions
    attributions = ig.attribute(
        inputs=input_embeddings,
        baselines=baseline_embeddings,
        target=predicted_token_id,
        n_steps=50,
        internal_batch_size=1
    )
    
    # Compute L2 norm of attributions for each token
    attribution_scores = attributions.norm(dim=-1).squeeze().cpu().detach().numpy()
    
    return attribution_scores, predicted_token_id

print("Computing Integrated Gradients (this may take a moment)...")
ig_scores, ig_predicted_id = compute_integrated_gradients(model, input_ids, baseline_type='zero')
ig_predicted_token = tokenizer.decode([ig_predicted_id])
print(f"Predicted next token: '{ig_predicted_token}'")


In [None]:
# Visualize Integrated Gradients attribution
plt.figure(figsize=(14, 6))
plt.bar(range(len(input_tokens)), ig_scores, color='mediumseagreen', alpha=0.7)
plt.xlabel('Token Position', fontsize=12)
plt.ylabel('Attribution Score (IG)', fontsize=12)
plt.title(f'Integrated Gradients Attribution: Predicting "{ig_predicted_token}"', fontsize=14, fontweight='bold')
plt.xticks(range(len(input_tokens)), input_tokens, rotation=45, ha='right')
plt.tight_layout()
plt.grid(axis='y', alpha=0.3)
plt.show()

# Print top contributing tokens
top_k = 5
top_indices = np.argsort(ig_scores)[-top_k:][::-1]
print(f"\nTop {top_k} contributing tokens (Integrated Gradients):")
for i, idx in enumerate(top_indices, 1):
    print(f"{i}. '{input_tokens[idx]}' (position {idx}): {ig_scores[idx]:.4f}")


## 6. Comparison of All Three Methods

Let's compare all three attribution methods side by side to see which tokens each method identifies as most important.


In [None]:
# Normalize scores for better comparison
def normalize_scores(scores):
    return (scores - scores.min()) / (scores.max() - scores.min() + 1e-10)

grad_scores_norm = normalize_scores(grad_scores)
attention_scores_norm = normalize_scores(attention_scores)
ig_scores_norm = normalize_scores(ig_scores)

# Create comparison plot
fig, axes = plt.subplots(3, 1, figsize=(16, 12))

# Plot 1: Gradient-Based
axes[0].bar(range(len(input_tokens)), grad_scores_norm, color='steelblue', alpha=0.7)
axes[0].set_ylabel('Normalized Score', fontsize=11)
axes[0].set_title('A) Gradient-Based Attribution', fontsize=13, fontweight='bold')
axes[0].set_xticks(range(len(input_tokens)))
axes[0].set_xticklabels(input_tokens, rotation=45, ha='right')
axes[0].grid(axis='y', alpha=0.3)

# Plot 2: Attention Weights
axes[1].bar(range(len(input_tokens)), attention_scores_norm, color='coral', alpha=0.7)
axes[1].set_ylabel('Normalized Score', fontsize=11)
axes[1].set_title('B) Attention Weights', fontsize=13, fontweight='bold')
axes[1].set_xticks(range(len(input_tokens)))
axes[1].set_xticklabels(input_tokens, rotation=45, ha='right')
axes[1].grid(axis='y', alpha=0.3)

# Plot 3: Integrated Gradients
axes[2].bar(range(len(input_tokens)), ig_scores_norm, color='mediumseagreen', alpha=0.7)
axes[2].set_xlabel('Token Position', fontsize=12)
axes[2].set_ylabel('Normalized Score', fontsize=11)
axes[2].set_title('C) Integrated Gradients', fontsize=13, fontweight='bold')
axes[2].set_xticks(range(len(input_tokens)))
axes[2].set_xticklabels(input_tokens, rotation=45, ha='right')
axes[2].grid(axis='y', alpha=0.3)

plt.suptitle('Comparison of Token Attribution Methods', fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.show()


In [None]:
# Create a combined heatmap
attribution_matrix = np.vstack([
    grad_scores_norm,
    attention_scores_norm,
    ig_scores_norm
])

plt.figure(figsize=(16, 6))
sns.heatmap(
    attribution_matrix,
    xticklabels=input_tokens,
    yticklabels=['Gradient-Based', 'Attention Weights', 'Integrated Gradients'],
    cmap='RdYlGn',
    cbar_kws={'label': 'Normalized Attribution Score'},
    annot=True,
    fmt='.2f',
    linewidths=0.5
)
plt.xlabel('Input Tokens', fontsize=12)
plt.ylabel('Attribution Method', fontsize=12)
plt.title('Combined Attribution Heatmap', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()


In [None]:
# Summary statistics
print("=" * 80)
print("SUMMARY: Token Attribution Analysis")
print("=" * 80)
print(f"\nPrompt: {prompt}")
print(f"Generated Answer: {generated_only}")
print(f"\nPredicted next token: '{predicted_token}'")

print("\n" + "-" * 80)
print("Top 3 Most Important Tokens by Each Method:")
print("-" * 80)

methods = [
    ('Gradient-Based', grad_scores_norm),
    ('Attention Weights', attention_scores_norm),
    ('Integrated Gradients', ig_scores_norm)
]

for method_name, scores in methods:
    top_3 = np.argsort(scores)[-3:][::-1]
    print(f"\n{method_name}:")
    for rank, idx in enumerate(top_3, 1):
        print(f"  {rank}. '{input_tokens[idx]}' (pos {idx}): {scores[idx]:.4f}")

print("\n" + "=" * 80)


## Analysis and Insights

### Method Comparison:

1. **Gradient-Based Attribution**:
   - Shows which tokens have the largest gradient magnitudes
   - Indicates tokens that, if changed slightly, would most affect the output
   - Fast to compute but can be noisy

2. **Attention Weights**:
   - Shows which tokens the model explicitly attends to
   - Provides interpretability through the attention mechanism
   - May not fully capture all influences (attention is just one component)

3. **Integrated Gradients**:
   - More theoretically grounded attribution method
   - Satisfies desirable axioms like completeness and sensitivity
   - Slower to compute but generally more reliable

### Key Observations:

- All three methods typically identify question words ("What", "Latvia") as important
- The specific tokens "Capital", "City", and "Latvia" are usually highly weighted
- Different methods may emphasize different aspects of the input
- Combining multiple attribution methods provides a more complete picture
