# Training the network based on the dataset of midi files

In [2]:
import tensorflow.compat.v1 as tf
from tensorflow.keras import datasets, layers, models, optimizers
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.8
tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config))

import glob
import pickle
import numpy as np
import matplotlib.pyplot as plt

from music21 import converter, instrument, note, chord

from tensorflow.keras.layers import LSTM
from tensorflow.keras.layers import Activation
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
from tensorflow.keras.utils import to_categorical
from tensorflow.python.keras.utils import np_utils
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import BatchNormalization as BatchNorm

### Reading the Data 

In [2]:
file = converter.parse("midi_songs/bicycle-ride.mid")
# converter.parse loads the midi file into a Music21 Stream Object
# Stream object provides list of all the notes and chords
components = []
# notes and chords placed inside a list
for element in file.recurse():
    components.append(element)
# String notation of each note object appended to the the list
    
    print(element)

<music21.stream.Part 0x7feb10407690>
(---o---): Piano
<music21.tempo.MetronomeMark allegro Quarter=133.0>
C major
C major
C major
C major
C major
C major
C major
C major
C major
C major
C major
C major
C major
C major
C major
C major
C major
C major
C major
C major
C major
C major
C major
C major
<music21.meter.TimeSignature 4/4>
<music21.note.Rest rest>
<music21.note.Note C>
<music21.note.Note D>
<music21.note.Note E>
<music21.note.Note C>
<music21.note.Note B>
<music21.note.Rest rest>
<music21.note.Note A>
<music21.note.Note G>
<music21.note.Note F>
<music21.note.Rest rest>
<music21.note.Note F>
<music21.note.Note G>
<music21.note.Note A>
<music21.note.Note D>
<music21.note.Note C>
<music21.note.Note B>
<music21.note.Note A>
<music21.note.Note B>
<music21.note.Note C>
<music21.note.Rest rest>
<music21.note.Note A>
<music21.note.Note G>
<music21.note.Rest rest>
<music21.note.Note C>
<music21.note.Note B>
<music21.note.Note C>
<music21.note.Note A>
<music21.note.Rest rest>
<music21.not

### Data Preparation

In [3]:
def train_network():
    ### Predefined Definitions ###
    notes = get_notes()

    # get amount of pitch names
    n_vocab = len(set(notes))

    network_input, network_output = prepare_sequences(notes, n_vocab)

    model = create_network(network_input, n_vocab)

    train(model, network_input, network_output)

In [4]:
# notes and chords placed into sequential list
# next step to create sequences that function as the input for the network

In [5]:
def get_notes():
    ### function to fetch notes and chords from midi files ###
    notes = []
    # Loading the data into an array

    # From array to Stream Object in preparation for Music21
    for file in glob.glob("midi_songs/*.mid"):
        midi = converter.parse(file)

        print("Parsing %s" % file)

        notes_to_parse = None

        try: # file has instrument parts
            # Seperating the different sounds present in each midi file
            s2 = instrument.partitionByInstrument(midi)
            notes_to_parse = s2.parts[0].recurse() 
        except: # file has notes in a flat structure
            notes_to_parse = midi.flat.notes

            # isinstance returns true if criteria is met, otherwrise returns false
        for element in notes_to_parse:
            if isinstance(element, note.Note):
                notes.append(str(element.pitch))
            elif isinstance(element, chord.Chord):
                notes.append('.'.join(str(n) for n in element.normalOrder))
                
            # append the pitch of every note object using its string notation
            # append the chord by encoding every note present in the chord into single string

    with open('data/notes', 'wb') as filepath:
        pickle.dump(notes, filepath)

    return notes

### Data Preprocessing

In [6]:
# mapping function to convert string based categorical data to integer
# e.g. ABC notation to integer so that the algo can interpret it
# finally, input sequences for the network will be developed
# followed by respective outputs

In [7]:
def prepare_sequences(notes, n_vocab):
    ### Preparation for the sequences used by the NN ###
    sequence_length = 100
    # length of each sequence = 100
    # this means that the previous 100 notes are used to predict the next note
    
    # get all pitch names
    pitchnames = sorted(set(item for item in notes))

     ### MAP FUNCTION ###
     # create a dictionary to map pitches to integers
    note_to_int = dict((note, number) for number, note in enumerate(pitchnames))

    network_input = []
    network_output = []

    # create input sequences and the corresponding outputs
    for i in range(0, len(notes) - sequence_length, 1):
        sequence_in = notes[i:i + sequence_length]
        sequence_out = notes[i + sequence_length]
        network_input.append([note_to_int[char] for char in sequence_in])
        network_output.append(note_to_int[sequence_out])

    n_patterns = len(network_input)

    # reshape the input into a format compatible with LSTM layers
    network_input = np.reshape(network_input, (n_patterns, sequence_length, 1))
    # normalize input
    network_input = network_input / float(n_vocab)

    network_output = np_utils.to_categorical(network_output)
    # one-hot encoding the output to convert the array of labelled data to categorical data

    return (network_input, network_output)

### Defining Model Parameters

In [8]:
def create_network(network_input, n_vocab):
    
    model = Sequential()
    model.add(LSTM(
        512,
        input_shape=(network_input.shape[1], network_input.shape[2]),
        recurrent_dropout=0.3,
        return_sequences=True
    ))
    model.add(LSTM(512, return_sequences=True, recurrent_dropout=0.3,))
    model.add(LSTM(512))
    model.add(BatchNorm())
    model.add(Dropout(0.3))
    model.add(Dense(256))
    model.add(Activation('relu'))
    model.add(BatchNorm())
    model.add(Dropout(0.3))
    model.add(Dense(n_vocab))
    model.add(Activation('softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='rmsprop')

    return model

### Model Training

In [9]:
def train(model, network_input, network_output):
    
    filepath = "weights-improvement-{epoch:02d}-{loss:.4f}-bigger.hdf5"
    checkpoint = ModelCheckpoint(
        filepath,
        monitor='loss',
        verbose=0,
        save_best_only=True,
        mode='min'
    )
    callbacks_list = [checkpoint]

    model.fit(network_input, network_output, epochs=200, batch_size=128, callbacks=callbacks_list)
    
if __name__ == '__main__':
    train_network()

Parsing midi_songs/route-16-dark-pop-remix-.mid
Parsing midi_songs/pokemon-center.mid
Parsing midi_songs/pallet-town.mid
Parsing midi_songs/route-1-boss-remix-.mid
Parsing midi_songs/pokemon-center-2-.mid
Parsing midi_songs/mt-moon.mid
Parsing midi_songs/viridian-forest.mid
Parsing midi_songs/route-24-2-.mid
Parsing midi_songs/lavender-town-2-v1-1-.mid
Parsing midi_songs/pokemon-tower-arranged-.mid
Parsing midi_songs/indigo-plateau.mid
Parsing midi_songs/yellow-pikachu-beach.mid
Parsing midi_songs/team-rocket-s-hideout-villainous-remix-.mid
Parsing midi_songs/route-12.mid
Parsing midi_songs/vermilion-city-2-.mid
Parsing midi_songs/bicycle-ride.mid
Parsing midi_songs/show-me-around-2-.mid
Parsing midi_songs/show-me-around.mid
Parsing midi_songs/male-trainer-encounter.mid
Parsing midi_songs/trainer-battle.mid
Parsing midi_songs/trainer-battle-arranged-xg-.mid
Parsing midi_songs/credits.mid
Parsing midi_songs/s-s-anne-remix-.mid
Parsing midi_songs/silph-co-.mid
Parsing midi_songs/celadon-

Epoch 150/200
Epoch 151/200
Epoch 152/200
Epoch 153/200
Epoch 154/200
Epoch 155/200
Epoch 156/200
Epoch 157/200
Epoch 158/200
Epoch 159/200
Epoch 160/200
Epoch 161/200
Epoch 162/200
Epoch 163/200
Epoch 164/200
Epoch 165/200
Epoch 166/200
Epoch 167/200
Epoch 168/200
Epoch 169/200
Epoch 170/200
Epoch 171/200
Epoch 172/200
Epoch 173/200
Epoch 174/200
Epoch 175/200
Epoch 176/200
Epoch 177/200
Epoch 178/200
Epoch 179/200
Epoch 180/200
Epoch 181/200
Epoch 182/200
Epoch 183/200
Epoch 184/200
Epoch 185/200
Epoch 186/200
Epoch 187/200
Epoch 188/200
Epoch 189/200
Epoch 190/200
Epoch 191/200
Epoch 192/200
Epoch 193/200
Epoch 194/200
Epoch 195/200
Epoch 196/200
Epoch 197/200
Epoch 198/200
Epoch 199/200
Epoch 200/200
