# Example 7: Save Activations and Attention Masks Together

This notebook demonstrates how to:
1. Load Bielik model
2. Attach two activation saver hooks:
   - LayerActivationDetector for layer activations
   - ModelInputDetector for attention masks
3. Run inference on a small dataset in batches
4. Save both activations and attention masks per batch
5. Verify everything was saved correctly to disk

This verifies that we can fulfill the user's request to have attention masks
easily accessible for each batch of representations, matching the activation batch structure.


In [28]:
# Setup and imports
%load_ext autoreload
%autoreload 2

import torch
from pathlib import Path
from datetime import datetime

from amber.datasets import TextDataset
from amber.language_model.language_model import LanguageModel
from amber.hooks.implementations.layer_activation_detector import LayerActivationDetector
from amber.hooks.implementations.model_input_detector import ModelInputDetector
from amber.hooks import HookType
from amber.store.local_store import LocalStore
from datasets import load_dataset

print("‚úÖ Imports completed")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
‚úÖ Imports completed


In [29]:
# Configuration
MODEL_ID = "speakleash/Bielik-1.5B-v3.0-Instruct"
STORE_DIR = Path("store")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
MAX_LENGTH = 128
DATA_LIMIT = 10

HF_DATASET = "roneneldan/TinyStories"
TEXT_FIELD = "text"
DATA_SPLIT = "train"

print("‚öôÔ∏è Configuration:")
print(f"   Model: {MODEL_ID}")
print(f"   Device: {DEVICE}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Max length: {MAX_LENGTH}")
print(f"   Dataset: {HF_DATASET}")
print(f"   Data limit: {DATA_LIMIT} samples")

‚öôÔ∏è Configuration:
   Model: speakleash/Bielik-1.5B-v3.0-Instruct
   Device: cpu
   Batch size: 4
   Max length: 128
   Dataset: roneneldan/TinyStories
   Data limit: 10 samples


In [30]:
# Step 1: Load Bielik model
print("üì• Loading Bielik model...")

store = LocalStore(STORE_DIR)
lm = LanguageModel.from_huggingface(MODEL_ID, store=store)
lm.model.to(DEVICE)

print(f"‚úÖ Model loaded: {lm.model_id}")
print(f"üì± Device: {DEVICE}")
print(f"üìÅ Store location: {lm.store.base_path}")

üì• Loading Bielik model...
‚úÖ Model loaded: speakleash_Bielik-1.5B-v3.0-Instruct
üì± Device: cpu
üìÅ Store location: store


In [31]:
# Step 2: Create small dataset
print("üìä Creating dataset...")

hf_dataset = load_dataset(HF_DATASET, split=DATA_SPLIT, streaming=False)
if DATA_LIMIT > 0:
    hf_dataset = hf_dataset.select(range(min(DATA_LIMIT, len(hf_dataset))))

dataset = TextDataset(hf_dataset, store=store, text_field=TEXT_FIELD)

print(f"‚úÖ Dataset created: {len(dataset)} samples")
print(f"üìù Sample text: {dataset[0][:100]}..." if len(dataset[0]) > 100 else f"üìù Sample text: {dataset[0]}")

üìä Creating dataset...


Saving the dataset (1/1 shards): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10/10 [00:00<00:00, 5252.07 examples/s]

‚úÖ Dataset created: 10 samples
üìù Sample text: One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with...





In [32]:
# Step 3: Find two layers to attach activation detectors to
print("üîç Finding layers to attach activation detectors...")

layer_names = lm.layers.get_layer_names()
print(f"üìã Found {len(layer_names)} layers")

# Find transformer layers (usually contains 'transformer' or 'layer')
transformer_layers = [name for name in layer_names if 'transformer' in name.lower() or 'layer' in name.lower()]
if transformer_layers:
    # Try to find attention layers
    attention_layers = [name for name in transformer_layers if 'attn' in name.lower()]
    if len(attention_layers) >= 2:
        LAYER_SIGNATURE_1 = attention_layers[0]
        LAYER_SIGNATURE_2 = attention_layers[1]
    elif len(attention_layers) == 1:
        LAYER_SIGNATURE_1 = attention_layers[0]
        # Find another layer (maybe MLP or norm)
        other_layers = [name for name in transformer_layers if name != LAYER_SIGNATURE_1]
        if other_layers:
            LAYER_SIGNATURE_2 = other_layers[0]
        else:
            LAYER_SIGNATURE_2 = transformer_layers[1] if len(transformer_layers) > 1 else None
    else:
        LAYER_SIGNATURE_1 = transformer_layers[0]
        LAYER_SIGNATURE_2 = transformer_layers[1] if len(transformer_layers) > 1 else None
else:
    LAYER_SIGNATURE_1 = layer_names[0] if layer_names else None
    LAYER_SIGNATURE_2 = layer_names[1] if len(layer_names) > 1 else None

if LAYER_SIGNATURE_1 and LAYER_SIGNATURE_2:
    print(f"‚úÖ Selected layer 1: {LAYER_SIGNATURE_1}")
    print(f"‚úÖ Selected layer 2: {LAYER_SIGNATURE_2}")
else:
    raise ValueError("Could not find two suitable layers")

üîç Finding layers to attach activation detectors...
üìã Found 422 layers
‚úÖ Selected layer 1: llamaforcausallm_model_layers_0_self_attn
‚úÖ Selected layer 2: llamaforcausallm_model_layers_0_self_attn_q_proj


In [33]:
# Step 4: Attach three hooks (one ModelInputDetector and two LayerActivationDetectors)
print("üîß Attaching hooks...")
print()

# Hook 1: ModelInputDetector for attention masks
print("1Ô∏è‚É£ Setting up ModelInputDetector for attention masks...")
attention_mask_layer_sig = "attention_masks"
root_model = lm.model

# Add layer signature to registry for root model
if attention_mask_layer_sig not in lm.layers.name_to_layer:
    lm.layers.name_to_layer[attention_mask_layer_sig] = root_model
    print(f"   üìù Added '{attention_mask_layer_sig}' to layers registry")

attention_mask_detector = ModelInputDetector(
    layer_signature=attention_mask_layer_sig,
    hook_id="attention_mask_detector",
    save_input_ids=False,
    save_attention_mask=True,
)
attention_mask_hook_id = lm.layers.register_hook(
    attention_mask_layer_sig, attention_mask_detector, HookType.PRE_FORWARD
)
print(f"   ‚úÖ Attached to root model")
print(f"   üÜî Hook ID: {attention_mask_hook_id}")
print()

# Hook 2: LayerActivationDetector for first layer activations
print("2Ô∏è‚É£ Setting up LayerActivationDetector for first layer...")
activation_detector_1 = LayerActivationDetector(
    layer_signature=LAYER_SIGNATURE_1,
    hook_id="activation_detector_1"
)
activation_hook_id_1 = lm.layers.register_hook(LAYER_SIGNATURE_1, activation_detector_1, HookType.FORWARD)
print(f"   ‚úÖ Attached to layer: {LAYER_SIGNATURE_1}")
print(f"   üÜî Hook ID: {activation_hook_id_1}")
print()

# Hook 3: LayerActivationDetector for second layer activations
print("3Ô∏è‚É£ Setting up LayerActivationDetector for second layer...")
activation_detector_2 = LayerActivationDetector(
    layer_signature=LAYER_SIGNATURE_2,
    hook_id="activation_detector_2"
)
activation_hook_id_2 = lm.layers.register_hook(LAYER_SIGNATURE_2, activation_detector_2, HookType.FORWARD)
print(f"   ‚úÖ Attached to layer: {LAYER_SIGNATURE_2}")
print(f"   üÜî Hook ID: {activation_hook_id_2}")
print()
print("‚úÖ All hooks attached successfully!")

üîß Attaching hooks...

1Ô∏è‚É£ Setting up ModelInputDetector for attention masks...
   üìù Added 'attention_masks' to layers registry
   ‚úÖ Attached to root model
   üÜî Hook ID: attention_mask_detector

2Ô∏è‚É£ Setting up LayerActivationDetector for first layer...
   ‚úÖ Attached to layer: llamaforcausallm_model_layers_0_self_attn
   üÜî Hook ID: activation_detector_1

3Ô∏è‚É£ Setting up LayerActivationDetector for second layer...
   ‚úÖ Attached to layer: llamaforcausallm_model_layers_0_self_attn_q_proj
   üÜî Hook ID: activation_detector_2

‚úÖ All hooks attached successfully!


In [34]:
# Step 5: Run inference on dataset in batches and save
print("üöÄ Running inference on dataset in batches...")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Total samples: {len(dataset)}")
print()

run_name = f"activations_with_masks_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
print(f"üìÅ Run name: {run_name}")
print()

batch_counter = 0

with torch.inference_mode():
    for batch_index, batch in enumerate(dataset.iter_batches(BATCH_SIZE)):
        # Extract texts from batch
        texts = dataset.extract_texts_from_batch(batch)
        
        # Clear previous captures
        attention_mask_detector.clear_captured()
        activation_detector_1.clear_captured()
        activation_detector_2.clear_captured()
        
        # Run inference
        output, encodings = lm.forwards(
            texts,
            tok_kwargs={
                "max_length": MAX_LENGTH,
                "padding": True,
                "truncation": True,
                "add_special_tokens": True
            },
            autocast=False,
        )
        
        # For HuggingFace models, we need to manually set attention masks from encodings
        # because pre_forward hook doesn't receive kwargs
        attention_mask_detector.set_inputs_from_encodings(encodings, module=lm.model)
        
        # Save detector metadata for this batch
        lm.save_detector_metadata(run_name, batch_index)
        
        batch_counter += 1
        print(f"‚úÖ Saved batch {batch_index} ({len(texts)} samples)")

print()
print(f"‚úÖ Completed! Saved {batch_counter} batches")

üöÄ Running inference on dataset in batches...
   Batch size: 4
   Total samples: 10

üìÅ Run name: activations_with_masks_20251209_213113

‚úÖ Saved batch 0 (4 samples)
‚úÖ Saved batch 1 (4 samples)
‚úÖ Saved batch 2 (2 samples)

‚úÖ Completed! Saved 3 batches


In [35]:
# Step 6: Verify saved data
print("üîç Verifying saved data...")
print()

# Get list of batches
batches = lm.context.store.list_run_batches(run_name)
print(f"üì¶ Found {len(batches)} batches in store")
print()

# Load first batch to inspect
batch_idx = 0
retrieved_metadata, retrieved_tensors = lm.context.store.get_detector_metadata(run_name, batch_idx)

print(f"üìä Batch {batch_idx} structure:")
print(f"   Layers with data: {list(retrieved_tensors.keys())}")
print()

# Check activations
if str(LAYER_SIGNATURE) in retrieved_tensors:
    activations = retrieved_tensors[str(LAYER_SIGNATURE)].get("activations")
    if activations is not None:
        print(f"‚úÖ Activations found:")
        print(f"   Shape: {activations.shape}  # [batch_size, seq_len, d_model]")
        print(f"   Dtype: {activations.dtype}")
        print(f"   Device: {activations.device}")
    else:
        print("‚ùå Activations not found")
else:
    print(f"‚ùå Layer {LAYER_SIGNATURE} not found in saved data")
print()

# Check attention masks
if "attention_masks" in retrieved_tensors:
    attention_mask = retrieved_tensors["attention_masks"].get("attention_mask")
    if attention_mask is not None:
        print(f"‚úÖ Attention masks found:")
        print(f"   Shape: {attention_mask.shape}  # [batch_size, seq_len]")
        print(f"   Dtype: {attention_mask.dtype}")
        print(f"   Device: {attention_mask.device}")
        print(f"   Sample values (first 5 tokens of first 3 samples):")
        print(f"   {attention_mask[:3, :5].tolist()}")
    else:
        print("‚ùå Attention mask not found")
else:
    print("‚ùå Attention masks layer not found in saved data")

üîç Verifying saved data...

üì¶ Found 3 batches in store

üìä Batch 0 structure:
   Layers with data: ['attention_masks', 'llamaforcausallm_model_layers_0_self_attn', 'llamaforcausallm_model_layers_0_self_attn_q_proj']

‚úÖ Activations found:
   Shape: torch.Size([4, 128, 1536])  # [batch_size, seq_len, d_model]
   Dtype: torch.float32
   Device: cpu

‚úÖ Attention masks found:
   Shape: torch.Size([4, 128])  # [batch_size, seq_len]
   Dtype: torch.bool
   Device: cpu
   Sample values (first 5 tokens of first 3 samples):
   [[False, True, True, True, True], [False, True, True, True, True], [False, True, True, True, True]]


In [36]:
# Step 7: Verify shapes match and demonstrate usage
print("üîó Verifying activation-attention mask matching...")
print()

if str(LAYER_SIGNATURE) in retrieved_tensors and "attention_masks" in retrieved_tensors:
    activations = retrieved_tensors[str(LAYER_SIGNATURE)]["activations"]
    attention_mask = retrieved_tensors["attention_masks"]["attention_mask"]
    
    batch_size, seq_len, d_model = activations.shape
    mask_batch_size, mask_seq_len = attention_mask.shape
    
    print(f"üìä Shape comparison:")
    print(f"   Activations: {activations.shape}  # [batch_size, seq_len, d_model]")
    print(f"   Attention mask: {attention_mask.shape}  # [batch_size, seq_len]")
    print()
    
    if batch_size == mask_batch_size and seq_len == mask_seq_len:
        print("‚úÖ Shapes match perfectly!")
        print()
        
        # Demonstrate filtering activations using attention mask
        print("üí° Example: Filtering activations for regular (non-padding) tokens:")
        print()
        
        sample_idx = 0
        sample_activations = activations[sample_idx]  # [seq_len, d_model]
        sample_mask = attention_mask[sample_idx]  # [seq_len]
        
        num_regular_tokens = sample_mask.sum().item()
        print(f"   Sample {sample_idx}:")
        print(f"      Total tokens: {seq_len}")
        print(f"      Regular tokens (non-padding): {num_regular_tokens}")
        print(f"      Padding tokens: {seq_len - num_regular_tokens}")
        print()
        
        # Filter activations to only regular tokens
        regular_activations = sample_activations[sample_mask.bool()]  # [num_regular_tokens, d_model]
        print(f"      Filtered activations shape: {regular_activations.shape}")
        print(f"      ‚úÖ Successfully filtered to only regular tokens!")
    else:
        print("‚ùå Shape mismatch!")
        print(f"   Batch size: {batch_size} vs {mask_batch_size}")
        print(f"   Sequence length: {seq_len} vs {mask_seq_len}")
else:
    print("‚ùå Cannot verify - missing activations or attention masks")

üîó Verifying activation-attention mask matching...

üìä Shape comparison:
   Activations: torch.Size([4, 128, 1536])  # [batch_size, seq_len, d_model]
   Attention mask: torch.Size([4, 128])  # [batch_size, seq_len]

‚úÖ Shapes match perfectly!

üí° Example: Filtering activations for regular (non-padding) tokens:

   Sample 0:
      Total tokens: 128
      Regular tokens (non-padding): 127
      Padding tokens: 1

      Filtered activations shape: torch.Size([127, 1536])
      ‚úÖ Successfully filtered to only regular tokens!


In [37]:
# Step 8: Verify all batches
print("üîç Verifying all batches...")
print()

all_batches_valid = True
for batch_idx in range(len(batches)):
    retrieved_metadata, retrieved_tensors = lm.store.get_detector_metadata(run_name, batch_idx)
    
    has_activations_1 = str(LAYER_SIGNATURE_1) in retrieved_tensors and \
                        "activations" in retrieved_tensors[str(LAYER_SIGNATURE_1)]
    has_activations_2 = str(LAYER_SIGNATURE_2) in retrieved_tensors and \
                        "activations" in retrieved_tensors[str(LAYER_SIGNATURE_2)]
    has_attention_mask = "attention_masks" in retrieved_tensors and \
                        "attention_mask" in retrieved_tensors["attention_masks"]
    
    if has_activations_1 and has_activations_2 and has_attention_mask:
        activations_1 = retrieved_tensors[str(LAYER_SIGNATURE_1)]["activations"]
        activations_2 = retrieved_tensors[str(LAYER_SIGNATURE_2)]["activations"]
        attention_mask = retrieved_tensors["attention_masks"]["attention_mask"]
        
        # Verify shapes match
        if (activations_1.shape[:2] == attention_mask.shape and
            activations_2.shape[:2] == attention_mask.shape):
            print(f"‚úÖ Batch {batch_idx}: layer1 {activations_1.shape}, layer2 {activations_2.shape}, mask {attention_mask.shape}")
        else:
            print(f"‚ùå Batch {batch_idx}: shape mismatch!")
            all_batches_valid = False
    else:
        print(f"‚ùå Batch {batch_idx}: missing data (layer1: {has_activations_1}, layer2: {has_activations_2}, mask: {has_attention_mask})")
        all_batches_valid = False

print()
if all_batches_valid:
    print("‚úÖ All batches verified successfully!")
    print(f"üìÅ Run name: {run_name}")
    print(f"üìÅ Store location: {lm.context.store.base_path}")
    print()
    print("üí° Summary:")
    print(f"   - Activations saved per batch: [batch_size, seq_len, d_model]")
    print(f"   - Attention masks saved per batch: [batch_size, seq_len]")
    print(f"   - Both are easily accessible and matched per batch")
    print(f"   - No need to run separate inference for attention masks")
else:
    print("‚ùå Some batches failed verification")

üîç Verifying all batches...

‚úÖ Batch 0: layer1 torch.Size([4, 128, 1536]), layer2 torch.Size([4, 128, 1536]), mask torch.Size([4, 128])
‚úÖ Batch 1: layer1 torch.Size([4, 128, 1536]), layer2 torch.Size([4, 128, 1536]), mask torch.Size([4, 128])
‚úÖ Batch 2: layer1 torch.Size([2, 128, 1536]), layer2 torch.Size([2, 128, 1536]), mask torch.Size([2, 128])

‚úÖ All batches verified successfully!
üìÅ Run name: activations_with_masks_20251209_213113
üìÅ Store location: store

üí° Summary:
   - Activations saved per batch: [batch_size, seq_len, d_model]
   - Attention masks saved per batch: [batch_size, seq_len]
   - Both are easily accessible and matched per batch
   - No need to run separate inference for attention masks


## Summary

This example demonstrated:

1. ‚úÖ **Loading Bielik model** - Successfully loaded from HuggingFace
2. ‚úÖ **Attaching two activation saver hooks** - LayerActivationDetector and ModelInputDetector
3. ‚úÖ **Running inference on dataset** - Processed dataset in batches
4. ‚úÖ **Saving both activations and attention masks** - Saved per batch, matching structure
5. ‚úÖ **Verification** - Confirmed all data saved correctly to disk

**Key Benefits:**
- Attention masks are saved per batch, matching activation batch structure
- Both are easily accessible from the same batch using the store API
- No need to run separate inference to get attention masks
- Shapes match perfectly: activations `[batch_size, seq_len, d_model]` and masks `[batch_size, seq_len]`
- Can easily filter activations to only regular (non-padding) tokens using attention masks

**Conclusion:** ‚úÖ The user's request can be fulfilled with the current tools!
We can attach both LayerActivationDetector and ModelInputDetector, run inference once,
and save both activations and attention masks together per batch.