In [66]:
from pyctcdecode import build_ctcdecoder
import numpy as np

In [69]:
labels_file_path = "/data2/brain2text/lm/char_lm/units_pytorch_character.txt"
labels = []
with open(labels_file_path, 'r', encoding='utf-8') as file:
    for line in file:
        labels.append(line.strip()) # .strip() removes leading/trailing whitespace and newline characters


logits = np.load("/data2/brain2text/b2t_25/logits/pretrained_RNN/char_logits_val.npz")
logits_list = []          
for key in logits.files:                                                                                                            
    trial = logits[key]                                                                                                             
    trial_rearranged = np.concatenate([trial[:, 1:], trial[:, 0:1]], axis=1)                                                        
    logits_list.append(trial_rearranged)   
    
    
# 1. Remove the blank '-' from labels (first entry)                                                                                 
labels_without_blank = labels[1:]  # Now 32 labels                                                                                  
                                                  
                                                                                                                                    
# 3. Build decoder with the modified labels                                                                                         
decoder = build_ctcdecoder(                                                                                                         
    labels_without_blank,  # 32 labels, decoder will expect 33 columns (32 + blank)                                                 
    kenlm_model_path="/data2/brain2text/lm/lm_dec19_huge_4gram.kenlm",                                                              
    alpha=0.5,                                                                                                                      
    beta=1.0,                                                                                                                       
)                                                                                                                                   
                                                                                                                                    

Unigrams not provided and cannot be automatically determined from LM file (only arpa format). Decoding accuracy might be reduced.
No known unigrams provided, decoding results might be a lot worse.


In [None]:
texts = decoder.decode_batch(logits_list=logits_list, beam_width=100) 

TypeError: BeamSearchDecoderCTC.decode_batch() missing 1 required positional argument: 'logits_list'

In [64]:
def _greedy_decode(logits, labels):
    """Decode argmax of logits and squash in CTC fashion."""
    label_dict = {n: c for n, c in enumerate(labels)}
    prev_c = None
    out = []
    for n in logits.argmax(axis=1):
        c = label_dict.get(n, "")  # if not in labels, then assume it's ctc blank char
        if c != prev_c:
            out.append(c)
        prev_c = c
    return "".join(out)

In [65]:
_greedy_decode(logits_trial_0, labels=labels)

'-you|-can|-se-e|-the|-co-d|-at|-this|-point|-as|-wel-l|'

In [None]:
# pyctcdecode tutorial labels

labels = [
    " ", 
    "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", 
    "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", 
    "'",
]

print(len(labels))

28


In [None]:
# Fix: Remove blank from labels and move blank column to the end of logits
import numpy as np

# Load labels (skip the first one which is the blank token '-')
labels_file_path = "/data2/brain2text/lm/char_lm/units_pytorch_character.txt"
labels = []
with open(labels_file_path, 'r', encoding='utf-8') as file:
    for line in file:
        labels.append(line.strip())

# Remove the blank token from labels (pyctcdecode will add its own at the end)
labels_without_blank = labels[1:]  # Skip '-' at index 0
print(f"Original labels count: {len(labels)}")
print(f"Labels without blank: {len(labels_without_blank)}")
print(f"First few labels: {labels_without_blank[:5]}")

# Load logits
logits = np.load("/data2/brain2text/b2t_25/logits/pretrained_RNN/char_logits_val.npz")
logits_trial_0 = logits['arr_0']
print(f"Original logits shape: {logits_trial_0.shape}")

# Rearrange logits: move blank column (index 0) to the end
# New order: columns 1-32 (the actual characters), then column 0 (blank)
logits_rearranged = np.concatenate([logits_trial_0[:, 1:], logits_trial_0[:, 0:1]], axis=1)
print(f"Rearranged logits shape: {logits_rearranged.shape}")