In [None]:
%load_ext autoreload 
%autoreload 2
import os
import sys
sys.path.append('../')
from lmtd9 import LMTD
from lmtd9 import database as db
from lmtd9 import evaluation
from keras import Model
from keras.layers import Input, Convolution1D, GlobalMaxPooling1D, merge, Dense, Dropout
from keras.layers import BatchNormalization, Activation

In [None]:
time_steps = 240
nb_features = 2048
nb_classes = 9
conv_filters = 384
dropout = 0.5
max_epochs = 10

In [None]:
lmtd = LMTD() # Creating an LMTD object for handling lmtd data

In [None]:
# Be sure to update your LMTD_PATH
LMTD_PATH = # insert your $LMTD_PATH here

features_path = os.path.join(LMTD_PATH, 'features', 'lmtd9_resnet152.pickle')
lmtd.load_precomp_features(features_file=features_path)

In [None]:
x_valid, x_valid_len, y_valid, valid_ids = lmtd.get_split('valid')
x_train, x_train_len, y_train, train_ids = lmtd.get_split('train')
x_test,  x_test_len,  y_test,  test_ids  = lmtd.get_split('test')

In [None]:
# A very simple architecture for fast training
inputs = Input(shape=(time_steps, nb_features))
x = BatchNormalization()(inputs)
x = Convolution1D(conv_filters, kernel_size=3)(inputs)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = GlobalMaxPooling1D()(x)
x = Dropout(dropout)(x)
out = Dense(nb_classes, activation='sigmoid')(x)

model = Model(inputs, out)

In [None]:
def print_result(result):
    for k, v in result.iteritems():
        try:
            print '{:<15s}'.format(lmtd.genres[k][2:]),
        except IndexError:
            print '{:<15s}'.format(k.title()),
        print '{:5.4f}'.format(v)


In [None]:
print model.summary()

In [None]:
model.compile('Adam', 'binary_crossentropy')

In [None]:
for epoch in range(max_epochs):
    
    model.fit(x_train, y_train,
              validation_data=(x_valid, y_valid), 
              initial_epoch=epoch,
              epochs=epoch+1,
              batch_size=32)
    
    y_pred = model.predict(x_valid)    
    result = evaluation.prauc(y_valid, y_pred)
    
    print_result(result)
    print '' 