# **Testing Aggregation Methods**

We know how to find attributions `1-to-1`. We'd like to find attributions:
- `many-to-1` (from an entire specific context document to 1 important token) or 
- `many-to-many` (from an entire specific context document to multiple tokens or the entire answer).

### **Imports and Setup**

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Available memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

### **Load Model and Tokenizer**

In [None]:
# Load the Llama-3.2-1B model and tokenizer
model_name = "meta-llama/Llama-3.2-1B"

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # Set padding token

print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto",
    low_cpu_mem_usage=True,
    attn_implementation="eager"  # Use eager attention to enable output_attentions
)

model.eval()  # Set to evaluation mode
print("Model loaded successfully!")
print(f"Model parameters: {model.num_parameters() / 1e6:.2f}M")

### **Run Generation Task**

- We'll ask the question: `What is the Capital City of Latvia?` (9 input tokens)
- And will receive the answer: `The capital city of Latvia is Riga.` (9 output tokens)

In [None]:
# Define the prompt
prompt = "What is the Capital City of Latvia?"
# Tokenize the input
inputs = tokenizer(prompt, return_tensors="pt").to(device)
input_ids = inputs["input_ids"]

In [None]:
# Manual autoregressive generation for 9 tokens using embeddings
# This approach allows for gradient-based explainability methods later
num_tokens_to_generate = 9

# Get embeddings layer and its weight matrix
embeddings_layer = model.get_input_embeddings()
embedding_matrix = embeddings_layer.weight  # Shape: (vocab_size, embedding_dim)

# Manually create embeddings by indexing the weight matrix
# This ensures we build a proper computation graph for gradient methods
current_token_embeddings = embedding_matrix[input_ids]  # Shape: (batch, seq_len, embedding_dim)

generated_tokens = []

print("Starting manual autoregressive generation using embeddings...")
print(f"Initial prompt: {prompt}")
print(f"Initial tokens: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
print(f"Embedding shape: {current_token_embeddings.shape}\n")

with torch.no_grad():
    for step in range(num_tokens_to_generate):
        # Forward pass with custom embeddings
        outputs = model(inputs_embeds=current_token_embeddings)
        logits = outputs.logits
        
        # Get the logits for the last position
        target_logits = logits[0, -1]
        
        # Get the predicted token (greedy decoding)
        predicted_token_id = target_logits.argmax(dim=-1)
        
        # Store the generated token
        generated_tokens.append(predicted_token_id.item())
        
        # Decode and print
        predicted_token = tokenizer.decode([predicted_token_id])
        print(f"Step {step + 1}: Generated token '{predicted_token}' (ID: {predicted_token_id.item()})")
        
        # Get embedding for the new token and append to the sequence
        new_token_embedding = embedding_matrix[predicted_token_id].unsqueeze(0).unsqueeze(0)
        current_token_embeddings = torch.cat([current_token_embeddings, new_token_embedding], dim=1)

# Print summary
print(f"\n{'='*60}")
print(f"Generation complete!")
print(f"{'='*60}")
print(f"\nGenerated token IDs: {generated_tokens}")
print(f"Generated tokens: {tokenizer.convert_ids_to_tokens(generated_tokens)}")
print(f"\nGenerated text: {tokenizer.decode(generated_tokens)}")
print(f"\nFinal embedding shape: {current_token_embeddings.shape}")
print(f"\nFull text (prompt + generation):")
# Reconstruct full token IDs for decoding
full_token_ids = torch.cat([input_ids[0], torch.tensor(generated_tokens, device=device)])
print(f"{tokenizer.decode(full_token_ids)}")

In [None]:

print(f"\n\n{'#'*80}")
print("### GRADIENT-BASED GENERATION WITH ATTRIBUTION TRACKING ###")
print(f"{'#'*80}\n")

def generate_one_token_with_attribution(model, original_input_ids, generated_token_ids):
    """
    Generate one token and compute attribution scores for the original input tokens.
    
    Args:
        model: The language model
        original_input_ids: Original input token IDs (tensor)
        generated_token_ids: List of already generated token IDs
    
    Returns:
        attribution_scores: Attribution scores for each original input token
        predicted_token_id: The next predicted token ID
    """
    # Combine original input and generated tokens
    if len(generated_token_ids) > 0:
        generated_tensor = torch.tensor(generated_token_ids, device=device).unsqueeze(0)
        combined_input_ids = torch.cat([original_input_ids, generated_tensor], dim=1)
    else:
        combined_input_ids = original_input_ids
    
    # Get embeddings layer and weight matrix
    embeddings_layer = model.get_input_embeddings()
    embedding_matrix = embeddings_layer.weight
    
    # Enable gradient computation
    with torch.enable_grad():
        # Create embeddings by indexing the weight matrix
        token_embeddings = embedding_matrix[combined_input_ids]
        
        # Detach and clone to break any existing graph, then enable gradients
        token_embeddings = token_embeddings.detach().clone()
        token_embeddings.requires_grad_(True)
        
        # Forward pass with custom embeddings
        outputs = model(inputs_embeds=token_embeddings)
        logits = outputs.logits
        
        # Get the logits for the last position
        target_logits = logits[0, -1]
        
        # Get the predicted token
        predicted_token_id = target_logits.argmax(dim=-1)
        
        # Compute gradients of the predicted token's logit w.r.t. input embeddings
        target_score = target_logits[predicted_token_id]
        target_score.backward()
        
        # Get gradients
        gradients = token_embeddings.grad
        
        # Check if gradients were computed
        if gradients is None:
            raise RuntimeError("Gradients were not computed.")
        
        # Compute attribution scores (L2 norm of gradients) for ORIGINAL input tokens only
        original_input_length = original_input_ids.shape[1]
        attribution_scores = gradients[0, :original_input_length].norm(dim=-1).cpu().detach().numpy()
    
    return attribution_scores, predicted_token_id.item()

# Run iterative generation with attribution tracking
print("Starting gradient-based autoregressive generation with attribution tracking...")
print(f"Generating {num_tokens_to_generate} tokens...\n")

# Store results as list of tuples: (generated_token_id, attribution_scores)
generation_results = []
current_generated_tokens = []

for step in range(num_tokens_to_generate):
    # Generate one token with attribution
    attr_scores, next_token_id = generate_one_token_with_attribution(
        model, 
        input_ids, 
        current_generated_tokens
    )
    
    # Store the result
    generation_results.append((next_token_id, attr_scores))
    current_generated_tokens.append(next_token_id)
    
    # Decode and print
    predicted_token = tokenizer.decode([next_token_id])
    print(f"Step {step + 1}: Generated '{predicted_token}' (ID: {next_token_id})")
    print(f"  Attribution scores shape: {attr_scores.shape}")
    print(f"  Top 3 attributed input tokens:")
    
    input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    top_3_indices = np.argsort(attr_scores)[-3:][::-1]
    for idx in top_3_indices:
        print(f"    - '{input_tokens[idx]}' (pos {idx}): {attr_scores[idx]:.4f}")
    print()

# Final summary
print(f"{'='*80}")
print(f"GRADIENT-BASED GENERATION COMPLETE")
print(f"{'='*80}")
print(f"\nGenerated {len(generation_results)} tokens with attribution scores")
print(f"\nGenerated token IDs: {[token_id for token_id, _ in generation_results]}")
print(f"Generated tokens: {tokenizer.convert_ids_to_tokens([token_id for token_id, _ in generation_results])}")
print(f"\nGenerated text: {tokenizer.decode([token_id for token_id, _ in generation_results])}")
print(f"\nAttribution scores stored for each generation step.")
print(f"Access as: generation_results[step] = (token_id, attribution_scores)")

# ====== Process attribution scores per input token across all generated output tokens ======
# "processed_results" will be a list of tuples: (input_token, input_token_id, [attr_score_for_each_output_token])

input_token_ids = input_ids[0].tolist()
input_tokens = tokenizer.convert_ids_to_tokens(input_token_ids)

# Number of output tokens generated
num_output_tokens = len(generation_results)

# Build a matrix attribution_scores_matrix[output_j][input_t]
attribution_scores_matrix = np.stack([attr_scores for (_, attr_scores) in generation_results], axis=0)  # shape: (num_output_tokens, num_input_tokens)

# Now for each input token, collect its scores across all output tokens
processed_results = []
for t, (tok, tok_id) in enumerate(zip(input_tokens, input_token_ids)):
    # scores for this input token across all output tokens (j=0..num_output_tokens-1)
    per_output_scores = attribution_scores_matrix[:, t].tolist()
    processed_results.append( (tok, tok_id, per_output_scores) )

# processed_results[n] = (input_token, input_token_id, [attribution_to_output_j for j in output_tokens])


### **Do Aggregation**
Look at 3 slices of inputs:
- Token 1-3 (What is the)
- Token 4-5 (Capital City)
- Token 6-7 (of Latvia)'

and compare their aggregated attribution scores for the output tokens 6-7 (R iga).

In [None]:
input_slices = {
    "What is the (tokens 1-3)": [1:4],
    "Capital City (tokens 4-5)": [4:6],
    "of Latvia (tokens 6-7)": [6:7]
}

# Output tokens to analyze: tokens 6-7 -> indices 5, 6 (0-indexed in generation_results)
output_token_indices = [6, 7]

# Get the actual output tokens for labeling
output_tokens_for_plot = [tokenizer.decode([generation_results[i][0]]) for i in output_token_indices]

print("="*80)
print("AGGREGATION ANALYSIS")
print("="*80)
print(f"\nInput slices defined:")
for name, indices in input_slices.items():
    tokens = [processed_results[i][0] for i in indices]
    print(f"  {name}: {tokens}")

print(f"\nAnalyzing attribution to output tokens {output_token_indices[0]+1} and {output_token_indices[1]+1}:")
for i, idx in enumerate(output_token_indices):
    print(f"  Output token {idx+1}: '{output_tokens_for_plot[i]}' (ID: {generation_results[idx][0]})")

# Compute aggregated (averaged) attribution scores for each input slice and output token
aggregated_scores = {}

for slice_name, input_indices in input_slices.items():
    slice_scores = []
    
    for output_idx in output_token_indices:
        # Get attribution scores for this output token from all input tokens in the slice
        scores_for_output = []
        for input_idx in input_indices:
            # processed_results[input_idx] = (token, token_id, [scores_for_each_output])
            score = processed_results[input_idx][2][output_idx]
            scores_for_output.append(score)
        
        # Average the scores across the input tokens in this slice
        avg_score = np.mean(scores_for_output)
        slice_scores.append(avg_score)
    
    aggregated_scores[slice_name] = slice_scores

# Print results
print(f"\n{'='*80}")
print("AGGREGATED ATTRIBUTION SCORES")
print(f"{'='*80}\n")

for slice_name, scores in aggregated_scores.items():
    print(f"{slice_name}:")
    for i, (output_idx, score) in enumerate(zip(output_token_indices, scores)):
        print(f"  → Output token {output_idx+1} ('{output_tokens_for_plot[i]}'): {score:.6f}")
    print()

# Plot the results
fig, ax = plt.subplots(figsize=(12, 6))

slice_names = list(aggregated_scores.keys())
num_slices = len(slice_names)
num_outputs = len(output_token_indices)

# Set up bar positions
x = np.arange(num_outputs)
width = 0.25  # Width of each bar

# Plot bars for each slice
colors = ['#4472C4', '#ED7D31', '#A5A5A5']
for i, slice_name in enumerate(slice_names):
    scores = aggregated_scores[slice_name]
    offset = (i - num_slices/2 + 0.5) * width
    ax.bar(x + offset, scores, width, label=slice_name, color=colors[i], alpha=0.8)

# Customize plot
ax.set_xlabel('Output Tokens', fontsize=13, fontweight='bold')
ax.set_ylabel('Average Attribution Score', fontsize=13, fontweight='bold')
ax.set_title('Aggregated Attribution Scores: Input Slices → Output Tokens 6-7', 
             fontsize=14, fontweight='bold', pad=20)
ax.set_xticks(x)
ax.set_xticklabels([f"Token {idx+1}\n'{tok}'" for idx, tok in zip(output_token_indices, output_tokens_for_plot)])
ax.legend(loc='upper right', fontsize=10)
ax.grid(axis='y', alpha=0.3, linestyle='--')

plt.tight_layout()
plt.show()

# Create a heatmap for better visualization
print("\n" + "="*80)
print("HEATMAP VISUALIZATION")
print("="*80 + "\n")

# Prepare data for heatmap
heatmap_data = np.array([aggregated_scores[name] for name in slice_names])

fig, ax = plt.subplots(figsize=(10, 6))
im = ax.imshow(heatmap_data, cmap='YlOrRd', aspect='auto')

# Set ticks and labels
ax.set_xticks(np.arange(num_outputs))
ax.set_yticks(np.arange(num_slices))
ax.set_xticklabels([f"Token {idx+1}: '{tok}'" for idx, tok in zip(output_token_indices, output_tokens_for_plot)])
ax.set_yticklabels(slice_names)

# Rotate the x-axis labels
plt.setp(ax.get_xticklabels(), rotation=0, ha="center")

# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Average Attribution Score', rotation=270, labelpad=20, fontsize=11)

# Add text annotations
for i in range(num_slices):
    for j in range(num_outputs):
        text = ax.text(j, i, f'{heatmap_data[i, j]:.4f}',
                      ha="center", va="center", color="black", fontsize=11, fontweight='bold')

ax.set_title('Attribution Heatmap: Input Slices → Output Tokens 6-7', 
             fontsize=14, fontweight='bold', pad=15)
ax.set_xlabel('Output Tokens', fontsize=12, fontweight='bold')
ax.set_ylabel('Input Token Slices', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

print("Aggregation analysis complete!")
