In [1]:
import sys
import matplotlib.pyplot as plt
import numpy as np
import glob
import os
import pickle
import tensorflow as tf
from music21 import converter, instrument, note, chord, stream
from keras.layers import Input, Dense, Reshape, Dropout, LSTM, Bidirectional
from keras.layers import BatchNormalization, Activation
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model, load_model
from keras.optimizers import Adam
from keras.utils import np_utils
import time
from tqdm.notebook import trange, tqdm

In [2]:
# List all NVIDIA GPUs as avaialble in this computer (or Colab's session)
!nvidia-smi -L

GPU 0: NVIDIA GeForce GTX 1080 Ti (UUID: GPU-1f6e482f-2053-a9f6-b7ed-5888013cd4db)


In [3]:
print( f"Detected GPU(s): {tf.config.experimental.list_physical_devices('GPU')}" )

Detected GPU(s): [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [4]:
def get_notes(path):
    """ Get all the notes and chords from the midi files """
    notes = []

    for file in glob.glob(path + "/*.mid"):
        try:
            midi = converter.parse(file)
            print("Parsing %s" % file)
            
            notes_to_parse = None
            try: # file has instrument parts
                s2 = instrument.partitionByInstrument(midi)
                notes_to_parse = s2.parts[0].recurse() 
            except: # file has notes in a flat structure
                notes_to_parse = midi.flat.notes

            for element in notes_to_parse:
                if isinstance(element, note.Note):
                    notes.append(str(element.pitch))
                elif isinstance(element, chord.Chord):
                    notes.append('.'.join(str(n) for n in element.normalOrder))
        except:
            print("ERROR Parsing(SKIP) %s" % file)

    return notes

In [5]:
def prepare_sequences(notes, n_vocab, seq_len = 100):
    """ Prepare the sequences used by the Neural Network """
    sequence_length = seq_len

    # Get all pitch names
    pitchnames = sorted(set(item for item in notes))

    # Create a dictionary to map pitches to integers
    note_to_int = dict((note, number) for number, note in enumerate(pitchnames))

    network_input = []
    network_output = []

    # create input sequences and the corresponding outputs
    for i in range(0, len(notes) - sequence_length, 1):
        sequence_in = notes[i:i + sequence_length]
        sequence_out = notes[i + sequence_length]
        network_input.append([note_to_int[char] for char in sequence_in])
        network_output.append(note_to_int[sequence_out])

    n_patterns = len(network_input)

    # Reshape the input into a format compatible with LSTM layers
    network_input = np.reshape(network_input, (n_patterns, sequence_length, 1))
    
    # Normalize input between -1 and 1
    network_input = (network_input - float(n_vocab)/2) / (float(n_vocab)/2)
    # Normalize input between 0 and 1
    #network_input = network_input/float(n_vocab)
    network_output = np_utils.to_categorical(network_output)

    return (network_input, network_output)

In [6]:
def create_midi(prediction_output, filename):
    """ convert the output from the prediction to notes and create a midi file
        from the notes """
    offset = 0
    output_notes = []

    # create note and chord objects based on the values generated by the model
    for item in prediction_output:
        pattern = item[0]
        # pattern is a chord
        if ('.' in pattern) or pattern.isdigit():
            notes_in_chord = pattern.split('.')
            notes = []
            for current_note in notes_in_chord:
                output_notes.append(instrument.Piano())                
                new_note = note.Note(int(current_note))
                
                #new_note.storedInstrument = instrument.StringInstrument()
                notes.append(new_note)
            new_chord = chord.Chord(notes)
            new_chord.offset = offset
            output_notes.append(new_chord)
        # pattern is a note
        else:
            output_notes.append(instrument.Piano())
            new_note = note.Note(pattern)
            new_note.offset = offset            
            #new_note.storedInstrument = instrument.StringInstrument()
            output_notes.append(new_note)

        # increase offset each iteration so that notes do not stack
        offset += 0.5

    midi_stream = stream.Stream(output_notes)
    midi_stream.write('midi', fp='{}.mid'.format(filename))

In [10]:
def make_generator_model(seq_len,latent_dim = 500):
        model = Sequential()
        model.add(Dense(512, input_dim=latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(2048))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        #model.add(Dense(np.prod(self.seq_shape), activation='sigmoid'))
        model.add(Dense(seq_len, activation='tanh'))
        model.add(Reshape((seq_len,1)))
        model.summary()
        
        noise = Input(shape=(latent_dim,))
        seq = model(noise)
        return Model(noise, seq)

In [11]:
seq_len = 100
latent_dim = 500
generator = make_generator_model(seq_len,latent_dim)

noise = tf.random.normal([1, latent_dim])
generated_seq = generator(noise, training=False)

print(generated_seq)

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 512)               256512    
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 512)               0         
_________________________________________________________________
batch_normalization (BatchNo (None, 512)               2048      
_________________________________________________________________
dense_1 (Dense)              (None, 1024)              525312    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 1024)              0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 1024)              4096      
_________________________________________________________________
dense_2 (Dense)              (None, 2048)              2

In [12]:
def make_discriminator_model(seq_len):
    model = Sequential()
    model.add(LSTM(512, input_shape=(seq_len,1), return_sequences=True))
    model.add(Bidirectional(LSTM(512)))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1))
    model.summary()

    seq = Input(shape=(seq_len,1))
    validity = model(seq)

    return Model(seq, validity)

In [13]:
discriminator = make_discriminator_model(seq_len)
decision = discriminator(generated_seq)
print (decision)

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm (LSTM)                  (None, 100, 512)          1052672   
_________________________________________________________________
bidirectional (Bidirectional (None, 1024)              4198400   
_________________________________________________________________
dense_4 (Dense)              (None, 1024)              1049600   
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 1024)              0         
_________________________________________________________________
dense_5 (Dense)              (None, 512)               524800    
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_6 (Dense)              (None, 1)                

In [14]:
# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [15]:
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

In [16]:
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

In [17]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [18]:
checkpoint_dir = './training_checkpoints3'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

In [16]:
path = 'Piano Sonatas'
notes = get_notes(path)

Parsing Piano Sonatas\K19d Piano Sonata Duet.mid
Parsing Piano Sonatas\K279 Piano sonata n01 1mov.mid
Parsing Piano Sonatas\K279 Piano sonata n01 2mov.mid
Parsing Piano Sonatas\K279 Piano sonata n01 3mov.mid
Parsing Piano Sonatas\K280 Piano sonata n02 1mov.mid
Parsing Piano Sonatas\K280 Piano sonata n02 2mov.mid
Parsing Piano Sonatas\K280 Piano sonata n02 3mov.mid
Parsing Piano Sonatas\K281 Piano Sonata n03 1mov.mid
Parsing Piano Sonatas\K281 Piano Sonata n03 2mov.mid
Parsing Piano Sonatas\K281 Piano Sonata n03 3mov.mid
Parsing Piano Sonatas\K282 Piano Sonata n04 .mid
Parsing Piano Sonatas\K284 Piano Sonata n06 .mid
Parsing Piano Sonatas\K309 Piano Sonata n10 1mov.mid
Parsing Piano Sonatas\K309 Piano Sonata n10 2mov.mid
Parsing Piano Sonatas\K309 Piano Sonata n10 3mov.mid
Parsing Piano Sonatas\K330 Piano Sonata n10 1mov.mid
Parsing Piano Sonatas\K330 Piano Sonata n10 2mov.mid
Parsing Piano Sonatas\K330 Piano Sonata n10 3mov.mid
Parsing Piano Sonatas\K331 Piano Sonata n11 .mid
Parsing P

In [19]:
filename = 'notes_chopin'
with open('notes/'+filename, 'rb') as filepath:
    notes = pickle.load(filepath)

In [17]:
n_vocab = len(set(notes))
X_train, y_train = prepare_sequences(notes, n_vocab)
print(X_train.shape[0])

236293


In [18]:
BATCH_SIZE = 256
BUFFER_SIZE = int(X_train.shape[0]/10)
train_dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

In [19]:
EPOCHS = 50
noise_dim = 500
#num_examples_to_generate = 16

# You will reuse this seed overtime (so it's easier)
#seed = tf.random.normal([num_examples_to_generate, noise_dim])

In [20]:
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(real_data):  
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        fake_data = generator(noise, training=True)

        real_output = discriminator(real_data, training=True)
        fake_output = discriminator(fake_data, training=True)

        disc_loss = discriminator_loss(real_output, fake_output)        
        gen_loss = generator_loss(fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    return gen_loss, disc_loss

In [21]:
def plot_loss(g_loss, d_loss):    
    plt.plot(d_loss, c='red')
    plt.plot(g_loss, c='blue')
    plt.title("GAN Loss per Epoch (Chopin)")
    plt.legend(['Discriminator', 'Generator'])
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.savefig('plot/GAN_Loss_per_Epoch_Chopin.png')
    plt.close()

In [22]:
def train(dataset, epochs):
    g_loss = [] 
    d_loss = []
    for epoch in tqdm(range(epochs), desc='Epoch:'):     
        start = time.time()
        batch_g_loss = [] 
        batch_d_loss = []
        for note_batch in tqdm(dataset, desc='Batch:', leave = False):
            gen_loss, disc_loss = train_step(note_batch)
            batch_g_loss.append(gen_loss.numpy())
            batch_d_loss.append(disc_loss.numpy())
        # Save the model every 10 epochs
        if (epoch + 1) % 10 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)
        g_loss.append( float (sum(batch_g_loss)/len(batch_g_loss)) )
        d_loss.append( float (sum(batch_d_loss)/len(batch_d_loss)) )
        print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

    # Generate after the final epoch
    generate(generator, notes, 'gan_chopin')
    generator.save('model/model_gen_chopin.hdf5')
    plot_loss(g_loss, d_loss)

In [20]:
def generate(model, input_notes, filename, latent_dim = 500):
    # Get pitch names and store in a dictionary
    notes = input_notes
    pitchnames = sorted(set(item for item in notes))
    int_to_note = dict((number, note) for number, note in enumerate(pitchnames))

    # Use random noise to generate sequences
    noise = tf.random.normal([1, latent_dim])
    #noise = np.random.normal(0, 1, (1, latent_dim))
    predictions = model.predict(noise)
    #for normalized -1 to 1
    p = len(pitchnames)/2
    pred_notes = [x*p+p for x in predictions[0]]
    #for normalized 0 to 1
    #p = len(pitchnames)
    #pred_notes = [x*p for x in predictions[0]]
    pred_notes = [int_to_note[int(x)] for x in pred_notes]
    print('pred_notes:',pred_notes)
    create_midi(pred_notes, filename)

In [24]:
train(train_dataset, EPOCHS)

Epoch::   0%|          | 0/50 [00:00<?, ?it/s]

Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 1 is 1177.4114904403687 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 2 is 1155.535148859024 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 3 is 1176.2864968776703 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 4 is 1179.0138175487518 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 5 is 1176.6950953006744 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 6 is 1210.861605167389 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 7 is 1174.195665359497 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 8 is 1145.7797470092773 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 9 is 1201.9405312538147 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 10 is 1195.4035806655884 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 11 is 1168.7515313625336 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 12 is 1161.1961178779602 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 13 is 1218.5167307853699 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 14 is 1179.2723305225372 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 15 is 1178.849612236023 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 16 is 1180.8568227291107 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 17 is 1180.9270164966583 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 18 is 1195.5957033634186 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 19 is 1194.5388247966766 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 20 is 1195.7495908737183 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 21 is 1219.742265701294 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 22 is 1253.3908994197845 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 23 is 1219.8371691703796 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 24 is 1216.1593153476715 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 25 is 1197.077392578125 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 26 is 1226.2160828113556 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 27 is 1216.1699674129486 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 28 is 1224.25031042099 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 29 is 1218.6020030975342 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 30 is 1224.2223620414734 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 31 is 1209.0956931114197 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 32 is 1188.3309259414673 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 33 is 1198.052397966385 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 34 is 1211.461730480194 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 35 is 1210.3528175354004 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 36 is 1199.704835653305 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 37 is 1215.1956152915955 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 38 is 1216.5816662311554 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 39 is 1230.9336602687836 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 40 is 1220.8138840198517 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 41 is 1234.3891031742096 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 42 is 1227.6885380744934 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 43 is 1181.852374792099 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 44 is 1150.089703321457 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 45 is 1147.8264601230621 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 46 is 1148.6581389904022 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 47 is 1153.5840742588043 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 48 is 1150.323100566864 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 49 is 1123.4426946640015 sec


Batch::   0%|          | 0/924 [00:00<?, ?it/s]

Time for epoch 50 is 1133.115144252777 sec
pred_notes: ['9.0.2.4', '3.7.8.10', 'G5', 'F3', '9.1.2', 'G5', 'B-5', 'D4', 'G#3', 'G4', '9.11.1', 'F6', 'C4', 'G#1', 'C#2', 'F6', 'C#3', 'G1', 'C4', 'C#2', 'F6', 'C2', 'E-2', 'B6', 'D6', 'B-5', 'E-3', 'G1', 'B5', 'C6', 'D2', 'F1', 'G3', 'B3', 'C6', 'G4', 'B4', 'F6', 'F3', 'C#6', 'G#3', 'C#4', 'C#6', 'F#1', 'C5', 'G5', 'E-5', 'F6', 'F#6', 'E1', 'B3', 'G#3', 'F#1', 'C#3', 'G4', 'G5', 'B-5', 'F1', 'G6', 'G#3', 'E1', 'G#2', 'C#4', 'F#2', 'E7', 'G#1', 'B5', 'A6', 'B6', 'G1', 'B4', 'E6', 'F#3', 'B-5', 'G3', 'D6', 'C#4', 'G2', 'C#4', 'E5', 'F#6', 'C#2', 'G#3', 'C#3', 'F3', 'F5', 'C#6', 'B5', 'F#3', 'A1', 'G2', 'F6', 'G3', 'E2', 'F#1', 'G#3', 'E-2', '7.11.0', 'D7', 'F#6']


In [27]:
generate(generator, notes, 'gan_chopin_3')

pred_notes: ['9.0.2.5', '6.8.0', '9.0.3.5', '6.8.0', '7.8', '5.6', '3.6.10', '8.0', '4.8', '9.11.2.4', '8.11.2', '3.7.8.10', '4.6.7', '9.0.2.5', '11.0.2', '5.8', '3.6.7', '4.8.11', '5.8', '7.10', '5.7.9', '3.5.8', '8.9', '7.9.0.3', '7.11.2', '7.10.0.3', '4.5.7.10', '3.5.9', '3.9', '6.10', '3.4.8', '6.9', '7.8.11', '7.10.1', '2.3.8', '7.11.1', '7.0', '5.7', '7.10.2', '6.9.0', '4.6.9.11', '2.4.9', '8.11.2.4', '7.8.11', '7.11.1', '2.5.8', '7.10.1', '6.7.9', '5', '3.7.9', '5.7', '5.7.11.0', '4.8.11', '5.8.11', '2.6', '5.7', '9.1.2', '1.5.7', '4.6.11', '3.5.9', '4', '11.4', '6.7', '3.7.9', '3.6.9', '7.8', '5.8', '3.6.8', '3.6.9', '3.6.10', '6.9.0', '2.5', '11.4', 'C#2', '6.7.11.2', '7.11.2', '7.11.1', '9.11.1', '5.7.9.11.2', '2.5.8.11', '4.5.9', '3.6.10', '6.7.11.2', '4.6.11', '8.1', '3.5.9', '9.0.3.5', '2.3.8', '5.7.11', '6.9.0.2', '6.9', '3.5.8', '2.4.7', '9.11.4', '7.9.0', '7.10.0', '5.7', '5.10', '2.5.9', '3.6']


## Export chopin note (Required when generating music)

In [48]:
with open('notes/notes_chopin', 'wb') as filepath:
    pickle.dump(notes, filepath)

## Restore checkpoint

In [25]:
#checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
checkpoint.restore(tf.train.get_checkpoint_state(checkpoint_dir, latest_filename='ckpt-3' ))

<tensorflow.python.training.tracking.util.InitializationOnlyStatus at 0x2d0e110bfd0>

In [26]:
 generator.save('model/model_gen_chopin_ckpt3.hdf5')