In [4]:
from brainaudio.inference.decoder import BatchedBeamCTCComputer, LexiconConstraint
import numpy as np
import torch

# Configuration
LANGUAGE_MODEL_PATH = "/data2/brain2text/lm/"
TOKENS_TXT = f"{LANGUAGE_MODEL_PATH}units_pytorch.txt"
WORDS_TXT = "/data2/brain2text/lm/vocab_lower_100k_pytorch_phoneme.txt"
LOGITS_PATH = "/data2/brain2text/b2t_25/logits/tm_transformer_combined_reduced_reg_seed_0/logits_val_None_None.npz"

PHONE_DEF = [
    'CTC BLANK', 'AA', 'AE', 'AH', 'AO', 'AW',
    'AY', 'B',  'CH', 'D', 'DH',
    'EH', 'ER', 'EY', 'F', 'G',
    'HH', 'IH', 'IY', 'JH', 'K',
    'L', 'M', 'N', 'NG', 'OW',
    'OY', 'P', 'R', 'S', 'SH',
    'T', 'TH', 'UH', 'UW', 'V',
    'W', 'Y', 'Z', 'ZH', 'SIL'
]

def apply_ctc_rules(ids):
    """Apply CTC rules: remove blanks (0) and merge consecutive repeats."""
    if hasattr(ids, 'cpu'):
        ids = ids.cpu().numpy()
    
    clean_ids = []
    prev_id = None
    
    for id_val in ids:
        if id_val == 0:  # Skip blank
            prev_id = None
            continue
        if id_val == prev_id:  # Skip repeats
            continue
        clean_ids.append(int(id_val))
        prev_id = id_val
    
    return clean_ids

def load_phoneme_to_word_mapping(lexicon_file):
    """Build phoneme sequence -> word mapping from lexicon."""
    phoneme_to_word = {}
    with open(lexicon_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 2:
                continue
            word = parts[0]
            phonemes = tuple(p for p in parts[1:] if p != '|')
            phoneme_to_word[phonemes] = word
    return phoneme_to_word

def load_token_to_phoneme_mapping(tokens_file):
    """Load token ID -> phoneme symbol mapping."""
    token_to_symbol = {}
    with open(tokens_file, 'r') as f:
        for idx, line in enumerate(f):
            token_to_symbol[idx] = line.strip()
    return token_to_symbol

IndentationError: unexpected indent (ctc_batched_beam_decoding.py, line 50)

In [6]:
# Load and prepare logits
model_logits = np.load(LOGITS_PATH)
logits_0 = model_logits["arr_0"]
logits_1 = model_logits["arr_1"]

# Pad to same time dimension
max_time = max(logits_0.shape[0], logits_1.shape[0])
if logits_0.shape[0] < max_time:
    pad_width = ((0, max_time - logits_0.shape[0]), (0, 0))
    logits_0 = np.pad(logits_0, pad_width, mode='constant', constant_values=-np.inf)
if logits_1.shape[0] < max_time:
    pad_width = ((0, max_time - logits_1.shape[0]), (0, 0))
    logits_1 = np.pad(logits_1, pad_width, mode='constant', constant_values=-np.inf)

# Batch and convert to torch
logits_batched = torch.from_numpy(np.stack([logits_0, logits_1], axis=0)).to('cuda:0')
logits_lengths = torch.from_numpy(np.array([
    model_logits["arr_0"].shape[0],
    model_logits["arr_1"].shape[0]
])).to('cuda:0')

print(f"Logits shape: {logits_batched.shape}")
print(f"Device: {logits_batched.device}\n")

# Load lexicon and mappings
lexicon = LexiconConstraint.from_file_paths(
    tokens_file=TOKENS_TXT,
    lexicon_file=WORDS_TXT,
    device=torch.device('cuda:0'),
)
phoneme_to_word = load_phoneme_to_word_mapping(WORDS_TXT)
token_to_symbol = load_token_to_phoneme_mapping(TOKENS_TXT)

print(f"Lexicon loaded: {len(phoneme_to_word)} words")

# Create decoder
decoder = BatchedBeamCTCComputer(
    blank_index=lexicon.blank_index,
    beam_size=10,
    lexicon=lexicon,
)
print(f"Decoder created (beam size: {decoder.beam_size})")

# Run beam search
print(f"Running beam search...")
result = decoder(logits_batched, logits_lengths)
print(f"Decoding complete!\n")

# Display results
for b in range(logits_batched.shape[0]):
    print(f"=== Utterance {b} ===")
    seq = result.transcript_wb[b, 0]
    seq_filtered = seq[seq >= 0]
    score = result.scores[b, 0].item()
    
    if score > float('-inf'):
        # Apply CTC rules and decode
        ids_no_blanks = apply_ctc_rules(seq_filtered)
        word_alts = lexicon.decode_sequence_to_words(ids_no_blanks, token_to_symbol, phoneme_to_word, return_alternatives=True)
        
        # Display words with homophones
        print("Decoded words:\n")
        for i, (primary_word, alternatives) in enumerate(word_alts, 1):
            if alternatives:
                print(f"{i}. {alternatives}")
            else:
                print(f"{i}. {primary_word}")
        print()

Logits shape: torch.Size([2, 228, 41])
Device: cuda:0

Lexicon loaded: 88200 words
Decoder created (beam size: 10)
Running beam search...
Lexicon loaded: 88200 words
Decoder created (beam size: 10)
Running beam search...
Decoding complete!

=== Utterance 0 ===
Decoded words:

1. ['you', 'u', 'eu', 'yu', 'ew', 'yuh', 'yoo', 'yew', 'eww', 'uwe', 'yue', 'ewe', 'yw']
2. ['can', 'kan', 'caen', 'cann', 'cannae', 'kann', 'cahn']
3. ['see', 'sea', 'c', 'sc', 'si', 'ci', 'sci', 'cee', 'sei', 'sse', 'sie', 'cie']
4. the
5. ['code', "ko'd"]
6. ['at', 'att', "'at"]
7. ['this', "this'"]
8. ['point', 'pointe']
9. ['as', 'az', 'azz']
10. ['well', 'welle', "we'l"]

=== Utterance 1 ===
Decoding complete!

=== Utterance 0 ===
Decoded words:

1. ['you', 'u', 'eu', 'yu', 'ew', 'yuh', 'yoo', 'yew', 'eww', 'uwe', 'yue', 'ewe', 'yw']
2. ['can', 'kan', 'caen', 'cann', 'cannae', 'kann', 'cahn']
3. ['see', 'sea', 'c', 'sc', 'si', 'ci', 'sci', 'cee', 'sei', 'sse', 'sie', 'cie']
4. the
5. ['code', "ko'd"]
6. ['at

In [10]:
import nemo

print(f"NeMo version: {nemo.__version__}")
print(nemo.__file__)

NeMo version: 2.7.0rc0
/home/ebrahim/NeMo/nemo/__init__.py
