In [None]:
# path to tensorflow model research directory for yamnet
# https://github.com/tensorflow/models/tree/master/research/audioset
yamnet_path = '/home/ubuntu/models/research/audioset/yamnet/'

# path to downloaded yamnet weights
# must be downloaded https://github.com/tensorflow/models/tree/master/research/audioset/yamnet
weights_path = '/home/ubuntu/models/research/audioset/yamnet/yamnet.h5'


In [None]:
import os
import json
import tensorflow as tf
os.sys.path.insert(0, yamnet_path)
import yamnet
import params
from datagen_yamnet import DataGenerator, get_files_and_labels

# directory storing the spectrogram inputs for each class
train_dir = './train_set_patches/'

# path to output model
model_out = './saved_models/model'


#### Get training data files and labels

In [None]:
# the 'train_split' parameter below sets how much of the training data will be randomly sampled for test data
files_train, files_val, labels = get_files_and_labels(train_dir, 
                                                    typ='npy',
                                                    train_split=0.8)

# if you want to use a separate folder with validation samples instead, do this:
# files_train, _, labels class_dict = get_files_and_labels(train_dir, 
#                                                         typ='npy',
#                                                         train_split=1)
# files_val, _, _ = get_files_and_labels(test_dir, 
#                                         typ='npy',
#                                         train_split=1)


#### Build pre-trained yamnet and load weights

In [None]:
params = params.Params()                      
yamnet_model = yamnet.yamnet_frames_model(params)
yamnet_model.load_weights(weights_path)


#### Build yamnet base model (backbone)

In [None]:
base_model = tf.keras.models.Model(yamnet_model.get_layer('reshape').input,
                                   yamnet_model.get_layer('layer14/pointwise_conv/relu').output,
                                   name='yamnet_base')

base_model.summary()


#### Define custom model top

In [None]:
input_shape = (int(params.patch_window_seconds/params.stft_hop_seconds),
               params.mel_bands)

def get_model():
    
    inputs = tf.keras.layers.Input(input_shape)
    
    # call the base model
    # set batch norm layers to inference mode
    # https://keras.io/guides/transfer_learning/
    x = base_model(inputs, training=False)
    
    # define model top
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dropout(0.5)(x) # for regularization
    outputs = tf.keras.layers.Dense(len(labels), activation='softmax')(x)
    
    # create model
    model = tf.keras.models.Model(inputs, outputs, name='custom_yamnet')
    
    return model

model = get_model()
model.summary()


In [None]:
# create data generators
# should use a larger batch size in a real case
batch_size = 2

train_generator = DataGenerator(files_train,
                                labels,
                                batch_size=batch_size)
validation_generator = DataGenerator(files_val,
                                    labels,
                                    batch_size=batch_size)


# define training callbacks
checkpoint = tf.keras.callbacks.ModelCheckpoint(model_out+'.h5',
                                                 monitor='val_loss', 
                                                 verbose=1,
                                                 save_best_only=True, 
                                                 mode='auto')


#### Train

In [None]:
# freeze base model 
model.get_layer('yamnet_base').trainable = False

# compile model
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(loss='categorical_crossentropy', 
              optimizer=optimizer)

model_history = model.fit(train_generator,
                            steps_per_epoch = len(train_generator),
                            epochs = 3,
                            validation_data = validation_generator,
                            validation_steps = len(validation_generator),
                            verbose = 1,
                            callbacks=[checkpoint])


#### Optionally fine-tune the entire model


In [None]:
# unfreeze base model 
model.get_layer('yamnet_base').trainable = False

# recompile model and lower the learning rate
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)
model.compile(loss='categorical_crossentropy', 
              optimizer=optimizer)

model_history = model.fit(train_generator,
                            steps_per_epoch = len(train_generator),
                            epochs = 3,
                            validation_data = validation_generator,
                            validation_steps = len(validation_generator),
                            verbose = 1,
                            callbacks=[checkpoint])
