# RetNet State Extraction - Testing and Verification

This notebook tests the state extraction mechanism for RetNet and verifies that retention states are correctly captured from all layers.

**Goals:**
1. Load the RetNet-2.7B model
2. Extract retention states using forward hooks
3. Verify state shapes and dimensions
4. Test state behavior on different inputs
5. Analyze state properties


In [None]:
import sys
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

try:
    import google.colab
    IN_COLAB = True
    print("Running in Google Colab")
except:
    IN_COLAB = False
    print("Running locally")

if IN_COLAB:
    if not os.path.exists('state-games'):
        print("Cloning repository...")
        !git clone https://github.com/idoavnir-uni/state-games.git
        print("Repository cloned!")
    
    os.chdir('state-games')
    print(f"Current directory: {os.getcwd()}")
    
    print("\nInstalling dependencies...")
    %pip install -q torch>=2.0.0 transformers>=4.30.0 huggingface_hub numpy pandas matplotlib einops h5py scikit-learn
    
    print("\nInstalling Flash Linear Attention library...")
    %pip install -q git+https://github.com/sustcsonglin/flash-linear-attention.git
    
    print("\nDependencies installed!")

if IN_COLAB:
    sys.path.insert(0, '/content/state-games')
else:
    sys.path.insert(0, os.path.abspath('..'))

from models.load_retnet import load_retnet_model, get_model_config, print_model_structure
from models.state_extractor import RetNetStateExtractor, save_states_to_file, load_states_from_file

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("Setup complete!")


## 1. Load Model and Configuration


In [None]:
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cpu":
    print("WARNING: Running on CPU. This will be very slow for 2.7B model.")
    print("Consider running on a GPU or using a smaller model for testing.")


In [None]:
# Load the model
print("Loading RetNet-2.7B model...")
model, tokenizer = load_retnet_model(
    model_name="fla-hub/retnet-2.7B-100B",
    device=device,
    torch_dtype=torch.bfloat16
)


In [None]:
# Get model configuration
config = get_model_config(model)

print("\n=== Key Configuration ===")
print(f"Number of layers: {config.get('num_layers', 'Unknown')}")
print(f"Number of heads: {config.get('num_heads', 'Unknown')}")
print(f"Hidden size: {config.get('hidden_size', 'Unknown')}")
print(f"Vocabulary size: {config.get('vocab_size', 'Unknown')}")
print(f"Max sequence length: {config.get('max_seq_len', 'Unknown')}")


In [None]:
# Print model structure to understand layer organization
print_model_structure(model, max_depth=3)


## 2. Initialize State Extractor


In [None]:
extractor = RetNetStateExtractor(model, verbose=True)

print("\nState extractor ready")


In [None]:
# Debug: Inspect cache structure directly
with torch.no_grad():
    outputs = model(input_ids, use_cache=True)

print(f"Outputs type: {type(outputs)}")
print(f"past_key_values type: {type(outputs.past_key_values)}")
print(f"Number of layers: {len(outputs.past_key_values)}")

print(f"\nFirst layer cache:")
first_layer = outputs.past_key_values[0]
print(f"Type: {type(first_layer)}")
print(f"Keys: {first_layer.keys() if isinstance(first_layer, dict) else 'Not a dict'}")

if "recurrent_state" in first_layer:
    print(f"\nrecurrent_state shape: {first_layer['recurrent_state'].shape}")
    print(f"recurrent_state dtype: {first_layer['recurrent_state'].dtype}")
    
print(f"\nAll available keys in first layer:")
for key, value in first_layer.items():
    if isinstance(value, torch.Tensor):
        print(f"  {key}: shape={value.shape}, dtype={value.dtype}")

## 3. Extract States on Sample Input


In [None]:
# Test with a simple sentence
test_text = "The quick brown fox jumps over the lazy dog."

print(f"Input text: '{test_text}'")
print(f"Input length: {len(test_text)} characters")

# Tokenize
inputs = tokenizer(test_text, return_tensors="pt")
input_ids = inputs.input_ids.to(device)

print(f"Token IDs shape: {input_ids.shape}")
print(f"Tokens: {tokenizer.convert_ids_to_tokens(input_ids[0])}")


In [None]:
# Extract states incrementally at each token position
# This gives us the KâŠ—V memory state after each token is processed
print("\nExtracting states incrementally...")
incremental_states = extractor.extract_states_incremental(input_ids)

print(f"\n=== Incremental State Shapes ===")
print(f"Number of positions: {len(incremental_states)}")
first_pos_states = incremental_states[1]
print(f"Number of layers per position: {len(first_pos_states)}")
print(f"State shape at each position: {first_pos_states[0].shape}")

## 4. Verify State Shapes and Properties

This section verifies that the extracted states have the expected dimensions and properties.


In [None]:
# Display state shapes for all layers
print("\n=== State Shapes ===")
for layer_idx in sorted(states.keys()):
    state = states[layer_idx]
    print(f"Layer {layer_idx:2d}: shape={state.shape}, dtype={state.dtype}, device={state.device}")
