In [None]:
import random
import tensorflow as tf
from tensorflow.keras import layers, models, Model, Input
import numpy as np
from util import *
import os
import gc
import copy as c
import pickle
import json

In [2]:
shuffle = True
seed = 420
batch_size = 256
memlen = 100
context_radius = 7
steps_per_epoch = 200
nepochs = 1000
nmelbands = 80
nchannels = 3
npred_steps = 5
mem_size = 2500
load_checkpoint = False
name_from_fp = lambda x: os.path.splitext(os.path.split(x)[1])[0]
model_dir = 'old_trained_models'
train_txt_fp = 'json/songs/songs_train.txt'
test_txt_fp = 'json/songs/songs_test.txt'
feats_dir = 'feats/songs'
diff_dict = {
    'Beginner': 0,
    'Easy': 1,
    'Medium': 2,
    'Hard': 3,
    'Challenge': 4,
    'Edit': 4
}

In [3]:
def generatorify_from_fp_list(dataset_fp, 
                              memlen = 100, 
                              batch_size = 50, 
                              mem_size = 10000, 
                              shuffle = False):
    def _gener():
        k = 0
        hopper = 0
        song = None
        song_feats = None
        with open(dataset_fp, 'r') as f:
            json_fps = f.read().splitlines()
            json_fps = list(np.unique(json_fps))
        np.random.seed(seed)
        json_fps = list(np.random.permutation(json_fps))
        while True:
            while hopper < mem_size:
                song = None
                hopper = 0
                json_fp = json_fps[k]
                k = (k + 1) % (len(json_fps) - 1)
                with open(json_fp, 'r') as json_f:
                    meta = json.loads(json_f.read())
                json_name = name_from_fp(json_fp)
                song_feats_fp = os.path.join(feats_dir, '{}.pkl'.format(json_name))
                with open(song_feats_fp, 'rb') as f:
                    song_feats = pickle.load(f)

                newsong = [[],[],[]]

                for chart in meta['charts']:
                    if not chart['type'] or chart['type'] != 'dance-double':
                        placed_notes = []
                        for note in chart['notes']:
                            if note[3] != '0000':
                                placed_notes.append(int(round(note[2]*100)))
                        for j in range(0,len(song_feats),npred_steps):
                            newsong[0]+= [j]
                            newsong[1].append([[diff_dict[chart['difficulty_coarse']]] for _ in range(memlen+npred_steps-1)])
                            stick_on = []
                            for i in range(j,j+5):
                                if i in placed_notes:
                                    stick_on.append([1])
                                else:
                                    stick_on.append([0])
                            newsong[2].append(stick_on)

                if song is None:
                    song = newsong
                else:
                    for j in range(3):
                        song[j] = np.append(song[j], newsong[j], axis = 0)
                hopper += len(newsong[0])

                if shuffle == True:
                    for i in range(3):
                        np.random.seed(seed)
                        song[i] = np.random.permutation(song[i])
                gc.collect()
            gc.collect()
                
            assert len(song[0])>5*batch_size

            ac = []
            sd = []
            lb = []
            for i in range(0,npred_steps*batch_size,npred_steps):
                ac.append(make_onset_feature_context(song_feats, song[0][i], radius = 3+npred_steps, left_radius = memlen+3))
                sd.append(song[1][i])
                lb.append(song[2][i])
                
            ac, sd, lb = np.array(ac), np.array(sd), np.squeeze(np.array(lb))
            
            for j in range(3):
                song[j] = song[j][int(npred_steps*batch_size):]
            assert(len(song[0])==len(song[1]) and len(song[1])==len(song[2]))
            hopper -= npred_steps*batch_size
            gc.collect()
            yield (ac,sd), lb
    return _gener()

def get_inputs_and_gens(trn_fp, tst_fp, shuffle = False, batch_size = 1000, memlen = 8, mem_size = 2500):

    inp_shape_0 = (None,memlen + 7 + npred_steps ,nmelbands,nchannels)
    inp_shape_1 = (None,memlen+npred_steps-1,1)

    train_gen = generatorify_from_fp_list(trn_fp, 
                                          batch_size=batch_size, 
                                          shuffle = shuffle, 
                                          mem_size=mem_size,
                                          memlen = memlen)
    test_gen = generatorify_from_fp_list(tst_fp, 
                                         batch_size=batch_size, 
                                         shuffle = shuffle, 
                                         mem_size=mem_size,
                                         memlen = memlen)

    audio_ctx_inp = Input(shape = inp_shape_0[1:], batch_size = batch_size)
    stream_inp = Input(shape = inp_shape_1[1:], batch_size = batch_size)

    return train_gen, test_gen, audio_ctx_inp, stream_inp

In [None]:
if not os.path.isdir(model_dir):
    os.makedirs(model_dir)

train_gen, test_gen, audio_ctx_inp, stream_inp= get_inputs_and_gens(train_txt_fp, 
                                                                  test_txt_fp,
                                                                  shuffle, 
                                                                  batch_size=batch_size, 
                                                                  memlen = memlen,
                                                                   mem_size = mem_size)

audio_proc = layers.BatchNormalization()(audio_ctx_inp)
audio_proc = layers.Conv2D(10, (7,3))(audio_proc)
audio_proc = layers.MaxPooling2D((1,3), strides = (1,3))(audio_proc)
audio_proc = layers.Conv2D(20, (3,3))(audio_proc)
audio_proc = layers.MaxPooling2D((1,3), strides = (1,3))(audio_proc)

audio_out = layers.Reshape((memlen+npred_steps-1,-1))(audio_proc)

stream_merge = layers.Concatenate(axis = -1)([audio_out, stream_inp])

note_comp = layers.LSTM(200, return_sequences = True, dropout = .5)(stream_merge)
note_comp = layers.LSTM(200, dropout = .5)(note_comp)

note_comp = layers.Dense(256, activation = 'relu')(note_comp)
note_comp = layers.Dense(128, activation = 'relu')(note_comp)

output = layers.Dense(npred_steps, activation = 'sigmoid')(note_comp)

model = Model([audio_ctx_inp, stream_inp], output)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    #optimizer = tf.keras.optimizers.SGD(learning_rate = 1e-2, clipvalue = 5),
    loss=tf.keras.losses.BinaryCrossentropy(from_logits = False),
    metrics=[
    tf.keras.metrics.AUC(from_logits = False, curve = 'PR', name = 'auc'),
    tf.keras.metrics.F1Score(average = 'micro', threshold = .5, name = 'f1'),
    tf.keras.metrics.BinaryAccuracy(name = 'acc'),
],
)
print(model.summary())
checkpoint_filepath = os.path.join(model_dir, 'onset_ddc_checkpoint.keras')
if load_checkpoint:
    if os.path.isfile(checkpoint_filepath):
        print(True)
        model.load_weights(checkpoint_filepath)
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    verbose = 0,
    save_best_only = True,
    monitor = 'val_auc',
    mode = 'max')

lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=5,
    min_lr=1e-6
)

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_auc',
    patience=20,
    restore_best_weights=True,
    mode = 'max',
    start_from_epoch = 100
)

model.fit(train_gen, 
          batch_size = batch_size, 
          epochs = nepochs, 
          steps_per_epoch = steps_per_epoch, 
          validation_steps = 20, 
          validation_data = test_gen, 
          callbacks = [model_checkpoint_callback, lr_scheduler, early_stopping])

model.save(model_dir + '/onset_ddc_model.keras')
tf.keras.backend.clear_session()