In [2]:
from torchaudio.models.decoder import ctc_decoder, cuda_ctc_decoder
import numpy as np 
import torch
from torch.nn.functional import softmax, log_softmax
from brainaudio.inference.inference_utils import _cer_and_wer
import torch.distributions as dist


In [3]:
import re

# --- The dictionary of informal shorthand words ---
# Does not include "a" or "I" as they are standard English.
SHORTHAND_MAP = {
    'b': 'be',
    'c': 'see',
    'r': 'are',
    'u': 'you',
    'y': 'why'
}

def normalize_shorthand(text: str) -> str:
    """
    Converts informal single-character shorthand (like 'u', 'r', 'c')
    in a string to their full-word equivalents.
    ASSUMES INPUT TEXT IS ALREADY LOWERCASE.

    Args:
        text: The input string (assumed to be lowercase).

    Returns:
        The modified string with shorthand words replaced.
    """
    
    modified_text = text
    
    for shorthand, full_word in SHORTHAND_MAP.items():
        
        # This is the regex pattern to find the whole word.
        # \b = word boundary (matches start/end of a word)
        # re.escape(shorthand) = the letter itself (e.g., 'c')
        # No re.IGNORECASE needed as we assume lowercase input.
        pattern = r'\b' + re.escape(shorthand) + r'\b'

        # Simplified replacement: just use the lowercase full word.
        modified_text = re.sub(
            pattern, 
            full_word, 
            modified_text
        )

    return modified_text

In [5]:
language_model_path = "/data2/brain2text/lm/languageModel/"
lexicon_phonemes_file = f"{language_model_path}lexicon_phonemes.txt"
units_txt_file_pytorch = f"{language_model_path}units_pytorch.txt"

units_txt_file_pytorch_char = f"{language_model_path}units_pytorch_character.txt"
lexicon_char_file= f"{language_model_path}lexicon_char.txt"

imagineville_vocab_phoneme = "/data2/brain2text/lm/vocab_lower_100k_pytorch_phoneme.txt"

In [6]:
import pandas as pd
val_transcripts = pd.read_pickle("/data2/brain2text/b2t_24/transcripts_val.pkl")

with open("/data2/brain2text/b2t_24/wfst_txt/tm_transformer_b2t_24+25_large_wide_bidir_grad_clip_cosine_decay.txt") as file:
    
    wfst_txt =  [line.strip() for line in file]

In [7]:
model_logits = np.load("/data2/brain2text/b2t_24/logits/tm_transformer_b2t_24+25_large_wide_bidir_grad_clip_cosine_decay/logits_val.npz")

In [8]:
decoder = ctc_decoder(tokens=units_txt_file_pytorch, lexicon=imagineville_vocab_phoneme, 
                       beam_size=30, nbest=1, lm="/data2/brain2text/lm/lm_dec19_huge_4gram.kenlm", lm_weight=2.0, word_score=0.1)

[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached l

In [9]:
idx = 412
acoustic_scale = 0.7
print("Transcript: ", val_transcripts[idx])
print("WFST: ", wfst_txt[idx])

# 1. Get the original logits
single_trial_logits = torch.as_tensor(model_logits[f'arr_{idx}']).float().unsqueeze(0)

# 4. Perform beam search on the NOISY logits (and apply temperature)
beam_search_outs = decoder(single_trial_logits[0:1, :]*acoustic_scale)

for i in range(len(beam_search_outs[0])):
    beam_search_transcript = normalize_shorthand(" ".join(beam_search_outs[0][i].words).strip())
    print("Pytorch: ", beam_search_transcript)

Transcript:  you are frequently exploited
WFST:  you are physically exploited
Pytorch:  you are fickle exploited


In [10]:
ground_truth_arr = []
pred_arr = []
for idx in range(880):
    if idx % 100 == 0:
        print(idx)
    single_trial_logits = torch.as_tensor(model_logits[f'arr_{idx}']).float().unsqueeze(0)
    beam_search_char = decoder(single_trial_logits * acoustic_scale)
    beam_search_transcript_char = normalize_shorthand(" ".join(beam_search_char[0][0].words).strip())
    ground_truth_sentence = val_transcripts[idx]
    ground_truth_arr.append(ground_truth_sentence.replace("-", ""))
    pred_arr.append(beam_search_transcript_char)
    
cer, wer, wer_sent = _cer_and_wer(pred_arr, ground_truth_arr)
print(wer)

0
100
200
300
400
500
600
700
800
0.20294599018003273


In [19]:
for p, w, g in zip(pred_arr, wfst_txt, ground_truth_arr):
    print("Trial: ")
    print(g)
    print("\n")
    print(w)
    print("\n")
    print(p)
    print("\n")
    

Trial: 
theocracy reconsidered


they aka recanted


they aka recanted


Trial: 
rich purchased several signed lithographs


takes purchase ever signed lithographs


take purchase ever said lithograph


Trial: 
so rules we made, in unabashed collusion


so rules we made in in a vast collusion


so rules we made in innovate collusion


Trial: 
lori's costume needed black gloves to be completely elegant


troy system needed back loves to be completely ellen


try awesome needed back love to bb compel ellen


Trial: 
the tooth fairy forgot to come when roger's tooth fell out


the truth i forgot to come run watchers tooth fell out


the truth i forgot to come run wires to fail out


Trial: 
that stinging vapor was caused by chloride vaporization


that singing vapor power was so the kid paper edition


that singing may pao was so the kid paper i decision


Trial: 
before thursday's exam, review every formula


before this i am reveal over formula


before this i am reveal ever formless




In [145]:
lm_weights = [1, 1.5, 2, 2.5, 3]   
acoustic_score = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
word_penalty = [0, 0.1, -0.1, 0.2, -0.2, 0.3, -0.3]
sil_score = [0, -0.1, -0.2, -0.3]
beam_size = [100]

wer_dict = {}
wer_dict['sil'] = []
wer_dict['lmw'] = []
wer_dict['wp'] = []
wer_dict['bs'] = []
wer_dict['ac'] = []

wer_dict['wer'] = []

for sil in sil_score:
    for wp in word_penalty:
      for lmw in lm_weights:
          for bs in beam_size:
          
            decoder = ctc_decoder(tokens=units_txt_file_pytorch, lexicon=imagineville_vocab, 
                       beam_size=bs, nbest=1, lm="/data2/brain2text/lm/lm_dec19_huge_4gram.kenlm", 
                       lm_weight=lmw, word_score=wp, sil_score=sil, log_add=True, beam_threshold=1e3)

            for ac in acoustic_score:
              
              ground_truth_arr = []
              pred_arr = []
          
              for idx in range(880):
                  single_trial_logits = torch.as_tensor(model_logits[f'arr_{idx}']).float().unsqueeze(0)
                  beam_search_outs = decoder(single_trial_logits*ac)
                  beam_search_transcript = normalize_shorthand(" ".join(beam_search_outs[0][0].words).strip())
                  ground_truth_sentence = val_transcripts[idx]
                  ground_truth_arr.append(ground_truth_sentence)
                  pred_arr.append(beam_search_transcript)
                  
              cer, wer, wer_sent = _cer_and_wer(pred_arr, ground_truth_arr)
              
              
              wer_dict['sil'].append(sil)
              wer_dict['lmw'].append(lmw)
              wer_dict['wp'].append(wp)
              wer_dict['bs'].append(bs)
              wer_dict['ac'].append(ac)
              wer_dict['wer'].append(wer)
              
              
              print(f"sil score: {sil}, lm weight: {lmw}, word penalty: {wp}, beam size: {bs}, acoustic score: {ac},  wer: {wer}")


[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached limit: 6
[Trie] Trie label number reached l

KeyError: 'ac'