# Hidden States Test

Tests extracting hidden states from multiple layers via NDIF.

**Environment Variables:**
- `MODEL_NAME`: Model to test
- `NDIF_API`: NDIF API key
- `HF_TOKEN`: HuggingFace token

In [None]:
import os
import time

MODEL_NAME = os.environ.get("MODEL_NAME", "openai-community/gpt2")
print(f"Testing model: {MODEL_NAME}")

In [None]:
# Configure NDIF
from nnsight import CONFIG

NDIF_API = os.environ.get("NDIF_API")
if NDIF_API:
    CONFIG.set_default_api_key(NDIF_API)
    print("NDIF API key configured")

HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
    os.environ["HF_TOKEN"] = HF_TOKEN
    print("HF_TOKEN configured")

In [None]:
# Load model
from nnsight import LanguageModel

print(f"Loading {MODEL_NAME}...")
start = time.time()
model = LanguageModel(MODEL_NAME, device_map="auto")
load_time = time.time() - start
print(f"Model loaded in {load_time:.1f}s")

In [None]:
# Get number of layers from model config
def get_num_layers(config):
    """Get number of layers from model config."""
    if hasattr(config, 'num_hidden_layers'):
        return config.num_hidden_layers
    elif hasattr(config, 'n_layer'):
        return config.n_layer
    elif hasattr(config, 'n_layers'):
        return config.n_layers
    else:
        raise ValueError("Cannot determine number of layers")

n_layers = get_num_layers(model.config)
print(f"Model has {n_layers} layers")

# Select representative layer indices to test (first, quarter, middle, three-quarter, last)
test_layer_indices = sorted(set([
    0,
    n_layers // 4,
    n_layers // 2,
    3 * n_layers // 4,
    n_layers - 1
]))
print(f"Testing layers: {test_layer_indices}")

In [None]:
# Extract hidden states from selected layers
# NOTE: We can't use list comprehensions with variables inside model.trace()
# Instead we extract specific hardcoded layer indices and collect results

prompt = "Hello world"
print(f"Running trace on: '{prompt}'")
print(f"Extracting hidden states from {len(test_layer_indices)} layers...")

# Detect architecture from model name
model_lower = MODEL_NAME.lower()

start = time.time()
if 'gpt-j' in model_lower or 'gpt2' in model_lower:
    # GPT-2, GPT-J: transformer.h
    # Use explicit indices since we can't use variables inside trace
    i0, i1, i2, i3, i4 = test_layer_indices[0], test_layer_indices[1], test_layer_indices[2], test_layer_indices[min(3, len(test_layer_indices)-1)], test_layer_indices[-1]
    with model.trace(prompt, remote=True):
        h0 = model.transformer.h[i0].output[0].save()
        h1 = model.transformer.h[i1].output[0].save()
        h2 = model.transformer.h[i2].output[0].save()
        h3 = model.transformer.h[i3].output[0].save()
        h4 = model.transformer.h[i4].output[0].save()
    hidden_states = [h0, h1, h2, h3, h4]
elif 'gpt-neo' in model_lower or 'pythia' in model_lower:
    # GPT-NeoX, Pythia: gpt_neox.layers
    i0, i1, i2, i3, i4 = test_layer_indices[0], test_layer_indices[1], test_layer_indices[2], test_layer_indices[min(3, len(test_layer_indices)-1)], test_layer_indices[-1]
    with model.trace(prompt, remote=True):
        h0 = model.gpt_neox.layers[i0].output[0].save()
        h1 = model.gpt_neox.layers[i1].output[0].save()
        h2 = model.gpt_neox.layers[i2].output[0].save()
        h3 = model.gpt_neox.layers[i3].output[0].save()
        h4 = model.gpt_neox.layers[i4].output[0].save()
    hidden_states = [h0, h1, h2, h3, h4]
else:
    # Llama, Mistral, Qwen, OLMo, etc.: model.layers
    i0, i1, i2, i3, i4 = test_layer_indices[0], test_layer_indices[1], test_layer_indices[2], test_layer_indices[min(3, len(test_layer_indices)-1)], test_layer_indices[-1]
    with model.trace(prompt, remote=True):
        h0 = model.model.layers[i0].output[0].save()
        h1 = model.model.layers[i1].output[0].save()
        h2 = model.model.layers[i2].output[0].save()
        h3 = model.model.layers[i3].output[0].save()
        h4 = model.model.layers[i4].output[0].save()
    hidden_states = [h0, h1, h2, h3, h4]

trace_time = time.time() - start
print(f"Extraction completed in {trace_time:.1f}s")

In [None]:
# Validate all hidden states
import torch

print(f"\nValidating {len(hidden_states)} layer outputs...")

for idx, (layer_idx, hidden) in enumerate(zip(test_layer_indices, hidden_states)):
    # Check shape
    assert len(hidden.shape) >= 2, f"Layer {layer_idx}: Expected at least 2D tensor"
    
    # Check for NaN/Inf
    assert not torch.isnan(hidden).any(), f"Layer {layer_idx}: Contains NaN"
    assert not torch.isinf(hidden).any(), f"Layer {layer_idx}: Contains Inf"
    
    # Check reasonable values
    max_val = hidden.abs().max().item()
    assert max_val < 10000, f"Layer {layer_idx}: Values too large ({max_val})"
    
    print(f"  Layer {layer_idx}: shape={hidden.shape}, max={max_val:.2f}")

# Check shapes are consistent
shapes = [h.shape for h in hidden_states]
hidden_dim = shapes[0][-1]
for layer_idx, shape in zip(test_layer_indices, shapes):
    assert shape[-1] == hidden_dim, f"Layer {layer_idx}: Inconsistent hidden dim"

print(f"\nAll {len(test_layer_indices)} tested layers validated!")
print(f"Hidden dimension: {hidden_dim}")

print("\n" + "=" * 40)
print("HIDDEN STATES " + "TEST PASSED")
print("=" * 40)