# Test Notebook: Labeling Utils

This notebook tests the `labeling_utils` module for extracting frame-wise posteriors from CTC models.

**Features tested:**
1. Model loading (HuggingFace MMS, Wav2Vec2)
2. Emission extraction (single audio)
3. Batched emission extraction
4. Vocabulary information
5. TorchAudio Pipeline Backend (MMS_FA)
6. Integration with audio_frontend

**Architecture:**
The module uses a plugin-style backend system:
- `labeling_utils.base`: Core abstractions (CTCModelBackend, VocabInfo, BackendConfig)
- `labeling_utils.registry`: Backend registration and discovery
- `labeling_utils.backends/`: Individual backend implementations
  - `huggingface.py`: HuggingFace Transformers (MMS, Wav2Vec2)
  - `torchaudio_backend.py`: TorchAudio pipelines (MMS_FA)

## Setup

In [1]:
# Install dependencies
!pip install -q transformers torch torchaudio

In [2]:
# =============================================================================
# Setup: Clone Repository and Configure Imports
# =============================================================================

import sys
import os
from pathlib import Path

# ===== CONFIGURATION =====
GITHUB_REPO = "https://github.com/huangruizhe/torchaudio_aligner.git"
BRANCH = "dev"  # Use 'dev' for testing, 'main' for stable
# =========================

def setup_imports():
    """Setup Python path for imports based on environment."""
    
    IN_COLAB = 'google.colab' in sys.modules
    
    if IN_COLAB:
        repo_path = '/content/torchaudio_aligner'
        src_path = f'{repo_path}/src'
        
        if not os.path.exists(repo_path):
            print(f"Cloning repository (branch: {BRANCH})...")
            os.system(f'git clone -b {BRANCH} {GITHUB_REPO} {repo_path}')
            print("Repository cloned")
        else:
            print(f"Updating repository (branch: {BRANCH})...")
            os.system(f'cd {repo_path} && git fetch origin && git checkout {BRANCH} && git pull origin {BRANCH}')
            print("Repository updated")
    else:
        possible_paths = [
            Path(".").absolute().parent / "src",
            Path(".").absolute() / "src",
        ]
        
        src_path = None
        for p in possible_paths:
            if p.exists() and (p / "labeling_utils").exists():
                src_path = str(p.absolute())
                break
        
        if src_path is None:
            raise FileNotFoundError("src directory not found")
        
        print(f"Running locally from: {src_path}")
    
    if src_path not in sys.path:
        sys.path.insert(0, src_path)
    
    return src_path

src_path = setup_imports()

# Import labeling_utils - new modular structure
from labeling_utils import (
    # High-level API
    load_model,
    get_emissions,
    get_emissions_batched,
    EmissionResult,
    # Model configuration
    ModelConfig,
    list_presets,
    get_model_info,
    # Backend system
    list_backends,
    get_backend,
    is_backend_available,
    # Core classes
    CTCModelBackend,
    VocabInfo,
    BackendConfig,
)

import torch
import logging
logging.basicConfig(level=logging.INFO)

print()
print("=" * 60)
print("Labeling Utils imported successfully!")
print("=" * 60)
print(f"Available backends: {list_backends()}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
print()
print("Available presets:")
for preset in list_presets():
    info = get_model_info(preset)
    print(f"  {preset}: {info['model_name']} ({info['backend']}, {info['languages']} languages)")

Updating repository (branch: dev)...
Repository updated

Labeling Utils imported successfully!
Available backends: ['huggingface', 'torchaudio']
Device: cuda

Available presets:
  mms: facebook/mms-1b-all (huggingface, 1100+ languages)
  mms-1b-all: facebook/mms-1b-all (huggingface, 1100+ languages)
  mms-1b-fl102: facebook/mms-1b-fl102 (huggingface, 102 languages)
  mms-300m: facebook/mms-300m (huggingface, Multiple languages)
  mms-fa: MMS_FA (torchaudio, 1130+ languages)
  mms-fa-torchaudio: MMS_FA (torchaudio, 1130+ languages)
  mms-fa-hf: MahmoudAshraf/mms-300m-1130-forced-aligner (huggingface, 1130+ languages)
  wav2vec2-base: facebook/wav2vec2-base-960h (huggingface, English languages)
  wav2vec2-large: facebook/wav2vec2-large-960h-lv60-self (huggingface, English languages)
  wav2vec2-large-lv60: facebook/wav2vec2-large-960h-lv60-self (huggingface, English languages)
  wav2vec2-xlsr: facebook/wav2vec2-large-xlsr-53 (huggingface, 53 languages)
  wav2vec2-base-ta: WAV2VEC2_ASR_BAS

## Test 1: Load MMS Model

In [3]:
print("=" * 60)
print("Test 1: Load MMS Model (HuggingFace Backend)")
print("=" * 60)

try:
    # Load MMS model for English
    backend = load_model(
        "facebook/mms-1b-all",
        language="eng",
        device="cuda" if torch.cuda.is_available() else "cpu",
    )
    
    print(f"Model loaded: {backend}")
    print(f"Is loaded: {backend.is_loaded}")
    print(f"Frame duration: {backend.frame_duration}s")
    print(f"Sample rate: {backend.sample_rate}Hz")
    
    # Get vocab info
    vocab = backend.get_vocab_info()
    print(f"\nVocabulary:")
    print(f"  Size: {len(vocab.labels)}")
    print(f"  Blank ID: {vocab.blank_id} ('{vocab.blank_token}')")
    print(f"  UNK ID: {vocab.unk_id} ('{vocab.unk_token}')")
    print(f"  Sample labels: {vocab.labels[:10]}...")
    
    print("\nTest 1 PASSED - MMS model loaded successfully")
except Exception as e:
    print(f"\nTest 1 FAILED: {e}")
    import traceback
    traceback.print_exc()

Test 1: Load MMS Model (HuggingFace Backend)


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


preprocessor_config.json:   0%|          | 0.00/254 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/397 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/96.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

adapter.eng.safetensors:   0%|          | 0.00/9.43M [00:00<?, ?B/s]

Model loaded: HuggingFaceCTCBackend(model='facebook/mms-1b-all', loaded=True)
Is loaded: True
Frame duration: 0.02s
Sample rate: 16000Hz

Vocabulary:
  Size: 154
  Blank ID: 0 ('<pad>')
  UNK ID: 3 ('<unk>')
  Sample labels: ['<pad>', '<s>', '</s>', '<unk>', '|', 'e', 't', 'a', 'o', 'i']...

Test 1 PASSED - MMS model loaded successfully


## Test 2: Extract Emissions from VOiCES Sample Audio

Using the Lab41 VOiCES dataset sample: "I had that curiosity beside me at this moment"

In [None]:
print("=" * 60)
print("Test 2: Extract Emissions from VOiCES Sample Audio")
print("=" * 60)

try:
    import torchaudio
    import urllib.request
    import os
    
    # VOiCES sample audio from repository
    # Transcript: "I had that curiosity beside me at this moment"
    SAMPLE_AUDIO = "Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
    SAMPLE_TEXT = "I had that curiosity beside me at this moment"
    
    # Determine path based on environment
    IN_COLAB = 'google.colab' in sys.modules
    if IN_COLAB:
        sample_path = f"/content/torchaudio_aligner/examples/{SAMPLE_AUDIO}"
    else:
        sample_path = str(Path(src_path).parent / "examples" / SAMPLE_AUDIO)
    
    if not os.path.exists(sample_path):
        # Download from GitHub if not found locally
        url = f"https://raw.githubusercontent.com/huangruizhe/torchaudio_aligner/dev/examples/{SAMPLE_AUDIO}"
        print(f"Downloading sample audio...")
        urllib.request.urlretrieve(url, SAMPLE_AUDIO)
        sample_path = SAMPLE_AUDIO
        print(f"Downloaded: {sample_path}")
    
    # Load audio
    waveform, sample_rate = torchaudio.load(sample_path)
    print(f"Loaded: {sample_path}")
    print(f"  Transcript: \"{SAMPLE_TEXT}\"")
    print(f"  Waveform shape: {waveform.shape}")
    print(f"  Sample rate: {sample_rate}Hz")
    print(f"  Duration: {waveform.shape[1] / sample_rate:.2f}s")
    
    # Resample if needed
    if sample_rate != 16000:
        waveform = torchaudio.functional.resample(waveform, sample_rate, 16000)
        sample_rate = 16000
        print(f"  Resampled to: {sample_rate}Hz")
    
    # Extract emissions
    result = get_emissions(backend, waveform, sample_rate=sample_rate)
    
    print(f"\nEmission result:")
    print(f"  Emissions shape: {result.emissions.shape}")
    print(f"  Num frames: {result.num_frames}")
    print(f"  Vocab size: {result.vocab_size}")
    print(f"  Duration: {result.duration:.2f}s")
    
    # Verify log probabilities (should sum to ~1 after exp)
    probs = torch.exp(result.emissions[0])
    prob_sum = probs.sum().item()
    print(f"  Prob sum at frame 0: {prob_sum:.4f} (should be ~1.0)")
    
    # Show top predictions for a few frames
    print(f"\nTop predictions (frames 10-15):")
    vocab = result.vocab_info
    for i in range(10, min(15, result.num_frames)):
        top_idx = result.emissions[i].argmax().item()
        top_prob = torch.exp(result.emissions[i, top_idx]).item()
        label = vocab.id_to_label.get(top_idx, "?")
        print(f"  Frame {i}: '{label}' (prob={top_prob:.3f})")
    
    print("\nTest 2 PASSED - Emissions extracted from VOiCES sample")
except Exception as e:
    print(f"\nTest 2 FAILED: {e}")
    import traceback
    traceback.print_exc()

## Test 3: Greedy Decoding from Emissions

Decode the emissions to verify the model recognizes the speech content.

In [None]:
print("=" * 60)
print("Test 3: Greedy Decoding from Emissions")
print("=" * 60)

try:
    # Use the emissions from Test 2
    # Greedy decode: take argmax at each frame, collapse repeats, remove blanks
    
    def greedy_decode(emissions: torch.Tensor, vocab_info: VocabInfo) -> str:
        """Simple greedy CTC decoding."""
        # Get most likely token at each frame
        indices = emissions.argmax(dim=-1).tolist()
        
        # Collapse consecutive duplicates
        collapsed = []
        prev = None
        for idx in indices:
            if idx != prev:
                collapsed.append(idx)
                prev = idx
        
        # Remove blanks and convert to characters
        tokens = []
        for idx in collapsed:
            if idx == vocab_info.blank_id:
                continue
            label = vocab_info.id_to_label.get(idx, "")
            # Handle word boundary token
            if label == "|":
                tokens.append(" ")
            else:
                tokens.append(label)
        
        return "".join(tokens).strip()
    
    decoded = greedy_decode(result.emissions, result.vocab_info)
    
    print(f"Ground truth: \"{SAMPLE_TEXT}\"")
    print(f"Decoded:      \"{decoded}\"")
    
    # Check if decoding roughly matches
    gt_normalized = SAMPLE_TEXT.lower().replace("'", "")
    decoded_normalized = decoded.lower().replace("'", "")
    
    # Simple word overlap check
    gt_words = set(gt_normalized.split())
    decoded_words = set(decoded_normalized.split())
    overlap = len(gt_words & decoded_words)
    total = len(gt_words)
    
    print(f"\nWord overlap: {overlap}/{total} ({100*overlap/total:.0f}%)")
    
    if overlap >= total // 2:
        print("\nTest 3 PASSED - Greedy decoding produces reasonable output")
    else:
        print("\nTest 3 WARNING - Low word overlap (model may need tuning)")
        
except Exception as e:
    print(f"\nTest 3 FAILED: {e}")
    import traceback
    traceback.print_exc()

## Test 4: Batched Emission Extraction

Test batch processing using variations of the VOiCES sample (different segments).

In [None]:
print("=" * 60)
print("Test 4: Batched Emission Extraction")
print("=" * 60)

try:
    # Create multiple waveforms from the VOiCES sample (different segments)
    # Use the waveform loaded in Test 2
    full_wav = waveform.squeeze(0)  # Remove channel dim
    total_samples = full_wav.shape[0]
    
    # Create 3 segments of different lengths
    waveforms = [
        full_wav[:total_samples // 3],           # First third
        full_wav[total_samples // 4:],            # Last 3/4
        full_wav[total_samples // 3:2*total_samples // 3],  # Middle third
    ]
    
    print(f"Input: {len(waveforms)} waveforms from VOiCES sample")
    for i, w in enumerate(waveforms):
        print(f"  [{i}] shape={w.shape}, duration={len(w)/16000:.2f}s")
    
    # Extract emissions in batch
    results = get_emissions_batched(
        backend,
        waveforms,
        sample_rate=16000,
        batch_size=2,
    )
    
    print(f"\nOutput: {len(results)} EmissionResults")
    for i, res in enumerate(results):
        decoded = greedy_decode(res.emissions, res.vocab_info)
        print(f"  [{i}] frames={res.num_frames}, duration={res.duration:.2f}s")
        print(f"       decoded: \"{decoded[:50]}{'...' if len(decoded) > 50 else ''}\"")
    
    assert len(results) == len(waveforms), "Output count mismatch"
    
    print("\nTest 4 PASSED - Batched extraction works")
except Exception as e:
    print(f"\nTest 4 FAILED: {e}")
    import traceback
    traceback.print_exc()

## Test 5: Different Languages (MMS Multilingual)

In [7]:
print("=" * 60)
print("Test 5: Load MMS for Different Languages")
print("=" * 60)

# Test a few languages
languages = [
    ("fra", "French"),
    ("cmn", "Mandarin Chinese"),
    ("jpn", "Japanese"),
]

for lang_code, lang_name in languages:
    print(f"\nLoading MMS for {lang_name} ({lang_code})...")
    try:
        lang_backend = load_model(
            "facebook/mms-1b-all",
            language=lang_code,
            device="cuda" if torch.cuda.is_available() else "cpu",
        )
        
        vocab = lang_backend.get_vocab_info()
        print(f"  Loaded! Vocab size: {len(vocab.labels)}")
        
        # Quick emission test
        test_wav = torch.randn(16000)  # 1 second
        result = get_emissions(lang_backend, test_wav)
        print(f"  Emissions shape: {result.emissions.shape}")
        print(f"  PASSED")
        
    except Exception as e:
        print(f"  FAILED: {e}")

print("\nTest 5 Complete")

Test 5: Load MMS for Different Languages

Loading MMS for French (fra)...


adapter.fra.safetensors:   0%|          | 0.00/10.2M [00:00<?, ?B/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/mms-1b-all and are newly initialized because the shapes did not match:
- lm_head.bias: found shape torch.Size([154]) in the checkpoint and torch.Size([314]) in the model instantiated
- lm_head.weight: found shape torch.Size([154, 1280]) in the checkpoint and torch.Size([314, 1280]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  Loaded! Vocab size: 314
  Emissions shape: torch.Size([49, 314])
  PASSED

Loading MMS for Mandarin Chinese (cmn)...
  FAILED: 'cmn'

Loading MMS for Japanese (jpn)...


adapter.jpn.safetensors:   0%|          | 0.00/20.3M [00:00<?, ?B/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/mms-1b-all and are newly initialized because the shapes did not match:
- lm_head.bias: found shape torch.Size([154]) in the checkpoint and torch.Size([2268]) in the model instantiated
- lm_head.weight: found shape torch.Size([154, 1280]) in the checkpoint and torch.Size([2268, 1280]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  Loaded! Vocab size: 2268
  Emissions shape: torch.Size([49, 2268])
  PASSED

Test 5 Complete


## Test 6: TorchAudio Pipeline Backend (MMS_FA)

Note: This test uses the TorchAudio pipeline API which has a different interface than HuggingFace.

In [None]:
print("=" * 60)
print("Test 6: TorchAudio Pipeline Backend (MMS_FA)")
print("=" * 60)

try:
    # Check if torchaudio backend is available
    if not is_backend_available("torchaudio"):
        print("TorchAudio backend not available (missing dependencies)")
        print("Skipping test...")
    else:
        # Load MMS_FA using the preset
        ta_backend = load_model("mms-fa")  # Uses TorchAudio backend automatically
        
        print(f"Model loaded: {ta_backend}")
        print(f"Is loaded: {ta_backend.is_loaded}")
        print(f"Frame duration: {ta_backend.frame_duration}s")
        print(f"Sample rate: {ta_backend.sample_rate}Hz")
        
        # Get vocab info
        vocab = ta_backend.get_vocab_info()
        print(f"\nVocabulary:")
        print(f"  Size: {len(vocab.labels)}")
        print(f"  Labels: {vocab.labels}") 
        print(f"  Blank ID: {vocab.blank_id} ('{vocab.blank_token}')")
        print(f"  UNK ID: {vocab.unk_id} ('{vocab.unk_token}')")
        
        # Test with VOiCES sample
        print(f"\nTesting with VOiCES sample:")
        print(f"  Transcript: \"{SAMPLE_TEXT}\"")
        
        ta_result = get_emissions(ta_backend, waveform, sample_rate=16000)
        
        print(f"\nEmission result:")
        print(f"  Emissions shape: {ta_result.emissions.shape}")
        print(f"  Num frames: {ta_result.num_frames}")
        print(f"  Vocab size: {ta_result.vocab_size}")
        
        # MMS_FA uses romanized phonemes, so decode differently
        ta_decoded = greedy_decode(ta_result.emissions, ta_result.vocab_info)
        print(f"\n  Decoded (romanized): \"{ta_decoded}\"")
        
        print("\nTest 6 PASSED - TorchAudio Pipeline backend works with real audio")
except Exception as e:
    print(f"\nTest 6 FAILED: {e}")
    print("Note: This test requires torchaudio with MMS_FA pipeline.")
    import traceback
    traceback.print_exc()

In [ ]:
print("=" * 60)
print("Test 7: Integration with Audio Frontend")
print("=" * 60)

try:
    from audio_frontend import segment_audio
    
    # Use VOiCES sample - segment into smaller chunks
    print(f"Original audio: {waveform.shape[1]/16000:.2f}s")
    
    # Segment into overlapping chunks
    seg_result = segment_audio(
        waveform.squeeze(0), 
        sample_rate=16000, 
        segment_size=1.5,  # 1.5 second segments
        overlap=0.3
    )
    print(f"Segmented into {len(seg_result.segments)} segments")
    
    # Extract emissions for each segment
    all_emissions = []
    for i, segment in enumerate(seg_result.segments):
        seg_emission = get_emissions(backend, segment.waveform)
        all_emissions.append(seg_emission)
        decoded = greedy_decode(seg_emission.emissions, seg_emission.vocab_info)
        print(f"  Segment {i}: frames={seg_emission.num_frames}, decoded=\"{decoded}\"")
    
    print(f"\nTotal emissions extracted: {len(all_emissions)}")
    print(f"Total frames: {sum(e.num_frames for e in all_emissions)}")
    
    print("\nTest 7 PASSED - Audio frontend + labeling utils integration works")
except ImportError:
    print("audio_frontend not available - skipping integration test")
    print("This is expected if running labeling_utils tests only")
except Exception as e:
    print(f"\nTest 7 FAILED: {e}")
    import traceback
    traceback.print_exc()

In [9]:
print("=" * 60)
print("LABELING UTILS TEST SUMMARY")
print("=" * 60)
print("""
The labeling_utils module provides a plugin-style architecture for
extracting frame-wise posteriors from CTC acoustic models.

ARCHITECTURE:
├── base.py          - Core abstractions (CTCModelBackend, VocabInfo, BackendConfig)
├── registry.py      - Backend registration and discovery
├── emissions.py     - get_emissions(), EmissionResult
├── models.py        - load_model(), model presets
└── backends/        - Individual backend implementations
    ├── huggingface.py      - HuggingFace Transformers
    └── torchaudio_backend.py - TorchAudio pipelines

HIGH-LEVEL API:
1. load_model() - Load CTC models with automatic backend detection
   - Presets: "mms", "mms-fa", "wav2vec2-base", etc.
   - Direct model IDs: "facebook/mms-1b-all"

2. get_emissions() - Extract frame-wise log posteriors
   - Returns EmissionResult with emissions, lengths, vocab_info

3. get_emissions_batched() - Efficient batch processing

SUPPORTED BACKENDS:
- huggingface (hf): MMS (1100+ languages), Wav2Vec2, XLSR
- torchaudio (ta): MMS_FA, WAV2VEC2_ASR_*, HUBERT_ASR_*

EXTENDING:
    from labeling_utils import CTCModelBackend, register_backend

    class MyBackend(CTCModelBackend):
        def load(self): ...
        def get_emissions(self, waveform, lengths=None): ...
        def get_vocab_info(self): ...

    register_backend("mybackend", MyBackend, aliases=["mb"])

NEXT STEPS:
- Add NeMo backend (Conformer-CTC, QuartzNet)
- Add ESPnet backend
- Add OmniASR backend
- Integrate with k2 WFST for alignment
""")

LABELING UTILS TEST SUMMARY

The labeling_utils module provides a plugin-style architecture for
extracting frame-wise posteriors from CTC acoustic models.

ARCHITECTURE:
├── base.py          - Core abstractions (CTCModelBackend, VocabInfo, BackendConfig)
├── registry.py      - Backend registration and discovery
├── emissions.py     - get_emissions(), EmissionResult
├── models.py        - load_model(), model presets
└── backends/        - Individual backend implementations
    ├── huggingface.py      - HuggingFace Transformers
    └── torchaudio_backend.py - TorchAudio pipelines

HIGH-LEVEL API:
1. load_model() - Load CTC models with automatic backend detection
   - Presets: "mms", "mms-fa", "wav2vec2-base", etc.
   - Direct model IDs: "facebook/mms-1b-all"

2. get_emissions() - Extract frame-wise log posteriors
   - Returns EmissionResult with emissions, lengths, vocab_info

3. get_emissions_batched() - Efficient batch processing

SUPPORTED BACKENDS:
- huggingface (hf): MMS (1100+ lang