In [None]:
import os
import json
os.sys.path.append('./audioset/yamnet')
import yamnet as yamnet_model
import params
from datagen_yamnet import DataGenerator, get_files_and_labels

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau


# Directory storing the spectrogram inputs for each class
train_dir = './train_set_patches/'

# Path to output model
model_out = './saved_models/model'


#### Get training data files and labels

In [None]:
# The 'train_split' parameter below sets how much of the training data will be randomly sampled for test data
files_train, labels_train, files_val, labels_val, class_dict = get_files_and_labels(train_dir, 
                                                                                    typ='npy',
                                                                                    train_split=0.9)

# If you want to use your own separate folder with test samples instead, do this:
# files_train, labels_train, _, _, class_dict = get_files_and_labels(train_dir, 
#                                                                     typ='npy',
#                                                                     train_split=1)
# files_val, labels_val, _, _, _ = get_files_and_labels(test_dir, 
#                                                         typ='npy',
#                                                         train_split=1)

class_dict_rev = {(str(v[0])): k for k, v in class_dict.items()}


#### Load pre-trained YAMNet

In [None]:
yamnet = yamnet_model.yamnet_model()
yamnet.load_weights('./audioset/yamnet/yamnet.h5')

In [None]:
# yamnet.summary()

#### Define custom "top" of YAMNet graph

In [None]:
inpts = tf.keras.layers.Input(shape=(params.PATCH_FRAMES, params.PATCH_BANDS))

x = tf.keras.layers.Dense(64, activation='relu')(yamnet.layers[-3].output)
o = tf.keras.layers.Dropout(0.5)(x)
o = tf.keras.layers.Dense(2, activation='softmax')(o)

model = Model(inputs=yamnet.input, outputs=o)

for layer in model.layers:
    layer.trainable = True

model.summary()

#### Initialize data generators

In [None]:
batch_size = 32

train_generator = DataGenerator(files_train,
                                labels_train,
                                batch_size=batch_size)
validation_generator = DataGenerator(files_val,
                                    labels_val,
                                    batch_size=batch_size)

#### Define training hyperparameters

In [None]:
# Save model architecture
model_json = model.to_json()
with open(model_out+'.json', "w") as json_file:
    json_file.write(model_json)
with open(model_out+'_classes.json', 'w') as f:
    json.dump(class_dict, f)
print('Saved model architecture')

# Define training callbacks
checkpoint = ModelCheckpoint(model_out+'.h5',
                             monitor='val_loss', 
                             verbose=1,
                             save_best_only=True, 
                             mode='auto')

reducelr = ReduceLROnPlateau(monitor='val_loss', 
                              factor=0.5, 
                              patience=3, 
                              verbose=1)

# Compile model
optimizer = tf.keras.optimizers.Adam(lr=0.0001)
model.compile(loss='binary_crossentropy', optimizer=optimizer)
    

#### Train

In [None]:
try:
    model_history = model.fit(train_generator,
                                steps_per_epoch = len(train_generator),
                                epochs = 20,
                                validation_data = validation_generator,
                                validation_steps = len(validation_generator),
                                verbose = 1,
                                callbacks=[checkpoint, reducelr])
except Exception as e:
    err = e

#### Plot training history

In [None]:
import matplotlib.pyplot as plt
plt.plot(model_history.history['loss'])
plt.plot(model_history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['training', 'validation'], loc='upper left')
plt.show()