In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals

import os
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH '] = 'true'

from os import listdir
from os.path import isfile, join
import sys

import numpy as np
from nltk.util import ngrams
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.layers import SpatialDropout1D, Dense, Flatten, BatchNormalization, Input, concatenate, Bidirectional, Dropout, Attention
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.optimizers import Nadam

from tensorflow.compat.v1.keras.layers import CuDNNLSTM
tf.compat.v1.disable_eager_execution()

# Data Prep

In [2]:
input_dir = "training_data/"

In [3]:
def get_file_names(mypath):
    return [f for f in listdir(mypath) if isfile(join(mypath, f))]


def get_notes(notes_path):
    all_song_notes = []
    file_names = get_file_names(notes_path)
    
    for file_name in file_names:
        file_notes = []
        with open(notes_path + file_name, 'r') as file:
            lines = file.readlines()
            file_notes = np.asarray([x.replace("\n", "").split(" ") for x in lines[3:]])
        all_song_notes.append(file_notes)
    return np.asarray(all_song_notes)


notes_by_song = get_notes(join(input_dir, "timings/"))
notes_by_song.shape

(98,)

In [4]:
def get_binary_rep(arrow_values):
    return (((np.asarray(arrow_values).astype(int)[:,None] & (1 << np.arange(4)))) > 0).astype(int)


def get_extended_binary_rep(arrow_combs):
    extended_binary_rep = []
    for i, arrow_comb in enumerate(arrow_combs):
        binary_rep = np.zeros((4, 4))
        for j, num in enumerate(list(arrow_comb)):
            binary_rep[int(num), j] = 1
        extended_binary_rep.append(binary_rep.ravel())
    return np.asarray(extended_binary_rep)


def create_tokens(timings):
    timings = timings.astype("float32")
    tokens = np.zeros((timings.shape[0], 3))
    tokens[0][0] = 1 # set start token
    next_note_token = np.append(timings[1:] - timings[:-1], np.asarray([0]))
    prev_note_token = np.append(np.asarray([0]),  next_note_token[: -1])
    tokens[:, 1] = prev_note_token.reshape(1, -1)
    tokens[:, 2] = next_note_token.reshape(1, -1)
    return tokens.astype("float32")


def get_notes_ngram(binary_notes, lookback):
    padding = np.zeros((lookback, binary_notes.shape[1]))
    data_w_padding = np.append(padding, binary_notes, axis = 0)
    return np.asarray(list(ngrams(data_w_padding, lookback)))


def get_all_note_combs():
    all_note_combs = []

    for i in range(0, 4):
        for j in range(0, 4):
            for k in range(0, 4):
                for l in range(0, 4):
                    all_note_combs.append(str(i)+str(j)+str(k)+str(l))

    all_note_combs = all_note_combs[1:]
    
    return all_note_combs


def data_prep(notes_by_song, lookback = 5):
    from sklearn.preprocessing import OneHotEncoder
    encoder = OneHotEncoder(categories='auto', sparse = False).fit(np.asarray(get_all_note_combs()).reshape(-1, 1))
    
    
    all_arrows = []
    all_tokens = []
    all_labels = []
    
    for notes in notes_by_song:
        ex_binary_notes = get_extended_binary_rep(notes[:, 0][:-1])
        
        notes_ngram = get_notes_ngram(ex_binary_notes, lookback)
        tokens = create_tokens(notes[:, 1].astype("float32"))
        labels = encoder.transform(notes[:, 0].reshape(-1, 1))
        
        
        all_arrows.append(notes_ngram)
        all_tokens.append(tokens)
        all_labels.append(labels)
        
    return np.concatenate(all_arrows), np.concatenate(all_tokens), np.concatenate(all_labels)

lookback = 64

all_arrows, all_tokens, all_labels = data_prep(notes_by_song, lookback = lookback)
all_arrows.shape, all_tokens.shape, all_labels.shape

((69957, 64, 16), (69957, 3), (69957, 255))

# Training

In [5]:
index_train, index_test, labels_train, labels_test = \
    train_test_split(np.asarray(range(len(all_labels))), all_labels, test_size=0.2, random_state = 42, shuffle = False)

arrows_train = all_arrows[index_train]
arrows_test = all_arrows[index_test]

tokens_train = np.expand_dims(all_tokens[index_train], axis = 1)
tokens_test = np.expand_dims(all_tokens[index_test], axis = 1)

(arrows_train.shape, tokens_train.shape, labels_train.shape), (arrows_test.shape, tokens_test.shape, labels_test.shape)


(((55965, 64, 16), (55965, 1, 3), (55965, 255)),
 ((13992, 64, 16), (13992, 1, 3), (13992, 255)))

In [7]:
def build(arrows_shape, token_shape, output_shape, silent = True):
    arrows = Input(shape = (arrows_shape[1], arrows_shape[2],))
    tokens = Input(shape = (1, token_shape[2],))
    
    x = CuDNNLSTM(256, kernel_initializer='glorot_normal', return_sequences = True)(arrows)
    #x = SeqSelfAttention(attention_activation='relu')(x)
    x = SpatialDropout1D(0.5)(x)
    #x = SpatialDropout1D(0.25)(x)
    x = CuDNNLSTM(128, kernel_initializer='glorot_normal', return_sequences = True)(x)
    #x = SpatialDropout1D(0.5)(x)
    x = Model(inputs = arrows, outputs = x)
    
    y = CuDNNLSTM(128, kernel_initializer='glorot_normal', return_sequences = True)(tokens)
    #y = SeqSelfAttention(attention_activation='relu')(y)
    y = SpatialDropout1D(0.5)(y)
    #y = SpatialDropout1D(0.25)(y)
    #y = CuDNNLSTM(128, kernel_initializer='glorot_normal', return_sequences = True)(y)
    #y = SpatialDropout1D(0.5)(y)
    y = Model(inputs = tokens, outputs = y)
    
    combined = concatenate([x.output, y.output], axis = 1)
    
    z = Flatten()(combined)
    z = Dropout(0.5)(z)
    z = Dense(output_shape[1], activation = "softmax")(z)
    model = Model(inputs = [x.input, y.input], outputs = z)
    model.compile(loss='categorical_crossentropy', optimizer='Nadam', metrics=['accuracy'])
    
    if not silent: 
        model.summary()
    
    return model

In [8]:
model = build(arrows_train.shape, tokens_train.shape, labels_train.shape, silent = False)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 64, 16)]     0                                            
__________________________________________________________________________________________________
cu_dnnlstm (CuDNNLSTM)          (None, 64, 256)      280576      input_1[0][0]                    
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 1, 3)]       0                                            
__________________________________________________________________________________________________
spatial_dropout1d (SpatialDropo (None, 64, 256)      0           cu_dnnlstm[0][0]                 
___________

In [9]:
batch_size = 64
model = build(arrows_train.shape, tokens_train.shape, labels_train.shape, silent = True)

callbacks = [EarlyStopping(monitor='val_loss', patience=5, verbose=0)]

history = model.fit([arrows_train, tokens_train], 
                    labels_train,
                    validation_data=([arrows_test, tokens_test], labels_test),
                    epochs = 100,
                    callbacks = callbacks,
                    batch_size= batch_size,
                    verbose = 1)

Train on 55965 samples, validate on 13992 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100


In [10]:
re_model = build(arrows_train.shape, tokens_train.shape, labels_train.shape, silent = True)
re_model.fit([all_arrows, np.expand_dims(all_tokens, axis=1)],
             all_labels,
             epochs = len(history.history['val_loss']),
             batch_size = batch_size,
             validation_split=0,
             verbose = 1)

Train on 69957 samples
Epoch 1/14
Epoch 2/14
Epoch 3/14
Epoch 4/14
Epoch 5/14
Epoch 6/14
Epoch 7/14
Epoch 8/14
Epoch 9/14
Epoch 10/14
Epoch 11/14
Epoch 12/14
Epoch 13/14
Epoch 14/14


<tensorflow.python.keras.callbacks.History at 0x1d61c876fd0>

# Predicting

In [15]:
re_model.save("models/retrained_arrow_model.h5")

In [12]:
from sklearn.preprocessing import OneHotEncoder
encoder = OneHotEncoder(categories='auto', sparse = False).fit(np.asarray(get_all_note_combs()).reshape(-1, 1))
pred_notes = []
ex_timings = notes_by_song[1][:,1]
ex_tokens = np.expand_dims(np.expand_dims(create_tokens(ex_timings), axis = 1), axis = 1)
notes_ngram = np.expand_dims(get_notes_ngram(np.zeros((1, 16)), lookback)[-1], axis = 0)
for i, token in enumerate(ex_tokens):
    #pred_arrow = np.argmax(model.predict([notes_ngram, token]))
    pred = re_model.predict([notes_ngram, token])
    pred_arrow = np.random.choice(all_labels.shape[1], 1, p=pred[0])[0]
    #pred_arrow = sample(model.predict([notes_ngram, token]))
    binary_rep = encoder.categories_[0][pred_arrow]
    pred_notes.append(binary_rep) 
    binary_note = get_extended_binary_rep([binary_rep])
    notes_ngram = np.roll(notes_ngram, -1, axis = 0)
    notes_ngram[0][-1] = binary_note
    print(i, notes_ngram[0][-1], binary_rep)

0 [1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] 0002
1 [1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] 0103
2 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
3 [0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1100
4 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
5 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
6 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
7 [1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] 0002
8 [0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] 1003
9 [0. 1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1010
10 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
11 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
12 [1. 1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0011
13 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
14 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
15 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
16 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
17 [0. 1. 1. 1. 1. 0. 0.

156 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
157 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
158 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
159 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
160 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
161 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
162 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
163 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
164 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
165 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
166 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
167 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
168 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
169 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
170 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
171 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
172 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 00

309 [1. 1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0011
310 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
311 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
312 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
313 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
314 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
315 [1. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0101
316 [0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] 1003
317 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
318 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
319 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
320 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
321 [1. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0.] 0230
322 [0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] 1300
323 [0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1100
324 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
325 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 00

465 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
466 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
467 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
468 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
469 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
470 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
471 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
472 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
473 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
474 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
475 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
476 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
477 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
478 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
479 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
480 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
481 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 00

620 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
621 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
622 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
623 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
624 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
625 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
626 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
627 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
628 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
629 [1. 0. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.] 0200
630 [1. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] 0300
631 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
632 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
633 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
634 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
635 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
636 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 00

776 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
777 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
778 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
779 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
780 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
781 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
782 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
783 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
784 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
785 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
786 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
787 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
788 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
789 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
790 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
791 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
792 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 01

931 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
932 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
933 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
934 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010


In [14]:
from sklearn.preprocessing import OneHotEncoder
encoder = OneHotEncoder(categories='auto', sparse = False).fit(np.asarray(get_all_note_combs()).reshape(-1, 1))
pred_notes = []
ex_timings = notes_by_song[1][:,1]
ex_tokens = np.expand_dims(np.expand_dims(create_tokens(ex_timings), axis = 1), axis = 1)
notes_ngram = np.expand_dims(get_notes_ngram(np.zeros((1, 16)), lookback)[-1], axis = 0)
for i, token in enumerate(ex_tokens):
    #pred_arrow = np.argmax(model.predict([notes_ngram, token]))
    pred = model.predict([notes_ngram, token])
    pred_arrow = np.random.choice(all_labels.shape[1], 1, p=pred[0])[0]
    #pred_arrow = sample(model.predict([notes_ngram, token]))
    binary_rep = encoder.categories_[0][pred_arrow]
    pred_notes.append(binary_rep) 
    binary_note = get_extended_binary_rep([binary_rep])
    notes_ngram = np.roll(notes_ngram, -1, axis = 0)
    notes_ngram[0][-1] = binary_note
    print(i, notes_ngram[0][-1], binary_rep)

0 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
1 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
2 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
3 [1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] 0020
4 [0. 1. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.] 1030
5 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
6 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
7 [0. 1. 1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.] 2001
8 [0. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.] 3010
9 [1. 0. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.] 0200
10 [1. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0.] 0310
11 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
12 [0. 1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] 2000
13 [0. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0.] 3001
14 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
15 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
16 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
17 [1. 1. 1. 0. 0. 0. 0.

154 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
155 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
156 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
157 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
158 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
159 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
160 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
161 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
162 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
163 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
164 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
165 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
166 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
167 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
168 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
169 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
170 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 10

307 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
308 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
309 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
310 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
311 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
312 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
313 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
314 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
315 [0. 1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] 2000
316 [0. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.] 3000
317 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
318 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
319 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
320 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
321 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
322 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
323 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 00

460 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
461 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
462 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
463 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
464 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
465 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
466 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
467 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
468 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
469 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
470 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
471 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
472 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
473 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
474 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
475 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
476 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 10

610 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
611 [0. 1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 1001
612 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
613 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
614 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
615 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
616 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
617 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
618 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
619 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
620 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
621 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
622 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
623 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
624 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
625 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
626 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 10

755 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
756 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
757 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
758 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
759 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
760 [1. 1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0011
761 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
762 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
763 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
764 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
765 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
766 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
767 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
768 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
769 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
770 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
771 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 10

906 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
907 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
908 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
909 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
910 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
911 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
912 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
913 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
914 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
915 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
916 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
917 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
918 [0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 1000
919 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 0001
920 [1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0010
921 [1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 0100
922 [1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] 00