Tokenizer 21 Dec testing

In [37]:
import re
from collections import defaultdict
import torch
import kenlm
from transformers import Wav2Vec2Processor, AutoProcessor, AutoTokenizer
from pyctcdecode import build_ctcdecoder
import numpy as np
import os

from typing import List, Tuple, Optional, Union

os.environ["TOKENIZERS_PARALLELISM"] = "false"

class CTCTextEncoder:
    def __init__(
        self,
        arpa_path: Optional[str] = None,
        binary_path: Optional[str] = None,
        unigram_path: Optional[str] = None,
        pretrained_tokenizer: str = "facebook/wav2vec2-base-960h",
        lm_weight: float = 0.5,
        beam_size: int = 100,
        blank_token: str = "[pad]",  # Blank token as <pad> for Wav2Vec2
        unk_token: str = "[unk]",     # UNK token
        **kwargs
    ):
        """
        Initialize encoder with Wav2Vec2 tokenizer and beam search decoder.

        Changes:
        - Ensure normalization is strictly lowercase with only [a-z ].
        """
        self.beam_size = beam_size
        self.lm_weight = lm_weight
        self.arpa_path = arpa_path
        self.binary_path = binary_path
        self.blank_token = blank_token
        self.unk_token = unk_token
        self.printed_samples = 0
        self.max_printed_samples = 5
        print('CTC Text Encoder:')
        print('pretrained_tokenizer:', pretrained_tokenizer)
        print('lm_weight:', lm_weight)
        print('beam_size:', beam_size)
        print('binary_path:', binary_path)

        # unigram_path = None

        # Load unigrams if provided
        
        self.unigrams = None
        
        # if unigram_path and os.path.exists(unigram_path):
        #     print(f"Loading unigrams from: {unigram_path}")
        #     with open(unigram_path, 'r', encoding='utf-8') as f:
        #         self.unigrams = [line.strip().lower() for line in f if line.strip()]
        #     print(f"Loaded {len(self.unigrams)} unigrams")
        
        if unigram_path and os.path.exists(unigram_path):
            print(f"Loading unigrams from: {unigram_path}")
            with open(unigram_path, 'r', encoding='utf-8') as f:
                self.unigrams = [line.strip().lower() for line in f if line.strip()]
                # self.unigrams = [line.strip().upper() for line in f if line.strip()]
            print(f"Loaded {len(self.unigrams)} unigrams")


        self._initialize_wav2vec2_tokenizer(pretrained_tokenizer)

        # Create index mappings
        self.ind2char = dict(enumerate(self.vocab))
        self.char2ind = {v: k for k, v in self.ind2char.items()}
        self.blank_index = self.char2ind[self.blank_token]

        print(f"\nVocabulary Info:")
        print(f"Size: {len(self.vocab)}")
        print("Full Vocabulary (up to first 50 tokens):", self.vocab[:50])
        print(f"Blank token: {self.blank_token}, Blank index: {self.blank_index}")

        print("Sample ind2char mappings:", {k: self.ind2char[k] for k in list(self.ind2char.keys())[:10]})
        print("Sample char2ind mappings:", {k: self.char2ind[k] for k in list(self.char2ind.keys())[:10]})

        self._initialize_language_model()

    # def _initialize_wav2vec2_tokenizer(self, pretrained_tokenizer: str):
    #     """Initialize vocabulary using Wav2Vec2 tokenizer."""
    #     self.processor = Wav2Vec2Processor.from_pretrained(pretrained_tokenizer)

    #     # Add the unique blank token if not present
    #     if self.blank_token not in self.tokenizer.get_vocab():
    #         self.processor.tokenizer.add_tokens([self.blank_token])
    #         print(f"Added '{self.blank_token}' to the tokenizer's vocabulary.")
    #     else:
    #         print(f"'{self.blank_token}' no language exists in the tokenizer's vocabulary.")

    #     # Add the UNK token if not present
    #     if self.unk_token not in self.processor.tokenizer.get_vocab():
    #         self.processor.tokenizer.add_tokens([self.unk_token])
    #         print(f"Added '{self.unk_token}' to the tokenizer's vocabulary.")
    #     else:
    #         print(f"'{self.unk_token}' no lm path exists in the tokenizer's vocabulary.")

    #     # Get vocab, convert to lowercase and replace '|' with ' '
    #     original_vocab = list(self.processor.tokenizer.get_vocab().keys())
    #     self.vocab = [x.lower() for x in original_vocab]
    #     self.vocab = [t.replace('|', ' ') for t in self.vocab]

    #     # Debug: Print a few tokens after modification
    #     print("Modified Vocabulary (first 20 tokens):", self.vocab[:20])
    

    def _initialize_wav2vec2_tokenizer(self, pretrained_tokenizer: str):
        """Initialize vocabulary using Wav2Vec2 tokenizer."""
        # Initialize Wav2Vec2Processor
        # self.processor = Wav2Vec2Processor.from_pretrained(pretrained_tokenizer)

        

        # # Add the UNK token if not present
        # if self.unk_token not in self.processor.tokenizer.get_vocab():
        #     self.processor.tokenizer.add_tokens([self.unk_token])
        #     print(f"Added '{self.unk_token}' to the tokenizer's vocabulary.")
        # else:
        #     print(f"'{self.unk_token}' already exists in the tokenizer's vocabulary.")

        # # Get vocab without altering the case
        # self.vocab = list(self.processor.tokenizer.get_vocab().keys())

        # vocab_dict = processor.tokenizer.get_vocab()
        # sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
        # labels_adj = list(sorted_vocab_dict.keys())

        # self.vocab = [t.replace('|', ' ') for t in self.vocab]

        # # Debugging: Inspect the tokenizer's vocabulary casing
        # print("\n--- Tokenizer Vocabulary Inspection ---")
        # sample_size = 100  # Adjust as needed
        # sample_tokens = self.vocab[:sample_size]
        # print(f"First {sample_size} tokens in vocabulary:")
        # print(sample_tokens)

        # # Save the full tokenizer vocabulary to a file for comparison
        # with open("tokenizer_vocab.txt", "w", encoding="utf-8") as f:
        #     for token in self.vocab:
        #         f.write(f"{token}\n")
        # print("Full tokenizer vocabulary saved to 'tokenizer_vocab.txt'.")
        # print("----------------------------------------\n")

        self.processor = AutoProcessor.from_pretrained("hf-test/xls-r-300m-sv")

        vocab_dict = self.tokenizer.get_vocab()
        sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}

        self.labels = list(sorted_vocab_dict.keys())

        # Get vocab, convert to lowercase and replace '|' with ' '
        original_vocab = list(self.tokenizer.get_vocab().keys())
        # self.vocab = [x.lower() for x in original_vocab]
        self.vocab = sorted_vocab_dict
        self.vocab = [t.replace('|', ' ') for t in self.vocab]

        # Add the unique blank token to the tokenizer's vocabulary if not present
        if self.blank_token not in self.tokenizer.get_vocab():
            self.tokenizer.add_tokens([self.blank_token])
            print(f"Added '{self.blank_token}' to the tokenizer's vocabulary.")
        else:
            print(f"'{self.blank_token}' already exists in the tokenizer's vocabulary.")

        # Debug: Print a few tokens after modification
        print("Modified Vocabulary (first 20 tokens):", self.vocab[:20])


    def _initialize_language_model(self):
        """Initialize language model with explicit blank token handling."""
        self.lm = None
        self.decoder = None

        model_path = self.binary_path if self.binary_path else self.arpa_path
        print('model_path: ', model_path)
        if not model_path or not os.path.exists(model_path):
            print("No language model path provided or file does not exist.")
            return

        try:
            self.lm = kenlm.Model(model_path)
            print(f"Loaded {'binary' if self.binary_path else 'ARPA'} language model.")

            # labels = [self.blank_token] + [c for c in self.vocab if c != self.blank_token]

            decoder_config = {
                "labels": self.labels,
                "kenlm_model_path": model_path,
                "alpha": self.lm_weight,
                "beta": 0.1,
                "unk_score_offset": -10.0,
            }

            if self.unigrams:
                print("\n--- Unigrams List ---")
                # Save the unigrams to a file
                with open("unigrams_list.txt", "w", encoding="utf-8") as f:
                    for unigram in self.unigrams:
                        f.write(f"{unigram}\n")
                print(f"Unigrams list saved to 'unigrams_list.txt'. Total unigrams: {len(self.unigrams)}")
                print("----------------------\n")
                decoder_config["unigrams"] = self.unigrams

            self.decoder = build_ctcdecoder(**decoder_config)
            print("Successfully initialized language model and decoder.")

        except Exception as e:
            print(f"Warning: Failed to initialize decoder: {str(e)}")
            self.decoder = None


    def encode(self, text: str) -> torch.Tensor:
        """
        Encode text with Wav2Vec2 tokenizer.
        """
        debug = False

        if self.printed_samples < self.max_printed_samples:
            original_text = text
            text = self.normalize_text(text)
            if debug:
                print(f"samples: {str(self.printed_samples)}")
                print(f"\nEncoding text:")
                print(f" Original: '{original_text}'")
                print(f" Normalized: '{text}'")
                for ch in text:
                    print(ch, ord(ch))


        try:
            encoded = self.tokenizer(text, return_tensors="pt", padding=False, truncation=False)
            token_indices = encoded.input_ids[0].tolist()
            if self.printed_samples < self.max_printed_samples:
                # Convert indices to tokens from self.vocab
                tokens = [self.vocab[idx] if 0 <= idx < len(self.vocab) else "<invalid>" for idx in token_indices]
                # print(f" Tokens (lowercased and '|'->' '): {tokens}")
                # print(f" Token indices: {token_indices}")
                self.printed_samples += 1
            return torch.tensor(token_indices).unsqueeze(0)
        except KeyError as e:
            unknown_tokens = set([token for token in text.split() if token not in self.char2ind])
            raise Exception(f"Unknown tokens: '{' '.join(unknown_tokens)}'")
        except Exception as e:
            raise Exception(f"Encoding error: {str(e)}")

    # latest ver
    # def decode(self, indices: List[int]) -> str:
    #     """
    #     Decode indices to text using beam search decoder if available.
    #     """
    #     if self.decoder:
    #         ctc_decode = self.decoder.decode(indices)
    #         # convert to lower case
    #         decoded_text = decoded_text.lower()
    #         return decoded_text
    #     else:
    #         decoded_text = self.decode_simple(indices)
    #         # convert to lower case
    #         decoded_text = decoded_text.lower()
    #         return
    
    def decode(self, indices: List[int]) -> str:
        """
        Decode indices to text using beam search decoder if available.
        """
        if self.decoder:
            decoded_text = self.decoder.decode(indices)
            # Convert to lower case
            decoded_text = decoded_text.lower()
            return decoded_text
        else:
            decoded_text = self.decode_simple(indices)
            # Convert to lower case
            decoded_text = decoded_text.lower()
            return decoded_text  # Ensure the decoded text is returned

    # latest ver
    # def decode_simple(self, indices: List[int]) -> str:
    #     """
    #     Simple CTC decoding without language model.
    #     """
    #     valid_indices = [
    #         idx for idx in indices
    #         if idx != self.blank_index and 0 <= idx < len(self.ind2char)
    #     ]
    #     try:
    #         tokens = [self.ind2char[idx] for idx in valid_indices]
    #         text = " ".join(tokens).strip().lower()
    #         return self.processor.tokenizer.clean_up_tokenization(text)
    #     except KeyError as e:
    #         return " ".join([self.ind2char[idx] for idx in valid_indices if idx in self.ind2char])

    def decode_simple(self, indices: List[int]) -> str:
        """
        Simple CTC decoding without language model.
        Collapses consecutive duplicate tokens and removes blanks.
        """
        decoded_chars = []

        for idx in indices:
            if idx == self.blank_index:
                continue  # Skip blank tokens
            if 0 <= idx < len(self.ind2char):
                char = self.ind2char[idx]
                decoded_chars.append(char)

        # Join characters without collapsing duplicates and convert to lowercase
        text = "".join(decoded_chars).strip().lower()

        # Clean up tokenization using the tokenizer's method (if available)
        return text


    def decode_logits(self, logits: Union[torch.Tensor, List[List[float]], np.ndarray]) -> str:
        """
        Decode logits using the decoder if available, otherwise use greedy decoding.
        """
        if isinstance(logits, torch.Tensor):
            logits = logits.cpu().numpy()
        elif isinstance(logits, list):
            logits = np.array(logits)
        elif not isinstance(logits, np.ndarray):
            raise TypeError("logits must be a torch.Tensor, list of lists, or numpy.ndarray")

        if logits.ndim == 3:
            logits = logits[0]

        if logits.ndim != 2:
            raise ValueError(f"Logits should be 2D (time_steps, vocab_size), got {logits.ndim}D")

        if self.decoder:
            decoded_text = self.decoder.decode(logits)
            return decoded_text
        else:
            predicted_indices = np.argmax(logits, axis=-1).tolist()
            return self.decode_simple(predicted_indices)

    def decode_indices(self, indices: Union[torch.Tensor, List[int], np.ndarray]) -> str:
        """
        Decode token indices to text using simple decoding (no LM).
        """
        if isinstance(indices, torch.Tensor):
            indices = indices.squeeze().tolist()
        elif isinstance(indices, np.ndarray):
            indices = indices.tolist()
        elif not isinstance(indices, list):
            raise TypeError("decode_indices expects a list, torch.Tensor, or numpy.ndarray.")

        return self.decode_simple(indices)


    # latest ver
    def ctc_decode(self, logits: Union[torch.Tensor, List[int], np.ndarray]) -> str:
        """
        Perform CTC decoding on logits.
        """
        if isinstance(logits, np.ndarray):
            logits = torch.from_numpy(logits)
        elif isinstance(logits, list):
            logits = torch.tensor(logits)

        if logits.dim() == 3:
            decoded_text = self.decode_logits(logits)
            return decoded_text
        elif logits.dim() == 2:
            decoded_text = self.decode_logits(logits)
            return decoded_text
        elif logits.dim() == 1:
            decoded_text = self.decode_indices(logits)
            return decoded_text
        else:
            raise ValueError(f"Unsupported logits shape: {logits.shape}. Expected 1D, 2D, or 3D.")

    def ctc_beam_search(self, probs, beam_size: int = 40,
                       use_lm: bool = False, debug: bool = False) -> List[Tuple[str, float]]:
        """
        Beam search with Wav2Vec2 support.
        """
        beam_size = self.beam_size
        debug = False

        if use_lm and self.decoder is not None:
            try:
                if isinstance(probs, torch.Tensor):
                    probs = probs.cpu().numpy()
                elif isinstance(probs, list):
                    probs = np.array(probs)
                elif isinstance(probs, np.ndarray):
                    pass
                else:
                    raise TypeError("probs must be a torch.Tensor, list, or numpy.ndarray")

                beams = self.decoder.decode_beams(
                    probs,
                    beam_prune_logp=-10.0,
                    token_min_logp=-5.0,
                    hotwords=[],
                    hotword_weight=10.0,
                )

                formatted_beams = []
                for beam in beams[:beam_size]:
                    text = beam[0]
                    acoustic_score = beam[3]
                    lm_score = beam[4]

                    text = self.tokenizer.clean_up_tokenization(text)
                    text = text.lower().strip()

                    combined_score = (1 - self.lm_weight) * acoustic_score + self.lm_weight * lm_score
                    text_len = max(1, len(text.split()))
                    normalized_score = combined_score / text_len

                    formatted_beams.append((text, normalized_score))

                if debug:
                    print("\nFormatted beam results with Wav2Vec2:")
                    for text, score in formatted_beams[:3]:
                        print(f"Text: '{text}', Score: {score:.4f}")

                if formatted_beams:
                    return sorted(formatted_beams, key=lambda x: -x[1])
                else:
                    print("No valid beams found, falling back to standard beam search")
                    return self._standard_beam_search(probs, beam_size, debug)

            except Exception as e:
                print(f"Beam search with LM failed: {str(e)}, falling back to standard beam search")
                return self._standard_beam_search(probs, beam_size, debug)
        else:
            return self._standard_beam_search(probs, beam_size, debug)

    def _standard_beam_search(self, probs, beam_size: int = 10, debug: bool = False) -> List[Tuple[str, float]]:
        """Original beam search implementation with improved debugging"""
        beam_size = self.beam_size

        if isinstance(probs, np.ndarray):
            probs = torch.from_numpy(probs)

        if probs.device != torch.device('cpu'):
            probs = probs.cpu()

        dp = {("", self.blank_token): 0.0}
        log_probs = torch.log(probs + 1e-8)

        if debug:
            print("\nStarting beam search with beam size:", beam_size)

        for t, prob in enumerate(log_probs):
            new_dp = defaultdict(lambda: float('-inf'))
            top_k = torch.topk(prob, k=min(beam_size, len(prob)))

            if debug and t < self.max_printed_samples:
                print(f"\nTimestep {t}:")
                print("Top tokens:", [(self.ind2char[idx.item()], val.item()) 
                                    for val, idx in zip(top_k.values, top_k.indices)])

            for val, ind in zip(top_k.values, top_k.indices):
                curr_char = self.ind2char[ind.item()]
                next_token_log_prob = val.item()

                for (prefix, last_char), log_prob in dp.items():
                    if last_char == curr_char and curr_char != " ":
                        new_prefix = prefix
                    else:
                        if curr_char != self.blank_token:
                            if curr_char == " " and prefix.endswith(" "):
                                continue
                            new_prefix = prefix + curr_char
                        else:
                            new_prefix = prefix

                    new_log_prob = log_prob + next_token_log_prob
                    key = (new_prefix, curr_char)
                    new_dp[key] = max(new_dp[key], new_log_prob)

            if len(new_dp) > 0:
                max_score = max(score for _, score in new_dp.items())
                new_dp = {key: score - max_score for key, score in new_dp.items()}

            dp = dict(sorted(new_dp.items(), key=lambda x: -x[1])[:beam_size])

            if debug and t < 2:
                print("\nCurrent beam:")
                for (text, last_char), score in list(dp.items())[:3]:
                    print(f"Text: '{text}', Last: '{last_char}', Score: {score:.4f}")

        final_beams = []
        for (text, _), score in dp.items():
            text = self.tokenizer.clean_up_tokenization(text)
            text = text.lower().strip()
            text_len = max(1, len(text.split()))
            normalized_score = score / text_len
            final_beams.append((text, normalized_score))

        final_beams.sort(key=lambda x: -x[1])
        if not final_beams:
            final_beams = [("", float('-inf'))]

        return final_beams[:beam_size]

    def test_language_model(self):
        """Debug function to verify LM functionality"""
        print("\nTesting Language Model...")

        if self.lm is None:
            print("Error: Language model is not loaded!")
            return

        test_sentences = [
            "this is a good sentence",
            "this is also a good sentence",
            "thiss iss nott aa goodd sentencee",
            "random word salad box cat",
            "the cat sat on the mat",
            "",
            "a",
        ]

        print("\nTesting individual sentences:")
        for sentence in test_sentences:
            score = self.score_with_lm(sentence)
            print(f"\nText: '{sentence}'")
            print(f"LM Score: {score:.4f}")

        test_prefixes = [
            "the quick brown",
            "how are",
            "thank",
            "nice to",
        ]

        print("\nTesting word completions:")
        for prefix in test_prefixes:
            print(f"\nPrefix: '{prefix}'")
            completions = [
                prefix + " " + word for word in ["you", "fox", "cat", "xyz", "meet"]
            ]
            scores = [(completion, self.score_with_lm(completion)) 
                    for completion in completions]
            scores.sort(key=lambda x: x[1], reverse=True)
            print("Top completions by score:")
            for completion, score in scores[:3]:
                print(f"  '{completion}': {score:.4f}")

    def score_with_lm(self, text: str) -> float:
        """
        Score text using language model, handling edge cases
        """
        if self.lm is None:
            return 0.0

        if not text or len(text.strip()) == 0:
            return float('-inf')

        text = text.lower().strip()
        return self.lm.score(text, bos=True, eos=True)

    def _basic_ctc_decode(self, logits: np.ndarray, sequence_length: int) -> List[str]:
        """Basic CTC decoding without LM"""
        argmax_indices = np.argmax(logits, axis=-1)

        if len(argmax_indices.shape) == 0:
            argmax_indices = np.array([argmax_indices])

        if len(argmax_indices.shape) == 1:
            argmax_indices = np.expand_dims(argmax_indices, axis=0)

        predictions = []
        for sequence in argmax_indices:
            decoded = []
            last_idx = None

            for idx in sequence[:sequence_length]:
                if idx != self.blank_index and idx != last_idx:
                    decoded.append(self.ind2char[idx])
                last_idx = idx

            text = "".join(decoded)
            if hasattr(self, 'processor') or self.use_bpe:
                text = self.tokenizer.clean_up_tokenization(text)
            predictions.append(text)

        return predictions

    @staticmethod
    def normalize_text(text: str) -> str:
        """Normalize input text"""
        # text = text.lower()
        text = text.lower()
        text = re.sub(r"[^a-z ]", "", text)
        # text = re.sub(r"[^A-Z ]", "", text)
        return text


    def test_decoder(self, sample_text: str = "test decoder functionality"):
        """Test the decoder setup"""
        print("\nTesting decoder configuration...")

        encoded = self.encode(sample_text)
        decoded = self.decode(encoded[0].tolist())
        print(f"Original text: {sample_text}")
        print(f"Basic decode: {decoded}")

        sequence_length = 50
        vocab_size = len(self)
        fake_logits = torch.randn(1, sequence_length, vocab_size)
        fake_length = torch.tensor([sequence_length])

        if self.decoder is not None:
            print("\nTesting pyctcdecode integration...")
            decoded_with_lm = self.ctc_decode(fake_logits)
            print(f"Decoded with LM: {decoded_with_lm}")

            print(f"\nBeam width: {self.beam_size}")
            print(f"LM weight: {self.lm_weight}")
        else:
            print("\nNo language model loaded - using basic CTC decoding")
            basic_decoded = self._basic_ctc_decode(fake_logits.numpy(), fake_length)
            print(f"Basic CTC decoded: {basic_decoded[0]}")

    def __len__(self):
        return len(self.vocab)

In [30]:
# Initialize the encoder
encoder = CTCTextEncoder(
    pretrained_tokenizer="facebook/wav2vec2-base-960h",
    blank_token="[pad]",
    unk_token="[unk]"
)

# Define test text
test_text = "hello world"

# Test encoding
encoded = encoder.encode(test_text)
print(f"Original text: {test_text}")
print(f"Encoded: {encoded}")

# Verify the structure of the encoded tensor
print(f"Encoded tensor shape: {encoded.shape}")

# Correct sequence length extraction
sequence_length = encoded.size(1)  # Number of time steps in the encoded tensor
vocab_size = len(encoder.vocab)

# Create dummy logits for decoding
logits = torch.randn(1, sequence_length, vocab_size)

# Decode logits directly
decoded_logits = encoder.decode_logits(logits)
print(f"Decoded from logits: {decoded_logits}")

# Test decoding from indices
decoded = encoder.decode(encoded[0].tolist())
print(f"Decoded: {decoded}")

# Test simple decoding
decoded_simple = encoder.decode_simple(encoded[0].tolist())
print(f"Simple decoded: {decoded_simple}")

# Check vocabulary alignment
print("\nVocabulary and mappings check:")
print(f"Sample ind2char mappings: {list(encoder.ind2char.items())[:10]}")
print(f"Sample char2ind mappings: {list(encoder.char2ind.items())[:10]}")


CTC Text Encoder:
pretrained_tokenizer: facebook/wav2vec2-base-960h
lm_weight: 0.5
beam_size: 100
binary_path: None


Fetching 4 files: 100%|█| 4/4 [00:00<00:00, 38746


Added '[pad]' to the tokenizer's vocabulary.
Modified Vocabulary (first 20 tokens): [' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's']

Vocabulary Info:
Size: 37
Full Vocabulary (up to first 50 tokens): [' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'ä', 'å', 'é', 'ô', 'ö', 'ü', '[unk]', '[pad]', '<s>', '</s>']
Blank token: [pad], Blank index: 34
Sample ind2char mappings: {0: ' ', 1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i'}
Sample char2ind mappings: {' ': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9}
model_path:  None
No language model path provided or file does not exist.
Original text: hello world
Encoded: tensor([[ 8,  5, 12, 12, 15,  0, 23, 15, 18, 12,  4]])
Encoded tensor shape: torch.Size([1, 11])
Decoded from logits: dbf</s>åévfüo
Decoded: hello world
Simple decoded: hello world

Vocabu

Updating class for tokenizer

In [14]:
import re
from collections import defaultdict
import torch
import kenlm
from transformers import Wav2Vec2Processor, AutoProcessor
from pyctcdecode import build_ctcdecoder
import numpy as np
import os
from string import ascii_lowercase

from typing import List, Tuple, Optional, Union

os.environ["TOKENIZERS_PARALLELISM"] = "false"

class CTCTextEncoder:
    def __init__(
        self,
        arpa_path: Optional[str] = None,
        # binary_path: Optional[str] = None,
        # unigram_path: Optional[str] = None,
        binary_path: Optional[str] = "4-gram_lc_correct.bin",
        unigram_path: Optional[str] = "librispeech-vocab.txt",
        pretrained_tokenizer: str = "facebook/wav2vec2-base-960h",
        lm_weight: float = 0.5,
        beam_size: int = 100,
        use_lm: bool = False,     # **Added use_lm parameter**
        use_bpe: bool = False,    # **Added use_bpe parameter**
        blank_token: str = "[pad]",  # Blank token as <pad> for Wav2Vec2
        unk_token: str = "[unk]",     # UNK token
        **kwargs
    ):
        """
        Initialize encoder with conditional tokenizer/processor and language model.

        Parameters:
        - use_lm (bool): Whether to use the Language Model (LM) during decoding.
        - use_bpe (bool): Whether to use Byte Pair Encoding (BPE) via tokenizer/processor.
                           If False, perform character-based encoding/decoding without tokenizer.
        """
        self.beam_size = beam_size
        self.lm_weight = lm_weight
        self.arpa_path = arpa_path
        self.binary_path = binary_path
        self.blank_token = blank_token
        self.unk_token = unk_token
        self.use_lm = use_lm # False
        self.use_bpe = use_bpe # False MANUAL FOR NOW
        # self.use_bpe = False # use_bpe # False # use_bpe
        self.printed_samples = 0
        self.max_printed_samples = 5
        print('CTC Text Encoder:')
        print('pretrained_tokenizer:', pretrained_tokenizer)
        print('lm_weight:', lm_weight)
        print('beam_size:', beam_size)
        print('binary_path:', binary_path)
        print('use_lm:', self.use_lm)
        print('use_bpe:', self.use_bpe)

        # Define blank token
        if use_bpe:
            self.blank_token = "<pad>"
        else:
            self.blank_token = ""
        print("blank token: ", self.blank_token)

        # Load unigrams if provided
        self.unigrams = None
        if unigram_path and os.path.exists(unigram_path):
            print(f"Loading unigrams from: {unigram_path}")
            with open(unigram_path, 'r', encoding='utf-8') as f:
                self.unigrams = [line.strip().lower() for line in f if line.strip()]
            print(f"Loaded {len(self.unigrams)} unigrams")

        # Initialize the tokenizer or set up character-based vocab
        self._initialize_vocabulary(pretrained_tokenizer)

        # Create index mappings
        self.ind2char = dict(enumerate(self.vocab))
        self.char2ind = {v: k for k, v in self.ind2char.items()}
        self.blank_index = self.char2ind.get(self.blank_token, None)

        print(f"\nVocabulary Info:")
        print(f"Size: {len(self.vocab)}")
        print("Full Vocabulary (up to first 50 tokens):", self.vocab[:50])
        print(f"Blank token: {self.blank_token}, Blank index: {self.blank_index}")

        print("Sample ind2char mappings:", {k: self.ind2char[k] for k in list(self.ind2char.keys())[:10]})
        print("Sample char2ind mappings:", {k: self.char2ind[k] for k in list(self.char2ind.keys())[:10]})

        # **Conditionally initialize language model based on use_lm**
        if self.use_lm:
            self._initialize_language_model()
            self.test_language_model()
        else:
            print("Language model usage is disabled.")
            self.lm = None
            self.decoder = None

    
    def _initialize_vocabulary(self, pretrained_tokenizer: str):
        """
        Initialize the vocabulary either using a tokenizer/processor (BPE) or character-based.

        Parameters:
        - pretrained_tokenizer (str): Name or path of the pretrained tokenizer.
        """
        if self.use_bpe:
            print("Initializing tokenizer and using BPE for encoding/decoding.")
            self.processor = Wav2Vec2Processor.from_pretrained(pretrained_tokenizer)

            vocab_dict = self.processor.tokenizer.get_vocab()
            sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}

            self.labels = list(sorted_vocab_dict.keys())
            self.vocab = [token.replace('|', ' ') for token in self.labels]

            if self.blank_token not in self.processor.tokenizer.get_vocab():
                self.processor.tokenizer.add_tokens([self.blank_token])
                print(f"Added '{self.blank_token}' to the tokenizer's vocabulary.")

            if self.unk_token not in self.processor.tokenizer.get_vocab():
                self.processor.tokenizer.add_tokens([self.unk_token])
                print(f"Added '{self.unk_token}' to the tokenizer's vocabulary.")

            self.ind2char = dict(enumerate(self.vocab))
            self.char2ind = {v: k for k, v in self.ind2char.items()}
            self.blank_index = self.char2ind.get(self.blank_token, None)
        else:
            print("Initializing character-based vocabulary without using tokenizer.")

            if self.unigrams and self.use_lm:
                # Build alphabet from unigrams
                # print("Building alphabets from unigrams")
                # alphabet_set = set(char for word in self.unigrams for char in word)
                # alphabet_set.add(" ")
                # alphabet = sorted(list(alphabet_set))


                print("Building default alphabet")
                alphabet = list(ascii_lowercase + " ")
                self.alphabet = alphabet

                self.vocab = [self.blank_token] + list(alphabet)
                print(f"Loaded character vocabulary of size: {len(self.vocab)}")
            else:
                # Default to lowercase letters and space
                print("Building default alphabet")
                alphabet = list(ascii_lowercase + " ")

                self.alphabet = alphabet
                # Insert blank token at the beginning of the vocabulary
                self.vocab = [self.blank_token] + list(self.alphabet)
                print(f"Loaded character vocabulary of size: {len(self.vocab)}")
            
            # self.vocab = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
            #             'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x',
            #             'y', 'z', ' ']
            
            # self.vocab += [self.blank_token, self.unk_token]
            self.ind2char = dict(enumerate(self.vocab))
            self.char2ind = {v: k for k, v in self.ind2char.items()}
            self.blank_index = self.char2ind.get(self.blank_token, None)
            self.processor = None


    def encode(self, text: str) -> torch.Tensor:
        """
        Encode text either using tokenizer/processor (BPE) or character-based encoding.

        Parameters:
        - text (str): The input text to encode.

        Returns:
        - torch.Tensor: Tensor of token indices.
        """
        if self.use_bpe:
            normalized_text = self.normalize_text(text)
            encoded = self.processor.tokenizer(normalized_text, return_tensors="pt", padding=False, truncation=False)
            token_indices = encoded.input_ids[0].tolist()
            return torch.tensor(token_indices).unsqueeze(0)
        else:
            normalized_text = self.normalize_text(text)
            token_indices = [self.char2ind.get(char, self.char2ind.get(self.unk_token)) for char in normalized_text]
            return torch.tensor(token_indices).unsqueeze(0)



    def decode_simple(self, indices: List[int]) -> str:
        """
        Simple CTC decoding without language model.
        Collapses consecutive duplicate tokens and removes blanks.

        Parameters:
        - indices (List[int]): List of token indices.

        Returns:
        - str: Decoded text.
        """
        decoded_chars = []
        previous_idx = None

        for idx in indices:
            if idx == self.blank_index:
                previous_idx = idx
                continue  # Skip blank tokens
            # if idx == previous_idx:
            #     continue  # Skip duplicate tokens # TO CHECK IF THIS DELETION IS OK!!!
            if 0 <= idx < len(self.ind2char):
                char = self.ind2char[idx]
                decoded_chars.append(char)
            previous_idx = idx

        # Join characters without spaces and convert to lowercase
        text = "".join(decoded_chars).strip().lower()

        if self.use_bpe and self.processor:
            return self.processor.tokenizer.clean_up_tokenization(text)
        else:
            return text




    def _initialize_language_model(self):
        """Initialize language model with explicit blank token handling."""
        self.lm = None
        self.decoder = None

        model_path = self.binary_path if self.binary_path else self.arpa_path
        print('model_path: ', model_path)
        if not model_path or not os.path.exists(model_path):
            print("No language model path provided or file does not exist.")
            return

        try:
            if not self.use_bpe:
                # Ensure blank token is at index 0
                # self.labels = [self.blank_token] + [c for c in self.vocab if c != self.blank_token]
                
                # Get vocabulary without EMPTY_TOK
                self.labels = [c for c in self.vocab if c != self.blank_token]
                
                print('DEBUG - labels:', self.labels) 
            
            self.lm = kenlm.Model(model_path)
            print(f"Loaded {'binary' if self.binary_path else 'ARPA'} language model.")

            decoder_config = {
                "labels": self.labels if self.use_bpe else self.vocab,
                "kenlm_model_path": model_path,
                "alpha": self.lm_weight,
                "beta": 0.1,
                "unk_score_offset": -10.0,
            }

            if self.unigrams:
                print("\n--- Unigrams List ---")
                # Save the unigrams to a file
                with open("unigrams_list.txt", "w", encoding="utf-8") as f:
                    for unigram in self.unigrams:
                        f.write(f"{unigram}\n")
                print(f"Unigrams list saved to 'unigrams_list.txt'. Total unigrams: {len(self.unigrams)}")
                print("----------------------\n")
                decoder_config["unigrams"] = self.unigrams

            self.decoder = build_ctcdecoder(**decoder_config)
            print("Successfully initialized language model and decoder.")

        except Exception as e:
            print(f"Warning: Failed to initialize decoder: {str(e)}")
            self.decoder = None

    
    def decode(self, indices: List[int]) -> str:
        """
        Decode indices to text using beam search decoder if available.

        Parameters:
        - indices (List[int]): List of token indices.

        Returns:
        - str: Decoded text.
        """
        if self.decoder:
            decoded_text = self.decoder.decode(indices)
            # Convert to lower case
            decoded_text = decoded_text.lower()
            return decoded_text
        else:
            decoded_text = self.decode_simple(indices)
            # Convert to lower case
            decoded_text = decoded_text.lower()
            return decoded_text  # Ensure the decoded text is returned

    
    def decode_logits(self, logits: Union[torch.Tensor, List[List[float]], np.ndarray]) -> str:
        """
        Decode logits using the decoder if available, otherwise use greedy decoding.

        Parameters:
        - logits (Union[torch.Tensor, List[List[float]], np.ndarray]): Logits output from the model.

        Returns:
        - str: Decoded text.
        """
        if isinstance(logits, torch.Tensor):
            logits = logits.cpu().numpy()
        elif isinstance(logits, list):
            logits = np.array(logits)
        elif not isinstance(logits, np.ndarray):
            raise TypeError("logits must be a torch.Tensor, list of lists, or numpy.ndarray")

        if logits.ndim == 3:
            logits = logits[0]

        if logits.ndim != 2:
            raise ValueError(f"Logits should be 2D (time_steps, vocab_size), got {logits.ndim}D")

        if self.decoder:
            decoded_text = self.decoder.decode(logits)
            return decoded_text
        else:
            predicted_indices = np.argmax(logits, axis=-1).tolist()
            return self.decode_simple(predicted_indices)

    def decode_indices(self, indices: Union[torch.Tensor, List[int], np.ndarray]) -> str:
        """
        Decode token indices to text using simple decoding (no LM).

        Parameters:
        - indices (Union[torch.Tensor, List[int], np.ndarray]): Token indices.

        Returns:
        - str: Decoded text.
        """
        if isinstance(indices, torch.Tensor):
            indices = indices.squeeze().tolist()
        elif isinstance(indices, np.ndarray):
            indices = indices.tolist()
        elif not isinstance(indices, list):
            raise TypeError("decode_indices expects a list, torch.Tensor, or numpy.ndarray.")

        return self.decode_simple(indices)

    def ctc_decode(self, logits: Union[torch.Tensor, List[int], np.ndarray]) -> str:
        """
        Perform CTC decoding on logits.

        Parameters:
        - logits (Union[torch.Tensor, List[int], np.ndarray]): Logits or token indices.

        Returns:
        - str: Decoded text.
        """
        if isinstance(logits, np.ndarray):
            logits = torch.from_numpy(logits)
        elif isinstance(logits, list):
            logits = torch.tensor(logits)

        if logits.dim() == 3:
            logits = logits[0]  # Reduce to 2D (sequence length, vocab size)

        if self.use_bpe:
            if self.use_lm and self.decoder:
                # Use LM if available
                return self.decoder.decode(logits)
            else:
                # Use tokenizer-based decoding
                predicted_indices = torch.argmax(logits, axis=-1).tolist()
                return self.decode(predicted_indices)
        else:
            # Use character-based decoding
            predicted_indices = torch.argmax(logits, axis=-1).tolist()
            return self.decode_simple(predicted_indices)

    

    def ctc_beam_search(self, probs, beam_size, use_lm = False, debug: bool = False) -> List[Tuple[str, float]]:
        """
        Beam search with optional Language Model support.

        Parameters:
        - probs: Probability distributions over tokens.
        - debug (bool): Whether to print debug information.

        Returns:
        - List[Tuple[str, float]]: List of decoded text with scores.
        """
        beam_size = self.beam_size
        debug = False

        if self.use_lm and self.decoder is not None:
            try:
                if isinstance(probs, torch.Tensor):
                    probs = probs.cpu().numpy()
                elif isinstance(probs, list):
                    probs = np.array(probs)
                elif isinstance(probs, np.ndarray):
                    pass
                else:
                    raise TypeError("probs must be a torch.Tensor, list, or numpy.ndarray")

                beams = self.decoder.decode_beams(
                    probs,
                    beam_prune_logp=-10.0,
                    token_min_logp=-5.0,
                    hotwords=[],
                    hotword_weight=10.0,
                )

                formatted_beams = []
                for beam in beams[:beam_size]:
                    text = beam[0]
                    acoustic_score = beam[3]
                    lm_score = beam[4]

                    if self.use_bpe and self.processor:
                        text = self.processor.tokenizer.clean_up_tokenization(text)
                    text = text.lower().strip()

                    combined_score = (1 - self.lm_weight) * acoustic_score + self.lm_weight * lm_score
                    text_len = max(1, len(text.split())) if self.use_bpe else max(1, len(text))
                    normalized_score = combined_score / text_len

                    formatted_beams.append((text, normalized_score))

                if debug:
                    print("\nFormatted beam results with Language Model:")
                    for text, score in formatted_beams[:3]:
                        print(f"Text: '{text}', Score: {score:.4f}")

                if formatted_beams:
                    return sorted(formatted_beams, key=lambda x: -x[1])
                else:
                    print("No valid beams found, falling back to standard beam search")
                    return self._standard_beam_search(probs, debug)

            except Exception as e:
                print(f"Beam search with LM failed: {str(e)}, falling back to standard beam search")
                return self._standard_beam_search(probs, debug)
        else:
            return self._standard_beam_search(probs, debug)

    # def _standard_beam_search(self, probs, debug: bool = False) -> List[Tuple[str, float]]:
    #     """
    #     Original beam search implementation without Language Model.

    #     Parameters:
    #     - probs: Probability distributions over tokens.
    #     - debug (bool): Whether to print debug information.

    #     Returns:
    #     - List[Tuple[str, float]]: List of decoded text with scores.
    #     """
    #     beam_size = self.beam_size

    #     debug = True # TEMP

    #     if isinstance(probs, np.ndarray):
    #         probs = torch.from_numpy(probs)

    #     if probs.device != torch.device('cpu'):
    #         probs = probs.cpu()

    #     dp = {("", self.blank_token): 0.0}
    #     log_probs = torch.log(probs + 1e-8)

    #     if debug:
    #         print("\nStarting beam search with beam size:", beam_size)

    #     for t, prob in enumerate(log_probs):
    #         new_dp = defaultdict(lambda: float('-inf'))
    #         top_k = torch.topk(prob, k=min(beam_size, len(prob)))

    #         if debug and t < self.max_printed_samples:
    #             print(f"\nTimestep {t}:")
    #             print("Top tokens:", [(self.ind2char[idx.item()], val.item()) 
    #                                 for val, idx in zip(top_k.values, top_k.indices)])

    #         for val, ind in zip(top_k.values, top_k.indices):
    #             curr_char = self.ind2char[ind.item()]
    #             next_token_log_prob = val.item()

    #             for (prefix, last_char), log_prob in dp.items():
    #                 if last_char == curr_char and curr_char != " ":
    #                     new_prefix = prefix
    #                 else:
    #                     if curr_char != self.blank_token:
    #                         if curr_char == " " and prefix.endswith(" "):
    #                             continue
    #                         new_prefix = prefix + curr_char
    #                     else:
    #                         new_prefix = prefix

    #                 new_log_prob = log_prob + next_token_log_prob
    #                 key = (new_prefix, curr_char)
    #                 new_dp[key] = max(new_dp[key], new_log_prob)

    #         if len(new_dp) > 0:
    #             max_score = max(score for _, score in new_dp.items())
    #             new_dp = {key: score - max_score for key, score in new_dp.items()}

    #         dp = dict(sorted(new_dp.items(), key=lambda x: -x[1])[:beam_size])

    #         if debug and t < 2:
    #             print("\nCurrent beam:")
    #             for (text, last_char), score in list(dp.items())[:3]:
    #                 print(f"Text: '{text}', Last: '{last_char}', Score: {score:.4f}")

    #     final_beams = []
    #     for (text, _), score in dp.items():
    #         if self.use_bpe and self.processor:
    #             text = self.processor.tokenizer.clean_up_tokenization(text)
    #         text = text.lower().strip()
    #         text_len = max(1, len(text.split())) if self.use_bpe else max(1, len(text))
    #         normalized_score = score / text_len
    #         final_beams.append((text, normalized_score))

    #     final_beams.sort(key=lambda x: -x[1])
    #     if not final_beams:
    #         final_beams = [("", float('-inf'))]

    #     return final_beams[:beam_size]


    # def _standard_beam_search(self, probs, beam_size: int = 10, debug: bool = False) -> List[Tuple[str, float]]:
    #     """Original beam search implementation with improved debugging"""
    #     # Convert input to torch tensor if needed
        
    #     # beam_size = self.beam_size
    #     beam_size = 10
        
    #     if isinstance(probs, np.ndarray):
    #         probs = torch.from_numpy(probs)

    #     # Ensure probs is on CPU
    #     if probs.device != torch.device('cpu'):
    #         probs = probs.cpu()

    #     # Initialize beam with empty string
    #     dp = {("", self.blank_token): 0.0}  # Using log probs

    #     # Convert to log probabilities
    #     log_probs = torch.log(probs + 1e-8)
        
    #     if debug:
    #         print("\nStarting beam search with beam size:", beam_size)
        
    #     for t, prob in enumerate(log_probs):
    #         new_dp = defaultdict(lambda: float('-inf'))
            
    #         # Get top-k tokens for this timestep
    #         top_k = torch.topk(prob, k=min(beam_size, len(prob)))
            
    #         if debug and t < 2:  # Print first two timesteps
    #             print(f"\nTimestep {t}:")
    #             print("Top tokens:", [(self.ind2char[idx.item()], val.item()) 
    #                                 for val, idx in zip(top_k.values, top_k.indices)])
            
    #         # Only expand using top-k tokens
    #         for val, ind in zip(top_k.values, top_k.indices):
    #             curr_char = self.ind2char[ind.item()]
    #             next_token_log_prob = val.item()
                
    #             for (prefix, last_char), log_prob in dp.items():
    #                 # Skip repeated characters (except spaces)
    #                 if last_char == curr_char and curr_char != " ":
    #                     new_prefix = prefix
    #                 else:
    #                     if curr_char != self.blank_token:
    #                         # Handle spaces better
    #                         if curr_char == " " and prefix.endswith(" "):
    #                             continue
    #                         new_prefix = prefix + curr_char
    #                     else:
    #                         new_prefix = prefix
                    
    #                 # Update score
    #                 new_log_prob = log_prob + next_token_log_prob
    #                 key = (new_prefix, curr_char)
    #                 new_dp[key] = max(new_dp[key], new_log_prob)
            
    #         # Normalize scores
    #         if len(new_dp) > 0:
    #             max_score = max(score for _, score in new_dp.items())
    #             new_dp = {key: score - max_score for key, score in new_dp.items()}
            
    #         # Truncate beams
    #         dp = dict(sorted(new_dp.items(), key=lambda x: -x[1])[:beam_size])
            
    #         if debug and t < 2:  # Print beam state for first two timesteps
    #             print("\nCurrent beam:")
    #             for (text, last_char), score in list(dp.items())[:3]:
    #                 print(f"Text: '{text}', Last: '{last_char}', Score: {score:.4f}")
        
    #     # Format final results
    #     final_beams = []
    #     for (text, _), score in dp.items():
    #         # Clean up text
    #         if self.use_bpe:
    #             text = self.tokenizer.clean_up_tokenization(text)
    #         else:
    #             text = ' '.join(text.split())
                
    #         if not text.strip():  # Skip empty results
    #             continue
                
    #         # Length normalization
    #         text_len = max(1, len(text.split()))
    #         normalized_score = score / text_len
            
    #         final_beams.append((text, normalized_score))
            
    #     # Sort and ensure we have results
    #     final_beams.sort(key=lambda x: -x[1])
    #     if not final_beams:
    #         final_beams = [("", float('-inf'))]
        
    #     return final_beams[:beam_size]

    def _standard_beam_search(self, probs, beam_size: int = 10, debug: bool = False) -> List[Tuple[str, float]]:
        beam_size = self.beam_size
        beam_size = 100
        
        if isinstance(probs, np.ndarray):
            probs = torch.from_numpy(probs)
        if probs.device != torch.device('cpu'):
            probs = probs.cpu()

        # Initialize beam with empty string
        dp = {("", self.blank_token): 0.0}
        
        # Convert to log probabilities more carefully
        log_probs = torch.log(torch.clamp(probs, min=1e-8))
        
        for t, prob in enumerate(log_probs):
            new_dp = defaultdict(lambda: float('-inf'))
            
            # Consider all tokens, not just top-k
            token_indices = range(len(prob))
            token_log_probs = prob
            
            for ind in token_indices:
                curr_char = self.ind2char[ind]
                next_token_log_prob = token_log_probs[ind].item()
                
                for (prefix, last_char), log_prob in dp.items():
                    # Modified prefix handling
                    if curr_char == self.blank_token:
                        # Blank token: keep prefix unchanged
                        new_prefix = prefix
                    elif last_char == curr_char and curr_char != " ":
                        # Repeated char: merge only if not space
                        new_prefix = prefix
                    else:
                        # New char: add to prefix
                        new_prefix = prefix + curr_char
                    
                    # Update score without aggressive normalization
                    new_log_prob = log_prob + next_token_log_prob
                    key = (new_prefix, curr_char)
                    new_dp[key] = max(new_dp[key], new_log_prob)
            
            # Less aggressive score normalization
            if len(new_dp) > 0:
                max_score = max(score for _, score in new_dp.items())
                new_dp = {key: score - max_score/2 for key, score in new_dp.items()}
            
            # Keep top beams
            dp = dict(sorted(new_dp.items(), key=lambda x: -x[1])[:beam_size])
            
            if debug and t < 2:
                print(f"\nTimestep {t}:")
                top_beams = list(dp.items())[:3]
                for (text, last_char), score in top_beams:
                    print(f"Text: '{text}', Last: '{last_char}', Score: {score:.4f}")
        
        # Modified final scoring
        final_beams = []
        for (text, _), score in dp.items():
            if self.use_bpe:
                text = self.tokenizer.clean_up_tokenization(text)
            else:
                text = ' '.join(text.split())
                
            if not text.strip():
                continue
                
            # Modified length normalization
            text_len = max(1, len(text))  # Use character length instead of word length
            normalized_score = score / (text_len ** 0.5)  # Square root length normalization
            
            final_beams.append((text, normalized_score))
        
        final_beams.sort(key=lambda x: -x[1])
        if not final_beams:
            final_beams = [("", float('-inf'))]
        
        return final_beams[:beam_size]


    # def _standard_beam_search(self, probs, beam_size=50, debug=False):
    #     """
    #     Beam search implementation without Language Model (use_lm=False).
        
    #     Parameters:
    #     - probs: Probability distributions over tokens (can be numpy array or tensor).
    #     - beam_size: Maximum number of beams to keep at each timestep.
    #     - debug (bool): Whether to print debug information.

    #     Returns:
    #     - List[Tuple[str, float]]: Decoded beams with their scores.
    #     """
    #     dp = {("", self.blank_token): 0.0}  # Initialize with log probabilities

    #     # Convert probabilities to a PyTorch tensor if needed
    #     if isinstance(probs, np.ndarray):
    #         probs = torch.tensor(probs)

    #     log_probs = torch.log(probs + 1e-8)  # Add epsilon to avoid log(0)

    #     for t, next_token_log_probs in enumerate(log_probs):
    #         if debug:
    #             print(f"\nTimestep {t}:")
    #             top_k = torch.topk(torch.exp(next_token_log_probs), k=5)  # Convert back to prob for debugging
    #             print("Top 5 tokens and probs:")
    #             for i, (p, idx) in enumerate(zip(top_k.values, top_k.indices)):
    #                 token = self.ind2char[idx.item()]
    #                 print(f"{token}: {p:.4f}")

    #         # Expand and merge paths
    #         dp = self.expand_and_merge_path(dp, next_token_log_probs)

    #         # Normalize scores periodically to prevent underflow
    #         if len(dp) > 0:
    #             max_score = max(score for _, score in dp.items())
    #             dp = {key: score - max_score for key, score in dp.items()}

    #         # Prune to keep top beams
    #         dp = self.truncate_paths(dp, beam_size)

    #         if debug:
    #             print("\nCurrent beam state:")
    #             for (text, last_char), score in list(dp.items())[:5]:
    #                 print(f"Text: '{text}', Last: '{last_char}', Score: {score:.4f}")

    #     # Prepare final results with length normalization
    #     final_beams = []
    #     for (text, _), score in sorted(dp.items(), key=lambda x: -x[1])[:beam_size]:
    #         # Clean up text
    #         text = ' '.join(text.split())
    #         # Length normalization
    #         text_len = max(1, len(text.split()))
    #         normalized_score = score / text_len
    #         final_beams.append((text, normalized_score))

    #     if debug:
    #         print("\nFinal beams:")
    #         for text, score in final_beams[:3]:
    #             print(f"Text: '{text}', Score: {score:.4f}")

    #     return final_beams

    # def truncate_paths(self, dp, beam_size):
    #     """Regular beam search truncation"""
    #     return dict(sorted(dp.items(), key=lambda x: -x[1])[:beam_size])    

    # def expand_and_merge_path(self, dp, next_token_log_probs):
    #     """
    #     Expand and merge paths for CTC decoding without LM.

    #     Parameters:
    #     - dp: Current beams with their scores.
    #     - next_token_log_probs: Log probabilities of the next tokens.

    #     Returns:
    #     - Updated beams with scores.
    #     """
    #     new_dp = defaultdict(lambda: float('-inf'))  # Initialize with log-prob -inf

    #     for ind, next_token_log_prob in enumerate(next_token_log_probs):
    #         if ind >= len(self.ind2char):  # Skip invalid indices
    #             continue
    #         curr_char = self.ind2char[ind]  # Get character corresponding to the index

    #         for (prefix, last_char), log_prob in dp.items():
    #             # Avoid consecutive duplicates unless blank token
    #             if last_char == curr_char and curr_char != self.blank_token:
    #                 new_prefix = prefix
    #             else:
    #                 if curr_char != self.blank_token:
    #                     if curr_char == " " and prefix.endswith(" "):  # Avoid consecutive spaces
    #                         continue
    #                     new_prefix = prefix + curr_char
    #                 else:
    #                     new_prefix = prefix

    #             # Update score for the new prefix
    #             new_log_prob = log_prob + next_token_log_prob
    #             new_dp[(new_prefix, curr_char)] = max(new_dp[(new_prefix, curr_char)], new_log_prob)

    #     return new_dp








    def test_language_model(self):
        """Debug function to verify LM functionality"""
        print("\nTesting Language Model...")

        if self.lm is None:
            print("Error: Language model is not loaded!")
            return

        test_sentences = [
            "this is a good sentence",
            "this is also a good sentence",
            "thiss iss nott aa goodd sentencee",
            "random word salad box cat",
            "the cat sat on the mat",
            "",
            "a",
        ]

        print("\nTesting individual sentences:")
        for sentence in test_sentences:
            score = self.score_with_lm(sentence)
            print(f"\nText: '{sentence}'")
            print(f"LM Score: {score:.4f}")

        test_prefixes = [
            "the quick brown",
            "how are",
            "thank",
            "nice to",
        ]

        print("\nTesting word completions:")
        for prefix in test_prefixes:
            print(f"\nPrefix: '{prefix}'")
            completions = [
                prefix + " " + word for word in ["you", "fox", "cat", "xyz", "meet"]
            ]
            scores = [(completion, self.score_with_lm(completion)) 
                    for completion in completions]
            scores.sort(key=lambda x: x[1], reverse=True)
            print("Top completions by score:")
            for completion, score in scores[:3]:
                print(f"  '{completion}': {score:.4f}")

    def score_with_lm(self, text: str) -> float:
        """
        Score text using language model, handling edge cases.

        Parameters:
        - text (str): The input text to score.

        Returns:
        - float: LM score.
        """
        if self.lm is None:
            return 0.0

        if not text or len(text.strip()) == 0:
            return float('-inf')

        text = text.lower().strip()
        return self.lm.score(text, bos=True, eos=True)

    def _basic_ctc_decode(self, logits: np.ndarray, sequence_length: int) -> List[str]:
        """
        Basic CTC decoding without LM.

        Parameters:
        - logits (np.ndarray): Logits from the model.
        - sequence_length (int): Length of the sequence to decode.

        Returns:
        - List[str]: Decoded text.
        """
        argmax_indices = np.argmax(logits, axis=-1)

        if len(argmax_indices.shape) == 0:
            argmax_indices = np.array([argmax_indices])

        if len(argmax_indices.shape) == 1:
            argmax_indices = np.expand_dims(argmax_indices, axis=0)

        predictions = []
        for sequence in argmax_indices:
            decoded = []
            last_idx = None

            for idx in sequence[:sequence_length]:
                if idx != self.blank_index and idx != last_idx:
                    decoded.append(self.ind2char[idx])
                last_idx = idx

            text = "".join(decoded)
            if self.use_bpe and self.processor:
                text = self.processor.tokenizer.clean_up_tokenization(text)
            predictions.append(text)

        return predictions

    @staticmethod
    def normalize_text(text: str) -> str:
        """Normalize input text."""
        text = text.lower()
        text = re.sub(r"[^a-z ]", "", text)
        return text

    def test_decoder(self, sample_text: str = "test decoder functionality"):
        """Test the decoder setup."""
        print("\nTesting decoder configuration...")

        encoded = self.encode(sample_text)
        decoded = self.decode(encoded[0].tolist())
        print(f"Original text: {sample_text}")
        print(f"Basic decode: {decoded}")

        sequence_length = 50
        vocab_size = len(self)
        fake_logits = torch.randn(1, sequence_length, vocab_size)
        fake_length = torch.tensor([sequence_length])

        if self.decoder is not None:
            print("\nTesting pyctcdecode integration...")
            decoded_with_lm = self.ctc_decode(fake_logits)
            print(f"Decoded with LM: {decoded_with_lm}")

            print(f"\nBeam width: {self.beam_size}")
            print(f"LM weight: {self.lm_weight}")
        else:
            print("\nNo language model loaded - using basic CTC decoding")
            basic_decoded = self._basic_ctc_decode(fake_logits.numpy(), sequence_length)
            print(f"Basic CTC decoded: {basic_decoded[0]}")

    def __len__(self):
        """Return the size of the vocabulary."""
        return len(self.vocab)

    def ctc_decode(self, logits: Union[torch.Tensor, List[int], np.ndarray]) -> str:
        """
        Perform CTC decoding on logits.

        Parameters:
        - logits (Union[torch.Tensor, List[int], np.ndarray]): Logits or token indices.

        Returns:
        - str: Decoded text.
        """
        if isinstance(logits, np.ndarray):
            logits = torch.from_numpy(logits)
        elif isinstance(logits, list):
            logits = torch.tensor(logits)

        if logits.dim() == 3:
            decoded_text = self.decode_logits(logits)
            return decoded_text
        elif logits.dim() == 2:
            decoded_text = self.decode_logits(logits)
            return decoded_text
        elif logits.dim() == 1:
            decoded_text = self.decode_indices(logits)
            return decoded_text
        else:
            raise ValueError(f"Unsupported logits shape: {logits.shape}. Expected 1D, 2D, or 3D.")















# import re
# from collections import defaultdict
# import torch
# import kenlm
# from transformers import Wav2Vec2Processor, AutoProcessor
# from pyctcdecode import build_ctcdecoder
# import numpy as np
# import os

# from typing import List, Tuple, Optional, Union

# os.environ["TOKENIZERS_PARALLELISM"] = "false"

# class CTCTextEncoder:
#     def __init__(
#         self,
#         arpa_path: Optional[str] = None,
#         binary_path: Optional[str] = None,
#         unigram_path: Optional[str] = None,
#         pretrained_tokenizer: str = "facebook/wav2vec2-base-960h",
#         lm_weight: float = 0.5,
#         beam_size: int = 100,
#         use_lm: bool = False,     # **Added use_lm parameter**
#         use_bpe: bool = False,    # **Added use_bpe parameter**
#         blank_token: str = "[pad]",  # Blank token as <pad> for Wav2Vec2
#         unk_token: str = "[unk]",     # UNK token
#         **kwargs
#     ):
#         """
#         Initialize encoder with conditional tokenizer/processor and language model.

#         Parameters:
#         - use_lm (bool): Whether to use the Language Model (LM) during decoding.
#         - use_bpe (bool): Whether to use Byte Pair Encoding (BPE) via tokenizer/processor.
#                            If False, perform character-based encoding/decoding without tokenizer.
#         """
#         self.beam_size = beam_size
#         self.lm_weight = lm_weight
#         self.arpa_path = arpa_path
#         self.binary_path = binary_path
#         self.blank_token = blank_token
#         self.unk_token = unk_token
#         self.use_lm = use_lm # False
#         self.use_bpe = False # False MANUAL FOR NOW
#         # self.use_bpe = False # use_bpe # False # use_bpe
#         self.printed_samples = 0
#         self.max_printed_samples = 5
#         print('CTC Text Encoder:')
#         print('pretrained_tokenizer:', pretrained_tokenizer)
#         print('lm_weight:', lm_weight)
#         print('beam_size:', beam_size)
#         print('binary_path:', binary_path)
#         print('use_lm:', self.use_lm)
#         print('use_bpe:', self.use_bpe)

#         # Load unigrams if provided
#         self.unigrams = None
#         if unigram_path and os.path.exists(unigram_path):
#             print(f"Loading unigrams from: {unigram_path}")
#             with open(unigram_path, 'r', encoding='utf-8') as f:
#                 self.unigrams = [line.strip().lower() for line in f if line.strip()]
#             print(f"Loaded {len(self.unigrams)} unigrams")

#         # Initialize the tokenizer or set up character-based vocab
#         self._initialize_vocabulary(pretrained_tokenizer)

#         # Create index mappings
#         self.ind2char = dict(enumerate(self.vocab))
#         self.char2ind = {v: k for k, v in self.ind2char.items()}
#         self.blank_index = self.char2ind.get(self.blank_token, None)

#         print(f"\nVocabulary Info:")
#         print(f"Size: {len(self.vocab)}")
#         print("Full Vocabulary (up to first 50 tokens):", self.vocab[:50])
#         print(f"Blank token: {self.blank_token}, Blank index: {self.blank_index}")

#         print("Sample ind2char mappings:", {k: self.ind2char[k] for k in list(self.ind2char.keys())[:10]})
#         print("Sample char2ind mappings:", {k: self.char2ind[k] for k in list(self.char2ind.keys())[:10]})

#         # **Conditionally initialize language model based on use_lm**
#         if self.use_lm:
#             self._initialize_language_model()
#         else:
#             print("Language model usage is disabled.")
#             self.lm = None
#             self.decoder = None

#     def _initialize_vocabulary(self, pretrained_tokenizer: str):
#         """
#         Initialize the vocabulary either using a tokenizer/processor (BPE) or character-based.

#         Parameters:
#         - pretrained_tokenizer (str): Name or path of the pretrained tokenizer.
#         """
#         if self.use_bpe:
#             print("Initializing tokenizer and using BPE for encoding/decoding.")
#             # Initialize AutoProcessor (you can switch to Wav2Vec2Processor if preferred)
#             # self.processor = AutoProcessor.from_pretrained(pretrained_tokenizer)
#             self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

#             # Get the vocabulary from tokenizer and process it
#             vocab_dict = self.processor.tokenizer.get_vocab()
#             sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}

#             self.labels = list(sorted_vocab_dict.keys())

#             # Replace '|' with ' ' in vocab
#             original_vocab = list(sorted_vocab_dict.keys())
#             self.vocab = [t.replace('|', ' ') for t in original_vocab]

#             # Add the unique blank token to the tokenizer's vocabulary if not present
#             if self.blank_token not in self.processor.tokenizer.get_vocab():
#                 self.processor.tokenizer.add_tokens([self.blank_token])
#                 print(f"Added '{self.blank_token}' to the tokenizer's vocabulary.")
#             else:
#                 print(f"'{self.blank_token}' already exists in the tokenizer's vocabulary.")

#             # Add the UNK token if not present
#             if self.unk_token not in self.processor.tokenizer.get_vocab():
#                 self.processor.tokenizer.add_tokens([self.unk_token])
#                 print(f"Added '{self.unk_token}' to the tokenizer's vocabulary.")
#             else:
#                 print(f"'{self.unk_token}' already exists in the tokenizer's vocabulary.")

#             # **Update ind2char and char2ind after adding tokens**
#             self.ind2char = dict(enumerate(self.vocab))
#             self.char2ind = {v: k for k, v in self.ind2char.items()}
#             self.blank_index = self.char2ind.get(self.blank_token, None)

#             # Debug: Print a few tokens after modification
#             print("Modified Vocabulary (first 20 tokens):", self.vocab[:20])
#         else:
#             print("Initializing character-based vocabulary without using tokenizer.")
#             # Define a simple character-based vocabulary: a-z and space
#             self.vocab = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
#                          'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x',
#                          'y', 'z', ' ']
#             # Optionally, add special tokens
#             self.vocab += [self.blank_token, self.unk_token]

#             # **Create index mappings**
#             self.ind2char = dict(enumerate(self.vocab))
#             self.char2ind = {v: k for k, v in self.ind2char.items()}
#             self.blank_index = self.char2ind.get(self.blank_token, None)

#             # No processor is used in this mode
#             self.processor = None

#     def _initialize_language_model(self):
#         """Initialize language model with explicit blank token handling."""
#         self.lm = None
#         self.decoder = None

#         model_path = self.binary_path if self.binary_path else self.arpa_path
#         print('model_path: ', model_path)
#         if not model_path or not os.path.exists(model_path):
#             print("No language model path provided or file does not exist.")
#             return

#         try:
#             self.lm = kenlm.Model(model_path)
#             print(f"Loaded {'binary' if self.binary_path else 'ARPA'} language model.")

#             decoder_config = {
#                 "labels": self.labels if self.use_bpe else self.vocab,
#                 "kenlm_model_path": model_path,
#                 "alpha": self.lm_weight,
#                 "beta": 0.1,
#                 "unk_score_offset": -10.0,
#             }

#             if self.unigrams:
#                 print("\n--- Unigrams List ---")
#                 # Save the unigrams to a file
#                 with open("unigrams_list.txt", "w", encoding="utf-8") as f:
#                     for unigram in self.unigrams:
#                         f.write(f"{unigram}\n")
#                 print(f"Unigrams list saved to 'unigrams_list.txt'. Total unigrams: {len(self.unigrams)}")
#                 print("----------------------\n")
#                 decoder_config["unigrams"] = self.unigrams

#             self.decoder = build_ctcdecoder(**decoder_config)
#             print("Successfully initialized language model and decoder.")

#         except Exception as e:
#             print(f"Warning: Failed to initialize decoder: {str(e)}")
#             self.decoder = None

#     def encode(self, text: str) -> torch.Tensor:
#         """
#         Encode text either using tokenizer/processor (BPE) or character-based encoding.

#         Parameters:
#         - text (str): The input text to encode.

#         Returns:
#         - torch.Tensor: Tensor of token indices.
#         """
#         debug = False

#         if self.printed_samples < self.max_printed_samples:
#             original_text = text
#             text = self.normalize_text(text)
#             if debug:
#                 print(f"samples: {str(self.printed_samples)}")
#                 print(f"\nEncoding text:")
#                 print(f" Original: '{original_text}'")
#                 print(f" Normalized: '{text}'")
#                 for ch in text:
#                     print(ch, ord(ch))

#         if self.use_bpe:
#             try:
#                 encoded = self.processor.tokenizer(text, return_tensors="pt", padding=False, truncation=False)
#                 token_indices = encoded.input_ids[0].tolist()
#                 if self.printed_samples < self.max_printed_samples:
#                     # Convert indices to tokens from self.vocab
#                     tokens = [self.vocab[idx] if 0 <= idx < len(self.vocab) else "<invalid>" for idx in token_indices]
#                     # Optionally print tokens for debugging
#                     # print(f" Tokens (lowercased and '|'->' '): {tokens}")
#                     # print(f" Token indices: {token_indices}")
#                     self.printed_samples += 1
#                 return torch.tensor(token_indices).unsqueeze(0)
#             except KeyError as e:
#                 unknown_tokens = set([token for token in text.split() if token not in self.char2ind])
#                 raise Exception(f"Unknown tokens: '{' '.join(unknown_tokens)}'")
#             except Exception as e:
#                 raise Exception(f"Encoding error: {str(e)}")
#         else:
#             # Character-based encoding
#             try:
#                 normalized_text = self.normalize_text(text)
#                 token_indices = [self.char2ind.get(char, self.char2ind.get(self.unk_token)) for char in normalized_text]
#                 return torch.tensor(token_indices).unsqueeze(0)
#             except Exception as e:
#                 raise Exception(f"Encoding error: {str(e)}")

#     def decode(self, indices: List[int]) -> str:
#         """
#         Decode indices to text using beam search decoder if available.

#         Parameters:
#         - indices (List[int]): List of token indices.

#         Returns:
#         - str: Decoded text.
#         """
#         if self.decoder:
#             decoded_text = self.decoder.decode(indices)
#             # Convert to lower case
#             decoded_text = decoded_text.lower()
#             return decoded_text
#         else:
#             decoded_text = self.decode_simple(indices)
#             # Convert to lower case
#             decoded_text = decoded_text.lower()
#             return decoded_text  # Ensure the decoded text is returned

#     def decode_simple(self, indices: List[int]) -> str:
#         """
#         Simple CTC decoding without language model.
#         Collapses consecutive duplicate tokens and removes blanks.

#         Parameters:
#         - indices (List[int]): List of token indices.

#         Returns:
#         - str: Decoded text.
#         """
#         decoded_chars = []
#         previous_idx = None

#         for idx in indices:
#             if idx == self.blank_index:
#                 previous_idx = idx
#                 continue  # Skip blank tokens
#             if idx == previous_idx:
#                 continue  # Skip duplicate tokens
#             if 0 <= idx < len(self.ind2char):
#                 char = self.ind2char[idx]
#                 decoded_chars.append(char)
#             previous_idx = idx

#         # Join characters without spaces and convert to lowercase
#         text = "".join(decoded_chars).strip().lower()

#         if self.use_bpe and self.processor:
#             # Clean up tokenization using the tokenizer's method
#             return self.processor.tokenizer.clean_up_tokenization(text)
#         else:
#             # For character-based decoding, additional cleanup can be added if necessary
#             return text

#     def decode_logits(self, logits: Union[torch.Tensor, List[List[float]], np.ndarray]) -> str:
#         """
#         Decode logits using the decoder if available, otherwise use greedy decoding.

#         Parameters:
#         - logits (Union[torch.Tensor, List[List[float]], np.ndarray]): Logits output from the model.

#         Returns:
#         - str: Decoded text.
#         """
#         if isinstance(logits, torch.Tensor):
#             logits = logits.cpu().numpy()
#         elif isinstance(logits, list):
#             logits = np.array(logits)
#         elif not isinstance(logits, np.ndarray):
#             raise TypeError("logits must be a torch.Tensor, list of lists, or numpy.ndarray")

#         if logits.ndim == 3:
#             logits = logits[0]

#         if logits.ndim != 2:
#             raise ValueError(f"Logits should be 2D (time_steps, vocab_size), got {logits.ndim}D")

#         if self.decoder:
#             decoded_text = self.decoder.decode(logits)
#             return decoded_text
#         else:
#             predicted_indices = np.argmax(logits, axis=-1).tolist()
#             return self.decode_simple(predicted_indices)

#     def decode_indices(self, indices: Union[torch.Tensor, List[int], np.ndarray]) -> str:
#         """
#         Decode token indices to text using simple decoding (no LM).

#         Parameters:
#         - indices (Union[torch.Tensor, List[int], np.ndarray]): Token indices.

#         Returns:
#         - str: Decoded text.
#         """
#         if isinstance(indices, torch.Tensor):
#             indices = indices.squeeze().tolist()
#         elif isinstance(indices, np.ndarray):
#             indices = indices.tolist()
#         elif not isinstance(indices, list):
#             raise TypeError("decode_indices expects a list, torch.Tensor, or numpy.ndarray.")

#         return self.decode_simple(indices)

#     def ctc_decode(self, logits: Union[torch.Tensor, List[int], np.ndarray]) -> str:
#         """
#         Perform CTC decoding on logits.

#         Parameters:
#         - logits (Union[torch.Tensor, List[int], np.ndarray]): Logits or token indices.

#         Returns:
#         - str: Decoded text.
#         """
#         if isinstance(logits, np.ndarray):
#             logits = torch.from_numpy(logits)
#         elif isinstance(logits, list):
#             logits = torch.tensor(logits)

#         if logits.dim() == 3:
#             decoded_text = self.decode_logits(logits)
#             return decoded_text
#         elif logits.dim() == 2:
#             decoded_text = self.decode_logits(logits)
#             return decoded_text
#         elif logits.dim() == 1:
#             decoded_text = self.decode_indices(logits)
#             return decoded_text
#         else:
#             raise ValueError(f"Unsupported logits shape: {logits.shape}. Expected 1D, 2D, or 3D.")

#     def ctc_beam_search(self, probs, beam_size, use_lm = False, debug: bool = False) -> List[Tuple[str, float]]:
#         """
#         Beam search with optional Language Model support.

#         Parameters:
#         - probs: Probability distributions over tokens.
#         - debug (bool): Whether to print debug information.

#         Returns:
#         - List[Tuple[str, float]]: List of decoded text with scores.
#         """
#         beam_size = self.beam_size
#         debug = False

#         if self.use_lm and self.decoder is not None:
#             try:
#                 if isinstance(probs, torch.Tensor):
#                     probs = probs.cpu().numpy()
#                 elif isinstance(probs, list):
#                     probs = np.array(probs)
#                 elif isinstance(probs, np.ndarray):
#                     pass
#                 else:
#                     raise TypeError("probs must be a torch.Tensor, list, or numpy.ndarray")

#                 beams = self.decoder.decode_beams(
#                     probs,
#                     beam_prune_logp=-10.0,
#                     token_min_logp=-5.0,
#                     hotwords=[],
#                     hotword_weight=10.0,
#                 )

#                 formatted_beams = []
#                 for beam in beams[:beam_size]:
#                     text = beam[0]
#                     acoustic_score = beam[3]
#                     lm_score = beam[4]

#                     if self.use_bpe and self.processor:
#                         text = self.processor.tokenizer.clean_up_tokenization(text)
#                     text = text.lower().strip()

#                     combined_score = (1 - self.lm_weight) * acoustic_score + self.lm_weight * lm_score
#                     text_len = max(1, len(text.split())) if self.use_bpe else max(1, len(text))
#                     normalized_score = combined_score / text_len

#                     formatted_beams.append((text, normalized_score))

#                 if debug:
#                     print("\nFormatted beam results with Language Model:")
#                     for text, score in formatted_beams[:3]:
#                         print(f"Text: '{text}', Score: {score:.4f}")

#                 if formatted_beams:
#                     return sorted(formatted_beams, key=lambda x: -x[1])
#                 else:
#                     print("No valid beams found, falling back to standard beam search")
#                     return self._standard_beam_search(probs, debug)

#             except Exception as e:
#                 print(f"Beam search with LM failed: {str(e)}, falling back to standard beam search")
#                 return self._standard_beam_search(probs, debug)
#         else:
#             return self._standard_beam_search(probs, debug)

#     def _standard_beam_search(self, probs, debug: bool = False) -> List[Tuple[str, float]]:
#         """
#         Original beam search implementation without Language Model.

#         Parameters:
#         - probs: Probability distributions over tokens.
#         - debug (bool): Whether to print debug information.

#         Returns:
#         - List[Tuple[str, float]]: List of decoded text with scores.
#         """
#         beam_size = self.beam_size

#         if isinstance(probs, np.ndarray):
#             probs = torch.from_numpy(probs)

#         if probs.device != torch.device('cpu'):
#             probs = probs.cpu()

#         dp = {("", self.blank_token): 0.0}
#         log_probs = torch.log(probs + 1e-8)

#         if debug:
#             print("\nStarting beam search with beam size:", beam_size)

#         for t, prob in enumerate(log_probs):
#             new_dp = defaultdict(lambda: float('-inf'))
#             top_k = torch.topk(prob, k=min(beam_size, len(prob)))

#             if debug and t < self.max_printed_samples:
#                 print(f"\nTimestep {t}:")
#                 print("Top tokens:", [(self.ind2char[idx.item()], val.item()) 
#                                     for val, idx in zip(top_k.values, top_k.indices)])

#             for val, ind in zip(top_k.values, top_k.indices):
#                 curr_char = self.ind2char[ind.item()]
#                 next_token_log_prob = val.item()

#                 for (prefix, last_char), log_prob in dp.items():
#                     if last_char == curr_char and curr_char != " ":
#                         new_prefix = prefix
#                     else:
#                         if curr_char != self.blank_token:
#                             if curr_char == " " and prefix.endswith(" "):
#                                 continue
#                             new_prefix = prefix + curr_char
#                         else:
#                             new_prefix = prefix

#                     new_log_prob = log_prob + next_token_log_prob
#                     key = (new_prefix, curr_char)
#                     new_dp[key] = max(new_dp[key], new_log_prob)

#             if len(new_dp) > 0:
#                 max_score = max(score for _, score in new_dp.items())
#                 new_dp = {key: score - max_score for key, score in new_dp.items()}

#             dp = dict(sorted(new_dp.items(), key=lambda x: -x[1])[:beam_size])

#             if debug and t < 2:
#                 print("\nCurrent beam:")
#                 for (text, last_char), score in list(dp.items())[:3]:
#                     print(f"Text: '{text}', Last: '{last_char}', Score: {score:.4f}")

#         final_beams = []
#         for (text, _), score in dp.items():
#             if self.use_bpe and self.processor:
#                 text = self.processor.tokenizer.clean_up_tokenization(text)
#             text = text.lower().strip()
#             text_len = max(1, len(text.split())) if self.use_bpe else max(1, len(text))
#             normalized_score = score / text_len
#             final_beams.append((text, normalized_score))

#         final_beams.sort(key=lambda x: -x[1])
#         if not final_beams:
#             final_beams = [("", float('-inf'))]

#         return final_beams[:beam_size]

#     def test_language_model(self):
#         """Debug function to verify LM functionality"""
#         print("\nTesting Language Model...")

#         if self.lm is None:
#             print("Error: Language model is not loaded!")
#             return

#         test_sentences = [
#             "this is a good sentence",
#             "this is also a good sentence",
#             "thiss iss nott aa goodd sentencee",
#             "random word salad box cat",
#             "the cat sat on the mat",
#             "",
#             "a",
#         ]

#         print("\nTesting individual sentences:")
#         for sentence in test_sentences:
#             score = self.score_with_lm(sentence)
#             print(f"\nText: '{sentence}'")
#             print(f"LM Score: {score:.4f}")

#         test_prefixes = [
#             "the quick brown",
#             "how are",
#             "thank",
#             "nice to",
#         ]

#         print("\nTesting word completions:")
#         for prefix in test_prefixes:
#             print(f"\nPrefix: '{prefix}'")
#             completions = [
#                 prefix + " " + word for word in ["you", "fox", "cat", "xyz", "meet"]
#             ]
#             scores = [(completion, self.score_with_lm(completion)) 
#                     for completion in completions]
#             scores.sort(key=lambda x: x[1], reverse=True)
#             print("Top completions by score:")
#             for completion, score in scores[:3]:
#                 print(f"  '{completion}': {score:.4f}")

#     def score_with_lm(self, text: str) -> float:
#         """
#         Score text using language model, handling edge cases.

#         Parameters:
#         - text (str): The input text to score.

#         Returns:
#         - float: LM score.
#         """
#         if self.lm is None:
#             return 0.0

#         if not text or len(text.strip()) == 0:
#             return float('-inf')

#         text = text.lower().strip()
#         return self.lm.score(text, bos=True, eos=True)

#     def _basic_ctc_decode(self, logits: np.ndarray, sequence_length: int) -> List[str]:
#         """
#         Basic CTC decoding without LM.

#         Parameters:
#         - logits (np.ndarray): Logits from the model.
#         - sequence_length (int): Length of the sequence to decode.

#         Returns:
#         - List[str]: Decoded text.
#         """
#         argmax_indices = np.argmax(logits, axis=-1)

#         if len(argmax_indices.shape) == 0:
#             argmax_indices = np.array([argmax_indices])

#         if len(argmax_indices.shape) == 1:
#             argmax_indices = np.expand_dims(argmax_indices, axis=0)

#         predictions = []
#         for sequence in argmax_indices:
#             decoded = []
#             last_idx = None

#             for idx in sequence[:sequence_length]:
#                 if idx != self.blank_index and idx != last_idx:
#                     decoded.append(self.ind2char[idx])
#                 last_idx = idx

#             text = "".join(decoded)
#             if self.use_bpe and self.processor:
#                 text = self.processor.tokenizer.clean_up_tokenization(text)
#             predictions.append(text)

#         return predictions

#     @staticmethod
#     def normalize_text(text: str) -> str:
#         """Normalize input text."""
#         text = text.lower()
#         text = re.sub(r"[^a-z ]", "", text)
#         return text

#     def test_decoder(self, sample_text: str = "test decoder functionality"):
#         """Test the decoder setup."""
#         print("\nTesting decoder configuration...")

#         encoded = self.encode(sample_text)
#         decoded = self.decode(encoded[0].tolist())
#         print(f"Original text: {sample_text}")
#         print(f"Basic decode: {decoded}")

#         sequence_length = 50
#         vocab_size = len(self)
#         fake_logits = torch.randn(1, sequence_length, vocab_size)
#         fake_length = torch.tensor([sequence_length])

#         if self.decoder is not None:
#             print("\nTesting pyctcdecode integration...")
#             decoded_with_lm = self.ctc_decode(fake_logits)
#             print(f"Decoded with LM: {decoded_with_lm}")

#             print(f"\nBeam width: {self.beam_size}")
#             print(f"LM weight: {self.lm_weight}")
#         else:
#             print("\nNo language model loaded - using basic CTC decoding")
#             basic_decoded = self._basic_ctc_decode(fake_logits.numpy(), sequence_length)
#             print(f"Basic CTC decoded: {basic_decoded[0]}")

#     def __len__(self):
#         """Return the size of the vocabulary."""
#         return len(self.vocab)

#     def ctc_decode(self, logits: Union[torch.Tensor, List[int], np.ndarray]) -> str:
#         """
#         Perform CTC decoding on logits.

#         Parameters:
#         - logits (Union[torch.Tensor, List[int], np.ndarray]): Logits or token indices.

#         Returns:
#         - str: Decoded text.
#         """
#         if isinstance(logits, np.ndarray):
#             logits = torch.from_numpy(logits)
#         elif isinstance(logits, list):
#             logits = torch.tensor(logits)

#         if logits.dim() == 3:
#             decoded_text = self.decode_logits(logits)
#             return decoded_text
#         elif logits.dim() == 2:
#             decoded_text = self.decode_logits(logits)
#             return decoded_text
#         elif logits.dim() == 1:
#             decoded_text = self.decode_indices(logits)
#             return decoded_text
#         else:
#             raise ValueError(f"Unsupported logits shape: {logits.shape}. Expected 1D, 2D, or 3D.")
        

In [9]:
from transformers import Wav2Vec2Processor


pretrained_tokenizer = "facebook/wav2vec2-base-960h"
processor = Wav2Vec2Processor.from_pretrained(pretrained_tokenizer)

vocab_dict = processor.tokenizer.get_vocab()
sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}

labels = list(sorted_vocab_dict.keys())
vocab = [token.replace('|', ' ') for token in labels]

In [41]:
blank_token = "[PAD]"
unk_token = "[UNK]"

# pretrained_tokenizer = "hf-test/xls-r-300m-sv"
# pretrained_tokenizer = "bert-base-uncased"
pretrained_tokenizer =  "facebook/wav2vec2-base-960h"

if pretrained_tokenizer == "facebook/wav2vec2-base-960h":
    tokenizer = Wav2Vec2Processor.from_pretrained(pretrained_tokenizer)
    # tokenizer.tokenizer = tokenizer
    vocab_dict = tokenizer.tokenizer.get_vocab()
    sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
    labels = list(sorted_vocab_dict.keys())

else:
    tokenizer = None
    tokenizer = AutoTokenizer.from_pretrained(pretrained_tokenizer)
    vocab_dict = tokenizer.vocab
    labels = list(vocab_dict.keys())
    

vocab = [token.replace('|', ' ') for token in labels]

# if blank_token not in tokenizer.get_vocab():
#     tokenizer.add_tokens([blank_token])
#     print(f"Added '{blank_token}' to the tokenizer's vocabulary.")

# if unk_token not in tokenizer.get_vocab():
#     tokenizer.add_tokens([unk_token])
#     print(f"Added '{unk_token}' to the tokenizer's vocabulary.")

if blank_token not in tokenizer.tokenizer.get_vocab():
    tokenizer.tokenizer.add_tokens([blank_token])
    print(f"Added '{blank_token}' to the tokenizer's vocabulary.")

if unk_token not in tokenizer.tokenizer.get_vocab():
    tokenizer.tokenizer.add_tokens([unk_token])
    print(f"Added '{unk_token}' to the tokenizer's vocabulary.")


ind2char = dict(enumerate(vocab))
char2ind = {v: k for k, v in ind2char.items()}
blank_index = char2ind.get(blank_token, None)

Added '[PAD]' to the tokenizer's vocabulary.
Added '[UNK]' to the tokenizer's vocabulary.


In [43]:
tokenizer.tokenizer.total_vocab_size

34

In [30]:
tokenizer.vocab_size

30522

In [23]:
# bert

pretrained_tokenizer = "bert-base-uncased"

tokenizer_bert = AutoTokenizer.from_pretrained(pretrained_tokenizer)

In [27]:
tokenizer_bert.vocab_size
tokenizer_bert.vocab

{'bye': 9061,
 'infections': 15245,
 'potsdam': 26554,
 'therese': 25598,
 'projections': 21796,
 'grassroots': 23299,
 '##tham': 22536,
 'stirring': 18385,
 'organise': 22933,
 '[unused973]': 978,
 '##chman': 19944,
 'ambient': 17093,
 'edouard': 21627,
 'fragments': 10341,
 'reviewer': 12027,
 '##yck': 28377,
 'struck': 4930,
 'utility': 9710,
 'cane': 11942,
 'vital': 8995,
 '##ull': 18083,
 '##dar': 7662,
 'please': 3531,
 'woody': 13703,
 'warden': 13745,
 'available': 2800,
 'toulon': 27160,
 'belong': 7141,
 '##eon': 10242,
 '##cture': 14890,
 'yankees': 11081,
 'nectar': 24816,
 'observer': 9718,
 '##folk': 29284,
 'packaging': 14793,
 '##〈': 30163,
 'codes': 9537,
 'thoughts': 4301,
 'signage': 29404,
 '有': 1873,
 'multiple': 3674,
 'scar': 11228,
 'hiv': 9820,
 'enthusiasts': 20305,
 'plagued': 17808,
 'portuguese': 5077,
 'decca': 21079,
 'ligand': 27854,
 'syndrome': 8715,
 'rode': 8469,
 '##ব': 29904,
 'nap': 18996,
 'avenge': 24896,
 'optic': 22816,
 'seduction': 26962,
 

In [32]:
vocab


['a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z',
 'ä',
 'å',
 'é',
 'ô',
 'ö',
 'ü',
 ' ',
 '[UNK]',
 '[PAD]']

In [10]:
vocab_dict

{'<pad>': 0,
 '<s>': 1,
 '</s>': 2,
 '<unk>': 3,
 '|': 4,
 'E': 5,
 'T': 6,
 'A': 7,
 'O': 8,
 'N': 9,
 'I': 10,
 'H': 11,
 'S': 12,
 'R': 13,
 'D': 14,
 'L': 15,
 'U': 16,
 'M': 17,
 'W': 18,
 'C': 19,
 'F': 20,
 'G': 21,
 'Y': 22,
 'P': 23,
 'B': 24,
 'V': 25,
 'K': 26,
 "'": 27,
 'X': 28,
 'J': 29,
 'Q': 30,
 'Z': 31}

In [20]:
# %cd ..
%pwd

/workspace/sound_asr


'/workspace/sound_asr'

In [13]:
def test_encoder(encoder_cls, text="hello world"):
    """
    Test the encoder for various configurations of use_bpe and use_lm.

    Parameters:
    - encoder_cls: Class of the encoder to test.
    - text (str): The input text to encode and decode.
    """
    test_cases = [
        {"use_bpe": False, "use_lm": False},
        {"use_bpe": True, "use_lm": False},
        {"use_bpe": False, "use_lm": True},
        {"use_bpe": True, "use_lm": True},
    ]

    for case in test_cases:
        print(f"\n--- Testing with use_bpe={case['use_bpe']} and use_lm={case['use_lm']} ---")

        # Initialize encoder with the test case settings
        encoder = encoder_cls(
            use_bpe=case["use_bpe"],
            use_lm=case["use_lm"],
            pretrained_tokenizer="facebook/wav2vec2-base-960h",
            binary_path="4-gram_lc_correct.bin",
            unigram_path="librispeech-vocab.txt"
        )

        # Test encoding
        encoded = encoder.encode(text)
        print(f"Encoded: {encoded}")

        # Test decoding from encoded indices
        decoded = encoder.decode_simple(encoded[0].tolist())
        print(f"Decoded: {decoded}")

        # Test decoding from logits (fake logits for demonstration)
        sequence_length = encoded.shape[1]
        vocab_size = len(encoder.vocab)
        logits = torch.randn(1, sequence_length, vocab_size)  # Random logits for testing

        if case["use_lm"]:
            decoded_with_lm = encoder.ctc_decode(logits)
            print(f"Decoded with LM: {decoded_with_lm}")
        else:
            decoded_without_lm = encoder.decode_logits(logits)
            print(f"Decoded without LM: {decoded_without_lm}")

# Example usage
test_encoder(CTCTextEncoder)



--- Testing with use_bpe=False and use_lm=False ---
CTC Text Encoder:
pretrained_tokenizer: facebook/wav2vec2-base-960h
lm_weight: 0.5
beam_size: 100
binary_path: 4-gram_lc_correct.bin
use_lm: False
use_bpe: False
Loading unigrams from: librispeech-vocab.txt
Loaded 200000 unigrams
Initializing character-based vocabulary without using tokenizer.

Vocabulary Info:
Size: 29
Full Vocabulary (up to first 50 tokens): ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', ' ', '[pad]', '[unk]']
Blank token: [pad], Blank index: 27
Sample ind2char mappings: {0: 'a', 1: 'b', 2: 'c', 3: 'd', 4: 'e', 5: 'f', 6: 'g', 7: 'h', 8: 'i', 9: 'j'}
Sample char2ind mappings: {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9}
Language model usage is disabled.
Encoded: tensor([[ 7,  4, 11, 11, 14, 26, 22, 14, 17, 11,  3]])
Decoded: hello world
Decoded without LM: ojxqzbu es

--- Testing with use_bpe=True

Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as BPE type. Is this correct?


Successfully initialized language model and decoder.
Encoded: tensor([[3, 3, 3, 3, 3, 4, 3, 3, 3, 3, 3]])
Decoded: <unk><unk><unk><unk><unk> <unk><unk><unk><unk><unk>
Decoded with LM: olivo


## Latest

In [6]:
%cd ..

/workspace/sound_asr


In [32]:
import re
from collections import defaultdict
import torch
import kenlm
from transformers import Wav2Vec2Processor, AutoProcessor, AutoTokenizer
from pyctcdecode import build_ctcdecoder
import numpy as np
import os

from typing import List, Tuple, Optional, Union

os.environ["TOKENIZERS_PARALLELISM"] = "false"

class CTCTextEncoder:
    def __init__(
        self,
        arpa_path: Optional[str] = None,
        # binary_path: Optional[str] = None,
        # unigram_path: Optional[str] = None,
        binary_path: Optional[str] = "4-gram_lc_correct.bin",
        unigram_path: Optional[str] = "librispeech-vocab.txt",
        pretrained_tokenizer: str = "facebook/wav2vec2-base-960h",
        lm_weight: float = 0.5,
        beam_size: int = 100,
        use_lm: bool = False,     # **Added use_lm parameter**
        use_bpe: bool = False,    # **Added use_bpe parameter**
        # blank_token: str = "<pad>",  # Blank token as <pad> for Wav2Vec2
        # unk_token: str = "<unk>",     # UNK token
        blank_token: str = "[PAD]",
        unk_token: str = "[UNK]",
        **kwargs
    ):
        """
        Initialize encoder with conditional tokenizer/processor and language model.

        Parameters:
        - use_lm (bool): Whether to use the Language Model (LM) during decoding.
        - use_bpe (bool): Whether to use Byte Pair Encoding (BPE) via tokenizer/processor.
                           If False, perform character-based encoding/decoding without tokenizer.
        """
        self.beam_size = beam_size
        self.lm_weight = lm_weight
        self.arpa_path = arpa_path
        self.binary_path = binary_path
        self.blank_token = blank_token
        self.unk_token = unk_token
        self.use_lm = use_lm # False
        self.use_bpe = use_bpe # False MANUAL FOR NOW
        # self.use_bpe = False # use_bpe # False # use_bpe
        self.printed_samples = 0
        self.max_printed_samples = 5
        print('CTC Text Encoder:')
        print('pretrained_tokenizer:', pretrained_tokenizer)
        print('lm_weight:', lm_weight)
        print('beam_size:', beam_size)
        print('binary_path:', binary_path)
        print('use_lm:', self.use_lm)
        print('use_bpe:', self.use_bpe)
        self.pretrained_tokenizer = pretrained_tokenizer

        # Load unigrams if provided
        self.unigrams = None
        if unigram_path and os.path.exists(unigram_path):
            print(f"Loading unigrams from: {unigram_path}")
            with open(unigram_path, 'r', encoding='utf-8') as f:
                self.unigrams = [line.strip().lower() for line in f if line.strip()]
            print(f"Loaded {len(self.unigrams)} unigrams")

        # Initialize the tokenizer or set up character-based vocab
        self._initialize_vocabulary(pretrained_tokenizer)

        # Create index mappings
        self.ind2char = dict(enumerate(self.vocab))
        self.char2ind = {v: k for k, v in self.ind2char.items()}
        self.blank_index = self.char2ind.get(self.blank_token, None)

        print(f"\nVocabulary Info:")
        print(f"Size: {len(self.vocab)}")
        print("Full Vocabulary (up to first 50 tokens):", self.vocab[:50])
        print(f"Blank token: {self.blank_token}, Blank index: {self.blank_index}")

        print("Sample ind2char mappings:", {k: self.ind2char[k] for k in list(self.ind2char.keys())[:10]})
        print("Sample char2ind mappings:", {k: self.char2ind[k] for k in list(self.char2ind.keys())[:10]})

        # **Conditionally initialize language model based on use_lm**
        if self.use_lm:
            self._initialize_language_model()
        else:
            print("Language model usage is disabled.")
            self.lm = None
            self.decoder = None

    
    # def _initialize_vocabulary(self, pretrained_tokenizer: str):
    #     """
    #     Initialize the vocabulary either using a tokenizer/processor (BPE) or character-based.

    #     Parameters:
    #     - pretrained_tokenizer (str): Name or path of the pretrained tokenizer.
    #     """
    #     if self.use_bpe:
    #         print("Initializing tokenizer and using BPE for encoding/decoding.")
    #         if self.pretrained_tokenizer == "facebook/wav2vec2-base-960h":
    #             self.tokenizer = Wav2Vec2Processor.from_pretrained(pretrained_tokenizer)
    #             vocab_dict = self.tokenizer.tokenizer.get_vocab()
    #             sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
    #             self.labels = list(sorted_vocab_dict.keys())
    #             self.vocab = [token.replace('|', ' ') for token in self.labels]

    #         else:
    #             self.tokenizer = None
    #             self.tokenizer = AutoTokenizer.from_pretrained(pretrained_tokenizer)
    #             vocab_dict = self.tokenizer.vocab
    #             # self.labels = list(vocab_dict.keys())

    #             sorted_vocab = sorted(vocab_dict.items(), key=lambda item: item[1])  # Sort by index
    #             self.vocab = [token for token, _ in sorted_vocab]  # Preserve order
    #             self.labels = self.vocab  # Keep alignment with tokenizer
    #             self.vocab = [token.replace('|', ' ') for token in self.labels]

                
            
            

    #         if self.blank_token not in self.tokenizer.get_vocab():
    #             self.tokenizer.add_tokens([self.blank_token])
    #             print(f"Added '{self.blank_token}' to the tokenizer's vocabulary.")

    #         if self.unk_token not in self.tokenizer.get_vocab():
    #             self.tokenizer.add_tokens([self.unk_token])
    #             print(f"Added '{self.unk_token}' to the tokenizer's vocabulary.")

    #         self.ind2char = dict(enumerate(self.vocab))
    #         self.char2ind = {v: k for k, v in self.ind2char.items()}
    #         self.blank_index = self.char2ind.get(self.blank_token, None)
    #     else:
    #         print("Initializing character-based vocabulary without using tokenizer.")
    #         self.vocab = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
    #                     'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x',
    #                     'y', 'z', ' ']
    #         # self.vocab += [self.blank_token, self.unk_token]
    #         self.vocab = [self.blank_token] + self.vocab
    #         self.ind2char = dict(enumerate(self.vocab))
    #         self.char2ind = {v: k for k, v in self.ind2char.items()}
    #         self.blank_index = self.char2ind.get(self.blank_token, None)
    #         self.tokenizer = None



    def _initialize_vocabulary(self, pretrained_tokenizer: str):
        """
        Initialize the vocabulary using BPE or character-based approach.
        """
        if self.use_bpe:
            print("Initializing tokenizer and using BPE for encoding/decoding.")
            
            # Load tokenizer (AutoTokenizer supports most HF models)
            self.tokenizer = AutoTokenizer.from_pretrained(pretrained_tokenizer)
            
            # Get the tokenizer's vocabulary
            vocab_dict = self.tokenizer.get_vocab()
            sorted_vocab = sorted(vocab_dict.items(), key=lambda item: item[1])  # Sort by token ID
            
            # Ensure true BPE subword tokens are used
            self.vocab = [token for token, _ in sorted_vocab]
            print(f"Loaded BPE vocabulary with {len(self.vocab)} tokens.")
            
            # Add blank and unknown tokens if not already present
            if self.blank_token not in self.vocab:
                self.tokenizer.add_tokens([self.blank_token])
                self.vocab.append(self.blank_token)
                print(f"Added '{self.blank_token}' to the vocabulary.")

            if self.unk_token not in self.vocab:
                self.tokenizer.add_tokens([self.unk_token])
                self.vocab.append(self.unk_token)
                print(f"Added '{self.unk_token}' to the vocabulary.")

            # Create index mappings
            self.ind2char = dict(enumerate(self.vocab))
            self.char2ind = {v: k for k, v in self.ind2char.items()}
            self.blank_index = self.char2ind.get(self.blank_token, None)
            
            print(f"Sample vocabulary: {self.vocab[:10]}")

            self.labels = self.vocab

        else:
            print("Initializing character-based vocabulary without using tokenizer.")
            self.vocab = [' '] + list("abcdefghijklmnopqrstuvwxyzäåéôöü")
            self.vocab += [self.unk_token, self.blank_token]
            self.ind2char = dict(enumerate(self.vocab))
            self.char2ind = {v: k for k, v in self.ind2char.items()}
            self.blank_index = self.char2ind.get(self.blank_token, None)
            self.tokenizer = None




    # def encode(self, text: str) -> torch.Tensor:
    #     """
    #     Encode text either using tokenizer/processor (BPE) or character-based encoding.

    #     Parameters:
    #     - text (str): The input text to encode.

    #     Returns:
    #     - torch.Tensor: Tensor of token indices.
    #     """
    #     if self.use_bpe:
    #         if self.pretrained_tokenizer == "facebook/wav2vec2-base-960h":
    #             text = text.upper() # convert to upper for Wav2Vec2 vocab
    #         encoded = self.tokenizer(text, return_tensors="pt", padding=False, truncation=False)
    #         token_indices = encoded.input_ids[0].tolist()
    #         return torch.tensor(token_indices).unsqueeze(0)
    #     else:
    #         normalized_text = self.normalize_text(text)
    #         token_indices = [self.char2ind.get(char, self.char2ind.get(self.unk_token)) for char in normalized_text]
    #         return torch.tensor(token_indices).unsqueeze(0)




    def encode(self, text: str) -> torch.Tensor:
        """
        Encode text either using tokenizer/processor (BPE) or character-based encoding.

        Parameters:
        - text (str): The input text to encode.

        Returns:
        - torch.Tensor: Tensor of token indices.
        """
        if self.use_bpe:
            encoded = self.tokenizer(text, return_tensors="pt", padding=False, truncation=False, add_special_tokens=False)
            token_indices = encoded["input_ids"][0].tolist()
            print(f"[DEBUG] BPE Encoding: {text} -> {token_indices}")
            return torch.tensor(token_indices).unsqueeze(0)
        else:
            normalized_text = self.normalize_text(text)
            token_indices = [self.char2ind.get(char, self.char2ind.get(self.unk_token)) for char in normalized_text]
            print(f"[DEBUG] Character Encoding: {text} -> {token_indices}")
            return torch.tensor(token_indices).unsqueeze(0)



    def decode_simple(self, indices: List[int]) -> str:
        """
        Simple CTC decoding without language model.
        Collapses consecutive duplicate tokens and removes blanks.

        Parameters:
        - indices (List[int]): List of token indices.

        Returns:
        - str: Decoded text.
        """
        decoded_chars = []
        previous_idx = None

        for idx in indices:
            if idx == self.blank_index:
                previous_idx = idx
                continue  # Skip blank tokens
            if idx == previous_idx:                   ##### TO TEST DELETION!!! #####
                continue  # Skip duplicate tokens     ##### TO TEST DELETION!!! #####
            if 0 <= idx < len(self.ind2char):
                char = self.ind2char[idx]
                decoded_chars.append(char)
            previous_idx = idx

        # Join characters without spaces and convert to lowercase
        text = "".join(decoded_chars).strip().lower()

        if self.use_bpe and self.tokenizer:
            return self.tokenizer.clean_up_tokenization(text)
        else:
            return text




    def _initialize_language_model(self):
        """Initialize language model with explicit blank token handling."""
        self.lm = None
        self.decoder = None

        model_path = self.binary_path if self.binary_path else self.arpa_path
        print('model_path: ', model_path)
        if not model_path or not os.path.exists(model_path):
            print("No language model path provided or file does not exist.")
            return

        try:
            self.lm = kenlm.Model(model_path)
            print(f"Loaded {'binary' if self.binary_path else 'ARPA'} language model.")

            decoder_config = {
                "labels": self.labels if self.use_bpe else self.vocab,
                "kenlm_model_path": model_path,
                "alpha": self.lm_weight,
                "beta": 0.1,
                "unk_score_offset": -10.0,
                
            }

            if self.unigrams:
                print("\n--- Unigrams List ---")
                # Save the unigrams to a file
                with open("unigrams_list.txt", "w", encoding="utf-8") as f:
                    for unigram in self.unigrams:
                        f.write(f"{unigram}\n")
                print(f"Unigrams list saved to 'unigrams_list.txt'. Total unigrams: {len(self.unigrams)}")
                print("----------------------\n")
                decoder_config["unigrams"] = self.unigrams

            self.decoder = build_ctcdecoder(**decoder_config)
            print("Successfully initialized language model and decoder.")

        except Exception as e:
            print(f"Warning: Failed to initialize decoder: {str(e)}")
            self.decoder = None

    
    def decode(self, indices: List[int]) -> str:
        """
        Decode indices to text using beam search decoder if available.

        Parameters:
        - indices (List[int]): List of token indices.

        Returns:
        - str: Decoded text.
        """
        if self.decoder:
            decoded_text = self.decoder.decode(indices)
            # Convert to lower case
            decoded_text = decoded_text.lower()
            return decoded_text
        else:
            decoded_text = self.decode_simple(indices)
            # Convert to lower case
            decoded_text = decoded_text.lower()
            return decoded_text  # Ensure the decoded text is returned

    
    def decode_logits(self, logits: Union[torch.Tensor, List[List[float]], np.ndarray]) -> str:
        """
        Decode logits using the decoder if available, otherwise use greedy decoding.

        Parameters:
        - logits (Union[torch.Tensor, List[List[float]], np.ndarray]): Logits output from the model.

        Returns:
        - str: Decoded text.
        """
        if isinstance(logits, torch.Tensor):
            logits = logits.cpu().numpy()
        elif isinstance(logits, list):
            logits = np.array(logits)
        elif not isinstance(logits, np.ndarray):
            raise TypeError("logits must be a torch.Tensor, list of lists, or numpy.ndarray")

        if logits.ndim == 3:
            logits = logits[0]

        if logits.ndim != 2:
            raise ValueError(f"Logits should be 2D (time_steps, vocab_size), got {logits.ndim}D")

        if self.decoder:
            decoded_text = self.decoder.decode(logits)
            return decoded_text
        else:
            predicted_indices = np.argmax(logits, axis=-1).tolist()
            return self.decode_simple(predicted_indices)

    def decode_indices(self, indices: Union[torch.Tensor, List[int], np.ndarray]) -> str:
        """
        Decode token indices to text using simple decoding (no LM).

        Parameters:
        - indices (Union[torch.Tensor, List[int], np.ndarray]): Token indices.

        Returns:
        - str: Decoded text.
        """
        if isinstance(indices, torch.Tensor):
            indices = indices.squeeze().tolist()
        elif isinstance(indices, np.ndarray):
            indices = indices.tolist()
        elif not isinstance(indices, list):
            raise TypeError("decode_indices expects a list, torch.Tensor, or numpy.ndarray.")

        return self.decode_simple(indices)

    def ctc_decode(self, logits: Union[torch.Tensor, List[int], np.ndarray]) -> str:
        """
        Perform CTC decoding on logits.

        Parameters:
        - logits (Union[torch.Tensor, List[int], np.ndarray]): Logits or token indices.

        Returns:
        - str: Decoded text.
        """
        if isinstance(logits, np.ndarray):
            logits = torch.from_numpy(logits)
        elif isinstance(logits, list):
            logits = torch.tensor(logits)

        if logits.dim() == 3:
            logits = logits[0]  # Reduce to 2D (sequence length, vocab size)

        if self.use_bpe:
            if self.use_lm and self.decoder:
                # Use LM if available
                return self.decoder.decode(logits)
            else:
                # Use tokenizer-based decoding
                predicted_indices = torch.argmax(logits, axis=-1).tolist()
                return self.decode(predicted_indices)
        else:
            # Use character-based decoding
            predicted_indices = torch.argmax(logits, axis=-1).tolist()
            return self.decode_simple(predicted_indices)

    

    def ctc_beam_search(self, probs, beam_size, use_lm = False, debug: bool = False) -> List[Tuple[str, float]]:
        """
        Beam search with optional Language Model support.

        Parameters:
        - probs: Probability distributions over tokens.
        - debug (bool): Whether to print debug information.

        Returns:
        - List[Tuple[str, float]]: List of decoded text with scores.
        """
        beam_size = self.beam_size
        debug = False

        if self.use_lm and self.decoder is not None:
            try:
                if isinstance(probs, torch.Tensor):
                    probs = probs.cpu().numpy()
                elif isinstance(probs, list):
                    probs = np.array(probs)
                elif isinstance(probs, np.ndarray):
                    pass
                else:
                    raise TypeError("probs must be a torch.Tensor, list, or numpy.ndarray")

                beams = self.decoder.decode_beams(
                    probs,
                    beam_prune_logp=-10.0,
                    token_min_logp=-5.0,
                    hotwords=[],
                    hotword_weight=10.0,
                )

                formatted_beams = []
                for beam in beams[:beam_size]:
                    text = beam[0]
                    acoustic_score = beam[3]
                    lm_score = beam[4]

                    if self.use_bpe and self.tokenizer:
                        text = self.tokenizer.clean_up_tokenization(text)
                    text = text.lower().strip()

                    combined_score = (1 - self.lm_weight) * acoustic_score + self.lm_weight * lm_score
                    text_len = max(1, len(text.split())) if self.use_bpe else max(1, len(text))
                    normalized_score = combined_score / text_len

                    formatted_beams.append((text, normalized_score))

                if debug:
                    print("\nFormatted beam results with Language Model:")
                    for text, score in formatted_beams[:3]:
                        print(f"Text: '{text}', Score: {score:.4f}")

                if formatted_beams:
                    return sorted(formatted_beams, key=lambda x: -x[1])
                else:
                    print("No valid beams found, falling back to standard beam search")
                    return self._standard_beam_search(probs, debug)

            except Exception as e:
                print(f"Beam search with LM failed: {str(e)}, falling back to standard beam search")
                return self._standard_beam_search(probs, debug)
        else:
            return self._standard_beam_search(probs, debug)

    def _standard_beam_search(self, probs, debug: bool = False) -> List[Tuple[str, float]]:
        """
        Original beam search implementation without Language Model.

        Parameters:
        - probs: Probability distributions over tokens.
        - debug (bool): Whether to print debug information.

        Returns:
        - List[Tuple[str, float]]: List of decoded text with scores.
        """
        beam_size = self.beam_size

        if isinstance(probs, np.ndarray):
            probs = torch.from_numpy(probs)

        if probs.device != torch.device('cpu'):
            probs = probs.cpu()

        dp = {("", self.blank_token): 0.0}
        log_probs = torch.log(probs + 1e-8)

        if debug:
            print("\nStarting beam search with beam size:", beam_size)

        for t, prob in enumerate(log_probs):
            new_dp = defaultdict(lambda: float('-inf'))
            top_k = torch.topk(prob, k=min(beam_size, len(prob)))

            if debug and t < self.max_printed_samples:
                print(f"\nTimestep {t}:")
                print("Top tokens:", [(self.ind2char[idx.item()], val.item()) 
                                    for val, idx in zip(top_k.values, top_k.indices)])

            for val, ind in zip(top_k.values, top_k.indices):
                curr_char = self.ind2char[ind.item()]
                next_token_log_prob = val.item()

                for (prefix, last_char), log_prob in dp.items():
                    if last_char == curr_char and curr_char != " ":
                        new_prefix = prefix
                    else:
                        if curr_char != self.blank_token:
                            if curr_char == " " and prefix.endswith(" "):
                                continue
                            new_prefix = prefix + curr_char
                        else:
                            new_prefix = prefix

                    new_log_prob = log_prob + next_token_log_prob
                    key = (new_prefix, curr_char)
                    new_dp[key] = max(new_dp[key], new_log_prob)

            if len(new_dp) > 0:
                max_score = max(score for _, score in new_dp.items())
                new_dp = {key: score - max_score for key, score in new_dp.items()}

            dp = dict(sorted(new_dp.items(), key=lambda x: -x[1])[:beam_size])

            if debug and t < 2:
                print("\nCurrent beam:")
                for (text, last_char), score in list(dp.items())[:3]:
                    print(f"Text: '{text}', Last: '{last_char}', Score: {score:.4f}")

        final_beams = []
        for (text, _), score in dp.items():
            if self.use_bpe and self.tokenizer:
                text = self.tokenizer.clean_up_tokenization(text)
            text = text.lower().strip()
            text_len = max(1, len(text.split())) if self.use_bpe else max(1, len(text))
            normalized_score = score / text_len
            final_beams.append((text, normalized_score))

        final_beams.sort(key=lambda x: -x[1])
        if not final_beams:
            final_beams = [("", float('-inf'))]

        return final_beams[:beam_size]

    def test_language_model(self):
        """Debug function to verify LM functionality"""
        print("\nTesting Language Model...")

        if self.lm is None:
            print("Error: Language model is not loaded!")
            return

        test_sentences = [
            "this is a good sentence",
            "this is also a good sentence",
            "thiss iss nott aa goodd sentencee",
            "random word salad box cat",
            "the cat sat on the mat",
            "",
            "a",
        ]

        print("\nTesting individual sentences:")
        for sentence in test_sentences:
            score = self.score_with_lm(sentence)
            print(f"\nText: '{sentence}'")
            print(f"LM Score: {score:.4f}")

        test_prefixes = [
            "the quick brown",
            "how are",
            "thank",
            "nice to",
        ]

        print("\nTesting word completions:")
        for prefix in test_prefixes:
            print(f"\nPrefix: '{prefix}'")
            completions = [
                prefix + " " + word for word in ["you", "fox", "cat", "xyz", "meet"]
            ]
            scores = [(completion, self.score_with_lm(completion)) 
                    for completion in completions]
            scores.sort(key=lambda x: x[1], reverse=True)
            print("Top completions by score:")
            for completion, score in scores[:3]:
                print(f"  '{completion}': {score:.4f}")

    def score_with_lm(self, text: str) -> float:
        """
        Score text using language model, handling edge cases.

        Parameters:
        - text (str): The input text to score.

        Returns:
        - float: LM score.
        """
        if self.lm is None:
            return 0.0

        if not text or len(text.strip()) == 0:
            return float('-inf')

        text = text.lower().strip()
        return self.lm.score(text, bos=True, eos=True)

    def _basic_ctc_decode(self, logits: np.ndarray, sequence_length: int) -> List[str]:
        """
        Basic CTC decoding without LM.

        Parameters:
        - logits (np.ndarray): Logits from the model.
        - sequence_length (int): Length of the sequence to decode.

        Returns:
        - List[str]: Decoded text.
        """
        argmax_indices = np.argmax(logits, axis=-1)

        if len(argmax_indices.shape) == 0:
            argmax_indices = np.array([argmax_indices])

        if len(argmax_indices.shape) == 1:
            argmax_indices = np.expand_dims(argmax_indices, axis=0)

        predictions = []
        for sequence in argmax_indices:
            decoded = []
            last_idx = None

            for idx in sequence[:sequence_length]:
                if idx != self.blank_index and idx != last_idx:
                    decoded.append(self.ind2char[idx])
                last_idx = idx

            text = "".join(decoded)
            if self.use_bpe and self.tokenizer:
                text = self.tokenizer.clean_up_tokenization(text)
            predictions.append(text)

        return predictions

    @staticmethod
    def normalize_text(text: str) -> str:
        """Normalize input text."""
        text = text.lower()
        text = re.sub(r"[^a-z ]", "", text)
        return text

    def test_decoder(self, sample_text: str = "test decoder functionality"):
        """Test the decoder setup."""
        print("\nTesting decoder configuration...")

        encoded = self.encode(sample_text)
        decoded = self.decode(encoded[0].tolist())
        print(f"Original text: {sample_text}")
        print(f"Basic decode: {decoded}")

        sequence_length = 50
        vocab_size = len(self)
        fake_logits = torch.randn(1, sequence_length, vocab_size)
        fake_length = torch.tensor([sequence_length])

        if self.decoder is not None:
            print("\nTesting pyctcdecode integration...")
            decoded_with_lm = self.ctc_decode(fake_logits)
            print(f"Decoded with LM: {decoded_with_lm}")

            print(f"\nBeam width: {self.beam_size}")
            print(f"LM weight: {self.lm_weight}")
        else:
            print("\nNo language model loaded - using basic CTC decoding")
            basic_decoded = self._basic_ctc_decode(fake_logits.numpy(), sequence_length)
            print(f"Basic CTC decoded: {basic_decoded[0]}")

    def __len__(self):
        """Return the size of the vocabulary."""
        return len(self.vocab)

    def ctc_decode(self, logits: Union[torch.Tensor, List[int], np.ndarray]) -> str:
        """
        Perform CTC decoding on logits.

        Parameters:
        - logits (Union[torch.Tensor, List[int], np.ndarray]): Logits or token indices.

        Returns:
        - str: Decoded text.
        """
        if isinstance(logits, np.ndarray):
            logits = torch.from_numpy(logits)
        elif isinstance(logits, list):
            logits = torch.tensor(logits)

        if logits.dim() == 3:
            decoded_text = self.decode_logits(logits)
            return decoded_text
        elif logits.dim() == 2:
            decoded_text = self.decode_logits(logits)
            return decoded_text
        elif logits.dim() == 1:
            decoded_text = self.decode_indices(logits)
            return decoded_text
        else:
            raise ValueError(f"Unsupported logits shape: {logits.shape}. Expected 1D, 2D, or 3D.")

## Updated WV2

In [14]:
import re
from collections import defaultdict
import torch
import kenlm
from pyctcdecode import build_ctcdecoder
import numpy as np
import os
import sentencepiece as spm

from typing import List, Tuple, Optional, Union

os.environ["TOKENIZERS_PARALLELISM"] = "false"

class CTCTextEncoder:
    def __init__(
        self,
        arpa_path: Optional[str] = None,
        binary_path: Optional[str] = "4-gram_lc_correct.bin",
        unigram_path: Optional[str] = "librispeech-vocab.txt",
        pretrained_tokenizer: str = "facebook/wav2vec2-base-960h",
        lm_weight: float = 0.5,
        beam_size: int = 100,
        use_lm: bool = False,     
        use_bpe: bool = False,    
        blank_token: str = "[PAD]",
        unk_token: str = "[UNK]",
        **kwargs
    ):
        """
        CTCTextEncoder can do:
          - Character-based encoding if use_bpe=False
          - SentencePiece subword encoding if use_bpe=True
        """
        self.beam_size = beam_size
        self.lm_weight = lm_weight
        self.arpa_path = arpa_path
        self.binary_path = binary_path
        self.blank_token = blank_token
        self.unk_token = unk_token
        self.use_lm = use_lm
        self.use_bpe = use_bpe
        self.printed_samples = 0
        self.max_printed_samples = 5

        print("CTC Text Encoder:")
        print("pretrained_tokenizer:", pretrained_tokenizer)
        print("lm_weight:", lm_weight)
        print("beam_size:", beam_size)
        print("binary_path:", binary_path)
        print("use_lm:", self.use_lm)
        print("use_bpe:", self.use_bpe)
        self.pretrained_tokenizer = pretrained_tokenizer

        # Load unigrams if provided
        self.unigrams = None
        if unigram_path and os.path.exists(unigram_path):
            print(f"Loading unigrams from: {unigram_path}")
            with open(unigram_path, 'r', encoding='utf-8') as f:
                self.unigrams = [line.strip().lower() for line in f if line.strip()]
            print(f"Loaded {len(self.unigrams)} unigrams")

        # Initialize the tokenizer or set up character-based vocab
        self._initialize_vocabulary(pretrained_tokenizer)

        # Create index mappings
        self.ind2char = dict(enumerate(self.vocab))
        self.char2ind = {v: k for k, v in self.ind2char.items()}
        self.blank_index = self.char2ind.get(self.blank_token, None)

        print(f"\nVocabulary Info:")
        print(f"Size: {len(self.vocab)}")
        print("Full Vocabulary (up to first 50 tokens):", self.vocab[:50])
        print(f"Blank token: {self.blank_token}, Blank index: {self.blank_index}")

        sample_inds = list(self.ind2char.keys())[:10]
        print("Sample ind2char mappings:", {k: self.ind2char[k] for k in sample_inds})
        sample_chars = list(self.char2ind.keys())[:10]
        print("Sample char2ind mappings:", {k: self.char2ind[k] for k in sample_chars})

        # Conditionnally init LM
        if self.use_lm:
            self._initialize_language_model()
        else:
            print("Language model usage is disabled.")
            self.lm = None
            self.decoder = None

    def _initialize_vocabulary(self, pretrained_tokenizer: str):
        """
        If use_bpe=True => load a SentencePiece model from `pretrained_tokenizer`.
        Otherwise => a standard char-based vocab.
        """
        if self.use_bpe:
            print("Initializing subword tokenizer (SentencePiece) for encoding/decoding.")
            if not os.path.exists(pretrained_tokenizer):
                raise FileNotFoundError(
                    f"[ERROR] SentencePiece model file not found: {pretrained_tokenizer}"
                )
            self.sp = spm.SentencePieceProcessor()
            self.sp.load(pretrained_tokenizer)

            vocab_size = self.sp.get_piece_size()
            print(f"Loaded SentencePiece model with vocab size={vocab_size}")

            # For LM usage, store the subwords in self.labels
            self.labels = [self.sp.id_to_piece(i) for i in range(vocab_size)]
            self.vocab = self.labels  # Keep indexing consistent
        else:
            print("Initializing character-based vocabulary without using tokenizer.")
            # Keep your old char-based logic
            self.vocab = [
                'a','b','c','d','e','f','g','h','i','j','k','l',
                'm','n','o','p','q','r','s','t','u','v','w','x','y','z',' '
            ]
            # put blank at front
            self.vocab = [self.blank_token] + self.vocab
            self.labels = self.vocab
            self.sp = None

    def encode(self, text: str) -> torch.Tensor:
        """
        Encode text => token IDs. 
        If use_bpe => subword via SentencePiece; else => char-based.
        """
        if self.use_bpe and self.sp is not None:
            # OPTIONAL: If your SentencePiece model was trained on uppercase data, do text = text.upper().
            # Otherwise, keep as is or do text.lower() based on how your model was trained.
            token_ids = self.sp.encode(text, out_type=int)
            return torch.tensor([token_ids], dtype=torch.long)
        else:
            # Char-based
            normalized_text = self.normalize_text(text)
            token_indices = [self.char2ind.get(ch, self.char2ind.get(self.unk_token))
                             for ch in normalized_text]
            return torch.tensor(token_indices).unsqueeze(0)

    def decode_simple(self, indices: List[int]) -> str:
        """
        Simple CTC decoding: collapse repeated tokens, remove blank.
        Then if we're in BPE mode, we can do a direct sp.decode fallback 
        for a better subword reconstruction (rather than naive char-join).
        """
        # If using BPE, let's do a direct SentencePiece decode 
        # for a more correct subword -> text mapping 
        # (especially for unknown words, spaces, etc.).
        if self.use_bpe and self.sp is not None:
            # But we must handle repeated tokens for CTC:
            collapsed = []
            prev_idx = None
            for idx in indices:
                if idx == self.blank_index:
                    prev_idx = idx
                    continue
                if idx != prev_idx:
                    collapsed.append(idx)
                prev_idx = idx
            # Now decode with sp
            text = self.sp.decode(collapsed)
            return text
        else:
            # Character-based or fallback approach
            decoded_chars = []
            previous_idx = None
            for idx in indices:
                if idx == self.blank_index:
                    previous_idx = idx
                    continue
                if idx == previous_idx:
                    continue
                if 0 <= idx < len(self.ind2char):
                    decoded_chars.append(self.ind2char[idx])
                previous_idx = idx
            text = "".join(decoded_chars).strip().lower()
            return text

    def decode(self, indices: List[int]) -> str:
        """
        If an LM is available, do LM decode. Otherwise, fallback to decode_simple.
        """
        if self.decoder:
            decoded_text = self.decoder.decode(indices)
            return decoded_text.lower().strip()
        else:
            return self.decode_simple(indices).lower().strip()

    def decode_logits(self, logits: Union[torch.Tensor, List[List[float]], np.ndarray]) -> str:
        """
        If LM is present, decode with beam search. Otherwise greedy decode.
        """
        if isinstance(logits, torch.Tensor):
            logits = logits.cpu().numpy()
        elif isinstance(logits, list):
            logits = np.array(logits)
        elif not isinstance(logits, np.ndarray):
            raise TypeError("logits must be a torch.Tensor, list of lists, or np.ndarray.")

        if logits.ndim == 3:
            logits = logits[0]
        if logits.ndim != 2:
            raise ValueError(f"Logits should be 2D, got shape={logits.shape}")

        if self.decoder:
            return self.decoder.decode(logits)
        else:
            indices = np.argmax(logits, axis=-1).tolist()
            return self.decode_simple(indices)

    def decode_indices(self, indices: Union[torch.Tensor, List[int], np.ndarray]) -> str:
        """
        Naive decode from token indices (no LM).
        """
        if isinstance(indices, torch.Tensor):
            indices = indices.squeeze().tolist()
        elif isinstance(indices, np.ndarray):
            indices = indices.tolist()
        return self.decode_simple(indices)

    def ctc_decode(self, logits: Union[torch.Tensor, List[int], np.ndarray]) -> str:
        """
        CTC decode from logits or direct token IDs.
        """
        if isinstance(logits, np.ndarray):
            logits = torch.from_numpy(logits)
        elif isinstance(logits, list):
            logits = torch.tensor(logits)

        if logits.dim() == 3:
            logits = logits[0]
        if self.use_bpe:
            if self.use_lm and self.decoder:
                return self.decoder.decode(logits)
            else:
                predicted_indices = torch.argmax(logits, axis=-1).tolist()
                return self.decode(predicted_indices)
        else:
            predicted_indices = torch.argmax(logits, axis=-1).tolist()
            return self.decode_simple(predicted_indices)

    def _initialize_language_model(self):
        """Initialize KenLM + pyctcdecode for an external LM if needed."""
        self.lm = None
        self.decoder = None

        model_path = self.binary_path if self.binary_path else self.arpa_path
        print("model_path:", model_path)
        if not model_path or not os.path.exists(model_path):
            print("No language model path or file does not exist.")
            return

        try:
            self.lm = kenlm.Model(model_path)
            print("Loaded KenLM language model.")

            decoder_config = {
                "labels": self.labels if self.use_bpe else self.vocab,
                "kenlm_model_path": model_path,
                "alpha": self.lm_weight,
                "beta": 0.1,
                "unk_score_offset": -10.0,
            }

            if self.unigrams:
                print("\n--- Unigrams List ---")
                with open("unigrams_list.txt", "w", encoding="utf-8") as f:
                    for unigram in self.unigrams:
                        f.write(f"{unigram}\n")
                print(f"Unigrams list saved to 'unigrams_list.txt'. Count={len(self.unigrams)}")
                decoder_config["unigrams"] = self.unigrams

            self.decoder = build_ctcdecoder(**decoder_config)
            print("Successfully initialized pyctcdecode with LM.")
        except Exception as e:
            print(f"Warning: Could not init LM: {str(e)}")
            self.decoder = None

    def score_with_lm(self, text: str) -> float:
        if self.lm is None:
            return 0.0
        if not text.strip():
            return float("-inf")
        return self.lm.score(text.lower().strip(), bos=True, eos=True)

    @staticmethod
    def normalize_text(text: str) -> str:
        """Used only in char-based mode."""
        text = text.lower()
        text = re.sub(r"[^a-z ]", "", text)
        return text

    def test_decoder(self, sample_text: str = "test decoder functionality"):
        print("\nTesting decoder configuration...")

        encoded = self.encode(sample_text)
        print(f"Encoded => {encoded}")
        decoded = self.decode_simple(encoded[0].tolist())
        print(f"Naive decode => {decoded}")

        # Test random logits
        seq_len = 10
        vocab_size = len(self)
        fake_logits = torch.randn(1, seq_len, vocab_size)

        if self.decoder:
            print("\nTesting pyctcdecode with random logits...")
            text_lm = self.ctc_decode(fake_logits)
            print(f"Decoded with LM => {text_lm}")
        else:
            print("\nNo LM loaded => using basic greedy decode.")
            pred = self.decode_logits(fake_logits)
            print(f"Greedy decoded => {pred}")

    def __len__(self):
        return len(self.vocab)


In [20]:
import re
from collections import defaultdict
import torch
import kenlm
from transformers import Wav2Vec2Processor, AutoProcessor, AutoTokenizer
from pyctcdecode import build_ctcdecoder
import numpy as np
import os

from typing import List, Tuple, Optional, Union

os.environ["TOKENIZERS_PARALLELISM"] = "false"

class CTCTextEncoder:
    def __init__(
        self,
        arpa_path: Optional[str] = None,
        binary_path: Optional[str] = "4-gram_lc_correct.bin",
        unigram_path: Optional[str] = "librispeech-vocab.txt",
        pretrained_tokenizer: str = "facebook/wav2vec2-base-960h",
        lm_weight: float = 0.5,
        beam_size: int = 100,
        use_lm: bool = False,
        use_bpe: bool = False,
        blank_token: str = "[PAD]",
        unk_token: str = "[UNK]",
        **kwargs
    ):
        """
        Initialize encoder with conditional subword (SentencePiece) or char-based vocab.
        """
        self.beam_size = beam_size
        self.lm_weight = lm_weight
        self.arpa_path = arpa_path
        self.binary_path = binary_path
        self.blank_token = blank_token
        self.unk_token = unk_token
        self.use_lm = use_lm
        self.use_bpe = use_bpe
        self.printed_samples = 0
        self.max_printed_samples = 5

        print("CTC Text Encoder:")
        print("pretrained_tokenizer:", pretrained_tokenizer)
        print("lm_weight:", lm_weight)
        print("beam_size:", beam_size)
        print("binary_path:", binary_path)
        print("use_lm:", self.use_lm)
        print("use_bpe:", self.use_bpe)

        self.pretrained_tokenizer = pretrained_tokenizer

        # Load unigrams if provided
        self.unigrams = None
        if unigram_path and os.path.exists(unigram_path):
            print(f"Loading unigrams from: {unigram_path}")
            with open(unigram_path, 'r', encoding='utf-8') as f:
                self.unigrams = [line.strip().lower() for line in f if line.strip()]
            print(f"Loaded {len(self.unigrams)} unigrams")

        # Initialize vocabulary/tokenizer
        self._initialize_vocabulary(pretrained_tokenizer)

        # Create index mappings
        self.ind2char = dict(enumerate(self.vocab))
        self.char2ind = {v: k for k, v in self.ind2char.items()}
        self.blank_index = self.char2ind.get(self.blank_token, None)

        print("\nVocabulary Info:")
        print(f"Size: {len(self.vocab)}")
        print("Full Vocabulary (up to first 50 tokens):", self.vocab[:50])
        print(f"Blank token: {self.blank_token}, Blank index: {self.blank_index}")

        sample_ind_keys = list(self.ind2char.keys())[:10]
        print("Sample ind2char mappings:", {k: self.ind2char[k] for k in sample_ind_keys})

        sample_char_keys = list(self.char2ind.keys())[:10]
        print("Sample char2ind mappings:", {k: self.char2ind[k] for k in sample_char_keys})

        # Conditionally load LM
        if self.use_lm:
            self._initialize_language_model()
        else:
            print("Language model usage is disabled.")
            self.lm = None
            self.decoder = None

    def _initialize_vocabulary(self, pretrained_tokenizer: str):
        """
        If use_bpe=True, load a SentencePiece model from `pretrained_tokenizer` path.
        Otherwise, use a basic character-based vocab.
        """
        if self.use_bpe:
            print("Initializing subword tokenizer using SentencePiece.")
            import sentencepiece as spm

            if not os.path.exists(pretrained_tokenizer):
                raise FileNotFoundError(
                    f"SentencePiece model not found at: {pretrained_tokenizer}"
                )

            # Load the SentencePiece model
            self.sp = spm.SentencePieceProcessor()
            self.sp.load(pretrained_tokenizer)

            # Create a dummy "vocab" list to keep indexing logic consistent
            # We'll store each piece as a string, e.g. "sp_0", "sp_1", ... or piece text
            vocab_size = self.sp.get_piece_size()
            print(f"Loaded SentencePiece model with vocab size={vocab_size}")

            # We'll store the actual piece text in self.labels for LM usage if needed
            self.labels = [self.sp.id_to_piece(i) for i in range(vocab_size)]
            self.vocab = self.labels  # So len(self.vocab) = sp.get_piece_size()

            # We do not forcibly insert blank_token or unk_token here
            # SentencePiece might have <unk> at id=0, etc.

        else:
            print("Initializing character-based vocabulary without tokenizer.")
            self.vocab = [
                "a", "b", "c", "d", "e", "f", "g", "h",
                "i", "j", "k", "l", "m", "n", "o", "p",
                "q", "r", "s", "t", "u", "v", "w", "x",
                "y", "z", " "
            ]
            # Put blank at the front
            self.vocab = [self.blank_token] + self.vocab
            self.labels = self.vocab
            self.sp = None

    def encode(self, text: str) -> torch.Tensor:
        """
        Encode text: if use_bpe => SentencePiece; else => char-based.
        Returns a 2D tensor [1, seq_len].
        """
        if self.use_bpe:
            # SentencePiece encode
            token_ids = self.sp.encode(text, out_type=int)
            # Wrap in batch dimension
            return torch.tensor([token_ids], dtype=torch.long)
        else:
            # Character-based
            normalized_text = self.normalize_text(text)
            token_indices = [
                self.char2ind.get(char, self.char2ind.get(self.unk_token))
                for char in normalized_text
            ]
            return torch.tensor(token_indices).unsqueeze(0)

    def decode_simple(self, indices: List[int]) -> str:
        """
        Simple CTC-like decoding: collapses repeats & removes blank. Then lowercases.
        For subword, this is naive but consistent with original code logic.
        """

        if self.use_bpe and self.sp is not None:
        # Handle repeated tokens due to CTC decoding
            collapsed = []
            prev_idx = None
            for idx in indices:
                if idx == self.blank_index:
                    prev_idx = idx
                    continue
                if idx != prev_idx:
                    collapsed.append(idx)
                prev_idx = idx

            # Decode subword tokens using SentencePiece
            text = self.sp.decode(collapsed)
            # Remove SentencePiece's `▁` symbol (denotes spaces)
            text = text.replace("▁", " ").strip()
            return text
        else:

            decoded_chars = []
            previous_idx = None

            for idx in indices:
                if idx == self.blank_index:
                    previous_idx = idx
                    continue
                if idx == previous_idx:
                    continue
                if 0 <= idx < len(self.ind2char):
                    decoded_chars.append(self.ind2char[idx])
                previous_idx = idx

            text = "".join(decoded_chars).strip().lower()

            # If using subword, the above is naive (we're literally concatenating piece strings)
            # A better approach is to call self.sp.decode(...) but we keep this for consistency.
        return text

    def decode(self, indices: List[int]) -> str:
        """
        If we have a pyctcdecode LM, use that. Else fall back to `decode_simple`.
        """
        if self.decoder:
            decoded_text = self.decoder.decode(indices)
            decoded_text = decoded_text.lower()
            decoded_text = decoded_text.replace("▁", " ")
            return decoded_text
        else:
            # Fallback: naive method
            decoded_text = self.decode_simple(indices)
            return decoded_text.lower()

    def decode_logits(self, logits: Union[torch.Tensor, List[List[float]], np.ndarray]) -> str:
        """
        Decode from raw logits. If LM is present => use it; else greedy.
        """
        if isinstance(logits, torch.Tensor):
            logits = logits.cpu().numpy()
        elif isinstance(logits, list):
            logits = np.array(logits)
        elif not isinstance(logits, np.ndarray):
            raise TypeError("logits must be torch.Tensor, list of lists, or np.ndarray")

        if logits.ndim == 3:
            logits = logits[0]

        if logits.ndim != 2:
            raise ValueError(f"Logits should be 2D, got shape {logits.shape}")

        if self.decoder:
            decoded_text = self.decoder.decode(logits)
            return decoded_text
        else:
            # Greedy decode
            predicted_indices = np.argmax(logits, axis=-1).tolist()
            return self.decode_simple(predicted_indices)

    def decode_indices(self, indices: Union[torch.Tensor, List[int], np.ndarray]) -> str:
        """
        Decode token indices to text with naive approach (no LM).
        """
        if isinstance(indices, torch.Tensor):
            indices = indices.squeeze().tolist()
        elif isinstance(indices, np.ndarray):
            indices = indices.tolist()
        elif not isinstance(indices, list):
            raise TypeError("decode_indices expects list/torch.Tensor/np.ndarray.")

        return self.decode_simple(indices)

    def ctc_decode(self, logits: Union[torch.Tensor, List[int], np.ndarray]) -> str:
        """
        CTC decode from logits or indices. 
        For subword, this is naive unless you do beam search or LM.
        """
        if isinstance(logits, np.ndarray):
            logits = torch.from_numpy(logits)
        elif isinstance(logits, list):
            logits = torch.tensor(logits)

        if logits.dim() == 3:
            logits = logits[0]  # [1, T, V] -> [T, V]

        if self.use_bpe:
            if self.use_lm and self.decoder:
                # LM-based decode
                return self.decoder.decode(logits)
            else:
                # Greedy subword decode
                predicted_indices = torch.argmax(logits, axis=-1).tolist()
                return self.decode(predicted_indices)
        else:
            # Character-based
            predicted_indices = torch.argmax(logits, axis=-1).tolist()
            return self.decode_simple(predicted_indices)

    def _initialize_language_model(self):
        """Initialize KenLM + pyctcdecode if LM is requested."""
        self.lm = None
        self.decoder = None

        model_path = self.binary_path if self.binary_path else self.arpa_path
        print("model_path: ", model_path)
        if not model_path or not os.path.exists(model_path):
            print("No language model path provided or file does not exist.")
            return

        try:
            self.lm = kenlm.Model(model_path)
            print(f"Loaded {'binary' if self.binary_path else 'ARPA'} language model.")

            decoder_config = {
                "labels": self.labels if self.use_bpe else self.vocab,
                "kenlm_model_path": model_path,
                "alpha": self.lm_weight,
                "beta": 0.1,
                "unk_score_offset": -10.0,
            }

            if self.unigrams:
                print("\n--- Unigrams List ---")
                with open("unigrams_list.txt", "w", encoding="utf-8") as f:
                    for unigram in self.unigrams:
                        f.write(f"{unigram}\n")
                print(f"Unigrams list saved to 'unigrams_list.txt'. Total unigrams: {len(self.unigrams)}")
                print("----------------------\n")
                decoder_config["unigrams"] = self.unigrams

            self.decoder = build_ctcdecoder(**decoder_config)
            print("Successfully initialized language model and decoder.")

        except Exception as e:
            print(f"Warning: Failed to initialize decoder: {str(e)}")
            self.decoder = None

    def test_language_model(self):
        """Debug: test LM if loaded."""
        print("\nTesting Language Model...")

        if self.lm is None:
            print("Error: Language model is not loaded!")
            return

        test_sentences = [
            "this is a good sentence",
            "this is also a good sentence",
            "thiss iss nott aa goodd sentencee",
        ]

        for s in test_sentences:
            score = self.score_with_lm(s)
            print(f"LM Score for '{s}': {score:.4f}")

    def score_with_lm(self, text: str) -> float:
        """Score text with LM if present."""
        if self.lm is None:
            return 0.0
        if not text or not text.strip():
            return float("-inf")
        text = text.lower().strip()
        return self.lm.score(text, bos=True, eos=True)

    @staticmethod
    def normalize_text(text: str) -> str:
        """Normalize for character-based approach."""
        text = text.lower()
        text = re.sub(r"[^a-z ]", "", text)
        return text

    def _basic_ctc_decode(self, logits: np.ndarray, sequence_length: int) -> List[str]:
        """
        Basic CTC greedy decode without LM. 
        Typically for use_bpe=False path or fallback testing.
        """
        argmax_indices = np.argmax(logits, axis=-1)

        if len(argmax_indices.shape) == 0:
            argmax_indices = np.array([argmax_indices])

        if len(argmax_indices.shape) == 1:
            argmax_indices = np.expand_dims(argmax_indices, axis=0)

        predictions = []
        for sequence in argmax_indices:
            decoded = []
            last_idx = None
            for idx in sequence[:sequence_length]:
                if idx != self.blank_index and idx != last_idx:
                    decoded.append(self.ind2char[idx])
                last_idx = idx

            text = "".join(decoded)
            if self.use_bpe and self.sp:
                # We could do self.sp.decode(...) but to keep it consistent, just do naive cleanup
                pass
            predictions.append(text)

        return predictions

    def test_decoder(self, sample_text: str = "test decoder functionality"):
        """Test the encode/decode pipeline."""
        print("\nTesting decoder configuration...")

        encoded = self.encode(sample_text)
        decoded = self.decode(encoded[0].tolist())
        print(f"Original text: {sample_text}")
        print(f"Basic decode: {decoded}")

        seq_length = 50
        vocab_size = len(self)
        fake_logits = torch.randn(1, seq_length, vocab_size)

        if self.decoder is not None:
            print("\nTesting pyctcdecode integration with fake logits...")
            decoded_with_lm = self.ctc_decode(fake_logits)
            print(f"Decoded with LM: {decoded_with_lm}")
        else:
            print("\nNo language model loaded - using basic CTC decoding")
            basic_decoded = self._basic_ctc_decode(fake_logits.numpy(), seq_length)
            print(f"Basic CTC decoded: {basic_decoded[0]}")

    def __len__(self):
        """Number of tokens in the vocab."""
        return len(self.vocab)


In [7]:
%ls

hasbooted  onstart.sh


In [9]:
%cd /workspace/sound_asr/

/workspace/sound_asr


In [21]:
def test_encoder(encoder_cls, text="hello world"):
    """
    Test the encoder for various configurations of use_bpe and use_lm.

    Parameters:
    - encoder_cls: Class of the encoder to test.
    - text (str): The input text to encode and decode.
    """
    test_cases = [
        {"use_bpe": False, "use_lm": False},
        {"use_bpe": True, "use_lm": False},
        {"use_bpe": False, "use_lm": True},
        {"use_bpe": True, "use_lm": True},
    ]

    for case in test_cases:
        print(f"\n--- Testing with use_bpe={case['use_bpe']} and use_lm={case['use_lm']} ---")

        # Initialize encoder with the test case settings
        encoder = encoder_cls(
            use_bpe=case["use_bpe"],
            use_lm=case["use_lm"],
            # pretrained_tokenizer="facebook/wav2vec2-base-960h",
            pretrained_tokenizer = "sentencepiece_model/librispeech_unigram_model.model",
            # sp_model_path = "sentencepiece_model/librispeech_unigram_model.model",
            # pretrained_tokenizer="hf-test/xls-r-300m-sv",
            # pretrained_tokenizer = "bert-base-uncased",
            binary_path="4-gram_lc_correct.bin",
            unigram_path="librispeech-vocab.txt"
        )

        # Test encoding
        encoded = encoder.encode(text)
        print(f"Encoded: {encoded}")

        # Test decoding from encoded indices
        decoded = encoder.decode_simple(encoded[0].tolist())
        print(f"Decoded: {decoded}")

test_encoder(CTCTextEncoder)



--- Testing with use_bpe=False and use_lm=False ---
CTC Text Encoder:
pretrained_tokenizer: sentencepiece_model/librispeech_unigram_model.model
lm_weight: 0.5
beam_size: 100
binary_path: 4-gram_lc_correct.bin
use_lm: False
use_bpe: False
Loading unigrams from: librispeech-vocab.txt
Loaded 200000 unigrams
Initializing character-based vocabulary without tokenizer.

Vocabulary Info:
Size: 28
Full Vocabulary (up to first 50 tokens): ['[PAD]', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', ' ']
Blank token: [PAD], Blank index: 0
Sample ind2char mappings: {0: '[PAD]', 1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i'}
Sample char2ind mappings: {'[PAD]': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9}
Language model usage is disabled.
Encoded: tensor([[ 8,  5, 12, 12, 15, 27, 23, 15, 18, 12,  4]])
Decoded: helo world

--- Testing with use_bpe=True and use_lm=Fals