# LogitLens Tutorial

This tutorial demonstrates how to use the LogitLens tool to analyze intermediate representations in transformer models. LogitLens lets you project hidden states from any layer through the output embedding to see what 'token' would be predicted at each layer.

In [None]:
from easyroutine.utils import path_to_parents
path_to_parents(1)
%load_ext autoreload
%autoreload 2

Changed working directory to: /orfeo/cephfs/home/dssc/francescortu/VisualComp/easyroutine


In [None]:
import torch
from easyroutine.interpretability import HookedModel, ExtractionConfig
from easyroutine.interpretability.tools import LogitLens
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


## 1. Setting Up

First, we'll load a small model. For this tutorial, we'll use a tiny test model, but you can replace this with any model you're interested in studying.

In [None]:
# For the tutorial we'll use a tiny test model
model = HookedModel.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")

# In practice, you can use any model you want, for example:
# model = HookedModel.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto")

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message


## 2. Initialize LogitLens

The LogitLens tool needs access to the model's unembedding matrix (the output embedding weights) and the final layer normalization.

In [None]:
# Create the LogitLens instance from our model
logit_lens = LogitLens.from_model(model)

# You can also create it directly from a model name
# logit_lens = LogitLens.from_model_name("mistralai/Mistral-7B-v0.1")

## 3. Preparing Input and Extracting Activations

Now we'll prepare some input data and extract activations from the model. For a real-world analysis, you'd use meaningful text instead of random tokens.

In [None]:
# Create tokenizer
tokenizer = model.get_tokenizer()

# For real analysis, use a meaningful prompt
prompt = "The capital of France is"
inputs = tokenizer(prompt, return_tensors="pt")

In [None]:
# For demo purposes, we'll also create a simple fake dataset
fake_dataset = [
    {
        "input_ids": torch.randint(0, 100, (1, 20)),
        "attention_mask": torch.ones(1, 20),
    },
    {
        "input_ids": torch.randint(0, 100, (1, 20)),
        "attention_mask": torch.ones(1, 20),
    }
]

# Extract activations for all layers - we need residual stream outputs
# and the final layer norm for best results
cache = model.extract_cache(
    [inputs],
    target_token_positions=["last"],
    extraction_config=ExtractionConfig(
        extract_resid_out=True,
        extract_last_layernorm=True
    )
)

Extracting cache::   0%|          | 0/2 [00:00<?, ?it/s]

Extracting cache:: 100%|██████████| 2/2 [00:00<00:00,  2.54it/s]


Let's look at what's in our cache:

In [None]:
# Print the keys in the cache to understand what we have
print("Cache keys:")
[key for key in cache.keys() if not key.startswith("__")]

ActivationCache(`resid_out_0, resid_out_1, logits, mapping_index, example_dict`)

## 4. Basic LogitLens Analysis

Now we'll apply the LogitLens to see what the model 'predicts' at each layer.

In [None]:
# Number of layers in the model
num_layers = model.model_config.num_hidden_layers
print(f"The model has {num_layers} layers")

# Apply logit lens to all layers
logit_lens_results = {}
for layer in range(num_layers):
    out = logit_lens.compute(
        activations=cache,
        target_key=f"resid_out_{layer}",
        apply_norm=True,  # Apply layer normalization
        apply_softmax=True  # Convert to probabilities
    )
    logit_lens_results[layer] = out[f"logit_lens_resid_out_{layer}"]

Computing Logit Lens of resid_out_{i}: 100%|██████████| 2/2 [00:33<00:00, 16.51s/it]


Let's examine the shape of the output for one layer:

In [None]:
# Look at the shape of the LogitLens output
layer = 0
print(f"Shape of logit lens output for layer {layer}:")
logit_lens_results[layer].shape

torch.Size([2, 1, 32000])

## 5. Analyzing the Results

Let's extract the top-k predicted tokens at each layer and see how they evolve.

In [None]:
def get_top_k_tokens(layer_results, k=5):
    """Get top-k token predictions for the last token in the sequence."""
    # Get probabilities for the last token position
    token_probs = layer_results[0, -1]
    
    # Get top-k predictions
    top_k = torch.topk(token_probs, k)
    
    # Convert token ids to strings
    tokens = [tokenizer.decode(idx.item()) for idx in top_k.indices]
    probs = [prob.item() for prob in top_k.values]
    
    return tokens, probs

# Get top-5 predictions for each layer
all_layer_predictions = {}
for layer in range(num_layers):
    tokens, probs = get_top_k_tokens(logit_lens_results[layer], k=5)
    all_layer_predictions[layer] = list(zip(tokens, probs))

# Create a DataFrame to display the evolution of predictions through layers
predictions_df = pd.DataFrame(
    {f"Layer {layer}": [f"{t} ({p:.3f})" for t, p in preds] 
     for layer, preds in all_layer_predictions.items()}
)

predictions_df

## 6. Comparing Token Directions

We can also compute logit differences between specific tokens to see how the model's 'preference' changes across layers.

In [None]:
# Let's define some interesting token pairs to compare
# If this were a real example using "The capital of France is",
# we might compare "Paris" vs "London"

# Get token IDs for comparison
target_tokens = [" Paris", " London"]
token_ids = [tokenizer.encode(t)[0] for t in target_tokens]

# Compute logit differences across all layers
logit_diffs = []
for layer in range(num_layers):
    result = logit_lens.compute(
        activations=cache,
        target_key=f"resid_out_{layer}",
        token_directions=[(token_ids[0], token_ids[1])],
        metric="logit_diff",
        apply_norm=True
    )
    # Extract the logit difference
    logit_diff = result[f"logit_diff_resid_out_{layer}"].item()
    logit_diffs.append(logit_diff)

# Plot the evolution of logit differences
plt.figure(figsize=(10, 6))
plt.plot(range(num_layers), logit_diffs, marker='o')
plt.axhline(y=0, color='r', linestyle='--', alpha=0.3)
plt.title(f'Logit Difference: {target_tokens[0]} vs {target_tokens[1]} Across Layers')
plt.xlabel('Layer')
plt.ylabel('Logit Difference')
plt.grid(alpha=0.3)
plt.show()

## 7. Advanced Analysis: Visualizing Intermediate Representations

Let's create a heatmap to visualize how the probabilities for top tokens evolve across layers.

In [None]:
# Get the top-10 tokens from the final layer
final_layer = num_layers - 1
final_probs = logit_lens_results[final_layer][0, -1]
top_indices = torch.topk(final_probs, 10).indices
top_tokens = [tokenizer.decode(idx.item()) for idx in top_indices]

# Create a probability matrix for these tokens across all layers
probs_matrix = np.zeros((num_layers, len(top_tokens)))
for layer in range(num_layers):
    layer_probs = logit_lens_results[layer][0, -1]
    for i, token_idx in enumerate(top_indices):
        probs_matrix[layer, i] = layer_probs[token_idx].item()

# Create a heatmap
plt.figure(figsize=(12, 8))
plt.imshow(probs_matrix, aspect='auto', cmap='viridis')
plt.colorbar(label='Probability')
plt.xlabel('Top Tokens')
plt.ylabel('Layer')
plt.title('Token Probability Evolution Across Layers')
plt.xticks(range(len(top_tokens)), top_tokens, rotation=45, ha='right')
plt.yticks(range(num_layers))
plt.tight_layout()
plt.show()

## 8. Summary

In this tutorial, we've seen how to:

1. Create a LogitLens instance from a model
2. Extract activations from the model
3. Apply LogitLens to analyze intermediate representations
4. Compare specific token directions across layers
5. Visualize how predictions evolve through the network

LogitLens is a powerful tool for interpreting what happens inside transformer models as information flows through the layers. It can help identify where certain concepts emerge in the model and how representations develop throughout the network.