# Interactive Transformer Architecture Visualization

This notebook demonstrates the interactive visualization capabilities.

**Features:**
- Hover over attention blocks to see individual heads
- Hover over MLP blocks to see the expansion/contraction
- Residual stream flows on the left with blocks branching off
- Shows the "+" operation where outputs are added back to residual

In [None]:
import sys
sys.path.insert(0, "..")

from transformer_viz import visualize, InteractiveTransformerViz, VisualizationConfig

## Quick Visualization

The simplest way to visualize a model - just call `visualize()` with a model name:

In [None]:
visualize("gpt2", max_layers=4)

## Different Models

Visualize other architectures:

In [None]:
# Pythia - smaller model
visualize("pythia-70m")

In [None]:
# GPT-2 Medium - larger model (showing 3 layers)
visualize("gpt2-medium", max_layers=3)

## Custom Architecture

Define your own model configuration:

In [None]:
my_model = {
    "n_layers": 4,
    "d_model": 256,
    "n_heads": 4,
    "d_head": 64,
    "d_mlp": 1024,
    "d_vocab": 10000,
    "model_name": "My Tiny Transformer"
}

visualize(my_model)

## With TransformerLens Model

Works directly with TransformerLens `HookedTransformer` models:

In [None]:
# Uncomment if you have transformer-lens installed:
#
# from transformer_lens import HookedTransformer
# 
# model = HookedTransformer.from_pretrained("gpt2-small")
# visualize(model, max_layers=4)

## Custom Colors

Customize the visualization style:

In [None]:
# Purple and green theme
custom_config = VisualizationConfig(
    attention_block_color="#6C5CE7",
    mlp_block_color="#00B894",
    embedding_color="#E17055",
    unembedding_color="#0984E3",
    attention_head_colors=[
        "#a29bfe", "#74b9ff", "#81ecec", "#55efc4",
        "#ffeaa7", "#fab1a0", "#fd79a8", "#e17055"
    ]
)

visualize("gpt2", max_layers=3, config=custom_config)

## Attention-Only Models

In [None]:
# Model with no MLP (attention only)
attn_only = {
    "n_layers": 2,
    "d_model": 512,
    "n_heads": 8,
    "d_head": 64,
    "d_mlp": 0,  # No MLP!
    "d_vocab": 50257,
    "model_name": "Attention-Only 2L"
}

visualize(attn_only)

## Save as HTML

Save the interactive visualization to share:

In [None]:
viz = InteractiveTransformerViz()
viz.from_pretrained("gpt2")
viz.render(max_layers=4)
viz.save_html("gpt2_interactive.html")
print("Saved to gpt2_interactive.html")