In [188]:
import os, fnmatch
from pathlib import Path
import datetime
from random import shuffle, seed
import numpy as np
from wavinfo import WavInfoReader
import tensorflow as tf
import tensorflow_io as tfio
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map, thread_map
from threading import Lock
from multiprocessing import Pool, RLock, freeze_support
import tensorflow.keras as keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Activation, Dense, LSTM, Dropout,\
    Lambda, Input, Multiply, Layer, Conv1D
from tensorflow.keras.callbacks import ReduceLROnPlateau, CSVLogger,\
    EarlyStopping, ModelCheckpoint

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, GRU
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.backend import clear_session
from librosa.util import frame

In [189]:
class audio_generator:
    '''
    Class to create a Tensorflow dataset based on an iterator from a large scale
    audio dataset. This audio generator only supports single channel audio files.
    '''
    def __init__(self, path_to_input, path_to_s1, len_of_chunks, nfft, window_size_samples, stride_size_samples, train_flag=False):
        '''
        Constructor of the audio generator class.
        Inputs:
            path_to_input       path to the mixtures
            path_to_s1          path to the target source data
            len_of_samples      length of audio snippets in samples
            fs                  sampling rate
            train_flag          flag for activate shuffling of files
        '''
        # set inputs to properties
        self.total_samples = 0
        self.path_to_input = path_to_input
        self.path_to_s1 = path_to_s1
        self.len_of_chunks = len_of_chunks
        self.nfft = nfft
        self.window_size_samples = window_size_samples
        self.stride_size_samples = stride_size_samples
        self.train_flag=train_flag
        # count the number of samples in your data set (depending on your disk,
        #                                               this can take some time)
        self.mutex = Lock()
        self.count_samples()
        # create iterable tf.data.Dataset object
        self.create_tf_data_obj()

    def number_of_chunks(self, len_of_audio, chunk_size, stride_size):
        return 1 + int((len_of_audio - chunk_size) / stride_size)

    def process_file(self, filename):
        info = WavInfoReader(os.path.join(self.path_to_input, filename))
        num_of_chunks = int(np.fix(self.number_of_chunks(info.data.frame_count, self.window_size_samples, self.stride_size_samples)/self.len_of_chunks))
        num_of_chunks = self.number_of_chunks(num_of_chunks, self.len_of_chunks, self.len_of_chunks)
        self.mutex.acquire()
        self.total_samples = self.total_samples + num_of_chunks
        self.mutex.release()

    def count_samples(self):
        '''
        Method to list the data of the dataset and count the number of samples.
        '''

        # list .wav files in directory
        self.file_names = fnmatch.filter(os.listdir(self.path_to_input), '*.wav')
        # count the number of samples contained in the dataset
        self.total_samples = 0

        data_path = os.path.join(self.path_to_input, self.file_names[0])
        self.data_shape = self.convert_audio_to_spectrogram(str(data_path), self.nfft, self.window_size_samples, self.stride_size_samples).shape
        # freeze_support()  # for Windows support
        # thread_map(self.process_file, self.file_names, chunksize=16)
        self.total_samples = 1700
        print(f"Found {self.total_samples} different chunks")

    def convert_audio_to_spectrogram(self, filepath, nfft, window_size_samples, stride_size_samples):
        audio = tfio.audio.AudioIOTensor(filepath)
        audio_slice = audio[:]
        audio_tensor = tf.squeeze(audio_slice, axis=[-1])
        if audio_tensor.dtype != tf.float32:
            audio_tensor = tf.cast(audio_tensor, tf.float32) / 32768.0
        spectrogram = tfio.audio.spectrogram(
            audio_tensor, nfft=nfft, window=window_size_samples, stride=stride_size_samples)

        spectrogram = tf.math.log(spectrogram).numpy()
        return spectrogram

    def create_generator(self):
        '''
        Method to create the iterator.
        '''

        # check if training or validation
        if self.train_flag:
            shuffle(self.file_names)
        # iterate over the files
        for file in self.file_names:
            noisy_path = os.path.join(self.path_to_input, file)
            clean_file = file.replace("_xdata", "_ydata")
            clean_path = os.path.join(self.path_to_s1, clean_file)

            # Calc STFT
            noisy_stft = self.convert_audio_to_spectrogram(str(noisy_path), self.nfft, self.window_size_samples, self.stride_size_samples)
            clean_stft = self.convert_audio_to_spectrogram(str(clean_path), self.nfft, self.window_size_samples, self.stride_size_samples)

            if noisy_stft.shape != clean_stft.shape:
                raise ValueError('Data shapes do not match.')

            # Count number of (len_of_chunks, NFFT/2+1) chunks in STFT
            segmented_data_x = frame(noisy_stft, frame_length=self.len_of_chunks, hop_length=self.len_of_chunks, axis=0)
            segmented_data_y = frame(clean_stft, frame_length=self.len_of_chunks, hop_length=self.len_of_chunks, axis=0)

            if np.isnan(segmented_data_x).any():
                print(segmented_data_x)
                raise ValueError("NaN")

            if np.isnan(segmented_data_y).any():
                print(segmented_data_y)
                raise ValueError("NaN")

            if segmented_data_x.shape != segmented_data_y.shape:
                raise ValueError('Data shapes do not match.')

            # iterate over the number of samples
            for idx in range(segmented_data_x.shape[0]):
                # yield the chunks as float32 data
                yield segmented_data_x[idx].astype('float32'), segmented_data_y[idx].astype('float32')


    def create_tf_data_obj(self):
        '''
        Method to to create the tf.data.Dataset.
        '''
        # creating the tf.data.Dataset from the iterator
        self.tf_data_set = tf.data.Dataset.from_generator(
            self.create_generator,
            (tf.float32, tf.float32),
            output_shapes=(tf.TensorShape([self.len_of_chunks, self.data_shape[1]]),
                           tf.TensorShape([self.len_of_chunks, self.data_shape[1]])),
            args=None)


In [190]:
class NAEC_model:
    def __init__(self):
        # defining default cost function
        self.cost_function = self.snr_cost
        # empty property for the model
        self.model = []
        # defining default parameters
        self.batchsize = 32

        self.chunk_len = 10
        self.nfft = 512
        self.block_shift = int(self.nfft * 0.5)

        self.activation = 'sigmoid'

        self.numUnits = 128
        self.numLayer = 2

        self.dropout = 0.25
        self.lr = 1e-5
        self.max_epochs = 1
        self.encoder_size = 256
        self.eps = 1e-7

        self.mse_loss = keras.losses.MeanSquaredError()
        # reset all seeds to 42 to reduce invariance between training runs
        os.environ['PYTHONHASHSEED'] = str(42)
        seed(42)
        np.random.seed(42)

    @staticmethod
    def snr_cost(s_estimate, s_true, s_input):
        '''
        Static Method defining the cost function.
        The negative signal to noise ratio is calculated here. The loss is
        always calculated over the last dimension.
        '''

        # # calculating the SNR
        # snr = tf.reduce_mean(tf.math.square(s_true), axis=-1, keepdims=True) /\
        #       (tf.reduce_mean(tf.math.square(s_true - s_estimate), axis=-1, keepdims=True) + 1e-7)
        #
        # # using some more lines, because TF has no log10
        # num = tf.math.log(snr)
        # denom = tf.math.log(tf.constant(10, dtype=num.dtype))
        # loss = -10 * (num / denom)
        # # returning the loss
        out_vals = s_input * s_estimate
        loss = keras.losses.MSE(out_vals, s_estimate)
        return loss


    def lossWrapper(self):
        '''
        A wrapper function which returns the loss function. This is done to
        to enable additional arguments to the loss function if necessary.
        '''

        def lossFunction(y_true, y_pred):
            # calculating loss and squeezing single dimensions away
            loss = tf.squeeze(self.cost_function(y_pred, y_true))
            # calculate mean over batches
            loss = tf.reduce_mean(loss)
            # return the loss
            # pred_data, input_data = y_pred
            # out_vals = input_data * pred_data
            # loss = keras.losses.MSE(y_pred, y_true)
            return loss

        # returning the loss function as handle
        return lossFunction

    def create_model(self):
        '''
        Method to build and compile the NAEC model. The model takes frequency
        domain batches of size (batchsize, chunk_len, nfft/2 + 1) and returns
        enhanced clips in the same dimensions.
        '''

        # input layer for time signal
        input_output_shape = (self.chunk_len, int(self.nfft / 2) + 1)
        freq_dat = Input(batch_shape=(None, self.chunk_len, int(self.nfft / 2) + 1))
        self.input_freq_layer = freq_dat

        # gru_layer = GRU(int(self.nfft / 2) + 1, activation='tanh', input_shape=input_output_shape, return_sequences=True)(freq_dat)
        mag_norm = InstantLayerNormalization()(freq_dat)
        gru_layer = LSTM(int(self.nfft / 2) + 1, activation='relu', input_shape=input_output_shape, return_sequences=True)(mag_norm)
        estimated_mag = Multiply()([freq_dat, gru_layer])

        self.model = tf.keras.Model(inputs=freq_dat, outputs=estimated_mag)
        # show the model summary
        self.model.summary()


    def compile_model(self):
        '''
        Method to compile the model for training
        '''
        # use the Adam optimizer with a clipnorm of 3
        optimizer_adam = keras.optimizers.Adam(learning_rate=self.lr, clipnorm=1, decay=1e-3)
        # compile model with loss function
        # self.model.compile(loss=self.lossWrapper(), optimizer=optimizer_adam, metrics=['accuracy'])
        self.model.compile(loss='mean_squared_error', optimizer=optimizer_adam, metrics=['accuracy'])
        # model.compile(loss=custom_loss(input_layer),optimizer=opt, metrics=['accuracy'])


    def train_model(self, run_name, path_to_train_x, path_to_train_y,
                    path_to_val_x, path_to_val_y):
        '''
        Method to train the DTLN model.
        '''
        print("Train NAEC model")
        # create save path if not existent
        ct = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
        save_path = str(Path("./models/" + str(ct)).mkdir(parents=True, exist_ok=True)) + '/' + run_name + '/'
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        print(f"Save path: {save_path}")

        # create log file writer
        csv_logger = CSVLogger(save_path + 'training_' + run_name + '.log')
        # create callback for the adaptive learning rate
        reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5,
                                      patience=3, min_lr=10 ** (-10), cooldown=1)
        # create callback for early stopping
        early_stopping = EarlyStopping(monitor='val_loss', min_delta=0,
                                       patience=10, verbose=0, mode='auto', baseline=None)
        # create model check pointer to save the best model
        checkpointer = ModelCheckpoint(save_path + run_name + '.h5',
                                       monitor='val_loss',
                                       verbose=1,
                                       save_best_only=True,
                                       save_weights_only=False,
                                       mode='auto',
                                       save_freq='epoch'
                                       )

        # create data generator for training data
        print("Load training data")
        generator_input = audio_generator(path_to_train_x,
                                          path_to_train_y,
                                          self.chunk_len,
                                          self.nfft,
                                          self.nfft,
                                          self.block_shift,
                                          train_flag=True)
        dataset = generator_input.tf_data_set
        dataset = dataset.batch(self.batchsize, drop_remainder=True).repeat()
        # calculate number of training steps in one epoch
        steps_train = generator_input.total_samples // self.batchsize

        print("Training generator: {}".format(dataset))
        print("Training steps count: {}".format(steps_train))

        # create data generator for validation data
        print("Load validation data")
        generator_val = audio_generator(path_to_val_x,
                                        path_to_val_y,
                                        self.chunk_len,
                                        self.nfft,
                                        self.nfft,
                                        self.block_shift,
                                        train_flag=False)
        dataset_val = generator_val.tf_data_set
        dataset_val = dataset_val.batch(self.batchsize, drop_remainder=True).repeat()
        # calculate number of validation steps
        steps_val = generator_val.total_samples // self.batchsize

        print("Validation generator: {}".format(dataset_val))
        print("Validation steps count: {}".format(steps_val))

        # start the training of the model
        self.model.fit(
            x=dataset,
            batch_size=None,
            steps_per_epoch=steps_train,
            epochs=self.max_epochs,
            verbose=1,
            validation_data=dataset_val,
            validation_steps=steps_val,
            callbacks=[checkpointer, reduce_lr, early_stopping],
            max_queue_size=50,
            workers=4,
            use_multiprocessing=True)
        # clear out garbage
        tf.keras.backend.clear_session()

class InstantLayerNormalization(Layer):
    '''
    Class implementing instant layer normalization. It can also be called
    channel-wise layer normalization and was proposed by
    Luo & Mesgarani (https://arxiv.org/abs/1809.07454v2)
    '''

    def __init__(self, **kwargs):
        '''
            Constructor
        '''
        super(InstantLayerNormalization, self).__init__(**kwargs)
        self.epsilon = 1e-7
        self.gamma = None
        self.beta = None

    def build(self, input_shape):
        '''
        Method to build the weights.
        '''
        shape = input_shape[-1:]
        # initialize gamma
        self.gamma = self.add_weight(shape=shape,
                                     initializer='ones',
                                     trainable=True,
                                     name='gamma')
        # initialize beta
        self.beta = self.add_weight(shape=shape,
                                    initializer='zeros',
                                    trainable=True,
                                    name='beta')


    def call(self, inputs):
        '''
        Method to call the Layer. All processing is done here.
        '''

        # calculate mean of each frame
        mean = tf.math.reduce_mean(inputs, axis=[-1], keepdims=True)
        # calculate variance of each frame
        variance = tf.math.reduce_mean(tf.math.square(inputs - mean),
                                       axis=[-1], keepdims=True)
        # calculate standard deviation
        std = tf.math.sqrt(variance + self.epsilon)
        # normalize each frame independently
        outputs = (inputs - mean) / std
        # scale with gamma
        outputs = outputs * self.gamma
        # add the bias beta
        outputs = outputs + self.beta
        # return output
        return outputs

In [191]:
CUR_PATH = os.getcwd()
TRAIN_DATA_PATH_PREFIX = CUR_PATH.replace("notebooks", "data") + "/datasets/train/"
VAL_DATA_PATH_PREFIX = CUR_PATH.replace("notebooks", "data") + "/datasets/validation/"
X_DATA_PATH = (TRAIN_DATA_PATH_PREFIX + "x_data").replace('\\', '/')
Y_DATA_PATH = (TRAIN_DATA_PATH_PREFIX + "y_data").replace('\\', '/')

VAL_X_DATA_PATH = (VAL_DATA_PATH_PREFIX + "x_data").replace('\\', '/')
VAL_Y_DATA_PATH = (VAL_DATA_PATH_PREFIX + "y_data").replace('\\', '/')

In [192]:
model_trainer = NAEC_model()
model_trainer.create_model()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 10, 257)]    0                                            
__________________________________________________________________________________________________
instant_layer_normalization (In (None, 10, 257)      514         input_1[0][0]                    
__________________________________________________________________________________________________
lstm (LSTM)                     (None, 10, 257)      529420      instant_layer_normalization[0][0]
__________________________________________________________________________________________________
multiply (Multiply)             (None, 10, 257)      0           input_1[0][0]                    
                                                                 lstm[0][0]                   

In [193]:
model_trainer.compile_model()

In [194]:
runName = 'NAEC_model'
model_trainer.train_model(runName, X_DATA_PATH, Y_DATA_PATH,
                         VAL_X_DATA_PATH, VAL_Y_DATA_PATH)

Train NAEC model
Save path: None/NAEC_model/
Load training data
Found 1700 different chunks
Training generator: <RepeatDataset shapes: ((32, 10, 257), (32, 10, 257)), types: (tf.float32, tf.float32)>
Training steps count: 53
Load validation data
Found 1700 different chunks
Validation generator: <RepeatDataset shapes: ((32, 10, 257), (32, 10, 257)), types: (tf.float32, tf.float32)>
Validation steps count: 53

Epoch 00001: val_loss did not improve from inf
