# **Testing Aggregation Methods**

We know how to find attributions `1-to-1`. We'd like to find attributions:
- `many-to-1` (from an entire specific context document to 1 important token) or 
- `many-to-many` (from an entire specific context document to multiple tokens or the entire answer).

### **Imports and Setup**

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Available memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

### **Load Model and Tokenizer**

In [None]:
# Load the Llama-3.2-1B model and tokenizer
model_name = "meta-llama/Llama-3.2-1B"

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # Set padding token

print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto",
    low_cpu_mem_usage=True,
    attn_implementation="eager"  # Use eager attention to enable output_attentions
)

model.eval()  # Set to evaluation mode
print("Model loaded successfully!")
print(f"Model parameters: {model.num_parameters() / 1e6:.2f}M")

### **Run Generation Task**

- We'll ask the question: `What is the Capital City of Latvia?` (9 input tokens)
- And will receive the answer: `The capital city of Latvia is Riga.` (9 output tokens)

In [None]:
# Define the prompt
prompt = "What is the Capital City of Latvia?"
# Tokenize the input
inputs = tokenizer(prompt, return_tensors="pt").to(device)
input_ids = inputs["input_ids"]

In [None]:
# Manual autoregressive generation for 9 tokens using embeddings
# This approach allows for gradient-based explainability methods later
num_tokens_to_generate = 9

# Get embeddings layer and its weight matrix
embeddings_layer = model.get_input_embeddings()
embedding_matrix = embeddings_layer.weight  # Shape: (vocab_size, embedding_dim)

# Manually create embeddings by indexing the weight matrix
# This ensures we build a proper computation graph for gradient methods
current_token_embeddings = embedding_matrix[input_ids]  # Shape: (batch, seq_len, embedding_dim)

generated_tokens = []

print("Starting manual autoregressive generation using embeddings...")
print(f"Initial prompt: {prompt}")
print(f"Initial tokens: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
print(f"Embedding shape: {current_token_embeddings.shape}\n")

with torch.no_grad():
    for step in range(num_tokens_to_generate):
        # Forward pass with custom embeddings
        outputs = model(inputs_embeds=current_token_embeddings)
        logits = outputs.logits
        
        # Get the logits for the last position
        target_logits = logits[0, -1]
        
        # Get the predicted token (greedy decoding)
        predicted_token_id = target_logits.argmax(dim=-1)
        
        # Store the generated token
        generated_tokens.append(predicted_token_id.item())
        
        # Decode and print
        predicted_token = tokenizer.decode([predicted_token_id])
        print(f"Step {step + 1}: Generated token '{predicted_token}' (ID: {predicted_token_id.item()})")
        
        # Get embedding for the new token and append to the sequence
        new_token_embedding = embedding_matrix[predicted_token_id].unsqueeze(0).unsqueeze(0)
        current_token_embeddings = torch.cat([current_token_embeddings, new_token_embedding], dim=1)

# Print summary
print(f"\n{'='*60}")
print(f"Generation complete!")
print(f"{'='*60}")
print(f"\nGenerated token IDs: {generated_tokens}")
print(f"Generated tokens: {tokenizer.convert_ids_to_tokens(generated_tokens)}")
print(f"\nGenerated text: {tokenizer.decode(generated_tokens)}")
print(f"\nFinal embedding shape: {current_token_embeddings.shape}")
print(f"\nFull text (prompt + generation):")
# Reconstruct full token IDs for decoding
full_token_ids = torch.cat([input_ids[0], torch.tensor(generated_tokens, device=device)])
print(f"{tokenizer.decode(full_token_ids)}")
