# Basic Trace Test

Tests basic `model.trace()` functionality with hidden state extraction.

**Environment Variables:**
- `MODEL_NAME`: Model to test (e.g., "meta-llama/Llama-3.1-8B")
- `NDIF_API`: NDIF API key
- `HF_TOKEN`: HuggingFace token (for gated models)

In [None]:
import os
import time

# Get model name from environment
MODEL_NAME = os.environ.get("MODEL_NAME", "openai-community/gpt2")
print(f"Testing model: {MODEL_NAME}")

In [None]:
# Configure NDIF
from nnsight import CONFIG

NDIF_API = os.environ.get("NDIF_API")
if NDIF_API:
    CONFIG.set_default_api_key(NDIF_API)
    print("NDIF API key configured")
else:
    print("Warning: NDIF_API not set")

# Configure HuggingFace token
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
    os.environ["HF_TOKEN"] = HF_TOKEN
    print("HF_TOKEN configured")

In [None]:
# Load model
from nnsight import LanguageModel

print(f"Loading {MODEL_NAME}...")
start = time.time()
model = LanguageModel(MODEL_NAME, device_map="auto")
load_time = time.time() - start
print(f"Model loaded in {load_time:.1f}s")

In [None]:
# Run basic trace - use model name to select correct layer path
# NOTE: Variables from outside model.trace() are NOT available on the server!
# We must use separate trace blocks per architecture.

prompt = "The quick brown fox"
print(f"Running trace on: '{prompt}'")

# Detect architecture from model name
model_lower = MODEL_NAME.lower()

start = time.time()
if 'gpt-j' in model_lower or 'gpt2' in model_lower:
    # GPT-2, GPT-J: transformer.h
    with model.trace(prompt, remote=True):
        hidden = model.transformer.h[0].output[0].save()
elif 'gpt-neo' in model_lower or 'pythia' in model_lower:
    # GPT-NeoX, Pythia: gpt_neox.layers
    with model.trace(prompt, remote=True):
        hidden = model.gpt_neox.layers[0].output[0].save()
else:
    # Llama, Mistral, Qwen, OLMo, etc.: model.layers
    with model.trace(prompt, remote=True):
        hidden = model.model.layers[0].output[0].save()

trace_time = time.time() - start
print(f"Trace completed in {trace_time:.1f}s")

In [None]:
# Validate results
import torch

print(f"Hidden state shape: {hidden.shape}")
print(f"Hidden state dtype: {hidden.dtype}")

# Verify shape is reasonable
assert len(hidden.shape) >= 2, f"Expected at least 2D tensor, got {hidden.shape}"
assert hidden.shape[-1] > 0, "Hidden dimension should be positive"

# Check for NaN/Inf
assert not torch.isnan(hidden).any(), "Hidden state contains NaN values"
assert not torch.isinf(hidden).any(), "Hidden state contains Inf values"

# Check values are reasonable
assert hidden.abs().max() < 1000, f"Hidden values too large: max={hidden.abs().max()}"

print("\n" + "=" * 40)
print("BASIC TRACE " + "TEST PASSED")
print("=" * 40)