In [None]:
from brainaudio.inference.decoder.ngram_lm import NGramGPULanguageModel
import torch
import math

# Phoneme vocabulary (same order as units_pytorch.txt, excluding blank)
PHONEMES = [
    'AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH',
    'EH', 'ER', 'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K',
    'L', 'M', 'N', 'NG', 'OW', 'OY', 'P', 'R', 'S', 'SH',
    'T', 'TH', 'UH', 'UW', 'V', 'W', 'Y', 'Z', 'ZH', 'SIL'
]
PHONEME_TO_IDX = {p: i for i, p in enumerate(PHONEMES)}

# Load the LM
lm = NGramGPULanguageModel.from_arpa(
    "/home/ebrahim/brainaudio/test_phoneme_bigram.arpa",
    vocab_size=40,
    token_offset=100,
)
lm = lm.to("cuda")

# Get initial BOS state
states = lm.get_init_states(batch_size=1, bos=True)

# Advance to get P(phoneme | <s>) for all phonemes
scores, next_states = lm.advance(states)

print("P(phoneme | <s>) - top 10:")
for score, idx in zip(*scores[0].topk(10)):
    print(f"  {PHONEMES[idx]:5s}: log_prob={score.item():.4f}, prob={math.exp(score.item()):.4f}")

# Evaluate a specific sequence: "DH AH SIL K AE T" (the cat)
print("\nScoring sequence 'DH AH SIL K AE T' (the cat):")
sequence = ['DH', 'AH', 'SIL', 'K', 'AE', 'T']
states = lm.get_init_states(batch_size=1, bos=True)
total_log_prob = 0.0

for phoneme in sequence:
    scores, next_states = lm.advance(states)
    idx = PHONEME_TO_IDX[phoneme]
    log_prob = scores[0, idx].item()
    total_log_prob += log_prob
    print(f"  P({phoneme:3s} | context) = {math.exp(log_prob):.4f} (log={log_prob:.4f})")
    # Move to next state
    states = next_states[:, idx]

print(f"\nTotal log prob: {total_log_prob:.4f}")
print(f"Total prob: {math.exp(total_log_prob):.6f}")


IndentationError: unexpected indent (2496436526.py, line 2)