# Código generado por gemini o copilot


## Creación de dataset y dataloader

In [1]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict

class SpeechJsonDataset(Dataset):
    """
    A PyTorch Dataset for reading speech data from a JSON file.

    The JSON file is expected to have a top-level dictionary where keys are
    sample IDs and values are dictionaries containing speech data attributes
    like 'wav', 'duration', 'spk_id', 'phn', 'wrd', and 'ground_truth_phn_ends'.
    """

    def __init__(self, json_file_path):
        """
        Initializes the dataset by loading the JSON file and building vocabularies.

        Args:
            json_file_path (str): The path to the JSON dataset file.
        """
        # Load the JSON data
        with open(json_file_path, 'r') as f:
            self.data = json.load(f)

        # Get a list of all sample IDs (keys in the top-level dictionary)
        self.sample_ids = list(self.data.keys())

        # Build vocabularies for phonemes, words, and speaker IDs
        self.phn_vocab = {"<PAD>": 0, "<UNK>": 1} # Start with PAD and UNK tokens
        self.wrd_vocab = {"<PAD>": 0, "<UNK>": 1}
        self.spk_vocab = {}

        self._build_vocabularies()

    def _build_vocabularies(self):
        """
        Iterates through the data to build numerical vocabularies for phonemes,
        words, and speaker IDs.
        """
        phn_idx_counter = 2 # Start from 2 as 0 and 1 are reserved for PAD/UNK
        wrd_idx_counter = 2
        spk_idx_counter = 0

        for sample_id in self.sample_ids:
            entry = self.data[sample_id]

            # Process phonemes
            phonemes = entry.get('phn', '').split()
            for phn in phonemes:
                if phn not in self.phn_vocab:
                    self.phn_vocab[phn] = phn_idx_counter
                    phn_idx_counter += 1

            # Process words
            words = entry.get('wrd', '').split()
            for wrd in words:
                if wrd not in self.wrd_vocab:
                    self.wrd_vocab[wrd] = wrd_idx_counter
                    wrd_idx_counter += 1

            # Process speaker ID
            spk_id = entry.get('spk_id')
            if spk_id and spk_id not in self.spk_vocab:
                self.spk_vocab[spk_id] = spk_idx_counter
                spk_idx_counter += 1

        print(f"Phoneme Vocabulary Size: {len(self.phn_vocab)}")
        print(f"Word Vocabulary Size: {len(self.wrd_vocab)}")
        print(f"Speaker Vocabulary Size: {len(self.spk_vocab)}")

    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.sample_ids)

    def __getitem__(self, idx):
        """
        Retrieves a single sample from the dataset at the given index.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            dict: A dictionary containing the processed data for the sample.
                  Keys include 'sample_id', 'wav_path', 'duration',
                  'speaker_id', 'phonemes', 'words', 'phoneme_ends'.
        """
        sample_id = self.sample_ids[idx]
        entry = self.data[sample_id]

        # Extract raw data
        wav_path = entry.get('wav', '')
        duration = entry.get('duration', 0.0)

        # Convert speaker ID to numerical
        spk_id_str = entry.get('spk_id', '')
        speaker_id = self.spk_vocab.get(spk_id_str, -1) # -1 for unknown speaker

        # Process phonemes: tokenize and numericalize
        phonemes_raw = entry.get('phn', '').split()
        phonemes_numerical = [self.phn_vocab.get(p, self.phn_vocab["<UNK>"]) for p in phonemes_raw]
        # Convert to tensor; padding will be handled by collate_fn
        phonemes_tensor = torch.tensor(phonemes_numerical, dtype=torch.long)

        # Process words: tokenize and numericalize
        words_raw = entry.get('wrd', '').split()
        words_numerical = [self.wrd_vocab.get(w, self.wrd_vocab["<UNK>"]) for w in words_raw]
        # Convert to tensor; padding will be handled by collate_fn
        words_tensor = torch.tensor(words_numerical, dtype=torch.long)

        # Process ground_truth_phn_ends: convert to list of floats and then to tensor
        phn_ends_raw = entry.get('ground_truth_phn_ends', '').split()
        phoneme_ends = [float(end) for end in phn_ends_raw if end.strip()] # Ensure no empty strings
        phoneme_ends_tensor = torch.tensor(phoneme_ends, dtype=torch.float32)

        return {
            'sample_id': sample_id,
            'wav_path': wav_path,
            'duration': torch.tensor(duration, dtype=torch.float32),
            'speaker_id': torch.tensor(speaker_id, dtype=torch.long),
            'phonemes': phonemes_tensor,
            'words': words_tensor,
            'phoneme_ends': phoneme_ends_tensor
        }

def custom_collate_fn(batch):
    """
    Custom collate function to handle variable-length sequences (phonemes, words, phoneme_ends)
    by padding them to the maximum length within each batch.

    Args:
        batch (list): A list of dictionaries, where each dictionary is a sample
                      returned by the __getitem__ method of the dataset.

    Returns:
        dict: A dictionary of batched tensors and other data.
    """
    # Find max lengths in the current batch for padding
    max_phn_len = max(len(item['phonemes']) for item in batch)
    max_wrd_len = max(len(item['words']) for item in batch)
    max_phn_ends_len = max(len(item['phoneme_ends']) for item in batch)

    padded_phonemes = []
    padded_words = []
    padded_phoneme_ends = []
    sample_ids = []
    wav_paths = []
    durations = []
    speaker_ids = []

    for item in batch:
        # Pad phonemes
        phn_len = len(item['phonemes'])
        padded_phn = torch.cat([
            item['phonemes'],
            torch.tensor([0] * (max_phn_len - phn_len), dtype=torch.long) # 0 is PAD_ID
        ])
        padded_phonemes.append(padded_phn)

        # Pad words
        wrd_len = len(item['words'])
        padded_wrd = torch.cat([
            item['words'],
            torch.tensor([0] * (max_wrd_len - wrd_len), dtype=torch.long) # 0 is PAD_ID
        ])
        padded_words.append(padded_wrd)

        # Pad phoneme ends
        phn_ends_len = len(item['phoneme_ends'])
        padded_pe = torch.cat([
            item['phoneme_ends'],
            torch.tensor([0.0] * (max_phn_ends_len - phn_ends_len), dtype=torch.float32) # 0.0 is PAD_ID
        ])
        padded_phoneme_ends.append(padded_pe)

        # Collect other data
        sample_ids.append(item['sample_id'])
        wav_paths.append(item['wav_path'])
        durations.append(item['duration'])
        speaker_ids.append(item['speaker_id'])

    return {
        'sample_id': sample_ids, # List of strings, not a tensor
        'wav_path': wav_paths,   # List of strings, not a tensor
        'duration': torch.stack(durations),
        'speaker_id': torch.stack(speaker_ids),
        'phonemes': torch.stack(padded_phonemes),
        'words': torch.stack(padded_words),
        'phoneme_ends': torch.stack(padded_phoneme_ends),
        'phoneme_lengths': torch.tensor([len(item['phonemes']) for item in batch], dtype=torch.long), # Store original lengths
        'word_lengths': torch.tensor([len(item['words']) for item in batch], dtype=torch.long),
        'phoneme_ends_lengths': torch.tensor([len(item['phoneme_ends']) for item in batch], dtype=torch.long)
    }


if __name__ == '__main__':
    # Create a dummy p.json file for demonstration if it doesn't exist
    # In a real scenario, you would already have this file.
    try:
        with open('p.json', 'x') as f:
            f.write("""
{
  "mrws1_sx320": {
    "wav": "/dbase/timit/test/dr5/mrws1/sx320.wav",
    "duration": 3.28325,
    "spk_id": "mrws1",
    "phn": "sil dh ih n ih r ih s ih n ih sil g aa sil m ey n aa sil b iy w ih th ih n w aa sil k ih ng sil d ih s sil t ih n sil s sil",
    "wrd": "the nearest synagogue may not be within walking distance",
    "ground_truth_phn_ends": "2360 2840 3216 4511 5556 7018 7880 10440 11160 12040 13160 13960 14200 17640 18280 19160 20360 21560 23800 25320 25720 26520 27800 28825 30440 31248 32208 34130 35880 36760 37640 37960 39101 40120 40360 41640 43000 43320 44200 44440 45280 46680 49560 52480"
  },
  "mrws1_sx230": {
    "wav": "/dbase/timit/test/dr5/mrws1/sx230.wav",
    "duration": 3.2064375,
    "spk_id": "mrws1",
    "phn": "sil ah l aw l iy w ey hh iy er sil b ah r ae sh sil n l ay z aa l eh r er z sil",
    "wrd": "allow leeway here but rationalize all errors",
    "ground_truth_phn_ends": "3000 3592 4600 8605 10402 12360 13618 16280 17169 18757 20469 23000 23443 24520 26402 28600 30160 30600 31400 32360 35800 36360 38809 40040 41650 42945 46120 48600 51280"
  },
  "mrws1_sx50": {
    "wav": "/dbase/timit/test/dr5/mrws1/sx50.wav",
    "duration": 3.232,
    "spk_id": "mrws1",
    "phn": "sil k ae dx ih s sil t r aa f ih sil k iy sil k ih n aa m ih sil k ah sil b ae sil k s sil n ih sil g l eh sil dh ah sil p aa r sil",
    "wrd": "catastrophic economic cutbacks neglect the poor",
    "ground_truth_phn_ends": "2040 3080 4428 4720 5880 7480 7880 8600 9293 10680 12040 13000 13953 14680 15640 16680 17348 18200 18966 20762 21840 22520 23800 25080 26318 28880 29240 31160 32279 32584 33520 34360 34920 35926 37000 37472 38200 39644 42234 42760 43240 44840 45720 48333 49640 51680"
  }
}
""")
    except FileExistsError:
        print("p.json already exists. Skipping dummy file creation.")

    json_file_path = 'p.json'

    # 1. Create the dataset instance
    dataset = SpeechJsonDataset(json_file_path)

    # 2. Create a DataLoader instance
    # Set batch_size and shuffle as needed for training
    batch_size = 2
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)

    # 3. Iterate through the DataLoader to see the batched output
    print(f"\n--- Iterating through DataLoader with batch_size={batch_size} ---")
    for i, batch in enumerate(dataloader):
        print(f"\nBatch {i+1}:")
        print(f"  Sample IDs: {batch['sample_id']}")
        print(f"  WAV Paths (first 2): {batch['wav_path'][:2]}...")
        print(f"  Durations:\n{batch['duration']}")
        print(f"  Speaker IDs:\n{batch['speaker_id']}")
        print(f"  Phonemes (numerical, padded):\n{batch['phonemes']}")
        print(f"  Phoneme Lengths (original):\n{batch['phoneme_lengths']}")
        print(f"  Words (numerical, padded):\n{batch['words']}")
        print(f"  Word Lengths (original):\n{batch['word_lengths']}")
        print(f"  Phoneme Ends (padded):\n{batch['phoneme_ends']}")
        print(f"  Phoneme Ends Lengths (original):\n{batch['phoneme_ends_lengths']}")

        # Example of how you might use these in a model:
        # model_output = your_model(batch['phonemes'], batch['phoneme_lengths'])
        # loss = criterion(model_output, batch['speaker_id'])

        if i == 0: # Only print the first batch for brevity
            break

p.json already exists. Skipping dummy file creation.
Phoneme Vocabulary Size: 33
Word Vocabulary Size: 23
Speaker Vocabulary Size: 1

--- Iterating through DataLoader with batch_size=2 ---

Batch 1:
  Sample IDs: ['mrws1_sx50', 'mrws1_sx320']
  WAV Paths (first 2): ['/dbase/timit/test/dr5/mrws1/sx50.wav', '/dbase/timit/test/dr5/mrws1/sx320.wav']...
  Durations:
tensor([3.2320, 3.2833])
  Speaker IDs:
tensor([0, 0])
  Phonemes (numerical, padded):
tensor([[ 2, 16, 25, 30,  4,  7,  2, 19,  6,  9, 31,  4,  2, 16, 13,  2, 16,  4,
          5,  9, 10,  4,  2, 16, 20,  2, 12, 25,  2, 16,  7,  2,  5,  4,  2,  8,
         21, 29,  2,  3, 20,  2, 32,  9,  6,  2],
        [ 2,  3,  4,  5,  4,  6,  4,  7,  4,  5,  4,  2,  8,  9,  2, 10, 11,  5,
          9,  2, 12, 13, 14,  4, 15,  4,  5, 14,  9,  2, 16,  4, 17,  2, 18,  4,
          7,  2, 19,  4,  5,  2,  7,  2,  0,  0]])
  Phoneme Lengths (original):
tensor([46, 44])
  Words (numerical, padded):
tensor([[18, 19, 20, 21,  2, 22,  0,  0,  0],
  

## Creación de un greedy ctc decoder

In [5]:
import torch

def greedy_ctc_decode(emissions: torch.Tensor, blank_idx: int) -> list[str]:
    """
    Performs greedy CTC decoding on a batch of log-probabilities.

    Args:
        emissions: Tensor of shape (seq_len, batch_size, num_classes)
                   containing log-probabilities.
        blank_idx: Index of the blank token.

    Returns:
        A list of decoded strings, one for each sequence in the batch.
    """
    decoded_sequences = []
    # Permute to (batch_size, seq_len, num_classes) for easier argmax
    emissions = emissions.permute(1, 0, 2)

    for i in range(emissions.shape[0]): # Iterate over batch
        # Get the index of the max probability at each timestep
        argmax_preds = emissions[i].argmax(dim=-1)

        decoded_seq = []
        last_char_idx = -1
        for char_idx in argmax_preds:
            if char_idx != blank_idx and (char_idx != last_char_idx or last_char_idx == blank_idx):
                # Add if not blank and not a repeated character (unless the last was blank)
                decoded_seq.append(char_idx.item())
            last_char_idx = char_idx

        # Convert indices to actual characters (you'll need your vocab mapping)
        # For this example, let's assume `tokens` list is available
        decoded_strings = [tokens[idx] for idx in decoded_seq]
        decoded_sequences.append("".join(decoded_strings))
    return decoded_sequences

# Example usage
tokens = ["<blank>", "a", "b", "c", " "] # Your actual vocabulary
blank_idx = tokens.index("<blank>")

# Example emissions (seq_len, batch_size, num_classes)
# Let's say the model outputs: "a-a-b-blank-b-c" (where '-' is blank)
# This should decode to "abc"
emissions_example = torch.zeros(7, 1, len(tokens))
emissions_example[0, 0, tokens.index("a")] = 10
emissions_example[1, 0, blank_idx] = 10
emissions_example[2, 0, tokens.index("a")] = 10
emissions_example[3, 0, blank_idx] = 10
emissions_example[4, 0, tokens.index("b")] = 10
emissions_example[5, 0, blank_idx] = 10
emissions_example[6, 0, tokens.index("c")] = 10

# Add a small amount of noise to other classes to avoid all zeros in softmax
emissions_example += torch.randn_like(emissions_example) * 0.1

decoded_greedy = greedy_ctc_decode(emissions_example, blank_idx)
print(f"Greedy decoded: {decoded_greedy}")

Greedy decoded: ['aabc']


## Lectura del archivo de texto con la codificación

In [None]:
import torch
from torchtext.vocab import Vocab
from collections import OrderedDict

def load_vocab_from_file(filepath: str) -> Vocab:
    """
    Loads a torchtext.vocab.Vocab object from a file where each line
    contains a character and its corresponding index, separated by a space.

    Args:
        filepath (str): The path to the vocabulary file.

    Returns:
        torchtext.vocab.Vocab: The constructed Vocab object.
    """
    # OrderedDict is used to preserve the order of insertion,
    # which is important for the indices to match your file.
    # We'll store (token, count) pairs, where count can be dummy (e.g., 1).
    token_counts = OrderedDict()

    # Variables to store special token indices for setting default later
    unk_token = None
    unk_index = -1

    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split(' ')
            if len(parts) == 2:
                char, index_str = parts[0], parts[1]
                try:
                    index = int(index_str)
                    # Store (char, 1) to build_vocab_from_iterator later
                    token_counts[char] = 1 # The count doesn't matter much here
                    if char == "<unk>": # Assuming your unknown token is named "<unk>"
                        unk_token = char
                        unk_index = index
                except ValueError:
                    print(f"Warning: Skipping malformed line '{line.strip()}' - index is not an integer.")
            else:
                print(f"Warning: Skipping malformed line '{line.strip()}' - expected 'char index'.")

    # Build the vocab from the ordered dictionary of token counts
    # This ensures that the indices in the vocab match the indices in your file.
    vocab = Vocab(token_counts)

    # Set the default index for unknown tokens if an <unk> token was found in the file
    if unk_token is not None and unk_index != -1:
        vocab.set_default_index(unk_index)
    else:
        print("Warning: No '<unk>' token found or its index is invalid. Default index not set.")
        # If no <unk> is found, lookup of unknown chars will raise an error by default.
        # You might want to handle this case, e.g., by setting a default index to 0 or raising an error.

    return vocab

# --- Example Usage ---
# 1. Create a dummy vocab file for demonstration
dummy_vocab_content = """<unk> 0
<pad> 1
a 2
b 3
c 4
 5
! 6
. 7
x 8
y 9
z 10
"""
with open('char_vocab.txt', 'w', encoding='utf-8') as f:
    f.write(dummy_vocab_content)

# 2. Load the vocabulary
vocab_file_path = 'char_vocab.txt'
my_vocab = load_vocab_from_file(vocab_file_path)

print(f"Loaded Vocab size: {len(my_vocab)}")

# Test conversions
print(f"Index of 'a': {my_vocab['a']}")
print(f"Character at index 3: {my_vocab.lookup_token(3)}")
print(f"Index of unknown character 'ñ': {my_vocab['ñ']}") # Should return default index for <unk>

# Example of using the vocab with a tokenizer
from torchtext.data.utils import get_tokenizer

# For character-level tokenization, a simple split by character or custom function works.
# For simplicity, we'll just convert a string to a list of characters here.
def char_tokenizer(text: str) -> list[str]:
    return list(text) # Splits "hello" into ['h', 'e', 'l', 'l', 'o']

sample_text = "ab c!.xyz"
tokens = char_tokenizer(sample_text)
indices = my_vocab(tokens)

print(f"Original text: '{sample_text}'")
print(f"Tokens: {tokens}")
print(f"Indices: {indices}")

decoded_tokens = my_vocab.lookup_tokens(indices)
decoded_text = "".join(decoded_tokens)
print(f"Decoded tokens: {decoded_tokens}")
print(f"Decoded text: '{decoded_text}'")

In [None]:
import torch
from torchtext.vocab import Vocab
import os # For creating a dummy file

def load_vocab_from_phoneme_file(filepath: str) -> Vocab:
    """
    Loads a torchtext.vocab.Vocab object from a file where each line
    contains a phoneme and its corresponding index, separated by a space.
    Assumes all necessary phonemes and indices are in the file,
    and no out-of-vocabulary handling is needed.

    Args:
        filepath (str): The path to the vocabulary file.

    Returns:
        torchtext.vocab.Vocab: The constructed Vocab object.
    """
    phoneme_to_index_map = {}
    max_index = -1

    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split(' ')
            if len(parts) == 2:
                phoneme, index_str = parts[0], parts[1]
                try:
                    index = int(index_str)
                    phoneme_to_index_map[phoneme] = index
                    if index > max_index:
                        max_index = index
                except ValueError:
                    print(f"Error: Could not parse index from line: '{line.strip()}'")
                    raise
            else:
                print(f"Error: Malformed line (expected 'phoneme index'): '{line.strip()}'")
                raise

    # Create the index-to-phoneme list (itos_list) based on the loaded map
    # Initialize with None as placeholders for potentially missing indices (though your data is contiguous)
    itos_list = [None] * (max_index + 1)
    for phoneme, index in phoneme_to_index_map.items():
        itos_list[index] = phoneme
    
    # Optional: Basic check to ensure no gaps or unassigned indices if strictness is required
    if any(item is None for item in itos_list):
        missing_indices = [i for i, item in enumerate(itos_list) if item is None]
        print(f"Warning: itos_list has unassigned indices: {missing_indices}. This might indicate gaps in your input file's indices.")


    # Create the Vocab object directly from the itos list
    vocab = Vocab(itos_list=itos_list)
    
    # No vocab.set_default_index() needed as per your requirement of no OOV.
    # If a phoneme not in the loaded vocab is queried, it will raise a KeyError.

    return vocab

# --- Example Usage ---

# 1. Create a dummy vocab file for demonstration
dummy_vocab_content = """sil 39
d 1
ih 2
jh 3
uw 4
iy 5
y 6
eh 7
w 8
dh 9
ey 10
n 11
aw 12
er 13
sh 14
ae 15
dx 16
t 17
ah 18
k 19
m 20
r 21
ay 22
aa 23
l 24
hh 25
f 26
v 27
s 28
ow 29
b 30
z 31
th 32
g 33
ng 34
p 35
ch 36
uh 37
oy 38
<blank> 0
"""
vocab_file_name = 'phoneme_vocab.txt'
with open(vocab_file_name, 'w', encoding='utf-8') as f:
    f.write(dummy_vocab_content)

# 2. Load the vocabulary from the file
try:
    my_phoneme_vocab = load_vocab_from_phoneme_file(vocab_file_name)
    print(f"Vocabulary loaded successfully from {vocab_file_name}.")
    print(f"Loaded Vocab size: {len(my_phoneme_vocab)}")

    # A simple "tokenizer" function for phoneme strings
    def phoneme_string_to_list(text_with_phonemes: str) -> list[str]:
        return text_with_phonemes.strip().split(' ')

    # Example phoneme sequence for encoding
    input_phonemes = "d ih jh uw <blank> iy ah"
    phoneme_tokens = phoneme_string_to_list(input_phonemes)

    # Encode phoneme tokens to indices
    encoded_indices = my_phoneme_vocab(phoneme_tokens)
    print(f"\nOriginal phonemes: {phoneme_tokens}")
    print(f"Encoded indices: {encoded_indices}")

    # Decode indices back to phonemes
    decoded_phonemes = my_phoneme_vocab.lookup_tokens(encoded_indices)
    print(f"Decoded phonemes: {decoded_phonemes}")

    # Test direct lookup
    print(f"\nIndex of 'sil': {my_phoneme_vocab['sil']}")
    print(f"Phoneme at index 0: {my_phoneme_vocab.lookup_token(0)}")
    print(f"Index of '<blank>': {my_phoneme_vocab['<blank>']}")
    print(f"Phoneme at index 39: {my_phoneme_vocab.lookup_token(39)}")

    # Demonstrate what happens with an OOV token (raises KeyError as no default is set)
    print("\nAttempting to lookup an out-of-vocabulary phoneme (will cause KeyError):")
    try:
        my_phoneme_vocab['not_a_phoneme']
    except KeyError as e:
        print(f"Caught expected KeyError: {e}")

finally:
    # Clean up the dummy file
    if os.path.exists(vocab_file_name):
        os.remove(vocab_file_name)