# Test Notebook: Labeling Utils

This notebook tests the `labeling_utils` module for extracting frame-wise posteriors from CTC models.

**Features tested:**
1. Model loading (HuggingFace MMS, Wav2Vec2)
2. Emission extraction (single audio)
3. Batched emission extraction
4. Vocabulary information
5. TorchAudio Pipeline Backend (MMS_FA)
6. Integration with audio_frontend
7. **NeMo Backend** (FastConformer hybrid RNN-T/CTC)
8. **OmniASR Backend** (1600+ languages)

**Architecture:**
The module uses a plugin-style backend system:
- `labeling_utils.base`: Core abstractions (CTCModelBackend, VocabInfo, BackendConfig)
- `labeling_utils.registry`: Backend registration and discovery
- `labeling_utils.backends/`: Individual backend implementations
  - `huggingface.py`: HuggingFace Transformers (MMS, Wav2Vec2)
  - `torchaudio_backend.py`: TorchAudio pipelines (MMS_FA)
  - `nemo_backend.py`: NVIDIA NeMo (Conformer-CTC, FastConformer)
  - `omniasr_backend.py`: Facebook OmniASR (1600+ languages)

## Setup

In [None]:
# Install dependencies
# !pip install -q transformers torch torchaudio torchcodec soundfile
!pip install transformers torch torchaudio torchcodec soundfile

In [None]:
# =============================================================================
# Setup: Clone Repository and Configure Imports
# =============================================================================

import sys
import os
from pathlib import Path

# ===== CONFIGURATION =====
GITHUB_REPO = "https://github.com/huangruizhe/torchaudio_aligner.git"
BRANCH = "dev"  # Use 'dev' for testing, 'main' for stable
# =========================

# Test result tracking
test_results = {}

def setup_imports():
    """Setup Python path for imports based on environment."""
    
    IN_COLAB = 'google.colab' in sys.modules
    
    if IN_COLAB:
        repo_path = '/content/torchaudio_aligner'
        src_path = f'{repo_path}/src'
        
        if not os.path.exists(repo_path):
            print(f"Cloning repository (branch: {BRANCH})...")
            os.system(f'git clone -b {BRANCH} {GITHUB_REPO} {repo_path}')
            print("Repository cloned")
        else:
            print(f"Updating repository (branch: {BRANCH})...")
            os.system(f'cd {repo_path} && git fetch origin && git checkout {BRANCH} && git pull origin {BRANCH}')
            print("Repository updated")
    else:
        possible_paths = [
            Path(".").absolute().parent / "src",
            Path(".").absolute() / "src",
        ]
        
        src_path = None
        for p in possible_paths:
            if p.exists() and (p / "labeling_utils").exists():
                src_path = str(p.absolute())
                break
        
        if src_path is None:
            raise FileNotFoundError("src directory not found")
        
        print(f"Running locally from: {src_path}")
    
    if src_path not in sys.path:
        sys.path.insert(0, src_path)
    
    return src_path

src_path = setup_imports()

# Import labeling_utils - new modular structure
from labeling_utils import (
    # High-level API
    load_model,
    get_emissions,
    get_emissions_batched,
    EmissionResult,
    # Model configuration
    ModelConfig,
    list_presets,
    get_model_info,
    # Backend system
    list_backends,
    get_backend,
    is_backend_available,
    # Core classes
    CTCModelBackend,
    VocabInfo,
    BackendConfig,
)

import torch
import logging
logging.basicConfig(level=logging.INFO)

print()
print("=" * 60)
print("‚úÖ Labeling Utils imported successfully!")
print("=" * 60)
print(f"Available backends: {list_backends()}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
print()
print("Available presets:")
for preset in list_presets():
    info = get_model_info(preset)
    print(f"  ‚Ä¢ {preset}: {info['model_name']} ({info['backend']}, {info['languages']} languages)")

## Test 1: Load MMS Model

In [None]:
print("=" * 60)
print("Test 1: Load MMS Model (HuggingFace Backend)")
print("=" * 60)

try:
    # Load MMS model for English
    backend = load_model(
        "facebook/mms-1b-all",
        language="eng",
        device="cuda" if torch.cuda.is_available() else "cpu",
    )
    
    print(f"Model loaded: {backend}")
    print(f"  ‚Ä¢ Is loaded: {backend.is_loaded}")
    print(f"  ‚Ä¢ Frame duration: {backend.frame_duration}s")
    print(f"  ‚Ä¢ Sample rate: {backend.sample_rate}Hz")
    
    # Get vocab info
    vocab = backend.get_vocab_info()
    print(f"\nVocabulary:")
    print(f"  ‚Ä¢ Size: {len(vocab.labels)}")
    print(f"  ‚Ä¢ Blank ID: {vocab.blank_id} ('{vocab.blank_token}')")
    print(f"  ‚Ä¢ UNK ID: {vocab.unk_id} ('{vocab.unk_token}')")
    print(f"  ‚Ä¢ Sample labels: {vocab.labels[:10]}...")
    
    test_results["Test 1"] = "‚úÖ PASSED"
    print(f"\n‚úÖ Test 1 PASSED - MMS model loaded successfully")
except Exception as e:
    test_results["Test 1"] = "‚ùå FAILED"
    print(f"\n‚ùå Test 1 FAILED: {e}")
    import traceback
    traceback.print_exc()

## Test 2: Extract Emissions from VOiCES Sample Audio

Using the Lab41 VOiCES dataset sample: *"I had that curiosity beside me at this moment"*

In [None]:
print("=" * 60)
print("Test 2: Extract Emissions from VOiCES Sample Audio")
print("=" * 60)

try:
    from audio_frontend import load_audio, resample
    import urllib.request
    import os
    
    # VOiCES sample audio from repository
    # Transcript: "I had that curiosity beside me at this moment"
    SAMPLE_AUDIO = "Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
    SAMPLE_TEXT = "I had that curiosity beside me at this moment"
    
    # Determine path based on environment
    IN_COLAB = 'google.colab' in sys.modules
    if IN_COLAB:
        sample_path = f"/content/torchaudio_aligner/examples/{SAMPLE_AUDIO}"
    else:
        sample_path = str(Path(src_path).parent / "examples" / SAMPLE_AUDIO)
    
    if not os.path.exists(sample_path):
        # Download from GitHub if not found locally
        url = f"https://raw.githubusercontent.com/huangruizhe/torchaudio_aligner/dev/examples/{SAMPLE_AUDIO}"
        print(f"üì• Downloading sample audio...")
        urllib.request.urlretrieve(url, SAMPLE_AUDIO)
        sample_path = SAMPLE_AUDIO
        print(f"   Downloaded: {sample_path}")
    
    # Load audio using our own API
    waveform, sample_rate = load_audio(sample_path)
    print(f"üéµ Loaded: {sample_path}")
    print(f"   Transcript: \"{SAMPLE_TEXT}\"")
    print(f"   ‚Ä¢ Waveform shape: {waveform.shape}")
    print(f"   ‚Ä¢ Sample rate: {sample_rate}Hz")
    print(f"   ‚Ä¢ Duration: {waveform.shape[1] / sample_rate:.2f}s")
    
    # Resample if needed using our own API
    if sample_rate != 16000:
        waveform = resample(waveform, sample_rate, 16000)
        sample_rate = 16000
        print(f"   ‚Ä¢ Resampled to: {sample_rate}Hz")
    
    # Extract emissions
    result = get_emissions(backend, waveform, sample_rate=sample_rate)
    
    print(f"\nüìä Emission result:")
    print(f"   ‚Ä¢ Emissions shape: {result.emissions.shape}")
    print(f"   ‚Ä¢ Num frames: {result.num_frames}")
    print(f"   ‚Ä¢ Vocab size: {result.vocab_size}")
    print(f"   ‚Ä¢ Duration: {result.duration:.2f}s")
    
    # Verify log probabilities (should sum to ~1 after exp)
    probs = torch.exp(result.emissions[0])
    prob_sum = probs.sum().item()
    print(f"   ‚Ä¢ Prob sum at frame 0: {prob_sum:.4f} (should be ~1.0)")
    
    # Show top predictions for a few frames
    print(f"\nüî§ Top predictions (frames 10-15):")
    vocab = result.vocab_info
    for i in range(10, min(15, result.num_frames)):
        top_idx = result.emissions[i].argmax().item()
        top_prob = torch.exp(result.emissions[i, top_idx]).item()
        label = vocab.id_to_label.get(top_idx, "?")
        print(f"   Frame {i}: '{label}' (prob={top_prob:.3f})")
    
    test_results["Test 2"] = "‚úÖ PASSED"
    print(f"\n‚úÖ Test 2 PASSED - Emissions extracted from VOiCES sample")
except Exception as e:
    test_results["Test 2"] = "‚ùå FAILED"
    print(f"\n‚ùå Test 2 FAILED: {e}")
    import traceback
    traceback.print_exc()

## Test 3: Greedy Decoding from Emissions

Decode the emissions to verify the model recognizes the speech content.

In [None]:
print("=" * 60)
print("Test 3: Greedy Decoding from Emissions")
print("=" * 60)

try:
    # Use the backend's built-in greedy_decode method
    # This uses the tokenizer to properly decode token IDs to text
    decoded = backend.greedy_decode(result.emissions)
    
    print(f"üìù Ground truth: \"{SAMPLE_TEXT}\"")
    print(f"üîä Decoded:      \"{decoded}\"")
    
    # Check if decoding roughly matches
    gt_normalized = SAMPLE_TEXT.lower().replace("'", "")
    decoded_normalized = decoded.lower().replace("'", "")
    
    # Simple word overlap check
    gt_words = set(gt_normalized.split())
    decoded_words = set(decoded_normalized.split())
    overlap = len(gt_words & decoded_words)
    total = len(gt_words)
    
    print(f"\nüìà Word overlap: {overlap}/{total} ({100*overlap/total:.0f}%)")
    
    if overlap >= total // 2:
        test_results["Test 3"] = "‚úÖ PASSED"
        print(f"\n‚úÖ Test 3 PASSED - Greedy decoding produces reasonable output")
    else:
        test_results["Test 3"] = "‚ö†Ô∏è WARNING"
        print(f"\n‚ö†Ô∏è Test 3 WARNING - Low word overlap (model may need tuning)")
        
except Exception as e:
    test_results["Test 3"] = "‚ùå FAILED"
    print(f"\n‚ùå Test 3 FAILED: {e}")
    import traceback
    traceback.print_exc()

## Test 4: Batched Emission Extraction

Test batch processing using segments from the VOiCES sample.

In [None]:
print("=" * 60)
print("Test 4: Batched Emission Extraction")
print("=" * 60)

try:
    # Create multiple waveforms from the VOiCES sample (different segments)
    full_wav = waveform.squeeze(0)  # Remove channel dim
    total_samples = full_wav.shape[0]
    
    # Create 3 segments of different lengths
    waveforms = [
        full_wav[:total_samples // 3],           # First third
        full_wav[total_samples // 4:],            # Last 3/4
        full_wav[total_samples // 3:2*total_samples // 3],  # Middle third
    ]
    
    print(f"üì¶ Input: {len(waveforms)} waveforms from VOiCES sample")
    for i, w in enumerate(waveforms):
        print(f"   [{i}] shape={w.shape}, duration={len(w)/16000:.2f}s")
    
    # Extract emissions in batch
    results = get_emissions_batched(
        backend,
        waveforms,
        sample_rate=16000,
        batch_size=2,
    )
    
    print(f"\nüìä Output: {len(results)} EmissionResults")
    for i, res in enumerate(results):
        decoded = backend.greedy_decode(res.emissions)
        print(f"   [{i}] frames={res.num_frames}, duration={res.duration:.2f}s")
        print(f"       decoded: \"{decoded[:50]}{'...' if len(decoded) > 50 else ''}\"")
    
    assert len(results) == len(waveforms), "Output count mismatch"
    
    test_results["Test 4"] = "‚úÖ PASSED"
    print(f"\n‚úÖ Test 4 PASSED - Batched extraction works")
except Exception as e:
    test_results["Test 4"] = "‚ùå FAILED"
    print(f"\n‚ùå Test 4 FAILED: {e}")
    import traceback
    traceback.print_exc()

## Test 5: Different Languages (MMS Multilingual)

In [None]:
print("=" * 60)
print("Test 5: Load MMS for Different Languages")
print("=" * 60)

# Test a few languages
languages = [
    ("fra", "French"),
    ("deu", "German"),
    ("jpn", "Japanese"),
]

test5_passed = 0
test5_total = len(languages)

for lang_code, lang_name in languages:
    print(f"\nüåç Loading MMS for {lang_name} ({lang_code})...")
    try:
        lang_backend = load_model(
            "facebook/mms-1b-all",
            language=lang_code,
            device="cuda" if torch.cuda.is_available() else "cpu",
        )
        
        vocab = lang_backend.get_vocab_info()
        print(f"   ‚Ä¢ Vocab size: {len(vocab.labels)}")
        
        # Quick emission test with VOiCES sample
        test_result = get_emissions(lang_backend, waveform, sample_rate=16000)
        print(f"   ‚Ä¢ Emissions shape: {test_result.emissions.shape}")
        print(f"   ‚úÖ {lang_name} PASSED")
        test5_passed += 1
        
    except Exception as e:
        print(f"   ‚ùå {lang_name} FAILED: {e}")

if test5_passed == test5_total:
    test_results["Test 5"] = "‚úÖ PASSED"
    print(f"\n‚úÖ Test 5 PASSED - All {test5_total} languages loaded successfully")
elif test5_passed > 0:
    test_results["Test 5"] = f"‚ö†Ô∏è PARTIAL ({test5_passed}/{test5_total})"
    print(f"\n‚ö†Ô∏è Test 5 PARTIAL - {test5_passed}/{test5_total} languages loaded")
else:
    test_results["Test 5"] = "‚ùå FAILED"
    print(f"\n‚ùå Test 5 FAILED - No languages loaded")

## Test 6: TorchAudio Pipeline Backend (MMS_FA)

Note: This test uses the TorchAudio pipeline API which has a different interface than HuggingFace.

In [None]:
print("=" * 60)
print("Test 6: TorchAudio Pipeline Backend (MMS_FA)")
print("=" * 60)

try:
    # Check if torchaudio backend is available
    if not is_backend_available("torchaudio"):
        test_results["Test 6"] = "‚è≠Ô∏è SKIPPED"
        print("‚è≠Ô∏è TorchAudio backend not available (missing dependencies)")
        print("   Skipping test...")
    else:
        # Load MMS_FA using the preset
        ta_backend = load_model("mms-fa")  # Uses TorchAudio backend automatically
        
        print(f"üîß Model loaded: {ta_backend}")
        print(f"   ‚Ä¢ Is loaded: {ta_backend.is_loaded}")
        print(f"   ‚Ä¢ Frame duration: {ta_backend.frame_duration}s")
        print(f"   ‚Ä¢ Sample rate: {ta_backend.sample_rate}Hz")
        
        # Get vocab info
        vocab = ta_backend.get_vocab_info()
        print(f"\nüìö Vocabulary:")
        print(f"   ‚Ä¢ Size: {len(vocab.labels)}")
        print(f"   ‚Ä¢ Labels: {vocab.labels}") 
        print(f"   ‚Ä¢ Blank ID: {vocab.blank_id} ('{vocab.blank_token}')")
        print(f"   ‚Ä¢ UNK ID: {vocab.unk_id} ('{vocab.unk_token}')")
        
        # Test with VOiCES sample
        print(f"\nüéµ Testing with VOiCES sample:")
        print(f"   Transcript: \"{SAMPLE_TEXT}\"")
        
        ta_result = get_emissions(ta_backend, waveform, sample_rate=16000)
        
        print(f"\nüìä Emission result:")
        print(f"   ‚Ä¢ Emissions shape: {ta_result.emissions.shape}")
        print(f"   ‚Ä¢ Num frames: {ta_result.num_frames}")
        print(f"   ‚Ä¢ Vocab size: {ta_result.vocab_size}")
        
        # MMS_FA uses romanized phonemes - use backend's greedy_decode
        ta_decoded = ta_backend.greedy_decode(ta_result.emissions)
        print(f"\nüî§ Decoded (romanized): \"{ta_decoded}\"")
        
        test_results["Test 6"] = "‚úÖ PASSED"
        print(f"\n‚úÖ Test 6 PASSED - TorchAudio Pipeline backend works")
except Exception as e:
    test_results["Test 6"] = "‚ùå FAILED"
    print(f"\n‚ùå Test 6 FAILED: {e}")
    print("   Note: This test requires torchaudio with MMS_FA pipeline.")
    import traceback
    traceback.print_exc()

## Test 7: Integration with Audio Frontend

In [None]:
print("=" * 60)
print("Test 7: Integration with Audio Frontend")
print("=" * 60)

try:
    from audio_frontend import segment_waveform
    
    # Use VOiCES sample - segment into smaller chunks
    print(f"üéµ Original audio: {waveform.shape[1]/16000:.2f}s")
    
    # Segment into overlapping chunks using segment_waveform (works with tensors)
    seg_result = segment_waveform(
        waveform.squeeze(0),  # 1D tensor
        sample_rate=16000, 
        segment_size=1.5,  # 1.5 second segments
        overlap=0.3
    )
    print(f"‚úÇÔ∏è Segmented into {len(seg_result.segments)} segments")
    
    # Extract emissions for each segment
    all_emissions = []
    for i, segment in enumerate(seg_result.segments):
        seg_emission = get_emissions(backend, segment.waveform)
        all_emissions.append(seg_emission)
        decoded = backend.greedy_decode(seg_emission.emissions)
        print(f"   Segment {i}: frames={seg_emission.num_frames}, decoded=\"{decoded}\"")
    
    print(f"\nüìä Total emissions extracted: {len(all_emissions)}")
    print(f"   Total frames: {sum(e.num_frames for e in all_emissions)}")
    
    test_results["Test 7"] = "‚úÖ PASSED"
    print(f"\n‚úÖ Test 7 PASSED - Audio frontend + labeling utils integration works")
except ImportError:
    test_results["Test 7"] = "‚è≠Ô∏è SKIPPED"
    print("‚è≠Ô∏è audio_frontend not available - skipping integration test")
    print("   This is expected if running labeling_utils tests only")
except Exception as e:
    test_results["Test 7"] = "‚ùå FAILED"
    print(f"\n‚ùå Test 7 FAILED: {e}")
    import traceback
    traceback.print_exc()

## Test 8: NeMo Backend (FastConformer Hybrid RNN-T/CTC)

Test the NeMo backend using the same model as in the tutorial: `nvidia/stt_en_fastconformer_hybrid_large_pc`

This is a hybrid RNN-T/CTC model - we use the CTC head for emission extraction.

In [None]:
# Optional: NeMo backend (heavy install, ~5-10 min)
# ! pip install nemo_toolkit[asr]
# ! pip install nemo_toolkit[all]

In [None]:
print("=" * 60)
print("Test 8: NeMo Backend (FastConformer Hybrid RNN-T/CTC)")
print("=" * 60)

try:
    # Check if nemo backend is available
    if not is_backend_available("nemo"):
        test_results["Test 8"] = "‚è≠Ô∏è SKIPPED"
        print("‚è≠Ô∏è NeMo backend not available (nemo_toolkit not installed)")
        print("   Install with: pip install nemo_toolkit[asr]")
        print("   Skipping test...")
    else:
        # Load FastConformer Hybrid model (same as in tutorial)
        # This is the model used in nemo_forced_aligner_tutorial.py
        print("üîß Loading NeMo FastConformer Hybrid model...")
        print("   Model: nvidia/stt_en_fastconformer_hybrid_large_pc")
        
        nemo_backend = load_model(
            "nemo-fastconformer",  # Uses nvidia/stt_en_fastconformer_hybrid_large_pc
            device="cuda" if torch.cuda.is_available() else "cpu",
        )
        
        print(f"\nüì¶ Model loaded: {nemo_backend}")
        print(f"   ‚Ä¢ Is loaded: {nemo_backend.is_loaded}")
        print(f"   ‚Ä¢ Frame duration: {nemo_backend.frame_duration}s")
        print(f"   ‚Ä¢ Sample rate: {nemo_backend.sample_rate}Hz")
        
        # Get vocab info
        nemo_vocab = nemo_backend.get_vocab_info()
        print(f"\nüìö Vocabulary (BPE):")
        print(f"   ‚Ä¢ Size: {len(nemo_vocab.labels)}")
        print(f"   ‚Ä¢ Blank ID: {nemo_vocab.blank_id} ('{nemo_vocab.blank_token}')")
        print(f"   ‚Ä¢ Sample tokens: {nemo_vocab.labels[1:11]}...")  # Skip blank
        
        # Test with VOiCES sample
        print(f"\nüéµ Testing with VOiCES sample:")
        print(f"   Transcript: \"{SAMPLE_TEXT}\"")
        
        nemo_result = get_emissions(nemo_backend, waveform, sample_rate=16000)
        
        print(f"\nüìä Emission result:")
        print(f"   ‚Ä¢ Emissions shape: {nemo_result.emissions.shape}")
        print(f"   ‚Ä¢ Num frames: {nemo_result.num_frames}")
        print(f"   ‚Ä¢ Vocab size: {nemo_result.vocab_size}")
        
        # Greedy decode using backend's tokenizer
        nemo_decoded = nemo_backend.greedy_decode(nemo_result.emissions)
        
        print(f"\nüî§ Greedy decoding:")
        print(f"   üìù Ground truth: \"{SAMPLE_TEXT}\"")
        print(f"   üîä Decoded:      \"{nemo_decoded}\"")
        
        # Check word overlap
        gt_normalized = SAMPLE_TEXT.lower().replace("'", "")
        decoded_normalized = nemo_decoded.lower().replace("'", "")
        
        gt_words = set(gt_normalized.split())
        decoded_words = set(decoded_normalized.split())
        overlap = len(gt_words & decoded_words)
        total = len(gt_words)
        
        print(f"\nüìà Word overlap: {overlap}/{total} ({100*overlap/total:.0f}%)")
        
        if overlap >= total // 2:
            test_results["Test 8"] = "‚úÖ PASSED"
            print(f"\n‚úÖ Test 8 PASSED - NeMo backend works with reasonable decoding")
        else:
            test_results["Test 8"] = "‚ö†Ô∏è WARNING"
            print(f"\n‚ö†Ô∏è Test 8 WARNING - Low word overlap (but backend works)")
            
except Exception as e:
    test_results["Test 8"] = "‚ùå FAILED"
    print(f"\n‚ùå Test 8 FAILED: {e}")
    print("   Note: This test requires nemo_toolkit[asr]")
    import traceback
    traceback.print_exc()

## Test 9: OmniASR Backend (1600+ Languages)

Test the OmniASR backend from Facebook/Meta's Omnilingual ASR project.

Note: This requires the `omnilingual-asr` package: `pip install omnilingual-asr`

In [None]:
# OmniASR Installation (Colab only)
# Credit: https://github.com/NeuralFalconYT/omnilingual-asr-colab
#
# WARNING: omnilingual-asr has specific PyTorch/fairseq2 requirements
# that may conflict with other packages. Run in a fresh environment.

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Uncomment to install OmniASR:
    # !pip uninstall -y torch torchaudio
    # !pip install torch==2.8.0+cu128 torchaudio==2.8.0+cu128 torchvision==0.23.0+cu128 --index-url https://download.pytorch.org/whl/cu128
    # !pip install fairseq2==0.6
    # !pip install omnilingual-asr==0.1.0
    # !pip install silero-vad>=4.0.0 onnxruntime>=1.12.0 uroman==1.3.1.1
    # !pip uninstall fairseq2 -y
    # !pip install fairseq2 --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/pt2.8.0/cu126
    # !pip install omnilingual-asr
    pass
else:
    print("OmniASR installation instructions are for Colab only.")
    print("For local installation, see: https://github.com/facebookresearch/omnilingual-asr")

In [None]:
print("=" * 60)
print("Test 9: OmniASR Backend (1600+ Languages)")
print("=" * 60)

try:
    # Check if omniasr backend is available
    if not is_backend_available("omniasr"):
        test_results["Test 9"] = "‚è≠Ô∏è SKIPPED"
        print("‚è≠Ô∏è OmniASR backend not available (omnilingual-asr not installed)")
        print("   Install with the cell above (Colab) or see:")
        print("   https://github.com/facebookresearch/omnilingual-asr")
        print("   Skipping test...")
    else:
        # Load OmniASR CTC model (300M is fastest for testing)
        print("üîß Loading OmniASR CTC model...")
        print("   Model: omniASR_CTC_300M (325M parameters)")
        
        omni_backend = load_model(
            "omniasr-300m",  # Fastest model for testing
            device="cuda" if torch.cuda.is_available() else "cpu",
        )
        
        print(f"\nüì¶ Model loaded: {omni_backend}")
        print(f"   ‚Ä¢ Is loaded: {omni_backend.is_loaded}")
        print(f"   ‚Ä¢ Frame duration: {omni_backend.frame_duration}s")
        print(f"   ‚Ä¢ Sample rate: {omni_backend.sample_rate}Hz")
        
        # Get vocab info - should show full 9812 tokens
        omni_vocab = omni_backend.get_vocab_info()
        print(f"\nüìö Vocabulary (Character-level SentencePiece):")
        print(f"   ‚Ä¢ Size: {len(omni_vocab.labels)}")
        print(f"   ‚Ä¢ Blank ID: {omni_vocab.blank_id} ('{omni_vocab.blank_token}')")
        print(f"   ‚Ä¢ First 20 tokens: {omni_vocab.labels[:20]}")
        
        # Test with VOiCES sample
        print(f"\nüéµ Testing with VOiCES sample:")
        print(f"   Transcript: \"{SAMPLE_TEXT}\"")
        
        omni_result = get_emissions(omni_backend, waveform, sample_rate=16000)
        
        print(f"\nüìä Emission result:")
        print(f"   ‚Ä¢ Emissions shape: {omni_result.emissions.shape}")
        print(f"   ‚Ä¢ Num frames: {omni_result.num_frames}")
        print(f"   ‚Ä¢ Vocab size: {omni_result.vocab_size}")
        
        # Greedy decode using backend's tokenizer
        omni_decoded = omni_backend.greedy_decode(omni_result.emissions)
        
        print(f"\nüî§ Greedy decoding:")
        print(f"   üìù Ground truth: \"{SAMPLE_TEXT}\"")
        print(f"   üîä Decoded:      \"{omni_decoded}\"")
        
        # Check word overlap
        gt_normalized = SAMPLE_TEXT.lower().replace("'", "")
        decoded_normalized = omni_decoded.lower().replace("'", "")
        
        gt_words = set(gt_normalized.split())
        decoded_words = set(decoded_normalized.split())
        overlap = len(gt_words & decoded_words)
        total = len(gt_words)
        
        print(f"\nüìà Word overlap: {overlap}/{total} ({100*overlap/total:.0f}%)")
        
        # Test batched inference
        print(f"\nüîÑ Testing batched inference...")
        test_waveforms = [
            waveform.squeeze(0)[:16000],  # 1 second
            waveform.squeeze(0)[:32000],  # 2 seconds
        ]
        batch_results = get_emissions_batched(omni_backend, test_waveforms, sample_rate=16000)
        print(f"   ‚Ä¢ Batch size: {len(batch_results)}")
        for i, res in enumerate(batch_results):
            decoded = omni_backend.greedy_decode(res.emissions)
            print(f"   ‚Ä¢ [{i}] frames={res.num_frames}, decoded=\"{decoded[:40]}...\"")
        
        if overlap >= total // 2:
            test_results["Test 9"] = "‚úÖ PASSED"
            print(f"\n‚úÖ Test 9 PASSED - OmniASR backend works with reasonable decoding")
        else:
            test_results["Test 9"] = "‚ö†Ô∏è WARNING"
            print(f"\n‚ö†Ô∏è Test 9 WARNING - Low word overlap (but backend works)")
            
except Exception as e:
    test_results["Test 9"] = "‚ùå FAILED"
    print(f"\n‚ùå Test 9 FAILED: {e}")
    print("   Note: This test requires omnilingual-asr package")
    import traceback
    traceback.print_exc()

## üìã Test Summary

In [None]:
print("=" * 60)
print("üìã TEST RESULTS SUMMARY")
print("=" * 60)

# Display test results
print("\n" + "-" * 40)
for test_name, result in test_results.items():
    print(f"  {result}  {test_name}")
print("-" * 40)

# Count results
passed = sum(1 for r in test_results.values() if "‚úÖ" in r)
failed = sum(1 for r in test_results.values() if "‚ùå" in r)
skipped = sum(1 for r in test_results.values() if "‚è≠Ô∏è" in r)
warning = sum(1 for r in test_results.values() if "‚ö†Ô∏è" in r)
total = len(test_results)

print(f"\n  Total: {total} tests")
print(f"  ‚úÖ Passed:  {passed}")
if warning > 0:
    print(f"  ‚ö†Ô∏è Warning: {warning}")
if skipped > 0:
    print(f"  ‚è≠Ô∏è Skipped: {skipped}")
if failed > 0:
    print(f"  ‚ùå Failed:  {failed}")

print("\n" + "=" * 60)
if failed == 0:
    print("üéâ All tests passed!")
else:
    print(f"‚ö†Ô∏è {failed} test(s) failed - please check above for details")
print("=" * 60)