# Reconocimiento de fonemas usando CTC

In [1]:
import json
import torch
import torch.utils.data as data
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import  Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


## Implementación del dataset y el dataloader

### Dataset
Me convierte los datos crudos para que puedan ser usados por `Dataloader`. Me permite implementar una función `__getitem__()` en la cual leemos los datos y devolvemos por ejemplo el wav y la transcripción de cada dato.

In [None]:
train_json = 'data/train.json'
test_json = 'data/test.json'
valid_json = 'data/dev.json'

class TimitDataset(Dataset):
    def __init__(self, json_file):
        with open(json_file, 'r') as f:
            self.datos_json = json.load(f)

    def __len__(self):
        return len(self.datos_json)
    
    def __getitem__(self,key):
        wavdir = self.datos_json[key]['wav']
        duration = self.datos_json[key]['duration']
        phn = self.datos_json[key]['phn']
        return wavdir, duration, phn

def collate_fn(batch):
    # El batch es una lista de tuplas: [(dato1,label1), (dato2,label2),...]
    sequences, labels = zip(*batch) # Esto devuelve: 
                                    # sequences = (dato1,dato2,...)
                                    # labels = (label1,label2,...)
    padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0)
    labels = torch.tensor(labels)
    return padded_sequences, labels # Esta es la salida del dataloader
   
train_ds = TimitDataset(train_json)
test_ds = TimitDataset(test_json)
valid_ds = TimitDataset(valid_json)

train_dl = DataLoader(train_ds, batch_size=8, shuffle=True, collate_fn=collate_fn)
test_dl = DataLoader(train_ds, batch_size=16, shuffle=False, collate_fn=collate_fn)
valid_dl = DataLoader(train_ds, batch_size=16, shuffle=False, collate_fn=collate_fn)

   
        

### Versión creada por gemini

In [10]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict

class SpeechJsonDataset(Dataset):
    """
    A PyTorch Dataset for reading speech data from a JSON file.

    The JSON file is expected to have a top-level dictionary where keys are
    sample IDs and values are dictionaries containing speech data attributes
    like 'wav', 'duration', 'spk_id', 'phn', 'wrd', and 'ground_truth_phn_ends'.
    """

    def __init__(self, json_file_path):
        """
        Initializes the dataset by loading the JSON file and building vocabularies.

        Args:
            json_file_path (str): The path to the JSON dataset file.
        """
        # Load the JSON data
        with open(json_file_path, 'r') as f:
            self.data = json.load(f)

        # Get a list of all sample IDs (keys in the top-level dictionary)
        self.sample_ids = list(self.data.keys())

        # Build vocabularies for phonemes, words, and speaker IDs
        self.phn_vocab = {"<PAD>": 0, "<UNK>": 1} # Start with PAD and UNK tokens
        self.wrd_vocab = {"<PAD>": 0, "<UNK>": 1}
        self.spk_vocab = {}

        self._build_vocabularies()

    def _build_vocabularies(self):
        """
        Iterates through the data to build numerical vocabularies for phonemes,
        words, and speaker IDs.
        """
        phn_idx_counter = 2 # Start from 2 as 0 and 1 are reserved for PAD/UNK
        wrd_idx_counter = 2
        spk_idx_counter = 0

        for sample_id in self.sample_ids:
            entry = self.data[sample_id]

            # Process phonemes
            phonemes = entry.get('phn', '').split()
            for phn in phonemes:
                if phn not in self.phn_vocab:
                    self.phn_vocab[phn] = phn_idx_counter
                    phn_idx_counter += 1

            # Process words
            words = entry.get('wrd', '').split()
            for wrd in words:
                if wrd not in self.wrd_vocab:
                    self.wrd_vocab[wrd] = wrd_idx_counter
                    wrd_idx_counter += 1

            # Process speaker ID
            spk_id = entry.get('spk_id')
            if spk_id and spk_id not in self.spk_vocab:
                self.spk_vocab[spk_id] = spk_idx_counter
                spk_idx_counter += 1

        print(f"Phoneme Vocabulary Size: {len(self.phn_vocab)}")
        print(f"Word Vocabulary Size: {len(self.wrd_vocab)}")
        print(f"Speaker Vocabulary Size: {len(self.spk_vocab)}")

    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.sample_ids)

    def __getitem__(self, idx):
        """
        Retrieves a single sample from the dataset at the given index.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            dict: A dictionary containing the processed data for the sample.
                  Keys include 'sample_id', 'wav_path', 'duration',
                  'speaker_id', 'phonemes', 'words', 'phoneme_ends'.
        """
        sample_id = self.sample_ids[idx]
        entry = self.data[sample_id]

        # Extract raw data
        wav_path = entry.get('wav', '')
        duration = entry.get('duration', 0.0)

        # Convert speaker ID to numerical
        spk_id_str = entry.get('spk_id', '')
        speaker_id = self.spk_vocab.get(spk_id_str, -1) # -1 for unknown speaker

        # Process phonemes: tokenize and numericalize
        phonemes_raw = entry.get('phn', '').split()
        phonemes_numerical = [self.phn_vocab.get(p, self.phn_vocab["<UNK>"]) for p in phonemes_raw]
        # Convert to tensor; padding will be handled by collate_fn
        phonemes_tensor = torch.tensor(phonemes_numerical, dtype=torch.long)

        # Process words: tokenize and numericalize
        words_raw = entry.get('wrd', '').split()
        words_numerical = [self.wrd_vocab.get(w, self.wrd_vocab["<UNK>"]) for w in words_raw]
        # Convert to tensor; padding will be handled by collate_fn
        words_tensor = torch.tensor(words_numerical, dtype=torch.long)

        # Process ground_truth_phn_ends: convert to list of floats and then to tensor
        phn_ends_raw = entry.get('ground_truth_phn_ends', '').split()
        phoneme_ends = [float(end) for end in phn_ends_raw if end.strip()] # Ensure no empty strings
        phoneme_ends_tensor = torch.tensor(phoneme_ends, dtype=torch.float32)

        return {
            'sample_id': sample_id,
            'wav_path': wav_path,
            'duration': torch.tensor(duration, dtype=torch.float32),
            'speaker_id': torch.tensor(speaker_id, dtype=torch.long),
            'phonemes': phonemes_tensor,
            'words': words_tensor,
            'phoneme_ends': phoneme_ends_tensor
        }

def custom_collate_fn(batch):
    """
    Custom collate function to handle variable-length sequences (phonemes, words, phoneme_ends)
    by padding them to the maximum length within each batch.

    Args:
        batch (list): A list of dictionaries, where each dictionary is a sample
                      returned by the __getitem__ method of the dataset.

    Returns:
        dict: A dictionary of batched tensors and other data.
    """
    # Find max lengths in the current batch for padding
    max_phn_len = max(len(item['phonemes']) for item in batch)
    max_wrd_len = max(len(item['words']) for item in batch)
    max_phn_ends_len = max(len(item['phoneme_ends']) for item in batch)

    padded_phonemes = []
    padded_words = []
    padded_phoneme_ends = []
    sample_ids = []
    wav_paths = []
    durations = []
    speaker_ids = []

    for item in batch:
        # Pad phonemes
        phn_len = len(item['phonemes'])
        padded_phn = torch.cat([
            item['phonemes'],
            torch.tensor([0] * (max_phn_len - phn_len), dtype=torch.long) # 0 is PAD_ID
        ])
        padded_phonemes.append(padded_phn)

        # Pad words
        wrd_len = len(item['words'])
        padded_wrd = torch.cat([
            item['words'],
            torch.tensor([0] * (max_wrd_len - wrd_len), dtype=torch.long) # 0 is PAD_ID
        ])
        padded_words.append(padded_wrd)

        # Pad phoneme ends
        phn_ends_len = len(item['phoneme_ends'])
        padded_pe = torch.cat([
            item['phoneme_ends'],
            torch.tensor([0.0] * (max_phn_ends_len - phn_ends_len), dtype=torch.float32) # 0.0 is PAD_ID
        ])
        padded_phoneme_ends.append(padded_pe)

        # Collect other data
        sample_ids.append(item['sample_id'])
        wav_paths.append(item['wav_path'])
        durations.append(item['duration'])
        speaker_ids.append(item['speaker_id'])

    return {
        'sample_id': sample_ids, # List of strings, not a tensor
        'wav_path': wav_paths,   # List of strings, not a tensor
        'duration': torch.stack(durations),
        'speaker_id': torch.stack(speaker_ids),
        'phonemes': torch.stack(padded_phonemes),
        'words': torch.stack(padded_words),
        'phoneme_ends': torch.stack(padded_phoneme_ends),
        'phoneme_lengths': torch.tensor([len(item['phonemes']) for item in batch], dtype=torch.long), # Store original lengths
        'word_lengths': torch.tensor([len(item['words']) for item in batch], dtype=torch.long),
        'phoneme_ends_lengths': torch.tensor([len(item['phoneme_ends']) for item in batch], dtype=torch.long)
    }


if __name__ == '__main__':
    # Create a dummy p.json file for demonstration if it doesn't exist
    # In a real scenario, you would already have this file.
    try:
        with open('p.json', 'x') as f:
            f.write("""
{
  "mrws1_sx320": {
    "wav": "/dbase/timit/test/dr5/mrws1/sx320.wav",
    "duration": 3.28325,
    "spk_id": "mrws1",
    "phn": "sil dh ih n ih r ih s ih n ih sil g aa sil m ey n aa sil b iy w ih th ih n w aa sil k ih ng sil d ih s sil t ih n sil s sil",
    "wrd": "the nearest synagogue may not be within walking distance",
    "ground_truth_phn_ends": "2360 2840 3216 4511 5556 7018 7880 10440 11160 12040 13160 13960 14200 17640 18280 19160 20360 21560 23800 25320 25720 26520 27800 28825 30440 31248 32208 34130 35880 36760 37640 37960 39101 40120 40360 41640 43000 43320 44200 44440 45280 46680 49560 52480"
  },
  "mrws1_sx230": {
    "wav": "/dbase/timit/test/dr5/mrws1/sx230.wav",
    "duration": 3.2064375,
    "spk_id": "mrws1",
    "phn": "sil ah l aw l iy w ey hh iy er sil b ah r ae sh sil n l ay z aa l eh r er z sil",
    "wrd": "allow leeway here but rationalize all errors",
    "ground_truth_phn_ends": "3000 3592 4600 8605 10402 12360 13618 16280 17169 18757 20469 23000 23443 24520 26402 28600 30160 30600 31400 32360 35800 36360 38809 40040 41650 42945 46120 48600 51280"
  },
  "mrws1_sx50": {
    "wav": "/dbase/timit/test/dr5/mrws1/sx50.wav",
    "duration": 3.232,
    "spk_id": "mrws1",
    "phn": "sil k ae dx ih s sil t r aa f ih sil k iy sil k ih n aa m ih sil k ah sil b ae sil k s sil n ih sil g l eh sil dh ah sil p aa r sil",
    "wrd": "catastrophic economic cutbacks neglect the poor",
    "ground_truth_phn_ends": "2040 3080 4428 4720 5880 7480 7880 8600 9293 10680 12040 13000 13953 14680 15640 16680 17348 18200 18966 20762 21840 22520 23800 25080 26318 28880 29240 31160 32279 32584 33520 34360 34920 35926 37000 37472 38200 39644 42234 42760 43240 44840 45720 48333 49640 51680"
  }
}
""")
    except FileExistsError:
        print("p.json already exists. Skipping dummy file creation.")

    json_file_path = 'p.json'

    # 1. Create the dataset instance
    dataset = SpeechJsonDataset(json_file_path)

    # 2. Create a DataLoader instance
    # Set batch_size and shuffle as needed for training
    batch_size = 2
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)

    # 3. Iterate through the DataLoader to see the batched output
    print(f"\n--- Iterating through DataLoader with batch_size={batch_size} ---")
    for i, batch in enumerate(dataloader):
        print(f"\nBatch {i+1}:")
        print(f"  Sample IDs: {batch['sample_id']}")
        print(f"  WAV Paths (first 2): {batch['wav_path'][:2]}...")
        print(f"  Durations:\n{batch['duration']}")
        print(f"  Speaker IDs:\n{batch['speaker_id']}")
        print(f"  Phonemes (numerical, padded):\n{batch['phonemes']}")
        print(f"  Phoneme Lengths (original):\n{batch['phoneme_lengths']}")
        print(f"  Words (numerical, padded):\n{batch['words']}")
        print(f"  Word Lengths (original):\n{batch['word_lengths']}")
        print(f"  Phoneme Ends (padded):\n{batch['phoneme_ends']}")
        print(f"  Phoneme Ends Lengths (original):\n{batch['phoneme_ends_lengths']}")

        # Example of how you might use these in a model:
        # model_output = your_model(batch['phonemes'], batch['phoneme_lengths'])
        # loss = criterion(model_output, batch['speaker_id'])

        if i == 0: # Only print the first batch for brevity
            break

Phoneme Vocabulary Size: 33
Word Vocabulary Size: 23
Speaker Vocabulary Size: 1

--- Iterating through DataLoader with batch_size=2 ---

Batch 1:
  Sample IDs: ['mrws1_sx50', 'mrws1_sx230']
  WAV Paths (first 2): ['/dbase/timit/test/dr5/mrws1/sx50.wav', '/dbase/timit/test/dr5/mrws1/sx230.wav']...
  Durations:
tensor([3.2320, 3.2064])
  Speaker IDs:
tensor([0, 0])
  Phonemes (numerical, padded):
tensor([[ 2, 16, 25, 30,  4,  7,  2, 19,  6,  9, 31,  4,  2, 16, 13,  2, 16,  4,
          5,  9, 10,  4,  2, 16, 20,  2, 12, 25,  2, 16,  7,  2,  5,  4,  2,  8,
         21, 29,  2,  3, 20,  2, 32,  9,  6,  2],
        [ 2, 20, 21, 22, 21, 13, 14, 11, 23, 13, 24,  2, 12, 20,  6, 25, 26,  2,
          5, 21, 27, 28,  9, 21, 29,  6, 24, 28,  2,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0]])
  Phoneme Lengths (original):
tensor([46, 29])
  Words (numerical, padded):
tensor([[18, 19, 20, 21,  2, 22,  0],
        [11, 12, 13, 14, 15, 16, 17]])
  Word Lengths (original