In [13]:
from dataclasses import dataclass

import numpy as np
import torch
import torchaudio
from pytorch_lightning import LightningDataModule
from torch.utils.data import Dataset, DataLoader

from datalab import Dataset
torch.set_num_threads(1)


@dataclass
class DataConfig:
    filelist_path: str
    sampling_rate: int
    num_samples: int
    batch_size: int
    num_workers: int


class Collate:
    
    def __init__(self, train, num_samples, sampling_rate):
        self.train = train
        self.num_samples = num_samples
        self.sampling_rate = sampling_rate
        pass
        
    def __call__(self, batch):
        print(batch)
        y, sr = batch["audio"].
        if y.size(0) > 1:
            # mix to mono
            y = y.mean(dim=0, keepdim=True)
        gain = np.random.uniform(-1, -6) if self.train else -3
        y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]])
        if sr != self.sampling_rate:
            y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate)
        if y.size(-1) < self.num_samples:
            pad_length = self.num_samples - y.size(-1)
            padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
            y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
        elif self.train:
            start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
            y = y[:, start : start + self.num_samples]
        else:
            # During validation, take always the first segment for determinism
            y = y[:, : self.num_samples]

        return y[0]

        
class VocosDataModule(LightningDataModule):
    def __init__(self, train_params: DataConfig, val_params: DataConfig):
        super().__init__()
        self.train_config = train_params
        self.val_config = val_params

    def _get_dataloder(self, cfg: DataConfig, train: bool):
        dataset = VocosDataset(cfg, train=train)
        dataloader = DataLoader(
            dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=train, pin_memory=True,
        )
        return dataloader

    def train_dataloader(self) -> DataLoader:
        dset = Dataset(
            sources=["libritts-r-train-*"],
            fields=["audio"],
            max_batch_size=1
        ).collate(
           Collate()
        )
        cfg = self.train_config
        return DataLoader(
            dset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=False, pin_memory=True
        )

    def val_dataloader(self) -> DataLoader:
        return self._get_dataloder(self.val_config, train=False)


class VocosDataset(Dataset):
    def __init__(self, cfg: DataConfig, train: bool):
        with open(cfg.filelist_path) as f:
            self.filelist = f.read().splitlines()
        self.sampling_rate = cfg.sampling_rate
        self.num_samples = cfg.num_samples
        self.train = train

    def __len__(self) -> int:
        return len(self.filelist)

    def __getitem__(self, index: int) -> torch.Tensor:
        audio_path = self.filelist[index]
        y, sr = torchaudio.load(audio_path)
        if y.size(0) > 1:
            # mix to mono
            y = y.mean(dim=0, keepdim=True)
        gain = np.random.uniform(-1, -6) if self.train else -3
        y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]])
        if sr != self.sampling_rate:
            y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate)
        if y.size(-1) < self.num_samples:
            pad_length = self.num_samples - y.size(-1)
            padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
            y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
        elif self.train:
            start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
            y = y[:, start : start + self.num_samples]
        else:
            # During validation, take always the first segment for determinism
            y = y[:, : self.num_samples]

        return y[0]


In [15]:
from itertools import islice
dset = Dataset(
            sources=["libritts-r-train-*"],
            fields=["audio"],
            shuffle=False,
            max_batch_size=1
        ).collate(
           Collate(train=True, num_samples=16384, sampling_rate=24000)
        )

dloader = DataLoader(
    dset, batch_size=None, num_workers=2, shuffle=False, pin_memory=True
)
for batch in islice(dloader, 2):
    print(batch)



   id                              shard_id                      source  \
0   0  008ca261-0e23-4050-9737-db98d2073036  libritts-r-train-clean-100   

                                               audio  
0  ([[tensor(0.), tensor(0.), tensor(0.), tensor(...  
tensor([0.0000e+00, 2.0742e-05, 2.0742e-05,  ..., 1.2463e-01, 1.2168e-01,
        1.2276e-01])
   id                              shard_id                      source  \
1   1  008ca261-0e23-4050-9737-db98d2073036  libritts-r-train-clean-100   

                                               audio  
1  ([[tensor(0.), tensor(0.), tensor(0.), tensor(...  
          id                              shard_id  \
17232  17232  8351adc5-31d9-4457-bd64-5162a5c7e5a9   

                           source  \
17232  libritts-r-train-clean-100   

                                                   audio  
17232  ([[tensor(0.), tensor(0.), tensor(0.), tensor(...  
   id                              shard_id                      source  \
2   2 