In [2]:
import lmdb
from typing import Optional

import torch
import numpy as np
import IPython.display as ipd

from udls import transforms
from udls.generated import AudioExample


In [3]:
class AudioExampleDataset(torch.utils.data.Dataset):

    def __init__(self, db_path : str, transforms: Optional[transforms.Transform] = None) -> None:
        super().__init__()
        
        self._db_path = db_path
        self._transforms = transforms

        self.env = lmdb.open(self._db_path, lock=False)
        
        with self.env.begin(write=False) as txn:
            self.keys = list(txn.cursor().iternext(values=False))
        
    def __len__(self):
        return len(self.keys)
    
    def __getitem__(self, idx):
        with self.env.begin(write=False) as txn:
            ae = AudioExample.FromString(txn.get(self.keys[idx]))
        
        buffer = ae.buffers["waveform"]
        assert buffer.precision == AudioExample.Precision.INT16

        audio = np.frombuffer(buffer.data, dtype=np.int16)
        audio = audio.astype(np.float32) / (2**15 - 1)

        if self._transforms is not None:
            audio = self._transforms(audio)

        return audio

In [6]:
db_path = '/home/etiandre' # ne contient pas data.mdb
dataset = AudioExampleDataset(db_path=db_path)
print(len(dataset))

0


In [34]:
for i in range(10):
    ipd.display(ipd.Audio(dataset[i], rate=sr))

In [35]:
def data_loader(dataset, batch_size, valid_ratio=0.2, num_workers=0):
    # Split it into training and validation sets
    # if valid_ratio = 0.2 : 80%/20% split for train/valid
    nb_train = round((1.0 - valid_ratio) * len(dataset))
    nb_valid = round(valid_ratio * len(dataset))
    train_set, valid_set = torch.utils.data.dataset.random_split(
        dataset,
        [nb_train, nb_valid],
        generator=torch.Generator().manual_seed(42)
    )

    # Define DataLoaders
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=batch_size, 
        shuffle=True,
        num_workers=num_workers
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_set,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers
    )
    
    return train_loader, valid_loader

In [36]:
batch = 8
val_ratio = 0.2
train_set, val_set = data_loader(dataset, batch, valid_ratio=val_ratio)


In [37]:
x = next(iter(train_set))
x.shape

torch.Size([8, 65536])