In [1]:
!wget https://storage.googleapis.com/magentadata/datasets/maestro/v1.0.0/maestro-v1.0.0-midi.zip
!unzip maestro-v1.0.0-midi.zip

--2019-06-05 15:35:19--  https://storage.googleapis.com/magentadata/datasets/maestro/v1.0.0/maestro-v1.0.0-midi.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.212.128, 2607:f8b0:4001:c07::80
Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.212.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 46579421 (44M) [application/zip]
Saving to: ‘maestro-v1.0.0-midi.zip’


2019-06-05 15:35:20 (61.4 MB/s) - ‘maestro-v1.0.0-midi.zip’ saved [46579421/46579421]

Archive:  maestro-v1.0.0-midi.zip
   creating: maestro-v1.0.0/
  inflating: maestro-v1.0.0/LICENSE  
  inflating: maestro-v1.0.0/maestro-v1.0.0.csv  
   creating: maestro-v1.0.0/2017/
  inflating: maestro-v1.0.0/2017/MIDI-Unprocessed_045_PIANO045_MID--AUDIO-split_07-06-17_Piano-e_2-01_wav--2.midi  
  inflating: maestro-v1.0.0/2017/MIDI-Unprocessed_059_PIANO059_MID--AUDIO-split_07-07-17_Piano-e_2-03_wav--3.midi  
  inflating: maestro-v1.0.0/2017/MIDI-Unprocessed_046_PI

In [0]:
import glob
import random
import pretty_midi
import IPython
import numpy as np
from tqdm import tnrange, tqdm_notebook, tqdm
from random import shuffle, seed
import numpy as np
import numpy as np
from numpy.random import choice
import pickle
import matplotlib.pyplot as plt

import unicodedata
import re
import numpy as np
import os
import io
import time

In [0]:
def get_list_midi(folder = 'maestro-v1.0.0/**/*.midi', seed_int = 666):
    """Get the list of all midi file in the folders

    Parameters
    ==========
    folder : str
    The midi folder.
    seed_int : int
    the random seed.

    Returns
    =======
    The midi files

    """
    list_all_midi = glob.glob(folder)
    seed(seed_int)
    shuffle(list_all_midi)
    return list_all_midi

list_all_midi = get_list_midi()

In [0]:
def generate_dict_time_notes(list_all_midi, batch_song = 16, start_index=0, fs=30, use_tqdm=True):
    """ Generate map (dictionary) of music ( in index ) to piano_roll (in np.array)

    Parameters
    ==========
    list_all_midi : list
        List of midi files
    batch_music : int
      A number of music in one batch
    start_index : int
      The start index to be batched in list_all_midi
    fs : int
      Sampling frequency of the columns, i.e. each column is spaced apart
        by ``1./fs`` seconds.
    use_tqdm : bool
      Whether to use tqdm or not in the function

    Returns
    =======
    dictionary of music to piano_roll (in np.array)

    """
    assert len(list_all_midi) >= batch_song
    
    dict_time_notes = {}
    process_tqdm_midi = tqdm_notebook(range(start_index, min(start_index + batch_song, len(list_all_midi)))) if use_tqdm else range(start_index,  min(start_index + batch_song, len(list_all_midi)))
    for i in process_tqdm_midi:
        midi_file_name = list_all_midi[i]
        if use_tqdm:
            process_tqdm_midi.set_description("Processing {}".format(midi_file_name))
        try: # Handle exception on malformat MIDI files
            midi_pretty_format = pretty_midi.PrettyMIDI(midi_file_name)
            piano_midi = midi_pretty_format.instruments[0] # Get the piano channels
            piano_roll = piano_midi.get_piano_roll(fs=fs)
            dict_time_notes[i] = piano_roll
        except Exception as e:
            print(e)
            print("broken file : {}".format(midi_file_name))
            pass
    return dict_time_notes

In [5]:
dict_time_notes = generate_dict_time_notes(list_all_midi, batch_song = 16, start_index=0, fs=30, use_tqdm=True)

HBox(children=(IntProgress(value=0, max=16), HTML(value='')))




In [6]:
dict_time_notes[0][:, :50].shape

(128, 50)

In [0]:
def generate_training_data(list_all_midi, batch_song = 16, train_size = 50, target_size = 1, 
                           start_index=0, fs=30, use_tqdm=True, ignore_velocity = True):
    """
    Generate traning batch of training data.
    Data is a sequence of notes in piano-roll representation
    
    Parameters
    ==========
    list_all_midi : list
      List of midi files
    batch_music : int
      A number of music in one batch
    train_size : int
      A number of notes used for prediction
    target_size : int
      A number of notes to be predicted
    start_index : int
      The start index to be batched in list_all_midi
    fs : int
      Sampling frequency of the columns, i.e. each column is spaced apart
        by ``1./fs`` seconds.
    seq_len : int
      The sequence length of the music to be input of neural network
    use_tqdm : bool
      Whether to use tqdm or not in the function
    ignore_velocity : bool
      Whether to use real valued or binary representation of note's velocity
    
    Returns
    =======
    Tuple of input and target neural network
    """
    
    dict_time_notes = generate_dict_time_notes(list_all_midi, batch_song, start_index, fs, use_tqdm)
    
    list_data, list_target = [], []
    for _, value in dict_time_notes.items():
        for i in range(0, value.shape[1], train_size + target_size):
            sample = value[:, i: i + train_size + target_size]
            if ignore_velocity:
                sample = np.where(sample > 0, 1, sample)
            if sample.shape[1] == train_size + target_size:
                list_data.append(sample[:, :train_size].T)
                list_target.append(sample[:, train_size:].T)
    
    return np.array(list_data), np.array(list_target)

In [8]:
list_data, list_target = generate_training_data(list_all_midi)

HBox(children=(IntProgress(value=0, max=16), HTML(value='')))




In [0]:
def generate_dict_time_notes(list_all_midi, batch_song = 16, start_index=0, fs=30, use_tqdm=True):
    """ Generate map (dictionary) of music ( in index ) to piano_roll (in np.array)

    Parameters
    ==========
    list_all_midi : list
        List of midi files
    batch_music : int
      A number of music in one batch
    start_index : int
      The start index to be batched in list_all_midi
    fs : int
      Sampling frequency of the columns, i.e. each column is spaced apart
        by ``1./fs`` seconds.
    use_tqdm : bool
      Whether to use tqdm or not in the function

    Returns
    =======
    dictionary of music to piano_roll (in np.array)

    """
    assert len(list_all_midi) >= batch_song
    
    dict_time_notes = {}
    process_tqdm_midi = tqdm_notebook(range(start_index, min(start_index + batch_song, len(list_all_midi)))) if use_tqdm else range(start_index,  min(start_index + batch_song, len(list_all_midi)))
    for i in process_tqdm_midi:
        midi_file_name = list_all_midi[i]
        if use_tqdm:
            process_tqdm_midi.set_description("Processing {}".format(midi_file_name))
        try: # Handle exception on malformat MIDI files
            midi_pretty_format = pretty_midi.PrettyMIDI(midi_file_name)
            piano_midi = midi_pretty_format.instruments[0] # Get the piano channels
            piano_roll = piano_midi.get_piano_roll(fs=fs)
            dict_time_notes[i] = piano_roll
        except Exception as e:
            print(e)
            print("broken file : {}".format(midi_file_name))
            pass
    return dict_time_notes

def generate_input_and_target(dict_keys_time, seq_len=50):
    """ Generate input and the target of our deep learning for one music.
    
    Parameters
    ==========
    dict_keys_time : dict
      Dictionary of timestep and notes
    seq_len : int
      The length of the sequence
      
    Returns
    =======
    Tuple of list of input and list of target of neural network.
    
       
    """
    # Get the start time and end time
    start_time, end_time = list(dict_keys_time.keys())[0], list(dict_keys_time.keys())[-1]
    list_training, list_target = [], []
    for index_enum, time in enumerate(range(start_time, end_time)):
        list_append_training, list_append_target = [], []
        start_iterate = 0
        flag_target_append = False # flag to append the test list
        if index_enum < seq_len:
            start_iterate = seq_len - index_enum - 1
            for i in range(start_iterate): # add 'e' to the seq list. 
                list_append_training.append('e')
                flag_target_append = True

        for i in range(start_iterate,seq_len):
            index_enum = time - (seq_len - i - 1)
            if index_enum in dict_keys_time:
                list_append_training.append(','.join(str(x) for x in dict_keys_time[index_enum]))      
            else:
                list_append_training.append('e')

        # add time + 1 to the list_append_target
        if time+1 in dict_keys_time:
            list_append_target.append(','.join(str(x) for x in dict_keys_time[time+1]))
        else:
            list_append_target.append('e')
        list_training.append(list_append_training)
        list_target.append(list_append_target)
    return list_training, list_target

def process_notes_in_song(dict_time_notes, seq_len = 50):
    """
    Iterate the dict of piano rolls into dictionary of timesteps and note played
    
    Parameters
    ==========
    dict_time_notes : dict
      dict contains index of music ( in index ) to piano_roll (in np.array)
    seq_len : int
      Length of the sequence
      
    Returns
    =======
    Dict of timesteps and note played
    """
    list_of_dict_keys_time = []
    
    for key in dict_time_notes:
        sample = dict_time_notes[key]
        times = np.unique(np.where(sample > 0)[1])
        index = np.where(sample > 0)
        dict_keys_time = {}

        for time in times:
            index_where = np.where(index[1] == time)
            notes = index[0][index_where]
            dict_keys_time[time] = notes
        list_of_dict_keys_time.append(dict_keys_time)
    return list_of_dict_keys_time

def generate_batch_song(list_all_midi, batch_music=16, start_index=0, fs=30, seq_len=50, use_tqdm=False):
    """
    Generate Batch music that will be used to be input and output of the neural network
    
    Parameters
    ==========
    list_all_midi : list
      List of midi files
    batch_music : int
      A number of music in one batch
    start_index : int
      The start index to be batched in list_all_midi
    fs : int
      Sampling frequency of the columns, i.e. each column is spaced apart
        by ``1./fs`` seconds.
    seq_len : int
      The sequence length of the music to be input of neural network
    use_tqdm : bool
      Whether to use tqdm or not in the function
    
    Returns
    =======
    Tuple of input and target neural network
    
    """
    
    assert len(list_all_midi) >= batch_music
    dict_time_notes = generate_dict_time_notes(list_all_midi, batch_music, start_index, fs, use_tqdm=use_tqdm)
    
    list_musics = process_notes_in_song(dict_time_notes, seq_len)
    collected_list_input, collected_list_target = [], []
     
    for music in list_musics:
        list_training, list_target = generate_input_and_target(music, seq_len)
        collected_list_input += list_training
        collected_list_target += list_target
    return collected_list_input, collected_list_target

In [0]:
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [0]:
class Model(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers=2, output_emb=128, bidirectional=False):
        """ 
        """
        super(self.__class__, self).__init__()
        
        self.lstm = nn.LSTM(output_emb, hidden_size, n_layers, dropout = 0.3, bidirectional = bidirectional)
        
        if bidirectional:
            self.logits_fc = nn.Linear(hidden_size * 2, output_size)
        else:
            self.logits_fc = nn.Linear(hidden_size, output_size)
        
        self.softmax = nn.Softmax(dim = 1)
        

    def forward(self, input_sequences, hidden=None):
        
        batch_size = input_sequences.shape[0]
        outputs, hidden = self.lstm(input_sequences, hidden)        
        logits = self.logits_fc(outputs[:, -1, :])
        out = self.softmax(logits)
        neg_logits = (1 - logits)
        
        return out, hidden

In [0]:
val_set = list_all_midi[0:2]

In [0]:
inputs_val, outputs_val = generate_training_data(val_set, len(val_set), train_size = 50, target_size = 1, start_index=0, use_tqdm=False)

In [0]:
inputs_tensor_val = torch.tensor(inputs_val, dtype=torch.float).cuda()
outputs_tensor_val = torch.tensor(outputs_val, dtype=torch.float).cuda()
# inputs_val = inputs_val.reshape(-1, 128).cuda()
# outputs_val = outputs_val.reshape(-1, 128).cuda()

In [0]:
from IPython import display

class TrainModel:
  
    def __init__(self, epochs, sampled_midi, frame_per_second, 
               batch_nnet_size, batch_song, optimizer, loss_fn, total_songs, model, seq_len):
        self.epochs = epochs
        self.sampled_midi = sampled_midi
        self.frame_per_second = frame_per_second
        self.batch_nnet_size = batch_nnet_size
        self.batch_song = batch_song
        self.optimizer = optimizer
        #     self.checkpoint = checkpoint
        self.loss_fn = loss_fn
        #     self.checkpoint_prefix = checkpoint_prefix
        self.total_songs = total_songs
        self.model = model
        self.seq_len = seq_len
        self.train_loss_history = []
        self.val_loss_history = []
    
    def train(self):
        loss_history = []
      
      
        for epoch in tqdm_notebook(range(self.epochs),desc='epochs'):
            # for each epochs, we shufle the list of all the datasets
            shuffle(self.sampled_midi)
            
            train_val_separator_index = int(0.9 * len(self.sampled_midi))
            train_data = self.sampled_midi[:train_val_separator_index]
#             val_data = self.sampled_midi[train_val_separator_index:]
            
            loss_total = 0
            train_loss_history = []
            val_loss_history = []
            
            steps = 0
            steps_nnet = 0

            # We will iterate all songs by self.song_size
            for i in tqdm_notebook(range(0, len(train_data), self.batch_song), desc='MUSIC'):

                steps += 1
                
                inputs_train, outputs_train = generate_training_data(train_data, self.batch_song, 
                                                                     train_size = 50, target_size = 1, 
                                                                     start_index=i, use_tqdm=False)
                
#                 inputs_val, outputs_val = generate_training_data(val_data, len(val_data),
#                                                                  train_size = 50, target_size = 1,
#                                                                  start_index=i, use_tqdm=False)
                
                
                index_shuffled = np.arange(start=0, stop=len(inputs_train))
                np.random.shuffle(index_shuffled)

                for nnet_steps in tqdm_notebook(range(0,len(index_shuffled), self.batch_nnet_size)):
        
                    steps_nnet += 1
                    current_index = index_shuffled[nnet_steps:nnet_steps+self.batch_nnet_size]
                    inputs, outputs = inputs_train[current_index], outputs_train[current_index]

                    inputs_tensor_train = torch.tensor(inputs, dtype=torch.float).cuda()
                    outputs_tensor_train = torch.tensor(outputs, dtype=torch.float32).cuda()

                    if len(inputs) // self.batch_nnet_size != 1:
                        break
                        
                    self.model.train()
#                     print(inputs_tensor_train.shape, outputs_tensor_train.shape)
#                     print(inputs_tensor_val.shape, outputs_tensor_val.shape)
                    prediction, _ = self.model.forward(inputs_tensor_train)
                    loss = self.loss_fn(prediction, outputs_tensor_train)
                    loss.backward()
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    loss_total += loss
                    
                    n_verbose = 20
                    if steps_nnet % n_verbose == 0:
                        loss_mean = (loss_total / n_verbose).detach().cpu().numpy()
                        print("epochs {} | Steps {} | total loss : {}".format(epoch + 1, steps_nnet, loss_mean))
                        self.train_loss_history.append(loss_mean)
                        loss_total = 0
                        
                        val_prediction, _ = self.model.forward(inputs_tensor_val)#, self.seq_len)
#                         print(val_prediction.shape)
                        val_loss = self.loss_fn(val_prediction, outputs_tensor_val).detach().cpu().numpy()
                        print(f'val loss: {val_loss}')
                        self.val_loss_history.append(val_loss)
                        
                    self.model.eval()


In [0]:
learning_rate = 0.001
seq_len = 50
EPOCHS = 5
BATCH_SONG = 32
BATCH_NNET_SIZE = 512
TOTAL_SONGS = len(list_all_midi)
FRAME_PER_SECOND = 20

def compute_loss(prediction, targets):
    targets = targets.squeeze()
    n_notes = torch.sum(targets, keepdim=False, dim=1)
    mask = n_notes > 0
    targets = targets[mask, :] / n_notes[mask].view(-1, 1)
    loss_table = -torch.log(prediction[mask, :]) * targets
    loss_table = loss_table.sum(dim=1)
    loss = loss_table.mean()
    return loss

In [17]:
rnn_1 = Model(input_size=seq_len, hidden_size=16, output_size=128, bidirectional=False).cuda()
rnn_2 = Model(input_size=seq_len, hidden_size=32, output_size=128, bidirectional=False).cuda()
rnn_3 = Model(input_size=seq_len, hidden_size=16, output_size=128, bidirectional=True).cuda()
rnn_4 = Model(input_size=seq_len, hidden_size=16, n_layers=3, output_size=128, bidirectional=False).cuda()
rnn_5 = Model(input_size=seq_len, hidden_size=16, n_layers=1, output_size=128, bidirectional=False).cuda()
rnn_6 = Model(input_size=seq_len, hidden_size=8, output_size=128, bidirectional=False).cuda()

# for rnn, name in [(rnn_1, 'hid_16'), (rnn_2, 'hid_32'), (rnn_3, 'hid_16_bid'), (rnn_4, 'hid_16_3layers'), (rnn_5, 'hid_16_1layer'), (rnn_6, 'hid_8')][::-1]:
for rnn, name in [(rnn_6, 'hid_16_1layer')]:
    optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
    train_class = TrainModel(EPOCHS, list_all_midi[:1000], FRAME_PER_SECOND,
                  BATCH_NNET_SIZE, BATCH_SONG, optimizer, compute_loss, TOTAL_SONGS, rnn, seq_len)
    train_class.train()
    train_history = np.array(train_class.train_loss_history)
    val_history = np.array(train_class.val_loss_history)
    print(name, 'done')
    np.savetxt(name + '_train.txt', train_history)
    np.savetxt(name + '_val.txt', val_history)

  "num_layers={}".format(dropout, num_layers))


HBox(children=(IntProgress(value=0, description='epochs', max=5, style=ProgressStyle(description_width='initia…

HBox(children=(IntProgress(value=0, description='MUSIC', max=29, style=ProgressStyle(description_width='initia…

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, max=17), HTML(value='')))

HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

epochs 1 | Steps 40 | total loss : 9.224726676940918
val loss: 4.785609245300293


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

epochs 1 | Steps 60 | total loss : 4.5010576248168945
val loss: 4.655632972717285


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

epochs 1 | Steps 80 | total loss : 4.361612319946289
val loss: 4.466649532318115


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

epochs 1 | Steps 100 | total loss : 4.2068095207214355
val loss: 4.314789295196533


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

epochs 1 | Steps 120 | total loss : 4.112691402435303
val loss: 4.218844890594482


HBox(children=(IntProgress(value=0, max=23), HTML(value='')))

epochs 1 | Steps 140 | total loss : 4.0431976318359375
val loss: 4.150054931640625
epochs 1 | Steps 160 | total loss : 4.1641459465026855
val loss: 4.082986831665039


HBox(children=(IntProgress(value=0, max=15), HTML(value='')))

HBox(children=(IntProgress(value=0, max=23), HTML(value='')))

epochs 1 | Steps 180 | total loss : 3.7279422283172607
val loss: 4.046439170837402


HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

epochs 1 | Steps 200 | total loss : 3.9048898220062256
val loss: 4.034699440002441
epochs 1 | Steps 220 | total loss : 4.128667831420898
val loss: 4.036096572875977


HBox(children=(IntProgress(value=0, max=15), HTML(value='')))

HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

epochs 1 | Steps 240 | total loss : 3.67722487449646
val loss: 4.025332927703857


HBox(children=(IntProgress(value=0, max=17), HTML(value='')))

epochs 1 | Steps 260 | total loss : 3.8900468349456787
val loss: 4.02070951461792


HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

epochs 1 | Steps 280 | total loss : 3.8395771980285645
val loss: 4.02182149887085


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 1 | Steps 300 | total loss : 3.8802781105041504
val loss: 4.016318321228027


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 1 | Steps 320 | total loss : 3.8533096313476562
val loss: 4.004552841186523


HBox(children=(IntProgress(value=0, max=23), HTML(value='')))

epochs 1 | Steps 340 | total loss : 3.8842365741729736
val loss: 4.004412651062012


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 1 | Steps 360 | total loss : 3.8828067779541016
val loss: 3.991891384124756


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

epochs 1 | Steps 380 | total loss : 3.8197484016418457
val loss: 3.993356227874756


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

epochs 1 | Steps 400 | total loss : 3.8539726734161377
val loss: 4.001939296722412


HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

epochs 1 | Steps 420 | total loss : 3.882297992706299
val loss: 4.0105366706848145


HBox(children=(IntProgress(value=0, max=15), HTML(value='')))

epochs 1 | Steps 440 | total loss : 3.886761426925659
val loss: 4.016546249389648


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 1 | Steps 460 | total loss : 3.8448753356933594
val loss: 4.014798164367676


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

epochs 1 | Steps 480 | total loss : 3.850735902786255
val loss: 4.022144794464111


HBox(children=(IntProgress(value=0, max=17), HTML(value='')))

epochs 1 | Steps 500 | total loss : 3.83561372756958
val loss: 4.021051406860352


HBox(children=(IntProgress(value=0, max=25), HTML(value='')))

epochs 1 | Steps 520 | total loss : 3.8412156105041504
val loss: 4.019843578338623


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

epochs 1 | Steps 540 | total loss : 3.8225715160369873
val loss: 4.019814968109131


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, description='MUSIC', max=29, style=ProgressStyle(description_width='initia…

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 2 | Steps 60 | total loss : 11.51364803314209
val loss: 3.9919397830963135


HBox(children=(IntProgress(value=0, max=23), HTML(value='')))

epochs 2 | Steps 80 | total loss : 3.8404624462127686
val loss: 4.009425163269043


HBox(children=(IntProgress(value=0, max=16), HTML(value='')))

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

epochs 2 | Steps 120 | total loss : 7.699582099914551
val loss: 4.022160053253174


HBox(children=(IntProgress(value=0, max=17), HTML(value='')))

epochs 2 | Steps 140 | total loss : 3.812859296798706
val loss: 4.017430305480957


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 2 | Steps 160 | total loss : 3.8155505657196045
val loss: 4.006255149841309


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

epochs 2 | Steps 180 | total loss : 3.841076612472534
val loss: 3.9902942180633545


HBox(children=(IntProgress(value=0, max=16), HTML(value='')))

epochs 2 | Steps 200 | total loss : 3.831202268600464
val loss: 4.001410961151123


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

epochs 2 | Steps 220 | total loss : 3.79227614402771
val loss: 3.996333599090576


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 2 | Steps 240 | total loss : 3.8136138916015625
val loss: 3.984252452850342


HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

epochs 2 | Steps 260 | total loss : 3.8521270751953125
val loss: 3.993905782699585


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 2 | Steps 280 | total loss : 3.8472328186035156
val loss: 3.99137020111084


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

epochs 2 | Steps 300 | total loss : 3.830503225326538
val loss: 3.968195915222168


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 2 | Steps 320 | total loss : 3.8291003704071045
val loss: 3.9679207801818848


HBox(children=(IntProgress(value=0, max=24), HTML(value='')))

epochs 2 | Steps 340 | total loss : 3.8210575580596924
val loss: 3.9770352840423584


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

epochs 2 | Steps 360 | total loss : 3.8429787158966064
val loss: 3.976187229156494


HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

epochs 2 | Steps 380 | total loss : 3.852722644805908
val loss: 3.974113941192627


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

epochs 2 | Steps 400 | total loss : 3.8472466468811035
val loss: 3.980405330657959


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

epochs 2 | Steps 420 | total loss : 3.872281789779663
val loss: 3.9789628982543945


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

epochs 2 | Steps 440 | total loss : 3.858657121658325
val loss: 3.9715278148651123


HBox(children=(IntProgress(value=0, max=16), HTML(value='')))

epochs 2 | Steps 460 | total loss : 3.812913656234741
val loss: 3.9876644611358643


HBox(children=(IntProgress(value=0, max=23), HTML(value='')))

epochs 2 | Steps 480 | total loss : 3.808441638946533
val loss: 3.997544527053833


HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

epochs 2 | Steps 500 | total loss : 3.8293933868408203
val loss: 3.986876964569092


HBox(children=(IntProgress(value=0, max=13), HTML(value='')))

epochs 2 | Steps 520 | total loss : 3.8321750164031982
val loss: 3.987546682357788


HBox(children=(IntProgress(value=0, max=28), HTML(value='')))

epochs 2 | Steps 540 | total loss : 3.8570523262023926
val loss: 3.9889919757843018


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, description='MUSIC', max=29, style=ProgressStyle(description_width='initia…

HBox(children=(IntProgress(value=0, max=16), HTML(value='')))

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

epochs 3 | Steps 20 | total loss : 3.857165575027466
val loss: 3.981865882873535


HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

epochs 3 | Steps 40 | total loss : 3.8507580757141113
val loss: 3.999938488006592


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 3 | Steps 60 | total loss : 3.8831875324249268
val loss: 3.9960432052612305


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

epochs 3 | Steps 80 | total loss : 3.8485982418060303
val loss: 3.986969470977783


HBox(children=(IntProgress(value=0, max=17), HTML(value='')))

epochs 3 | Steps 100 | total loss : 3.8279435634613037
val loss: 3.980992317199707


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 3 | Steps 120 | total loss : 3.8027477264404297
val loss: 3.9662632942199707


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 3 | Steps 140 | total loss : 3.8037235736846924
val loss: 3.9497451782226562


HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

epochs 3 | Steps 160 | total loss : 3.7927491664886475
val loss: 3.9462621212005615


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 3 | Steps 180 | total loss : 3.801382541656494
val loss: 3.958564519882202


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 3 | Steps 200 | total loss : 3.7973358631134033
val loss: 3.9714107513427734


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

epochs 3 | Steps 220 | total loss : 3.7775676250457764
val loss: 3.9709584712982178


HBox(children=(IntProgress(value=0, max=14), HTML(value='')))

epochs 3 | Steps 240 | total loss : 3.7856597900390625
val loss: 3.963273763656616


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 3 | Steps 260 | total loss : 3.7742791175842285
val loss: 3.94948148727417


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

epochs 3 | Steps 280 | total loss : 3.7892472743988037
val loss: 3.958380937576294


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

epochs 3 | Steps 300 | total loss : 3.789691925048828
val loss: 3.96376895904541


HBox(children=(IntProgress(value=0, max=17), HTML(value='')))

HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

epochs 3 | Steps 340 | total loss : 7.577991008758545
val loss: 3.95090389251709


HBox(children=(IntProgress(value=0, max=26), HTML(value='')))

epochs 3 | Steps 360 | total loss : 3.794963836669922
val loss: 3.93576717376709


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

epochs 3 | Steps 380 | total loss : 3.7773818969726562
val loss: 3.927234649658203


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

epochs 3 | Steps 400 | total loss : 3.7603211402893066
val loss: 3.919691324234009


HBox(children=(IntProgress(value=0, max=23), HTML(value='')))

epochs 3 | Steps 420 | total loss : 3.7580394744873047
val loss: 3.91222882270813


HBox(children=(IntProgress(value=0, max=24), HTML(value='')))

epochs 3 | Steps 440 | total loss : 3.751359224319458
val loss: 3.9220197200775146


HBox(children=(IntProgress(value=0, max=16), HTML(value='')))

epochs 3 | Steps 460 | total loss : 3.7534024715423584
val loss: 3.926352024078369


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

epochs 3 | Steps 480 | total loss : 3.771019458770752
val loss: 3.919776439666748


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 3 | Steps 500 | total loss : 3.730220079421997
val loss: 3.9154934883117676


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 3 | Steps 520 | total loss : 3.7283616065979004
val loss: 3.8999524116516113


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

epochs 3 | Steps 540 | total loss : 3.758462905883789
val loss: 3.8856256008148193


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, description='MUSIC', max=29, style=ProgressStyle(description_width='initia…

HBox(children=(IntProgress(value=0, max=24), HTML(value='')))

epochs 4 | Steps 20 | total loss : 3.9393112659454346
val loss: 3.894634962081909


HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

epochs 4 | Steps 40 | total loss : 3.7014527320861816
val loss: 3.8992815017700195


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 4 | Steps 60 | total loss : 3.7344651222229004
val loss: 3.8975026607513428


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

epochs 4 | Steps 80 | total loss : 3.7484779357910156
val loss: 3.8869502544403076


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 4 | Steps 100 | total loss : 3.7700066566467285
val loss: 3.8839194774627686


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

epochs 4 | Steps 120 | total loss : 3.731868028640747
val loss: 3.8611791133880615


HBox(children=(IntProgress(value=0, max=17), HTML(value='')))

epochs 4 | Steps 140 | total loss : 3.7820022106170654
val loss: 3.855797052383423


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

epochs 4 | Steps 180 | total loss : 7.381570339202881
val loss: 3.849388360977173


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

epochs 4 | Steps 220 | total loss : 7.126835823059082
val loss: 3.8550024032592773


HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

epochs 4 | Steps 240 | total loss : 3.686603546142578
val loss: 3.8355841636657715


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

epochs 4 | Steps 280 | total loss : 7.594107151031494
val loss: 3.839834451675415


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

epochs 4 | Steps 300 | total loss : 3.487306594848633
val loss: 3.8270208835601807


HBox(children=(IntProgress(value=0, max=16), HTML(value='')))

epochs 4 | Steps 320 | total loss : 3.6702821254730225
val loss: 3.8138279914855957


HBox(children=(IntProgress(value=0, max=24), HTML(value='')))

epochs 4 | Steps 340 | total loss : 3.6654605865478516
val loss: 3.7988266944885254


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

epochs 4 | Steps 360 | total loss : 3.6946938037872314
val loss: 3.805807113647461


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

epochs 4 | Steps 380 | total loss : 3.6685523986816406
val loss: 3.8066859245300293


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

epochs 4 | Steps 400 | total loss : 3.6667988300323486
val loss: 3.8137452602386475


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

epochs 4 | Steps 420 | total loss : 3.676811933517456
val loss: 3.8019673824310303


HBox(children=(IntProgress(value=0, max=26), HTML(value='')))

epochs 4 | Steps 440 | total loss : 3.6608073711395264
val loss: 3.7967798709869385


HBox(children=(IntProgress(value=0, max=17), HTML(value='')))

epochs 4 | Steps 460 | total loss : 3.673675537109375
val loss: 3.8047263622283936


HBox(children=(IntProgress(value=0, max=23), HTML(value='')))

epochs 4 | Steps 480 | total loss : 3.6436240673065186
val loss: 3.7941336631774902


HBox(children=(IntProgress(value=0, max=16), HTML(value='')))

epochs 4 | Steps 500 | total loss : 3.652334213256836
val loss: 3.8122036457061768


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 4 | Steps 520 | total loss : 3.634352207183838
val loss: 3.794846296310425


HBox(children=(IntProgress(value=0, max=16), HTML(value='')))

epochs 4 | Steps 540 | total loss : 3.6156067848205566
val loss: 3.7670934200286865


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, description='MUSIC', max=29, style=ProgressStyle(description_width='initia…

HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

epochs 5 | Steps 20 | total loss : 3.814404249191284
val loss: 3.74638295173645


HBox(children=(IntProgress(value=0, max=17), HTML(value='')))

HBox(children=(IntProgress(value=0, max=17), HTML(value='')))

epochs 5 | Steps 40 | total loss : 3.4604854583740234
val loss: 3.7495810985565186


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

epochs 5 | Steps 60 | total loss : 3.6228718757629395
val loss: 3.7380285263061523


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

epochs 5 | Steps 80 | total loss : 3.593816041946411
val loss: 3.736933708190918


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

epochs 5 | Steps 100 | total loss : 3.623891592025757
val loss: 3.728949785232544


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

epochs 5 | Steps 120 | total loss : 3.5959320068359375
val loss: 3.713803768157959


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

epochs 5 | Steps 140 | total loss : 3.571964979171753
val loss: 3.7103664875030518


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

epochs 5 | Steps 160 | total loss : 3.6105751991271973
val loss: 3.7067208290100098


HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

epochs 5 | Steps 180 | total loss : 3.5977590084075928
val loss: 3.727585792541504


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

epochs 5 | Steps 200 | total loss : 3.581956148147583
val loss: 3.7293992042541504


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

epochs 5 | Steps 220 | total loss : 3.6261584758758545
val loss: 3.7367475032806396


HBox(children=(IntProgress(value=0, max=16), HTML(value='')))

epochs 5 | Steps 240 | total loss : 3.619389057159424
val loss: 3.7293734550476074


HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

epochs 5 | Steps 260 | total loss : 3.604356050491333
val loss: 3.7106680870056152


HBox(children=(IntProgress(value=0, max=24), HTML(value='')))

epochs 5 | Steps 280 | total loss : 3.565584897994995
val loss: 3.702362060546875


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

epochs 5 | Steps 300 | total loss : 3.5675415992736816
val loss: 3.699354887008667


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

epochs 5 | Steps 320 | total loss : 3.5828957557678223
val loss: 3.7068727016448975


HBox(children=(IntProgress(value=0, max=17), HTML(value='')))

epochs 5 | Steps 340 | total loss : 3.579643726348877
val loss: 3.6981687545776367


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 5 | Steps 360 | total loss : 3.5531728267669678
val loss: 3.7003321647644043


HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

epochs 5 | Steps 380 | total loss : 3.575399160385132
val loss: 3.6934897899627686


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 5 | Steps 400 | total loss : 3.5695583820343018
val loss: 3.6785192489624023


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

epochs 5 | Steps 420 | total loss : 3.562192678451538
val loss: 3.6670734882354736


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

epochs 5 | Steps 440 | total loss : 3.5615546703338623
val loss: 3.6631083488464355


HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

epochs 5 | Steps 460 | total loss : 3.554725408554077
val loss: 3.6764559745788574


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

epochs 5 | Steps 480 | total loss : 3.5572049617767334
val loss: 3.6923553943634033


HBox(children=(IntProgress(value=0, max=16), HTML(value='')))

epochs 5 | Steps 500 | total loss : 3.5555920600891113
val loss: 3.688473701477051


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

epochs 5 | Steps 520 | total loss : 3.5298027992248535
val loss: 3.6828501224517822


HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

epochs 5 | Steps 540 | total loss : 3.524007558822632
val loss: 3.67398738861084


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

hid_16_1layer done


In [18]:
from sklearn.preprocessing import MultiLabelBinarizer
mlb = MultiLabelBinarizer()
mlb.fit([list(range(128))])

MultiLabelBinarizer(classes=None, sparse_output=False)

In [0]:
def generate_from_random(unique_notes, batch_size, seq_len=50):
    random_notes = [[x] for x in np.random.choice(128, 50)]
    random_notes = mlb.transform(random_notes)
    return torch.LongTensor(random_notes).view(1, random_notes.shape[0], random_notes.shape[1]).cuda()

def generate_from_constant(unique_notes, batch_size, seq_len=50, note=64):
    random_notes = [[seq_len] for _ in range(seq_len)]
    random_notes = mlb.transform(random_notes)
    return torch.LongTensor(random_notes).view(1, random_notes.shape[0], random_notes.shape[1]).cuda()
  
def generate_notes(previous_notes, model, batch_size, unique_notes, max_generated=1000, seq_len=50, threshold=0.015):
    for n_step in range(max_generated - seq_len):
        prev_notes_truncated = previous_notes[:, -seq_len:, :].float()
        next_notes_probs, _ = model(prev_notes_truncated.cuda())
#         next_notes_probs, _ = model(prev_notes_truncated)
#         new_notes = torch.arange(128)[next_notes_probs[0] > threshold]
        new_notes = torch.multinomial(next_notes_probs, 2)
#         print(new_notes)
        new_notes = list(new_notes.cpu().numpy()[0])
#         print(new_notes)
        
#         print(new_notes)
        new_note_vector = mlb.transform([new_notes])  # WRONG INDEXES!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
#         print(new_note_vector)
        new_note_vector = torch.LongTensor(new_note_vector).view(1, new_note_vector.shape[0], new_note_vector.shape[1]).cuda()
#         for note, note_prob in enumerate(next_notes_probs):
#             if np.random.uniform() < note_prob:
#                 next_notes.append(note)
        
#         NEW_NOTE = torch.multinomial(next_notes_probs, 1)
# #         print(NEW_NOTE)
# #         print(previous_notes)
#         previous_notes = torch.cat([previous_notes[:, :], NEW_NOTE], dim=1, out=previous_notes)
        previous_notes = torch.cat([previous_notes[:, :, :], new_note_vector], dim=1)
        
    
    return previous_notes

def write_midi_file_from_generated(generate, midi_file_name = "result.mid", start_index=49, fs=20, max_generated=1000):
    array_piano_roll = np.zeros((128, max_generated), dtype=np.int16)
    
    for index, notes in enumerate(generate[:, :]):
#         print(notes)
        nonzero_notes = np.nonzero(notes)
        if len(nonzero_notes) > 0:
#             print(nonzero_notes)
            nonzero_notes = nonzero_notes[0]
        else:
            continue

        for j in nonzero_notes:
#             print(f'j = {j}, index = {index}')
            array_piano_roll[int(j),index] = 1
#     print(array_piano_roll.sum())
    generate_to_midi = piano_roll_to_pretty_midi(array_piano_roll, fs=fs)
    for note in generate_to_midi.instruments[0].notes:
        note.velocity = 100
    generate_to_midi.write(midi_file_name)
    
    return array_piano_roll

def piano_roll_to_pretty_midi(piano_roll, fs=100, program=0):
    '''Convert a Piano Roll array into a PrettyMidi object
     with a single instrument.
    Parameters
    ----------
    piano_roll : np.ndarray, shape=(128,frames), dtype=int
        Piano roll of one instrument
    fs : int
        Sampling frequency of the columns, i.e. each column is spaced apart
        by ``1./fs`` seconds.
    program : int
        The program number of the instrument.
    Returns
    -------
    midi_object : pretty_midi.PrettyMIDI
        A pretty_midi.PrettyMIDI class instance describing
        the piano roll.
    '''
    notes, frames = piano_roll.shape
    pm = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=program)

    # pad 1 column of zeros so we can acknowledge inital and ending events
    piano_roll = np.pad(piano_roll, [(0, 0), (1, 1)], 'constant')

    # use changes in velocities to find note on / note off events
    velocity_changes = np.nonzero(np.diff(piano_roll).T)

    # keep track on velocities and note on times
    prev_velocities = np.zeros(notes, dtype=int)
    note_on_time = np.zeros(notes)

    for time, note in zip(*velocity_changes):
        # use time + 1 because of padding above
        velocity = piano_roll[note, time + 1]
        time = time / fs
        if velocity > 0:
            if prev_velocities[note] == 0:
                note_on_time[note] = time
                prev_velocities[note] = velocity
        else:
            pm_note = pretty_midi.Note(
                velocity=prev_velocities[note],
                pitch=note,
                start=note_on_time[note],
                end=time)
            instrument.notes.append(pm_note)
            prev_velocities[note] = 0
    pm.instruments.append(instrument)
    return pm

In [0]:
rnn_1.eval()
max_generate = 1000
# unique_notes = note_tokenizer.unique_word
seq_len = 50

previous_notes = generate_from_constant(128, 1)
previous_notes = generate_notes(previous_notes, rnn, 1, 128, max_generated=max_generate, threshold=0.001).float().cuda()
array_piano_roll = write_midi_file_from_generated(previous_notes[0], "constant.mid", start_index=seq_len-1, 
                                                  fs=10, max_generated = max_generate)

In [0]:
previous_notes = generate_from_random(128, 1)
previous_notes = generate_notes(previous_notes, rnn, 1, 128, max_generated=max_generate, threshold=0.001).float().cuda()
array_piano_roll = write_midi_file_from_generated(previous_notes[0], "random.mid", start_index=seq_len-1, 
                                                  fs=10, max_generated = max_generate)