In [1]:
import glob
import numpy as np
from scipy.io import wavfile
import torch
from torch.utils.data import Dataset
from utils import bin_and_one_hot

In [2]:
class WaveNetDataSet(Dataset):
    def __init__(self, audio_dir, num_frames_input, num_frames_output):
        """
        Assume we have a directory audio_dir, containing wav files as our dataset
        """
        self.data = glob.glob(audio_dir + "/*.wav")
        self.num_frames_input = num_frames_input
        self.num_frames_output = num_frames_output
        self.len_subset = num_frames_input + num_frames_output
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        fname = self.data[idx]
        # read a wavfile
        frequency, data = wavfile.read(fname)
        # select a random starting point within that file (leaving room to subset into x/y)
        random_start = np.random.randint(0, len(data) - self.len_subset)
        # select a subsete of that wavfile
        subset = data[random_start:random_start + self.len_subset]
        # x is the first set of frames in the subset
        x = subset[: self.num_frames_input]
        # we are going to try and predict everything after the split
        y = subset[self.num_frames_input: ]
        return torch.from_numpy(x), torch.from_numpy(y)

In [3]:
dset = WaveNetDataSet("./data", 100, 10)

In [4]:
dset[1]

(
   11
   65
  126
  150
  111
  123
  164
  184
  179
  171
  161
  196
  199
  222
  255
  274
  280
  312
  380
  450
  481
  458
  442
  436
  436
  405
  373
  378
  375
  392
  428
  408
  421
  448
  448
  507
  565
  541
  511
  483
  482
  455
  425
  347
  342
  391
  411
  378
  370
  414
  470
  460
  398
  400
  422
  417
  349
  282
  299
  320
  304
  243
  278
  289
  301
  270
  272
  239
  231
  194
  164
  158
  117
   13
   -4
   82
  124
   62
    1
   33
   99
   94
   25
   36
  110
  137
   48
  -15
  -49
  -65
 -128
 -129
 -100
 -130
 -140
 -119
  -33
   17
   79
   32
 [torch.ShortTensor of size 100], 
   28
   56
   80
  119
  107
   86
   56
   35
   44
  109
 [torch.ShortTensor of size 10])