In [1]:
import miditoolkit
import remi_utils as utils
from collections import Counter
import pickle
import glob
import json
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np

In [2]:
# create pickle file based on dataset

def extract_events(input_path, chord=False):
    note_items, tempo_items = utils.read_items(input_path)
    note_items = utils.quantize_items(note_items)
    max_time = note_items[-1].end
    if chord:
        chord_items = utils.extract_chords(note_items)
        items = chord_items + tempo_items + note_items
    else:
        items = tempo_items + note_items
    groups = utils.group_items(items, max_time)
    events = utils.item2event(groups)
    return events

In [3]:
all_elements= []
for midi_file in glob.glob("./data_augmented/melody/*/*.mid*", recursive=True):
    #print(midi_file)
    events = extract_events(midi_file) # If you're analyzing chords, use `extract_events(midi_file, chord=True)`
    for event in events:
        element = '{}_{}'.format(event.name, event.value)
        all_elements.append(element)

for midi_file in glob.glob("./data_augmented/piano/*/*.mid*", recursive=True):
    #print(midi_file)
    try:
        events = extract_events(midi_file) # If you're analyzing chords, use `extract_events(midi_file, chord=True)`
    except:
        print(midi_file)
    for event in events:
        element = '{}_{}'.format(event.name, event.value)
        all_elements.append(element)        

counts = Counter(all_elements)
event2word = {c: i for i, c in enumerate(counts.keys())}
word2event = {i: c for i, c in enumerate(counts.keys())}
pickle.dump((event2word, word2event), open('dictionary_augmented.pkl', 'wb'))

KeyboardInterrupt: 

In [36]:
event2word, word2event = pickle.load(open('dictionary_augmented.pkl', 'rb'))

In [4]:
count = 0
intros = []
outros = []
solos = []
with open('solos.json') as json_file: 
    data = json.load(json_file) 
semitones = [-3,-2,-1,0,1,2,3]
for i in range(1,910):
    for k in semitones:
        filename = str(i).zfill(3)
        if filename not in data:
            continue
        for j in range(len(data[filename])):
            count += 1
            # extract intro
            intro = extract_events("./data_augmented/melody/intro/" + str(i).zfill(3) + "_solo_" + str(k) + "_" + str(j) + ".mid")
            w_intro = utils.event_to_word(intro, event2word)
            intros.append(w_intro)
            # extract outro
            outro = extract_events("./data_augmented/melody/outro/" + str(i).zfill(3) + "_solo_" + str(k) + "_" + str(j) + ".mid")
            w_outro = utils.event_to_word(outro, event2word)
            outros.append(w_outro)
            # extract solo
            solo = extract_events("./data_augmented/melody/middle/" + str(i).zfill(3) + "_solo_" + str(k) + "_" + str(j) + ".mid")
            w_solo = utils.event_to_word(solo, event2word)
            solos.append(w_solo)
        

KeyboardInterrupt: 

In [8]:
len(intros)

5250

In [18]:
count = 0
intros_piano = []
outros_piano = []
solos_piano = []
with open('solos.json') as json_file: 
    data = json.load(json_file) 
for i in range(1,910):
    for k in semitones:
        filename = str(i).zfill(3)
        if filename not in data:
            continue
        for j in range(len(data[filename])):
            count += 1
            # extract intro
            intro = extract_events("./data_augmented/piano/intro/" + str(i).zfill(3) + "_solo_" + str(k) + "_" + str(j) + ".mid")
            w_intro = utils.event_to_word(intro, event2word)
            intros_piano.append(w_intro)
            # extract outro
            outro = extract_events("./data_augmented/piano/outro/" + str(i).zfill(3) + "_solo_" + str(k) + "_" + str(j) + ".mid")
            w_outro = utils.event_to_word(outro, event2word)
            outros_piano.append(w_outro)
            # extract solo
            solo = extract_events("./data_augmented/piano/middle/" + str(i).zfill(3) + "_solo_" + str(k) + "_" + str(j) + ".mid")
            w_solo = utils.event_to_word(solo, event2word)
            solos_piano.append(w_solo)


In [20]:
data = [intros, intros_piano, outros, outros_piano, solos, solos_piano]
pickle.dump(data, open('./solo_generation_dataset_augmented/solo_generation_dataset.pkl', 'wb'))

In [4]:
data = pickle.load(open('./solo_generation_dataset_augmented/solo_generation_dataset.pkl', 'rb'))

In [6]:
def find_max_length(series):
    max_length=0
    for i in range(len(series)):
        if max_length < len(series[i]):
            max_length = len(series[i])
    return max_length

def pad_dataset(dataset, word2event):
    pad_value = len(word2event)
    max_length = 0
    for i in range(len(dataset)):
        if max_length < find_max_length(dataset[i]):
            max_length = find_max_length(dataset[i])
    print(max_length)
    for i in range(len(dataset)):
        for j in range(len(dataset[i])):
            while len(dataset[i][j]) < max_length:
                dataset[i][j].append(pad_value)
    return dataset

In [268]:
data_padded = pad_dataset(data,word2event)

1637


In [12]:
pickle.dump(data_padded, open('./solo_generation_dataset_augmented/solo_generation_dataset_padded.pkl', 'wb'))
data_padded = pickle.load(open('./solo_generation_dataset_augmented/solo_generation_dataset_padded.pkl', 'rb'))
#data = data_padded

NameError: name 'data_padded' is not defined

In [5]:
data_text = data
for i in range(len(data)):
    for j in range(len(data[i])):
        string_array = [str(num) for num in data[i][j]]
        data_text[i][j] = ' '.join(string_array)

In [7]:
intros_t, intros_piano_t, outros_t, outros_piano_t, solos_t, solos_piano_t = data_text

In [41]:
intros_piano_t[0] + solos_piano_t[0]

'0 1 43 105 67 25 104 15 67 14 55 52 67 14 46 52 23 11 104 15 72 38 224 137 72 32 131 19 72 11 128 52 91 45 179 52 70 54 131 22 70 14 128 52 70 11 104 52 78 53 179 19 17 64 131 15 90 64 131 52 90 14 128 52 90 11 104 52 90 11 75 52 74 51 179 52 0 67 25 131 31 67 14 128 19 23 32 179 7 8 64 131 15 72 14 153 66 72 64 93 31 72 32 100 52 91 35 131 19 13 14 93 7 70 11 75 7 78 51 104 66 78 24 55 66 90 14 179 66 74 38 127 52 0 67 14 104 15 67 14 55 7 67 14 46 52 23 38 104 22 72 11 185 42 72 14 128 31 72 14 104 31 91 54 129 19 13 11 128 22 70 14 128 15 70 24 104 22 70 14 55 15 78 14 128 34 78 14 100 34 90 11 132 28 74 32 131 34 0 67 25 128 15 67 5 100 31 67 14 75 52 23 11 128 22 72 14 153 66 72 38 93 31 72 11 100 31 91 35 131 52 13 38 93 15 70 35 93 15 70 32 100 22 70 14 75 15 78 32 104 66 78 14 55 66 90 14 179 42 74 45 127 52 0 67 25 104 15 67 14 55 52 67 14 46 52 23 11 104 15 72 32 131 19 72 11 128 52 91 45 179 52 70 54 131 220 1 43 105 67 53 179 19 4 64 131 15 23 64 131 52 23 14 128 52 23 11 

In [42]:
piano = []
for x in range(len(intros_piano_t)):
    piano.append(intros_piano_t[x] + ' ' + solos_piano_t[x] + ' ' + outros_piano_t[x])
main_melody = []
for x in range(len(intros_t)):
    main_melody.append(intros_t[x] + ' ' + solos_t[x] + ' ' + outros_t[x])

In [43]:
piano_train, piano_subset, main_train, main_subset = piano[:5250-1568], piano[-1568:], main_melody[:5250-1568], main_melody[-1568:]

In [44]:
half = len(piano_subset)//2
piano_test, piano_valid, main_test, main_valid = piano_subset[half:], piano_subset[:half], main_subset[half:], main_subset[:half]

In [45]:
main_test = main_test[3::7]
main_valid = main_valid[3::7]
piano_test = piano_test[3::7]
piano_valid = piano_valid[3::7]

In [46]:
len(piano_test)

112

In [47]:
train = []
val = []
test = []

for i in range(len(main_train)):
    train.append([main_train[i], piano_train[i]])
    
for i in range(len(main_valid)):
    val.append([main_valid[i],piano_valid[i]])
    
for i in range(len(main_test)):
    test.append([main_test[i],piano_test[i]])
    

In [48]:
df_train = pd.DataFrame(train,columns=['main', 'piano'])
df_val = pd.DataFrame(test,columns=['main', 'piano'])
df_test = pd.DataFrame(val,columns=['main', 'piano'])

In [49]:
destination_folder="solo_generation_dataset_augmented_mag"
df_train.to_csv(destination_folder + '/train_torchtext.csv', index=False)
df_val.to_csv(destination_folder + '/val_torchtext.csv', index=False)
df_test.to_csv(destination_folder + '/test_torchtext.csv', index=False)

In [50]:
def remove_padding(series, word2event):
    return [value for value in series if value != len(word2event)]

In [51]:
lst_int = [int(x) for x in df_train.values[0][0].split(' ')]
utils.write_midi(remove_padding(lst_int, word2event), word2event, 'test.midi')

In [52]:
for x in lst_int:
    print(word2event[x])

Bar_None
Position_1/16
Tempo Class_mid
Tempo Value_0
Position_2/16
Note Velocity_28
Note On_65
Note Duration_1
Position_4/16
Note Velocity_29
Note On_67
Note Duration_14
Bar_None
Position_2/16
Note Velocity_28
Note On_58
Note Duration_0
Position_3/16
Note Velocity_27
Note On_60
Note Duration_0
Position_4/16
Note Velocity_29
Note On_63
Note Duration_0
Position_5/16
Note Velocity_28
Note On_65
Note Duration_0
Position_6/16
Note Velocity_28
Note On_67
Note Duration_0
Position_8/16
Note Velocity_28
Note On_63
Note Duration_1
Position_10/16
Note Velocity_29
Note On_60
Note Duration_2
Position_12/16
Note Velocity_28
Note On_65
Note Duration_11
Bar_None
Position_6/16
Note Velocity_28
Note On_65
Note Duration_2
Position_8/16
Note Velocity_28
Note On_62
Note Duration_1
Position_10/16
Note Velocity_28
Note On_58
Note Duration_1
Position_12/16
Note Velocity_29
Note On_63
Note Duration_8
Bar_None
Position_2/16
Note Velocity_28
Note On_58
Note Duration_0
Position_3/16
Note Velocity_28
Note On_60
No

In [38]:
lst_int

[0,
 1,
 43,
 105,
 67,
 5,
 46,
 15,
 23,
 25,
 48,
 77,
 0,
 67,
 5,
 104,
 22,
 4,
 14,
 100,
 22,
 23,
 25,
 75,
 22,
 8,
 5,
 46,
 22,
 72,
 5,
 48,
 22,
 91,
 5,
 75,
 15,
 70,
 25,
 100,
 7,
 78,
 5,
 46,
 28,
 0,
 72,
 5,
 46,
 7,
 91,
 5,
 55,
 15,
 70,
 5,
 104,
 15,
 78,
 25,
 75,
 66,
 0,
 67,
 5,
 104,
 22,
 4,
 5,
 100,
 22,
 23,
 5,
 75,
 22,
 8,
 5,
 46,
 22,
 72,
 14,
 48,
 22,
 91,
 5,
 75,
 7,
 70,
 5,
 100,
 15,
 78,
 5,
 46,
 19,
 74,
 5,
 104,
 15,
 0,
 67,
 5,
 46,
 15,
 23,
 25,
 75,
 770,
 1,
 43,
 105,
 91,
 25,
 62,
 15,
 13,
 5,
 58,
 22,
 70,
 25,
 73,
 22,
 16,
 5,
 61,
 22,
 78,
 5,
 86,
 7,
 90,
 64,
 61,
 22,
 27,
 5,
 86,
 22,
 74,
 5,
 73,
 22,
 0,
 67,
 25,
 61,
 34,
 78,
 18,
 61,
 7,
 90,
 25,
 73,
 22,
 27,
 18,
 61,
 22,
 74,
 24,
 57,
 15,
 0,
 67,
 25,
 73,
 34,
 70,
 25,
 104,
 22,
 16,
 5,
 100,
 22,
 78,
 25,
 75,
 7,
 90,
 25,
 100,
 22,
 27,
 25,
 75,
 22,
 74,
 18,
 46,
 7,
 0,
 67,
 5,
 75,
 22,
 4,
 25,
 46,
 22,
 4,
 5,
 46,
 22,
 23,
