# NDIF Monitor - Hidden States

**Model:** `meta-llama/Llama-2-7b-hf`

This notebook tests nnsight + NDIF functionality. Generated by [NDIF Monitor](https://github.com/davidbau/ndif-monitor).

## Setup

Configure your API keys in Colab Secrets (🔑 icon in left sidebar):
- `NDIF_API_KEY`: Your NDIF API key from [nnsight.net](https://nnsight.net)
- `HF_TOKEN`: Your HuggingFace token (for gated models)


In [None]:
# Install dependencies
!pip install -q nnsight torch

# Load API keys from Colab secrets into environment
# nnsight automatically picks up NDIF_API_KEY from env
import os
try:
    from google.colab import userdata
    for key in ['NDIF_API_KEY', 'HF_TOKEN']:
        try:
            os.environ[key] = userdata.get(key)
        except:
            pass
except ImportError:
    pass  # Not in Colab, use existing env vars


In [None]:
# Load model
from nnsight import LanguageModel
import time

MODEL_NAME = 'meta-llama/Llama-2-7b-hf'
print(f'Loading {MODEL_NAME}...')

start = time.time()
model = LanguageModel(MODEL_NAME, device_map='auto')
load_time = time.time() - start
print(f'Model loaded in {load_time:.1f}s')


## Hidden States Extraction

Tests extracting hidden states from all layers.


In [None]:
# Extract hidden states from all layers
prompt = 'Hello world'
print(f"Extracting hidden states from: '{prompt}'")

layers = model.model.layers
num_layers = len(layers)
print(f'Model has {num_layers} layers')

start = time.time()
with model.trace(prompt, remote=True):
    # Collect layer outputs and save as a list
    states = [layer.output[0] for layer in layers]
    states.save()  # nnsight adds .save() to lists

extract_time = time.time() - start
print(f'Extraction completed in {extract_time:.1f}s')

print(f'\nExtracted {len(states)} layer states:')
for i, state in enumerate(states[:5]):
    print(f'  Layer {i}: {state.shape}')
if len(states) > 5:
    print(f'  ... and {len(states) - 5} more layers')


In [None]:
# Validate hidden states
import torch

assert len(states) == num_layers, f'Expected {num_layers} states, got {len(states)}'

for i, state in enumerate(states):
    assert not torch.isnan(state).any(), f'Layer {i} contains NaN'
    assert not torch.isinf(state).any(), f'Layer {i} contains Inf'

print('Validation ' + 'passed!')


In [None]:
print('\n' + '=' * 50)
print('HIDDEN STATES TEST ' + 'PASSED ✓')
print('=' * 50)
