# LogitLensKit Demo

This notebook demonstrates how to use LogitLensKit to visualize the logit lens of transformer language models.

## What is the Logit Lens?

The **Logit Lens** is an interpretability technique that decodes hidden states at each layer into vocabulary probabilities. By applying the model's output projection to intermediate layers, we can see how the model's predictions evolve through its computation.

## Setup

First, ensure you have the required dependencies:

```bash
cd python
pip install -e ".[dev]"
```

In [None]:
# Add logitlenskit to path (for local development)
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd().parent / 'python' / 'src'))

In [None]:
from nnsight import LanguageModel
from logitlenskit import collect_logit_lens_topk_efficient, show_logit_lens

## Load a Model

We'll start with GPT-2, which is small enough to run locally.

In [None]:
model = LanguageModel('openai-community/gpt2', device_map='auto')
print(f'Loaded {model.config.model_type} with {model.config.n_layer} layers')

## Collect Logit Lens Data

The `collect_logit_lens_topk_efficient` function extracts:
- Top-k predictions at each layer and token position
- Probability trajectories showing how predictions evolve

In [None]:
prompt = 'The capital of France is'

data = collect_logit_lens_topk_efficient(
    prompt,
    model,
    top_k=5,
    track_across_layers=True,
    remote=False,  # Set to True for NDIF remote execution
)

print(f'Collected data for {len(data["tokens"])} tokens across {len(data["layers"])} layers')

## Visualize with Interactive Widget

The `show_logit_lens` function creates an interactive visualization:

- **Click cells** to see top-k predictions
- **Click tokens** in the popup to pin their trajectories
- **Click input tokens** (left column) to compare multiple positions
- **Drag edges** to resize columns and chart

In [None]:
show_logit_lens(data, model.tokenizer, title=f'GPT-2: {prompt}')

## Try Different Prompts

Experiment with different prompts to see how the model's predictions evolve:

In [None]:
prompts = [
    'The quick brown fox jumps over the',
    'To be or not to be, that is the',
    '1 + 1 =',
]

for prompt in prompts:
    data = collect_logit_lens_topk_efficient(
        prompt, model, top_k=5, track_across_layers=True, remote=False
    )
    display(show_logit_lens(data, model.tokenizer, title=prompt))

## Using NDIF for Large Models

For large models like Llama-70B, use NDIF remote execution. This performs all computation on the server, sending only the essential top-k results back to you.

```python
# Set up NDIF (requires API key in .env.local)
from nnsight import CONFIG
import os
CONFIG.set_default_api_key(os.environ['NDIF_API'])

# Load a large model
model = LanguageModel('meta-llama/Llama-3.1-70B', device_map='auto', token=os.environ['HF_TOKEN'])

# Collect with remote=True
data = collect_logit_lens_topk_efficient(
    prompt, model, top_k=5, track_across_layers=True,
    remote=True  # Server-side computation
)
```

Bandwidth comparison:
- Naive (full logits): ~819 MB
- With server-side reduction: ~320 KB

## Analyzing Specific Layers

You can analyze a subset of layers for faster exploration:

In [None]:
# Analyze every other layer
data = collect_logit_lens_topk_efficient(
    'Hello world',
    model,
    layers=[0, 2, 4, 6, 8, 10, 11],  # GPT-2 has 12 layers (0-11)
    track_across_layers=True,
    remote=False,
)

show_logit_lens(data, model.tokenizer, title='Layer subset analysis')

## Next Steps

- Explore different model architectures (see [MODEL_SUPPORT.md](../docs/MODEL_SUPPORT.md))
- Try the [live demo](https://davidbau.github.io/logitlenskit/) with Llama 70B data
- Check the [API reference](../docs/PYTHON_API.md) for advanced usage