# Pseudo-Mamba State Lab: Black Box Mapping DemoThis notebook demonstrates how to use the **Pseudo-Mamba State Lab** tools to extract, visualize, and analyze the internal states of a Mamba model.**Tools Used:**- `trace_model.py`: Extracts full state trajectories from all layers.- `visualize_states.py`: Visualizes state evolution (heatmaps, norms, PCA).

In [None]:
import torchimport matplotlib.pyplot as pltfrom mamba_ssm.models.mixer_seq_simple import MambaLMHeadModelfrom mamba_ssm.models.config_mamba import MambaConfigfrom trace_model import trace_modelfrom visualize_states import plot_state_heatmap, plot_state_norm, analyze_state_pca%matplotlib inline

In [None]:
# 1. Setup a small Mamba model for demonstrationconfig = MambaConfig(    d_model=256,    n_layer=4,    d_state=16,    expand=2,    vocab_size=1000)device = 'cuda' if torch.cuda.is_available() else 'cpu'model = MambaLMHeadModel(config).to(device)print(f'Model created on {device}')

In [None]:
# 2. Create a dummy input sequencebatch_size = 2seq_len = 64input_ids = torch.randint(0, 1000, (batch_size, seq_len), device=device)print(f'Input shape: {input_ids.shape}')

In [None]:
# 3. Run the model with tracingprint('Tracing model execution...')trace = trace_model(model, input_ids)print(f'Captured traces for {len(trace.layer_traces)} layers.')

In [None]:
# 4. Visualize State Norms# This shows how the magnitude of the SSM state evolves over time across layers.plot_state_norm(trace)

In [None]:
# 5. Visualize State Heatmaps# Let's look at the internal state of Layer 0, Dimension 0.layer_idx = 0dim_idx = 0plot_state_heatmap(trace.layer_traces[layer_idx], layer_idx, dim_idx, state_type='ssm')

In [None]:
# 6. PCA Trajectory Analysis# Visualize the state trajectory in low-dimensional space.analyze_state_pca(trace.layer_traces[0], layer_idx=0)

## ConclusionYou have successfully mapped the 'Black Box'! 🕵️‍♂️Use these tools to:- Debug why a model fails on specific tokens.- Analyze how memory is maintained over long sequences.- Correlate internal states with specific input patterns.