# Logit Lens Tutorial

This tutorial shows how to visualize the **logit lens** of transformer language models.

## What is the Logit Lens?

The logit lens is an interpretability technique that decodes hidden states at each layer into vocabulary probabilities. By applying the model's output projection (`ln_final` → `lm_head`) to intermediate layers, we can see how the model's predictions evolve through its computation.

## Setup

First, install the required packages:

In [None]:
!pip install -q nnterp "logitlenskit @ git+https://github.com/davidbau/logitlenskit.git#subdirectory=python"

## Part 1: Using the Library

The simplest way to use logitlenskit is with just a few lines of code:

In [None]:
from nnterp import StandardizedTransformer
from logitlenskit import collect_logit_lens, show_logit_lens

# Load a model (GPT-2 is small enough to run locally)
model = StandardizedTransformer("openai-community/gpt2")

# Collect logit lens data
data = collect_logit_lens("The capital of France is", model, remote=False)

# Display interactive visualization
show_logit_lens(data, title="GPT-2: The capital of France is")

### Interacting with the Widget

- **Click cells** to see top-k predictions at that layer/position
- **Click tokens** in the popup to pin their probability trajectories
- **Click input tokens** (left column) to compare multiple positions
- **Drag edges** to resize columns and chart

## Part 2: Understanding the Data Format

Let's examine what `collect_logit_lens` returns:

In [None]:
print("Keys:", list(data.keys()))
print()
print("model:", data["model"])
print("input:", data["input"])
print("layers:", data["layers"][:5], "...", f"({len(data['layers'])} total)")
print()
print("topk shape:", data["topk"].shape, "- [n_layers, n_positions, k]")
print("tracked[0] shape:", data["tracked"][0].shape, "- unique tokens at position 0")
print("probs[0] shape:", data["probs"][0].shape, "- [n_layers, n_tracked] trajectories")
print()
print("vocab (sample):", dict(list(data["vocab"].items())[:5]))

### Data Structure

```python
{
    "model": str,                # Model name
    "input": List[str],          # Input tokens as strings
    "layers": List[int],         # Layer indices analyzed
    "topk": Tensor[int32],       # [n_layers, n_pos, k] - top-k token indices
    "tracked": List[Tensor],     # Per-position unique token indices
    "probs": List[Tensor],       # Per-position [n_layers, n_tracked] probabilities
    "vocab": Dict[int, str],     # Token index → string mapping
}
```

## Part 3: How It Works (Implementation)

Here's the complete implementation of `collect_logit_lens`. This is exactly what the library does internally:

In [None]:
import torch

def collect_logit_lens(prompt, model, k=5, layers=None, remote=True):
    """
    Collect logit lens data: top-k predictions and probability trajectories.
    """
    # Tokenize once, client-side
    token_ids = model.tokenizer.encode(prompt)
    n_pos = len(token_ids)

    # Default: all layers
    if layers is None:
        layers = list(range(model.num_layers))
    n_layers = len(layers)

    # Run model, compute logit lens
    # When remote=True, computation happens on NDIF server
    with model.trace(token_ids, remote=remote):
        all_probs = []
        all_topk = []

        for li in layers:
            # Project hidden state to vocabulary: hidden -> norm -> lm_head
            logits = model.lm_head(model.ln_final(model.layers_output[li]))
            probs = torch.softmax(logits[0], dim=-1)
            all_probs.append(probs)
            all_topk.append(probs.topk(k, dim=-1).indices)

        # Stack top-k indices: [n_layers, n_pos, k]
        topk = torch.stack(all_topk).to(torch.int32)

        # For each position: find unique tokens, extract trajectories
        tracked = []
        probs_out = []
        for pos in range(n_pos):
            # Union of all tokens appearing in top-k at any layer
            unique = torch.unique(topk[:, pos, :].flatten()).to(torch.int32)
            # Extract probability trajectory for each unique token
            traj = torch.stack([all_probs[li][pos, unique] for li in range(n_layers)])
            tracked.append(unique)
            probs_out.append(traj)

        # Save results to transmit from server
        result = {"topk": topk, "tracked": tracked, "probs": probs_out}.save()

    # Build vocabulary map (client-side)
    all_ids = set(result["topk"].flatten().tolist())
    for t in result["tracked"]:
        all_ids.update(t.tolist())
    vocab = {i: model.tokenizer.decode([i]) for i in all_ids}

    return {
        "model": model.config._name_or_path,
        "input": [model.tokenizer.decode([t]) for t in token_ids],
        "layers": layers,
        "topk": result["topk"],
        "tracked": result["tracked"],
        "probs": result["probs"],
        "vocab": vocab,
    }

### Key Insights

1. **Standardized access via nnterp**: `model.layers_output[i]`, `model.ln_final`, `model.lm_head` work for any transformer architecture

2. **Server-side computation**: When `remote=True`, all the heavy work (forward pass, softmax, top-k) happens on NDIF servers. Only the small results are transmitted back.

3. **Trajectory tracking**: We find all tokens that appear in top-k at *any* layer, then extract their probabilities at *all* layers. This enables visualizing how predictions evolve.

4. **Bandwidth optimization**: Instead of sending full logits (~500MB for a 70B model), we send only top-k indices and tracked probabilities (~500KB).

## Part 4: Converting to JavaScript Format

The widget uses a JSON format optimized for the browser. Here's the conversion:

In [None]:
def to_js_format(data):
    """
    Convert Python API format to JavaScript V2 format.
    """
    vocab = data["vocab"]
    n_layers = len(data["layers"])
    n_pos = len(data["input"])

    # topk: [n_layers, n_pos, k] indices -> [n_layers][n_pos] string lists
    topk_js = [
        [[vocab[idx.item()] for idx in data["topk"][li, pos]]
         for pos in range(n_pos)]
        for li in range(n_layers)
    ]

    # tracked/probs: parallel arrays -> {token: trajectory} dicts per position
    tracked_js = [
        {
            vocab[idx.item()]: [round(p, 5) for p in data["probs"][pos][:, i].tolist()]
            for i, idx in enumerate(data["tracked"][pos])
        }
        for pos in range(n_pos)
    ]

    return {
        "meta": {"version": 2, "model": data["model"]},
        "input": data["input"],
        "layers": data["layers"],
        "topk": topk_js,
        "tracked": tracked_js,
    }

In [None]:
# Convert and examine
js_data = to_js_format(data)

print("JavaScript format keys:", list(js_data.keys()))
print()
print("meta:", js_data["meta"])
print("input:", js_data["input"])
print()
print("topk[0][0] (layer 0, position 0):", js_data["topk"][0][0])
print()
print("tracked[0] (position 0, first 3 tokens):")
for tok, traj in list(js_data["tracked"][0].items())[:3]:
    print(f"  {tok!r}: {traj[:5]}... ({len(traj)} values)")

### JavaScript Format Structure

```javascript
{
    "meta": {"version": 2, "model": "..."},
    "input": ["The", " capital", ...],
    "layers": [0, 1, 2, ...],
    "topk": [                           // [n_layers][n_positions] → token strings
        [["the", "a", ...], ...],       // layer 0
        ...
    ],
    "tracked": [                        // [n_positions] → {token: trajectory}
        {" Paris": [0.01, 0.02, ...], ...},  // position 0
        ...
    ]
}
```

Key differences from Python format:
- Token indices replaced with strings (no vocab dict needed)
- Trajectories stored as `{token: [prob_at_layer_0, prob_at_layer_1, ...]}`

## Part 5: Using NDIF for Large Models

For large models like Llama-70B, use NDIF remote execution:

```python
import os
from nnsight import CONFIG

# Set up NDIF API key
CONFIG.set_default_api_key(os.environ['NDIF_API_KEY'])

# Load large model
model = StandardizedTransformer("meta-llama/Llama-3.1-70B")

# Collect with remote=True (default)
data = collect_logit_lens("The capital of France is", model, remote=True)
show_logit_lens(data)
```

### Bandwidth Savings

For Llama-70B (80 layers, 128k vocab):
- Full logits: ~547 MB
- Our format: ~500 KB
- **1000× reduction**

## Part 6: Try Different Prompts

In [None]:
prompts = [
    "The quick brown fox jumps over the",
    "To be or not to be, that is the",
    "1 + 1 =",
]

for prompt in prompts:
    data = collect_logit_lens(prompt, model, remote=False)
    display(show_logit_lens(data, title=prompt))

## Summary

You've learned:

1. **What the logit lens is**: Projecting intermediate hidden states to vocabulary probabilities

2. **How to use the library**: `collect_logit_lens()` + `show_logit_lens()`

3. **How it works internally**: The ~50 lines of code that do the actual computation

4. **The data formats**: Python format (tensors + vocab) and JavaScript format (strings)

5. **Why it's efficient**: Server-side top-k extraction reduces bandwidth by 1000×