# MIDI-RNN

Tensorflow-based recurrent neural network for generating MIDI music.

The input training data is a <a href="https://www.reddit.com/r/datasets/comments/3akhxy/the_largest_midi_collection_on_the_internet/">huge</a> dataset of MIDI tunes scraped from the internet.

The MIDI arrays are turned into numpy arrays that are fed into a tensorflow model that concatenates MIDI timing and note information. The network tries to predict the next note and its timing. Two LSTMs are used, along with a number of dense layers. The input data is also passed through a dot-product query layer, in some sense akin to attention heads.

The network can then be made self-recursive to query a generated MIDI song, which is played by the pygame python module.

# Data Download

Lets begin by downloading and unzipping the MIDI zip file.

Use <a href="https://mega.co.nz/#!Elg1TA7T!MXEZPzq9s9YObiUcMCoNQJmCbawZqzAkHzY4Ym6Gs_Q">this</a> mega link to retrieve the file, unzip it, and move it to the working directory.

<p style="font-weight:800">WARNING</p>file size is over 1GB!

## Imports

In [None]:
%%capture
!pip install numpy
!pip intall mido
!pip install tensorflow==1.8.0
!pip install tensorboard
!pip install tqdm
!pip install pygame

In [None]:
%%capture
import random
import numpy as np
import tensorflow as tf
from tensorflow.contrib import rnn
import mido
from mido import Message, MidiFile, MidiTrack, MAX_PITCHWHEEL
import glob
import copy
from subprocess import Popen
import time
import os
from tqdm import tqdm
from IPython.display import IFrame
import pygame

## Utility Functions

In [None]:
def create_midi_track(midi_notes,midi_timing):
    '''
    Utility function to turn RNN outputs into mido MIDI track.
    
    Inputs
    
    midi_notes: numpy array of note data, from calling Tensorflow forward pass
    midi_timing: numpy array of note timing, from calling Tensorflow forward pass
    
    Function
    
    Parses correct MIDI relative timing, and adds notes and their timing to mido track file,
    subsequently saving the track as a .mid MIDI file.
    
    Outputs
    
    None
    '''
    outfile = MidiFile(type=0)
    track = MidiTrack()
    outfile.tracks.append(track)
    aggregator = []
    aggregator_times = []
    for index,msg_array in enumerate(midi_notes):
        msg_array = np.packbits(msg_array,axis=0)
        if (msg_array[0] > 0):
            for x in range(1,1+min(3,msg_array[4])):
                aggregator.append(msg_array[x])
                aggregator_times.append(midi_timing[index])
        else:
                if len(aggregator) > 0:
                    try:
                        msg = Message.from_bytes(aggregator)
                        msg.time = max(0,int(np.round(np.mean(np.asarray(aggregator_times))*10)))
                        track.append(msg)
                    except Exception as e:
                        print(e)
                    aggregator = []
                    aggregator_times = []
                try:
                    msg = Message.from_bytes(msg_array[1:1+min(3,msg_array[4])])
                    msg.time = max(0,int(np.round(np.mean(midi_timing[index])*10)))
                    track.append(msg)
                except Exception as e:
                    print(e)
    try:
        outfile.save('generated_midi.mid')
    except Exception as e:
        print(e)

def load_midi_file(filename):
    '''
    Utility function for loading midi file as numpy data and timing arrays
    
    Inputs
    
    filename: string denoting file location
    
    Function:
    
    midi note data is stored as a (-1,5) array.
    This is because some MIDI commands are longer than one byte.
    The first scalar (cur[0]) denotes whether this MIDI
    command is longer than one byte.
    The final scalar (cur[4]) denotes the byte order.
    The scalars in between store the byte values.
    
    Finally, the note data is unpacked into bits, producing a size-40 array
    (1 byte * 5 = 8 bits * 5 = 40)
    
    Outputs:
    
    midi_data: numpy data array indicating the midi note and velocity as well as which
    midi bytes are part of the same note
    midi_time: numpy timing data array for the midi notes
    
    '''
    midi_file = mido.MidiFile(filename)
    midi = [msg for msg in midi_file]
    bytes_num = np.max(np.asarray([len(x.bytes()) for x in midi]))
    midi_data = []
    midi_time = []
    for x in midi:
        cur = np.zeros((5))
        for ind,byte in enumerate(x.bytes()):
            if (len(x.bytes())>3):
                cur[0] = 1
            cur[1 + (ind)%3] = byte
            if ((ind+1) % 3 ==0 or ind + 1 == len(x.bytes())):
                cur[4] = (ind % 3) + 1
                midi_data.append(cur)
                midi_time.append(x.time*midi_file.ticks_per_beat*2)
                cur = np.zeros((5))
    midi_data = np.asarray(midi_data,dtype=np.dtype('B'))
    midi_time = np.asarray(midi_time).astype(np.float32)
    midi_data = np.unpackbits(midi_data,axis=1).astype(np.int32)
    return midi_data, midi_time
            
def is_midi_0(filename):
    '''
    Utility function for testing if .mid file is type 0 (delta-time) and loads correctly
    '''
    try:
        return mido.MidiFile(filename).type == 0
    except Exception as e:
        return False
    
def extract_samples(midis_array=None,midi_times_array=None,p=0, seq_length=25, current_file=0):
    '''
    Utility function for retrieving next RNN input and target arrays,
    as well as applying basic augmentation through addition of noise,
    and setting random time-points to zero.
    '''
    input_midi_val = midis_array[current_file][p:p+seq_length].reshape(seq_length,5*8).astype(np.float32)
    input_time_val = midi_times_array[current_file][p:p+seq_length].reshape(seq_length,1).astype(np.float32)
    target_midi_val = midis_array[current_file][p + 1: p+seq_length + 1].reshape(seq_length,5*8)
    target_time_val = midi_times_array[current_file][p + 1: p+seq_length + 1].reshape(seq_length,1)
    input_midi_val += np.random.normal(scale=0.1,size=seq_length*5*8).reshape(seq_length,5*8)
    input_time_val += np.random.normal(scale=0.1,size=seq_length).reshape(seq_length,1)
    input_midi_val = np.maximum(0,input_midi_val)
    input_time_val = np.maximum(0,input_time_val)

    random_int = int(np.random.randint(seq_length,size=1))
    input_midi_val[random_int] = np.zeros((1,5*8))
    input_time_val[random_int] = np.zeros((1,1))
    return input_midi_val, input_time_val, target_midi_val, target_time_val

## Data Input

We will use glob to find the MIDI file-paths from the dataset we downloaded.

In [None]:
cwd = os.getcwd()
file_list = sorted(glob.glob(cwd + "/**/**/*.mid",recursive=True))

## Hyperparameters & Random Seed

In [None]:
seed_value = 42
tf.set_random_seed(seed_value)
random.seed(seed_value)
hidden_size = 500
dense_size = 100
seq_length = int(25)
dropout_rate = 0.95
learning_rate = 1e-4 # Adam Learning Rate
iterations = 1000 # Iterations per training run
max_midi_files_in_memory = 500 # Number of MIDI files stored in memory
loading_rate = 0.99 # Rate of replacing MIDi files in memory with new samples

## Placeholders Etc.

Here we set up placeholder nodes for the TF graph: RNN inputs and targets, LSTM states, the
global step of the optimiser,the dropout rate, and RNN sequence length.

In [None]:
global_step = tf.Variable(0, name='global_step', trainable=False)

input_midi     = tf.placeholder(shape=[seq_length,5*8], dtype=tf.float32, name="inputs")
input_time     = tf.placeholder(shape=[seq_length,1], dtype=tf.float32, name="inputs_time")
targets_midi     = tf.placeholder(shape=[seq_length,5*8], dtype=tf.float32, name="targets")
targets_time     = tf.placeholder(shape=[seq_length,1], dtype=tf.float32, name="targets_time")

ic0_c = tf.placeholder(shape=[1, hidden_size], dtype=tf.float32, name="statec0c")
ic0_h = tf.placeholder(shape=[1, hidden_size], dtype=tf.float32, name="statec0h")
ic1_c = tf.placeholder(shape=[1, hidden_size], dtype=tf.float32, name="statec1c")
ic1_h = tf.placeholder(shape=[1, hidden_size], dtype=tf.float32, name="statec1h")

rate_ph = tf.placeholder(shape=[],dtype=tf.float32,name="dropout")
seq_length_ph = tf.placeholder(shape=[1], dtype=tf.int32, name="seqlength")

initial_state_c0 = tf.nn.rnn_cell.LSTMStateTuple(ic0_c, ic0_h)
initial_state_c1 = tf.nn.rnn_cell.LSTMStateTuple(ic1_c, ic1_h)

## Tensorflow Graph

This is the Tensorflow graph for the neural network.
The MIDI timing and note values are concatenated into inputs0.
This layer is passed through an l2-normalised dot-product query-mapping,
and is used to gate another dense layer that uses the same inputs.
The gated dense layer is then fed into the first LSTM.
The output of this LSTM passes through a further dense layer, before
passing through the second LSTM. The output of this final LSTM is passed through
yet another dense layer, which forks into two dense layers, that respectively feed
into dense layers predicting the target MIDI note and timing values for the next timestep.
The forking is used to allow the network capacity to split the hidden representation into
timing and note distributions.

In [None]:
with tf.variable_scope("RNN") as scope:

        rnn1f = rnn.LSTMCell(hidden_size, reuse=tf.AUTO_REUSE,name="lstm1", initializer = tf.random_normal_initializer(stddev=0.01))
        rnn2f = rnn.LSTMCell(hidden_size, reuse=tf.AUTO_REUSE,name="lstm2", initializer = tf.random_normal_initializer(stddev=0.01))

        zero_state1 = rnn1f.zero_state(1,tf.float32)
        zero_state2 = rnn2f.zero_state(1,tf.float32)

        inputs0 = tf.concat([input_midi,input_time],axis=1)
        query_table = tf.Variable(np.random.normal(scale=0.01,size=(hidden_size,41)),dtype=tf.float32,name="query")
        query_norm = tf.transpose(tf.nn.l2_normalize(1e-7+query_table,axis=1))
        multiplier = tf.Variable(1.0,dtype=tf.float32,name="multiplier")
        gating = tf.nn.tanh(multiplier*tf.matmul(tf.nn.l2_normalize(inputs0 + 1e-7,axis=1),query_norm))
        inputs1_gated = gating*tf.layers.dense(inputs0, name="inputs1_gated",units=hidden_size,activation=tf.nn.leaky_relu,kernel_initializer = tf.random_normal_initializer(stddev=0.01),reuse=tf.AUTO_REUSE)
        outputs1f, states1f = rnn.static_rnn(rnn1f, tf.split(inputs1_gated,seq_length), initial_state=initial_state_c0, dtype=tf.float32,sequence_length=seq_length_ph)
        mid2b = tf.layers.dense(tf.concat(outputs1f,axis=0), name="mid2b",units=hidden_size,activation=tf.nn.leaky_relu,kernel_initializer = tf.random_normal_initializer(stddev=0.01),reuse=tf.AUTO_REUSE)
        outputs2f, states2f = rnn.static_rnn(rnn2f, tf.split(tf.nn.dropout(mid2b,rate_ph),seq_length),initial_state=initial_state_c1, dtype=tf.float32,sequence_length=seq_length_ph)
        output_fork = tf.layers.dense(tf.concat(outputs2f,axis=0), name="output_fork",units=hidden_size,activation=tf.nn.leaky_relu,kernel_initializer = tf.random_normal_initializer(stddev=0.01),reuse=tf.AUTO_REUSE)
        pre_midi_output = tf.layers.dense(output_fork, units =  dense_size, name='pre_midi_output', activation=tf.nn.leaky_relu, kernel_initializer = tf.random_normal_initializer(stddev=0.01),reuse=tf.AUTO_REUSE)
        midi_output = tf.layers.dense(pre_midi_output, units =  5*8, name='midi_output',activation=None, kernel_initializer = tf.random_normal_initializer(stddev=0.01),reuse=tf.AUTO_REUSE)
        pre_midi_time = tf.layers.dense(output_fork, units =  dense_size, name='pre_midi_time',activation=tf.nn.leaky_relu, kernel_initializer = tf.random_normal_initializer(stddev=0.01),reuse=tf.AUTO_REUSE)
        midi_time = tf.layers.dense(pre_midi_time, units =  1, name='midi_time',activation=tf.nn.leaky_relu,kernel_initializer = tf.random_normal_initializer(stddev=0.1),reuse=tf.AUTO_REUSE)

# RNN outputs, Loss function, Optimiser
We take the sigmoid of the midi output predictions to constrain the value between 0 and 1.
We apply both MSE loss and cosine loss to make the network learn the bit-representations of MIDI commands,
as well as absolute values. For the timing loss, we use purely RMSE.

We apply gradient clipping to stabilise learning.

In [None]:
midi_output = tf.nn.sigmoid(midi_output)
loss_mse = tf.reduce_mean((midi_output- targets_midi)**2)
loss_mse = tf.where(tf.is_nan(loss_mse), tf.zeros_like(loss_mse), loss_mse)
loss_cosine = tf.reduce_mean(tf.losses.cosine_distance(axis=1,labels=tf.nn.l2_normalize(1e-7+midi_output),\
                                                       predictions=tf.nn.l2_normalize(1e-7+targets_midi))**2)

loss_time =  tf.sqrt(1e-7 + tf.reduce_mean(((midi_time- targets_time))**2))
loss_time = tf.where(tf.is_nan(loss_time), tf.zeros_like(loss_time), loss_time)

loss = (loss_mse + 0.001*loss_cosine + loss_time*0.001)

tf.summary.scalar('loss/loss',loss)
tf.summary.scalar('loss/mse',loss_mse)
tf.summary.scalar('loss/mse_time',loss_time)
tf.summary.histogram('outputs/outputs1',midi_output)
tf.summary.histogram('targets/targets',targets_midi)
tf.summary.histogram('outputs/time',midi_time)
tf.summary.histogram('targets/time',targets_time)

# Minimizer
minimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

# Gradient clipping
grad_clipping = tf.constant(5.0, name="grad_clipping")
clipped_grads_and_vars = []
for index,the_tuple in enumerate(minimizer.compute_gradients(loss,var_list=tf.trainable_variables())):
    grad, var = the_tuple[0], the_tuple[1]
    clipped_grad = tf.clip_by_value(grad, -grad_clipping, grad_clipping)
    clipped_grad = tf.where(tf.is_nan(clipped_grad), tf.zeros_like(clipped_grad), clipped_grad)
    clipped_grads_and_vars.append((clipped_grad, var))

updates = minimizer.apply_gradients(clipped_grads_and_vars,global_step=global_step)

## Tensorflow Session, Initialisations

In [None]:
sess = tf.Session()
saver = tf.train.Saver()
init = tf.global_variables_initializer()
init2 = tf.initialize_all_variables()
train_writer = tf.summary.FileWriter("./rnn_audio/summary_rnn", sess.graph)

merged = tf.summary.merge_all()
restore = False
if (not restore):
    sess.run(init)
    sess.run(init2)
else:
    checkpoint = 0 # Add last checkpoint file number here
    saver.restore(sess,"./rnn_audio/rnn_audio-{}".format(checkpoint))

midis_array = []
midi_times_array = []
first_midi_array, first_midi_times_array = load_midi_file(file_list[0])
midis_array.append(first_midi_array)
midi_times_array.append(first_midi_times_array)
p_list = [0]
rnn_state1_list = [copy.deepcopy(sess.run(zero_state1))]
rnn_state2_list = [copy.deepcopy(sess.run(zero_state2))]
file_counter = len(midis_array)-1
current_file = 0
last_iteration = 0

## Prepare Tensorboard

We use subprocess.Popen for an async call to tensorboard, then display the localhost website in an iFrame.
You can subsequently call the cell after the training-code to terminate the tensorboard process.

In [None]:
p = Popen(['tensorboard','--logdir=./rnn_audio/summary_rnn']) # something long running
time.sleep(10) # Sleep while tensorboard setting up.
IFrame(src="http://localhost:6006/", width='100%', height='500px')

## Training Loop

This is the main training loop of the notebook. You can re-run this to keep training the network.
There is a sample-pool of MIDI files that is randomly refreshed.

In [None]:

for iteration in tqdm(range(last_iteration,last_iteration+iterations)):
    # Initialize
    current_file = int(np.random.randint(len(midis_array),size=1))
    while (p_list[current_file] + seq_length + 1 > midis_array[current_file].shape[0]):
        # Reset current file tracking   
        p_list[current_file] = 0
        rnn_state1_list[current_file] = copy.deepcopy(sess.run(zero_state1))
        rnn_state2_list[current_file] = copy.deepcopy(sess.run(zero_state2))
        current_file = int(np.random.randint(len(midis_array),size=1))
        # Transition to loading new file
        if (np.random.uniform() > loading_rate):
            file_loaded = False
            while (not file_loaded or (file_choice in loaded_files)):
                try:
                    file_choice = int(np.random.randint(len(file_list),size=1))
                    if (is_midi_0(file_list[file_choice])):
                        midi_file = mido.MidiFile(file_list[file_choice])
                        file_loaded = True
                except Exception as e:
                    print(e)
            p_list.append(0)
            rnn_state1_list.append(copy.deepcopy(sess.run(zero_state1)))
            rnn_state2_list.append(copy.deepcopy(sess.run(zero_state2)))
            if (len(midis_array) > max_midi_files_in_memory):
                midis_array = midis_array[1:max_midi_files_in_memory]
                midi_times_array = midi_times_array[1:max_midi_files_in_memory]
                p_list = p_list[1:max_midi_files_in_memory]
                rnn_state1_list = rnn_state1_list[1:max_midi_files_in_memory]
                rnn_state2_list = rnn_state2_list[1:max_midi_files_in_memory]
                loaded_files = loaded_files[1:max_midi_files_in_memory]
            current_file = len(midis_array)-1
            
    p = p_list[current_file]
    input_midi_val, input_time_val, target_midi_val, target_time_val = \
    extract_samples(midis_array=midis_array,midi_times_array=midi_times_array,p=p, seq_length=seq_length, current_file=current_file)

    states1f_val,states2f_val,loss_val,_,summary = sess.run([states1f,states2f,loss,updates,merged],
                                      feed_dict={input_midi : input_midi_val,
                                                 input_time : input_time_val,
                                                 ic1_c: rnn_state2_list[current_file].c,
                                                 ic1_h: rnn_state2_list[current_file].h,
                                                 ic0_c: rnn_state1_list[current_file].c,
                                                 ic0_h: rnn_state1_list[current_file].h,
                                                 rate_ph: dropout_rate,
                                                 targets_midi: target_midi_val,
                                                 targets_time : target_time_val,
                                                 seq_length_ph:np.ones((1))*seq_length
                                                 })
    rnn_state1_list[current_file] = states1f_val
    rnn_state2_list[current_file] = states2f_val

    if (iteration % 200 == 0):
        train_writer.add_summary(summary, iteration)
        print('iter: {}, p: {}, loss: {}'.format(iteration, p, loss_val))
    p_list[current_file] += seq_length

saver.save(sess,"./rnn_audio/rnn_audio",global_step=global_step)
last_iteration += iterations

In [None]:
# Close Tensorboard
p.terminate()

## Generate RNN MIDI Music

Here we feed the predictions of the RNN model into itself, to generate
novel MIDI compositions, which can then be played later.

In [None]:
rnn_state1 = copy.deepcopy(sess.run(zero_state1))
rnn_state2 = copy.deepcopy(sess.run(zero_state2))
midi_data_vals = []
midi_time_vals = []
input_midi_val = midis_array[current_file][0:seq_length].reshape(seq_length,5*8)
input_time_val = midi_times_array[current_file][0:seq_length].reshape(seq_length,1)

midi_data_vals.append(np.around(input_midi_val[0]).astype(np.int32))
midi_time_vals.append(np.around(input_time_val[0]).astype(np.int32))

generated_midi_length = 2500
        
print('Sampling a generated MIDI file')
for t in tqdm(range(generated_midi_length)):
    midi_data_current,midi_time_current,rnn_state1,rnn_state2 = \
                    sess.run([midi_output,midi_time,states1f,states2f],
                             feed_dict={input_midi: input_midi_val,
                                                     input_time : input_time_val,
                                                     ic1_c: rnn_state2.c,
                                                     ic1_h: rnn_state2.h,
                                                     ic0_c: rnn_state1.c,
                                                     ic0_h: rnn_state1.h,
                                                     rate_ph: 1.0,
                                                     seq_length_ph: np.ones((1))
                                                     })
            
    input_midi_val[0] = np.maximum(0,np.asarray(np.nan_to_num(midi_data_current[0].reshape(1,5*8))))
    input_time_val[0] = np.maximum(0,np.asarray(np.nan_to_num(midi_time_current[0].reshape(1,1))))
    midi_data_vals.append(np.maximum(0,np.around(midi_data_current[0]).astype(np.int32)))
    midi_time_vals.append(np.maximum(0,np.around(midi_time_current[0]).astype(np.float32)))

create_midi_track(midi_data_vals,midi_time_vals)

## Play MIDI tune

MIDI pygame example from <a href="https://www.daniweb.com/programming/software-development/code/216979/embed-and-play-midi-music-in-your-code-python">here<a/>

In [None]:
def play_music(music_file):
    """
    stream music with mixer.music module in blocking manner
    this will stream the sound from disk while playing
    """
    clock = pygame.time.Clock()
    try:
        pygame.mixer.music.load(music_file)
    except pygame.error as e:
        print(e)
        return
    pygame.mixer.music.play()
    while pygame.mixer.music.get_busy():
        # check if playback has finished
        clock.tick(30)
freq = 44100    # audio CD quality
bitsize = -16   # unsigned 16 bit
channels = 2    # 1 is mono, 2 is stereo
buffer = 1024    # number of samples
pygame.mixer.init(freq, bitsize, channels, buffer)
# optional volume 0 to 1.0
pygame.mixer.music.set_volume(0.3)
try:
    # use the midi file you just saved
    play_sample = True
    sample_path = os.path.join(os.getcwd(),"sample.mid")
    generated_path = os.path.join(os.getcwd(),"generated_midi.mid")
    play_music(sample_path if play_sample else generated_path)
except KeyboardInterrupt:
    # if user hits Ctrl/C then exit
    # (works only in console mode)
    pygame.mixer.music.fadeout(1000)
    pygame.mixer.music.stop()
    raise SystemExit