In [1]:
import os
os.sys.path.append('../models/research/audioset/yamnet')
import yamnet as yamnet_model
import params
from datagen_yamnet import DataGenerator, get_files_and_labels

In [2]:
files_train, labels_train, files_val, labels_val, class_dict = get_files_and_labels('./train_set_patches/', typ='npy')
class_dict_rev = {(str(v[0])): k for k, v in class_dict.items()}


In [3]:
yamnet = yamnet_model.yamnet_model()
yamnet.load_weights('../models/research/audioset/yamnet/yamnet.h5')

In [8]:
# yamnet.summary()

In [5]:
import tensorflow as tf
from tensorflow.keras.models import Model
inpts = tf.keras.layers.Input(shape=(params.PATCH_FRAMES, params.PATCH_BANDS))

x = tf.keras.layers.Dense(512, activation='relu')(yamnet.layers[-3].output)
o = tf.keras.layers.Dropout(0.5)(x)
o = tf.keras.layers.Dense(2, activation='softmax')(o)

yamnet2 = Model(inputs=yamnet.input, outputs=o)

for layer in yamnet2.layers:
    layer.trainable = True

optimizer = tf.keras.optimizers.Adam(lr=0.0001, decay=1e-7)
yamnet2.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

yamnet2.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 96, 64)]          0         
_________________________________________________________________
reshape (Reshape)            (None, 96, 64, 1)         0         
_________________________________________________________________
layer1/conv (Conv2D)         (None, 48, 32, 32)        288       
_________________________________________________________________
layer1/conv/bn (BatchNormali (None, 48, 32, 32)        96        
_________________________________________________________________
layer1/relu (ReLU)           (None, 48, 32, 32)        0         
_________________________________________________________________
layer2/depthwise_conv (Depth (None, 48, 32, 32)        288       
_________________________________________________________________
layer2/depthwise_conv/bn (Ba (None, 48, 32, 32)        96    

In [6]:
train_generator = DataGenerator(files_train,
                                labels_train)
validation_generator = DataGenerator(files_val,
                                    labels_val)

In [7]:
model_history = yamnet2.fit(train_generator,
                            steps_per_epoch = len(train_generator),
                            epochs = 1,
                            validation_data = validation_generator,
                            validation_steps = len(validation_generator),
                            verbose = 1)

