In [None]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
from torchvision import transforms
import cv2
from PIL import Image

In [None]:
class WineDataset(Dataset):
    def __init__(self):
        #dataloading
        data = np.loadtxt('wine.csv', delimiter=',', dtype=np.float32, skiprows=1)
        self.x = torch.from_numpy(data[:, 1:])
        self.y = torch.from_numpy(data[:, [0]])
        
        self.n_samples = data.shape[0]
        
    def __getitem__(self, index):
        # dataset indexing 
        return self.x[index], self.y[index]
        
    def __len__(self):
        return self.n_samples

In [None]:
class MusicDataset(Dataset):
    def __init__(self, csv_path: str, root_dir: str, transform=None):
        # dataloading
        meta_data = pd.read_csv(csv_path)
        song_file_names = get_image_files(root_dir)

        self.n_samples = meta_data.shape[0]
        self.transform = transform
        self.labels = get_labels(meta_data, column_name='label')
        self.data = load_data(song_file_names, transform)
        print(self.data.shape)
        
    def __getitem__(self, index):
        # dataset indexing
        return self.data[index], self.labels[index]
        
    def __len__(self):
        # get length of dataset
        return self.n_samples
        

In [None]:
preprocessing = transforms.Compose([
    transforms.ToTensor(), # converts (H W C) uint8 to (C W H) float32 [0-1]
    transforms.Normalize((0.5124, 0.4420, 0.4994), (0.4354, 0.4721, 0.4593))
])

In [None]:
dataset = MusicDataset('d:/Data/features_30_sec.csv', 'd:/Data/images_original', transform=preprocessing)

In [None]:
a, b = dataset[0]

In [None]:
def get_labels(df, column_name: str):
    labels = df[column_name]
    labels_to_id = labels.apply(list(labels.unique()).index)
    labels_reshaped = np.reshape(np.array(labels_to_id), (labels.shape[0], 1)).astype('float32')
    tensor_labels = torch.tensor(labels_reshaped)
    return tensor_labels
    
import os
def get_image_files(rootdir: str) -> list:
    #rootdir = 'd:/Data/images_original'
    image_file_locations = []
    for subdir, dirs, files in os.walk(rootdir):
        for file in files:
            #print os.path.join(subdir, file)
            filepath = subdir + os.sep + file

            if filepath.endswith(".png"):
                image_file_locations.append(filepath)
    return image_file_locations

In [None]:
def load_data(image_file_locations: list, transform=None):
    #  [batch_size, channels, height, width].
    images = []
    for image_file in image_file_locations:
        current_img = cv2.imread(image_file).astype('uint8')
        if transform is not None:
            # transform/preprocess image
            current_img = transform(current_img)
        else:
            current_img = torch.from_numpy(current_img)
        #transposed_img = torch.einsum('ijk->kij', current_img)
        images.append(current_img)
    print(images[0].shape)
    data = torch.stack(images, dim=0)
    return data

In [None]:
a.view(3, -1).mean(dim=1)/255.0, a.view(3, -1).std(dim=1)/255.0