# RetNet State Extraction - Testing and Verification

This notebook tests the state extraction mechanism for RetNet and compares different extraction methods.

**Goals:**
1. Load the RetNet-2.7B model
2. Compare 3 extraction methods:
   - `extract_states` - Final state only (single forward pass)
   - `extract_states_incremental` - All positions (O(N²) - slow)
   - `extract_incremental_states_single_pass` - All positions (O(N) - efficient)
3. Verify correctness by comparing results
4. Measure and compare performance


In [None]:
import sys
import os
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

try:
    import google.colab
    IN_COLAB = True
    print("Running in Google Colab")
except:
    IN_COLAB = False
    print("Running locally")

if IN_COLAB:
    if not os.path.exists('state-games'):
        print("Cloning repository...")
        !git clone https://github.com/idoavnir-uni/state-games.git
        print("Repository cloned!")
    
    os.chdir('state-games')
    print(f"Current directory: {os.getcwd()}")
    
    print("\nInstalling dependencies...")
    %pip install -q torch>=2.0.0 transformers>=4.30.0 huggingface_hub numpy pandas matplotlib einops h5py scikit-learn
    
    print("\nInstalling Flash Linear Attention library...")
    %pip install -q git+https://github.com/sustcsonglin/flash-linear-attention.git
    
    print("\nDependencies installed!")

if IN_COLAB:
    sys.path.insert(0, '/content/state-games')
else:
    sys.path.insert(0, os.path.abspath('..'))

from models.load_retnet import load_retnet_model, get_model_config, print_model_structure
from models.state_extractor import RetNetStateExtractor, save_states_to_file, load_states_from_file

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("Setup complete!")


## 1. Load Model and Configuration


In [None]:
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cpu":
    print("WARNING: Running on CPU. This will be very slow for 2.7B model.")
    print("Consider running on a GPU or using a smaller model for testing.")


In [None]:
# Load the model
print("Loading RetNet-2.7B model...")
model, tokenizer = load_retnet_model(
    model_name="fla-hub/retnet-2.7B-100B",
    device=device,
    torch_dtype=torch.bfloat16
)


In [None]:
# Get model configuration
config = get_model_config(model)

print("\n=== Key Configuration ===")
print(f"Number of layers: {config.get('num_layers', 'Unknown')}")
print(f"Number of heads: {config.get('num_heads', 'Unknown')}")
print(f"Hidden size: {config.get('hidden_size', 'Unknown')}")
print(f"Vocabulary size: {config.get('vocab_size', 'Unknown')}")
print(f"Max sequence length: {config.get('max_seq_len', 'Unknown')}")


In [None]:
# Print model structure to understand layer organization
print_model_structure(model, max_depth=3)


## 2. Initialize State Extractor and Prepare Test Input


In [None]:
extractor = RetNetStateExtractor(model, verbose=True)

# Prepare test input
test_text = "The quick brown fox jumps over the lazy dog."
print(f"Input text: '{test_text}'")

inputs = tokenizer(test_text, return_tensors="pt")
input_ids = inputs.input_ids.to(device)

print(f"Token IDs shape: {input_ids.shape}")
print(f"Tokens: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
print(f"\nState extractor ready")


In [None]:
# Debug: Inspect cache structure directly
with torch.no_grad():
    outputs = model(input_ids, use_cache=True)

print(f"Outputs type: {type(outputs)}")
print(f"past_key_values type: {type(outputs.past_key_values)}")
print(f"Number of layers: {len(outputs.past_key_values)}")

print(f"\nFirst layer cache:")
first_layer = outputs.past_key_values[0]
print(f"Type: {type(first_layer)}")
print(f"Keys: {first_layer.keys() if isinstance(first_layer, dict) else 'Not a dict'}")

if "recurrent_state" in first_layer:
    print(f"\nrecurrent_state shape: {first_layer['recurrent_state'].shape}")
    print(f"recurrent_state dtype: {first_layer['recurrent_state'].dtype}")
    
print(f"\nAll available keys in first layer:")
for key, value in first_layer.items():
    if isinstance(value, torch.Tensor):
        print(f"  {key}: shape={value.shape}, dtype={value.dtype}")

## 3. Compare Extraction Methods

We compare three extraction methods:
1. **`extract_states`** - Gets only the final state (single forward pass)
2. **`extract_states_incremental`** - Gets all intermediate states (O(N²) - runs full prefix for each position)
3. **`extract_incremental_states_single_pass`** - Gets all intermediate states (O(N) - efficient incremental processing)


In [None]:
# Method 1: extract_final_states - Final state only (single forward pass)
print("=" * 60)
print("METHOD 1: extract_final_states (final state only)")
print("=" * 60)

start_time = time.time()
final_states = extractor.extract_final_states(input_ids)
time_method1 = time.time() - start_time

print(f"\nTime: {time_method1:.4f}s")
print(f"Number of layers: {len(final_states)}")
if final_states:
    first_layer_state = final_states[0]
    print(f"State shape per layer: {first_layer_state.shape}")


In [None]:
# Method 2: extract_incremental_states_dumb_rerunning - All positions (O(N²) - slow)
print("=" * 60)
print("METHOD 2: extract_incremental_states_dumb_rerunning (O(N²) - slow)")
print("=" * 60)

start_time = time.time()
incremental_states = extractor.extract_incremental_states_dumb_rerunning(input_ids)
time_method2 = time.time() - start_time

print(f"\nTime: {time_method2:.4f}s")
print(f"Number of positions: {len(incremental_states)}")
first_pos_states = incremental_states[1]
print(f"Number of layers per position: {len(first_pos_states)}")
print(f"State shape at each position: {first_pos_states[0].shape}")

In [None]:
# Method 3: extract_incremental_states_single_pass - All positions (O(N) - efficient)
print("=" * 60)
print("METHOD 3: extract_incremental_states_single_pass (O(N) - efficient)")
print("=" * 60)

start_time = time.time()
single_pass_states = extractor.extract_incremental_states_single_pass(input_ids)
time_method3 = time.time() - start_time

print(f"\nTime: {time_method3:.4f}s")
print(f"Number of positions: {len(single_pass_states)}")
first_pos_states_sp = single_pass_states[1]
print(f"Number of layers per position: {len(first_pos_states_sp)}")
print(f"State shape at each position: {first_pos_states_sp[0].shape}")

## 4. Compare Results and Verify Correctness

This section verifies that the efficient method produces the same results as the slow method.


In [None]:
# Compare timing
print("=" * 60)
print("TIMING COMPARISON")
print("=" * 60)
seq_len = input_ids.shape[1]
print(f"\nSequence length: {seq_len} tokens")
print(f"\nMethod 1 (final state only):    {time_method1:.4f}s")
print(f"Method 2 (incremental, O(N²)):  {time_method2:.4f}s")
print(f"Method 3 (single-pass, O(N)):   {time_method3:.4f}s")
print(f"\nSpeedup (Method 3 vs Method 2): {time_method2 / time_method3:.2f}x")


In [None]:
# Verify correctness: compare final states from all methods
print("=" * 60)
print("CORRECTNESS VERIFICATION")
print("=" * 60)

seq_len = input_ids.shape[1]

# Compare final state from method 1 with last position from method 2 and 3
print("\n1. Comparing final state (Method 1) vs last position (Method 2):")
all_match_1_vs_2 = True
for layer_idx in final_states.keys():
    state_m1 = final_states[layer_idx]
    state_m2 = incremental_states[seq_len][layer_idx]
    is_close = torch.allclose(state_m1, state_m2, rtol=1e-4, atol=1e-6)
    if not is_close:
        all_match_1_vs_2 = False
        max_diff = (state_m1 - state_m2).abs().max().item()
        print(f"   Layer {layer_idx}: MISMATCH (max diff: {max_diff:.2e})")
print(f"   All layers match: {all_match_1_vs_2}")

print("\n2. Comparing final state (Method 1) vs last position (Method 3):")
all_match_1_vs_3 = True
for layer_idx in final_states.keys():
    state_m1 = final_states[layer_idx]
    state_m3 = single_pass_states[seq_len][layer_idx]
    is_close = torch.allclose(state_m1, state_m3, rtol=1e-4, atol=1e-6)
    if not is_close:
        all_match_1_vs_3 = False
        max_diff = (state_m1 - state_m3).abs().max().item()
        print(f"   Layer {layer_idx}: MISMATCH (max diff: {max_diff:.2e})")
print(f"   All layers match: {all_match_1_vs_3}")

print("\n3. Comparing all positions (Method 2 vs Method 3):")
all_match_2_vs_3 = True
mismatches = []
for pos in incremental_states.keys():
    for layer_idx in incremental_states[pos].keys():
        state_m2 = incremental_states[pos][layer_idx]
        state_m3 = single_pass_states[pos][layer_idx]
        is_close = torch.allclose(state_m2, state_m3, rtol=1e-4, atol=1e-6)
        if not is_close:
            all_match_2_vs_3 = False
            max_diff = (state_m2 - state_m3).abs().max().item()
            mismatches.append((pos, layer_idx, max_diff))

if mismatches:
    print(f"   Found {len(mismatches)} mismatches:")
    for pos, layer_idx, diff in mismatches[:5]:
        print(f"     Position {pos}, Layer {layer_idx}: max diff = {diff:.2e}")
else:
    print(f"   All positions match: {all_match_2_vs_3}")


In [None]:
# Visualize timing comparison
fig, ax = plt.subplots(figsize=(10, 6))

methods = ['Method 1\n(final only)', 'Method 2\n(incremental O(N²))', 'Method 3\n(single-pass O(N))']
times = [time_method1, time_method2, time_method3]
colors = ['#2ecc71', '#e74c3c', '#3498db']

bars = ax.bar(methods, times, color=colors, edgecolor='black', linewidth=1.5)

for bar, t in zip(bars, times):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
            f'{t:.3f}s', ha='center', va='bottom', fontsize=12, fontweight='bold')

ax.set_ylabel('Time (seconds)', fontsize=12)
ax.set_title(f'State Extraction Methods Comparison\n(Sequence length: {seq_len} tokens)', fontsize=14)
ax.set_ylim(0, max(times) * 1.2)

plt.tight_layout()
plt.show()

print(f"\nSummary:")
print(f"- For final state only: Use Method 1 (fastest)")
print(f"- For all intermediate states: Use Method 3 (single-pass)")
print(f"- Method 3 is {time_method2/time_method3:.1f}x faster than Method 2 for this sequence")
