In [1]:
import sys
sys.path.insert(0, './mamba')
import os
import torch
import time
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.utils.generation import InferenceParams
from transformers import AutoTokenizer

# Setup
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
MODEL_NAME = "state-spaces/mamba-2.8b"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading {MODEL_NAME}...")
model = MambaLMHeadModel.from_pretrained(MODEL_NAME, device=DEVICE)
model.eval()
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

print(f"✓ Model loaded: {sum(p.numel() for p in model.parameters()):,} parameters")

  from .autonotebook import tqdm as notebook_tqdm


Loading state-spaces/mamba-2.8b...


  return torch.load(resolved_archive_file, map_location=mapped_device)


✓ Model loaded: 2,768,345,600 parameters


In [2]:
# Test texts
trunk_1 = "This paper addresses the challenges of running multiple machine learning"
trunk_2 = "models on resource-constrained edge devices, which are often equipped with a variety of processors like CPUs, GPUs, and DSPs. The primary goal is"

# Tokenize
tokens_1 = tokenizer(trunk_1, return_tensors="pt", return_attention_mask=False)
tokens_2 = tokenizer(trunk_2, return_tensors="pt", return_attention_mask=False)
input_ids_1 = tokens_1.input_ids.to(DEVICE)
input_ids_2 = tokens_2.input_ids.to(DEVICE)

print(f"Trunk 1: {input_ids_1.shape} tokens")
print(f"Trunk 2: {input_ids_2.shape} tokens")

Trunk 1: torch.Size([1, 10]) tokens
Trunk 2: torch.Size([1, 32]) tokens


## 1. Prefill trunk_1 and get cache

In [3]:
print("=== Test 1: Prefill trunk_1 ===")

# Initialize cache
inference_params = InferenceParams(
    max_seqlen=2048, max_batch_size=1, seqlen_offset=0, key_value_memory_dict={}
)

# Process trunk 1
with torch.no_grad():
    output_1 = model(input_ids_1, inference_params=inference_params)

# Update offset (Mamba requires manual update)
inference_params.seqlen_offset += input_ids_1.shape[1]

# Extract cache
cache_after_1 = {}
for layer_idx, (conv_state, ssm_state) in inference_params.key_value_memory_dict.items():
    cache_after_1[layer_idx] = {
        'conv_state': conv_state.clone().cpu(),
        'ssm_state': ssm_state.clone().cpu()
    }

print(f"✓ Processed trunk_1: seqlen_offset={inference_params.seqlen_offset}")
print(f"✓ Cache extracted: {len(cache_after_1)} layers")
print(f"Sample cache shape: {list(cache_after_1.values())[0]['conv_state'].shape}")

=== Test 1: Prefill trunk_1 ===
✓ Processed trunk_1: seqlen_offset=10
✓ Cache extracted: 64 layers
Sample cache shape: torch.Size([1, 5120, 4])


## 2. Continue with trunk_2 from existing cache

In [4]:
print("=== Test 2: Continue with trunk_2 ===")

# Save cache before
cache_before_2 = {k: {'conv_state': v[0].clone().cpu(), 'ssm_state': v[1].clone().cpu()} 
                 for k, v in inference_params.key_value_memory_dict.items()}

# Process trunk_2 token by token (Mamba requires this with cache)
with torch.no_grad():
    outputs_2 = []
    for i in range(input_ids_2.shape[1]):
        step_output = model(input_ids_2[:, i:i+1], inference_params=inference_params)
        outputs_2.append(step_output.logits if hasattr(step_output, 'logits') else step_output)
        inference_params.seqlen_offset += 1
    
    output_2 = type('MockOutput', (), {'logits': torch.cat(outputs_2, dim=1)})()

print(f"✓ Processed trunk_2: seqlen_offset={inference_params.seqlen_offset}")
print(f"✓ Final sequence length: {input_ids_1.shape[1] + input_ids_2.shape[1]}")
print(f"✓ Output shape: {output_2.logits.shape}")

=== Test 2: Continue with trunk_2 ===
✓ Processed trunk_2: seqlen_offset=42
✓ Final sequence length: 42
✓ Output shape: torch.Size([1, 32, 50280])


## 3. Decode from cache

In [5]:
print("=== Test 3: Decode from cache ===")

# Generate tokens
max_tokens = 20
generated_tokens = []
current_logits = output_2.logits[:, -1:, :]  # Start with last token
current_token = None
decode_start_time = time.time()

for step in range(max_tokens):
    if step == 0:
        # First step: get last token from previous output
        next_token_logits = output_2.logits[:, -1:, :]
    else:
        # Subsequent steps: single token forward pass
        with torch.no_grad():
            # Prepare input in the correct format [batch, seq_len=1]
            if current_token.dim() == 0:  # scalar
                input_tensor = current_token.unsqueeze(0).unsqueeze(0)
            elif current_token.dim() == 1:  # [seq_len]
                input_tensor = current_token.unsqueeze(0)
            elif current_token.dim() == 2:
                if current_token.shape[0] == 1 and current_token.shape[1] == 1:
                    input_tensor = current_tensor  # already [1, 1]
                else:
                    input_tensor = current_tensor[0:1, -1:]  # ensure [1, 1]
            else:
                raise ValueError(f"Unexpected token shape: {current_token.shape}")

            # Ensure the input has exactly shape [batch_size=1, seq_len=1]
            if input_tensor.shape != (1, 1):
                input_tensor = input_tensor.reshape(1, 1)

            step_output = model(
                input_tensor,
                inference_params=inference_params
            )

            # Handle different output formats
            if hasattr(step_output, 'logits'):
                next_token_logits = step_output.logits
            elif isinstance(step_output, tuple):
                next_token_logits = step_output[0]
            else:
                # Assume it's the logits tensor directly
                next_token_logits = step_output

    # Sample next token
    next_token = torch.argmax(next_token_logits, dim=-1)
    generated_tokens.append(next_token.item())

    # Ensure current_token is a scalar for next iteration
    current_token = next_token.squeeze()  # This should be a scalar

    # Update seqlen_offset after each generated token
    inference_params.seqlen_offset += 1

    if step % 5 == 0:
        token_text = tokenizer.decode(next_token.item())
        print(f"Step {step + 1}: token {next_token.item()} = '{tokenizer.decode(next_token.item())}'")

    # Check if we hit an end token
    if next_token.item() in [tokenizer.eos_token_id, 0]:
        print(f"  End token reached at step {step + 1}")
        break

total_decode_time = time.time() - decode_start_time

print(f"\n✓ Generated {len(generated_tokens)} tokens")
print(f"Total decode time: {total_decode_time*1000:.2f} ms")
print(f"Time per token: {total_decode_time/len(generated_tokens)*1000:.2f} ms/token")

=== Test 3: Decode from cache ===
Step 1: token 281 = ' to'
Step 6: token 476 = ' can'
Step 11: token 4715 = ' learning'
Step 16: token 15 = '.'

✓ Generated 20 tokens
Total decode time: 531.29 ms
Time per token: 26.56 ms/token


## 4. Show results

In [6]:
# Combine and decode
full_input = tokenizer.decode(input_ids_1[0]) + tokenizer.decode(input_ids_2[0])
generated_text = tokenizer.decode(generated_tokens)

print("=== FINAL RESULTS ===")
print(f"Input: '{full_input}'")
print(f"Generated: '{generated_text}'")
print(f"\n✅ Mamba cache API test successful!")
print(f"✅ All 3 tests passed:")
print(f"   1. Prefill trunk_1 ✓")
print(f"   2. Continue with trunk_2 ✓")
print(f"   3. Decode from cache ✓")

=== FINAL RESULTS ===
Input: 'This paper addresses the challenges of running multiple machine learningmodels on resource-constrained edge devices, which are often equipped with a variety of processors like CPUs, GPUs, and DSPs. The primary goal is'
Generated: ' to develop a framework that can efficiently run multiple machine learning models on edge devices. The framework is based'

✅ Mamba cache API test successful!
✅ All 3 tests passed:
   1. Prefill trunk_1 ✓
   2. Continue with trunk_2 ✓
   3. Decode from cache ✓
