In [None]:
import torch
import collections
from dataclasses import dataclass
from typing import List, Optional, Dict

@dataclass
class CTCHypothesis:
    """Mimics torchaudio.models.decoder.CTCHypothesis"""
    tokens: torch.Tensor
    words: List[str]
    score: float
    timesteps: torch.Tensor

class VectorizedLexicon:
    def __init__(self, lexicon_path: str, tokens: List[str], blank_token="<blank>"):
        self.tokens = tokens
        self.token_to_id = {t: i for i, t in enumerate(tokens)}
        self.blank_id = self.token_to_id.get(blank_token, 0)  # Default to 0 if not found
        
        # 1. Build the Trie on CPU
        self.trie = {'children': {}, 'is_word': False, 'id': 0}
        self.node_count = 1
        
        with open(lexicon_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if not parts: continue
                word = parts[0]
                # Assuming lexicon format: "word c a t" or just "word" if simple
                # We'll assume the word itself is the sequence of chars for simplicity,
                # or that the file provides space-separated tokens.
                spelling = parts[1:] if len(parts) > 1 else list(word)
                
                node = self.trie
                for char in spelling:
                    if char not in self.token_to_id:
                        continue # Skip unknown tokens
                    token_id = self.token_to_id[char]
                    if token_id not in node['children']:
                        node['children'][token_id] = {'children': {}, 'is_word': False, 'id': self.node_count}
                        self.node_count += 1
                    node = node['children'][token_id]
                node['is_word'] = True

        # 2. Flatten Trie to GPU Tensor: [Num_Nodes, Vocab_Size]
        # transitions[i, k] = index of next node if we are at node i and see token k
        # We initialize with -1 (invalid transition)
        self.vocab_size = len(tokens)
        self.transitions = torch.full((self.node_count, self.vocab_size), -1, dtype=torch.long)
        self.is_word_node = torch.zeros(self.node_count, dtype=torch.bool)
        
        # Queue for BFS traversal to populate tensor
        queue = [self.trie]
        while queue:
            node = queue.pop(0)
            u = node['id']
            if node['is_word']:
                self.is_word_node[u] = True
            
            for token_id, child_node in node['children'].items():
                v = child_node['id']
                self.transitions[u, token_id] = v
                queue.append(child_node)
                
        # Self-loops for blank token (stay at same node)
        # NOTE: Standard CTC logic handles blanks separately, but for the generic 
        # transition matrix, we often mark blank as "stay". 
        # Here we will handle blanks in the decoder logic explicitly to adhere to CTC rules.

    def to(self, device):
        self.transitions = self.transitions.to(device)
        self.is_word_node = self.is_word_node.to(device)
        return self


class GPUCTCDecoder:
    def __init__(self, 
                 lexicon_path: str, 
                 tokens: List[str], 
                 beam_size: int = 10, 
                 blank_token: str = "-"):
        
        self.beam_size = beam_size
        self.tokens = tokens
        self.blank_id = tokens.index(blank_token) if blank_token in tokens else 0
        
        # Load and vectorized lexicon
        self.lexicon = VectorizedLexicon(lexicon_path, tokens, blank_token)
        
    def __call__(self, emissions: torch.Tensor) -> List[List[CTCHypothesis]]:
        """
        Args:
            emissions: (Batch, Time, Vocab) - Logits from acoustic model
        Returns:
            List of lists of hypotheses (Batch -> N-Best)
        """
        device = emissions.device
        if self.lexicon.transitions.device != device:
            self.lexicon.to(device)

        B, T, V = emissions.shape
        
        # Initialize Beams
        # Scores: [B, Beam_Size]
        beam_scores = torch.full((B, self.beam_size), -float('inf'), device=device)
        beam_scores[:, 0] = 0.0  # Start with prob 1.0 (log prob 0.0)
        
        # History Tracking
        # We need to store token history to reconstruct paths later
        # For efficiency in this demo, we'll just store the current state indices
        # In a real optimized version, you'd use a backpointer array.
        
        # Trie Node Indices: [B, Beam_Size]
        # Start at root (0)
        beam_trie_nodes = torch.zeros((B, self.beam_size), dtype=torch.long, device=device)
        
        # Last Token (for CTC repeat collapse): [B, Beam_Size]
        beam_last_tokens = torch.full((B, self.beam_size), self.blank_id, dtype=torch.long, device=device)
        
        # Store paths (inefficient in pure Python, but clear for logic)
        # List of Lists of lists: Batch -> Beam -> Sequence
        beam_paths = [[[] for _ in range(self.beam_size)] for _ in range(B)]

        # --- DECODING LOOP ---
        for t in range(T):
            # 1. Get Logits for this step: [B, V]
            log_probs = torch.nn.functional.log_softmax(emissions[:, t, :], dim=-1)
            
            # 2. Expand Beams: [B, Beam, V]
            # Add current log_probs to existing beam scores
            # Shape: [B, Beam, 1] + [B, 1, V] -> [B, Beam, V]
            next_scores = beam_scores.unsqueeze(-1) + log_probs.unsqueeze(1)
            next_scores = next_scores.view(B, -1) # Flatten to [B, Beam * V]
            
            # --- LEXICON CONSTRAINT IMPLEMENTATION ---
            
            # Logic: We need to know the valid 'next_node' for every possible expansion.
            # Current nodes: beam_trie_nodes [B, Beam]
            # We need to lookup transitions for all V tokens.
            
            # Expand current nodes to match the flattened structure
            # [B, Beam] -> [B, Beam, 1] -> expand -> flatten -> [B, Beam*V]
            current_nodes_expanded = beam_trie_nodes.unsqueeze(-1).expand(-1, -1, V).reshape(B, -1)
            
            # Get potential tokens for every candidate
            # [1, 1, V] -> [B, Beam, V] -> flatten
            candidate_tokens = torch.arange(V, device=device).reshape(1, 1, -1).expand(B, self.beam_size, -1).reshape(B, -1)
            
            # Calculate Next Node using Vectorized Lookup
            # This is the "Step" logic for CTC:
            # - If token is BLANK: Stay at current node
            # - If token == Last Token: Stay at current node (Merge repeats)
            # - If token is NEW: Look up transition table
            
            last_tokens_expanded = beam_last_tokens.unsqueeze(-1).expand(-1, -1, V).reshape(B, -1)
            
            # Default: Move to next node in Trie
            # transitions[current_node, token]
            # We use gather or advanced indexing. 
            # tensor[idx_batch, idx_vocab]
            # Since we have a batch of current nodes, we flatten everything for lookup
            
            # Lookup: [Total_Candidates]
            next_trie_nodes_lookup = self.lexicon.transitions[current_nodes_expanded, candidate_tokens]
            
            is_space = (candidate_tokens == self.space_id)
            
            # Check if the current node (before moving) is a valid end of a word
            # We expand: [B, Beam] -> [B, Beam * V] to match flattened list
            current_is_word = self.lexicon.is_word_node[current_nodes_expanded]
            
            # Logic: 
            # 1. If Token is Space AND Current is Word -> Go to Root (0)
            # 2. If Token is Space AND Current is NOT Word -> Invalid (-1)
            # 3. If Token is NOT Space -> Keep the standard lookup result
            
            # We build a "Space Result" tensor first
            space_target = torch.where(current_is_word, 
                                       torch.tensor(0, device=device),  # Valid Space -> Root
                                       torch.tensor(-1, device=device)) # Invalid Space -> Die
            
            # Apply it to our lookup
            next_trie_nodes_lookup = torch.where(is_space, space_target, next_trie_nodes_lookup)
            
            # Apply CTC Logic
            is_blank = (candidate_tokens == self.blank_id)
            is_repeat = (candidate_tokens == last_tokens_expanded)
            
            # Logic:
            # If blank OR repeat: Next Node = Current Node
            # Else: Next Node = Lookup Result
            
            final_next_nodes = torch.where(
                is_blank | is_repeat,
                current_nodes_expanded,
                next_trie_nodes_lookup
            )
            
            
            # MASKING:
            # If final_next_nodes is -1, it means we tried to step off the Trie.
            # Set score to -inf
            valid_mask = (final_next_nodes != -1)
            next_scores = torch.where(valid_mask, next_scores, torch.tensor(-float('inf'), device=device))

            # ------------------------------------------
            # [[ SLOT FOR YOUR LLM SCORING HERE ]]
            # You have access to `next_scores` (Batch, Beam*V) 
            # and `final_next_nodes` before we top-k.
            # ------------------------------------------

            # 3. Prune (Top-K)
            # Take top beam_size scores
            top_scores, top_indices = torch.topk(next_scores, self.beam_size, dim=1)
            
            # 4. Update State
            beam_scores = top_scores
            
            # Recover which beam and which token generated the top scores
            # index = beam_idx * V + token_idx
            prev_beam_indices = top_indices // V
            new_token_indices = top_indices % V
            
            # Update Trie Nodes
            # We need to gather the node IDs from the `final_next_nodes` array using the top_indices
            beam_trie_nodes = torch.gather(final_next_nodes, 1, top_indices)
            
            # Update Last Tokens
            beam_last_tokens = new_token_indices

            # Update Paths (CPU-side bookkeeping for output)
            # This part is hard to fully vectorize without heavy memory, so usually done on CPU or
            # using a parent-pointer tensor. For simplicity/clarity:
            new_paths = []
            for b in range(B):
                batch_paths = []
                for k in range(self.beam_size):
                    prev_k = prev_beam_indices[b, k].item()
                    token = new_token_indices[b, k].item()
                    
                    # Copy previous path
                    current_path = list(beam_paths[b][prev_k])
                    current_path.append(token)
                    batch_paths.append(current_path)
                new_paths.append(batch_paths)
            beam_paths = new_paths
            
        # --- Finalize Results ---
        results = []
        for b in range(B):
            hyps = []
            for k in range(self.beam_size):
                # Filter blanks and repeats to get final "words"
                # Note: In this specific logic, we only tracked Trie nodes.
                # To get actual words, we'd usually process the raw token path.
                raw_tokens = beam_paths[b][k]
                
                # CTC Collapse for output string
                collapsed_indices = []
                prev = -1
                for t_id in raw_tokens:
                    if t_id != self.blank_id and t_id != prev:
                        collapsed_indices.append(t_id)
                    prev = t_id
                
                # Simple word reconstruction (assuming chars)
                words_str = "".join([self.tokens[i] for i in collapsed_indices])
                
                hyps.append(CTCHypothesis(
                    tokens=torch.tensor(collapsed_indices),
                    words=[words_str], # simplified
                    score=beam_scores[b, k].item(),
                    timesteps=torch.tensor([]) # Placeholder
                ))
            results.append(hyps)
            
        return results

# --- USAGE EXAMPLE ---

# 1. Create a dummy lexicon file
with open("lexicon.txt", "w") as f:
    f.write("CAT c a t\n")
    f.write("DOG d o g\n")
    f.write("BAT b a t\n")

# 2. Setup
tokens = ["-", "c", "a", "t", "d", "o", "g", "b"] # - is blank
decoder = GPUCTCDecoder("lexicon.txt", tokens, beam_size=2)

# 3. Fake Emissions [Batch=1, Time=5, Vocab=8]
# Let's make "d o g" likely
emissions = torch.randn(1, 5, 8).cuda() 
# ... (Set high logits for d, o, g to test) ...

# 4. Run
results = decoder(emissions)
print(results[0][0].words)

['do']


In [12]:
{}[0]

KeyError: 0

In [None]:
def debug_print_trie(node, token_map_inv, indent=0):
    """
    Recursive function to print the Trie structure.
    node: The current dictionary node
    token_map_inv: Dict mapping IDs back to characters {0: 'a', 1: 'b', ...}
    """
    # Visual spacing based on depth
    space = "  " * indent
    
    # Check if this node marks the end of a word
    word_marker = " [WORD]" if node['is_word'] else ""
    
    # Print current node ID
    print(f"{space}Node ID: {node['id']}{word_marker}")
    
    # Recursively print children
    for token_id, child_node in node['children'].items():
        char = token_map_inv.get(token_id, f"id_{token_id}")
        print(f"{space}  --({char})-->")
        debug_print_trie(child_node, token_map_inv, indent + 1)

# --- usage with the previous example ---
# Assuming you have the 'decoder' object from the previous code block:

# 1. Create a dummy lexicon file
with open("lexicon.txt", "w") as f:
    f.write("CAT c a t\n")
    f.write("DOG d o g\n")
    f.write("BAT b a t\n")
    f.write("CAR c a r\n")
    f.write("BAR b a r\n")
    
    

# 2. Setup
tokens = ["-", "c", "a", "t", "d", "o", "g", "b", "r"] # - is blank
decoder = GPUCTCDecoder("lexicon.txt", tokens, beam_size=2)

# Create an inverse map to see characters instead of IDs
id_to_token = {v: k for k, v in decoder.lexicon.token_to_id.items()}

print("--- TRIE STRUCTURE ---")
debug_print_trie(decoder.lexicon.trie, id_to_token)

--- TRIE STRUCTURE ---
Node ID: 0
  --(c)-->
  Node ID: 1
    --(a)-->
    Node ID: 2
      --(t)-->
      Node ID: 3 [WORD]
      --(r)-->
      Node ID: 10 [WORD]
  --(d)-->
  Node ID: 4
    --(o)-->
    Node ID: 5
      --(g)-->
      Node ID: 6 [WORD]
  --(b)-->
  Node ID: 7
    --(a)-->
    Node ID: 8
      --(t)-->
      Node ID: 9 [WORD]
      --(r)-->
      Node ID: 11 [WORD]


In [13]:
l = {}

l[0] = "key"

print(l)

{0: 'key'}
