# Test Notebook: OmniLingual ASR (OmniASR) Backend

This notebook tests and debugs the OmniASR backend for torchaudio_aligner.

**Goals:**
1. Understand OmniASR's internal structure (model, tokenizer, vocab)
2. Properly extract vocabulary from the tokenizer
3. Properly extract CTC emissions/posteriors
4. Test different approaches to get emissions

**Reference:**
- https://github.com/facebookresearch/omnilingual-asr
- Models built on fairseq2

## Setup

In [None]:
# Install dependencies
# WARNING: omnilingual-asr may have conflicting dependencies
!pip install -q omnilingual-asr

In [1]:
import torch
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

Device: cuda


## Part 1: Explore OmniASR Pipeline Structure

First, let's understand what's inside the ASRInferencePipeline.

In [2]:
from omnilingual_asr.models.inference.pipeline import ASRInferencePipeline

# Load a smaller model for testing
MODEL_NAME = "omniASR_CTC_300M"  # Smallest CTC model
# MODEL_NAME = "omniASR_CTC_1B"  # 1B parameter model
# MODEL_NAME = "omniASR_CTC_1B_v2"  # v2 uses different tokenizer

print(f"Loading {MODEL_NAME}...")
pipeline = ASRInferencePipeline(model_card=MODEL_NAME, device=device)
print("Pipeline loaded!")

Loading omniASR_CTC_300M...


100%|██████████| 1.21G/1.21G [00:29<00:00, 44.9MB/s]


Output()

Pipeline loaded!


In [3]:
# Inspect pipeline attributes
print("=" * 60)
print("Pipeline Attributes:")
print("=" * 60)
for attr in dir(pipeline):
    if not attr.startswith('_'):
        val = getattr(pipeline, attr, None)
        if not callable(val):
            print(f"  {attr}: {type(val).__name__}")

print("\n" + "=" * 60)
print("Pipeline Private Attributes:")
print("=" * 60)
for attr in dir(pipeline):
    if attr.startswith('_') and not attr.startswith('__'):
        val = getattr(pipeline, attr, None)
        if not callable(val):
            print(f"  {attr}: {type(val).__name__}")

Pipeline Attributes:
  beam_search_generator: NoneType
  device: device
  dtype: dtype
  tokenizer: RawSentencePieceTokenizer

Pipeline Private Attributes:


In [4]:
# Try to access the model
print("=" * 60)
print("Accessing Model:")
print("=" * 60)

model = None
if hasattr(pipeline, 'model'):
    model = pipeline.model
    print(f"Found model via 'model' attribute")
elif hasattr(pipeline, '_model'):
    model = pipeline._model
    print(f"Found model via '_model' attribute")
else:
    # Search for torch.nn.Module in attributes
    for attr in dir(pipeline):
        val = getattr(pipeline, attr, None)
        if isinstance(val, torch.nn.Module):
            model = val
            print(f"Found model via '{attr}' attribute")
            break

if model is not None:
    print(f"\nModel type: {type(model).__name__}")
    print(f"Model class: {model.__class__.__module__}.{model.__class__.__name__}")
else:
    print("Could not find model!")

Accessing Model:
Found model via 'model' attribute

Model type: Wav2Vec2AsrModel
Model class: fairseq2.models.wav2vec2.asr.model.Wav2Vec2AsrModel


In [5]:
# Inspect model structure
if model is not None:
    print("=" * 60)
    print("Model Attributes (non-callable):")
    print("=" * 60)
    for attr in dir(model):
        if not attr.startswith('_'):
            val = getattr(model, attr, None)
            if not callable(val) and not isinstance(val, torch.nn.Module):
                print(f"  {attr}: {type(val).__name__}")
    
    print("\n" + "=" * 60)
    print("Model Sub-modules:")
    print("=" * 60)
    for name, module in model.named_children():
        print(f"  {name}: {type(module).__name__}")

Model Attributes (non-callable):
  T_destination: TypeVar
  call_super_init: bool
  dump_patches: bool
  final_dropout: NoneType
  masker: NoneType
  model_dim: int
  training: bool

Model Sub-modules:
  encoder_frontend: Wav2Vec2Frontend
  encoder: StandardTransformerEncoder
  final_proj: Linear


In [6]:
# Deep inspection of model modules
if model is not None:
    print("=" * 60)
    print("All Named Modules (looking for CTC/projection layers):")
    print("=" * 60)
    for name, module in model.named_modules():
        if any(s in name.lower() for s in ['ctc', 'proj', 'output', 'lm_head', 'linear', 'final']):
            print(f"  {name}: {type(module).__name__}")
            if isinstance(module, torch.nn.Linear):
                print(f"       in_features={module.in_features}, out_features={module.out_features}")

All Named Modules (looking for CTC/projection layers):
  encoder_frontend.model_dim_proj: Linear
  encoder.layers.0.self_attn.q_proj: Linear
  encoder.layers.0.self_attn.k_proj: Linear
  encoder.layers.0.self_attn.v_proj: Linear
  encoder.layers.0.self_attn.output_proj: Linear
  encoder.layers.0.ffn.inner_proj: Linear
  encoder.layers.0.ffn.output_proj: Linear
  encoder.layers.1.self_attn.q_proj: Linear
  encoder.layers.1.self_attn.k_proj: Linear
  encoder.layers.1.self_attn.v_proj: Linear
  encoder.layers.1.self_attn.output_proj: Linear
  encoder.layers.1.ffn.inner_proj: Linear
  encoder.layers.1.ffn.output_proj: Linear
  encoder.layers.2.self_attn.q_proj: Linear
  encoder.layers.2.self_attn.k_proj: Linear
  encoder.layers.2.self_attn.v_proj: Linear
  encoder.layers.2.self_attn.output_proj: Linear
  encoder.layers.2.ffn.inner_proj: Linear
  encoder.layers.2.ffn.output_proj: Linear
  encoder.layers.3.self_attn.q_proj: Linear
  encoder.layers.3.self_attn.k_proj: Linear
  encoder.layers.

## Part 2: Explore Tokenizer and Vocabulary

The tokenizer determines the output vocabulary for CTC.

In [7]:
# Try to access tokenizer from pipeline
print("=" * 60)
print("Accessing Tokenizer:")
print("=" * 60)

tokenizer = None
tokenizer_source = None

# Check pipeline attributes
for attr in ['tokenizer', '_tokenizer', 'text_tokenizer', '_text_tokenizer']:
    if hasattr(pipeline, attr):
        tokenizer = getattr(pipeline, attr)
        tokenizer_source = f"pipeline.{attr}"
        print(f"Found tokenizer via {tokenizer_source}")
        break

# Check model attributes
if tokenizer is None and model is not None:
    for attr in ['tokenizer', '_tokenizer', 'decoder']:
        if hasattr(model, attr):
            val = getattr(model, attr)
            if hasattr(val, 'tokenizer'):
                tokenizer = val.tokenizer
                tokenizer_source = f"model.{attr}.tokenizer"
            else:
                tokenizer = val
                tokenizer_source = f"model.{attr}"
            print(f"Found tokenizer via {tokenizer_source}")
            break

if tokenizer is not None:
    print(f"\nTokenizer type: {type(tokenizer).__name__}")
    print(f"Tokenizer class: {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}")
else:
    print("Could not find tokenizer directly!")

Accessing Tokenizer:
Found tokenizer via pipeline.tokenizer

Tokenizer type: RawSentencePieceTokenizer
Tokenizer class: fairseq2.data.tokenizers.sentencepiece.RawSentencePieceTokenizer


In [8]:
# Inspect tokenizer attributes
if tokenizer is not None:
    print("=" * 60)
    print("Tokenizer Attributes:")
    print("=" * 60)
    for attr in dir(tokenizer):
        if not attr.startswith('_'):
            val = getattr(tokenizer, attr, None)
            if not callable(val):
                try:
                    print(f"  {attr}: {type(val).__name__} = {str(val)[:100]}")
                except:
                    print(f"  {attr}: {type(val).__name__}")
    
    print("\nTokenizer Methods:")
    for attr in dir(tokenizer):
        if not attr.startswith('_'):
            val = getattr(tokenizer, attr, None)
            if callable(val):
                print(f"  {attr}()")

Tokenizer Attributes:
  vocab_info: VocabularyInfo = VocabularyInfo(size=9812, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1, boh_idx=None, eoh_idx=None)

Tokenizer Methods:
  create_decoder()
  create_encoder()
  create_raw_encoder()


In [11]:
# Try to get vocabulary from tokenizer
print("=" * 60)
print("Extracting Vocabulary:")
print("=" * 60)

vocab = None
vocab_size = None

if tokenizer is not None:
    # Method 1: get_vocab()
    if hasattr(tokenizer, 'get_vocab'):
        try:
            vocab = tokenizer.get_vocab()
            print(f"Method 1 (get_vocab): vocab size = {len(vocab)}")
        except Exception as e:
            print(f"Method 1 failed: {e}")
    
    # Method 2: vocab attribute
    if hasattr(tokenizer, 'vocab'):
        try:
            v = tokenizer.vocab
            print(f"Method 2 (vocab attr): type = {type(v).__name__}")
            if isinstance(v, dict):
                vocab = v
                print(f"  vocab size = {len(vocab)}")
        except Exception as e:
            print(f"Method 2 failed: {e}")
    
    # Method 3: vocab_size attribute
    if hasattr(tokenizer, 'vocab_size'):
        try:
            vocab_size = tokenizer.vocab_size
            print(f"Method 3 (vocab_size): {vocab_size}")
        except Exception as e:
            print(f"Method 3 failed: {e}")
    
    # Method 4: vocab_info()
    if hasattr(tokenizer, 'vocab_info'):
        try:
            vi = tokenizer.vocab_info()
            print(f"Method 4 (vocab_info): {vi}")
        except Exception as e:
            print(f"Method 4 failed: {e}")

Extracting Vocabulary:
Method 4 failed: 'VocabularyInfo' object is not callable


In [12]:
# Alternative: Load tokenizer via fairseq2 asset store
print("=" * 60)
print("Loading Tokenizer via Fairseq2 Asset Store:")
print("=" * 60)

try:
    from fairseq2.assets import asset_store, download_manager
    
    # List available cards
    print("\nSearching for tokenizer cards...")
    
    # Try to find OmniASR tokenizer
    tokenizer_names = [
        "omniASR_tokenizer_written_v2",
        "omniASR_tokenizer_v1",
        "omniASR_tokenizer",
    ]
    
    for name in tokenizer_names:
        try:
            card = asset_store.retrieve_card(name)
            print(f"\nFound card: {name}")
            print(f"  Card: {card}")
            
            # Try to download
            if hasattr(card, 'uri'):
                path = download_manager.download_tokenizer(card.uri)
                print(f"  Downloaded to: {path}")
        except Exception as e:
            print(f"  {name}: {e}")
            
except ImportError as e:
    print(f"fairseq2.assets not available: {e}")
except Exception as e:
    print(f"Error: {e}")

Loading Tokenizer via Fairseq2 Asset Store:
fairseq2.assets not available: cannot import name 'asset_store' from 'fairseq2.assets' (/usr/local/lib/python3.12/dist-packages/fairseq2/assets/__init__.py)


In [10]:
# Alternative: Get vocab from model output dimension
print("=" * 60)
print("Getting Vocab Size from Model Output Layer:")
print("=" * 60)

if model is not None:
    # Find the final linear layer (CTC projection)
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            if any(s in name.lower() for s in ['ctc', 'proj', 'output', 'final', 'lm']):
                print(f"  {name}: in={module.in_features}, out={module.out_features}")
                if vocab_size is None:
                    vocab_size = module.out_features
                    print(f"  -> Using {vocab_size} as vocab size")

Getting Vocab Size from Model Output Layer:


## Part 3: Test Emission Extraction

Now let's test different approaches to get CTC emissions.

In [13]:
# Create test waveform
print("=" * 60)
print("Creating Test Input:")
print("=" * 60)

# Generate random audio (10 seconds at 16kHz)
batch_size = 1
duration_sec = 5
sample_rate = 16000
waveform = torch.randn(batch_size, duration_sec * sample_rate).to(device)
lengths = torch.tensor([waveform.shape[1]] * batch_size).to(device)

print(f"Waveform shape: {waveform.shape}")
print(f"Lengths: {lengths}")

Creating Test Input:
Waveform shape: torch.Size([1, 80000])
Lengths: tensor([80000], device='cuda:0')


In [14]:
# Option 1: Direct model forward
print("=" * 60)
print("Option 1: Direct Model Forward")
print("=" * 60)

if model is not None:
    model.eval()
    
    try:
        with torch.inference_mode():
            out = model(waveform, lengths)
        
        print(f"Output type: {type(out).__name__}")
        
        # Try to extract logits
        if isinstance(out, torch.Tensor):
            logits = out
            print(f"Output is tensor: {logits.shape}")
        elif isinstance(out, dict):
            print(f"Output keys: {out.keys()}")
            for key in ['logits', 'ctc_logits', 'emissions', 'encoder_out']:
                if key in out:
                    logits = out[key]
                    print(f"Found '{key}': {logits.shape if isinstance(logits, torch.Tensor) else type(logits)}")
                    break
        else:
            print(f"Output attributes: {[a for a in dir(out) if not a.startswith('_')]}")
            for attr in ['logits', 'ctc_logits', 'emissions', 'output']:
                if hasattr(out, attr):
                    val = getattr(out, attr)
                    print(f"Found '{attr}': {val.shape if isinstance(val, torch.Tensor) else type(val)}")
                    
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()

Option 1: Direct Model Forward
Error: 'Tensor' object has no attribute 'packed'


Traceback (most recent call last):
  File "/tmp/ipython-input-1564891814.py", line 11, in <cell line: 0>
    out = model(waveform, lengths)
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/fairseq2/models/wav2vec2/asr/model.py", line 125, in forward
    seqs, seqs_layout, _ = self.encoder_frontend.extract_features(seqs, seqs_layout)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/fairseq2/models/wav2vec2/frontend.py", line 162, in extract_features
    seqs, seqs_layout = self.feature_extractor(

In [15]:
# Option 2: Encoder + CTC projection
print("=" * 60)
print("Option 2: Encoder + CTC Projection")
print("=" * 60)

if model is not None:
    model.eval()
    
    # Find encoder
    encoder = None
    for attr in ['encoder', 'wav2vec2', 'feature_extractor']:
        if hasattr(model, attr):
            encoder = getattr(model, attr)
            print(f"Found encoder: model.{attr}")
            break
    
    # Find CTC projection
    ctc_proj = None
    for attr in ['ctc_proj', 'output_projection', 'ctc', 'ctc_decoder', 'final_proj']:
        if hasattr(model, attr):
            ctc_proj = getattr(model, attr)
            print(f"Found CTC projection: model.{attr}")
            break
    
    if encoder is not None and ctc_proj is not None:
        try:
            with torch.inference_mode():
                # Encode
                encoder_out = encoder(waveform)
                print(f"Encoder output type: {type(encoder_out).__name__}")
                
                # Get features
                if isinstance(encoder_out, torch.Tensor):
                    features = encoder_out
                elif hasattr(encoder_out, 'output'):
                    features = encoder_out.output
                elif hasattr(encoder_out, 'last_hidden_state'):
                    features = encoder_out.last_hidden_state
                else:
                    features = encoder_out
                    
                print(f"Features shape: {features.shape if isinstance(features, torch.Tensor) else type(features)}")
                
                # Apply CTC projection
                logits = ctc_proj(features)
                print(f"Logits shape: {logits.shape}")
                
                # Log softmax
                emissions = F.log_softmax(logits.float(), dim=-1)
                print(f"Emissions shape: {emissions.shape}")
                
        except Exception as e:
            print(f"Error: {e}")
            import traceback
            traceback.print_exc()
    else:
        print("Could not find encoder and/or CTC projection")

Option 2: Encoder + CTC Projection
Found encoder: model.encoder
Found CTC projection: model.final_proj
Error: StandardTransformerEncoder.forward() missing 1 required positional argument: 'seqs_layout'


Traceback (most recent call last):
  File "/tmp/ipython-input-3988646838.py", line 29, in <cell line: 0>
    encoder_out = encoder(waveform)
                  ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: StandardTransformerEncoder.forward() missing 1 required positional argument: 'seqs_layout'


In [16]:
# Option 3: Hook into the model
print("=" * 60)
print("Option 3: Hook-based Extraction")
print("=" * 60)

if model is not None:
    model.eval()
    
    captured = {}
    
    def hook_fn(module, inp, out):
        captured["logits"] = out
    
    # Find a likely projection layer
    proj = None
    proj_name = None
    for name, m in model.named_modules():
        if isinstance(m, torch.nn.Linear):
            if any(s in name.lower() for s in ["ctc", "proj", "output", "lm_head", "final"]):
                proj = m
                proj_name = name
                break
    
    if proj is not None:
        print(f"Hooking: {proj_name}")
        h = proj.register_forward_hook(hook_fn)
        
        try:
            with torch.inference_mode():
                _ = model(waveform, lengths)
            
            h.remove()
            
            if "logits" in captured:
                logits = captured["logits"]
                print(f"Captured logits shape: {logits.shape}")
                
                emissions = F.log_softmax(logits.float(), dim=-1)
                print(f"Emissions shape: {emissions.shape}")
            else:
                print("Hook did not capture logits")
                
        except Exception as e:
            h.remove()
            print(f"Error: {e}")
            import traceback
            traceback.print_exc()
    else:
        print("Could not find a projection layer to hook")

Option 3: Hook-based Extraction
Could not find a projection layer to hook


In [17]:
# Option 4: Use extract_features if available
print("=" * 60)
print("Option 4: extract_features() Method")
print("=" * 60)

if model is not None:
    model.eval()
    
    if hasattr(model, 'extract_features'):
        try:
            with torch.inference_mode():
                encoder_out, padding_mask = model.extract_features(waveform, padding_mask=None)
            
            print(f"Encoder output shape: {encoder_out.shape}")
            if padding_mask is not None:
                print(f"Padding mask shape: {padding_mask.shape}")
            
            # Find CTC projection
            ctc_proj = None
            for attr in ['ctc_proj', 'output_projection', 'proj', 'final_proj']:
                if hasattr(model, attr):
                    ctc_proj = getattr(model, attr)
                    print(f"Found CTC projection: model.{attr}")
                    break
            
            if ctc_proj is not None:
                logits = ctc_proj(encoder_out)
                print(f"Logits shape: {logits.shape}")
                
                emissions = F.log_softmax(logits.float(), dim=-1)
                print(f"Emissions shape: {emissions.shape}")
                
        except Exception as e:
            print(f"Error: {e}")
            import traceback
            traceback.print_exc()
    else:
        print("Model does not have extract_features() method")

Option 4: extract_features() Method
Model does not have extract_features() method


In [18]:
# Option 5: Full forward with output parsing
print("=" * 60)
print("Option 5: Full Forward with Output Parsing")
print("=" * 60)

if model is not None:
    model.eval()
    
    try:
        with torch.inference_mode():
            # Try different forward signatures
            out = None
            
            # Try 1: (source, padding_mask)
            try:
                out = model(source=waveform, padding_mask=None)
                print("Forward signature: (source=, padding_mask=)")
            except:
                pass
            
            # Try 2: (waveform, lengths)
            if out is None:
                try:
                    out = model(waveform, lengths)
                    print("Forward signature: (waveform, lengths)")
                except:
                    pass
            
            # Try 3: just waveform
            if out is None:
                try:
                    out = model(waveform)
                    print("Forward signature: (waveform)")
                except:
                    pass
            
            if out is not None:
                print(f"\nOutput type: {type(out).__name__}")
                
                # Parse output
                if isinstance(out, torch.Tensor):
                    print(f"Direct tensor output: {out.shape}")
                    logits = out
                elif isinstance(out, dict):
                    print(f"Dict output keys: {out.keys()}")
                    for k, v in out.items():
                        if isinstance(v, torch.Tensor):
                            print(f"  {k}: {v.shape}")
                        elif isinstance(v, dict):
                            print(f"  {k}: dict with keys {v.keys()}")
                else:
                    # Named tuple or custom object
                    print(f"Object attributes:")
                    for attr in dir(out):
                        if not attr.startswith('_'):
                            val = getattr(out, attr, None)
                            if isinstance(val, torch.Tensor):
                                print(f"  {attr}: {val.shape}")
                            elif not callable(val):
                                print(f"  {attr}: {type(val).__name__}")
            else:
                print("Could not call model forward")
                
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()

Option 5: Full Forward with Output Parsing
Could not call model forward


## Part 4: Test Pipeline Audio Loading

Check if the pipeline has built-in audio loading.

In [19]:
# Check pipeline for audio loading methods
print("=" * 60)
print("Pipeline Audio Loading Methods:")
print("=" * 60)

for attr in dir(pipeline):
    if 'audio' in attr.lower() or 'load' in attr.lower() or 'wave' in attr.lower():
        val = getattr(pipeline, attr, None)
        print(f"  {attr}: {type(val).__name__} {'(callable)' if callable(val) else ''}")

Pipeline Audio Loading Methods:
  _build_audio_wavform_pipeline: method (callable)
  _process_context_audio: method (callable)
  audio_decoder: AudioDecoder (callable)
  collater_audio: Collater (callable)


In [20]:
# Test transcription to verify model works
print("=" * 60)
print("Test Transcription (to verify model works):")
print("=" * 60)

# Generate some test audio or use a real file
try:
    # If you have a test audio file, use it:
    # result = pipeline("/path/to/audio.wav")
    
    # Or test with generated audio (will produce gibberish)
    import torchaudio
    import tempfile
    
    # Create a simple test audio
    test_waveform = torch.randn(1, 16000 * 3)  # 3 seconds
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
        torchaudio.save(f.name, test_waveform, 16000)
        temp_path = f.name
    
    # Try to transcribe
    result = pipeline(temp_path)
    print(f"Transcription result: {result}")
    print(f"Result type: {type(result).__name__}")
    
    import os
    os.unlink(temp_path)
    
except Exception as e:
    print(f"Error: {e}")
    import traceback
    traceback.print_exc()

Test Transcription (to verify model works):
Error: 'ASRInferencePipeline' object is not callable


  s = torchaudio.io.StreamWriter(uri, format=muxer, buffer_size=buffer_size)
Traceback (most recent call last):
  File "/tmp/ipython-input-3959037900.py", line 22, in <cell line: 0>
    result = pipeline(temp_path)
             ^^^^^^^^^^^^^^^^^^^
TypeError: 'ASRInferencePipeline' object is not callable


## Part 5: Summary and Recommendations

Based on the exploration above, summarize what we learned.

In [21]:
print("=" * 60)
print("SUMMARY")
print("=" * 60)

print(f"\n1. Model Name: {MODEL_NAME}")
print(f"2. Model Type: {type(model).__name__ if model else 'Not found'}")
print(f"3. Tokenizer Type: {type(tokenizer).__name__ if tokenizer else 'Not found'}")
print(f"4. Vocab Size: {vocab_size if vocab_size else 'Unknown'}")

print("\n" + "=" * 60)
print("RECOMMENDATIONS FOR BACKEND IMPLEMENTATION:")
print("=" * 60)

print("""
Based on the exploration, update omniasr_backend.py to:

1. TOKENIZER/VOCAB:
   - [Fill in based on what worked above]

2. EMISSION EXTRACTION:
   - [Fill in based on what worked above]

3. AUDIO LOADING:
   - [Fill in based on what worked above]
""")

SUMMARY

1. Model Name: omniASR_CTC_300M
2. Model Type: Wav2Vec2AsrModel
3. Tokenizer Type: RawSentencePieceTokenizer
4. Vocab Size: Unknown

RECOMMENDATIONS FOR BACKEND IMPLEMENTATION:

Based on the exploration, update omniasr_backend.py to:

1. TOKENIZER/VOCAB:
   - [Fill in based on what worked above]

2. EMISSION EXTRACTION:
   - [Fill in based on what worked above]

3. AUDIO LOADING:
   - [Fill in based on what worked above]



In [24]:
# Get Vocabulary from Tokenizer

# The tokenizer has vocab_info attribute (not a method!)
print("=== Vocabulary Info ===")
vocab_info = pipeline.tokenizer.vocab_info
print(f"vocab_info type: {type(vocab_info)}")
print(f"vocab_info: {vocab_info}")
print(f"size: {vocab_info.size}")
print(f"unk_idx: {vocab_info.unk_idx}")
print(f"bos_idx: {vocab_info.bos_idx}")
print(f"eos_idx: {vocab_info.eos_idx}")
print(f"pad_idx: {vocab_info.pad_idx}")

# Try to get actual tokens via encoder
encoder = pipeline.tokenizer.create_encoder()
print(f"\nEncoder type: {type(encoder)}")
print(f"Encoder attrs: {[a for a in dir(encoder) if not a.startswith('_')]}")

# Try to decode some token IDs to see what they are
decoder = pipeline.tokenizer.create_decoder()
print(f"\nDecoder type: {type(decoder)}")
print(f"Decoder attrs: {[a for a in dir(decoder) if not a.startswith('_')]}")

# Test decode some IDs
for i in [0, 1, 2, 3, 4, 5, 10, 100, 1000]:
    try:
        # Different ways to decode
        if hasattr(decoder, '__call__'):
            text = decoder(torch.tensor([[i]]))
            print(f"ID {i}: {text}")
    except Exception as e:
        print(f"ID {i}: error - {e}")

=== Vocabulary Info ===
vocab_info type: <class 'fairseq2.data.tokenizers.vocab_info.VocabularyInfo'>
vocab_info: VocabularyInfo(size=9812, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1, boh_idx=None, eoh_idx=None)
size: 9812
unk_idx: 3
bos_idx: 0
eos_idx: 2
pad_idx: 1

Encoder type: <class 'fairseq2.data.tokenizers.sentencepiece.SentencePieceEncoder'>
Encoder attrs: ['encode_as_tokens', 'prefix_indices', 'suffix_indices']

Decoder type: <class 'fairseq2.data.tokenizers.sentencepiece.SentencePieceDecoder'>
Decoder attrs: ['decode_from_tokens']
ID 0: error - The input tensor must be one dimensional, but has 2 dimension(s) instead.
ID 1: error - The input tensor must be one dimensional, but has 2 dimension(s) instead.
ID 2: error - The input tensor must be one dimensional, but has 2 dimension(s) instead.
ID 3: error - The input tensor must be one dimensional, but has 2 dimension(s) instead.
ID 4: error - The input tensor must be one dimensional, but has 2 dimension(s) instead.
ID 5: error -

In [34]:
#### Use Pipeline's Internal Methods

# The pipeline has _apply_model methods - let's check those
print("=== Pipeline _apply_model ===")
import inspect

# Check _apply_model_wav2vec2asr signature
sig = inspect.signature(pipeline._apply_model_wav2vec2asr)
print(f"_apply_model_wav2vec2asr signature: {sig}")

# Check _build_audio_wavform_pipeline
sig = inspect.signature(pipeline._build_audio_wavform_pipeline)
print(f"_build_audio_wavform_pipeline signature: {sig}")

#### Use the collater to prepare proper input

# The pipeline has collaters that prepare data in the right format
print("=== Using Pipeline Collaters ===")

# First load audio properly
import torchaudio
import tempfile

# Create test audio file
test_waveform = torch.randn(1, 16000 * 3)  # 3 seconds
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
    torchaudio.save(f.name, test_waveform, 16000)
    temp_path = f.name

# Check audio_decoder
print(f"audio_decoder type: {type(pipeline.audio_decoder)}")
print(f"audio_decoder attrs: {[a for a in dir(pipeline.audio_decoder) if not a.startswith('_')]}")

# Check collater_audio
print(f"\ncollater_audio type: {type(pipeline.collater_audio)}")

# Try file_mapper
print(f"\nfile_mapper type: {type(pipeline.file_mapper)}")

=== Pipeline _apply_model ===
_apply_model_wav2vec2asr signature: (batch: 'Seq2SeqBatch') -> 'List[str]'
_build_audio_wavform_pipeline signature: (inp_list: 'AudioInput') -> 'DataPipelineBuilder'
=== Using Pipeline Collaters ===
audio_decoder type: <class 'fairseq2n.bindings.data.audio.AudioDecoder'>
audio_decoder attrs: []

collater_audio type: <class 'fairseq2n.bindings.data.data_pipeline.Collater'>

file_mapper type: <class 'fairseq2n.bindings.data.data_pipeline.FileMapper'>


In [35]:
#### Trace through transcribe to understand flow

# Look at transcribe source to understand input format
import inspect
print("=== transcribe method source ===")
try:
    source = inspect.getsource(pipeline.transcribe)
    print(source[:2000])  # First 2000 chars
except:
    print("Could not get source")

=== transcribe method source ===
    @torch.inference_mode()
    def transcribe(
        self,
        inp: AudioInput,
        *,
        lang: List[str | None] | List[str] | List[None] | None = None,
        batch_size: int = 2,
    ) -> List[str]:
        """
        Transcribes `AudioInput` into text by preprocessing (decoding, resample to 16kHz, converting to mono, normalizing)
        each input sample and performing inference with `self.model`.

        Works for both CTC and LLM model variants by optionally allowing a language conditioning token to help with LLM generation.
        It is ignored when performing inference with CTC. See `omnilingual_asr/models/wav2vec2_llama/lang_ids.py` for supported languages.

        Args:
            `inp`: Audio input in different forms.
                - `List[ Path | str ]`: Audio file paths
                - `List[ bytes ]`: Raw audio data
                - `List[ np.ndarray ]`: Audio data as uint8 numpy array
                - `List[ di

In [36]:
#### Get SequenceLayout from fairseq2

# Check what SequenceLayout looks like
print("=== Understanding SequenceLayout ===")
from fairseq2.data import SequenceBatch

# Check if we can create proper input
print(f"SequenceBatch attrs: {[a for a in dir(SequenceBatch) if not a.startswith('_')]}")

# Try to find how pipeline creates batches
print(f"\n_create_batch_simple signature: {inspect.signature(pipeline._create_batch_simple)}")

=== Understanding SequenceLayout ===


ImportError: cannot import name 'SequenceBatch' from 'fairseq2.data' (/usr/local/lib/python3.12/dist-packages/fairseq2/data/__init__.py)

In [37]:
#### Direct approach - intercept at the right level

# The cleanest approach: use fairseq2's batch utilities
print("=== Create proper batch ===")

try:
    from fairseq2.data import SequenceBatch
    from fairseq2.nn.padding import PaddingMask

    # Load audio the same way pipeline does
    # Check what audio_decoder returns
    audio_data = pipeline.audio_decoder(temp_path)
    print(f"audio_decoder output type: {type(audio_data)}")
    print(f"audio_decoder output: {audio_data}")

except Exception as e:
    print(f"Error: {e}")
    import traceback
    traceback.print_exc()

=== Create proper batch ===
Error: cannot import name 'SequenceBatch' from 'fairseq2.data' (/usr/local/lib/python3.12/dist-packages/fairseq2/data/__init__.py)


Traceback (most recent call last):
  File "/tmp/ipython-input-4276800488.py", line 7, in <cell line: 0>
    from fairseq2.data import SequenceBatch
ImportError: cannot import name 'SequenceBatch' from 'fairseq2.data' (/usr/local/lib/python3.12/dist-packages/fairseq2/data/__init__.py)


In [39]:
#### Hook into final_proj during actual transcription

# Hook during actual transcription call
print("=== Hook during transcription ===")

import torch.nn.functional as F

audio_path = "/content/torchaudio_aligner/examples/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"

captured = {}

def hook_fn(module, inp, out):
    captured["input"] = inp
    captured["output"] = out
    print(f"Hook fired! Input shape: {inp[0].shape if isinstance(inp[0], torch.Tensor) else type(inp[0])}")
    print(f"Output shape: {out.shape}")

# Register hook on final_proj
h = pipeline.model.final_proj.register_forward_hook(hook_fn)

try:
    # Run transcription - NOTE: pass as a LIST
    result = pipeline.transcribe([audio_path])
    print(f"\nTranscription: {result}")
finally:
    h.remove()

# Now we have the logits!
if "output" in captured:
    logits = captured["output"]
    print(f"\nCaptured logits shape: {logits.shape}")
    emissions = F.log_softmax(logits.float(), dim=-1)
    print(f"Emissions shape: {emissions.shape}")
    print(f"Vocab size: {emissions.shape[-1]}")

    # Verify it matches tokenizer vocab
    print(f"Tokenizer vocab size: {pipeline.tokenizer.vocab_info.size}")

=== Hook during transcription ===
Hook fired! Input shape: torch.Size([1, 169, 1024])
Output shape: torch.Size([1, 169, 9812])

Transcription: ['i had that curiosity beside me at this moment']

Captured logits shape: torch.Size([1, 169, 9812])
Emissions shape: torch.Size([1, 169, 9812])
Vocab size: 9812
Tokenizer vocab size: 9812


In [41]:
# Get vocabulary tokens from tokenizer
print("=== Extracting Vocabulary Tokens ===")

vocab_info = pipeline.tokenizer.vocab_info
vocab_size = vocab_info.size
print(f"Vocab size: {vocab_size}")
print(f"Special tokens: unk={vocab_info.unk_idx}, bos={vocab_info.bos_idx}, eos={vocab_info.eos_idx}, pad={vocab_info.pad_idx}")

# Create decoder to convert IDs to tokens
decoder = pipeline.tokenizer.create_decoder()
print(f"\nDecoder type: {type(decoder)}")

# Try to decode individual token IDs
print("\n=== Decoding sample token IDs ===")
labels = []
for i in range(vocab_size):
    try:
        # Decoder expects batch of sequences
        token_tensor = torch.tensor([[i]])
        decoded = decoder(token_tensor)
        if isinstance(decoded, list):
            token = decoded[0] if decoded else ""
        else:
            token = str(decoded)
        labels.append(token)

        # Print first 20 and some samples
        if i < 20 or i in [100, 500, 1000, 5000, 9000]:
            print(f"  ID {i}: '{token}'")
    except Exception as e:
        labels.append(f"<{i}>")
        if i < 20:
            print(f"  ID {i}: ERROR - {e}")

print(f"\nTotal labels extracted: {len(labels)}")
print(f"Sample labels [0:10]: {labels[0:10]}")
print(f"Sample labels [100:110]: {labels[100:110]}")

# Alternative: Try raw encoder to understand token structure
print("=== Testing Encoder ===")

encoder = pipeline.tokenizer.create_encoder()
print(f"Encoder type: {type(encoder)}")
print(f"Encoder attrs: {[a for a in dir(encoder) if not a.startswith('_')]}")

# Encode some text
test_texts = ["hello", "i had that curiosity", "test"]
for text in test_texts:
    try:
        encoded = encoder(text)
        print(f"'{text}' -> {encoded}")
    except Exception as e:
        print(f"'{text}' -> ERROR: {e}")

# Check if there's a way to get the raw sentencepiece model
print("=== Raw SentencePiece Access ===")

tokenizer = pipeline.tokenizer
print(f"Tokenizer attrs: {[a for a in dir(tokenizer) if not a.startswith('_')]}")

# Check for model or vocab access
for attr in ['model', 'sp_model', 'spm', 'vocab', 'get_vocab', 'id_to_token', 'token_to_id']:
    if hasattr(tokenizer, attr):
        val = getattr(tokenizer, attr)
        print(f"  {attr}: {type(val)}")
        if callable(val):
            try:
                result = val()
                print(f"    -> {type(result)}")
            except:
                pass

# Try to access underlying sentencepiece processor
print("=== Underlying SentencePiece ===")

# RawSentencePieceTokenizer might have _processor or similar
for attr in dir(tokenizer):
    if 'piece' in attr.lower() or 'sp' in attr.lower() or 'processor' in attr.lower() or 'model' in attr.lower():
        print(f"  {attr}: {type(getattr(tokenizer, attr, None))}")

# Build complete vocab using the decoder we know works
print("=== Building Complete Vocabulary ===")

vocab_size = pipeline.tokenizer.vocab_info.size
decoder = pipeline.tokenizer.create_decoder()

# Decode all tokens
all_labels = []
failed = 0

for i in range(vocab_size):
    try:
        token_tensor = torch.tensor([[i]])
        decoded = decoder(token_tensor)
        if isinstance(decoded, list):
            token = decoded[0] if decoded else ""
        else:
            token = str(decoded)
        all_labels.append(token)
    except:
        all_labels.append("")
        failed += 1

print(f"Successfully decoded: {vocab_size - failed}/{vocab_size}")
print(f"\nFirst 50 labels: {all_labels[:50]}")

# Find blank token (usually index 0 for CTC)
print(f"\nLabel at index 0 (likely blank): '{all_labels[0]}'")
print(f"Label at index 1 (pad): '{all_labels[1]}'")
print(f"Label at index 2 (eos): '{all_labels[2]}'")
print(f"Label at index 3 (unk): '{all_labels[3]}'")

=== Extracting Vocabulary Tokens ===
Vocab size: 9812
Special tokens: unk=3, bos=0, eos=2, pad=1

Decoder type: <class 'fairseq2.data.tokenizers.sentencepiece.SentencePieceDecoder'>

=== Decoding sample token IDs ===
  ID 0: ERROR - The input tensor must be one dimensional, but has 2 dimension(s) instead.
  ID 1: ERROR - The input tensor must be one dimensional, but has 2 dimension(s) instead.
  ID 2: ERROR - The input tensor must be one dimensional, but has 2 dimension(s) instead.
  ID 3: ERROR - The input tensor must be one dimensional, but has 2 dimension(s) instead.
  ID 4: ERROR - The input tensor must be one dimensional, but has 2 dimension(s) instead.
  ID 5: ERROR - The input tensor must be one dimensional, but has 2 dimension(s) instead.
  ID 6: ERROR - The input tensor must be one dimensional, but has 2 dimension(s) instead.
  ID 7: ERROR - The input tensor must be one dimensional, but has 2 dimension(s) instead.
  ID 8: ERROR - The input tensor must be one dimensional, but h

In [43]:
# Fix: Decoder expects 1D tensor
print("=== Decoding with 1D tensor ===")

decoder = pipeline.tokenizer.create_decoder()

# Test with 1D tensor
for i in [0, 1, 2, 3, 4, 5, 10, 100, 500, 1000]:
    try:
        token_tensor = torch.tensor([i])  # 1D, not 2D
        decoded = decoder(token_tensor)
        print(f"  ID {i}: '{decoded}'")
    except Exception as e:
        print(f"  ID {i}: ERROR - {e}")

# Access the underlying SentencePiece model directly
print("=== Access SentencePieceModel ===")

sp_model = pipeline.tokenizer._model
print(f"sp_model type: {type(sp_model)}")
print(f"sp_model attrs: {[a for a in dir(sp_model) if not a.startswith('_')]}")

# Try common sentencepiece methods
for method in ['id_to_piece', 'IdToPiece', 'decode_from_ids', 'index_to_token', 'get_piece']:
    if hasattr(sp_model, method):
        print(f"  Found method: {method}")
        try:
            result = getattr(sp_model, method)(0)
            print(f"    sp_model.{method}(0) = '{result}'")
        except Exception as e:
            print(f"    Error: {e}")

# Try encode_as_tokens to understand the token format
print("=== Using encode_as_tokens ===")

encoder = pipeline.tokenizer.create_encoder()

test_texts = ["hello", "a", "i", " ", "test"]
for text in test_texts:
    try:
        # encode_as_tokens might return actual token strings
        tokens = encoder.encode_as_tokens(text)
        ids = encoder(text)
        print(f"'{text}' -> tokens: {tokens}, ids: {ids.tolist()}")
    except Exception as e:
        print(f"'{text}' -> ERROR: {e}")

# Try to decode a sequence (what we got from transcription)
print("=== Decode a sequence ===")

decoder = pipeline.tokenizer.create_decoder()

# The encoder output for "hello" was [113, 9346, 1875, 1875, 8749]
test_sequence = torch.tensor([113, 9346, 1875, 1875, 8749])
try:
    decoded = decoder(test_sequence)
    print(f"Decoded 'hello' tokens: '{decoded}'")
except Exception as e:
    print(f"Error: {e}")

# Try other sequences
test_sequence2 = torch.tensor([4328, 4, 113, 9499])  # Part of "i had that curiosity"
try:
    decoded = decoder(test_sequence2)
    print(f"Decoded partial sequence: '{decoded}'")
except Exception as e:
    print(f"Error: {e}")

# Check if SentencePieceModel has index_to_token or similar
print("=== SentencePieceModel methods ===")

sp_model = pipeline.tokenizer._model

# List all methods
for attr in sorted(dir(sp_model)):
    if not attr.startswith('_'):
        val = getattr(sp_model, attr)
        if callable(val):
            print(f"  {attr}()")
        else:
            print(f"  {attr} = {type(val).__name__}")

# Try to use the model's token_decoder (from pipeline attributes)
print("=== Using pipeline.token_decoder ===")

token_decoder = pipeline.token_decoder
print(f"token_decoder type: {type(token_decoder)}")
print(f"token_decoder attrs: {[a for a in dir(token_decoder) if not a.startswith('_')]}")

# Try to decode
try:
    test_ids = torch.tensor([113, 9346, 1875, 1875, 8749])
    result = token_decoder(test_ids)
    print(f"Decoded: {result}")
except Exception as e:
    print(f"Error: {e}")

=== Decoding with 1D tensor ===
  ID 0: ''
  ID 1: ''
  ID 2: ''
  ID 3: ' ⁇ '
  ID 4: ' '
  ID 5: 'ዘ'
  ID 10: '祭'
  ID 100: 'ු'
  ID 500: '仗'
  ID 1000: '깨'
=== Access SentencePieceModel ===
sp_model type: <class 'fairseq2.data.tokenizers.sentencepiece.SentencePieceModel'>
sp_model attrs: ['bos_idx', 'eos_idx', 'index_to_token', 'pad_idx', 'token_to_index', 'unk_idx', 'vocabulary_size']
  Found method: index_to_token
    sp_model.index_to_token(0) = '<s>'
=== Using encode_as_tokens ===
'hello' -> tokens: ['h', 'e', 'l', 'l', 'o'], ids: [113, 9346, 1875, 1875, 8749]
'a' -> tokens: ['a'], ids: [9499]
'i' -> tokens: ['i'], ids: [4328]
' ' -> tokens: [], ids: []
'test' -> tokens: ['t', 'e', 's', 't'], ids: [2226, 9346, 7076, 2226]
=== Decode a sequence ===
Decoded 'hello' tokens: 'hello'
Decoded partial sequence: 'i ha'
=== SentencePieceModel methods ===
  bos_idx = int
  eos_idx = int
  index_to_token()
  pad_idx = int
  token_to_index()
  unk_idx = int
  vocabulary_size = int
=== Using

In [44]:
# Build complete vocabulary using sp_model.index_to_token
print("=== Building Complete Vocabulary ===")

sp_model = pipeline.tokenizer._model
vocab_size = sp_model.vocabulary_size

print(f"Vocab size: {vocab_size}")
print(f"Special indices: bos={sp_model.bos_idx}, eos={sp_model.eos_idx}, pad={sp_model.pad_idx}, unk={sp_model.unk_idx}")

# Build labels list
labels = []
for i in range(vocab_size):
    token = sp_model.index_to_token(i)
    labels.append(token)

print(f"\nFirst 30 tokens:")
for i in range(30):
    print(f"  {i}: '{labels[i]}'")

print(f"\nSample tokens:")
for i in [100, 500, 1000, 5000, 9000, 9811]:
    print(f"  {i}: '{labels[i]}'")

# Check common characters
print(f"\nCommon characters:")
for char in ['a', 'b', 'c', 'h', 'e', 'l', 'o', ' ']:
    try:
        idx = sp_model.token_to_index(char)
        print(f"  '{char}' -> {idx}")
    except:
        print(f"  '{char}' -> NOT FOUND")

# Identify blank token for CTC
print("=== Identifying CTC Blank Token ===")

# In CTC, blank is often:
# - Index 0
# - A special <blank> or <ctc> token
# - The pad token

# Check what's at common blank positions
for i in [0, 1, 2, 3, 4]:
    print(f"  Index {i}: '{labels[i]}'")

# For OmniASR CTC, blank is likely index 0 (<s>/bos) or a dedicated blank
# Let's check by looking at emission argmax patterns
print("\n=== Check emissions for blank patterns ===")

# Get emissions from our earlier capture
if "output" in captured:
    emissions = F.log_softmax(captured["output"].float(), dim=-1)

    # Argmax prediction
    argmax_ids = emissions.argmax(dim=-1)[0]  # [T]
    print(f"Argmax shape: {argmax_ids.shape}")
    print(f"Argmax first 50: {argmax_ids[:50].tolist()}")

    # Count most common IDs (blank should be most frequent)
    from collections import Counter
    counts = Counter(argmax_ids.tolist())
    print(f"\nMost common IDs (likely blank is most frequent):")
    for idx, count in counts.most_common(10):
        print(f"  ID {idx} ('{labels[idx]}'): {count} times")

# Summary: Build VocabInfo for our backend
print("=== VocabInfo Summary for Backend ===")

sp_model = pipeline.tokenizer._model

# Determine blank_id - most common in CTC output is usually blank
# Based on emission analysis above, or use bos_idx (0) as common choice

# For OmniASR, blank is typically index 0
blank_id = 0  # Adjust based on emission analysis

vocab_info_dict = {
    "vocab_size": sp_model.vocabulary_size,
    "blank_id": blank_id,
    "blank_token": sp_model.index_to_token(blank_id),
    "unk_id": sp_model.unk_idx,
    "unk_token": sp_model.index_to_token(sp_model.unk_idx),
    "bos_id": sp_model.bos_idx,
    "eos_id": sp_model.eos_idx,
    "pad_id": sp_model.pad_idx,
}

print(f"VocabInfo:")
for k, v in vocab_info_dict.items():
    print(f"  {k}: {v}")

print(f"\nLabels list: {len(labels)} tokens")
print(f"  First 10: {labels[:10]}")

=== Building Complete Vocabulary ===
Vocab size: 9812
Special indices: bos=0, eos=2, pad=1, unk=3

First 30 tokens:
  0: '<s>'
  1: '<pad>'
  2: '</s>'
  3: '<unk>'
  4: ' '
  5: 'ዘ'
  6: 'љ'
  7: '耷'
  8: '氰'
  9: '賬'
  10: '祭'
  11: '繳'
  12: '渭'
  13: '穗'
  14: '捡'
  15: '栩'
  16: '勸'
  17: '诚'
  18: '戮'
  19: '尉'
  20: '們'
  21: '邮'
  22: '佈'
  23: 'ޯ'
  24: '刻'
  25: '겔'
  26: '纪'
  27: '빠'
  28: '吠'
  29: '螺'

Sample tokens:
  100: 'ු'
  500: '仗'
  1000: '깨'
  5000: '抑'
  9000: '닐'
  9811: '곽'

Common characters:
  'a' -> 9499
  'b' -> 7565
  'c' -> 5943
  'h' -> 113
  'e' -> 9346
  'l' -> 1875
  'o' -> 8749
  ' ' -> 4
=== Identifying CTC Blank Token ===
  Index 0: '<s>'
  Index 1: '<pad>'
  Index 2: '</s>'
  Index 3: '<unk>'
  Index 4: ' '

=== Check emissions for blank patterns ===
Argmax shape: torch.Size([169])
Argmax first 50: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4328, 0, 0, 4, 4, 113, 0, 9499, 0, 1133, 0, 0, 4, 4, 222

In [46]:
# Explore OmniASR preprocessing pipeline
print("=== Exploring OmniASR Preprocessing ===")

# Check what _build_audio_wavform_pipeline does
import inspect
print("_build_audio_wavform_pipeline signature:")
print(inspect.signature(pipeline._build_audio_wavform_pipeline))

# Check audio_decoder
print(f"\naudio_decoder type: {type(pipeline.audio_decoder)}")
print(f"audio_decoder attrs: {[a for a in dir(pipeline.audio_decoder) if not a.startswith('_')]}")

# Try to understand how pipeline processes audio internally
print("=== Check collater_audio ===")

collater = pipeline.collater_audio
print(f"collater_audio type: {type(collater)}")
print(f"collater_audio attrs: {[a for a in dir(collater) if not a.startswith('_')]}")

# Check full_collater
print(f"\nfull_collater type: {type(pipeline.full_collater)}")

# Look at _apply_model_wav2vec2asr - this is where the magic happens
print("=== _apply_model_wav2vec2asr source ===")
try:
    source = inspect.getsource(pipeline._apply_model_wav2vec2asr)
    print(source[:3000])
except Exception as e:
    print(f"Error: {e}")

# Check what fairseq2 SequenceBatch looks like (different import path)
print("=== Finding SequenceBatch ===")

# Try different import paths
import_attempts = [
    "from fairseq2.nn.padding import PaddingMask",
    "from fairseq2.data import Collater",
    "from fairseq2.nn import SequenceBatch",
    "from fairseq2.models.sequence import SequenceBatch",
]

for imp in import_attempts:
    try:
        exec(imp)
        print(f"SUCCESS: {imp}")
    except ImportError as e:
        print(f"FAILED: {imp} -> {e}")

# Check model's expected input format by looking at forward signature
print("=== Model forward signature ===")

model = pipeline.model
print(f"Model type: {type(model)}")

# Check forward method
import inspect
try:
    sig = inspect.signature(model.forward)
    print(f"forward signature: {sig}")
except:
    pass

# Check what encoder_frontend expects
print(f"\nencoder_frontend type: {type(model.encoder_frontend)}")
if hasattr(model.encoder_frontend, 'forward'):
    try:
        sig = inspect.signature(model.encoder_frontend.forward)
        print(f"encoder_frontend.forward signature: {sig}")
    except:
        pass

# Look at fairseq2's wav2vec2 model source to understand input format
print("=== Wav2Vec2AsrModel forward source ===")

from fairseq2.models.wav2vec2.asr.model import Wav2Vec2AsrModel
try:
    source = inspect.getsource(Wav2Vec2AsrModel.forward)
    print(source[:2000])
except Exception as e:
    print(f"Error: {e}")

# Check if there's a way to create the proper input batch
print("=== Creating proper input ===")

# Look for SequenceBatch or similar in fairseq2
import fairseq2
print(f"fairseq2 version: {fairseq2.__version__}")

# Check nn module
from fairseq2 import nn as fs2_nn
print(f"fairseq2.nn contents: {[a for a in dir(fs2_nn) if not a.startswith('_')]}")

# Try to find how to create a SequenceBatch from raw waveform
print("=== Exploring fairseq2.data ===")

from fairseq2 import data as fs2_data
print(f"fairseq2.data contents: {[a for a in dir(fs2_data) if not a.startswith('_')]}")

# Check for collate utilities
for name in dir(fs2_data):
    if 'collat' in name.lower() or 'batch' in name.lower() or 'sequence' in name.lower():
        print(f"  {name}: {type(getattr(fs2_data, name))}")

=== Exploring OmniASR Preprocessing ===
_build_audio_wavform_pipeline signature:
(inp_list: 'AudioInput') -> 'DataPipelineBuilder'

audio_decoder type: <class 'fairseq2n.bindings.data.audio.AudioDecoder'>
audio_decoder attrs: []
=== Check collater_audio ===
collater_audio type: <class 'fairseq2n.bindings.data.data_pipeline.Collater'>
collater_audio attrs: []

full_collater type: <class 'fairseq2n.bindings.data.data_pipeline.Collater'>
=== _apply_model_wav2vec2asr source ===
    def _apply_model_wav2vec2asr(self, batch: Seq2SeqBatch) -> List[str]:
        batch_layout = BatchLayout(
            batch.source_seqs.shape,
            seq_lens=batch.source_seq_lens,
            device=batch.source_seqs.device,
        )

        logits, bl_out = self.model(batch.source_seqs, batch_layout)
        pred_ids = torch.argmax(logits, dim=-1)
        transcriptions = []

        for i in range(pred_ids.shape[0]):
            # Create a mask for where consecutive elements differ (CTC decoding)
    

In [None]:
# Explore model structure patterns
print("=== Compare model structures ===")

# OmniASR (fairseq2)
omni_model = pipeline.model
print("OmniASR structure:")
for name, child in omni_model.named_children():
    print(f"  {name}: {type(child).__name__}")

# If you have HuggingFace model loaded:
from transformers import Wav2Vec2ForCTC
hf_model = Wav2Vec2ForCTC.from_pretrained("facebook/omniASR_CTC_300M")
print("\nHuggingFace structure:")
for name, child in hf_model.named_children():
    print(f"  {name}: {type(child).__name__}")

=== Compare model structures ===
OmniASR structure:
  encoder_frontend: Wav2Vec2Frontend
  encoder: StandardTransformerEncoder
  final_proj: Linear


In [53]:
# Direct model call with BatchLayout
print("=== Direct Model Call with BatchLayout ===")

from fairseq2.nn import BatchLayout
import torchaudio

# Convert device string to torch.device
device_obj = torch.device(device)

# Check model dtype
model = pipeline.model
model_dtype = next(model.parameters()).dtype
print(f"Model dtype: {model_dtype}")

# Load test audio
audio_path = "/content/torchaudio_aligner/examples/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
waveform, sr = torchaudio.load(audio_path)

# Resample to 16kHz if needed
if sr != 16000:
    waveform = torchaudio.functional.resample(waveform, sr, 16000)

# Convert to mono if stereo
if waveform.shape[0] > 1:
    waveform = waveform.mean(dim=0, keepdim=True)

# Prepare input: (batch, samples) - MATCH MODEL DTYPE
waveform = waveform.squeeze(0)  # Remove channel dim -> (samples,)
waveform_batch = waveform.unsqueeze(0).to(device_obj, dtype=model_dtype)  # Match dtype!

print(f"Waveform shape: {waveform_batch.shape}, dtype: {waveform_batch.dtype}")

# Create BatchLayout
seq_lens = torch.tensor([waveform_batch.shape[1]], device=device_obj)
batch_layout = BatchLayout(
    waveform_batch.shape,
    seq_lens=seq_lens,
    device=device_obj,
)

print(f"BatchLayout: {batch_layout}")

# Call model directly
print("=== Calling model.forward() directly ===")

model.eval()

with torch.inference_mode():
    logits, output_layout = model(waveform_batch, batch_layout)

print(f"Logits shape: {logits.shape}")
print(f"Output layout: {output_layout}")

# Convert to emissions (use float32 for numerical stability)
emissions = F.log_softmax(logits.float(), dim=-1)
print(f"Emissions shape: {emissions.shape}")

# Verify by greedy decoding
pred_ids = torch.argmax(logits, dim=-1)[0]

# seq_lens is a list of ints, not tensors
seq_len = output_layout.seq_lens[0]
if hasattr(seq_len, 'item'):
    seq_len = seq_len.item()
print(f"Sequence length: {seq_len}")

pred_ids = pred_ids[:seq_len]

# CTC collapse
mask = torch.ones(pred_ids.shape[0], dtype=torch.bool, device=device_obj)
mask[1:] = pred_ids[1:] != pred_ids[:-1]
decoded_ids = pred_ids[mask]

# Remove blank (index 0)
blank_id = 0
decoded_ids = decoded_ids[decoded_ids != blank_id]

# Decode to text
decoder = pipeline.tokenizer.create_decoder()
text = decoder(decoded_ids)
print(f"\nDecoded text: '{text}'")

=== Direct Model Call with BatchLayout ===
Model dtype: torch.bfloat16
Waveform shape: torch.Size([1, 54400]), dtype: torch.bfloat16
BatchLayout: BatchLayout(width=54400, seq_begin_indices=[0, 54400], seq_lens=[tensor(54400, device='cuda:0')], min_seq_len=54400, max_seq_len=54400, padded=False, packed=False)
=== Calling model.forward() directly ===
Logits shape: torch.Size([1, 169, 9812])
Output layout: BatchLayout(width=169, seq_begin_indices=[0, 169], seq_lens=[169], min_seq_len=169, max_seq_len=169, padded=False, packed=False)
Emissions shape: torch.Size([1, 169, 9812])
Sequence length: 169

Decoded text: 'i had that curiosity beside me at this moment'


In [54]:
# Test batched inference with different lengths
print("=== Test Batched Inference ===")

# Create two waveforms of different lengths
waveform1_len = waveform_batch.shape[1] // 2  # Half length
waveform2_len = waveform_batch.shape[1]       # Full length

waveform1 = waveform_batch[0, :waveform1_len]
waveform2 = waveform_batch[0, :waveform2_len]

# Pad to same length
max_len = max(waveform1_len, waveform2_len)
waveform1_padded = F.pad(waveform1, (0, max_len - waveform1_len))
waveform2_padded = waveform2  # Already max length

# Stack into batch
batch = torch.stack([waveform1_padded, waveform2_padded])
seq_lens = torch.tensor([waveform1_len, waveform2_len], device=device_obj)

print(f"Batch shape: {batch.shape}")
print(f"Seq lens: {seq_lens.tolist()}")

# Create BatchLayout
batch_layout = BatchLayout(
    batch.shape,
    seq_lens=seq_lens,
    device=device_obj,
)

print(f"BatchLayout: {batch_layout}")

# Forward pass
with torch.inference_mode():
    logits, output_layout = model(batch, batch_layout)

print(f"\nOutput logits shape: {logits.shape}")
print(f"Output seq_lens: {output_layout.seq_lens}")

# Decode both
emissions = F.log_softmax(logits.float(), dim=-1)
decoder = pipeline.tokenizer.create_decoder()
blank_id = 0

for i in range(2):
    pred_ids = torch.argmax(logits[i], dim=-1)
    seq_len = output_layout.seq_lens[i]
    if hasattr(seq_len, 'item'):
        seq_len = seq_len.item()
    pred_ids = pred_ids[:seq_len]

    # CTC collapse
    mask = torch.ones(pred_ids.shape[0], dtype=torch.bool, device=device_obj)
    mask[1:] = pred_ids[1:] != pred_ids[:-1]
    decoded_ids = pred_ids[mask]
    decoded_ids = decoded_ids[decoded_ids != blank_id]

    text = decoder(decoded_ids)
    print(f"\nSample {i} (len={output_layout.seq_lens[i]}): '{text}'")

=== Test Batched Inference ===
Batch shape: torch.Size([2, 54400])
Seq lens: [27200, 54400]
BatchLayout: BatchLayout(width=54400, seq_begin_indices=[0, 54400, 108800], seq_lens=[tensor(27200, device='cuda:0'), tensor(54400, device='cuda:0')], min_seq_len=27200, max_seq_len=54400, padded=True, packed=False)

Output logits shape: torch.Size([2, 169, 9812])
Output seq_lens: [84, 169]

Sample 0 (len=84): 'i had that curiosi'

Sample 1 (len=169): 'i had that curiosity beside me at this moment'


In [57]:
import sys
import os

# First, clone repo to get access to install_utils
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    GITHUB_REPO = "https://github.com/huangruizhe/torchaudio_aligner.git"
    BRANCH = "dev"
    repo_path = '/content/torchaudio_aligner'
    src_path = f'{repo_path}/src'
    
    if not os.path.exists(repo_path):
        os.system(f'git clone -b {BRANCH} {GITHUB_REPO} {repo_path}')
    else:
        os.system(f'cd {repo_path} && git pull origin dev')
    
    if src_path not in sys.path:
        sys.path.insert(0, src_path)

In [58]:
# Reload the module
import importlib
import labeling_utils
importlib.reload(labeling_utils)

# Test the OmniASR backend
from labeling_utils import load_model

backend = load_model("omniasr-300m", device="cuda")
print(f"Vocab size: {len(backend.get_vocab_info().labels)}")
print(f"First 20 tokens: {backend.get_vocab_info().labels[:20]}")

# Test emission extraction
import torch
waveform = torch.randn(1, 16000).to("cuda")  # 1 second of audio
emissions, lengths = backend.get_emissions(waveform)
print(f"Emissions shape: {emissions.shape}")
print(f"Lengths: {lengths}")

Vocab size: 9812
First 20 tokens: ['<s>', '<pad>', '</s>', '<unk>', ' ', 'ዘ', 'љ', '耷', '氰', '賬', '祭', '繳', '渭', '穗', '捡', '栩', '勸', '诚', '戮', '尉']
Emissions shape: torch.Size([1, 49, 9812])
Lengths: tensor([49], device='cuda:0')
