In [1]:
import pypianoroll
import os
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import io
import symusic
from pathlib import Path

from chroma_subsystem.BinaryTokenizer import BinaryTokenizer, SimpleSerialChromaTokenizer
from miditok import REMI, TokenizerConfig
from transformers import RobertaTokenizer, RobertaTokenizerFast, RobertaModel

from torch.nn.utils.rnn import pad_sequence

import midi_pianoroll_utils as mpu

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MAX_LENGTH = 1024

In [3]:
class LiveMelCATDataset(Dataset):
    def __init__(self, midis_folder, segment_size=64, resolution=24):
        self.midis_folder = midis_folder
        self.midis_list = os.listdir(midis_folder)
        self.segment_size = segment_size
        self.resolution = resolution
        self.binary_chroma_tokenizer = SimpleSerialChromaTokenizer()
        self.remi_tokenizer = REMI(params=Path('/media/datadisk/data/pretrained_models/midis_REMI_BPE_tokenizer.json'))
        self.roberta_tokenizer_chroma = RobertaTokenizerFast.from_pretrained('/media/datadisk/data/pretrained_models/chroma_mlm_tiny/chroma_wordlevel_tokenizer')
        self.roberta_tokenizer_midi = RobertaTokenizerFast.from_pretrained('/media/datadisk/data/pretrained_models/midi_mlm_tiny/midi_wordlevel_tokenizer')
        self.roberta_tokenizer_text = RobertaTokenizer.from_pretrained('roberta-base')
        self.padding_values = {
            'melody': self.roberta_tokenizer_midi.pad_token_id,
            'chroma': self.roberta_tokenizer_chroma.pad_token_id,
            'text': self.roberta_tokenizer_text.pad_token_id,
            'accomp': self.roberta_tokenizer_midi.pad_token_id
        }
        self.max_seq_lengths = {
            'melody': 1024,
            'chroma': 1024,
            'text': 1024,
            'accomp': 4096
        }
    # end init
    def __len__(self):
        return len(self.midis_list)
    # end len
    def __getitem__(self, idx):
        print('idx:', idx)
        print(self.midis_list[idx])
        # load a midi file in pianoroll
        main_piece = pypianoroll.read(self.midis_folder + os.sep + self.midis_list[idx], resolution=self.resolution)
        main_piece_size = main_piece.downbeat.shape[0]
        # make deepcopy
        new_piece = deepcopy(main_piece)
        # trim piece
        start_idx = np.random.randint( main_piece_size - self.segment_size*main_piece.resolution )
        end_idx = start_idx + self.segment_size*main_piece.resolution
        new_piece.trim(start_idx, end_idx)
        # split melody - accompaniment
        melody_piece, accomp_piece = mpu.split_melody_accompaniment_from_pianoroll( new_piece )
        # keep chroma from accompaniment
        chroma_zoomed_out = mpu.chroma_from_pianoroll(accomp_piece)
        # tokenize chroma to text tokens
        tokenized_chroma = self.binary_chroma_tokenizer(chroma_zoomed_out)
        chroma_string = ' '.join( tokenized_chroma['tokens'] )
        chroma_tokens = self.roberta_tokenizer_chroma( chroma_string )
        # make ghost files of melody and accomp pieces
        melody_file = mpu.pianoroll_to_midi_bytes(melody_piece)
        accomp_file = mpu.pianoroll_to_midi_bytes(accomp_piece)
        # tokenize melody and accompaniment midi to text
        remi_tokenized_melody = self.remi_tokenizer(melody_file)
        melody_string = ' '.join(remi_tokenized_melody[0].tokens)
        melody_tokens = self.roberta_tokenizer_midi(melody_string)
        remi_tokenized_accomp = self.remi_tokenizer(accomp_file)
        accomp_string = ' '.join(remi_tokenized_accomp[0].tokens)
        accomp_tokens = self.roberta_tokenizer_midi(accomp_string)
        # get text from title
        text_description = self.midis_list[idx]
        # tokenize text
        text_tokens = self.roberta_tokenizer_text(text_description)
        # return torch.LongTensor(melody_tokens['input_ids']),
        return {
            'melody': torch.LongTensor(melody_tokens['input_ids']),
            'chroma': torch.LongTensor(chroma_tokens['input_ids']),
            'text': torch.LongTensor(text_tokens['input_ids']),
            'accomp': torch.LongTensor(accomp_tokens['input_ids'])
        }
    # end getitem

    # def chroma_from_pianoroll(self, main_piece, resolution=24):
    #     # first binarize a new deep copy
    #     binary_piece = deepcopy(main_piece)
    #     binary_piece.binarize()
    #     # make chroma
    #     chroma = binary_piece.tracks[0].pianoroll[:,:12]
    #     for i in range(12, 128-12, 12):
    #         chroma = np.logical_or(chroma, binary_piece.tracks[0].pianoroll[:,i:(i+12)])
    #     chroma[:,-6:] = np.logical_or(chroma[:,-6:], binary_piece.tracks[0].pianoroll[:,-6:])
    #     # quarter chroma resolution
    #     chroma_tmp = np.zeros( (1,12) )
    #     chroma_zoomed_out = None
    #     for i in range(chroma.shape[0]):
    #         chroma_tmp += chroma[i,:]
    #         if (i+1)%resolution == 0:
    #             if chroma_zoomed_out is None:
    #                 chroma_zoomed_out = chroma_tmp >= np.mean( chroma_tmp )
    #             else:
    #                 chroma_zoomed_out = np.vstack( (chroma_zoomed_out, chroma_tmp >= np.mean( chroma_tmp )) )
    #     if np.sum( chroma_tmp ) > 0:
    #         if chroma_zoomed_out is None:
    #             chroma_zoomed_out = chroma_tmp >= np.mean( chroma_tmp )
    #         else:
    #             chroma_zoomed_out = np.vstack( (chroma_zoomed_out, chroma_tmp >= np.mean( chroma_tmp )) )
    #     return chroma_zoomed_out
    # # end chroma_from_pianoroll

    # def split_melody_accompaniment(self, pypianoroll_structure):
    #     melody_piece = deepcopy( pypianoroll_structure )
    #     accomp_piece = deepcopy( pypianoroll_structure )

    #     mel_pr = melody_piece.tracks[0].pianoroll
    #     acc_pr = accomp_piece.tracks[0].pianoroll

    #     pr = np.array(melody_piece.tracks[0].pianoroll)
    #     running_melody = -1
    #     i = 0
    #     # for i in range( pr.shape[0] ):
    #     while i < pr.shape[0]:
    #         # check if any note
    #         if np.sum(pr[i,:]) > 0:
    #             # get running max
    #             running_max = np.max( np.nonzero( pr[i,:] ) )
    #             # check if there exists a running melody
    #             if running_melody > -1:
    #                 # check if running melody is continued
    #                 if running_melody == running_max:
    #                     # remove all lower pitches from melody
    #                     mel_pr[i, :running_max] = 0
    #                     # remove higher pitch from accomp
    #                     acc_pr[i, running_max] = 0
    #                 else:
    #                     # running melody may need to change
    #                     # check if new highest pitch just started
    #                     if running_max > running_melody:
    #                         # a new higher note has started
    #                         # finish previous note that was highest until now
    #                         j = 0
    #                         while j+i < mel_pr.shape[0] and mel_pr[i+j, running_melody] > 0 and running_max > running_melody:
    #                             mel_pr[i+j, :running_melody] = 0
    #                             mel_pr[i+j, running_melody+1:running_max] = 0
    #                             acc_pr[i+j, running_melody] = 0
    #                             acc_pr[i+j, running_max] = 0
    #                             if np.sum( pr[i+j,:] ) > 0:
    #                                 running_max = np.max( np.nonzero( pr[i+j,:] ) )
    #                             else:
    #                                 running_melody = -1
    #                                 break
    #                             j += 1
    #                         # start new running melody
    #                         i += j-1
    #                         running_melody = running_max
    #                     else:
    #                         # i should be > 0 since we have that running_melody > -1
    #                         # a lower note has come
    #                         # if has begun earlier, it should be ignored
    #                         if pr[i-1, running_max] > 0:
    #                             # its continuing an existing note - not part of melody
    #                             mel_pr[i, :] = 0
    #                             # running max should not be canceled, it remains as ghost max
    #                             # until a new higher max or a fresh lower max starts
    #                         else:
    #                             # a new fresh lower max starts that shouldn't be ignored
    #                             # start new running melody
    #                             running_melody = running_max
    #                             # remove all lower pitches from melody
    #                             mel_pr[i, :running_max] = 0
    #                             # remove higher pitch from accomp
    #                             acc_pr[i, running_max] = 0
    #             else:
    #                 # no running melody, check max conditions
    #                 # new note started - make it the running melody
    #                 running_melody = running_max
    #                 # remove all lower pitches from melody
    #                 mel_pr[i, :running_max] = 0
    #                 # remove higher pitch from accomp
    #                 acc_pr[i, running_max] = 0
    #             # end if
    #         else:
    #             # there is a gap
    #             running_melody = -1
    #         # end if
    #         i += 1
    #     # end for
    #     return melody_piece, accomp_piece
    # # end split_melody_accompaniment

    # def make_midi_bytes(self, pianoroll_structure):
    #     # initialize bytes handle
    #     b_handle = io.BytesIO()
    #     # write midi data to bytes handle
    #     pianoroll_structure.write(b_handle)
    #     # start read pointer from the beginning
    #     b_handle.seek(0)
    #     # create a buffered reader to read the handle
    #     buffered_reader = io.BufferedReader(b_handle)
    #     # create a midi object from the "file", i.e., buffered reader
    #     midi_bytes = symusic.Score.from_midi(b_handle.getvalue())
    #     # close the bytes handle
    #     b_handle.close()
    #     return midi_bytes
    # # end 

In [4]:
midifolder = '/media/datadisk/datasets/GiantMIDI-PIano/midis_v1.2/midis'
# midifolder = '/media/datadisk/data/Giant_PIano/'
dataset = LiveMelCATDataset(midifolder, segment_size=40)



In [5]:
print(len(dataset))

10761


In [6]:
d0 = dataset[0]

idx: 0
Ismagilov, Timur, Spring Sketches, 2QxuHQoT5Dk.mid


In [7]:
d0

{'melody': tensor([ 26,  70, 157,   5,  12,  52,  17,  20,   5,  12,   5,  10,  69,  59,
          15,  11,   5,   8,   5,   9,  26,  90,  59,  24,   7,   5,  27,   5,
           6,  66,  96,  21,   7,   5,  14,   5,   6,  82, 105,  21,  29,   5,
           8,   5,  10,  26,  83,  52,  17,   7,   5,   8,   5,   6,  73,  52,
         103,   7,   5,   9,   5,   6,  62,  38,  13,   7,   5,  14,   5,   6,
          89,  44,  16,   7,   5,  14,   5,   6,  76,  52,  16,   7,   5,  22,
           5,   6,  26,  73,  59,  19,  20,   5,   8,   5,  10,  26,  61,  44,
          13,   7,   5,   9,   5,   6,  66,  52,  19,  11,   5,  12,   5,   9,
          82,  96,  19,  11,   5,   8,   5,   9,  26,  79,  59,  13,  11,   5,
          10,   5,   9,  26,  92,  59,  21,   7,   5,  25,   5,   6,  61,  96,
          24,   7,   5,   9,   5,   6,  66, 105,  24,  20,   5,  12,   5,  10,
          76, 105,  21,  95,   5,  12,   5,   8,  26,  26,  26,  67,  59,  13,
          11,   5,   8,   5,   9,  82,  59

In [8]:
class MelCATCollator:
    def __init__(self, max_seq_lens=None, padding_values=None):
        self.max_seq_lens = max_seq_lens
        self.padding_values = padding_values
    # end init

    def __call__(self, batch):
        # Create a dictionary to hold the batched data
        batch_dict = {}
        # Assume the batch is a list of dictionaries (each sample is a dict)
        for key in batch[0]:
            batch_dict[key] = {}
            values = [item[key] for item in batch]
            if isinstance(values[0], list):
                # If values are lists (sequences of variable lengths), pad them
                padded_values = pad_sequence([torch.tensor(v) for v in values], 
                                            batch_first=True, padding_value=self.padding_values[key])
            elif isinstance(values[0], torch.Tensor):
                # If values are tensors, stack them directly
                padded_values = pad_sequence(values, batch_first=True, padding_value=self.padding_values[key])
            # trim to max length
            if self.max_seq_lens is not None and padded_values.shape[1] > self.max_seq_lens[key]:
                padded_values = padded_values[:,:self.max_seq_lens[key]]
            batch_dict[key]['input_ids'] = padded_values
            attention_mask = torch.ones_like(padded_values, dtype=torch.long)
            attention_mask[padded_values == self.padding_values[key]] = 0
            batch_dict[key]['attention_mask'] = attention_mask
        return batch_dict
    # end call
# end class


In [9]:
# # OLD
# def custom_collate_fn(batch):
#     # Create a dictionary to hold the batched data
#     batch_dict = {}

#     # Assume the batch is a list of dictionaries (each sample is a dict)
#     for key in batch[0]:
#         batch_dict[key] = {}
#         values = [item[key] for item in batch]
#         if isinstance(values[0], list):
#             # If values are lists (sequences of variable lengths), pad them
#             padded_values = pad_sequence([torch.tensor(v) for v in values], 
#                                          batch_first=True, padding_value=0)
#             batch_dict[key]['input_ids'] = padded_values
#             attention_mask = torch.ones_like(padded_values.shape, dtype=torch.long)
#             attention_mask[padded_values == 0] = 0
#             batch_dict[key]['attention_mask'] = attention_mask
        
#         elif isinstance(values[0], torch.Tensor):
#             # If values are tensors, stack them directly
#             batch_dict[key]['input_ids'] = pad_sequence(values, batch_first=True, padding_value=0)
#             attention_mask = torch.ones_like(batch_dict[key]['input_ids'], dtype=torch.long)
#             attention_mask[batch_dict[key]['input_ids'] == 0] = 0
#             batch_dict[key]['attention_mask'] = attention_mask
    
#     return batch_dict

In [10]:
custom_collate_fn = MelCATCollator(max_seq_lens=dataset.max_seq_lengths, padding_values=dataset.padding_values)

In [11]:
dataloader = DataLoader(dataset, batch_size=4, collate_fn=custom_collate_fn)

In [12]:
b = next(iter(dataloader))

idx: 0
Ismagilov, Timur, Spring Sketches, 2QxuHQoT5Dk.mid


idx: 1
Gurlitt, Cornelius, Frühlingsblumen, Op.215, WD6wHfUb-kU.mid
idx: 2
Singelée, Jean Baptiste, Fantaisie sur des motifs de 'La sonnambula', Op.39, AcaSiJG7mkU.mid
idx: 3
Simpson, Daniel Léo, Kleine Klavierstücke No.9 in F major, R4z8vPF1Hto.mid


In [13]:
print(b)
# print(len(b['melody']['input_ids']))
# print(len(b['accomp']['input_ids']))

{'melody': {'input_ids': tensor([[ 26,  70, 157,  ...,   3,   3,   3],
        [ 26,  70, 157,  ...,  12,   5,   9],
        [ 26,  70, 157,  ...,   3,   3,   3],
        [ 26,  70, 157,  ...,   3,   3,   3]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}, 'chroma': {'input_ids': tensor([[25, 17,  8,  ...,  7, 12, 17],
        [25, 17,  9,  ...,  3,  3,  3],
        [25, 17, 13,  ...,  3,  3,  3],
        [25, 17,  5,  ...,  3,  3,  3]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}, 'text': {'input_ids': tensor([[    0,  6209, 16266,   718,  1417,     6,  2668,   710,     6,  5519,
          4058,   594,  5559,     6,   132,  1864,  1178,   257, 37912,   139,
           565,   245,   495,   330,     4, 16079,     2,     1,     1,     1,
             1,     1,     1, 

In [14]:
print(b['melody']['input_ids'].shape)
print(b['melody']['attention_mask'].shape)
print(b['text']['input_ids'].shape)
print(b['text']['attention_mask'].shape)
print(b['chroma']['input_ids'].shape)
print(b['chroma']['attention_mask'].shape)
print(b['accomp']['input_ids'].shape)
print(b['accomp']['attention_mask'].shape)

torch.Size([4, 402])
torch.Size([4, 402])
torch.Size([4, 40])
torch.Size([4, 40])
torch.Size([4, 301])
torch.Size([4, 301])
torch.Size([4, 889])
torch.Size([4, 889])
