In [18]:
#TO DOWNLOAD THE DATASET
# jsb = "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/jsb_chorales.pickle"
# piano = "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/piano_midi.pickle"
# muse = "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/muse_data.pickle"
# nottingham = "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/nottingham.pickle"

# !wget $jsb
# !wget $piano
# !wget $muse
# !wget $nottingham

In [1]:
import pickle
import numpy as np
from torch.utils.data import Dataset
import torch

In [44]:
class MusicDataset(Dataset):
    def __init__(self, path, max_note=88, min_note=21, split='train'):
        self.max_note = max_note
        self.min_note = min_note
        self.path = path
        self.data = self.load_process_data(self.path, max_note=self.max_note, min_note=self.min_note, split=split)
        self.split = split
        self.sequence_lengths = self.data['sequence_lengths']
        self.encodings = self.data['encodings']
    
    def read_pickle_from_url(self, path, split='train'):
        with open(path, 'rb') as file:
            data = pickle.load(file)
        return data


    def load_process_data(self, path, max_note=88, min_note=21, split='train'):
        data_dict = {}
        encodings = []
        sequence_lengths = []
        
        data = self.read_pickle_from_url(path) #229 music data
        tr_data = data[split]
        
        for i, music in enumerate(tr_data):
            one_hot_music = torch.zeros((len(music), max_note))
            # print(one_hot_music.shape)
            for j, keys in enumerate(music): #the tuples
                # print(keys)
                one_hot_vector = [0] * max_note
                if len(keys) == 0:
                    continue
                for key in keys:
                    key = key - 21
                    one_hot_vector[key] = 1
                one_hot_music[j,:] = torch.tensor(one_hot_vector)
            
            encodings.append(one_hot_music)
            # print(one_hot_music)
            sequence_lengths.append(len(one_hot_music))
            
        data_dict['encodings'] = encodings
        data_dict['sequence_lengths'] = torch.tensor(sequence_lengths, dtype=torch.long)
        
        return data_dict
        
    def __getitem__(self, index):
        return self.encodings[index], self.sequence_lengths[index]
        #encodings dim: (1, seq_len, 88)
    def __len__(self):
        return len(self.encodings)

In [20]:
# processed_dataset = {}
# for split, data_split in data.items():
#     processed_dataset[split] = {}
#     n_seqs = len(data_split)
#     processed_dataset[split]['sequence_lengths'] = torch.zeros(n_seqs, dtype=torch.long)
#     processed_dataset[split]['sequences'] = []
#     for seq in range(n_seqs):
#         seq_length = len(data_split[seq])
#         processed_dataset[split]['sequence_lengths'][seq] = seq_length
#         processed_sequence = torch.zeros((seq_length, note_range))
#         for t in range(seq_length):
#             note_slice = torch.tensor(list(data_split[seq][t])) - min_note
#             slice_length = len(note_slice)
#             if slice_length > 0:
#                 processed_sequence[t, note_slice] = torch.ones(slice_length)
#         processed_dataset[split]['sequences'].append(processed_sequence)
# pickle.dump(processed_dataset, open(output, "wb"), pickle.HIGHEST_PROTOCOL)
# print("dumped processed data to %s" % output)

In [23]:
from torch.utils.data import DataLoader
dataset = MusicDataset(path='data/jsb_chorales.pickle', split='train')
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [43]:
for i, (encodings, sequence_lengths) in enumerate(dataloader):
    print(encodings.shape)
    print(encodings)
    print(sequence_lengths)
    break


torch.Size([1, 129, 88])
tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])
tensor([129])
