# Example 5: Special Token Mask Detection

This notebook demonstrates how to:
1. Load a language model
2. Create `ModelInputDetector` with special token mask detection enabled
3. Use model's automatic special token detection OR provide custom special token IDs
4. Run inference and capture special token masks
5. Visualize and verify the mask correctness
6. Save the mask to store

The special token mask is a binary mask (1 for special tokens, 0 for regular tokens) that has the same shape as `input_ids`. This is useful for:
- Filtering out special tokens during analysis
- Understanding tokenization behavior
- Creating attention masks that exclude special tokens
- Analyzing model behavior on special vs regular tokens


In [10]:
%load_ext autoreload
%autoreload 2

import torch
from pathlib import Path
import numpy as np

from amber.hooks import ModelInputDetector
from amber.language_model.language_model import LanguageModel
from amber.store.local_store import LocalStore

print("‚úÖ Imports completed")


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
‚úÖ Imports completed


In [3]:
MODEL_ID = "speakleash/Bielik-1.5B-v3.0-Instruct"
STORE_DIR = Path("store")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

TEST_TEXTS = [
    "Hello world! This is a test.",
    "The quick brown fox jumps over the lazy dog.",
    "Machine learning is fascinating.",
]

print("üöÄ Starting Special Token Mask Example")
print(f"üì± Using device: {DEVICE}")
print(f"üîß Model: {MODEL_ID}")
print(f"üìù Number of test texts: {len(TEST_TEXTS)}")
print()

STORE_DIR.mkdir(parents=True, exist_ok=True)
print("‚úÖ Output directories created")


üöÄ Starting Special Token Mask Example
üì± Using device: cpu
üîß Model: speakleash/Bielik-1.5B-v3.0-Instruct
üìù Number of test texts: 3

‚úÖ Output directories created


In [4]:
print("üì• Loading language model...")

store = LocalStore(STORE_DIR)
lm = LanguageModel.from_huggingface(MODEL_ID, store=store)
lm.model.to(DEVICE)

print(f"‚úÖ Model loaded: {lm.model_id}")
print(f"üì± Device: {DEVICE}")
print(f"üìÅ Store location: {lm.context.store.base_path}")
print()

tokenizer = lm.tokenizer
print("üîç Tokenizer special tokens:")
special_token_attrs = ['pad_token_id', 'eos_token_id', 'bos_token_id', 'unk_token_id', 
                       'cls_token_id', 'sep_token_id', 'mask_token_id']
special_tokens = {}
for attr in special_token_attrs:
    token_id = getattr(tokenizer, attr, None)
    if token_id is not None:
        special_tokens[attr] = token_id
        print(f"  {attr}: {token_id}")

if hasattr(tokenizer, 'all_special_ids'):
    print(f"  all_special_ids: {tokenizer.all_special_ids}")


üì• Loading language model...
‚úÖ Model loaded: speakleash_Bielik-1.5B-v3.0-Instruct
üì± Device: cpu
üìÅ Store location: store

üîç Tokenizer special tokens:
  pad_token_id: 2
  eos_token_id: 4
  bos_token_id: 1
  unk_token_id: 0
  all_special_ids: [1, 4, 0, 2, 3, 5, 6]


## Option 1: Auto-detect Special Tokens from Model

The detector will automatically extract special token IDs from the model's tokenizer or config.


In [5]:
print("üîß Creating ModelInputDetector with auto-detection of special tokens...")

layer_signature = "model_inputs_with_mask"
if layer_signature not in lm.layers.name_to_layer:
    lm.layers.name_to_layer[layer_signature] = lm.model
    print(f"üìù Added '{layer_signature}' to layers registry")

input_detector = ModelInputDetector(
    layer_signature=layer_signature,
    hook_id="model_input_detector_with_mask",
    save_input_ids=True,
    save_attention_mask=False,
    save_special_token_mask=True,
    special_token_ids=None,
)

hook_id = lm.layers.register_hook(layer_signature, input_detector)

print(f"‚úÖ Detector attached to model")
print(f"üÜî Detector ID: {input_detector.id}")
print(f"üíæ Will save: input_ids, special_token_mask")


üîß Creating ModelInputDetector with auto-detection of special tokens...
üìù Added 'model_inputs_with_mask' to layers registry
‚úÖ Detector attached to model
üÜî Detector ID: model_input_detector_with_mask
üíæ Will save: input_ids, special_token_mask


In [7]:
print("üöÄ Running inference...")
print(f"üìù Processing {len(TEST_TEXTS)} texts")

input_detector.clear_captured()

output, encodings = lm.forwards(
    TEST_TEXTS,
    tok_kwargs={"max_length": 128, "padding": True, "truncation": True, "add_special_tokens": True},
    autocast=False,
)

input_detector.set_inputs_from_encodings(encodings, module=lm.model)

print("‚úÖ Inference completed")
print(f"üìä Encodings keys: {list(encodings.keys())}")
print()
print("üí° Data captured in detector - ready to inspect")


üöÄ Running inference...
üìù Processing 3 texts
‚úÖ Inference completed
üìä Encodings keys: ['input_ids', 'attention_mask']

üí° Data captured in detector - ready to inspect


In [8]:
input_ids = input_detector.get_captured_input_ids()
special_token_mask = input_detector.get_captured_special_token_mask()

print("üìä Captured Data:")
print(f"  input_ids shape: {input_ids.shape}")
print(f"  special_token_mask shape: {special_token_mask.shape}")
print(f"  special_token_mask dtype: {special_token_mask.dtype}")
print()

assert input_ids.shape == special_token_mask.shape, "Shapes must match!"
print("‚úÖ Shapes match!")
print()

print("üîç Special Token Mask Analysis:")
for i, text in enumerate(TEST_TEXTS):
    print(f"\nText {i+1}: {text[:50]}...")
    print(f"  input_ids: {input_ids[i].tolist()}")
    print(f"  mask:      {special_token_mask[i].int().tolist()}")
    
    num_special = special_token_mask[i].sum().item()
    num_total = len(input_ids[i])
    print(f"  Special tokens: {num_special}/{num_total} ({100*num_special/num_total:.1f}%)")


üìä Captured Data:
  input_ids shape: torch.Size([3, 18])
  special_token_mask shape: torch.Size([3, 18])
  special_token_mask dtype: torch.bool

‚úÖ Shapes match!

üîç Special Token Mask Analysis:

Text 1: Hello world! This is a test....
  input_ids: [2, 2, 2, 2, 2, 2, 2, 2, 1, 10404, 397, 22299, 31964, 22382, 3707, 322, 6291, 31917]
  mask:      [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  Special tokens: 9/18 (50.0%)

Text 2: The quick brown fox jumps over the lazy dog....
  input_ids: [1, 2091, 9108, 23156, 31225, 31892, 2228, 31967, 590, 4742, 31896, 17419, 1226, 1347, 395, 303, 31908, 31917]
  mask:      [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  Special tokens: 1/18 (5.6%)

Text 3: Machine learning is fascinating....
  input_ids: [2, 2, 2, 2, 2, 1, 739, 1437, 289, 568, 300, 4957, 3707, 1075, 31896, 5910, 19217, 31917]
  mask:      [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  Special tokens: 6/18 (33.3%)


In [9]:
# Step 5: Visualize mask
print("üìà Visualizing Special Token Mask:")
print()

for i, text in enumerate(TEST_TEXTS):
    print(f"Text {i+1}: {text}")
    print(f"  Token IDs:     {input_ids[i].tolist()}")
    print(f"  Special Mask:  {' '.join(['‚ñà' if m else '‚ñë' for m in special_token_mask[i].tolist()])}")
    print(f"  Values:        {' '.join(['1' if m else '0' for m in special_token_mask[i].tolist()])}")
    print()
    
    # Decode tokens to verify
    token_ids_list = input_ids[i].tolist()
    mask_list = special_token_mask[i].tolist()
    
    print("  Token breakdown:")
    for j, (token_id, is_special) in enumerate(zip(token_ids_list, mask_list)):
        token_str = tokenizer.decode([token_id])
        special_marker = "[SPECIAL]" if is_special else ""
        print(f"    [{j:2d}] ID={token_id:4d} | {token_str:20s} {special_marker}")
    print()


üìà Visualizing Special Token Mask:

Text 1: Hello world! This is a test.
  Token IDs:     [2, 2, 2, 2, 2, 2, 2, 2, 1, 10404, 397, 22299, 31964, 22382, 3707, 322, 6291, 31917]
  Special Mask:  ‚ñà ‚ñà ‚ñà ‚ñà ‚ñà ‚ñà ‚ñà ‚ñà ‚ñà ‚ñë ‚ñë ‚ñë ‚ñë ‚ñë ‚ñë ‚ñë ‚ñë ‚ñë
  Values:        1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0

  Token breakdown:
    [ 0] ID=   2 | </s>                 [SPECIAL]
    [ 1] ID=   2 | </s>                 [SPECIAL]
    [ 2] ID=   2 | </s>                 [SPECIAL]
    [ 3] ID=   2 | </s>                 [SPECIAL]
    [ 4] ID=   2 | </s>                 [SPECIAL]
    [ 5] ID=   2 | </s>                 [SPECIAL]
    [ 6] ID=   2 | </s>                 [SPECIAL]
    [ 7] ID=   2 | </s>                 [SPECIAL]
    [ 8] ID=   1 | <s>                  [SPECIAL]
    [ 9] ID=10404 | Hel                  
    [10] ID= 397 | lo                   
    [11] ID=22299 | world                
    [12] ID=31964 | !                    
    [13] ID=22382 | This                 
  