# NDIF Monitor - Hidden States

**Model:** `meta-llama/Llama-2-7b-hf`

This notebook tests nnsight + NDIF functionality. Generated by [NDIF Monitor](https://github.com/davidbau/ndif-monitor).

## Setup

First, configure your API keys in Colab Secrets (🔑 icon in left sidebar):
- `NDIF_API`: Your NDIF API key from [nnsight.net](https://nnsight.net)
- `HF_TOKEN`: Your HuggingFace token (for gated models)


In [None]:
# Install dependencies
!pip install -q nnsight torch


In [None]:
# Configure authentication from Colab secrets
import os

try:
    from google.colab import userdata
    NDIF_API = userdata.get('NDIF_API')
    HF_TOKEN = userdata.get('HF_TOKEN')
    print('Loaded secrets from Colab')
except:
    # Fallback for local testing
    NDIF_API = os.environ.get('NDIF_API')
    HF_TOKEN = os.environ.get('HF_TOKEN')
    print('Using environment variables')

# Configure nnsight
from nnsight import CONFIG
if NDIF_API:
    CONFIG.set_default_api_key(NDIF_API)
    print('NDIF API key configured')
else:
    print('Warning: NDIF_API not set - add it to Colab Secrets')

if HF_TOKEN:
    os.environ['HF_TOKEN'] = HF_TOKEN
    print('HuggingFace token configured')


In [None]:
# Load model
from nnsight import LanguageModel
import time

MODEL_NAME = 'meta-llama/Llama-2-7b-hf'
print(f'Loading {{MODEL_NAME}}...')

start = time.time()
model = LanguageModel(MODEL_NAME, device_map='auto')
load_time = time.time() - start
print(f'Model loaded in {{load_time:.1f}}s')


## Hidden States Extraction

Tests extracting hidden states from all layers.


In [None]:
# Extract hidden states from all layers
prompt = 'Hello world'
print(f"Extracting hidden states from: '{prompt}'")

layers = model.model.layers
num_layers = len(layers)
print(f'Model has {{num_layers}} layers')

start = time.time()
with model.trace(prompt, remote=True):
    states = [layer.output[0].save() for layer in layers]

extract_time = time.time() - start
print(f'Extraction completed in {{extract_time:.1f}}s')

print(f'\nExtracted {{len(states)}} layer states:')
for i, state in enumerate(states[:5]):
    print(f'  Layer {{i}}: {{state.shape}}')
if len(states) > 5:
    print(f'  ... and {{len(states) - 5}} more layers')


In [None]:
# Validate hidden states
import torch

assert len(states) > 0, 'No hidden states extracted'

for i, state in enumerate(states):
    assert not torch.isnan(state).any(), f'Layer {i} contains NaN'
    assert not torch.isinf(state).any(), f'Layer {i} contains Inf'

print('Validation passed!')


In [None]:
print('\n' + '=' * 50)
print('HIDDEN STATES TEST PASSED ✓')
print('=' * 50)
