In [1]:
%matplotlib inline
import os
import matplotlib.pyplot as plt
from xml_to_mid import xml_to_mid
from loader_util import *

folder = '/home/azuma/workspace/git/lead-sheet-dataset/datasets/xml/a'
save_path = './dataset'


"""
p_xml_list = get_p_extension_list(folder, 'xml')
samples = p_xml_list[:]

xml_to_mid(samples, save_path)
"""

"\np_xml_list = get_p_extension_list(folder, 'xml')\nsamples = p_xml_list[:]\n\nxml_to_mid(samples, save_path)\n"

In [3]:
import numpy as np

def generate_time_note_dict(pianoroll_dict):
    time_note_dict = {} # key: file_num, value: time_note_dict
    
    for name_num in pianoroll_dict.keys():
        pianoroll = pianoroll_dict[name_num]
        pianoroll_T = pianoroll.T
        time_note_list = []
        
        # add top note idx
        for i in range(pianoroll_T.shape[0]):
            note = np.nonzero(pianoroll_T[i])[0] 
            if len(note) == 0:
                time_note_list.append('e')
            else:
                time_note_list.append(max(note))
                
        time_note_dict[name_num] = time_note_list
        
    return time_note_dict


In [4]:
def generate_input_and_target(time_note, seq_len=50):
    start, end = 0, len(time_note)
    input_list, target_list = [], []
    
    for idx in range(start, end):
        input_sample, target_sample = [], []
        start_iterate = 0
        
        if idx < seq_len:
            start_iterate = seq_len - idx - 1
            for i in range(start_iterate):
                input_sample.append('e')
                
        for i in range(start_iterate, seq_len):
            current_idx = idx - (seq_len - i - 1)
            input_sample.append(time_note[current_idx])
                
        if idx + 1 < end:
            target_sample.append(time_note[idx + 1])
        else:
            target_sample.append('e')
            
        input_list.append(input_sample)
        target_list.append(target_sample)
        
    return input_list, target_list


def generate_batch(p_midi_list, batch_song=16, start_idx=0, fs=30, seq_len=50):
    assert len(p_midi_list) >= batch_song
    
    batch_input, batch_target = [], []
    pianoroll_dict = generate_pianoroll_dict(p_midi_list, batch_song=batch_song, fs=fs)
    time_note_dict = generate_time_note_dict(pianoroll_dict)
    
    for i in list(time_note_dict.keys()):
        input_list, target_list = generate_input_and_target(time_note_dict[i], seq_len)
        batch_input += input_list
        batch_target += target_list
        
    return batch_input, batch_target
    

In [5]:
batch_song = 10
fs = 10
seq_len = 50

In [6]:
def align_dicts(pianoroll_dict, notes_chord_dict):
    
    # get key that has .mid and .chord
    pianoroll_keys = list(pianoroll_dict.keys())
    notes_chord_keys = list(notes_chord_dict.keys())
    common_keys = list(set(pianoroll_keys) & set(notes_chord_keys))
    
    # rm abundant item
    rm_list = list(set(pianoroll_keys) - set(common_keys))
    for i in rm_list:
        del pianoroll_dict[i]
    rm_list = list(set(notes_chord_keys) - set(common_keys))
    for i in rm_list:
        del notes_chord_dict[i]
        
    # align length of value 
    for key in common_keys:                                                    
        pianoroll_dict_len = pianoroll_dict[key].shape[1]
        notes_chord_dict_len = len(notes_chord_dict[key])
        if pianoroll_dict_len >= notes_chord_dict_len:
            pianoroll_dict[key] = pianoroll_dict[key][:notes_chord_dict_len]
        else:
            notes_chord_dict[key] = notes_chord_dict[key][:pianoroll_dict_len]
            
    return pianoroll_dict, notes_chord_dict


In [7]:
pianoroll_dict = generate_pianoroll_dict(p_midi_list, batch_song=batch_song, start_idx=0, fs=fs)
notes_chord_dict = generate_notes_chord_dict(p_midi_list, batch_song=batch_song, start_idx=0, fs=fs)
pianoroll_dict, notes_chord_dict = align_dicts(pianoroll_dict, notes_chord_dict)

In [8]:
for key in list(pianoroll_dict.keys()):
    print(pianoroll_dict[key].shape)
    print(len(notes_chord_dict[key]))

(128, 234)
234
(114, 114)
114
(128, 303)
303
(128, 149)
149
(128, 182)
182
(128, 149)
149
(128, 168)
168
(128, 405)
405
(127, 158)
127
(128, 299)
281


In [10]:
for i in range(5):
    start_idx = i * batch_song
    pianoroll_dict = generate_pianoroll_dict(p_midi_list, batch_song=batch_song, start_idx=start_idx, fs=fs)
    notes_chord_dict = generate_notes_chord_dict(p_midi_list, batch_song=batch_song, start_idx=start_idx, fs=fs)
    # pianoroll_dict, notes_chord_dict = align_dicts(pianoroll_dict, notes_chord_dict)

    for key in list(pianoroll_dict.keys()):
        print(pianoroll_dict[key].shape)
        print(len(notes_chord_dict[key]))

(128, 234)
234
(128, 114)
114
(128, 303)
303
(128, 149)
150
(128, 182)
182
(128, 149)
150
(128, 168)
206
(128, 405)
405
(128, 158)
127
(128, 299)
281
(128, 144)
144
(128, 152)
152
(128, 149)
140
(128, 202)
202
(128, 127)
127
(128, 138)
138
(128, 160)
160
(128, 147)
147
(128, 218)
232
(128, 145)
145
(128, 416)
416
(128, 240)
240
(128, 307)
307
(128, 278)
278
(128, 177)
177
(128, 168)
159
(128, 318)
320
(128, 243)
243
(128, 252)
252
(128, 175)
200
(128, 290)
300
(128, 93)
93
(128, 149)
150
(128, 147)
147
(128, 93)
93
(128, 110)
110
(128, 399)
412
(128, 335)
352
(128, 302)
302
(128, 581)
600
(128, 142)
150
(128, 148)
148
(128, 149)
150
(128, 183)
183
(128, 342)
342
(128, 132)
133
(128, 237)
249
(128, 218)
290
(128, 99)
101
(128, 149)
150
