/
gtzan_dataset.py
64 lines (50 loc) · 2.09 KB
/
gtzan_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
from torch.utils import data
from misc.data_loader import ValidationDataLoader
from misc.transforms import get_train_transform, get_test_transform
from misc.utils import LabelsToOneHot, tensor_to_numpy
class GTANZDataset(data.Dataset):
def __init__(self, dataset_path, transforms=None, one_hot_labels=False):
data = np.load(dataset_path)
self.transforms = transforms
self.X = data["X"]
self.y = data["y"]
self.label_name = data["label_name"]
self.n = self.X.shape[0]
self.one_hot_labels = one_hot_labels
if one_hot_labels:
self.one_hot_encoder = LabelsToOneHot(self.y)
else:
self.one_hot_encoder = None
def instance_dataset(self, dataset_path, transforms):
new_dataset = GTANZDataset(dataset_path, transforms=transforms, one_hot_labels=False)
if self.one_hot_labels:
new_dataset.one_hot_labels = True
new_dataset.one_hot_encoder = self.one_hot_encoder
return new_dataset
def __len__(self):
return self.n
def __getitem__(self, index):
X, y, label_name = self.X[index], self.y[index], self.label_name[index]
if self.transforms:
X = tensor_to_numpy(self.transforms(X.reshape((1, -1, 1))))
if self.one_hot_labels:
y = self.one_hot_encoder(y)[0, :]
return {"sound": X, "class": y, "class_label": label_name}
if __name__=="__main__":
dataset = GTANZDataset("../genres16_test.npz",
transforms=get_train_transform(length=2 ** 14),
one_hot_labels=True)
print(len(dataset))
print(dataset[5])
params = {'batch_size': 64,
'shuffle': True,
'num_workers': 1}
dataset = GTANZDataset("../genres16_test.npz",
transforms=get_test_transform(length=2 ** 14),
one_hot_labels=True)
test_generator = ValidationDataLoader(dataset, **params)
for batch in test_generator:
print(batch['sound'].shape)
print(batch)
break