# OmniASR Backend Testing

This notebook tests the OmniASR backend integration with `labeling_utils`.

**OmniASR** is Facebook/Meta's Omnilingual ASR project supporting 1600+ languages.

**Available models:**
- `omniASR_CTC_300M`: 325M parameters (fastest)
- `omniASR_CTC_1B`: 975M parameters
- `omniASR_CTC_3B`: 3.08B parameters
- `omniASR_CTC_7B`: 6.5B parameters

**References:**
- https://github.com/facebookresearch/omnilingual-asr
- https://github.com/NeuralFalconYT/omnilingual-asr-colab (installation credit)

## 1. Installation (Colab)

OmniASR requires specific PyTorch/fairseq2 versions. Run this cell to install.

In [None]:
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Installing OmniASR dependencies...")
    
    # Uninstall existing PyTorch
    !pip uninstall -y torch torchaudio torchvision -q
    
    # Install PyTorch 2.8.0 with CUDA 12.8 (fairseq2 requirement)
    !pip install torch==2.8.0+cu128 torchaudio==2.8.0+cu128 torchvision==0.23.0+cu128 --index-url https://download.pytorch.org/whl/cu128 -q
    
    # Install fairseq2 and omnilingual-asr
    !pip install fairseq2==0.6 -q
    !pip install omnilingual-asr==0.1.0 -q
    
    # Additional dependencies
    !pip install silero-vad>=4.0.0 onnxruntime>=1.12.0 uroman==1.3.1.1 -q
    
    # Reinstall fairseq2 with correct CUDA variant
    !pip uninstall fairseq2 -y -q
    !pip install fairseq2 --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/pt2.8.0/cu126 -q
    
    # Reinstall omnilingual-asr
    !pip install omnilingual-asr -q
    
    print("Installation complete!")
else:
    print("Not running in Colab. See https://github.com/facebookresearch/omnilingual-asr for installation.")

## 2. Verify Installation

In [None]:
# Verify OmniASR is installed
try:
    from omnilingual_asr.models.inference.pipeline import ASRInferencePipeline
    print("OmniASR is installed correctly!")
except ImportError as e:
    print(f"OmniASR not installed: {e}")
    print("Run the installation cell above.")

## 3. Setup Repository

In [None]:
import sys
import os
from pathlib import Path

GITHUB_REPO = "https://github.com/huangruizhe/torchaudio_aligner.git"
BRANCH = "dev"

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    repo_path = '/content/torchaudio_aligner'
    src_path = f'{repo_path}/src'
    
    if not os.path.exists(repo_path):
        print(f"Cloning repository (branch: {BRANCH})...")
        os.system(f'git clone -b {BRANCH} {GITHUB_REPO} {repo_path}')
    else:
        print(f"Updating repository (branch: {BRANCH})...")
        os.system(f'cd {repo_path} && git pull origin {BRANCH}')
else:
    # Local development
    possible_paths = [
        Path(".").absolute().parent / "src",
        Path(".").absolute() / "src",
    ]
    src_path = None
    for p in possible_paths:
        if p.exists() and (p / "labeling_utils").exists():
            src_path = str(p.absolute())
            break
    if src_path is None:
        raise FileNotFoundError("src directory not found")
    print(f"Running locally from: {src_path}")

if src_path not in sys.path:
    sys.path.insert(0, src_path)

print(f"Source path: {src_path}")

In [None]:
# Import labeling_utils
import importlib
import labeling_utils
importlib.reload(labeling_utils)

from labeling_utils import (
    load_model,
    get_emissions,
    get_emissions_batched,
    list_backends,
    is_backend_available,
    list_presets,
)

import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Available backends: {list_backends()}")
print(f"OmniASR available: {is_backend_available('omniasr')}")

## 4. Load OmniASR Model

In [None]:
# Load OmniASR CTC 300M model (fastest for testing)
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading OmniASR CTC 300M on {device}...")
backend = load_model("omniasr-300m", device=device)

print(f"\nModel loaded!")
print(f"  Frame duration: {backend.frame_duration}s")
print(f"  Sample rate: {backend.sample_rate}Hz")

## 5. Vocabulary Information

In [None]:
# Get vocabulary info
vocab = backend.get_vocab_info()

print(f"Vocabulary size: {len(vocab.labels)}")
print(f"Blank ID: {vocab.blank_id} ('{vocab.blank_token}')")
print(f"UNK ID: {vocab.unk_id} ('{vocab.unk_token}')")
print(f"\nFirst 50 tokens:")
print(vocab.labels[:50])

In [None]:
# Explore vocabulary - find specific characters
def find_tokens(pattern, vocab_labels, max_results=20):
    """Find tokens matching a pattern."""
    results = []
    for i, token in enumerate(vocab_labels):
        if pattern.lower() in token.lower():
            results.append((i, token))
        if len(results) >= max_results:
            break
    return results

# Find some common characters
print("Space token:", [(i, t) for i, t in enumerate(vocab.labels) if t == ' '])
print("\nLatin letters (a-z):")
latin = [(i, t) for i, t in enumerate(vocab.labels) if len(t) == 1 and t.isalpha() and ord(t) < 128]
print(latin[:26])

## 6. Test Emission Extraction

In [None]:
# Load sample audio
from audio_frontend import load_audio, resample
import urllib.request

SAMPLE_AUDIO = "Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SAMPLE_TEXT = "I had that curiosity beside me at this moment"

# Determine path
if IN_COLAB:
    sample_path = f"/content/torchaudio_aligner/examples/{SAMPLE_AUDIO}"
else:
    sample_path = str(Path(src_path).parent / "examples" / SAMPLE_AUDIO)

if not os.path.exists(sample_path):
    url = f"https://raw.githubusercontent.com/huangruizhe/torchaudio_aligner/dev/examples/{SAMPLE_AUDIO}"
    print(f"Downloading sample audio...")
    urllib.request.urlretrieve(url, SAMPLE_AUDIO)
    sample_path = SAMPLE_AUDIO

waveform, sample_rate = load_audio(sample_path)
print(f"Loaded: {sample_path}")
print(f"  Shape: {waveform.shape}")
print(f"  Sample rate: {sample_rate}Hz")
print(f"  Duration: {waveform.shape[1] / sample_rate:.2f}s")
print(f"  Transcript: \"{SAMPLE_TEXT}\"")

# Resample if needed
if sample_rate != 16000:
    waveform = resample(waveform, sample_rate, 16000)
    sample_rate = 16000
    print(f"  Resampled to: {sample_rate}Hz")

In [None]:
# Extract emissions
result = get_emissions(backend, waveform, sample_rate=16000)

print(f"Emissions shape: {result.emissions.shape}")
print(f"Num frames: {result.num_frames}")
print(f"Vocab size: {result.vocab_size}")
print(f"Duration: {result.duration:.2f}s")

# Verify log probabilities sum to ~1
probs = torch.exp(result.emissions[0])
print(f"\nProb sum at frame 0: {probs.sum().item():.4f} (should be ~1.0)")

## 7. Greedy Decoding

In [None]:
# Greedy decode
decoded = backend.greedy_decode(result.emissions)

print(f"Ground truth: \"{SAMPLE_TEXT}\"")
print(f"Decoded:      \"{decoded}\"")

# Word overlap
gt_words = set(SAMPLE_TEXT.lower().split())
decoded_words = set(decoded.lower().split())
overlap = len(gt_words & decoded_words)
print(f"\nWord overlap: {overlap}/{len(gt_words)} ({100*overlap/len(gt_words):.0f}%)")

In [None]:
# Analyze frame predictions
print("Top predictions per frame (first 30 frames):")
print("-" * 50)

blank_count = 0
for i in range(min(30, result.num_frames)):
    top_idx = result.emissions[i].argmax().item()
    top_prob = torch.exp(result.emissions[i, top_idx]).item()
    label = vocab.id_to_label.get(top_idx, "?")
    
    if top_idx == vocab.blank_id:
        blank_count += 1
        display_label = "<blank>"
    else:
        display_label = repr(label)
    
    print(f"Frame {i:3d}: {display_label:10s} (idx={top_idx:4d}, prob={top_prob:.3f})")

print(f"\nBlank frames in first 30: {blank_count}")

## 8. Batched Inference

In [None]:
# Test batched inference with variable-length inputs
waveforms = [
    waveform.squeeze(0)[:16000],   # 1 second
    waveform.squeeze(0)[:32000],   # 2 seconds
    waveform.squeeze(0)[:48000],   # 3 seconds
]

print(f"Testing batched inference with {len(waveforms)} samples:")
for i, w in enumerate(waveforms):
    print(f"  [{i}] {len(w)/16000:.1f}s ({len(w)} samples)")

# Extract emissions in batch
batch_results = get_emissions_batched(backend, waveforms, sample_rate=16000)

print(f"\nResults:")
for i, res in enumerate(batch_results):
    decoded = backend.greedy_decode(res.emissions)
    print(f"  [{i}] frames={res.num_frames:3d}, decoded=\"{decoded}\"")

## 9. Direct Model Access (Advanced)

The OmniASR backend uses direct model forward calls with fairseq2's BatchLayout.

In [None]:
# Access underlying model components
print("OmniASR Model Architecture:")
print(f"  Model type: {type(backend._model).__name__}")
print(f"  Model dtype: {backend._model_dtype}")
print(f"  Device: {backend._device_obj}")

# Model submodules
print(f"\nSubmodules:")
for name, module in backend._model.named_children():
    print(f"  {name}: {type(module).__name__}")

In [None]:
# Direct forward call example (what the backend does internally)
from fairseq2.nn import BatchLayout
import torch.nn.functional as F

# Prepare input
test_wav = waveform.squeeze(0).to(backend._device_obj, dtype=backend._model_dtype)
test_wav = test_wav.unsqueeze(0)  # Add batch dim

# Create lengths
lengths = torch.tensor([test_wav.shape[1]], dtype=torch.long, device=backend._device_obj)

# Create BatchLayout
batch_layout = BatchLayout(
    test_wav.shape,
    seq_lens=lengths,
    device=backend._device_obj,
)

# Forward pass
backend._model.eval()
with torch.inference_mode():
    logits, output_layout = backend._model(test_wav, batch_layout)

# Convert to log probabilities
emissions = F.log_softmax(logits.float(), dim=-1)

print(f"Direct forward call:")
print(f"  Input shape: {test_wav.shape}")
print(f"  Logits shape: {logits.shape}")
print(f"  Emissions shape: {emissions.shape}")
print(f"  Output seq lens: {output_layout.seq_lens}")

## 10. Compare Model Sizes

In [None]:
# Test with different model sizes (if you have enough GPU memory)
# Uncomment to test larger models

model_sizes = [
    ("omniasr-300m", "omniASR_CTC_300M (325M params)"),
    # ("omniasr-1b", "omniASR_CTC_1B (975M params)"),
    # ("omniasr-3b", "omniASR_CTC_3B (3.08B params)"),
    # ("omniasr-7b", "omniASR_CTC_7B (6.5B params)"),
]

print("Model comparison:")
print("=" * 60)

for preset, description in model_sizes:
    print(f"\nLoading {description}...")
    try:
        model = load_model(preset, device=device)
        result = get_emissions(model, waveform, sample_rate=16000)
        decoded = model.greedy_decode(result.emissions)
        
        print(f"  Vocab size: {len(model.get_vocab_info().labels)}")
        print(f"  Decoded: \"{decoded}\"")
        
        # Clean up to free GPU memory
        del model
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
    except Exception as e:
        print(f"  Failed: {e}")

## 11. Integration with Full Alignment Pipeline

In [None]:
# Full alignment example using OmniASR backend
from alignment import Aligner

# Create aligner with OmniASR backend
aligner = Aligner(
    backend="omniasr-300m",
    device=device,
)

print(f"Aligner created with OmniASR backend")
print(f"  Backend: {aligner.backend}")

In [None]:
# Align the sample audio
result = aligner.align(
    audio=sample_path,
    text=SAMPLE_TEXT,
)

print("Alignment result:")
print(f"  Words: {len(result.words)}")
print()
for word in result.words:
    print(f"  {word.start:.3f} - {word.end:.3f}: {word.label}")

In [None]:
# Display alignment with audio player
result.display_html()

## Summary

The OmniASR backend provides:

1. **Character-level vocabulary** with 9812 tokens supporting 1600+ languages
2. **Direct model forward calls** using fairseq2's BatchLayout (no temp files)
3. **Batched inference** with variable-length inputs
4. **Multiple model sizes** from 300M to 7B parameters
5. **Full integration** with the alignment pipeline