In [None]:
import torch
from transformers import AutoTokenizer
import json
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from mamba_ssm.utils.generation import InferenceParams
import os
import sys

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def compute_min_max_state_norms(tensor: torch.Tensor) -> tuple[float, float]:

    norms = torch.norm(tensor, p=2, dim=-1)  # shape: [batch_size, num_heads, num_states]
    
    if norms.shape[0] == 1:
        norms = norms.squeeze(0)  # shape: [num_heads, num_states]
    
    min_norm = torch.min(norms).item()  # global minimum norm
    max_norm = torch.max(norms).item()  # global maximum norm
    
    return min_norm, max_norm
model_name = "state-spaces/mamba2-1.3b"

if "state-spaces/mamba" in model_name and "hf" not in "state-spaces/mamba":
    tokenizer = AutoTokenizer.from_pretrained(
            "EleutherAI/gpt-neox-20b", model_max_length=sys.maxsize, trust_remote_code=True
        )
    model = MambaLMHeadModel.from_pretrained(model_name).cuda()
else:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
model.eval()

name = 'pg19'
data_path = f'data/{name}'
name = os.path.join(data_path, name)

  return torch.load(resolved_archive_file, map_location=mapped_device)


In [3]:
with open(f"{name}.json", "r", encoding="utf-8") as f:
    loaded_texts = json.load(f)

In [None]:
longest = loaded_texts

print(f"Longest length (chars): {len(longest)}")
tokens = tokenizer(longest, truncation=False, return_tensors="pt")
inputs = tokens
second_diff = 0
second_layer_outlier = 0
second_state_outlier = 0
second_mmax = 0
second_mmin = 0

max_list = []
min_list = []

for max_l in [1024, 2048, 4096, 8192, 16384, 32768, 65536][::-1]:
    seq_len = inputs.input_ids[:, :max_l].shape[1]
    print(f'seq_len: {seq_len}')
    inference_params = InferenceParams(max_seqlen=seq_len, max_batch_size=1)

    # Run model
    start = 0
    with torch.no_grad():
        _ = model(input_ids=inputs.input_ids[:, start:max_l+start].cuda(), inference_params=inference_params)

    print("Layer-wise SSM state norms:")
    layer_states = inference_params.key_value_memory_dict

    diff = 0
    layer_outlier = 0
    state_outlier = 0
    mmax = 0
    mmin = 0

    all_diffs = []

    for layer_idx, state_tuple in layer_states.items():
        for state_idx, state in enumerate(state_tuple):
            if 'mamba2' in model_name:
                min_val, max_val = compute_min_max_state_norms(state[-1])
            else:
                min_val, max_val = compute_min_max_state_norms(state[0])
            
            cur_diff = max_val - min_val
            all_diffs.append((cur_diff, layer_idx, state_idx, min_val, max_val))
            max_list.append(max_val)
            min_list.append(min_val)

            if cur_diff > diff:
                second_diff = diff
                second_layer_outlier = layer_outlier
                second_state_outlier = state_outlier
                second_mmax = mmax
                second_mmin = mmin

                diff = cur_diff
                layer_outlier = layer_idx
                state_outlier = state_idx
                mmax = max_val
                mmin = min_val

    mean_max = sum(max_list) / len(max_list) if max_list else 0
    mean_min = sum(min_list) / len(min_list) if min_list else 0

    print(f'maxlength: {max_l}-layer idx-{layer_outlier}-state idx-{state_outlier}_min:{mmin:.4f}-max:{mmax:.4f}')
    print(f'2nd max diff layer idx-{second_layer_outlier}-state idx-{second_state_outlier}_min:{second_mmin:.4f}-max:{second_mmax:.4f}')
    print(f'mean of all max values: {mean_max:.4f}')
    print(f'mean of all min values: {mean_min:.4f}')
    print(f'-----------------{max_l}-----------------------')

Longest length (chars): 1665152
seq_len: 65536
Layer-wise SSM state norms:
maxlength: 65536-layer idx-32-state idx-1_min:0.0002-max:1242.7589
2nd max diff layer idx-8-state idx-1_min:0.0021-max:271.7547
mean of all max values: 35.4999
mean of all min values: 0.0616
-----------------65536-----------------------
seq_len: 32768
Layer-wise SSM state norms:
maxlength: 32768-layer idx-32-state idx-1_min:0.0003-max:1201.5530
2nd max diff layer idx-8-state idx-1_min:0.0027-max:208.3220
mean of all max values: 34.3853
mean of all min values: 0.0674
-----------------32768-----------------------
seq_len: 16384
Layer-wise SSM state norms:
maxlength: 16384-layer idx-32-state idx-1_min:0.0005-max:1241.3979
2nd max diff layer idx-28-state idx-1_min:0.0047-max:227.3754
mean of all max values: 34.1940
mean of all min values: 0.0674
-----------------16384-----------------------
seq_len: 8192
Layer-wise SSM state norms:
maxlength: 8192-layer idx-32-state idx-1_min:0.0004-max:876.0833
2nd max diff layer i