# NDIF Monitor - Basic Trace

**Model:** `meta-llama/Llama-3.1-8B`

This notebook tests nnsight + NDIF functionality. Generated by [NDIF Monitor](https://github.com/davidbau/ndif-monitor).

## Setup

Configure your API keys in Colab Secrets (🔑 icon in left sidebar):
- `NDIF_API_KEY`: Your NDIF API key from [nnsight.net](https://nnsight.net)
- `HF_TOKEN`: Your HuggingFace token (for gated models)


In [None]:
# Install dependencies
!pip install -q nnsight torch

# Load API keys from Colab secrets into environment
# nnsight automatically picks up NDIF_API_KEY from env
import os
try:
    from google.colab import userdata
    for key in ['NDIF_API_KEY', 'HF_TOKEN']:
        try:
            os.environ[key] = userdata.get(key)
        except:
            pass
except ImportError:
    pass  # Not in Colab, use existing env vars


In [None]:
# Load model
from nnsight import LanguageModel
import time

MODEL_NAME = 'meta-llama/Llama-3.1-8B'
print(f'Loading {MODEL_NAME}...')

start = time.time()
model = LanguageModel(MODEL_NAME, device_map='auto')
load_time = time.time() - start
print(f'Model loaded in {load_time:.1f}s')


## Basic Trace Test

Tests `model.trace()` functionality with hidden state extraction.


In [None]:
# Run basic trace
prompt = 'The quick brown fox jumps over the lazy dog'
print(f"Running trace on: '{prompt}'")

start = time.time()
with model.trace(prompt, remote=True):
    hidden = model.model.layers[0].output[0].save()

trace_time = time.time() - start
print(f'Trace completed in {trace_time:.1f}s')
print(f'Hidden state shape: {hidden.shape}')


In [None]:
# Validate results
import torch

if 'hidden' not in dir():
    raise RuntimeError('Trace was interrupted - hidden state not captured. Try running again.')

# Verify shape is reasonable
assert len(hidden.shape) >= 2, f'Expected at least 2D tensor, got {hidden.shape}'
assert hidden.shape[-1] > 0, 'Hidden dimension should be positive'

# Check for NaN/Inf
assert not torch.isnan(hidden).any(), 'Hidden state contains NaN values'
assert not torch.isinf(hidden).any(), 'Hidden state contains Inf values'

print('Validation passed!')


In [None]:
print('\n' + '=' * 50)
print('BASIC TRACE TEST PASSED ✓')
print('=' * 50)
