# Test Notebook: Labeling Utils

This notebook tests the `labeling_utils` module for extracting frame-wise posteriors from CTC models.

**Features tested:**
1. Model loading (HuggingFace MMS, Wav2Vec2)
2. Emission extraction (single audio)
3. Batched emission extraction
4. Vocabulary information
5. Integration with audio_frontend

## Setup

In [None]:
# Install dependencies
!pip install -q transformers torch torchaudio

In [None]:
# =============================================================================
# Setup: Clone Repository and Configure Imports
# =============================================================================

import sys
import os
from pathlib import Path

# ===== CONFIGURATION =====
GITHUB_REPO = "https://github.com/huangruizhe/torchaudio_aligner.git"
BRANCH = "dev"  # Use 'dev' for testing, 'main' for stable
# =========================

def setup_imports():
    """Setup Python path for imports based on environment."""
    
    IN_COLAB = 'google.colab' in sys.modules
    
    if IN_COLAB:
        repo_path = '/content/torchaudio_aligner'
        src_path = f'{repo_path}/src'
        
        if not os.path.exists(repo_path):
            print(f"Cloning repository (branch: {BRANCH})...")
            os.system(f'git clone -b {BRANCH} {GITHUB_REPO} {repo_path}')
            print("Repository cloned")
        else:
            print(f"Updating repository (branch: {BRANCH})...")
            os.system(f'cd {repo_path} && git fetch origin && git checkout {BRANCH} && git pull origin {BRANCH}')
            print("Repository updated")
    else:
        possible_paths = [
            Path(".").absolute().parent / "src",
            Path(".").absolute() / "src",
        ]
        
        src_path = None
        for p in possible_paths:
            if p.exists() and (p / "labeling_utils").exists():
                src_path = str(p.absolute())
                break
        
        if src_path is None:
            raise FileNotFoundError("src directory not found")
        
        print(f"Running locally from: {src_path}")
    
    if src_path not in sys.path:
        sys.path.insert(0, src_path)
    
    return src_path

src_path = setup_imports()

# Import labeling_utils
from labeling_utils import (
    load_model,
    get_emissions,
    get_emissions_batched,
    EmissionResult,
    ModelConfig,
    list_backends,
)

import torch
import logging
logging.basicConfig(level=logging.INFO)

print()
print("=" * 60)
print("Labeling Utils imported successfully!")
print("=" * 60)
print(f"Available backends: {list_backends()}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

## Test 1: Load MMS Model

In [None]:
print("=" * 60)
print("Test 1: Load MMS Model (HuggingFace Backend)")
print("=" * 60)

try:
    # Load MMS model for English
    backend = load_model(
        "facebook/mms-1b-all",
        language="eng",
        device="cuda" if torch.cuda.is_available() else "cpu",
    )
    
    print(f"Model loaded: {backend}")
    print(f"Is loaded: {backend.is_loaded}")
    print(f"Frame duration: {backend.frame_duration}s")
    print(f"Sample rate: {backend.sample_rate}Hz")
    
    # Get vocab info
    vocab = backend.get_vocab_info()
    print(f"\nVocabulary:")
    print(f"  Size: {len(vocab.labels)}")
    print(f"  Blank ID: {vocab.blank_id} ('{vocab.blank_token}')")
    print(f"  UNK ID: {vocab.unk_id} ('{vocab.unk_token}')")
    print(f"  Sample labels: {vocab.labels[:10]}...")
    
    print("\nTest 1 PASSED - MMS model loaded successfully")
except Exception as e:
    print(f"\nTest 1 FAILED: {e}")
    import traceback
    traceback.print_exc()

## Test 2: Extract Emissions from Sample Audio

In [None]:
print("=" * 60)
print("Test 2: Extract Emissions from Sample Audio")
print("=" * 60)

try:
    import torchaudio
    
    # Download sample audio
    sample_url = "https://pytorch.org/audio/stable/_static/audio.wav"
    
    # Create a simple test waveform (1 second of random noise - just for shape testing)
    # In real usage, load actual audio
    sample_rate = 16000
    duration = 2.0  # seconds
    waveform = torch.randn(1, int(sample_rate * duration))
    
    print(f"Input waveform shape: {waveform.shape}")
    print(f"Duration: {waveform.shape[1] / sample_rate:.2f}s")
    
    # Extract emissions
    result = get_emissions(backend, waveform, sample_rate=sample_rate)
    
    print(f"\nEmission result:")
    print(f"  Emissions shape: {result.emissions.shape}")
    print(f"  Num frames: {result.num_frames}")
    print(f"  Vocab size: {result.vocab_size}")
    print(f"  Duration: {result.duration:.2f}s")
    print(f"  Frame timestamps (first 10): {result.get_frame_timestamps()[:10].tolist()}")
    
    # Verify shape
    assert result.emissions.dim() == 2, f"Expected 2D tensor, got {result.emissions.dim()}D"
    assert result.emissions.shape[-1] == result.vocab_size
    
    # Verify log probabilities (should sum to ~1 after exp)
    probs = torch.exp(result.emissions[0])  # First frame
    prob_sum = probs.sum().item()
    print(f"\n  Prob sum at frame 0: {prob_sum:.4f} (should be ~1.0)")
    
    print("\nTest 2 PASSED - Emissions extracted successfully")
except Exception as e:
    print(f"\nTest 2 FAILED: {e}")
    import traceback
    traceback.print_exc()

## Test 3: Real Audio - Meta Earnings Call

In [None]:
print("=" * 60)
print("Test 3: Real Audio - Extract Emissions from Meta Earnings Call")
print("=" * 60)

try:
    import torchaudio
    import urllib.request
    import os
    
    # Download a short segment of real audio
    audio_url = "https://static.seekingalpha.com/cdn/s3/transcripts_audio/4780182.mp3"
    audio_file = "meta_earnings.mp3"
    
    if not os.path.exists(audio_file):
        print(f"Downloading audio...")
        urllib.request.urlretrieve(audio_url, audio_file)
        print(f"Downloaded: {audio_file}")
    
    # Load first 10 seconds
    waveform, sample_rate = torchaudio.load(audio_file, num_frames=10 * 16000)
    print(f"Loaded waveform: shape={waveform.shape}, sample_rate={sample_rate}")
    
    # Resample if needed
    if sample_rate != 16000:
        waveform = torchaudio.functional.resample(waveform, sample_rate, 16000)
        sample_rate = 16000
    
    # Convert to mono
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    
    print(f"Preprocessed: shape={waveform.shape}")
    
    # Extract emissions
    result = get_emissions(backend, waveform.squeeze(0), sample_rate=sample_rate)
    
    print(f"\nEmission result:")
    print(f"  Emissions shape: {result.emissions.shape}")
    print(f"  Num frames: {result.num_frames}")
    print(f"  Duration: {result.duration:.2f}s")
    
    # Show top predictions for first few frames
    print(f"\nTop predictions (first 5 frames):")
    vocab = result.vocab_info
    for i in range(min(5, result.num_frames)):
        top_idx = result.emissions[i].argmax().item()
        top_prob = torch.exp(result.emissions[i, top_idx]).item()
        label = vocab.id_to_label.get(top_idx, "?")
        print(f"  Frame {i}: '{label}' (prob={top_prob:.3f})")
    
    print("\nTest 3 PASSED - Real audio emissions extracted successfully")
except Exception as e:
    print(f"\nTest 3 FAILED: {e}")
    import traceback
    traceback.print_exc()

## Test 4: Batched Emission Extraction

In [None]:
print("=" * 60)
print("Test 4: Batched Emission Extraction")
print("=" * 60)

try:
    # Create multiple test waveforms of different lengths
    sample_rate = 16000
    waveforms = [
        torch.randn(int(sample_rate * 1.0)),  # 1 second
        torch.randn(int(sample_rate * 2.0)),  # 2 seconds
        torch.randn(int(sample_rate * 1.5)),  # 1.5 seconds
    ]
    
    print(f"Input: {len(waveforms)} waveforms")
    for i, w in enumerate(waveforms):
        print(f"  [{i}] shape={w.shape}, duration={len(w)/sample_rate:.2f}s")
    
    # Extract emissions in batch
    results = get_emissions_batched(
        backend,
        waveforms,
        sample_rate=sample_rate,
        batch_size=2,
    )
    
    print(f"\nOutput: {len(results)} EmissionResults")
    for i, result in enumerate(results):
        print(f"  [{i}] emissions shape={result.emissions.shape}, duration={result.duration:.2f}s")
    
    assert len(results) == len(waveforms), "Output count mismatch"
    
    print("\nTest 4 PASSED - Batched extraction works")
except Exception as e:
    print(f"\nTest 4 FAILED: {e}")
    import traceback
    traceback.print_exc()

## Test 5: Different Languages (MMS Multilingual)

In [None]:
print("=" * 60)
print("Test 5: Load MMS for Different Languages")
print("=" * 60)

# Test a few languages
languages = [
    ("fra", "French"),
    ("cmn", "Mandarin Chinese"),
    ("jpn", "Japanese"),
]

for lang_code, lang_name in languages:
    print(f"\nLoading MMS for {lang_name} ({lang_code})...")
    try:
        lang_backend = load_model(
            "facebook/mms-1b-all",
            language=lang_code,
            device="cuda" if torch.cuda.is_available() else "cpu",
        )
        
        vocab = lang_backend.get_vocab_info()
        print(f"  Loaded! Vocab size: {len(vocab.labels)}")
        
        # Quick emission test
        test_wav = torch.randn(16000)  # 1 second
        result = get_emissions(lang_backend, test_wav)
        print(f"  Emissions shape: {result.emissions.shape}")
        print(f"  PASSED")
        
    except Exception as e:
        print(f"  FAILED: {e}")

print("\nTest 5 Complete")

## Test 6: TorchAudio Pipeline Backend (MMS_FA)

Note: This test uses the TorchAudio pipeline API which has a different interface than HuggingFace.

In [None]:
print("=" * 60)
print("Test 6: TorchAudio Pipeline Backend (MMS_FA)")
print("=" * 60)

try:
    from labeling_utils import TorchAudioPipelineBackend, BackendConfig
    
    # Create config for MMS_FA
    config = BackendConfig(
        model_name="MMS_FA",
        device="cuda" if torch.cuda.is_available() else "cpu",
        with_star=True,
    )
    
    # Create and load backend
    ta_backend = TorchAudioPipelineBackend(config)
    ta_backend.load()
    
    print(f"Model loaded: {ta_backend}")
    print(f"Is loaded: {ta_backend.is_loaded}")
    print(f"Frame duration: {ta_backend.frame_duration}s")
    print(f"Sample rate: {ta_backend.sample_rate}Hz")
    
    # Get vocab info
    vocab = ta_backend.get_vocab_info()
    print(f"\nVocabulary:")
    print(f"  Size: {len(vocab.labels)}")
    print(f"  Labels: {vocab.labels[:15]}...")
    print(f"  Blank ID: {vocab.blank_id} ('{vocab.blank_token}')")
    print(f"  UNK ID: {vocab.unk_id} ('{vocab.unk_token}')")
    
    # Test emission extraction
    test_wav = torch.randn(16000 * 2)  # 2 seconds
    result = get_emissions(ta_backend, test_wav)
    
    print(f"\nEmission result:")
    print(f"  Emissions shape: {result.emissions.shape}")
    print(f"  Num frames: {result.num_frames}")
    print(f"  Vocab size: {result.vocab_size}")
    
    print("\nTest 6 PASSED - TorchAudio Pipeline backend works")
except Exception as e:
    print(f"\nTest 6 FAILED: {e}")
    print("Note: This test requires torchaudio with MMS_FA pipeline.")
    print("If MMS_FA is not available, this is expected.")
    import traceback
    traceback.print_exc()

## Test 7: Integration with Audio Frontend

In [None]:
print("=" * 60)
print("LABELING UTILS TEST SUMMARY")
print("=" * 60)
print("""
The labeling_utils module provides:

1. load_model() - Load CTC models from HuggingFace or TorchAudio
   - Supports MMS (1100+ languages)
   - Supports Wav2Vec2 variants
   - Automatic language adapter loading

2. get_emissions() - Extract frame-wise log posteriors
   - Returns EmissionResult with metadata
   - Automatic resampling if needed

3. get_emissions_batched() - Efficient batch processing
   - Process multiple audio files at once

4. Extensible backend system:
   - HuggingFaceCTCBackend: For HuggingFace models (facebook/mms-1b-all, etc.)
   - TorchAudioPipelineBackend: For TorchAudio pipelines (MMS_FA, etc.)
   - Easy to add NeMo, ESPnet, etc. via register_backend()

Next steps:
- Use emissions with k2 WFST for alignment
- Add more backends (NeMo, ESPnet, OmniASR)
""")

## Summary

In [None]:
print("=" * 60)
print("LABELING UTILS TEST SUMMARY")
print("=" * 60)
print("""
The labeling_utils module provides:

1. load_model() - Load CTC models from HuggingFace
   - Supports MMS (1100+ languages)
   - Supports Wav2Vec2 variants
   - Automatic language adapter loading

2. get_emissions() - Extract frame-wise log posteriors
   - Returns EmissionResult with metadata
   - Automatic resampling if needed

3. get_emissions_batched() - Efficient batch processing
   - Process multiple audio files at once

4. Extensible backend system:
   - HuggingFaceCTCBackend (current)
   - Easy to add NeMo, ESPnet, etc.

Next steps:
- Use emissions with k2 WFST for alignment
- Add more backends (NeMo, ESPnet, OmniASR)
""")