In [None]:
import os
import numpy as np
from util import *
import random
import gc
import tensorflow as tf
from tensorflow.keras import layers, models, Model, Input
import copy as c
import pickle

In [2]:
shuffle = True
seed = 420
batch_size = 1024
steps_per_epoch = 200
nepochs = 500
memlen = 64
mem_size = 5000
narrow_types = 4
train_txt_fp = 'sym/songs/songs_train.txt'
test_txt_fp = 'sym/songs/songs_test.txt'
model_dir = 'trained_models'
load_checkpoint = False
if not os.path.isdir(model_dir):
    os.makedirs(model_dir)

In [1]:
def generatorify_from_fp_list(dataset_fp_list, memlen = 7, batch_size = 50, mem_size = 50000, shuffle = False, bidirectional_audio = True):
    random.shuffle(dataset_fp_list)
    def _gener():
        k = 0
        hopper = 0
        song = None
        while True:
            while hopper < mem_size:     
                with open(dataset_fp_list[k], 'rb') as f:
                    loaded = pickle.load(f)
                charts, _ = loaded[0], loaded[1]
                del(loaded)
                k = (k + 1) % (len(dataset_fp_list) - 1)
                for chart in charts:
                    newsong = [[a[0] for a in chart], [a[1] for a in chart], [a[1] for a in chart]]
                    newsong[1] = [ddc_string_to_step(b) for b in newsong[1]]
                    newsong[1] = windowize(np.array(newsong[1]), frames=memlen)
                    newsong[0].append(0)
                    newsong[0] = [[newsong[0][i], newsong[0][i+1], 0] for i in range(len(newsong[0])-1)]
                    newsong[0][0][2] = 1
                    newsong[0] = windowize(np.array(newsong[0]), frames = memlen)
                    newsong[0] = np.concatenate((newsong[1],newsong[0]), axis = -1)
                    if song is None:
                        song = newsong
                    else:
                        for j in range(3):
                            song[j] = np.append(song[j],newsong[j], axis = 0)
                    hopper += len(newsong[0])

                if shuffle == True:
                    for i in range(3):
                        np.random.seed(seed)
                        song[i] = np.random.permutation(song[i])
                gc.collect()
            gc.collect()
                
            assert len(song[0])>batch_size+memlen

            success_take = 0
            miss_take = 0
            sd = []
            lb = []
            i = 0
            while success_take<batch_size:
                sd.append(list(song[0][i][:-1])+[[0 for j in range(narrow_types*4)]+list(song[0][i][-1][-3:])])    
                lb.append(sparse_to_categorical(sparceify([int(a) for a in list(song[2][i])]), 255))
                success_take += 1
                i += 1
            sd, lb = np.array(sd), np.array(lb)
            
            for j in range(3):
                song[j] = song[j][int(batch_size + miss_take):]
            hopper -= batch_size
            yield sd, lb
    return _gener()



def get_inputs_and_gens(trn_fp, tst_fp, shuffle = False, batch_size = 1000, bidirectional_audio = True, memlen = 64, aud_memlen = 15):
    trn_ds, tst_ds = get_dataset_fp_list(trn_fp, tst_fp)

    train_gen = generatorify_from_fp_list(trn_ds, batch_size=batch_size, shuffle = shuffle, mem_size=mem_size, memlen=memlen, aud_memlen=aud_memlen, bidirectional_audio=bidirectional_audio)
    test_gen = generatorify_from_fp_list(tst_ds, batch_size=batch_size, shuffle = shuffle, mem_size=mem_size, memlen=memlen, aud_memlen=aud_memlen, bidirectional_audio=bidirectional_audio)

    inp_shape_1 = (None,memlen+1,19)

    sym_inp = Input(shape = inp_shape_1[1:], batch_size = batch_size)

    return train_gen, test_gen, sym_inp

In [None]:
train_gen, test_gen, sym_inp = get_inputs_and_gens(train_txt_fp,
                                                      test_txt_fp, 
                                                      shuffle, 
                                                      batch_size=batch_size, 
                                                      memlen=memlen)

sym_proc = layers.LSTM(128, return_sequences = True, dropout = .5)(sym_inp)
sym_proc = layers.LSTM(128, return_sequences = False, dropout = .5)(sym_proc)

output = layers.Dense(256, activation = 'softmax')(sym_proc)

model = Model(sym_inp, output)

model.compile(
    optimizer=tf.keras.optimizers.SGD(learning_rate=1e0),
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits = False),
    metrics=[
    tf.keras.metrics.AUC(from_logits = False, curve = 'PR', name = 'auc'),
    tf.keras.metrics.F1Score(average = 'micro'),
    tf.keras.metrics.CategoricalAccuracy()
],
)
print(model.summary())
checkpoint_filepath = os.path.join(model_dir, 'sym_ddc_checkpoint.keras')
if load_checkpoint:
    if os.path.isfile(checkpoint_filepath):
        model.load_weights(checkpoint_filepath)
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    verbose = 0,
    save_best_only = True,
    monitor = 'val_auc',
    mode = 'max')

lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=10,
    min_lr=1e-6
)

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_auc',
    patience=20,
    restore_best_weights=True,
    mode = 'max',
    start_from_epoch = 100
)

model.fit(train_gen, 
          batch_size = batch_size, 
          epochs = nepochs, 
          steps_per_epoch = steps_per_epoch, 
          validation_steps = 20, 
          validation_data = test_gen, 
          callbacks = [model_checkpoint_callback, lr_scheduler, early_stopping])

model.save(model_dir + '/sym_ddc_model.keras')