In [103]:
import tensorflow as tf
import tensorflow_io as tfio
from tensorflow.keras import layers, models
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping

from tensorflow.keras.layers import Input, Conv2D, Lambda
from tensorflow.keras import Model, Layer

from tensorflow.keras.utils import register_keras_serializable

import numpy as np
import random
from audiomentations import AddBackgroundNoise
import pandas as pd
from tqdm import tqdm
from keras_tqdm import TQDMNotebookCallback
import keras

In [104]:
import os
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # or any {'0', '1', '2'}
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"



In [105]:
import warnings
warnings.filterwarnings("ignore")

In [106]:
class CONFIG:
    SEED = 42
    CLASSIFIER_LR = 3e-4
    EPOCH = 200
    BATCH_SIZE = 8
    VOICE_DIR = './cv-corpus-19.0-2024-09-13/ko/clips/'
    NOISE_DIR = './ESC-50-master/audio/'

In [107]:
def periodic_hann_window(window_length, dtype):
    return 0.5 - 0.5 * tf.math.cos(2.0 *
                                   np.pi *
                                   tf.range(tf.cast(window_length, tf.float32)) /
                                   tf.cast(window_length, tf.float32))

In [108]:
def wave2log_mel_spectrogram(wave):
    signal_stft = tf.signal.stft(tf.cast(wave, tf.float32),
                                 frame_length=400,
                                 frame_step=160,
                                 fft_length=1024,
                                 window_fn=periodic_hann_window)
    # print(signal_stft.shape) # (98, 513)
    spectogram = tf.abs(signal_stft)

    linear_to_mel = tf.signal.linear_to_mel_weight_matrix(40,
                                signal_stft.shape[-1],
                                16000,
                                300.0,
                                4000.0)
    mel_spectrogram = tf.tensordot(spectogram, linear_to_mel, 1)
    log_mel_spectrogram = tf.math.log(mel_spectrogram + 1e-12)
    return log_mel_spectrogram

In [109]:
def data_generator(df):
    for _, row in df.iterrows():
        path = row['path']
        label = row['label']

        audio_raw = tf.io.read_file(path)
        wave, sr = tf.audio.decode_wav(audio_raw, desired_channels=1)
        wave = tf.squeeze(wave, axis=-1)
        # print(wave.shape) (16000, )
        log_mel_spectrogram = wave2log_mel_spectrogram(wave)
        log_mel_spectrogram = np.expand_dims(log_mel_spectrogram, axis=-1)
        # print(log_mel_spectrogram.shape) #(49, 80)

        yield log_mel_spectrogram, label


In [110]:
train_df = pd.read_csv('train_dataset.csv')


In [111]:
valid_df = pd.read_csv('valid_dataset.csv')


In [112]:
test_df = pd.read_csv('test_dataset.csv')


In [113]:
train_dataset = tf.data.Dataset.from_generator(lambda: data_generator(train_df), output_signature=(tf.TensorSpec(shape=(98, 40, 1), dtype=tf.float32), tf.TensorSpec(shape=(), dtype=tf.int32))).batch(CONFIG.BATCH_SIZE).cache().prefetch(tf.data.AUTOTUNE)
valid_dataset = tf.data.Dataset.from_generator(lambda: data_generator(valid_df), output_signature=(tf.TensorSpec(shape=(98, 40, 1), dtype=tf.float32), tf.TensorSpec(shape=(), dtype=tf.int32))).batch(CONFIG.BATCH_SIZE).cache().prefetch(tf.data.AUTOTUNE)
test_dataset = tf.data.Dataset.from_generator(lambda: data_generator(test_df), output_signature=(tf.TensorSpec(shape=(98, 40, 1), dtype=tf.float32), tf.TensorSpec(shape=(), dtype=tf.int32))).batch(CONFIG.BATCH_SIZE).cache().prefetch(tf.data.AUTOTUNE)

In [114]:
@tf.keras.utils.register_keras_serializable(package="Custom")
class CustomAttentionModel(Model):
    def __init__(self, **kwargs):
        super(CustomAttentionModel, self).__init__(**kwargs)
        
        # Define layers
        self.lstm_cell = layers.LSTMCell(40)  # Single LSTM cell
        self.rnn = layers.RNN(self.lstm_cell, return_sequences=True)
        # self.lstm = layers.LSTM(40, return_sequences=True, input_shape=(98, 40))
        self.dense = layers.Dense(2, activation='softmax')
        self.query_vector = self.add_weight(
            shape=(40,),
            initializer=tf.keras.initializers.RandomNormal(),
            trainable=True,
            name="query_vector",
        )
        self.reshape = layers.Reshape((98, 40))  # Reshape for RNN compatibility

    def call(self, inputs):
        # Forward pass
        inputs = self.reshape(inputs)  # (batch_size, 98, 40)
        lstm_output = self.rnn(inputs)  # (batch_size, time_steps, 40)

        # Attention mechanism
        query_vector = tf.expand_dims(self.query_vector, axis=0)  # Shape: (1, 40)
        query_vector = tf.expand_dims(query_vector, axis=0)  # Shape: (1, 1, 40)
        attention_scores = tf.matmul(lstm_output, query_vector, transpose_b=True)  # Shape: (batch_size, time_steps, 1)
        attention_weights = tf.nn.softmax(attention_scores, axis=1)  # Shape: (batch_size, time_steps, 1)

        # Context vector
        context_vector = tf.reduce_sum(lstm_output * attention_weights, axis=1)  # Shape: (batch_size, 40)

        # Classification
        outputs = self.dense(context_vector)  # Shape: (batch_size, 2)
        return outputs

    def get_config(self):
        config = super(CustomAttentionModel, self).get_config()
        return config
    
    @classmethod
    def from_config(cls, config):
        return cls(**config)
    
    
# Instantiate and build the model
inputs = tf.keras.Input(shape=(98, 40, 1))
model = CustomAttentionModel()
outputs = model(inputs)

model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.summary()

In [115]:

model.compile(optimizer=Adam(learning_rate=CONFIG.CLASSIFIER_LR),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=2, restore_best_weights=True)


history = model.fit(train_dataset, verbose=2, epochs=CONFIG.EPOCH, validation_data=valid_dataset, callbacks=[early_stopping], steps_per_epoch=len(train_df)//CONFIG.BATCH_SIZE, validation_steps=len(valid_df)//CONFIG.BATCH_SIZE)

Epoch 1/200
53/53 - 9s - 177ms/step - accuracy: 0.7052 - loss: 0.6163 - val_accuracy: 0.7212 - val_loss: 0.5844
Epoch 2/200


2024-11-22 17:59:53.014867: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:53.014916: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


53/53 - 1s - 16ms/step - accuracy: 1.0000 - loss: 0.3721 - val_accuracy: 1.0000 - val_loss: 0.3544
Epoch 3/200


2024-11-22 17:59:53.305270: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:53.305326: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


53/53 - 0s - 6ms/step - accuracy: 0.7406 - loss: 0.5416 - val_accuracy: 0.7115 - val_loss: 0.5379
Epoch 4/200
53/53 - 0s - 217us/step - accuracy: 1.0000 - loss: 0.3679 - val_accuracy: 1.0000 - val_loss: 0.2641
Epoch 5/200


2024-11-22 17:59:53.636232: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615
2024-11-22 17:59:53.642851: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:53.642881: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


53/53 - 0s - 5ms/step - accuracy: 0.7500 - loss: 0.5005 - val_accuracy: 0.7115 - val_loss: 0.5004
Epoch 6/200
53/53 - 0s - 200us/step - accuracy: 1.0000 - loss: 0.3650 - val_accuracy: 1.0000 - val_loss: 0.2199
Epoch 7/200


2024-11-22 17:59:53.927119: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:53.927142: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


53/53 - 0s - 5ms/step - accuracy: 0.7642 - loss: 0.4677 - val_accuracy: 0.7308 - val_loss: 0.4701
Epoch 8/200
53/53 - 0s - 200us/step - accuracy: 1.0000 - loss: 0.3559 - val_accuracy: 1.0000 - val_loss: 0.1939
Epoch 9/200


2024-11-22 17:59:54.215730: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:54.215754: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615
2024-11-22 17:59:54.222071: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:54.222087: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


53/53 - 0s - 5ms/step - accuracy: 0.7972 - loss: 0.4413 - val_accuracy: 0.7404 - val_loss: 0.4522
Epoch 10/200
53/53 - 0s - 193us/step - accuracy: 1.0000 - loss: 0.3372 - val_accuracy: 1.0000 - val_loss: 0.1727
Epoch 11/200


2024-11-22 17:59:54.503523: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:54.503544: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615
2024-11-22 17:59:54.509728: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:54.509744: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


53/53 - 0s - 5ms/step - accuracy: 0.8137 - loss: 0.4212 - val_accuracy: 0.7500 - val_loss: 0.4409
Epoch 12/200
53/53 - 0s - 202us/step - accuracy: 1.0000 - loss: 0.3192 - val_accuracy: 1.0000 - val_loss: 0.1576
Epoch 13/200


2024-11-22 17:59:54.792541: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:54.792577: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615
2024-11-22 17:59:54.798880: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:54.798896: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


53/53 - 0s - 5ms/step - accuracy: 0.8231 - loss: 0.4033 - val_accuracy: 0.7500 - val_loss: 0.4367
Epoch 14/200
53/53 - 0s - 195us/step - accuracy: 1.0000 - loss: 0.2979 - val_accuracy: 1.0000 - val_loss: 0.1412
Epoch 15/200


2024-11-22 17:59:55.080305: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:55.080329: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615
2024-11-22 17:59:55.086511: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:55.086533: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


53/53 - 0s - 5ms/step - accuracy: 0.8278 - loss: 0.3885 - val_accuracy: 0.7500 - val_loss: 0.4368
Epoch 16/200
53/53 - 0s - 199us/step - accuracy: 1.0000 - loss: 0.2779 - val_accuracy: 1.0000 - val_loss: 0.1293
Epoch 17/200


2024-11-22 17:59:55.369839: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:55.369864: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615
2024-11-22 17:59:55.376035: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:55.376050: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


53/53 - 0s - 5ms/step - accuracy: 0.8349 - loss: 0.3748 - val_accuracy: 0.7500 - val_loss: 0.4373
Epoch 18/200
53/53 - 0s - 201us/step - accuracy: 1.0000 - loss: 0.2671 - val_accuracy: 1.0000 - val_loss: 0.1190
Epoch 19/200


2024-11-22 17:59:55.657904: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:55.657928: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615
2024-11-22 17:59:55.664202: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:55.664218: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


53/53 - 0s - 5ms/step - accuracy: 0.8373 - loss: 0.3624 - val_accuracy: 0.7500 - val_loss: 0.4380
Epoch 20/200
53/53 - 0s - 198us/step - accuracy: 1.0000 - loss: 0.2602 - val_accuracy: 1.0000 - val_loss: 0.1039
Epoch 21/200


2024-11-22 17:59:55.952474: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:55.952512: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


53/53 - 0s - 5ms/step - accuracy: 0.8491 - loss: 0.3498 - val_accuracy: 0.7596 - val_loss: 0.4398
Epoch 22/200
53/53 - 0s - 200us/step - accuracy: 1.0000 - loss: 0.2501 - val_accuracy: 1.0000 - val_loss: 0.0849
Epoch 23/200


2024-11-22 17:59:56.240497: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:56.240521: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


53/53 - 0s - 5ms/step - accuracy: 0.8491 - loss: 0.3373 - val_accuracy: 0.7596 - val_loss: 0.4432
Epoch 24/200
53/53 - 0s - 199us/step - accuracy: 1.0000 - loss: 0.2434 - val_accuracy: 1.0000 - val_loss: 0.0681
Epoch 25/200


2024-11-22 17:59:56.531259: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:56.531282: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


53/53 - 0s - 5ms/step - accuracy: 0.8538 - loss: 0.3255 - val_accuracy: 0.7596 - val_loss: 0.4472
Epoch 26/200
53/53 - 0s - 199us/step - accuracy: 1.0000 - loss: 0.2407 - val_accuracy: 1.0000 - val_loss: 0.0569
Epoch 27/200


2024-11-22 17:59:56.820360: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:56.820385: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


53/53 - 0s - 5ms/step - accuracy: 0.8561 - loss: 0.3124 - val_accuracy: 0.7500 - val_loss: 0.4558
Epoch 28/200
53/53 - 0s - 200us/step - accuracy: 1.0000 - loss: 0.2354 - val_accuracy: 1.0000 - val_loss: 0.0457
Epoch 29/200


2024-11-22 17:59:57.103093: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:57.103130: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615
2024-11-22 17:59:57.109425: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:57.109440: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


53/53 - 0s - 5ms/step - accuracy: 0.8561 - loss: 0.3042 - val_accuracy: 0.7404 - val_loss: 0.4695
Epoch 30/200
53/53 - 0s - 200us/step - accuracy: 1.0000 - loss: 0.2336 - val_accuracy: 1.0000 - val_loss: 0.0370
Epoch 31/200


2024-11-22 17:59:57.393135: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:57.393156: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615
2024-11-22 17:59:57.399479: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:57.399496: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


53/53 - 0s - 5ms/step - accuracy: 0.8585 - loss: 0.2955 - val_accuracy: 0.7596 - val_loss: 0.4781
Epoch 32/200
53/53 - 0s - 199us/step - accuracy: 1.0000 - loss: 0.2314 - val_accuracy: 1.0000 - val_loss: 0.0273
Epoch 33/200


2024-11-22 17:59:57.685232: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:57.685255: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615
2024-11-22 17:59:57.691531: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:57.691547: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


53/53 - 0s - 5ms/step - accuracy: 0.8561 - loss: 0.2987 - val_accuracy: 0.7500 - val_loss: 0.4865
Epoch 34/200
53/53 - 0s - 189us/step - accuracy: 1.0000 - loss: 0.2324 - val_accuracy: 1.0000 - val_loss: 0.0398
Epoch 35/200


2024-11-22 17:59:57.976832: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:57.976856: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615
2024-11-22 17:59:57.983233: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 17:59:57.983248: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


53/53 - 0s - 5ms/step - accuracy: 0.8608 - loss: 0.2856 - val_accuracy: 0.7500 - val_loss: 0.4667
Epoch 35: early stopping
Restoring model weights from the end of the best epoch: 32.


In [116]:
test_loss, test_accuracy = model.evaluate(test_dataset)

[1m44/44[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 117ms/step - accuracy: 0.7644 - loss: 0.4906


2024-11-22 18:00:03.411261: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 1964356605870186667
2024-11-22 18:00:03.411288: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 14900770143387590615


In [117]:
inputs = tf.random.normal([32, 98, 40, 1])  # Example input
model(inputs)  # Build the model by passing inputs
model.save('./vad_slimnet.h5')




In [118]:
model.export('./vad_slimnet')

INFO:tensorflow:Assets written to: ./vad_slimnet/assets


INFO:tensorflow:Assets written to: ./vad_slimnet/assets


Saved artifact at './vad_slimnet'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 98, 40, 1), dtype=tf.float32, name='keras_tensor_12')
Output Type:
  TensorSpec(shape=(None, 2), dtype=tf.float32, name=None)
Captures:
  139327687697936: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139329709788672: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139327687699168: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139329709415184: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139327687688784: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139327687703216: TensorSpec(shape=(), dtype=tf.resource, name=None)
