In [102]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from torch.nn.utils.rnn import pad_sequence
from typing import Literal

In [103]:
class CommonVoice(torch.utils.data.Dataset):
    def __init__(
        self, split: Literal["train", "validation", "test"], streaming: bool = False
    ):
        self.dataset = load_dataset(
            "mozilla-foundation/common_voice_11_0",
            "en",
            split=split,
            streaming=streaming,
        )

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        x = torch.tensor(item["audio"]["array"])
        y = item["sentence"]
        return x, y


def collate_batch(batch):
    batch_x = [torch.tensor(item[0], dtype=torch.float32) for item in batch]
    batch_y = [item[1] for item in batch]

    batch_x_padded = pad_sequence(batch_x, batch_first=True, padding_value=0)

    return batch_x_padded, batch_y

In [105]:
ds_train = CommonVoice("train")
ds_val = CommonVoice("validation")
ds_test = CommonVoice("test")

In [107]:
ds_train[0]

(tensor([ 0.0000e+00,  1.9661e-15, -7.5534e-14,  ...,  1.6872e-05,
          1.4191e-07,  1.1668e-04], dtype=torch.float64),
 'The track appears on the compilation album "Kraftworks".')

In [106]:
dl_train = DataLoader(ds_train, batch_size=32, collate_fn=collate_batch, shuffle=True)

In [108]:
batch = next(iter(dl_train))
xs, ys = batch

  batch_x = [torch.tensor(item[0], dtype=torch.float32) for item in batch]


In [110]:
xs

tensor([[ 0.0000e+00, -6.4042e-16, -2.6238e-16,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  2.5533e-16, -1.0709e-14,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-4.5475e-13, -4.5475e-12, -4.0927e-12,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [ 0.0000e+00,  5.0786e-16,  2.5514e-15,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00, -4.4333e-16, -1.6788e-15,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00, -3.0650e-14, -5.4769e-14,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]])