# Test Notebook: OmniLingual ASR (OmniASR) Backend

This notebook tests and debugs the OmniASR backend for torchaudio_aligner.

**Goals:**
1. Understand OmniASR's internal structure (model, tokenizer, vocab)
2. Properly extract vocabulary from the tokenizer
3. Properly extract CTC emissions/posteriors
4. Test different approaches to get emissions

**Reference:**
- https://github.com/facebookresearch/omnilingual-asr
- Models built on fairseq2

## Setup

In [None]:
# Install dependencies
# WARNING: omnilingual-asr may have conflicting dependencies
!pip install -q omnilingual-asr

In [None]:
import torch
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

## Part 1: Explore OmniASR Pipeline Structure

First, let's understand what's inside the ASRInferencePipeline.

In [None]:
from omnilingual_asr.models.inference.pipeline import ASRInferencePipeline

# Load a smaller model for testing
MODEL_NAME = "omniASR_CTC_300M"  # Smallest CTC model
# MODEL_NAME = "omniASR_CTC_1B"  # 1B parameter model
# MODEL_NAME = "omniASR_CTC_1B_v2"  # v2 uses different tokenizer

print(f"Loading {MODEL_NAME}...")
pipeline = ASRInferencePipeline(model_card=MODEL_NAME, device=device)
print("Pipeline loaded!")

In [None]:
# Inspect pipeline attributes
print("=" * 60)
print("Pipeline Attributes:")
print("=" * 60)
for attr in dir(pipeline):
    if not attr.startswith('_'):
        val = getattr(pipeline, attr, None)
        if not callable(val):
            print(f"  {attr}: {type(val).__name__}")

print("\n" + "=" * 60)
print("Pipeline Private Attributes:")
print("=" * 60)
for attr in dir(pipeline):
    if attr.startswith('_') and not attr.startswith('__'):
        val = getattr(pipeline, attr, None)
        if not callable(val):
            print(f"  {attr}: {type(val).__name__}")

In [None]:
# Try to access the model
print("=" * 60)
print("Accessing Model:")
print("=" * 60)

model = None
if hasattr(pipeline, 'model'):
    model = pipeline.model
    print(f"Found model via 'model' attribute")
elif hasattr(pipeline, '_model'):
    model = pipeline._model
    print(f"Found model via '_model' attribute")
else:
    # Search for torch.nn.Module in attributes
    for attr in dir(pipeline):
        val = getattr(pipeline, attr, None)
        if isinstance(val, torch.nn.Module):
            model = val
            print(f"Found model via '{attr}' attribute")
            break

if model is not None:
    print(f"\nModel type: {type(model).__name__}")
    print(f"Model class: {model.__class__.__module__}.{model.__class__.__name__}")
else:
    print("Could not find model!")

In [None]:
# Inspect model structure
if model is not None:
    print("=" * 60)
    print("Model Attributes (non-callable):")
    print("=" * 60)
    for attr in dir(model):
        if not attr.startswith('_'):
            val = getattr(model, attr, None)
            if not callable(val) and not isinstance(val, torch.nn.Module):
                print(f"  {attr}: {type(val).__name__}")
    
    print("\n" + "=" * 60)
    print("Model Sub-modules:")
    print("=" * 60)
    for name, module in model.named_children():
        print(f"  {name}: {type(module).__name__}")

In [None]:
# Deep inspection of model modules
if model is not None:
    print("=" * 60)
    print("All Named Modules (looking for CTC/projection layers):")
    print("=" * 60)
    for name, module in model.named_modules():
        if any(s in name.lower() for s in ['ctc', 'proj', 'output', 'lm_head', 'linear', 'final']):
            print(f"  {name}: {type(module).__name__}")
            if isinstance(module, torch.nn.Linear):
                print(f"       in_features={module.in_features}, out_features={module.out_features}")

## Part 2: Explore Tokenizer and Vocabulary

The tokenizer determines the output vocabulary for CTC.

In [None]:
# Try to access tokenizer from pipeline
print("=" * 60)
print("Accessing Tokenizer:")
print("=" * 60)

tokenizer = None
tokenizer_source = None

# Check pipeline attributes
for attr in ['tokenizer', '_tokenizer', 'text_tokenizer', '_text_tokenizer']:
    if hasattr(pipeline, attr):
        tokenizer = getattr(pipeline, attr)
        tokenizer_source = f"pipeline.{attr}"
        print(f"Found tokenizer via {tokenizer_source}")
        break

# Check model attributes
if tokenizer is None and model is not None:
    for attr in ['tokenizer', '_tokenizer', 'decoder']:
        if hasattr(model, attr):
            val = getattr(model, attr)
            if hasattr(val, 'tokenizer'):
                tokenizer = val.tokenizer
                tokenizer_source = f"model.{attr}.tokenizer"
            else:
                tokenizer = val
                tokenizer_source = f"model.{attr}"
            print(f"Found tokenizer via {tokenizer_source}")
            break

if tokenizer is not None:
    print(f"\nTokenizer type: {type(tokenizer).__name__}")
    print(f"Tokenizer class: {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}")
else:
    print("Could not find tokenizer directly!")

In [None]:
# Inspect tokenizer attributes
if tokenizer is not None:
    print("=" * 60)
    print("Tokenizer Attributes:")
    print("=" * 60)
    for attr in dir(tokenizer):
        if not attr.startswith('_'):
            val = getattr(tokenizer, attr, None)
            if not callable(val):
                try:
                    print(f"  {attr}: {type(val).__name__} = {str(val)[:100]}")
                except:
                    print(f"  {attr}: {type(val).__name__}")
    
    print("\nTokenizer Methods:")
    for attr in dir(tokenizer):
        if not attr.startswith('_'):
            val = getattr(tokenizer, attr, None)
            if callable(val):
                print(f"  {attr}()")

In [None]:
# Try to get vocabulary from tokenizer
print("=" * 60)
print("Extracting Vocabulary:")
print("=" * 60)

vocab = None
vocab_size = None

if tokenizer is not None:
    # Method 1: get_vocab()
    if hasattr(tokenizer, 'get_vocab'):
        try:
            vocab = tokenizer.get_vocab()
            print(f"Method 1 (get_vocab): vocab size = {len(vocab)}")
        except Exception as e:
            print(f"Method 1 failed: {e}")
    
    # Method 2: vocab attribute
    if hasattr(tokenizer, 'vocab'):
        try:
            v = tokenizer.vocab
            print(f"Method 2 (vocab attr): type = {type(v).__name__}")
            if isinstance(v, dict):
                vocab = v
                print(f"  vocab size = {len(vocab)}")
        except Exception as e:
            print(f"Method 2 failed: {e}")
    
    # Method 3: vocab_size attribute
    if hasattr(tokenizer, 'vocab_size'):
        try:
            vocab_size = tokenizer.vocab_size
            print(f"Method 3 (vocab_size): {vocab_size}")
        except Exception as e:
            print(f"Method 3 failed: {e}")
    
    # Method 4: vocab_info()
    if hasattr(tokenizer, 'vocab_info'):
        try:
            vi = tokenizer.vocab_info()
            print(f"Method 4 (vocab_info): {vi}")
        except Exception as e:
            print(f"Method 4 failed: {e}")

In [None]:
# Alternative: Load tokenizer via fairseq2 asset store
print("=" * 60)
print("Loading Tokenizer via Fairseq2 Asset Store:")
print("=" * 60)

try:
    from fairseq2.assets import asset_store, download_manager
    
    # List available cards
    print("\nSearching for tokenizer cards...")
    
    # Try to find OmniASR tokenizer
    tokenizer_names = [
        "omniASR_tokenizer_written_v2",
        "omniASR_tokenizer_v1",
        "omniASR_tokenizer",
    ]
    
    for name in tokenizer_names:
        try:
            card = asset_store.retrieve_card(name)
            print(f"\nFound card: {name}")
            print(f"  Card: {card}")
            
            # Try to download
            if hasattr(card, 'uri'):
                path = download_manager.download_tokenizer(card.uri)
                print(f"  Downloaded to: {path}")
        except Exception as e:
            print(f"  {name}: {e}")
            
except ImportError as e:
    print(f"fairseq2.assets not available: {e}")
except Exception as e:
    print(f"Error: {e}")

In [None]:
# Alternative: Get vocab from model output dimension
print("=" * 60)
print("Getting Vocab Size from Model Output Layer:")
print("=" * 60)

if model is not None:
    # Find the final linear layer (CTC projection)
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            if any(s in name.lower() for s in ['ctc', 'proj', 'output', 'final', 'lm']):
                print(f"  {name}: in={module.in_features}, out={module.out_features}")
                if vocab_size is None:
                    vocab_size = module.out_features
                    print(f"  -> Using {vocab_size} as vocab size")

## Part 3: Test Emission Extraction

Now let's test different approaches to get CTC emissions.

In [None]:
# Create test waveform
print("=" * 60)
print("Creating Test Input:")
print("=" * 60)

# Generate random audio (10 seconds at 16kHz)
batch_size = 1
duration_sec = 5
sample_rate = 16000
waveform = torch.randn(batch_size, duration_sec * sample_rate).to(device)
lengths = torch.tensor([waveform.shape[1]] * batch_size).to(device)

print(f"Waveform shape: {waveform.shape}")
print(f"Lengths: {lengths}")

In [None]:
# Option 1: Direct model forward
print("=" * 60)
print("Option 1: Direct Model Forward")
print("=" * 60)

if model is not None:
    model.eval()
    
    try:
        with torch.inference_mode():
            out = model(waveform, lengths)
        
        print(f"Output type: {type(out).__name__}")
        
        # Try to extract logits
        if isinstance(out, torch.Tensor):
            logits = out
            print(f"Output is tensor: {logits.shape}")
        elif isinstance(out, dict):
            print(f"Output keys: {out.keys()}")
            for key in ['logits', 'ctc_logits', 'emissions', 'encoder_out']:
                if key in out:
                    logits = out[key]
                    print(f"Found '{key}': {logits.shape if isinstance(logits, torch.Tensor) else type(logits)}")
                    break
        else:
            print(f"Output attributes: {[a for a in dir(out) if not a.startswith('_')]}")
            for attr in ['logits', 'ctc_logits', 'emissions', 'output']:
                if hasattr(out, attr):
                    val = getattr(out, attr)
                    print(f"Found '{attr}': {val.shape if isinstance(val, torch.Tensor) else type(val)}")
                    
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()

In [None]:
# Option 2: Encoder + CTC projection
print("=" * 60)
print("Option 2: Encoder + CTC Projection")
print("=" * 60)

if model is not None:
    model.eval()
    
    # Find encoder
    encoder = None
    for attr in ['encoder', 'wav2vec2', 'feature_extractor']:
        if hasattr(model, attr):
            encoder = getattr(model, attr)
            print(f"Found encoder: model.{attr}")
            break
    
    # Find CTC projection
    ctc_proj = None
    for attr in ['ctc_proj', 'output_projection', 'ctc', 'ctc_decoder', 'final_proj']:
        if hasattr(model, attr):
            ctc_proj = getattr(model, attr)
            print(f"Found CTC projection: model.{attr}")
            break
    
    if encoder is not None and ctc_proj is not None:
        try:
            with torch.inference_mode():
                # Encode
                encoder_out = encoder(waveform)
                print(f"Encoder output type: {type(encoder_out).__name__}")
                
                # Get features
                if isinstance(encoder_out, torch.Tensor):
                    features = encoder_out
                elif hasattr(encoder_out, 'output'):
                    features = encoder_out.output
                elif hasattr(encoder_out, 'last_hidden_state'):
                    features = encoder_out.last_hidden_state
                else:
                    features = encoder_out
                    
                print(f"Features shape: {features.shape if isinstance(features, torch.Tensor) else type(features)}")
                
                # Apply CTC projection
                logits = ctc_proj(features)
                print(f"Logits shape: {logits.shape}")
                
                # Log softmax
                emissions = F.log_softmax(logits.float(), dim=-1)
                print(f"Emissions shape: {emissions.shape}")
                
        except Exception as e:
            print(f"Error: {e}")
            import traceback
            traceback.print_exc()
    else:
        print("Could not find encoder and/or CTC projection")

In [None]:
# Option 3: Hook into the model
print("=" * 60)
print("Option 3: Hook-based Extraction")
print("=" * 60)

if model is not None:
    model.eval()
    
    captured = {}
    
    def hook_fn(module, inp, out):
        captured["logits"] = out
    
    # Find a likely projection layer
    proj = None
    proj_name = None
    for name, m in model.named_modules():
        if isinstance(m, torch.nn.Linear):
            if any(s in name.lower() for s in ["ctc", "proj", "output", "lm_head", "final"]):
                proj = m
                proj_name = name
                break
    
    if proj is not None:
        print(f"Hooking: {proj_name}")
        h = proj.register_forward_hook(hook_fn)
        
        try:
            with torch.inference_mode():
                _ = model(waveform, lengths)
            
            h.remove()
            
            if "logits" in captured:
                logits = captured["logits"]
                print(f"Captured logits shape: {logits.shape}")
                
                emissions = F.log_softmax(logits.float(), dim=-1)
                print(f"Emissions shape: {emissions.shape}")
            else:
                print("Hook did not capture logits")
                
        except Exception as e:
            h.remove()
            print(f"Error: {e}")
            import traceback
            traceback.print_exc()
    else:
        print("Could not find a projection layer to hook")

In [None]:
# Option 4: Use extract_features if available
print("=" * 60)
print("Option 4: extract_features() Method")
print("=" * 60)

if model is not None:
    model.eval()
    
    if hasattr(model, 'extract_features'):
        try:
            with torch.inference_mode():
                encoder_out, padding_mask = model.extract_features(waveform, padding_mask=None)
            
            print(f"Encoder output shape: {encoder_out.shape}")
            if padding_mask is not None:
                print(f"Padding mask shape: {padding_mask.shape}")
            
            # Find CTC projection
            ctc_proj = None
            for attr in ['ctc_proj', 'output_projection', 'proj', 'final_proj']:
                if hasattr(model, attr):
                    ctc_proj = getattr(model, attr)
                    print(f"Found CTC projection: model.{attr}")
                    break
            
            if ctc_proj is not None:
                logits = ctc_proj(encoder_out)
                print(f"Logits shape: {logits.shape}")
                
                emissions = F.log_softmax(logits.float(), dim=-1)
                print(f"Emissions shape: {emissions.shape}")
                
        except Exception as e:
            print(f"Error: {e}")
            import traceback
            traceback.print_exc()
    else:
        print("Model does not have extract_features() method")

In [None]:
# Option 5: Full forward with output parsing
print("=" * 60)
print("Option 5: Full Forward with Output Parsing")
print("=" * 60)

if model is not None:
    model.eval()
    
    try:
        with torch.inference_mode():
            # Try different forward signatures
            out = None
            
            # Try 1: (source, padding_mask)
            try:
                out = model(source=waveform, padding_mask=None)
                print("Forward signature: (source=, padding_mask=)")
            except:
                pass
            
            # Try 2: (waveform, lengths)
            if out is None:
                try:
                    out = model(waveform, lengths)
                    print("Forward signature: (waveform, lengths)")
                except:
                    pass
            
            # Try 3: just waveform
            if out is None:
                try:
                    out = model(waveform)
                    print("Forward signature: (waveform)")
                except:
                    pass
            
            if out is not None:
                print(f"\nOutput type: {type(out).__name__}")
                
                # Parse output
                if isinstance(out, torch.Tensor):
                    print(f"Direct tensor output: {out.shape}")
                    logits = out
                elif isinstance(out, dict):
                    print(f"Dict output keys: {out.keys()}")
                    for k, v in out.items():
                        if isinstance(v, torch.Tensor):
                            print(f"  {k}: {v.shape}")
                        elif isinstance(v, dict):
                            print(f"  {k}: dict with keys {v.keys()}")
                else:
                    # Named tuple or custom object
                    print(f"Object attributes:")
                    for attr in dir(out):
                        if not attr.startswith('_'):
                            val = getattr(out, attr, None)
                            if isinstance(val, torch.Tensor):
                                print(f"  {attr}: {val.shape}")
                            elif not callable(val):
                                print(f"  {attr}: {type(val).__name__}")
            else:
                print("Could not call model forward")
                
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()

## Part 4: Test Pipeline Audio Loading

Check if the pipeline has built-in audio loading.

In [None]:
# Check pipeline for audio loading methods
print("=" * 60)
print("Pipeline Audio Loading Methods:")
print("=" * 60)

for attr in dir(pipeline):
    if 'audio' in attr.lower() or 'load' in attr.lower() or 'wave' in attr.lower():
        val = getattr(pipeline, attr, None)
        print(f"  {attr}: {type(val).__name__} {'(callable)' if callable(val) else ''}")

In [None]:
# Test transcription to verify model works
print("=" * 60)
print("Test Transcription (to verify model works):")
print("=" * 60)

# Generate some test audio or use a real file
try:
    # If you have a test audio file, use it:
    # result = pipeline("/path/to/audio.wav")
    
    # Or test with generated audio (will produce gibberish)
    import torchaudio
    import tempfile
    
    # Create a simple test audio
    test_waveform = torch.randn(1, 16000 * 3)  # 3 seconds
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
        torchaudio.save(f.name, test_waveform, 16000)
        temp_path = f.name
    
    # Try to transcribe
    result = pipeline(temp_path)
    print(f"Transcription result: {result}")
    print(f"Result type: {type(result).__name__}")
    
    import os
    os.unlink(temp_path)
    
except Exception as e:
    print(f"Error: {e}")
    import traceback
    traceback.print_exc()

## Part 5: Summary and Recommendations

Based on the exploration above, summarize what we learned.

In [None]:
print("=" * 60)
print("SUMMARY")
print("=" * 60)

print(f"\n1. Model Name: {MODEL_NAME}")
print(f"2. Model Type: {type(model).__name__ if model else 'Not found'}")
print(f"3. Tokenizer Type: {type(tokenizer).__name__ if tokenizer else 'Not found'}")
print(f"4. Vocab Size: {vocab_size if vocab_size else 'Unknown'}")

print("\n" + "=" * 60)
print("RECOMMENDATIONS FOR BACKEND IMPLEMENTATION:")
print("=" * 60)

print("""
Based on the exploration, update omniasr_backend.py to:

1. TOKENIZER/VOCAB:
   - [Fill in based on what worked above]

2. EMISSION EXTRACTION:
   - [Fill in based on what worked above]

3. AUDIO LOADING:
   - [Fill in based on what worked above]
""")