In [8]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, Dropout
from tensorflow.keras.layers import LSTM
from tensorflow.keras.utils import get_file
import numpy as np
import random
import sys
import os
import pdb

In [10]:
tf.version.VERSION

'2.1.0'

## Model creation

In [None]:
def get_model(maxlen, num_chars, num_layers):
    print('Build model...')
    model = Sequential()
    for layer_idx in range(num_layers):
        if layer_idx == 0:
            model.add(LSTM(num_units, return_sequences=True, input_shape=(maxlen, num_chars)))
        else:
            model.add(LSTM(num_units, return_sequences=False))
        model.add(Dropout(0.2))

    model.add(Dense(num_chars))
    model.add(Activation('softmax'))

    model.compile(loss='categorical_crossentropy', optimizer='adam')
    return model

In [None]:
def sample(a, temperature=1.0):
    # helper function to sample an index from a probability array
    a = np.log(a) / temperature
    a = np.exp(a) / np.sum(np.exp(a))
    return np.argmax(np.random.multinomial(1, a, 1))

In [None]:
def run(is_character=False, maxlen=None, num_units=None, model_prefix=''):

    character_mode = is_character

    if character_mode:
        if maxlen == None:
            maxlen = 1024
        if num_units == None:
            num_units = 32
        step = 2*17 # step to create training data for truncated-BPTT
    else: # word mode
        if maxlen == None:
            maxlen = 128 # maxlength used in RNN input
        if num_units == None: 
            num_units = 512 #number of unit per layer LSTM 512 
        step = 8

    if character_mode:
        num_char_pred = maxlen*3/2
    else: 
        num_char_pred = 17*30 #this should be the number of elements predicted in the output. How "long" is my output sequence

    num_layers = 2
    # 
    if character_mode:
        prefix = 'char'
    else:
        prefix = 'word'

    path = 'metallica_drums_text.txt' # Corpus file
    text = open(path).read()
    print('corpus length:', len(text))

    if character_mode:
        chars = set(text)
    else:
        chord_seq = text.split(' ')
        chars = set(chord_seq) #contains the unique words in my dictionary. They are 119
        text = chord_seq #contains the full text in an array format. Each entry of my array is a word of type 0xb0110101010 

    char_indices = dict((c, i) for i, c in enumerate(chars))
    indices_char = dict((i, c) for i, c in enumerate(chars))
    num_chars = len(char_indices)
    print('total chars:', num_chars)

    # cut the text in semi-redundant sequences of maxlen characters

    sentences = []
    next_chars = []
    for i in range(0, len(text) - maxlen, step):
        sentences.append(text[i: i + maxlen])
        next_chars.append(text[i + maxlen])
    print('nb sequences:', len(sentences))
    print('Vectorization...')

    X = np.zeros((len(sentences), maxlen, num_chars), dtype=np.bool)
    y = np.zeros((len(sentences), num_chars), dtype=np.bool)
    for i, sentence in enumerate(sentences):
        for t, char in enumerate(sentence):
            X[i, t, char_indices[char]] = 1
        y[i, char_indices[next_chars[i]]] = 1

    # build the model: 2 stacked LSTM
    model = get_model(maxlen, num_chars, num_layers)

    result_directory = 'result_%s_%s_%d_%d_units/' % (prefix, model_prefix, maxlen, num_units)
    filepath_model = '%sbest_model.hdf' % result_directory
    description_model = '%s, %d layers, %d units, %d maxlen, %d steps' % (prefix, num_layers, num_units, maxlen, step)
    checker = keras.callbacks.ModelCheckpoint(filepath_model, monitor='loss', verbose=0, save_best_only=True, mode='auto')
    early_stop = keras.callbacks.EarlyStopping(monitor='loss', patience=15, verbose=0, mode='auto')

    if not os.path.exists(result_directory):
        os.mkdir(result_directory)

    # write a description file.
    with open(result_directory+description_model, 'w') as f_description:
        pass

    # train the model, output generated text after each iteration
    batch_size = 128
    loss_history = []
    pt_x = [1,29,30,40,100,100,200,300,400]
    nb_epochs = [np.sum(pt_x[:i+1]) for i in range(len(pt_x))]

    # not random seed, but the same seed for all.
    start_index = random.randint(0, len(text) - maxlen - 1)

    for iteration, nb_epoch in zip(pt_x,nb_epochs):
        if os.path.exists('stop_asap.keunwoo'):
            os.remove('stop_asap.keunwoo')
            break

        print('-' * 50)
        print('Iteration', iteration)

        result = model.fit(X, y, batch_size=batch_size, nb_epoch=nb_epoch, callbacks=[checker, early_stop]) 
        loss_history = loss_history + result.history['loss']

        print 'Saving model after %d epochs...' % nb_epoch
        model.save_weights('%smodel_after_%d.hdf'%(result_directory, nb_epoch), overwrite=True)

        for diversity in [0.9, 1.0, 1.2]:
            with open(('%sresult_%s_iter_%02d_diversity_%4.2f.txt' % (result_directory, prefix, iteration, diversity)), 'w') as f_write:

                print()
                print('----- diversity:', diversity)
                f_write.write('diversity:%4.2f\n' % diversity)
                if character_mode:
                    generated = ''
                else:
                    generated = []
                sentence = text[start_index: start_index + maxlen]
                seed_sentence = text[start_index: start_index + maxlen]

                if character_mode:
                    generated += sentence
                else:
                    generated = generated + sentence


                print('----- Generating with seed:')

                if character_mode:
                    print(sentence)
                    sys.stdout.write(generated)
                else:
                    print(' '.join(sentence))

                for i in xrange(num_char_pred):
                    # if generated.endswith('_END_'):
                    # 	break
                    x = np.zeros((1, maxlen, num_chars))

                    for t, char in enumerate(sentence):
                        x[0, t, char_indices[char]] = 1.

                    preds = model.predict(x, verbose=0)[0]
                    next_index = sample(preds, diversity)
                    next_char = indices_char[next_index]

                    if character_mode:
                        generated += next_char
                        sentence = sentence[1:] + next_char
                    else:
                        generated.append(next_char)
                        sentence = sentence[1:]
                        sentence.append(next_char)

                    if character_mode:
                        sys.stdout.write(next_char)
                    # else:
                    # 	for ch in next_char:
                    # 		sys.stdout.write(ch)	

                    sys.stdout.flush()

                if character_mode:
                    f_write.write(seed_sentence + '\n')
                    f_write.write(generated)
                else:
                    f_write.write(' '.join(seed_sentence) + '\n')
                    f_write.write(' ' .join(generated))

        np.save('%sloss_%s.npy'%(result_directory, prefix), loss_history)

    print 'Done! You might want to run main_post_process.py to get midi files. '
    print 'You need python-midi (https://github.com/vishnubob/python-midi) to run it.'


# Testing code
Basically i'm taking part of the main code and printing them here to understand better what it does.

In [17]:
path = 'metallica_drums_text.txt' # Corpus file
text = open(path).read()

In [18]:
chord_seq = text.split(' ')
chars = set(chord_seq)
text = chord_seq
print(chars)
print(len(chars))

{'', '0b100001000', '0b001000010', '0b110000101', '0b000011000', '0b001010000', '0b101010000', '0b100001101', '0b011000011', '0b010001100', '0b011000100', '0b100010011', '0b110001100', '0b011001011', '0b010000011', '0b000000100', '0b110010100', '0b000000001', '0b010100011', '0b100001100', '0b100010100', '0b010100000', '0b000010100', '0b100000010', '0b101000000', '0b000101000', '0b110000010', '0b001010010', '0b000000011', '0b110000000', '0b000010010', '0b111000011', '0b100000000', '0b111000000', '0b110001000', '0b001000100', '0b001000001', '0b111010000', '0b000100000', '0b001001101', '0b010000001', '0b001001000', '0b100000101', '0b111000001', '0b010001101', '0b000001100', '0b010010000', '0b000000101', '0b010000010', '0b101100000', '0b000100011', '0b001100000', '0b011001010', '0b100011000', '0b010001001', '0b100000100', '0b000010001', '0b101000001', '0b101011000', '0b001000111', '0b101011011', '0b000001001', '0b101000101', '0b001000101', '0b010000101', '0b111100001', '0b110100001', '0b01

In [13]:
print(text)

['0b010000000', '0b010000000', '0b000000000', '0b010000000', '0b010000000', '0b000001000', '0b000000000', '0b000001000', '0b010000000', '0b010000000', '0b000000000', '0b010000000', '0b010000000', '0b000001000', '0b000000000', '0b000001000', 'BAR', '0b010000000', '0b010000000', '0b000000000', '0b010000000', '0b010000000', '0b000001000', '0b000000000', '0b000001000', '0b010000000', '0b000000000', '0b000000000', '0b000001000', '0b000000000', '0b000001000', '0b000001000', '0b000000000', 'BAR', '0b100000001', '0b000000000', '0b000000000', '0b000000000', '0b010000001', '0b000000000', '0b000000000', '0b000000000', '0b100000001', '0b000000000', '0b000000000', '0b000000000', '0b010000001', '0b000000000', '0b000000000', '0b000000000', 'BAR', '0b100000001', '0b000000000', '0b000000000', '0b000000000', '0b010000001', '0b000000000', '0b000000000', '0b000000000', '0b100000001', '0b000000000', '0b000000000', '0b000000000', '0b010000001', '0b000000000', '0b000000000', '0b000000000', 'BAR', '0b10000000

In [19]:
char_indices = dict((c, i) for i, c in enumerate(chars))
indices_char = dict((i, c) for i, c in enumerate(chars))
num_chars = len(char_indices)
print('total chars:', num_chars)

total chars: 119


In [20]:
print(char_indices)

{'': 0, '0b100001000': 1, '0b001000010': 2, '0b110000101': 3, '0b000011000': 4, '0b001010000': 5, '0b101010000': 6, '0b100001101': 7, '0b011000011': 8, '0b010001100': 9, '0b011000100': 10, '0b100010011': 11, '0b110001100': 12, '0b011001011': 13, '0b010000011': 14, '0b000000100': 15, '0b110010100': 16, '0b000000001': 17, '0b010100011': 18, '0b100001100': 19, '0b100010100': 20, '0b010100000': 21, '0b000010100': 22, '0b100000010': 23, '0b101000000': 24, '0b000101000': 25, '0b110000010': 26, '0b001010010': 27, '0b000000011': 28, '0b110000000': 29, '0b000010010': 30, '0b111000011': 31, '0b100000000': 32, '0b111000000': 33, '0b110001000': 34, '0b001000100': 35, '0b001000001': 36, '0b111010000': 37, '0b000100000': 38, '0b001001101': 39, '0b010000001': 40, '0b001001000': 41, '0b100000101': 42, '0b111000001': 43, '0b010001101': 44, '0b000001100': 45, '0b010010000': 46, '0b000000101': 47, '0b010000010': 48, '0b101100000': 49, '0b000100011': 50, '0b001100000': 51, '0b011001010': 52, '0b100011000'

In [21]:
print(indices_char)

{0: '', 1: '0b100001000', 2: '0b001000010', 3: '0b110000101', 4: '0b000011000', 5: '0b001010000', 6: '0b101010000', 7: '0b100001101', 8: '0b011000011', 9: '0b010001100', 10: '0b011000100', 11: '0b100010011', 12: '0b110001100', 13: '0b011001011', 14: '0b010000011', 15: '0b000000100', 16: '0b110010100', 17: '0b000000001', 18: '0b010100011', 19: '0b100001100', 20: '0b100010100', 21: '0b010100000', 22: '0b000010100', 23: '0b100000010', 24: '0b101000000', 25: '0b000101000', 26: '0b110000010', 27: '0b001010010', 28: '0b000000011', 29: '0b110000000', 30: '0b000010010', 31: '0b111000011', 32: '0b100000000', 33: '0b111000000', 34: '0b110001000', 35: '0b001000100', 36: '0b001000001', 37: '0b111010000', 38: '0b000100000', 39: '0b001001101', 40: '0b010000001', 41: '0b001001000', 42: '0b100000101', 43: '0b111000001', 44: '0b010001101', 45: '0b000001100', 46: '0b010010000', 47: '0b000000101', 48: '0b010000010', 49: '0b101100000', 50: '0b000100011', 51: '0b001100000', 52: '0b011001010', 53: '0b100011