# LogitLensKit E2E Test: Basic Workflow

This notebook tests the basic end-to-end workflow:
1. Load a model
2. Collect logit lens data
3. Display the interactive widget

Run with: `pytest --nbmake notebooks/test_e2e_basic.ipynb`

In [None]:
# Setup: Load environment variables
import os
import sys
from pathlib import Path

# Add python/src to path for local development
project_root = Path(os.getcwd()).parent
sys.path.insert(0, str(project_root / 'python' / 'src'))

# Load .env.local
env_path = project_root / '.env.local'
if env_path.exists():
    with open(env_path) as f:
        for line in f:
            if '=' in line and not line.startswith('#'):
                key, val = line.strip().split('=', 1)
                os.environ[key] = val.strip('"').strip("'")
    print('Loaded .env.local')
else:
    print('Warning: .env.local not found')

In [None]:
# Test 1: Import logitlenskit
from logitlenskit import (
    collect_logit_lens_topk_efficient,
    show_logit_lens,
    detect_model_type,
    MODEL_CONFIGS,
)

print('Imports successful')
print(f'Supported model types: {list(MODEL_CONFIGS.keys())}')

In [None]:
# Test 2: Load GPT-2 model (local, no NDIF required)
from nnsight import LanguageModel

model = LanguageModel('openai-community/gpt2', device_map='auto')
print(f'Loaded model: {model.config.model_type}')

# Verify detection
detected = detect_model_type(model)
assert detected == 'gpt2', f'Expected gpt2, got {detected}'
print(f'Model type detected: {detected}')

In [None]:
# Test 3: Collect logit lens data (local execution)
prompt = 'The quick brown fox'

data = collect_logit_lens_topk_efficient(
    prompt,
    model,
    top_k=5,
    track_across_layers=True,
    remote=False,  # Local execution
)

print(f'Tokens: {data["tokens"]}')
print(f'Layers analyzed: {len(data["layers"])}')
print(f'Top indices shape: {data["top_indices"].shape}')
print(f'Tracked tokens per position: {[len(t) for t in data["tracked_indices"]]}')

In [None]:
# Test 4: Validate data structure
assert 'tokens' in data, 'Missing tokens'
assert 'layers' in data, 'Missing layers'
assert 'top_indices' in data, 'Missing top_indices'
assert 'top_probs' in data, 'Missing top_probs'
assert 'tracked_indices' in data, 'Missing tracked_indices'
assert 'tracked_probs' in data, 'Missing tracked_probs'

# Shape checks
n_layers = len(data['layers'])
n_tokens = len(data['tokens'])
k = 5

assert data['top_indices'].shape == (n_layers, n_tokens, k), \
    f'top_indices shape mismatch: {data["top_indices"].shape}'
assert data['top_probs'].shape == (n_layers, n_tokens, k), \
    f'top_probs shape mismatch: {data["top_probs"].shape}'
assert len(data['tracked_indices']) == n_tokens, \
    f'tracked_indices length mismatch: {len(data["tracked_indices"])}'
assert len(data['tracked_probs']) == n_tokens, \
    f'tracked_probs length mismatch: {len(data["tracked_probs"])}'

print('Data structure validation passed!')

In [None]:
# Test 5: Format data for widget (v2 compact format)
from logitlenskit import format_data_for_widget

widget_data = format_data_for_widget(data, model.tokenizer, model_name='openai-community/gpt2')

# Check v2 format structure
assert 'meta' in widget_data, 'Missing meta'
assert 'layers' in widget_data, 'Missing layers'
assert 'input' in widget_data, 'Missing input'
assert 'tracked' in widget_data, 'Missing tracked'
assert 'topk' in widget_data, 'Missing topk'

# Check meta
assert widget_data['meta']['version'] == 2, 'Version should be 2'
assert 'timestamp' in widget_data['meta'], 'Missing timestamp'
assert widget_data['meta']['model'] == 'openai-community/gpt2', 'Model name mismatch'

# Check dimensions
assert len(widget_data['input']) == n_tokens, f'input length mismatch'
assert len(widget_data['tracked']) == n_tokens, f'tracked length mismatch'
assert len(widget_data['topk']) == n_layers, f'topk layers mismatch'
assert len(widget_data['topk'][0]) == n_tokens, f'topk positions mismatch'

# Check tracked structure (dict of token -> trajectory)
first_tracked = widget_data['tracked'][0]
assert isinstance(first_tracked, dict), 'tracked[0] should be dict'
first_token = list(first_tracked.keys())[0]
assert isinstance(first_tracked[first_token], list), 'trajectory should be list'
assert len(first_tracked[first_token]) == n_layers, 'trajectory length should match layers'

# Check topk structure (list of token strings)
first_topk = widget_data['topk'][0][0]
assert isinstance(first_topk, list), 'topk[layer][pos] should be list'
assert all(isinstance(t, str) for t in first_topk), 'topk entries should be strings'

print(f'Widget data (v2): {len(widget_data["input"])} tokens, {len(widget_data["layers"])} layers')
print(f'Tracked tokens at pos 0: {len(widget_data["tracked"][0])} unique tokens')
print(f'Sample topk at layer 0, pos 0: {widget_data["topk"][0][0][:3]}')

In [None]:
# Test 6: Display widget (visual verification)
# This generates HTML that would render in a real Jupyter environment
html = show_logit_lens(data, model.tokenizer, title='GPT-2: The quick brown fox')

# Verify HTML was generated
assert html is not None
html_str = html.data
assert 'LogitLensWidget' in html_str, 'Widget code not in HTML'
assert 'GPT-2: The quick brown fox' in html_str, 'Title not in HTML'

print(f'Generated HTML: {len(html_str)} characters')
print('Widget HTML generation successful!')

# Display (in real Jupyter this would show the interactive widget)
html

In [None]:
# Test 7: Layer subset
layers_subset = [0, 3, 6, 9, 11]  # GPT-2 has 12 layers

data_subset = collect_logit_lens_topk_efficient(
    'Test',
    model,
    top_k=3,
    layers=layers_subset,
    remote=False,
)

assert data_subset['layers'] == layers_subset
assert data_subset['top_indices'].shape[0] == len(layers_subset)

print(f'Layer subset test passed: {layers_subset}')

In [None]:
# Summary
print('=' * 50)
print('E2E Basic Test Summary')
print('=' * 50)
print('1. Imports: PASSED')
print('2. Model loading: PASSED')
print('3. Data collection: PASSED')
print('4. Data validation: PASSED')
print('5. Widget formatting: PASSED')
print('6. HTML generation: PASSED')
print('7. Layer subset: PASSED')
print('=' * 50)
print('All tests passed!')