# Example 8: Inference with Hooks

This notebook demonstrates how to:
1. Load Bielik model
2. Attach hooks (ModelInputDetector for attention masks)
3. Verify hooks work correctly
4. Run inference on texts using `infer_texts()`
5. Run inference on dataset using `infer_dataset()`
6. Verify metadata was saved correctly

This shows the new inference API that separates basic inference from activation saving.

In [1]:
# Setup and imports
%load_ext autoreload
%autoreload 2

import torch
from pathlib import Path
from datetime import datetime

from amber.datasets import TextDataset
from amber.language_model.language_model import LanguageModel
from amber.hooks.implementations.model_input_detector import ModelInputDetector
from amber.hooks import HookType
from amber.store.local_store import LocalStore
from datasets import load_dataset

print("‚úÖ Imports completed")

  from .autonotebook import tqdm as notebook_tqdm


‚úÖ Imports completed


In [2]:
# Configuration
MODEL_ID = "speakleash/Bielik-1.5B-v3.0-Instruct"
STORE_DIR = Path("store")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
MAX_LENGTH = 128
DATA_LIMIT = 10

HF_DATASET = "roneneldan/TinyStories"
TEXT_FIELD = "text"
DATA_SPLIT = "train"

print("‚öôÔ∏è Configuration:")
print(f"   Model: {MODEL_ID}")
print(f"   Device: {DEVICE}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Max length: {MAX_LENGTH}")
print(f"   Dataset: {HF_DATASET}")
print(f"   Data limit: {DATA_LIMIT} samples")

‚öôÔ∏è Configuration:
   Model: speakleash/Bielik-1.5B-v3.0-Instruct
   Device: cpu
   Batch size: 4
   Max length: 128
   Dataset: roneneldan/TinyStories
   Data limit: 10 samples


In [3]:
# Step 1: Load Bielik model
print("üì• Loading Bielik model...")

store = LocalStore(STORE_DIR)
lm = LanguageModel.from_huggingface(MODEL_ID, store=store)
lm.model.to(DEVICE)

print(f"‚úÖ Model loaded: {lm.model_id}")
print(f"üì± Device: {DEVICE}")
print(f"üìÅ Store location: {lm.store.base_path}")

üì• Loading Bielik model...
‚úÖ Model loaded: speakleash_Bielik-1.5B-v3.0-Instruct
üì± Device: cpu
üìÅ Store location: store


In [4]:
# Step 2: Create small dataset
print("üìä Creating dataset...")

hf_dataset = load_dataset(HF_DATASET, split=DATA_SPLIT, streaming=False)
if DATA_LIMIT > 0:
    hf_dataset = hf_dataset.select(range(min(DATA_LIMIT, len(hf_dataset))))

dataset = TextDataset(hf_dataset, store=store, text_field=TEXT_FIELD)

print(f"‚úÖ Dataset created: {len(dataset)} samples")
print(f"üìù Sample text: {dataset[0][:100]}..." if len(dataset[0]) > 100 else f"üìù Sample text: {dataset[0]}")

üìä Creating dataset...


Saving the dataset (1/1 shards): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10/10 [00:00<00:00, 5175.60 examples/s]

‚úÖ Dataset created: 10 samples
üìù Sample text: One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with...





In [5]:
# Step 3: Attach ModelInputDetector hook for attention masks
print("üîß Attaching ModelInputDetector hook...")
print()

attention_mask_layer_sig = "attention_masks"
root_model = lm.model

if attention_mask_layer_sig not in lm.layers.name_to_layer:
    lm.layers.name_to_layer[attention_mask_layer_sig] = root_model
    print(f"   üìù Added '{attention_mask_layer_sig}' to layers registry")

attention_mask_detector = ModelInputDetector(
    layer_signature=attention_mask_layer_sig,
    hook_id="attention_mask_detector",
    save_input_ids=False,
    save_attention_mask=True,
)
attention_mask_hook_id = lm.layers.register_hook(
    attention_mask_layer_sig, attention_mask_detector, HookType.PRE_FORWARD
)
print(f"   ‚úÖ Attached to root model")
print(f"   üÜî Hook ID: {attention_mask_hook_id}")
print()
print("‚úÖ Hook attached successfully!")

üîß Attaching ModelInputDetector hook...

   üìù Added 'attention_masks' to layers registry
   ‚úÖ Attached to root model
   üÜî Hook ID: attention_mask_detector

‚úÖ Hook attached successfully!


In [6]:
# Step 4: Verify hook works with basic inference
print("üîç Verifying hook works with basic inference...")
print()

test_texts = ["Hello, world!", "This is a test."]

attention_mask_detector.clear_captured()

output, encodings = lm.forwards(
    test_texts,
    tok_kwargs={
        "max_length": MAX_LENGTH,
        "padding": True,
        "truncation": True,
        "add_special_tokens": True
    },
    autocast=False,
)

captured_mask = attention_mask_detector.get_captured_attention_mask()

if captured_mask is not None:
    print(f"‚úÖ Hook works! Captured attention mask:")
    print(f"   Shape: {captured_mask.shape}")
    print(f"   Dtype: {captured_mask.dtype}")
    print(f"   Sample values (first 5 tokens of first sample): {captured_mask[0, :5].tolist()}")
else:
    print("‚ùå Hook did not capture attention mask")

üîç Verifying hook works with basic inference...

‚úÖ Hook works! Captured attention mask:
   Shape: torch.Size([2, 6])
   Dtype: torch.bool
   Sample values (first 5 tokens of first sample): [False, True, True, True, True]


In [7]:
# Step 5: Mode 1 - Inference on texts using infer_texts()
print("üöÄ Mode 1: Inference on texts using infer_texts()")
print("=" * 60)
print()

texts = [dataset[i] for i in range(min(6, len(dataset)))]
print(f"üìù Processing {len(texts)} texts...")
print()

run_name_texts = f"inference_texts_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
print(f"üìÅ Run name: {run_name_texts}")
print()

attention_mask_detector.clear_captured()

outputs, encodings = lm.inference.infer_texts(
    texts,
    run_name=run_name_texts,
    batch_size=3,
    tok_kwargs={
        "max_length": MAX_LENGTH,
        "padding": True,
        "truncation": True,
        "add_special_tokens": True
    },
    autocast=False,
    verbose=True,
)

print()
print(f"‚úÖ Inference completed!")
print(f"   Number of batches: {len(outputs) if isinstance(outputs, list) else 1}")
print(f"   Output type: {type(outputs)}")
print()

batches = lm.store.list_run_batches(run_name_texts)
print(f"üì¶ Saved {len(batches)} batches to store")

üöÄ Mode 1: Inference on texts using infer_texts()

üìù Processing 6 texts...

üìÅ Run name: inference_texts_20251209_220644



2025-12-09 22:06:44,866 [INFO] amber.language_model.inference: Saved batch 0 for run=inference_texts_20251209_220644
2025-12-09 22:06:45,658 [INFO] amber.language_model.inference: Saved batch 1 for run=inference_texts_20251209_220644



‚úÖ Inference completed!
   Number of batches: 2
   Output type: <class 'list'>

üì¶ Saved 2 batches to store


In [8]:
# Step 6: Verify saved metadata from infer_texts()
print("üîç Verifying saved metadata from infer_texts()...")
print()

if len(batches) > 0:
    batch_idx = 0
    retrieved_metadata, retrieved_tensors = lm.store.get_detector_metadata(run_name_texts, batch_idx)
    
    print(f"üìä Batch {batch_idx} metadata:")
    print(f"   Layers with data: {list(retrieved_tensors.keys())}")
    
    if "attention_masks" in retrieved_tensors:
        attention_mask = retrieved_tensors["attention_masks"].get("attention_mask")
        if attention_mask is not None:
            print(f"   ‚úÖ Attention mask found:")
            print(f"      Shape: {attention_mask.shape}")
            print(f"      Dtype: {attention_mask.dtype}")
    
    run_metadata = lm.store.get_run_metadata(run_name_texts)
    if run_metadata:
        print(f"   ‚úÖ Run metadata found:")
        print(f"      Model: {run_metadata.get('model', 'N/A')}")
        print(f"      Batch size: {run_metadata.get('options', {}).get('batch_size', 'N/A')}")
else:
    print("‚ùå No batches found")

üîç Verifying saved metadata from infer_texts()...

üìä Batch 0 metadata:
   Layers with data: ['attention_masks']
   ‚úÖ Attention mask found:
      Shape: torch.Size([3, 128])
      Dtype: torch.bool
   ‚úÖ Run metadata found:
      Model: LlamaForCausalLM
      Batch size: 3


In [9]:
# Step 7: Mode 2 - Inference on dataset using infer_dataset()
print("üöÄ Mode 2: Inference on dataset using infer_dataset()")
print("=" * 60)
print()

run_name_dataset = f"inference_dataset_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
print(f"üìÅ Run name: {run_name_dataset}")
print(f"üìä Dataset size: {len(dataset)} samples")
print(f"üì¶ Batch size: {BATCH_SIZE}")
print()

run_name = lm.inference.infer_dataset(
    dataset,
    run_name=run_name_dataset,
    batch_size=BATCH_SIZE,
    tok_kwargs={
        "max_length": MAX_LENGTH,
        "padding": True,
        "truncation": True,
        "add_special_tokens": True
    },
    autocast=False,
    verbose=True,
)

print()
print(f"‚úÖ Inference completed!")
print(f"üìÅ Run name: {run_name}")
print()

batches = lm.store.list_run_batches(run_name)
print(f"üì¶ Saved {len(batches)} batches to store")

2025-12-09 22:06:45,701 [INFO] amber.language_model.inference: Starting infer_dataset: run=inference_dataset_20251209_220645, batch_size=4, device=cpu


üöÄ Mode 2: Inference on dataset using infer_dataset()

üìÅ Run name: inference_dataset_20251209_220645
üìä Dataset size: 10 samples
üì¶ Batch size: 4



2025-12-09 22:06:46,720 [INFO] amber.language_model.inference: Saved batch 0 for run=inference_dataset_20251209_220645
2025-12-09 22:06:47,764 [INFO] amber.language_model.inference: Saved batch 1 for run=inference_dataset_20251209_220645
2025-12-09 22:06:48,357 [INFO] amber.language_model.inference: Saved batch 2 for run=inference_dataset_20251209_220645
2025-12-09 22:06:48,358 [INFO] amber.language_model.inference: Completed infer_dataset: run=inference_dataset_20251209_220645, batches_saved=3



‚úÖ Inference completed!
üìÅ Run name: inference_dataset_20251209_220645

üì¶ Saved 3 batches to store


In [10]:
# Step 8: Verify saved metadata from infer_dataset()
print("üîç Verifying saved metadata from infer_dataset()...")
print()

if len(batches) > 0:
    batch_idx = 0
    retrieved_metadata, retrieved_tensors = lm.store.get_detector_metadata(run_name, batch_idx)
    
    print(f"üìä Batch {batch_idx} metadata:")
    print(f"   Layers with data: {list(retrieved_tensors.keys())}")
    
    if "attention_masks" in retrieved_tensors:
        attention_mask = retrieved_tensors["attention_masks"].get("attention_mask")
        if attention_mask is not None:
            print(f"   ‚úÖ Attention mask found:")
            print(f"      Shape: {attention_mask.shape}")
            print(f"      Dtype: {attention_mask.dtype}")
    
    run_metadata = lm.store.get_run_metadata(run_name)
    if run_metadata:
        print(f"   ‚úÖ Run metadata found:")
        print(f"      Model: {run_metadata.get('model', 'N/A')}")
        print(f"      Batch size: {run_metadata.get('options', {}).get('batch_size', 'N/A')}")
        print(f"      Dataset length: {run_metadata.get('dataset', {}).get('length', 'N/A')}")
    
    print()
    print(f"üìä All batches summary:")
    for i in range(min(3, len(batches))):
        meta, tensors = lm.store.get_detector_metadata(run_name, i)
        mask_shape = tensors.get("attention_masks", {}).get("attention_mask", None)
        if mask_shape is not None:
            print(f"   Batch {i}: attention_mask shape {mask_shape.shape}")
else:
    print("‚ùå No batches found")

üîç Verifying saved metadata from infer_dataset()...

üìä Batch 0 metadata:
   Layers with data: ['attention_masks']
   ‚úÖ Attention mask found:
      Shape: torch.Size([4, 128])
      Dtype: torch.bool
   ‚úÖ Run metadata found:
      Model: LlamaForCausalLM
      Batch size: 4
      Dataset length: 10

üìä All batches summary:
   Batch 0: attention_mask shape torch.Size([4, 128])
   Batch 1: attention_mask shape torch.Size([4, 128])
   Batch 2: attention_mask shape torch.Size([2, 128])


## Summary

This example demonstrated:

1. ‚úÖ **Loading Bielik model** - Successfully loaded from HuggingFace
2. ‚úÖ **Attaching hooks** - ModelInputDetector for attention masks
3. ‚úÖ **Verifying hooks work** - Confirmed hooks capture data during inference
4. ‚úÖ **Mode 1: infer_texts()** - Inference on list of texts with metadata saving
5. ‚úÖ **Mode 2: infer_dataset()** - Inference on whole dataset with batch processing
6. ‚úÖ **Verification** - Confirmed all metadata saved correctly to disk

**Key Benefits:**
- `infer_texts()` - Simple inference on text lists with optional batching
- `infer_dataset()` - Efficient batch processing of datasets
- Both methods automatically save metadata when `run_name` is provided
- Hooks work seamlessly with both inference modes
- Metadata structure is consistent across both modes

**Conclusion:** ‚úÖ The new inference API provides clean separation between basic inference and activation saving!