In [18]:
#TO DOWNLOAD THE DATASET
# jsb = "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/jsb_chorales.pickle"
# piano = "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/piano_midi.pickle"
# muse = "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/muse_data.pickle"
# nottingham = "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/nottingham.pickle"

# !wget $jsb
# !wget $piano
# !wget $muse
# !wget $nottingham

In [164]:
import pickle
import numpy as np
from torch.utils.data import Dataset
import torch
from torch.nn.utils.rnn import pad_sequence
import requests
import os

In [150]:
class MusicDataset(Dataset):
    def __init__(self, path, max_note=88, min_note=21, split='train'):
        self.max_note = max_note
        self.min_note = min_note
        self.path = path
        self.split = split
        
        self.data = self.load_process_data()
        self.sequence_lengths = self.data['sequence_lengths']
        self.encodings = self.data['encodings']
        self.masks = self.data['masks']
    
    def read_pickle_from_url(self, path, split='train'):
        with open(path, 'rb') as file:
            data = pickle.load(file)
        return data


    def load_process_data(self):
        """
        Data loader for the music dataset.
        split: train, test
        music: each music in the split_data: 
        keys: each set of note e.g (60, 64, 67) in the music
        notes: each note in e.g (60, 64, 67)
        Returns: a one-hot encoding of the music at the indecies of the notes, sequence length of each music
        """
        
        data_dict = {}
        
        data = self.read_pickle_from_url(self.path) #e.g 229 music data
        split_data = data[self.split]
        
        all_music_one_hot_list = []    
        sequence_lengths = []

        for music in split_data:
            one_hot_matrix = np.zeros((len(music), self.max_note), dtype=int)
            
            for row_index, keys in enumerate(music):

                for note in keys:
                    one_hot_matrix[row_index, note - self.min_note] = 1  
                    
            all_music_one_hot_list.append(one_hot_matrix)
            sequence_lengths.append(len(music))


        #pad music in all_music_one_hot with zeros until max_sequence_length with -1
        assert len(all_music_one_hot_list) == len(split_data)
        max_sequence_length = max(sequence_lengths)
        split_length = len(split_data)

        padded_all_music_one_hot = pad_sequence([torch.tensor(music) for music in all_music_one_hot_list], batch_first=True, padding_value=-1)

        masks = torch.zeros(( split_length, max_sequence_length, self.max_note ))
        masks = padded_all_music_one_hot != -1
    
        data_dict['encodings'] = padded_all_music_one_hot
        data_dict['sequence_lengths'] = sequence_lengths
        data_dict['masks'] = masks
        
        return data_dict
        
    def __getitem__(self, index):
        return self.encodings[index], self.masks[index], self.sequence_lengths[index]
        #encodings dim: (bs, seq_len, 88)
    def __len__(self):
        return len(self.encodings)

In [1]:
from torch.utils.data import DataLoader
from dataloader import MusicDataset
from omegaconf import OmegaConf

config = OmegaConf.load('config.yaml')
dataset = MusicDataset(config.dataset)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)



('jsb_chorales.pickle', 'https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/jsb_chorales.pickle')
jsb_chorales.pickle already exists. Skipping download.
('piano_midi.pickle', 'https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/piano_midi.pickle')
piano_midi.pickle already exists. Skipping download.
('muse_data.pickle', 'https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/muse_data.pickle')
muse_data.pickle already exists. Skipping download.
('nottingham.pickle', 'https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/nottingham.pickle')
nottingham.pickle already exists. Skipping download.


In [204]:
for i, (encodings, masks, sequence_lengths) in enumerate(dataloader):
    # print([encodings].shape)
    print(encodings.shape)
    print(masks.shape)
    print(sequence_lengths.shape)
    break


torch.Size([4, 129, 88])
torch.Size([4, 129, 88])
torch.Size([4])


In [195]:
def download_dataset(config):
    """
    Download a dataset from a given URL.

    Args:
    url (str): URL of the dataset to download.
    file_name (str): Name of the file to save the downloaded data.
    Note: By default, the data will be stored in data directory.
    """
    
    def download(file_name, url):
        response = requests.get(url)
        os.makedirs('data', exist_ok=True)
        
        if os.path.exists(os.path.join('data', file_name)):
            print(f"{file_name} already exists. Skipping download.")
            return
        
        if response.status_code == 200:
            with open(os.path.join('data', file_name) , 'wb') as file:
                file.write(response.content)
            print(f"Downloaded {file_name}")
        else:
            print(f"Failed to download {file_name}. Status code: {response.status_code}")
            
    if config.dataset.download_first:
        for data_dict in config.dataset.urls:
            file_name, url = next(iter(data_dict.items()))
            download(url, file_name)


In [196]:
#import omegaconf


path = 'config.yaml'

config = OmegaConf.load(path)

In [197]:
config.dataset.urls

[{'jsb_chorales.pickle': 'https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/jsb_chorales.pickle'}, {'piano_midi.pickle': 'https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/piano_midi.pickle'}, {'muse_data.pickle': 'https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/muse_data.pickle'}, {'nottingham.pickle': 'https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/nottingham.pickle'}]

In [198]:
data

{'jsb_chorales.pickle': 'https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/jsb_chorales.pickle'}

In [199]:
if config.dataset.download_first:
    for data_dict in config.dataset.urls:
        file_name, url = next(iter(data_dict.items()))

        download_dataset(url, file_name)

Downloaded jsb_chorales.pickle


KeyboardInterrupt: 