# Example 6: Save Activations with Attention Masks

This notebook demonstrates how to:
1. Load a language model (Bielik)
2. Create a text dataset
3. Save activations from a specific layer WITH attention masks
4. Load and verify both activations and attention masks from the store
5. Demonstrate how to match attention masks with activations per batch

This example shows the new feature where attention masks are automatically saved alongside activations when using `save_activations_dataset()`, making it easy to identify which tokens are regular (non-padding) tokens in the internal representations.


In [None]:
# Setup and imports
%load_ext autoreload
%autoreload 2

import torch
from pathlib import Path
from datetime import datetime

from amber.datasets import TextDataset
from amber.language_model.language_model import LanguageModel
from amber.store.local_store import LocalStore

print("‚úÖ Imports completed")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  from .autonotebook import tqdm as notebook_tqdm


‚úÖ Imports completed


In [3]:
# Configuration
MODEL_ID = "speakleash/Bielik-1.5B-v3.0-Instruct"  # Bielik model
HF_DATASET = "roneneldan/TinyStories"
DATA_SPLIT = "train"
TEXT_FIELD = "text"
DATA_LIMIT = 100  # Number of text samples to use
MAX_LENGTH = 128  # Maximum sequence length
BATCH_SIZE_SAVE = 16  # Batch size for saving activations

# Choose which layer to hook - you can inspect available layers with model.layers.print_layer_names()
# For Bielik, we'll use a transformer layer - adjust based on actual layer names
LAYER_SIGNATURE = None  # Will be set after model loading

# Storage locations
STORE_DIR = Path("store")
RUN_ID = f"activations_with_masks_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print("üöÄ Starting Activations with Attention Masks Example")
print(f"üì± Using device: {DEVICE}")
print(f"üîß Model: {MODEL_ID}")
print(f"üìä Dataset: {HF_DATASET}")
print(f"üíæ Run ID: {RUN_ID}")
print()

STORE_DIR.mkdir(parents=True, exist_ok=True)
print("‚úÖ Output directories created")

üöÄ Starting Activations with Attention Masks Example
üì± Using device: cpu
üîß Model: speakleash/Bielik-1.5B-v3.0-Instruct
üìä Dataset: roneneldan/TinyStories
üíæ Run ID: activations_with_masks_20251209_220425

‚úÖ Output directories created


In [4]:
# Step 1: Load language model and store
print("üì• Loading language model...")

store = LocalStore(base_path=STORE_DIR)
lm = LanguageModel.from_huggingface(MODEL_ID, store=store)

print(f"‚úÖ Model loaded: {lm.model_id}")
print(f"üì± Device: {DEVICE}")
print(f"üìÅ Store location: {lm.context.store.base_path}")
print()

# Print available layers to choose one
print("üîç Available layers (first 20):")
layer_names = lm.layers.get_layer_names()
for i, name in enumerate(layer_names[:20]):
    print(f"  {i}: {name}")
if len(layer_names) > 20:
    print(f"  ... and {len(layer_names) - 20} more")
print()

# Auto-select a transformer layer if available, otherwise use first layer
if LAYER_SIGNATURE is None:
    # Try to find a transformer layer
    transformer_layers = [name for name in layer_names if 'transformer' in name.lower() and ('layer' in name.lower() or 'h_' in name.lower())]
    if transformer_layers:
        LAYER_SIGNATURE = transformer_layers[0]
        print(f"üéØ Auto-selected layer: {LAYER_SIGNATURE}")
    else:
        LAYER_SIGNATURE = layer_names[0] if layer_names else 0
        print(f"üéØ Using first layer: {LAYER_SIGNATURE}")
else:
    print(f"üéØ Using specified layer: {LAYER_SIGNATURE}")

üì• Loading language model...
‚úÖ Model loaded: speakleash_Bielik-1.5B-v3.0-Instruct
üì± Device: cpu
üìÅ Store location: store

üîç Available layers (first 20):
  0: llamaforcausallm_model
  1: llamaforcausallm_model_embed_tokens
  2: llamaforcausallm_model_layers
  3: llamaforcausallm_model_layers_0
  4: llamaforcausallm_model_layers_0_self_attn
  5: llamaforcausallm_model_layers_0_self_attn_q_proj
  6: llamaforcausallm_model_layers_0_self_attn_k_proj
  7: llamaforcausallm_model_layers_0_self_attn_v_proj
  8: llamaforcausallm_model_layers_0_self_attn_o_proj
  9: llamaforcausallm_model_layers_0_mlp
  10: llamaforcausallm_model_layers_0_mlp_gate_proj
  11: llamaforcausallm_model_layers_0_mlp_up_proj
  12: llamaforcausallm_model_layers_0_mlp_down_proj
  13: llamaforcausallm_model_layers_0_mlp_act_fn
  14: llamaforcausallm_model_layers_0_input_layernorm
  15: llamaforcausallm_model_layers_0_post_attention_layernorm
  16: llamaforcausallm_model_layers_1
  17: llamaforcausallm_model_lay

In [5]:
# Step 2: Create dataset
print("üìä Creating dataset...")

from datasets import load_dataset

hf_dataset = load_dataset(HF_DATASET, split=DATA_SPLIT, streaming=False)
if DATA_LIMIT > 0:
    hf_dataset = hf_dataset.select(range(min(DATA_LIMIT, len(hf_dataset))))

dataset = TextDataset(hf_dataset, store=store, text_field=TEXT_FIELD)

print(f"‚úÖ Dataset created: {len(dataset)} samples")
print(f"üìù Sample text: {dataset[0][:100]}..." if len(dataset[0]) > 100 else f"üìù Sample text: {dataset[0]}")

üìä Creating dataset...


Saving the dataset (1/1 shards): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [00:00<00:00, 61572.28 examples/s]

‚úÖ Dataset created: 100 samples
üìù Sample text: One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with...





In [None]:
# Step 3: Save activations WITH attention masks
print("üíæ Saving activations with attention masks...")
print(f"   Layer: {LAYER_SIGNATURE}")
print(f"   Batch size: {BATCH_SIZE_SAVE}")
print(f"   Max length: {MAX_LENGTH}")
print(f"   Save attention masks: True (default)")
print()

# Save activations with attention masks enabled (default behavior)
run_name = lm.activations.save_activations_dataset(
    dataset,
    layer_signature=LAYER_SIGNATURE,
    run_name=RUN_ID,
    batch_size=BATCH_SIZE_SAVE,
    max_length=MAX_LENGTH,
    autocast=False,  # Disable autocast for consistency
    verbose=True,
)

print(f"\n‚úÖ Saved activations with attention masks")
print(f"üìÅ Run ID: {run_name}")
print(f"üìÅ Store location: {lm.context.store.base_path}")

2025-12-09 22:04:52,870 [INFO] amber.language_model.activations: Starting save_activations_dataset: run=activations_with_masks_20251209_220425, layer=llamaforcausallm_model, batch_size=16, device=cpu


üíæ Saving activations with attention masks...
   Layer: llamaforcausallm_model
   Batch size: 16
   Max length: 128
   Save attention masks: True (default)



In [None]:
# Step 4: Verify saved data by loading it back
print("üîç Verifying saved data...")

# Get list of batches
batches = lm.context.store.list_run_batches(run_name)
print(f"‚úÖ Found {len(batches)} batches")
print()

# Load first batch to inspect structure
batch_idx = 0
retrieved_metadata, retrieved_tensors = lm.context.store.get_detector_metadata(run_name, batch_idx)

print(f"üì¶ Batch {batch_idx} structure:")
print(f"   Layers with data: {list(retrieved_tensors.keys())}")
print()

# Check activations
if str(LAYER_SIGNATURE) in retrieved_tensors:
    activations = retrieved_tensors[str(LAYER_SIGNATURE)].get("activations")
    if activations is not None:
        print(f"‚úÖ Activations found:")
        print(f"   Shape: {activations.shape}")
        print(f"   Dtype: {activations.dtype}")
        print(f"   Device: {activations.device}")
    else:
        print("‚ùå Activations not found")
else:
    print(f"‚ùå Layer {LAYER_SIGNATURE} not found in saved data")
print()

# Check attention masks
if "attention_masks" in retrieved_tensors:
    attention_mask = retrieved_tensors["attention_masks"].get("attention_mask")
    if attention_mask is not None:
        print(f"‚úÖ Attention masks found:")
        print(f"   Shape: {attention_mask.shape}")
        print(f"   Dtype: {attention_mask.dtype}")
        print(f"   Device: {attention_mask.device}")
        print(f"   Sample values (first 5 tokens of first 3 samples):")
        print(f"   {attention_mask[:3, :5].tolist()}")
    else:
        print("‚ùå Attention mask not found")
else:
    print("‚ùå Attention masks layer not found in saved data")

In [None]:
# Step 5: Demonstrate matching activations with attention masks
print("üîó Demonstrating activation-attention mask matching...")
print()

# Load a batch
batch_idx = 0
retrieved_metadata, retrieved_tensors = lm.context.store.get_detector_metadata(run_name, batch_idx)

# Get activations and attention masks
activations = retrieved_tensors[str(LAYER_SIGNATURE)]["activations"]
attention_mask = retrieved_tensors["attention_masks"]["attention_mask"]

print(f"üìä Batch {batch_idx} data:")
print(f"   Activations shape: {activations.shape}  # [batch_size, seq_len, d_model]")
print(f"   Attention mask shape: {attention_mask.shape}  # [batch_size, seq_len]")
print()

# Verify shapes match
batch_size, seq_len, d_model = activations.shape
mask_batch_size, mask_seq_len = attention_mask.shape

if batch_size == mask_batch_size and seq_len == mask_seq_len:
    print("‚úÖ Shapes match perfectly!")
    print()
    
    # Show how to filter activations using attention mask
    print("üí° Example: Filtering activations for regular (non-padding) tokens:")
    print()
    
    # For first sample in batch
    sample_idx = 0
    sample_activations = activations[sample_idx]  # [seq_len, d_model]
    sample_mask = attention_mask[sample_idx]  # [seq_len]
    
    # Count regular tokens
    num_regular_tokens = sample_mask.sum().item()
    print(f"   Sample {sample_idx}:")
    print(f"      Total tokens: {seq_len}")
    print(f"      Regular tokens (attention_mask=1): {num_regular_tokens}")
    print(f"      Padding tokens (attention_mask=0): {seq_len - num_regular_tokens}")
    print()
    
    # Filter activations to only regular tokens
    regular_activations = sample_activations[sample_mask.bool()]  # [num_regular_tokens, d_model]
    print(f"   Filtered activations shape: {regular_activations.shape}")
    print(f"   ‚úÖ Successfully filtered to {regular_activations.shape[0]} regular token activations")
    print()
    
    # Show how to apply mask across entire batch
    print("üí° Example: Applying mask to entire batch:")
    # Expand mask to match activation dimensions
    mask_expanded = attention_mask.unsqueeze(-1).expand_as(activations)  # [batch_size, seq_len, d_model]
    # Masked activations (set padding positions to zero)
    masked_activations = activations * mask_expanded
    print(f"   Masked activations shape: {masked_activations.shape}")
    print(f"   ‚úÖ Padding positions are now zero")
else:
    print(f"‚ùå Shape mismatch!")
    print(f"   Activations: batch={batch_size}, seq={seq_len}")
    print(f"   Attention mask: batch={mask_batch_size}, seq={mask_seq_len}")

In [None]:
# Step 6: Compare with and without attention masks
print("üîÑ Comparing save with and without attention masks...")
print()

# Save without attention masks
run_name_no_mask = f"{RUN_ID}_no_mask"
print(f"üíæ Saving WITHOUT attention masks (run: {run_name_no_mask})...")

run_name_no_mask = lm.activations.save_activations_dataset(
    dataset,
    layer_signature=LAYER_SIGNATURE,
    run_name=run_name_no_mask,
    batch_size=BATCH_SIZE_SAVE,
    max_length=MAX_LENGTH,
    autocast=False,
    save_attention_mask=False,  # Explicitly disable
    verbose=False,
)

# Check if attention masks were saved
retrieved_metadata_no_mask, retrieved_tensors_no_mask = lm.context.store.get_detector_metadata(run_name_no_mask, 0)

if "attention_masks" in retrieved_tensors_no_mask:
    print("‚ùå Attention masks found (unexpected!)")
else:
    print("‚úÖ No attention masks saved (as expected)")
    print(f"   Available layers: {list(retrieved_tensors_no_mask.keys())}")
print()

# Compare with run that has attention masks
print(f"üìä Comparison:")
print(f"   Run WITH masks ({run_name}):")
print(f"      Layers: {list(retrieved_tensors.keys())}")
print(f"   Run WITHOUT masks ({run_name_no_mask}):")
print(f"      Layers: {list(retrieved_tensors_no_mask.keys())}")

In [None]:
# Step 7: Access attention masks using the convenience method
print("üîç Accessing attention masks using store convenience method...")
print()

# Using get_detector_metadata_by_layer_by_key
batch_idx = 0
attention_mask = lm.context.store.get_detector_metadata_by_layer_by_key(
    run_name, batch_idx, "attention_masks", "attention_mask"
)

print(f"‚úÖ Retrieved attention mask directly:")
print(f"   Shape: {attention_mask.shape}")
print(f"   Dtype: {attention_mask.dtype}")
print()

# Get activations the same way
activations = lm.context.store.get_detector_metadata_by_layer_by_key(
    run_name, batch_idx, str(LAYER_SIGNATURE), "activations"
)

print(f"‚úÖ Retrieved activations directly:")
print(f"   Shape: {activations.shape}")
print(f"   Dtype: {activations.dtype}")
print()
print("üí° Both can be easily accessed and matched per batch!")

## Summary

This example demonstrated:

1. ‚úÖ **Saving activations with attention masks** - Using `save_attention_mask=True` (default) in `save_activations_dataset()`
2. ‚úÖ **Automatic batch matching** - Attention masks are saved per batch, matching the activation batch structure
3. ‚úÖ **Easy access** - Both activations and attention masks can be loaded from the same batch using the store API
4. ‚úÖ **Shape verification** - Attention masks `[batch_size, seq_len]` match activation sequence dimensions `[batch_size, seq_len, d_model]`
5. ‚úÖ **Practical usage** - Filtering activations to only regular (non-padding) tokens using attention masks

**Key Benefits:**
- No need to run separate inference to get attention masks
- Attention masks are automatically matched to activation batches
- Easy to filter activations to only regular tokens
- Consistent API for accessing both activations and attention masks