In [1]:
import librosa
import os
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from torch.nn.utils import rnn
import torch

In [2]:
class NottinghamDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.name_list = os.listdir(self.data_dir)

    def __len__(self):
        return len(self.name_list)

    def __getitem__(self, idx):
        arr, _ = librosa.load(os.path.join(self.data_dir, self.name_list[idx]))
        return torch.from_numpy(arr)

In [3]:
dataset = NottinghamDataset('./data/nottingham-dataset/wav')

In [4]:
def collate_fn(data):
    data.sort(key=lambda d: len(d), reverse=True)
    data_length = [len(d) for d in data]
    data = rnn.pad_sequence(data, batch_first=True, padding_value=0)
    return data.unsqueeze(-1), data_length

In [5]:
dataset[0].shape

torch.Size([705664])

In [6]:
dataloader = DataLoader(dataset, batch_size=3, collate_fn=collate_fn)

In [7]:
for i, (data, length) in enumerate(dataloader):
    if i > 0: break
    print(rnn.pack_padded_sequence(data, length, batch_first=True).data.shape)

torch.Size([4277920, 1])


In [8]:
model = torch.nn.LSTM(1, 5, batch_first=True)

In [9]:
flag = 0
for i, (data, length) in enumerate(dataloader):
    if i > 0: break
    data = rnn.pack_padded_sequence(data, length, batch_first=True)
    output, hidden = model(data)
    if flag == 0:
        output, out_len = rnn.pad_packed_sequence(output, batch_first=True)
        print(output.shape)
        print(output)
        flag = 1

: 

: 