In [15]:
import glob
import numpy as np
from scipy.io import wavfile
import torch
from torch.utils.data import Dataset
from utils import bin_and_one_hot

In [16]:
class WaveNetDataSet(Dataset):
    def __init__(self, audio_dir, num_frames_input, num_frames_output):
        """
        Assume we have a directory audio_dir, containing wav files as our dataset
        """
        self.data = glob.glob(audio_dir + "/*.wav")
        self.num_frames_input = num_frames_input
        self.num_frames_output = num_frames_output
        self.len_subset = num_frames_input + num_frames_output
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        fname = self.data[idx]
        frequency, data = wavfile.read(fname)
        random_start = np.random.randint(0, len(data) - self.len_subset)
        subset = data[random_start:random_start + self.len_subset]
        x = subset[: self.num_frames_input]
        y = subset[self.num_frames_input: self.num_frames_input + self.num_frames_output]
        return torch.from_numpy(x), torch.from_numpy(y)

In [17]:
dset = WaveNetDataSet("./data", 10, 1)

In [24]:
dset[1]

(
 -357
 -409
 -412
 -383
 -362
 -379
 -376
 -398
 -375
 -339
 [torch.ShortTensor of size 10], 
 -327
 [torch.ShortTensor of size 1])