In [45]:
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-xls-r-300m")
# processor = AutoProcessor.from_pretrained("hf-test/xls-r-300m-sv")

OSError: Can't load tokenizer for 'facebook/wav2vec2-xls-r-300m'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure 'facebook/wav2vec2-xls-r-300m' is the correct path to a directory containing all relevant files for a Wav2Vec2CTCTokenizer tokenizer.

In [61]:
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

In [64]:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")

In [65]:
vocab_dict = processor.tokenizer.get_vocab()
sorted_vocab_dict = {k: v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}

In [66]:
sorted_vocab_dict

{'<pad>': 0,
 '<s>': 1,
 '</s>': 2,
 '<unk>': 3,
 '|': 4,
 'E': 5,
 'T': 6,
 'A': 7,
 'O': 8,
 'N': 9,
 'I': 10,
 'H': 11,
 'S': 12,
 'R': 13,
 'D': 14,
 'L': 15,
 'U': 16,
 'M': 17,
 'W': 18,
 'C': 19,
 'F': 20,
 'G': 21,
 'Y': 22,
 'P': 23,
 'B': 24,
 'V': 25,
 'K': 26,
 "'": 27,
 'X': 28,
 'J': 29,
 'Q': 30,
 'Z': 31}

In [10]:
import re
from collections import defaultdict
import torch
import kenlm
from transformers import Wav2Vec2Processor, AutoProcessor
from pyctcdecode import build_ctcdecoder
import numpy as np
import os

from typing import List, Tuple, Optional, Union

os.environ["TOKENIZERS_PARALLELISM"] = "false"

class CTCTextEncoder:
    def __init__(
        self,
        arpa_path: Optional[str] = None,
        binary_path: Optional[str] = None,
        unigram_path: Optional[str] = None,
        pretrained_tokenizer: str = "facebook/wav2vec2-base-960h",
        lm_weight: float = 0.5,
        beam_size: int = 100,
        blank_token: str = "<pad>",  # Blank token as <pad> for Wav2Vec2
        unk_token: str = "<unk>",     # UNK token
        **kwargs
    ):
        """
        Initialize encoder with Wav2Vec2 tokenizer and beam search decoder.

        Changes:
        - Ensure normalization is strictly lowercase with only [a-z ].
        """
        self.beam_size = beam_size
        self.lm_weight = lm_weight
        self.arpa_path = arpa_path
        self.binary_path = binary_path
        self.blank_token = blank_token
        self.unk_token = unk_token
        self.printed_samples = 0
        self.max_printed_samples = 5
        print('CTC Text Encoder:')
        print('pretrained_tokenizer:', pretrained_tokenizer)
        print('lm_weight:', lm_weight)
        print('beam_size:', beam_size)
        print('binary_path:', binary_path)

        # unigram_path = None

        # Load unigrams if provided
        
        self.unigrams = None
        
        # if unigram_path and os.path.exists(unigram_path):
        #     print(f"Loading unigrams from: {unigram_path}")
        #     with open(unigram_path, 'r', encoding='utf-8') as f:
        #         self.unigrams = [line.strip().lower() for line in f if line.strip()]
        #     print(f"Loaded {len(self.unigrams)} unigrams")
        
        if unigram_path and os.path.exists(unigram_path):
            print(f"Loading unigrams from: {unigram_path}")
            with open(unigram_path, 'r', encoding='utf-8') as f:
                self.unigrams = [line.strip().lower() for line in f if line.strip()]
            print(f"Loaded {len(self.unigrams)} unigrams")


        self._initialize_wav2vec2_tokenizer(pretrained_tokenizer)

        # Create index mappings
        self.ind2char = dict(enumerate(self.vocab))
        self.char2ind = {v: k for k, v in self.ind2char.items()}
        self.blank_index  = self.char2ind[self.blank_token]

        print(f"\nVocabulary Info:")
        print(f"Size: {len(self.vocab)}")
        print("Full Vocabulary (up to first 50 tokens):", self.vocab[:50])
        print(f"Blank token: {self.blank_token}, Blank index: {self.blank_index}")

        print("Sample ind2char mappings:", {k: self.ind2char[k] for k in list(self.ind2char.keys())[:10]})
        print("Sample char2ind mappings:", {k: self.char2ind[k] for k in list(self.char2ind.keys())[:10]})

        self._initialize_language_model()

    # def _initialize_wav2vec2_tokenizer(self, pretrained_tokenizer: str):
    #     """Initialize vocabulary using Wav2Vec2 tokenizer."""
    #     self.processor = Wav2Vec2Processor.from_pretrained(pretrained_tokenizer)

    #     # Add the unique blank token if not present
    #     if self.blank_token not in self.processor.tokenizer.get_vocab():
    #         self.processor.tokenizer.add_tokens([self.blank_token])
    #         print(f"Added '{self.blank_token}' to the tokenizer's vocabulary.")
    #     else:
    #         print(f"'{self.blank_token}' already exists in the tokenizer's vocabulary.")

    #     # Add the UNK token if not present
    #     if self.unk_token not in self.processor.tokenizer.get_vocab():
    #         self.processor.tokenizer.add_tokens([self.unk_token])
    #         print(f"Added '{self.unk_token}' to the tokenizer's vocabulary.")
    #     else:
    #         print(f"'{self.unk_token}' already exists in the tokenizer's vocabulary.")

    #     # Get vocab, convert to lowercase and replace '|' with ' '
    #     original_vocab = list(self.processor.tokenizer.get_vocab().keys())
    #     self.vocab = 
    



    #     self.vocab = [t.replace('|', ' ') for t in self.vocab]

    #     # Debug: Print a few tokens after modification
    #     print("Modified Vocabulary (first 20 tokens):", self.vocab[:20])
    

    def _initialize_wav2vec2_tokenizer(self, pretrained_tokenizer: str):
        """Initialize vocabulary using Wav2Vec2 tokenizer."""
        # Initialize Wav2Vec2Processor
        # self.processor = Wav2Vec2Processor.from_pretrained(pretrained_tokenizer)

        

        # # Add the UNK token if not present
        # if self.unk_token not in self.processor.tokenizer.get_vocab():
        #     self.processor.tokenizer.add_tokens([self.unk_token])
        #     print(f"Added '{self.unk_token}' to the tokenizer's vocabulary.")
        # else:
        #     print(f"'{self.unk_token}' already exists in the tokenizer's vocabulary.")

        # # Get vocab without altering the case
        # self.vocab = list(self.processor.tokenizer.get_vocab().keys())

        # vocab_dict = processor.tokenizer.get_vocab()
        # sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
        # labels_adj = list(sorted_vocab_dict.keys())

        # self.vocab = [t.replace('|', ' ') for t in self.vocab]

        # # Debugging: Inspect the tokenizer's vocabulary casing
        # print("\n--- Tokenizer Vocabulary Inspection ---")
        # sample_size = 100  # Adjust as needed
        # sample_tokens = self.vocab[:sample_size]
        # print(f"First {sample_size} tokens in vocabulary:")
        # print(sample_tokens)

        # # Save the full tokenizer vocabulary to a file for comparison
        # with open("tokenizer_vocab.txt", "w", encoding="utf-8") as f:
        #     for token in self.vocab:
        #         f.write(f"{token}\n")
        # print("Full tokenizer vocabulary saved to 'tokenizer_vocab.txt'.")
        # print("----------------------------------------\n")
        

        from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC

        self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") # 100h

        # self.processor = AutoProcessor.from_pretrained("hf-test/xls-r-300m-sv")

        vocab_dict = self.processor.tokenizer.get_vocab()
        # sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
        sorted_vocab_dict = {k: v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}

        sorted_vocab_dict_for_labels= {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
        self.labels = list(sorted_vocab_dict_for_labels.keys())

        # Get vocab, convert to lowercase and replace '|' with ' '
        original_vocab = list(self.processor.tokenizer.get_vocab().keys())
        # self.vocab = [x.lower() for x in original_vocab]
        self.vocab = sorted_vocab_dict
        self.vocab = [t.replace('|', ' ') for t in self.vocab]

    
        # Add the unique blank token to the tokenizer's vocabulary if not present
        if self.blank_token not in self.processor.tokenizer.get_vocab():
            self.processor.tokenizer.add_tokens([self.blank_token])
            print(f"Added '{self.blank_token}' to the tokenizer's vocabulary.")
        else:
            print(f"'{self.blank_token}' already exists in the tokenizer's vocabulary.")

        # Debug: Print a few tokens after modification
        print("Modified Vocabulary:", self.vocab)


    def _initialize_language_model(self):
        """Initialize language model with explicit blank token handling."""
        self.lm = None
        self.decoder = None

        model_path = self.binary_path if self.binary_path else self.arpa_path
        print('model_path: ', model_path)
        if not model_path or not os.path.exists(model_path):
            print("No language model path provided or file does not exist.")
            return

        try:
            self.lm = kenlm.Model(model_path)
            print(f"Loaded {'binary' if self.binary_path else 'ARPA'} language model.")

            # labels = [self.blank_token] + [c for c in self.vocab if c != self.blank_token]

            decoder_config = {
                "labels": self.labels,
                "kenlm_model_path": model_path,
                "alpha": self.lm_weight,
                "beta": 0.1,
                "unk_score_offset": -10.0,
            }

            if self.unigrams:
                print("\n--- Unigrams List ---")
                # Save the unigrams to a file
                with open("unigrams_list.txt", "w", encoding="utf-8") as f:
                    for unigram in self.unigrams:
                        f.write(f"{unigram}\n")
                print(f"Unigrams list saved to 'unigrams_list.txt'. Total unigrams: {len(self.unigrams)}")
                print("----------------------\n")
                decoder_config["unigrams"] = self.unigrams

            self.decoder = build_ctcdecoder(**decoder_config)
            print("Successfully initialized language model and decoder.")

        except Exception as e:
            print(f"Warning: Failed to initialize decoder: {str(e)}")
            self.decoder = None


    def encode(self, text: str) -> torch.Tensor:
        """
        Encode text with Wav2Vec2 tokenizer.
        """
        debug = False

        if self.printed_samples < self.max_printed_samples:
            original_text = text
            text = self.normalize_text(text)
            if debug:
                print(f"samples: {str(self.printed_samples)}")
                print(f"\nEncoding text:")
                print(f" Original: '{original_text}'")
                print(f" Normalized: '{text}'")
                for ch in text:
                    print(ch, ord(ch))


        try:
            text = text.upper()
            encoded = self.processor.tokenizer(text, return_tensors="pt", padding=False, truncation=False)
            token_indices = encoded.input_ids[0].tolist()
            if self.printed_samples < self.max_printed_samples:
                # Convert indices to tokens from self.vocab
                tokens = [self.vocab[idx] if 0 <= idx < len(self.vocab) else "<invalid>" for idx in token_indices]
                # print(f" Tokens (lowercased and '|'->' '): {tokens}")
                # print(f" Token indices: {token_indices}")
                self.printed_samples += 1
            return torch.tensor(token_indices).unsqueeze(0)
        except KeyError as e:
            unknown_tokens = set([token for token in text.split() if token not in self.char2ind])
            raise Exception(f"Unknown tokens: '{' '.join(unknown_tokens)}'")
        except Exception as e:
            raise Exception(f"Encoding error: {str(e)}")

    # latest ver
    # def decode(self, indices: List[int]) -> str:
    #     """
    #     Decode indices to text using beam search decoder if available.
    #     """
    #     if self.decoder:
    #         decoded_text = self.decoder.decode(indices)
    #         # convert to lower case
    #         decoded_text = decoded_text.lower()
    #         return decoded_text
    #     else:
    #         decoded_text = self.decode_simple(indices)
    #         # convert to lower case
    #         decoded_text = decoded_text.lower()
    #         return
    
    def decode(self, indices: List[int]) -> str:
        """
        Decode indices to text using beam search decoder if available.
        """
        if self.decoder:
            decoded_text = self.decoder.decode(indices)
            # Convert to lower case
            decoded_text = decoded_text.lower()
            return decoded_text
        else:
            decoded_text = self.decode_simple(indices)
            # Convert to lower case
            decoded_text = decoded_text.lower()
            return decoded_text  # Ensure the decoded text is returned

    # latest ver
    # def decode_simple(self, indices: List[int]) -> str:
    #     """
    #     Simple CTC decoding without language model.
    #     """
    #     valid_indices = [
    #         idx for idx in indices
    #         if idx != self.blank_index and 0 <= idx < len(self.ind2char)
    #     ]
    #     try:
    #         tokens = [self.ind2char[idx] for idx in valid_indices]
    #         text = " ".join(tokens).strip().lower()
    #         return self.processor.tokenizer.clean_up_tokenization(text)
    #     except KeyError as e:
    #         return " ".join([self.ind2char[idx] for idx in valid_indices if idx in self.ind2char])

    # def decode_simple(self, indices: List[int]) -> str:
    #     """
    #     Simple CTC decoding without language model.
    #     Collapses consecutive duplicate tokens and removes blanks.
    #     """
    #     decoded_chars = []
    #     previous_idx = None


    #     print(indices)

    #     for idx in indices:
    #         if idx == self.blank_index:
    #             previous_idx = idx
    #             continue  # Skip blank tokens
    #         if idx == previous_idx:
    #             continue  # Skip duplicate tokens
    #         if 0 <= idx < len(self.ind2char):
    #             char = self.ind2char[idx]
    #             decoded_chars.append(char)
    #         previous_idx = idx

    #     # Join characters without spaces and convert to lowercase
    #     text = "".join(decoded_chars).strip().lower()

    #     # Clean up tokenization using the tokenizer's method
    #     return self.processor.tokenizer.clean_up_tokenization(text)


    def decode_simple(self, indices: List[int]) -> str:
        """
        Simple CTC decoding without language model.
        Collapses consecutive duplicate tokens and removes blanks.
        """
        decoded_chars = []
        previous_idx = None

        # Ensure indices is a list of scalars
        if isinstance(indices, (torch.Tensor, np.ndarray)):
            indices = indices.tolist()

        print("[DEBUG] Decoding Indices:", indices)

        for idx in indices:
            # Flatten or extract scalar from idx if it's a list/tensor
            if isinstance(idx, (list, torch.Tensor, np.ndarray)):
                idx = idx[0] if isinstance(idx, list) else idx.item()

            if not isinstance(idx, int):
                print(f"[DEBUG] Skipping non-integer index: {idx}")
                continue  # Skip invalid indices

            if idx == self.blank_index:
                previous_idx = idx
                continue  # Skip blank tokens

            if idx == previous_idx:
                continue  # Skip duplicate tokens

            if 0 <= idx < len(self.ind2char):
                decoded_chars.append(self.ind2char[idx])
            else:
                print(f"[DEBUG] Invalid index encountered: {idx}")

            previous_idx = idx

        # Join characters into a string
        decoded_text = "".join(decoded_chars).strip().lower()

        # Clean up tokenization if tokenizer is available
        if hasattr(self.processor, "tokenizer"):
            return self.processor.tokenizer.clean_up_tokenization(decoded_text)

        return decoded_text



    def decode_logits(self, logits: Union[torch.Tensor, List[List[float]], np.ndarray]) -> str:
        """
        Decode logits using the decoder if available, otherwise use greedy decoding.
        """
        if isinstance(logits, torch.Tensor):
            logits = logits.cpu().numpy()
        elif isinstance(logits, list):
            logits = np.array(logits)
        elif not isinstance(logits, np.ndarray):
            raise TypeError("logits must be a torch.Tensor, list of lists, or numpy.ndarray")

        if logits.ndim == 3:
            logits = logits[0]

        if logits.ndim != 2:
            raise ValueError(f"Logits should be 2D (time_steps, vocab_size), got {logits.ndim}D")

        if self.decoder:
            decoded_text = self.decoder.decode(logits)
            return decoded_text
        else:
            predicted_indices = np.argmax(logits, axis=-1).tolist()
            return self.decode_simple(predicted_indices)

    def decode_indices(self, indices: Union[torch.Tensor, List[int], np.ndarray]) -> str:
        """
        Decode token indices to text using simple decoding (no LM).
        """
        if isinstance(indices, torch.Tensor):
            indices = indices.squeeze().tolist()
        elif isinstance(indices, np.ndarray):
            indices = indices.tolist()
        elif not isinstance(indices, list):
            raise TypeError("decode_indices expects a list, torch.Tensor, or numpy.ndarray.")

        return self.decode_simple(indices)


    # latest ver
    def ctc_decode(self, logits: Union[torch.Tensor, List[int], np.ndarray]) -> str:
        """
        Perform CTC decoding on logits.
        """
        if isinstance(logits, np.ndarray):
            logits = torch.from_numpy(logits)
        elif isinstance(logits, list):
            logits = torch.tensor(logits)

        if logits.dim() == 3:
            decoded_text = self.decode_logits(logits)
            return decoded_text.lower()
        elif logits.dim() == 2:
            decoded_text = self.decode_logits(logits)
            return decoded_text.lower()
        elif logits.dim() == 1:
            decoded_text = self.decode_indices(logits)
            return decoded_text.lower()
        else:
            raise ValueError(f"Unsupported logits shape: {logits.shape}. Expected 1D, 2D, or 3D.")

    def ctc_beam_search(self, probs, beam_size: int = 40,
                       use_lm: bool = False, debug: bool = False) -> List[Tuple[str, float]]:
        """
        Beam search with Wav2Vec2 support.
        """
        beam_size = self.beam_size
        debug = False

        if use_lm and self.decoder is not None:
            try:
                if isinstance(probs, torch.Tensor):
                    probs = probs.cpu().numpy()
                elif isinstance(probs, list):
                    probs = np.array(probs)
                elif isinstance(probs, np.ndarray):
                    pass
                else:
                    raise TypeError("probs must be a torch.Tensor, list, or numpy.ndarray")

                beams = self.decoder.decode_beams(
                    probs,
                    beam_prune_logp=-10.0,
                    token_min_logp=-5.0,
                    hotwords=[],
                    hotword_weight=10.0,
                )

                formatted_beams = []
                for beam in beams[:beam_size]:
                    text = beam[0]
                    acoustic_score = beam[3]
                    lm_score = beam[4]

                    text = self.processor.tokenizer.clean_up_tokenization(text)
                    text = text.lower().strip()

                    combined_score = (1 - self.lm_weight) * acoustic_score + self.lm_weight * lm_score
                    text_len = max(1, len(text.split()))
                    normalized_score = combined_score / text_len

                    formatted_beams.append((text, normalized_score))

                if debug:
                    print("\nFormatted beam results with Wav2Vec2:")
                    for text, score in formatted_beams[:3]:
                        print(f"Text: '{text}', Score: {score:.4f}")

                if formatted_beams:
                    return sorted(formatted_beams, key=lambda x: -x[1])
                else:
                    print("No valid beams found, falling back to standard beam search")
                    return self._standard_beam_search(probs, beam_size, debug)

            except Exception as e:
                print(f"Beam search with LM failed: {str(e)}, falling back to standard beam search")
                return self._standard_beam_search(probs, beam_size, debug)
        else:
            return self._standard_beam_search(probs, beam_size, debug)

    def _standard_beam_search(self, probs, beam_size: int = 10, debug: bool = False) -> List[Tuple[str, float]]:
        """Original beam search implementation with improved debugging"""
        beam_size = self.beam_size

        if isinstance(probs, np.ndarray):
            probs = torch.from_numpy(probs)

        if probs.device != torch.device('cpu'):
            probs = probs.cpu()

        dp = {("", self.blank_token): 0.0}
        log_probs = torch.log(probs + 1e-8)

        if debug:
            print("\nStarting beam search with beam size:", beam_size)

        for t, prob in enumerate(log_probs):
            new_dp = defaultdict(lambda: float('-inf'))
            top_k = torch.topk(prob, k=min(beam_size, len(prob)))

            if debug and t < self.max_printed_samples:
                print(f"\nTimestep {t}:")
                print("Top tokens:", [(self.ind2char[idx.item()], val.item()) 
                                    for val, idx in zip(top_k.values, top_k.indices)])

            for val, ind in zip(top_k.values, top_k.indices):
                curr_char = self.ind2char[ind.item()]
                next_token_log_prob = val.item()

                for (prefix, last_char), log_prob in dp.items():
                    if last_char == curr_char and curr_char != " ":
                        new_prefix = prefix
                    else:
                        if curr_char != self.blank_token:
                            if curr_char == " " and prefix.endswith(" "):
                                continue
                            new_prefix = prefix + curr_char
                        else:
                            new_prefix = prefix

                    new_log_prob = log_prob + next_token_log_prob
                    key = (new_prefix, curr_char)
                    new_dp[key] = max(new_dp[key], new_log_prob)

            if len(new_dp) > 0:
                max_score = max(score for _, score in new_dp.items())
                new_dp = {key: score - max_score for key, score in new_dp.items()}

            dp = dict(sorted(new_dp.items(), key=lambda x: -x[1])[:beam_size])

            if debug and t < 2:
                print("\nCurrent beam:")
                for (text, last_char), score in list(dp.items())[:3]:
                    print(f"Text: '{text}', Last: '{last_char}', Score: {score:.4f}")

        final_beams = []
        for (text, _), score in dp.items():
            text = self.processor.tokenizer.clean_up_tokenization(text)
            text = text.lower().strip()
            text_len = max(1, len(text.split()))
            normalized_score = score / text_len
            final_beams.append((text, normalized_score))

        final_beams.sort(key=lambda x: -x[1])
        if not final_beams:
            final_beams = [("", float('-inf'))]

        return final_beams[:beam_size]

    def test_language_model(self):
        """Debug function to verify LM functionality"""
        print("\nTesting Language Model...")

        if self.lm is None:
            print("Error: Language model is not loaded!")
            return

        test_sentences = [
            "this is a good sentence",
            "this is also a good sentence",
            "thiss iss nott aa goodd sentencee",
            "random word salad box cat",
            "the cat sat on the mat",
            "",
            "a",
        ]

        print("\nTesting individual sentences:")
        for sentence in test_sentences:
            score = self.score_with_lm(sentence)
            print(f"\nText: '{sentence}'")
            print(f"LM Score: {score:.4f}")

        test_prefixes = [
            "the quick brown",
            "how are",
            "thank",
            "nice to",
        ]

        print("\nTesting word completions:")
        for prefix in test_prefixes:
            print(f"\nPrefix: '{prefix}'")
            completions = [
                prefix + " " + word for word in ["you", "fox", "cat", "xyz", "meet"]
            ]
            scores = [(completion, self.score_with_lm(completion)) 
                    for completion in completions]
            scores.sort(key=lambda x: x[1], reverse=True)
            print("Top completions by score:")
            for completion, score in scores[:3]:
                print(f"  '{completion}': {score:.4f}")

    def score_with_lm(self, text: str) -> float:
        """
        Score text using language model, handling edge cases
        """
        if self.lm is None:
            return 0.0

        if not text or len(text.strip()) == 0:
            return float('-inf')

        text = text.lower().strip()
        return self.lm.score(text, bos=True, eos=True)

    def _basic_ctc_decode(self, logits: np.ndarray, sequence_length: int) -> List[str]:
        """Basic CTC decoding without LM"""
        argmax_indices = np.argmax(logits, axis=-1)

        if len(argmax_indices.shape) == 0:
            argmax_indices = np.array([argmax_indices])

        if len(argmax_indices.shape) == 1:
            argmax_indices = np.expand_dims(argmax_indices, axis=0)

        predictions = []
        for sequence in argmax_indices:
            decoded = []
            last_idx = None

            for idx in sequence[:sequence_length]:
                if idx != self.blank_index and idx != last_idx:
                    decoded.append(self.ind2char[idx])
                last_idx = idx

            text = "".join(decoded)
            if hasattr(self, 'processor'):
                text = self.processor.tokenizer.clean_up_tokenization(text)
            predictions.append(text)

        return predictions

    @staticmethod
    def normalize_text(text: str) -> str:
        """Normalize input text"""
        # text = text.lower()
        text = text.lower()
        text = re.sub(r"[^a-z ]", "", text)
        # text = re.sub(r"[^A-Z ]", "", text)
        return text


    def test_decoder(self, sample_text: str = "test decoder functionality"):
        """Test the decoder setup"""
        print("\nTesting decoder configuration...")

        encoded = self.encode(sample_text)
        decoded = self.decode(encoded[0].tolist())
        print(f"Original text: {sample_text}")
        print(f"Basic decode: {decoded}")

        sequence_length = 50
        vocab_size = len(self)
        fake_logits = torch.randn(1, sequence_length, vocab_size)
        fake_length = torch.tensor([sequence_length])

        if self.decoder is not None:
            print("\nTesting pyctcdecode integration...")
            decoded_with_lm = self.ctc_decode(fake_logits)
            print(f"Decoded with LM: {decoded_with_lm}")

            print(f"\nBeam width: {self.beam_size}")
            print(f"LM weight: {self.lm_weight}")
        else:
            print("\nNo language model loaded - using basic CTC decoding")
            basic_decoded = self._basic_ctc_decode(fake_logits.numpy(), fake_length)
            print(f"Basic CTC decoded: {basic_decoded[0]}")

    def __len__(self):
        return len(self.vocab)


In [8]:
%cd ..

/workspace/sound_asr


In [9]:
!ls -l

total 22930916
-rw-r--r--   1 root root 2393853947 Oct  3  2017 3-gram.arpa
-rw-r--r--   1 root root 1819258851 Dec 18 23:09 3-gram.bin
-rw-r--r--   1 root root 2393853947 Dec 18 22:43 3-gram_lc.arpa
-rw-r--r--   1 root root 1819258851 Dec 19 21:09 3-gram_lc.bin
-rw-r--r--   1 root root 4395628122 Oct  3  2017 4-gram.arpa
-rw-r--r--   1 root root 3124591979 Dec 18 23:08 4-gram.bin
-rw-r--r--   1 root root 4395628122 Dec 18 22:43 4-gram_lc.arpa
-rw-r--r--   1 root root 3124591979 Dec 19 21:08 4-gram_lc.bin
-rw-r--r--   1 root root       1070 Dec 18 22:31 LICENSE
-rwxr-xr-x   1 root root        743 Dec 19 21:06 LM_setup.sh
-rw-r--r--   1 root root       2527 Dec 18 23:09 README.md
-rw-r--r--   1 root root          0 Dec 18 22:31 __init__.py
drwxr-xr-x   3 root root         30 Dec 18 22:48 data
-rw-r--r--   1 root root       2217 Dec 18 22:31 inference.py
drwxr-xr-x   9 root root       4096 Dec 18 22:43 kenlm
-rw-r--r--   1 root root    1737588 Oct  3  2017 librispeech-vocab.txt
-rw-r--r-

In [20]:
with open("4-gram_lc.arpa", "r") as read_file, open("4-gram_lc_correct.arpa", "w") as write_file:
  has_added_eos = False
  for line in read_file:
    if not has_added_eos and "ngram 1=" in line:
      count=line.strip().split("=")[-1]
      write_file.write(line.replace(f"{count}", f"{int(count)+1}"))
    elif not has_added_eos and "<s>" in line:
      write_file.write(line)
      write_file.write(line.replace("<s>", "</s>"))
      has_added_eos = True
    else:
      write_file.write(line)

In [11]:
from pyctcdecode import build_ctcdecoder

decoder = build_ctcdecoder(
    labels=list(sorted_vocab_dict.keys()),
    kenlm_model_path="4-gram_correct_lc.arpa",
)

Loading the LM will be faster if you build a binary file.
Reading /workspace/sound_asr/4-gram_correct.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************
Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as BPE type. Is this correct?
Unigrams and labels don't seem to agree.


In [18]:
import os
unigram_path = 'librispeech-vocab.txt'

if unigram_path and os.path.exists(unigram_path):
    print(f"Loading unigrams from: {unigram_path}")
    with open(unigram_path, 'r', encoding='utf-8') as f:
        unigrams = [line.strip().lower() for line in f if line.strip()]
    print(f"Loaded {len(unigrams)} unigrams")

Loading unigrams from: librispeech-vocab.txt
Loaded 200000 unigrams


In [50]:
decoder = build_ctcdecoder(
    labels=list(sorted_vocab_dict.keys()),
    kenlm_model_path="4-gram_lc_correct.bin",
    unigrams=unigrams
)

Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as BPE type. Is this correct?


In [51]:
labels = list(sorted_vocab_dict.keys())
labels

['<pad>',
 '<s>',
 '</s>',
 '<unk>',
 '|',
 'e',
 't',
 'a',
 'o',
 'n',
 'i',
 'h',
 's',
 'r',
 'd',
 'l',
 'u',
 'm',
 'w',
 'c',
 'f',
 'g',
 'y',
 'p',
 'b',
 'v',
 'k',
 "'",
 'x',
 'j',
 'q',
 'z']

In [17]:
unigrams

['A',
 "A''S",
 "A'BODY",
 "A'COURT",
 "A'D",
 "A'GHA",
 "A'GOIN",
 "A'LL",
 "A'M",
 "A'MIGHTY",
 "A'MIGHTY'S",
 "A'MOST",
 "A'N'T",
 "A'PENNY",
 "A'READY",
 "A'RIGHT",
 "A'RONY",
 "A'S",
 "A'TER",
 "A'TERNOON",
 "A'TERWARDS",
 "A'THEGITHER",
 "A'THING",
 "A'TIM",
 "A'VE",
 'AA',
 'AAANTHOR',
 'AACHEN',
 'AAD',
 'AAGE',
 "AAGE'S",
 'AAGOT',
 "AAGOT'S",
 'AAH',
 'AAHMES',
 'AAKRE',
 'AAL',
 'AALBOM',
 'AALST',
 'AAMASH',
 'AAN',
 'AANA',
 'AANY',
 'AAR',
 'AARAAF',
 'AARAU',
 'AARD',
 'AARON',
 "AARON'S",
 'AARONIC',
 'AARONS',
 'AARONSON',
 'AASA',
 'AASE',
 'AASVOGEL',
 'AASVOGELS',
 'AAT',
 'AB',
 "AB'S",
 'ABA',
 'ABAAT',
 "ABAB'DEH",
 'ABABA',
 'ABABDE',
 'ABABDEH',
 'ABACA',
 'ABACK',
 'ABACO',
 'ABACUS',
 'ABAD',
 'ABADAN',
 'ABADDON',
 'ABAFT',
 'ABAHT',
 'ABAJO',
 'ABALONE',
 'ABALONES',
 'ABANA',
 'ABANAZAR',
 'ABANDON',
 "ABANDON'D",
 'ABANDONED',
 'ABANDONEDLY',
 'ABANDONING',
 'ABANDONMENT',
 'ABANDONMENTS',
 'ABANDONS',
 'ABAOUT',
 'ABARA',
 'ABARAK',
 'ABARIA',
 'ABARIAN'

In [17]:
import re
from collections import defaultdict
import torch
import kenlm
from transformers import Wav2Vec2Processor, AutoProcessor
from pyctcdecode import build_ctcdecoder
import numpy as np
import os

from typing import List, Tuple, Optional, Union

os.environ["TOKENIZERS_PARALLELISM"] = "false"

class CTCTextEncoder:
    def __init__(
        self,
        arpa_path: Optional[str] = None,
        binary_path: Optional[str] = None,
        unigram_path: Optional[str] = None,
        pretrained_tokenizer: str = "facebook/wav2vec2-base-960h",
        lm_weight: float = 0.5,
        beam_size: int = 100,
        blank_token: str = "[pad]",  # Blank token as <pad> for Wav2Vec2
        unk_token: str = "[unk]",     # UNK token
        **kwargs
    ):
        """
        Initialize encoder with Wav2Vec2 tokenizer and beam search decoder.

        Changes:
        - Ensure normalization is strictly lowercase with only [a-z ].
        """
        self.beam_size = beam_size
        self.lm_weight = lm_weight
        self.arpa_path = arpa_path
        self.binary_path = binary_path
        self.blank_token = blank_token
        self.unk_token = unk_token
        self.printed_samples = 0
        self.max_printed_samples = 5
        print('CTC Text Encoder:')
        print('pretrained_tokenizer:', pretrained_tokenizer)
        print('lm_weight:', lm_weight)
        print('beam_size:', beam_size)
        print('binary_path:', binary_path)

        # unigram_path = None

        # Load unigrams if provided
        
        self.unigrams = None
        
        # if unigram_path and os.path.exists(unigram_path):
        #     print(f"Loading unigrams from: {unigram_path}")
        #     with open(unigram_path, 'r', encoding='utf-8') as f:
        #         self.unigrams = [line.strip().lower() for line in f if line.strip()]
        #     print(f"Loaded {len(self.unigrams)} unigrams")
        
        if unigram_path and os.path.exists(unigram_path):
            print(f"Loading unigrams from: {unigram_path}")
            with open(unigram_path, 'r', encoding='utf-8') as f:
                self.unigrams = [line.strip().lower() for line in f if line.strip()]
            print(f"Loaded {len(self.unigrams)} unigrams")


        self._initialize_wav2vec2_tokenizer(pretrained_tokenizer)

        # Create index mappings
        self.ind2char = dict(enumerate(self.vocab))
        self.char2ind = {v: k for k, v in self.ind2char.items()}
        self.blank_index = self.char2ind[self.blank_token]

        print(f"\nVocabulary Info:")
        print(f"Size: {len(self.vocab)}")
        print("Full Vocabulary (up to first 50 tokens):", self.vocab[:50])
        print(f"Blank token: {self.blank_token}, Blank index: {self.blank_index}")

        print("Sample ind2char mappings:", {k: self.ind2char[k] for k in list(self.ind2char.keys())[:10]})
        print("Sample char2ind mappings:", {k: self.char2ind[k] for k in list(self.char2ind.keys())[:10]})

        self._initialize_language_model() # CHECKING WITHOUT

    # def _initialize_wav2vec2_tokenizer(self, pretrained_tokenizer: str):
    #     """Initialize vocabulary using Wav2Vec2 tokenizer."""
    #     self.processor = Wav2Vec2Processor.from_pretrained(pretrained_tokenizer)

    #     # Add the unique blank token if not present
    #     if self.blank_token not in self.processor.tokenizer.get_vocab():
    #         self.processor.tokenizer.add_tokens([self.blank_token])
    #         print(f"Added '{self.blank_token}' to the tokenizer's vocabulary.")
    #     else:
    #         print(f"'{self.blank_token}' already exists in the tokenizer's vocabulary.")

    #     # Add the UNK token if not present
    #     if self.unk_token not in self.processor.tokenizer.get_vocab():
    #         self.processor.tokenizer.add_tokens([self.unk_token])
    #         print(f"Added '{self.unk_token}' to the tokenizer's vocabulary.")
    #     else:
    #         print(f"'{self.unk_token}' already exists in the tokenizer's vocabulary.")

    #     # Get vocab, convert to lowercase and replace '|' with ' '
    #     original_vocab = list(self.processor.tokenizer.get_vocab().keys())
    #     self.vocab = [x.lower() for x in original_vocab]
    #     self.vocab = [t.replace('|', ' ') for t in self.vocab]

    #     # Debug: Print a few tokens after modification
    #     print("Modified Vocabulary (first 20 tokens):", self.vocab[:20])
    

    def _initialize_wav2vec2_tokenizer(self, pretrained_tokenizer: str):
        """Initialize vocabulary using Wav2Vec2 tokenizer."""
        # Initialize Wav2Vec2Processor
        # self.processor = Wav2Vec2Processor.from_pretrained(pretrained_tokenizer)

        

        # # Add the UNK token if not present
        # if self.unk_token not in self.processor.tokenizer.get_vocab():
        #     self.processor.tokenizer.add_tokens([self.unk_token])
        #     print(f"Added '{self.unk_token}' to the tokenizer's vocabulary.")
        # else:
        #     print(f"'{self.unk_token}' already exists in the tokenizer's vocabulary.")

        # # Get vocab without altering the case
        # self.vocab = list(self.processor.tokenizer.get_vocab().keys())

        # vocab_dict = processor.tokenizer.get_vocab()
        # sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
        # labels_adj = list(sorted_vocab_dict.keys())

        # self.vocab = [t.replace('|', ' ') for t in self.vocab]

        # # Debugging: Inspect the tokenizer's vocabulary casing
        # print("\n--- Tokenizer Vocabulary Inspection ---")
        # sample_size = 100  # Adjust as needed
        # sample_tokens = self.vocab[:sample_size]
        # print(f"First {sample_size} tokens in vocabulary:")
        # print(sample_tokens)

        # # Save the full tokenizer vocabulary to a file for comparison
        # with open("tokenizer_vocab.txt", "w", encoding="utf-8") as f:
        #     for token in self.vocab:
        #         f.write(f"{token}\n")
        # print("Full tokenizer vocabulary saved to 'tokenizer_vocab.txt'.")
        # print("----------------------------------------\n")

        self.processor = AutoProcessor.from_pretrained("hf-test/xls-r-300m-sv")

        vocab_dict = self.processor.tokenizer.get_vocab()
        sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}

        self.labels = list(sorted_vocab_dict.keys())

        # Get vocab, convert to lowercase and replace '|' with ' '
        original_vocab = list(self.processor.tokenizer.get_vocab().keys())
        # self.vocab = [x.lower() for x in original_vocab]
        self.vocab = sorted_vocab_dict
        self.vocab = [t.replace('|', ' ') for t in self.vocab]

        # Add the unique blank token to the tokenizer's vocabulary if not present
        if self.blank_token not in self.processor.tokenizer.get_vocab():
            self.processor.tokenizer.add_tokens([self.blank_token])
            print(f"Added '{self.blank_token}' to the tokenizer's vocabulary.")
        else:
            print(f"'{self.blank_token}' already exists in the tokenizer's vocabulary.")

        # Debug: Print a few tokens after modification
        print("Modified Vocabulary (first 20 tokens):", self.vocab[:20])


    def _initialize_language_model(self):
        """Initialize language model with explicit blank token handling."""
        self.lm = None
        self.decoder = None

        model_path = self.binary_path if self.binary_path else self.arpa_path
        print('model_path: ', model_path)
        if not model_path or not os.path.exists(model_path):
            print("No language model path provided or file does not exist.")
            return

        try:
            self.lm = kenlm.Model(model_path)
            print(f"Loaded {'binary' if self.binary_path else 'ARPA'} language model.")

            # labels = [self.blank_token] + [c for c in self.vocab if c != self.blank_token]

            decoder_config = {
                "labels": self.labels,
                "kenlm_model_path": model_path,
                "alpha": self.lm_weight,
                "beta": 0.1,
                "unk_score_offset": -10.0,
            }

            if self.unigrams:
                print("\n--- Unigrams List ---")
                # Save the unigrams to a file
                with open("unigrams_list.txt", "w", encoding="utf-8") as f:
                    for unigram in self.unigrams:
                        f.write(f"{unigram}\n")
                print(f"Unigrams list saved to 'unigrams_list.txt'. Total unigrams: {len(self.unigrams)}")
                print("----------------------\n")
                decoder_config["unigrams"] = self.unigrams

            self.decoder = build_ctcdecoder(**decoder_config)
            print("Successfully initialized language model and decoder.")

        except Exception as e:
            print(f"Warning: Failed to initialize decoder: {str(e)}")
            self.decoder = None


    def encode(self, text: str) -> torch.Tensor:
        """
        Encode text with Wav2Vec2 tokenizer.
        """
        debug = False

        if self.printed_samples < self.max_printed_samples:
            original_text = text
            text = self.normalize_text(text)
            if debug:
                print(f"samples: {str(self.printed_samples)}")
                print(f"\nEncoding text:")
                print(f" Original: '{original_text}'")
                print(f" Normalized: '{text}'")
                for ch in text:
                    print(ch, ord(ch))


        try:
            encoded = self.processor.tokenizer(text, return_tensors="pt", padding=False, truncation=False)
            token_indices = encoded.input_ids[0].tolist()
            if self.printed_samples < self.max_printed_samples:
                # Convert indices to tokens from self.vocab
                tokens = [self.vocab[idx] if 0 <= idx < len(self.vocab) else "<invalid>" for idx in token_indices]
                # print(f" Tokens (lowercased and '|'->' '): {tokens}")
                # print(f" Token indices: {token_indices}")
                self.printed_samples += 1
            return torch.tensor(token_indices).unsqueeze(0)
        except KeyError as e:
            unknown_tokens = set([token for token in text.split() if token not in self.char2ind])
            raise Exception(f"Unknown tokens: '{' '.join(unknown_tokens)}'")
        except Exception as e:
            raise Exception(f"Encoding error: {str(e)}")

    # latest ver
    # def decode(self, indices: List[int]) -> str:
    #     """
    #     Decode indices to text using beam search decoder if available.
    #     """
    #     if self.decoder:
    #         decoded_text = self.decoder.decode(indices)
    #         # convert to lower case
    #         decoded_text = decoded_text.lower()
    #         return decoded_text
    #     else:
    #         decoded_text = self.decode_simple(indices)
    #         # convert to lower case
    #         decoded_text = decoded_text.lower()
    #         return
    
    def decode(self, indices: List[int]) -> str:
        """
        Decode indices to text using beam search decoder if available.
        """
        if self.decoder:
            decoded_text = self.decoder.decode(indices)
            # Convert to lower case
            decoded_text = decoded_text.lower()
            return decoded_text
        else:
            decoded_text = self.decode_simple(indices)
            # Convert to lower case
            decoded_text = decoded_text.lower()
            return decoded_text  # Ensure the decoded text is returned

    # latest ver
    # def decode_simple(self, indices: List[int]) -> str:
    #     """
    #     Simple CTC decoding without language model.
    #     """
    #     valid_indices = [
    #         idx for idx in indices
    #         if idx != self.blank_index and 0 <= idx < len(self.ind2char)
    #     ]
    #     try:
    #         tokens = [self.ind2char[idx] for idx in valid_indices]
    #         text = " ".join(tokens).strip().lower()
    #         return self.processor.tokenizer.clean_up_tokenization(text)
    #     except KeyError as e:
    #         return " ".join([self.ind2char[idx] for idx in valid_indices if idx in self.ind2char])

    def decode_simple(self, indices: List[int]) -> str:
        """
        Simple CTC decoding without language model.
        Collapses consecutive duplicate tokens and removes blanks.
        """
        decoded_chars = []
        previous_idx = None

        for idx in indices:
            if idx == self.blank_index:
                previous_idx = idx
                continue  # Skip blank tokens
            if idx == previous_idx:
                continue  # Skip duplicate tokens
            if 0 <= idx < len(self.ind2char):
                char = self.ind2char[idx]
                decoded_chars.append(char)
            previous_idx = idx

        # Join characters without spaces and convert to lowercase
        text = "".join(decoded_chars).strip().lower()

        # Clean up tokenization using the tokenizer's method
        return self.processor.tokenizer.clean_up_tokenization(text)


    def decode_logits(self, logits: Union[torch.Tensor, List[List[float]], np.ndarray]) -> str:
        """
        Decode logits using the decoder if available, otherwise use greedy decoding.
        """
        if isinstance(logits, torch.Tensor):
            logits = logits.cpu().numpy()
        elif isinstance(logits, list):
            logits = np.array(logits)
        elif not isinstance(logits, np.ndarray):
            raise TypeError("logits must be a torch.Tensor, list of lists, or numpy.ndarray")

        if logits.ndim == 3:
            logits = logits[0]

        if logits.ndim != 2:
            raise ValueError(f"Logits should be 2D (time_steps, vocab_size), got {logits.ndim}D")

        if self.decoder:
            decoded_text = self.decoder.decode(logits)
            return decoded_text
        else:
            predicted_indices = np.argmax(logits, axis=-1).tolist()
            return self.decode_simple(predicted_indices)

    def decode_indices(self, indices: Union[torch.Tensor, List[int], np.ndarray]) -> str:
        """
        Decode token indices to text using simple decoding (no LM).
        """
        if isinstance(indices, torch.Tensor):
            indices = indices.squeeze().tolist()
        elif isinstance(indices, np.ndarray):
            indices = indices.tolist()
        elif not isinstance(indices, list):
            raise TypeError("decode_indices expects a list, torch.Tensor, or numpy.ndarray.")

        return self.decode_simple(indices)


    # latest ver
    def ctc_decode(self, logits: Union[torch.Tensor, List[int], np.ndarray]) -> str:
        """
        Perform CTC decoding on logits.
        """
        if isinstance(logits, np.ndarray):
            logits = torch.from_numpy(logits)
        elif isinstance(logits, list):
            logits = torch.tensor(logits)

        if logits.dim() == 3:
            decoded_text = self.decode_logits(logits)
            return decoded_text
        elif logits.dim() == 2:
            decoded_text = self.decode_logits(logits)
            return decoded_text
        elif logits.dim() == 1:
            decoded_text = self.decode_indices(logits)
            return decoded_text
        else:
            raise ValueError(f"Unsupported logits shape: {logits.shape}. Expected 1D, 2D, or 3D.")

    def ctc_beam_search(self, probs, beam_size: int = 40,
                       use_lm: bool = False, debug: bool = False) -> List[Tuple[str, float]]:
        """
        Beam search with Wav2Vec2 support.
        """
        beam_size = self.beam_size
        debug = False

        if use_lm and self.decoder is not None:
            try:
                if isinstance(probs, torch.Tensor):
                    probs = probs.cpu().numpy()
                elif isinstance(probs, list):
                    probs = np.array(probs)
                elif isinstance(probs, np.ndarray):
                    pass
                else:
                    raise TypeError("probs must be a torch.Tensor, list, or numpy.ndarray")

                beams = self.decoder.decode_beams(
                    probs,
                    beam_prune_logp=-10.0,
                    token_min_logp=-5.0,
                    hotwords=[],
                    hotword_weight=10.0,
                )

                formatted_beams = []
                for beam in beams[:beam_size]:
                    text = beam[0]
                    acoustic_score = beam[3]
                    lm_score = beam[4]

                    text = self.processor.tokenizer.clean_up_tokenization(text)
                    text = text.lower().strip()

                    combined_score = (1 - self.lm_weight) * acoustic_score + self.lm_weight * lm_score
                    text_len = max(1, len(text.split()))
                    normalized_score = combined_score / text_len

                    formatted_beams.append((text, normalized_score))

                if debug:
                    print("\nFormatted beam results with Wav2Vec2:")
                    for text, score in formatted_beams[:3]:
                        print(f"Text: '{text}', Score: {score:.4f}")

                if formatted_beams:
                    return sorted(formatted_beams, key=lambda x: -x[1])
                else:
                    print("No valid beams found, falling back to standard beam search")
                    return self._standard_beam_search(probs, beam_size, debug)

            except Exception as e:
                print(f"Beam search with LM failed: {str(e)}, falling back to standard beam search")
                return self._standard_beam_search(probs, beam_size, debug)
        else:
            return self._standard_beam_search(probs, beam_size, debug)

    def _standard_beam_search(self, probs, beam_size: int = 10, debug: bool = False) -> List[Tuple[str, float]]:
        """Original beam search implementation with improved debugging"""
        beam_size = self.beam_size

        if isinstance(probs, np.ndarray):
            probs = torch.from_numpy(probs)

        if probs.device != torch.device('cpu'):
            probs = probs.cpu()

        dp = {("", self.blank_token): 0.0}
        log_probs = torch.log(probs + 1e-8)

        if debug:
            print("\nStarting beam search with beam size:", beam_size)

        for t, prob in enumerate(log_probs):
            new_dp = defaultdict(lambda: float('-inf'))
            top_k = torch.topk(prob, k=min(beam_size, len(prob)))

            if debug and t < self.max_printed_samples:
                print(f"\nTimestep {t}:")
                print("Top tokens:", [(self.ind2char[idx.item()], val.item()) 
                                    for val, idx in zip(top_k.values, top_k.indices)])

            for val, ind in zip(top_k.values, top_k.indices):
                curr_char = self.ind2char[ind.item()]
                next_token_log_prob = val.item()

                for (prefix, last_char), log_prob in dp.items():
                    if last_char == curr_char and curr_char != " ":
                        new_prefix = prefix
                    else:
                        if curr_char != self.blank_token:
                            if curr_char == " " and prefix.endswith(" "):
                                continue
                            new_prefix = prefix + curr_char
                        else:
                            new_prefix = prefix

                    new_log_prob = log_prob + next_token_log_prob
                    key = (new_prefix, curr_char)
                    new_dp[key] = max(new_dp[key], new_log_prob)

            if len(new_dp) > 0:
                max_score = max(score for _, score in new_dp.items())
                new_dp = {key: score - max_score for key, score in new_dp.items()}

            dp = dict(sorted(new_dp.items(), key=lambda x: -x[1])[:beam_size])

            if debug and t < 2:
                print("\nCurrent beam:")
                for (text, last_char), score in list(dp.items())[:3]:
                    print(f"Text: '{text}', Last: '{last_char}', Score: {score:.4f}")

        final_beams = []
        for (text, _), score in dp.items():
            text = self.processor.tokenizer.clean_up_tokenization(text)
            text = text.lower().strip()
            text_len = max(1, len(text.split()))
            normalized_score = score / text_len
            final_beams.append((text, normalized_score))

        final_beams.sort(key=lambda x: -x[1])
        if not final_beams:
            final_beams = [("", float('-inf'))]

        return final_beams[:beam_size]

    def test_language_model(self):
        """Debug function to verify LM functionality"""
        print("\nTesting Language Model...")

        if self.lm is None:
            print("Error: Language model is not loaded!")
            return

        test_sentences = [
            "this is a good sentence",
            "this is also a good sentence",
            "thiss iss nott aa goodd sentencee",
            "random word salad box cat",
            "the cat sat on the mat",
            "",
            "a",
        ]

        print("\nTesting individual sentences:")
        for sentence in test_sentences:
            score = self.score_with_lm(sentence)
            print(f"\nText: '{sentence}'")
            print(f"LM Score: {score:.4f}")

        test_prefixes = [
            "the quick brown",
            "how are",
            "thank",
            "nice to",
        ]

        print("\nTesting word completions:")
        for prefix in test_prefixes:
            print(f"\nPrefix: '{prefix}'")
            completions = [
                prefix + " " + word for word in ["you", "fox", "cat", "xyz", "meet"]
            ]
            scores = [(completion, self.score_with_lm(completion)) 
                    for completion in completions]
            scores.sort(key=lambda x: x[1], reverse=True)
            print("Top completions by score:")
            for completion, score in scores[:3]:
                print(f"  '{completion}': {score:.4f}")

    def score_with_lm(self, text: str) -> float:
        """
        Score text using language model, handling edge cases
        """
        if self.lm is None:
            return 0.0

        if not text or len(text.strip()) == 0:
            return float('-inf')

        text = text.lower().strip()
        return self.lm.score(text, bos=True, eos=True)

    def _basic_ctc_decode(self, logits: np.ndarray, sequence_length: int) -> List[str]:
        """Basic CTC decoding without LM"""
        argmax_indices = np.argmax(logits, axis=-1)

        if len(argmax_indices.shape) == 0:
            argmax_indices = np.array([argmax_indices])

        if len(argmax_indices.shape) == 1:
            argmax_indices = np.expand_dims(argmax_indices, axis=0)

        predictions = []
        for sequence in argmax_indices:
            decoded = []
            last_idx = None

            for idx in sequence[:sequence_length]:
                if idx != self.blank_index and idx != last_idx:
                    decoded.append(self.ind2char[idx])
                last_idx = idx

            text = "".join(decoded)
            if hasattr(self, 'processor'):
                text = self.processor.tokenizer.clean_up_tokenization(text)
            predictions.append(text)

        return predictions

    @staticmethod
    def normalize_text(text: str) -> str:
        """Normalize input text"""
        # text = text.lower()
        text = text.lower()
        text = re.sub(r"[^a-z ]", "", text)
        # text = re.sub(r"[^A-Z ]", "", text)
        return text


    def test_decoder(self, sample_text: str = "test decoder functionality"):
        """Test the decoder setup"""
        print("\nTesting decoder configuration...")

        encoded = self.encode(sample_text)
        decoded = self.decode(encoded[0].tolist())
        print(f"Original text: {sample_text}")
        print(f"Basic decode: {decoded}")

        sequence_length = 50
        vocab_size = len(self)
        fake_logits = torch.randn(1, sequence_length, vocab_size)
        fake_length = torch.tensor([sequence_length])

        if self.decoder is not None:
            print("\nTesting pyctcdecode integration...")
            decoded_with_lm = self.ctc_decode(fake_logits)
            print(f"Decoded with LM: {decoded_with_lm}")

            print(f"\nBeam width: {self.beam_size}")
            print(f"LM weight: {self.lm_weight}")
        else:
            print("\nNo language model loaded - using basic CTC decoding")
            basic_decoded = self._basic_ctc_decode(fake_logits.numpy(), fake_length)
            print(f"Basic CTC decoded: {basic_decoded[0]}")

    def __len__(self):
        return len(self.vocab)


In [18]:
# Import necessary libraries
import torch
import numpy as np

# Initialize the CTCTextEncoder
encoder = CTCTextEncoder(
    arpa_path="4-gram_lc_correct.arpa",
    binary_path="4-gram_lc_correct.bin",
    unigram_path="librispeech-vocab.txt",
    use_bpe=True,  # Set to False for character-level encoding
    pretrained_tokenizer="bert-base-uncased",
    beam_size=10,
    lm_weight=0.35
)

# 1. Inspect Vocabulary and Mappings
print("=== Vocabulary Inspection ===")
print(f"Vocabulary Size: {len(encoder.vocab)}")
print("Sample ind2char Mapping:", {k: v for k, v in list(encoder.ind2char.items())[:10]})
print("Sample char2ind Mapping:", {k: v for k, v in list(encoder.char2ind.items())[:10]})

# Identify the correct space token
space_char_candidates = ["|", " ", "<space>"]  # Common space tokens
space_char = None
for candidate in space_char_candidates:
    if candidate in encoder.char2ind:
        space_char = candidate
        break

if space_char is None:
    print("Space token not found in char2ind mapping. Adding it now.")
    # Define what the space token should be, e.g., "|"
    space_char = "|"
    if space_char not in encoder.char2ind:
        # Add space token to the tokenizer
        encoder.processor.tokenizer.add_tokens([space_char])
        # Update vocab and mappings
        encoder.vocab.append(space_char)
        new_index = len(encoder.vocab) - 1
        encoder.char2ind[space_char] = new_index
        encoder.ind2char[new_index] = space_char
        print(f"Added space token '{space_char}' with index {new_index}.")

# Get space and blank indices
blank_char = encoder.blank_token
space_index = encoder.char2ind.get(space_char, None)
blank_index = encoder.char2ind.get(blank_char, None)

print(f"\nSpace Token ('{space_char}') Index: {space_index}")
print(f"Blank Token ('{blank_char}') Index: {blank_index}")

# Ensure that space and blank tokens exist
assert space_index is not None, f"Space token '{space_char}' not found in char2ind mapping."
assert blank_index is not None, f"Blank token '{blank_char}' not found in char2ind mapping."

# 2. Create Test Sequence
# Define the test phrase
test_phrase = "HELLO WORLD"  # Using '|' as space
print(f"\nTest Phrase: '{test_phrase}'")

# Encode the test phrase into token indices
encoded_tensor = encoder.encode(test_phrase)
encoded_indices = encoded_tensor[0].tolist()
print(f"Encoded Indices for '{test_phrase}': {encoded_indices}")

# Insert duplicates and blanks into the sequence
# Example: h, h, <pad>, e, e, <pad>, l, l, <pad>, etc.
test_sequence = []
for idx in encoded_indices:
    test_sequence.extend([idx, idx, blank_index])

# Add space with duplicates
test_sequence.extend([space_index, space_index, blank_index])

# Add the rest of the words similarly
additional_phrase = "world"
additional_indices = [encoder.char2ind.get(char, blank_index) for char in additional_phrase]
for idx in additional_indices:
    test_sequence.extend([idx, idx, blank_index])

print(f"\nTest Sequence with Duplicates and Blanks: {test_sequence}")

# 3. Decode Using decode_simple()
decoded_simple = encoder.decode_simple(test_sequence)
print(f"\nDecoded with decode_simple: '{decoded_simple}'")

# 4. Decode Using decode() (Beam Search with LM)
# Simulate logits where the correct tokens have high probabilities
vocab_size = len(encoder)
time_steps = len(test_sequence)
logits = np.full((time_steps, vocab_size), -10.0)  # Initialize with low logits

for t, idx in enumerate(test_sequence):
    if idx is not None and idx < vocab_size:
        logits[t, idx] = 10.0  # Assign high logit to the correct token

# Convert logits to torch tensor with shape (batch_size=1, time_steps, vocab_size)
logits_tensor = torch.tensor(logits).unsqueeze(0)  # Shape: (1, time_steps, vocab_size)

# Decode using decode (beam search with LM)
decoded_beam = encoder.decode(logits_tensor[0].numpy())  # Pass logits for single batch
print(f"\nDecoded with decode (beam search): '{decoded_beam}'")

# 5. Decode Using ctc_decode()
decoded_ctc = encoder.ctc_decode(logits_tensor)
print(f"\nDecoded with ctc_decode: '{decoded_ctc}'")

# 6. Assertions to Verify Correctness
expected_output = test_phrase.lower().replace("|", " ")
print(f"\nExpected Output: '{expected_output}'")

# Assert decode_simple
assert decoded_simple == expected_output, "decode_simple failed to correctly decode the phrase."

# Assert decode (beam search)
assert decoded_beam == expected_output, "decode (beam search) failed to correctly decode the phrase."

# Assert ctc_decode
assert decoded_ctc == expected_output, "ctc_decode failed to correctly decode the phrase."

print("\n✅ All decoding methods passed the unit test!")


CTC Text Encoder:
pretrained_tokenizer: bert-base-uncased
lm_weight: 0.35
beam_size: 10
binary_path: 4-gram_lc_correct.bin


Fetching 4 files: 100%|█| 4/4 [00:00<00:00, 45221


Added '[pad]' to the tokenizer's vocabulary.
Modified Vocabulary (first 20 tokens): [' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's']

Vocabulary Info:
Size: 37
Full Vocabulary (up to first 50 tokens): [' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'ä', 'å', 'é', 'ô', 'ö', 'ü', '[unk]', '[pad]', '<s>', '</s>']
Blank token: [pad], Blank index: 34
Sample ind2char mappings: {0: ' ', 1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i'}
Sample char2ind mappings: {' ': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9}
model_path:  4-gram_lc_correct.bin
No language model path provided or file does not exist.
=== Vocabulary Inspection ===
Vocabulary Size: 37
Sample ind2char Mapping: {0: ' ', 1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i'}
Sample char2ind Mapping: {' ': 0, 'a': 1, 'b': 2, '

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [11]:
# Import necessary libraries
import torch
import numpy as np

# Initialize the CTCTextEncoder
encoder = CTCTextEncoder(
    arpa_path="4-gram_lc_correct.arpa",
    binary_path="4-gram_lc_correct.bin",
    unigram_path="librispeech-vocab.txt",
    use_bpe=False,  # Ensure BPE is disabled for character-level encoding
    pretrained_tokenizer="bert-base-uncased",
    beam_size=10,
    lm_weight=0.35
)

# 1. Inspect Vocabulary and Mappings
print("=== Vocabulary Inspection ===")
print(f"Vocabulary Size: {len(encoder.vocab)}")
print("Sample ind2char Mapping:", {k: v for k, v in list(encoder.ind2char.items())[:10]})
print("Sample char2ind Mapping:", {k: v for k, v in list(encoder.char2ind.items())[:10]})

# Identify the correct space token
space_char_candidates = ["|", " ", "<space>"]  # Common space tokens
space_char = None
for candidate in space_char_candidates:
    if candidate in encoder.char2ind:
        space_char = candidate
        break

if space_char is None:
    print("Space token not found in char2ind mapping. Adding it now.")
    # Define what the space token should be, e.g., "|"
    space_char = "|"
    if space_char not in encoder.char2ind:
        # Add space token to the tokenizer
        encoder.processor.tokenizer.add_tokens([space_char])
        # Update vocab and mappings
        encoder.vocab.append(space_char)
        new_index = len(encoder.vocab) - 1
        encoder.char2ind[space_char] = new_index
        encoder.ind2char[new_index] = space_char
        print(f"Added space token '{space_char}' with index {new_index}.")

# Get space and blank indices
blank_char = encoder.blank_token
space_index = encoder.char2ind.get(space_char, None)
blank_index = encoder.char2ind.get(blank_char, None)

print(f"\nSpace Token ('{space_char}') Index: {space_index}")
print(f"Blank Token ('{blank_char}') Index: {blank_index}")

# Ensure that space and blank tokens exist
assert space_index is not None, f"Space token '{space_char}' not found in char2ind mapping."
assert blank_index is not None, f"Blank token '{blank_char}' not found in char2ind mapping."

# Ensure space token is correctly mapped to ' '
if encoder.ind2char.get(space_index) != " ":
    encoder.ind2char[space_index] = " "
    encoder.char2ind[" "] = space_index
    print(f"Updated space token mapping to ' ' for index {space_index}.")

# 2. Create Test Sequence
# Define the test phrase
# test_phrase = "hello world"  # Using ' ' as space
test_phrase = "in a moment he communicated his thoughts to his companions and in the next moment they resolved to turn back and carry her off to please rodolfo for the rich who are open handed always find parasites ready to encourage their bad propensities and thus to conceive this wicked design to communicate it approve it resolve on ravishing leocadia and to carry that design into effect was the work of a moment"
print(f"\nTest Phrase: '{test_phrase}'")

# Encode the test phrase into token indices
encoded_tensor = encoder.encode(test_phrase)
encoded_indices = encoded_tensor[0].tolist()
print(f"Encoded Indices for '{test_phrase}': {encoded_indices}")

# Insert duplicates and blanks into the sequence
# Example: h, h, <pad>, e, e, <pad>, l, l, <pad>, etc.
test_sequence = []
for idx in encoded_indices:
    test_sequence.extend([idx, idx, blank_index])

print(f"\nTest Sequence with Duplicates and Blanks: {test_sequence}")

# 3. Decode Using decode_simple_no_cleanup()
print("\n=== Decoding with decode_simple_no_cleanup() ===")

# Define a modified decode_simple method with debug prints
def decode_simple_no_cleanup(self, indices: List[int]) -> str:
    """
    Simple CTC decoding without language model.
    Collapses consecutive duplicate tokens and removes blanks.
    Includes debug prints to trace decoding steps.
    """
    decoded_chars = []
    previous_idx = None

    for idx in indices:
        if idx == self.blank_index:
            previous_idx = idx
            continue  # Skip blank tokens
        if idx == previous_idx:
            continue  # Skip duplicate tokens
        if 0 <= idx < len(self.ind2char):
            char = self.ind2char[idx]
            decoded_chars.append(char)
            print(f"Appended char: '{char}'")
        else:
            print(f"Skipped invalid index: {idx}")
        previous_idx = idx

    # Join characters without spaces and convert to lowercase
    text = "".join(decoded_chars).strip().lower()
    print(f"Text before cleanup: '{text}'")

    # Clean up tokenization using the tokenizer's method
    cleaned_text = self.processor.tokenizer.clean_up_tokenization(text)
    print(f"Text after cleanup: '{cleaned_text}'")

    return cleaned_text

# Bind the new method to the encoder instance
import types
encoder.decode_simple_no_cleanup = types.MethodType(decode_simple_no_cleanup, encoder)

# Perform decoding without cleanup
decoded_simple_no_cleanup = encoder.decode_simple_no_cleanup(test_sequence)
print(f"Decoded without cleanup: '{decoded_simple_no_cleanup}'")

# 4. Decode Using decode_simple()
print("\n=== Decoding with decode_simple() ===")
decoded_simple = encoder.decode_simple(test_sequence)
print(f"Decoded with decode_simple: '{decoded_simple}'")

# 5. Decode Using decode() (Beam Search with LM)
# Simulate logits where the correct tokens have high probabilities
vocab_size = len(encoder)
time_steps = len(test_sequence)
logits = np.full((time_steps, vocab_size), -10.0)  # Initialize with low logits

for t, idx in enumerate(test_sequence):
    if idx is not None and idx < vocab_size:
        logits[t, idx] = 10.0  # Assign high logit to the correct token

# Convert logits to torch tensor with shape (batch_size=1, time_steps, vocab_size)
logits_tensor = torch.tensor(logits).unsqueeze(0)  # Shape: (1, time_steps, vocab_size)

# Decode using decode (beam search with LM)
decoded_beam = encoder.decode(logits_tensor[0].numpy())  # Pass logits for single batch
print(f"\nDecoded with decode (beam search): '{decoded_beam}'")

# 6. Decode Using ctc_decode()
print("\n=== Decoding with ctc_decode() ===")
decoded_ctc = encoder.ctc_decode(logits_tensor)
print(f"Decoded with ctc_decode: '{decoded_ctc}'")

# 7. Assertions to Verify Correctness
expected_output = test_phrase.lower().replace("|", " ")
print(f"\nExpected Output: '{expected_output}'")

# Assert decode_simple
assert decoded_simple == expected_output, "decode_simple failed to correctly decode the phrase."

# Assert decode (beam search)
assert decoded_beam == expected_output, "decode (beam search) failed to correctly decode the phrase."

# Assert ctc_decode
assert decoded_ctc == expected_output, "ctc_decode failed to correctly decode the phrase."

print("\n✅ All decoding methods passed the unit test!")


CTC Text Encoder:
pretrained_tokenizer: bert-base-uncased
lm_weight: 0.35
beam_size: 10
binary_path: 4-gram_lc_correct.bin
'<pad>' already exists in the tokenizer's vocabulary.
Modified Vocabulary: ['<pad>', '<s>', '</s>', '<unk>', ' ', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z']

Vocabulary Info:
Size: 32
Full Vocabulary (up to first 50 tokens): ['<pad>', '<s>', '</s>', '<unk>', ' ', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z']
Blank token: <pad>, Blank index: 0
Sample ind2char mappings: {0: '<pad>', 1: '<s>', 2: '</s>', 3: '<unk>', 4: ' ', 5: 'E', 6: 'T', 7: 'A', 8: 'O', 9: 'N'}
Sample char2ind mappings: {'<pad>': 0, '<s>': 1, '</s>': 2, '<unk>': 3, ' ': 4, 'E': 5, 'T': 6, 'A': 7, 'O': 8, 'N': 9}
model_path:  4-gram_lc_correct.bin
No language model path provided or file does not exist.
=== Vocabulary Inspe

AssertionError: decode (beam search) failed to correctly decode the phrase.

In [10]:
# Import necessary libraries
import torch
import numpy as np

# Assuming CTCTextEncoder is already defined and imported
# from src.text_encoder import CTCTextEncoder

# Initialize the CTCTextEncoder
encoder = CTCTextEncoder(
    arpa_path="4-gram_lc_correct.arpa",
    binary_path="4-gram_lc_correct.bin",
    unigram_path="librispeech-vocab.txt",
    use_bpe=True,  # Ensure BPE is disabled for character-level encoding
    pretrained_tokenizer="bert-base-uncased",
    beam_size=10,
    lm_weight=0.35
)

# 1. Inspect Vocabulary and Mappings
print("=== Vocabulary Inspection ===")
print(f"Vocabulary Size: {len(encoder.vocab)}")
print("Sample ind2char Mapping:", {k: v for k, v in list(encoder.ind2char.items())[:10]})
print("Sample char2ind Mapping:", {v: k for k, v in list(encoder.char2ind.items())[:10]})

# Identify indices for space and blank token
space_char = " "
blank_char = encoder.blank_token
space_index = encoder.char2ind.get(space_char, None)
blank_index = encoder.char2ind.get(blank_char, None)

print(f"\nSpace Token ('{space_char}') Index: {space_index}")
print(f"Blank Token ('{blank_char}') Index: {blank_index}")

# Ensure that space and blank tokens exist
assert space_index is not None, f"Space token '{space_char}' not found in char2ind mapping."
assert blank_index is not None, f"Blank token '{blank_char}' not found in char2ind mapping."

# 2. Create Test Sequence
# Define the test phrase
test_phrase = "hello world"
print(f"\nTest Phrase: '{test_phrase}'")

# Encode the test phrase into token indices
encoded_tensor = encoder.encode(test_phrase)
encoded_indices = encoded_tensor[0].tolist()
print(f"Encoded Indices for '{test_phrase}': {encoded_indices}")

# Insert duplicates and blanks into the sequence
# Example: h, h, <pad>, e, e, <pad>, l, l, <pad>, etc.
test_sequence = []
for idx in encoded_indices:
    test_sequence.extend([idx, idx, blank_index])

print(f"\nTest Sequence with Duplicates and Blanks: {test_sequence}")

# 3. Decode Using decode_simple_no_cleanup()
print("\n=== Decoding with decode_simple_no_cleanup() ===")

# Define a modified decode_simple method with debug prints
def decode_simple_no_cleanup(self, indices: List[int]) -> str:
    """
    Simple CTC decoding without language model.
    Collapses consecutive duplicate tokens and removes blanks.
    Includes debug prints to trace decoding steps.
    """
    decoded_chars = []
    previous_idx = None

    for idx in indices:
        if idx == self.blank_index:
            previous_idx = idx
            continue  # Skip blank tokens
        if idx == previous_idx:
            continue  # Skip duplicate tokens
        if 0 <= idx < len(self.ind2char):
            char = self.ind2char[idx]
            decoded_chars.append(char)
            print(f"Appended char: '{char}'")
        else:
            print(f"Skipped invalid index: {idx}")
        previous_idx = idx

    # Join characters without spaces and convert to lowercase
    text = "".join(decoded_chars).strip().lower()
    print(f"Text before cleanup: '{text}'")

    # Clean up tokenization using the tokenizer's method
    cleaned_text = self.processor.tokenizer.clean_up_tokenization(text)
    print(f"Text after cleanup: '{cleaned_text}'")

    return cleaned_text

# Bind the new method to the encoder instance
import types
encoder.decode_simple_no_cleanup = types.MethodType(decode_simple_no_cleanup, encoder)

# Perform decoding without cleanup
decoded_simple_no_cleanup = encoder.decode_simple_no_cleanup(test_sequence)
print(f"Decoded without cleanup: '{decoded_simple_no_cleanup}'")

# 4. Decode Using decode_simple()
print("\n=== Decoding with decode_simple() ===")
decoded_simple = encoder.decode_simple(test_sequence)
print(f"Decoded with decode_simple: '{decoded_simple}'")

# 5. Decode Using decode() (Beam Search with LM)
# Simulate logits where the correct tokens have high probabilities
vocab_size = len(encoder)
time_steps = len(test_sequence)
logits = np.full((time_steps, vocab_size), -10.0)  # Initialize with low logits

for t, idx in enumerate(test_sequence):
    if idx is not None and idx < vocab_size:
        logits[t, idx] = 10.0  # Assign high logit to the correct token

# Convert logits to torch tensor with shape (batch_size=1, time_steps, vocab_size)
logits_tensor = torch.tensor(logits).unsqueeze(0)  # Shape: (1, time_steps, vocab_size)

print("\n=== Decoding with decode() (beam search) ===")
decoded_beam = encoder.decode(logits_tensor[0].numpy())  # Pass logits for single batch
print(f"Decoded with decode (beam search): '{decoded_beam}'")

# 6. Decode Using ctc_decode()
print("\n=== Decoding with ctc_decode() ===")
decoded_ctc = encoder.ctc_decode(logits_tensor)
print(f"Decoded with ctc_decode: '{decoded_ctc}'")

# 7. Assertions to Verify Correctness
expected_output = test_phrase.lower().replace("|", " ")
print(f"\nExpected Output: '{expected_output}'")

# Assert decode_simple
assert decoded_simple == expected_output, "decode_simple failed to correctly decode the phrase."

# Assert decode (beam search)
assert decoded_beam == expected_output, "decode (beam search) failed to correctly decode the phrase."

# Assert ctc_decode
assert decoded_ctc == expected_output, "ctc_decode failed to correctly decode the phrase."

print("\n✅ All decoding methods passed the unit test!")


CTC Text Encoder:
pretrained_tokenizer: bert-base-uncased
lm_weight: 0.35
beam_size: 10
binary_path: 4-gram_lc_correct.bin
'<pad>' already exists in the tokenizer's vocabulary.
Modified Vocabulary: ['<pad>', '<s>', '</s>', '<unk>', ' ', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z']

Vocabulary Info:
Size: 32
Full Vocabulary (up to first 50 tokens): ['<pad>', '<s>', '</s>', '<unk>', ' ', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z']
Blank token: <pad>, Blank index: 0
Sample ind2char mappings: {0: '<pad>', 1: '<s>', 2: '</s>', 3: '<unk>', 4: ' ', 5: 'E', 6: 'T', 7: 'A', 8: 'O', 9: 'N'}
Sample char2ind mappings: {'<pad>': 0, '<s>': 1, '</s>': 2, '<unk>': 3, ' ': 4, 'E': 5, 'T': 6, 'A': 7, 'O': 8, 'N': 9}
model_path:  4-gram_lc_correct.bin
No language model path provided or file does not exist.
=== Vocabulary Inspe

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [29]:
# Inspect the full char2ind and ind2char mappings
print("=== Full char2ind Mapping ===")
for char, idx in encoder.char2ind.items():
    print(f"'{char}': {idx}")

print("\n=== Full ind2char Mapping ===")
for idx, char in encoder.ind2char.items():
    print(f"{idx}: '{char}'")


=== Full char2ind Mapping ===
' ': 0
'a': 1
'b': 2
'c': 3
'd': 4
'e': 5
'f': 6
'g': 7
'h': 8
'i': 9
'j': 10
'k': 11
'l': 12
'm': 13
'n': 14
'o': 15
'p': 16
'q': 17
'r': 18
's': 19
't': 20
'u': 21
'v': 22
'w': 23
'x': 24
'y': 25
'z': 26
'ä': 27
'å': 28
'é': 29
'ô': 30
'ö': 31
'ü': 32
'[unk]': 33
'[pad]': 34
'<s>': 35
'</s>': 36

=== Full ind2char Mapping ===
0: ' '
1: 'a'
2: 'b'
3: 'c'
4: 'd'
5: 'e'
6: 'f'
7: 'g'
8: 'h'
9: 'i'
10: 'j'
11: 'k'
12: 'l'
13: 'm'
14: 'n'
15: 'o'
16: 'p'
17: 'q'
18: 'r'
19: 's'
20: 't'
21: 'u'
22: 'v'
23: 'w'
24: 'x'
25: 'y'
26: 'z'
27: 'ä'
28: 'å'
29: 'é'
30: 'ô'
31: 'ö'
32: 'ü'
33: '[unk]'
34: '[pad]'
35: '<s>'
36: '</s>'
