In [1]:
from nemo.collections.asr.parts.submodules.ctc_batched_beam_decoding import BatchedBeamCTCComputer
from nemo.collections.asr.parts.submodules.lexicon_constraint import LexiconConstraint
import numpy as np
import torch

PHONE_DEF = [
    'CTC BLANK', 'AA', 'AE', 'AH', 'AO', 'AW',
    'AY', 'B',  'CH', 'D', 'DH',
    'EH', 'ER', 'EY', 'F', 'G',
    'HH', 'IH', 'IY', 'JH', 'K',
    'L', 'M', 'N', 'NG', 'OW',
    'OY', 'P', 'R', 'S', 'SH',from nemo.collections.asr.parts.submodules.ctc_batched_beam_decoding import BatchedBeamCTCComputer
from nemo.collections.asr.parts.submodules.lexicon_constraint import LexiconConstraint
import numpy as np
import torch

PHONE_DEF = [
    'CTC BLANK', 'AA', 'AE', 'AH', 'AO', 'AW',
    'AY', 'B',  'CH', 'D', 'DH',
    'EH', 'ER', 'EY', 'F', 'G',
    'HH', 'IH', 'IY', 'JH', 'K',
    'L', 'M', 'N', 'NG', 'OW',
    'OY', 'P', 'R', 'S', 'SH',
    'T', 'TH', 'UH', 'UW', 'V',
    'W', 'Y', 'Z', 'ZH', 'SIL'
]

def decode_phonemes(ids):
    """
    Decode phoneme IDs to readable string with CTC merging rules applied.
    - Removes blanks (ID 0)
    - Merges consecutive repeated phonemes
    - Optionally removes silence tokens
    """
    # Handle torch tensors
    if hasattr(ids, 'cpu'):
        ids = ids.cpu().numpy()
    
    # Apply CTC rules: remove blanks and merge repeats
    phonemes = []
    prev_id = None
    
    for id_val in ids:
        # Skip blank
        if id_val == 0:
            prev_id = None
            continue
        
        # Skip repeats
        if id_val == prev_id:
            continue
            
        # Add phoneme
        if 0 <= id_val < len(PHONE_DEF):
            phonemes.append(PHONE_DEF[id_val])
        else:
            phonemes.append(f'<UNK:{id_val}>')
        
        prev_id = id_val
    
    return ' '.join(phonemes)

def decode_to_words(ids, lexicon_file):
    """
    Convert phoneme IDs to words by matching against the lexicon.
    
    Args:
        ids: Sequence of phoneme IDs
        lexicon_file: Path to lexicon file
        
    Returns:
        String of words separated by spaces
    """
    # Handle torch tensors
    if hasattr(ids, 'cpu'):
        ids = ids.cpu().numpy()
    
    # Apply CTC rules: remove blanks and merge repeats
    phoneme_ids = []
    prev_id = None
    
    for id_val in ids:
        if id_val == 0:  # Skip blank
            prev_id = None
            continue
        if id_val == prev_id:  # Skip repeats
            continue
        phoneme_ids.append(int(id_val))
        prev_id = id_val
    
    # Load phoneme-to-token mapping
    token_to_phoneme = {}
    with open('/data2/brain2text/lm/units_pytorch.txt', 'r') as f:
        for idx, line in enumerate(f):
            token_to_phoneme[idx] = line.strip()
    
    # Build reverse lexicon: phoneme sequence -> word
    # Format: "word PHONEME1 PHONEME2 ... |"
    phoneme_to_word = {}
    with open(lexicon_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 2:
                continue
            word = parts[0]
            # Remove the trailing "|" from phonemes for the key
            phonemes = tuple(p for p in parts[1:] if p != '|')
            phoneme_to_word[phonemes] = word
    
    # Split the sequence by silence tokens (|) and lookup each word
    words = []
    current_phonemes = []
    
    for pid in phoneme_ids:
        phoneme = token_to_phoneme.get(pid, '')
        
        if phoneme == '|':  # Word boundary
            if current_phonemes:
                # Lookup word without the trailing "|"
                phoneme_tuple = tuple(current_phonemes)
                word = phoneme_to_word.get(phoneme_tuple, f"<UNK:{' '.join(current_phonemes)}>")
                words.append(word)
                current_phonemes = []
        else:
            current_phonemes.append(phoneme)
    
    # Handle any remaining phonemes (sequences without trailing |)
    if current_phonemes:
        phoneme_tuple = tuple(current_phonemes)
        word = phoneme_to_word.get(phoneme_tuple, f"<UNK:{' '.join(current_phonemes)}>")
        words.append(word)
    
    return ' '.join(words)

if __name__ == "__main__":
    model_logits = np.load("/data2/brain2text/b2t_25/logits/tm_transformer_combined_reduced_reg_seed_0/logits_val_None_None.npz")
    language_model_path = "/data2/brain2text/lm/"
    tokens_txt = f"{language_model_path}units_pytorch.txt"
    words_txt = "/data2/brain2text/lm/vocab_lower_100k_pytorch_phoneme.txt"

logits_0 = model_logits["arr_0"]
logits_1 = model_logits["arr_1"]

# Stack along batch dimension with time padding
max_time = max(logits_0.shape[0], logits_1.shape[0])
vocab_size = logits_0.shape[1]

# Pad to max time if needed
if logits_0.shape[0] < max_time:
    pad_width = ((0, max_time - logits_0.shape[0]), (0, 0))
    logits_0 = np.pad(logits_0, pad_width, mode='constant', constant_values=-np.inf)

if logits_1.shape[0] < max_time:
    pad_width = ((0, max_time - logits_1.shape[0]), (0, 0))
    logits_1 = np.pad(logits_1, pad_width, mode='constant', constant_values=-np.inf)

# Stack along batch dimension: (batch=2, time, vocab)
logits_batched = np.stack([logits_0, logits_1], axis=0)

# Store original lengths before padding
logits_lengths = np.array([
    model_logits["arr_0"].shape[0],
    model_logits["arr_1"].shape[0]
])

# Convert to torch and move to GPU
logits_batched = torch.from_numpy(logits_batched).to('cuda:0')
logits_lengths = torch.from_numpy(logits_lengths).to('cuda:0')

print(f"Logits shape: {logits_batched.shape}")
print(f"Device: {logits_batched.device}\n")

# Load lexicon constraint

lexicon = LexiconConstraint.from_file_paths(
    tokens_file=tokens_txt,
    lexicon_file=words_txt,
    device=torch.device('cuda:0'),
)

# Build word mapping for decoding
phoneme_to_word = {}
with open(words_txt, 'r') as f:
    for line in f:
        parts = line.strip().split()
        if len(parts) < 2:
            continue
        word = parts[0]
        # Remove trailing "|" from phonemes for the key
        phonemes = tuple(p for p in parts[1:] if p != '|')
        phoneme_to_word[phonemes] = word

print(f"Lexicon loaded successfully!")
print(f"  Words in lexicon: {len(phoneme_to_word)}")

# Create decoder with lexicon constraint
decoder = BatchedBeamCTCComputer(
    blank_index=lexicon.blank_index,
    beam_size=10,
    lexicon=lexicon,  # Add lexicon constraint
)

print(f"\nDecoder created with lexicon constraint (beam size: {decoder.beam_size})")

# Run decoding
print(f"Running beam search...")
result = decoder(logits_batched, logits_lengths)
print(f"Decoding complete!\n")

# Display results - only show first hypothesis
for b in range(logits_batched.shape[0]):
    print(f"=== Utterance {b} ===")
    seq = result.transcript_wb[b, 0]  # Only first hypothesis
    seq_filtered = seq[seq >= 0]
    score = result.scores[b, 0].item()
    if score > float('-inf'):
        # Apply CTC rules
        ids_no_blanks = []
        prev_id = None
        for id_val in seq_filtered.cpu().numpy():
            if id_val == 0:  # Skip blank
                prev_id = None
                continue
            if id_val == prev_id:  # Skip repeats
                continue
            ids_no_blanks.append(int(id_val))
            prev_id = id_val
        
        # Decode using lexicon
        token_to_symbol = {}
        with open(tokens_txt, 'r') as f:
            for idx, line in enumerate(f):
                token_to_symbol[idx] = line.strip()
        
        # Get words with alternatives (homophones)
        word_alts = lexicon.decode_sequence_to_words(ids_no_blanks, token_to_symbol, phoneme_to_word, return_alternatives=True)
        
        # Display only words with their homophones
        print("Decoded words (with homophones where available):\n")
        for i, (primary_word, alternatives) in enumerate(word_alts, 1):
            if alternatives:
                print(f"{i}. {alternatives}")
            else:
                print(f"{i}. {primary_word}")
        
        print()  # Blank line between utterances

    'T', 'TH', 'UH', 'UW', 'V',
    'W', 'Y', 'Z', 'ZH', 'SIL'
]

def decode_phonemes(ids):
    """
    Decode phoneme IDs to readable string with CTC merging rules applied.
    - Removes blanks (ID 0)
    - Merges consecutive repeated phonemes
    - Optionally removes silence tokens
    """
    # Handle torch tensors
    if hasattr(ids, 'cpu'):
        ids = ids.cpu().numpy()
    
    # Apply CTC rules: remove blanks and merge repeats
    phonemes = []
    prev_id = None
    
    for id_val in ids:
        # Skip blank
        if id_val == 0:
            prev_id = None
            continue
        
        # Skip repeats
        if id_val == prev_id:
            continue
            
        # Add phoneme
        if 0 <= id_val < len(PHONE_DEF):
            phonemes.append(PHONE_DEF[id_val])
        else:
            phonemes.append(f'<UNK:{id_val}>')
        
        prev_id = id_val
    
    return ' '.join(phonemes)

def decode_to_words(ids, lexicon_file):
    """
    Convert phoneme IDs to words by matching against the lexicon.
    
    Args:
        ids: Sequence of phoneme IDs
        lexicon_file: Path to lexicon file
        
    Returns:
        String of words separated by spaces
    """
    # Handle torch tensors
    if hasattr(ids, 'cpu'):
        ids = ids.cpu().numpy()
    
    # Apply CTC rules: remove blanks and merge repeats
    phoneme_ids = []
    prev_id = None
    
    for id_val in ids:
        if id_val == 0:  # Skip blank
            prev_id = None
            continue
        if id_val == prev_id:  # Skip repeats
            continue
        phoneme_ids.append(int(id_val))
        prev_id = id_val
    
    # Load phoneme-to-token mapping
    token_to_phoneme = {}
    with open('/data2/brain2text/lm/units_pytorch.txt', 'r') as f:
        for idx, line in enumerate(f):
            token_to_phoneme[idx] = line.strip()
    
    # Build reverse lexicon: phoneme sequence -> word
    # Format: "word PHONEME1 PHONEME2 ... |"
    phoneme_to_word = {}
    with open(lexicon_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 2:
                continue
            word = parts[0]
            # Remove the trailing "|" from phonemes for the key
            phonemes = tuple(p for p in parts[1:] if p != '|')
            phoneme_to_word[phonemes] = word
    
    # Split the sequence by silence tokens (|) and lookup each word
    words = []
    current_phonemes = []
    
    for pid in phoneme_ids:
        phoneme = token_to_phoneme.get(pid, '')
        
        if phoneme == '|':  # Word boundary
            if current_phonemes:
                # Lookup word without the trailing "|"
                phoneme_tuple = tuple(current_phonemes)
                word = phoneme_to_word.get(phoneme_tuple, f"<UNK:{' '.join(current_phonemes)}>")
                words.append(word)
                current_phonemes = []
        else:
            current_phonemes.append(phoneme)
    
    # Handle any remaining phonemes (sequences without trailing |)
    if current_phonemes:
        phoneme_tuple = tuple(current_phonemes)
        word = phoneme_to_word.get(phoneme_tuple, f"<UNK:{' '.join(current_phonemes)}>")
        words.append(word)
    
    return ' '.join(words)

model_logits = np.load("/data2/brain2text/b2t_25/logits/tm_transformer_combined_reduced_reg_seed_0/logits_val_None_None.npz")
language_model_path = "/data2/brain2text/lm/"
tokens_txt = f"{language_model_path}units_pytorch.txt"
words_txt = "/data2/brain2text/lm/vocab_lower_100k_pytorch_phoneme.txt"

logits_0 = model_logits["arr_0"]
logits_1 = model_logits["arr_1"]

# Stack along batch dimension with time padding
max_time = max(logits_0.shape[0], logits_1.shape[0])
vocab_size = logits_0.shape[1]

# Pad to max time if needed
if logits_0.shape[0] < max_time:
    pad_width = ((0, max_time - logits_0.shape[0]), (0, 0))
    logits_0 = np.pad(logits_0, pad_width, mode='constant', constant_values=-np.inf)

if logits_1.shape[0] < max_time:
    pad_width = ((0, max_time - logits_1.shape[0]), (0, 0))
    logits_1 = np.pad(logits_1, pad_width, mode='constant', constant_values=-np.inf)

# Stack along batch dimension: (batch=2, time, vocab)
logits_batched = np.stack([logits_0, logits_1], axis=0)

# Store original lengths before padding
logits_lengths = np.array([
    model_logits["arr_0"].shape[0],
    model_logits["arr_1"].shape[0]
])

# Convert to torch and move to GPU
logits_batched = torch.from_numpy(logits_batched).to('cuda:0')
logits_lengths = torch.from_numpy(logits_lengths).to('cuda:0')

print(f"Logits shape: {logits_batched.shape}")
print(f"Device: {logits_batched.device}\n")

# Load lexicon constraint

lexicon = LexiconConstraint.from_file_paths(
    tokens_file=tokens_txt,
    lexicon_file=words_txt,
    device=torch.device('cuda:0'),
)

# Build word mapping for decoding
phoneme_to_word = {}
with open(words_txt, 'r') as f:
    for line in f:
        parts = line.strip().split()
        if len(parts) < 2:
            continue
        word = parts[0]
        # Remove trailing "|" from phonemes for the key
        phonemes = tuple(p for p in parts[1:] if p != '|')
        phoneme_to_word[phonemes] = word

print(f"Lexicon loaded successfully!")
print(f"  Words in lexicon: {len(phoneme_to_word)}")

# Create decoder with lexicon constraint
decoder = BatchedBeamCTCComputer(
    blank_index=lexicon.blank_index,
    beam_size=10,
    lexicon=lexicon,  # Add lexicon constraint
)

print(f"\nDecoder created with lexicon constraint (beam size: {decoder.beam_size})")

# Run decoding
print(f"Running beam search...")
result = decoder(logits_batched, logits_lengths)
print(f"Decoding complete!\n")

# Display results - only show first hypothesis
for b in range(logits_batched.shape[0]):
    print(f"=== Utterance {b} ===")
    seq = result.transcript_wb[b, 0]  # Only first hypothesis
    seq_filtered = seq[seq >= 0]
    score = result.scores[b, 0].item()
    if score > float('-inf'):
        # Apply CTC rules
        ids_no_blanks = []
        prev_id = None
        for id_val in seq_filtered.cpu().numpy():
            if id_val == 0:  # Skip blank
                prev_id = None
                continue
            if id_val == prev_id:  # Skip repeats
                continue
            ids_no_blanks.append(int(id_val))
            prev_id = id_val
        
        # Decode using lexicon
        token_to_symbol = {}
        with open(tokens_txt, 'r') as f:
            for idx, line in enumerate(f):
                token_to_symbol[idx] = line.strip()
        
        # Get words with alternatives (homophones)
        word_alts = lexicon.decode_sequence_to_words(ids_no_blanks, token_to_symbol, phoneme_to_word, return_alternatives=True)
        
        # Display only words with their homophones
        print("Decoded words (with homophones where available):\n")
        for i, (primary_word, alternatives) in enumerate(word_alts, 1):
            if alternatives:
                print(f"{i}. {alternatives}")
            else:
                print(f"{i}. {primary_word}")
        
        print()  # Blank line between utterances


  from .autonotebook import tqdm as notebook_tqdm
[NeMo W 2025-11-21 17:20:53 megatron_init:62] Megatron num_microbatches_calculator not found, using Apex version.
[NeMo W 2025-11-21 17:20:53 megatron_init:62] Megatron num_microbatches_calculator not found, using Apex version.
OneLogger: Setting error_handling_strategy to DISABLE_QUIETLY_AND_REPORT_METRIC_ERROR for rank (rank=0) with OneLogger disabled. To override: explicitly set error_handling_strategy parameter.
No exporters were provided. This means that no telemetry data will be collected.
OneLogger: Setting error_handling_strategy to DISABLE_QUIETLY_AND_REPORT_METRIC_ERROR for rank (rank=0) with OneLogger disabled. To override: explicitly set error_handling_strategy parameter.
No exporters were provided. This means that no telemetry data will be collected.
    
    


  from .autonotebook import tqdm as notebook_tqdm
[NeMo W 2025-11-21 17:20:53 megatron_init:62] Megatron num_microbatches_calculator not found, using Apex version.
[NeMo W 2025-11-21 17:20:53 megatron_init:62] Megatron num_microbatches_calculator not found, using Apex version.
OneLogger: Setting error_handling_strategy to DISABLE_QUIETLY_AND_REPORT_METRIC_ERROR for rank (rank=0) with OneLogger disabled. To override: explicitly set error_handling_strategy parameter.
No exporters were provided. This means that no telemetry data will be collected.
OneLogger: Setting error_handling_strategy to DISABLE_QUIETLY_AND_REPORT_METRIC_ERROR for rank (rank=0) with OneLogger disabled. To override: explicitly set error_handling_strategy parameter.
No exporters were provided. This means that no telemetry data will be collected.
    
    


Logits shape: torch.Size([2, 228, 41])
Device: cuda:0

Lexicon loaded successfully!
  Words in lexicon: 88200

Decoder created with lexicon constraint (beam size: 10)
Running beam search...
Lexicon loaded successfully!
  Words in lexicon: 88200

Decoder created with lexicon constraint (beam size: 10)
Running beam search...
Decoding complete!

=== Utterance 0 ===
Decoded words (with homophones where available):

1. ['you', 'u', 'eu', 'yu', 'ew', 'yuh', 'yoo', 'yew', 'eww', 'uwe', 'yue', 'ewe', 'yw']
2. ['can', 'kan', 'caen', 'cann', 'cannae', 'kann', 'cahn']
3. ['see', 'sea', 'c', 'sc', 'si', 'ci', 'sci', 'cee', 'sei', 'sse', 'sie', 'cie']
4. the
5. ['code', "ko'd"]
6. ['at', 'att', "'at"]
7. ['this', "this'"]
8. ['point', 'pointe']
9. ['as', 'az', 'azz']
10. ['well', 'welle', "we'l"]

=== Utterance 1 ===
Decoding complete!

=== Utterance 0 ===
Decoded words (with homophones where available):

1. ['you', 'u', 'eu', 'yu', 'ew', 'yuh', 'yoo', 'yew', 'eww', 'uwe', 'yue', 'ewe', 'yw']
2. ['