# DENOISE USING SEGAN WRITTEN IN TENSORFLOW II



![](https://contentgroup.com.au/wp-content/uploads/2019/08/speech.jpg)

## Introduction
- Speech enhancement aims to improve speech quality by using various algorithms. The objective of enhancement is improvement in intelligibility and/or overall perceptual quality of degraded speech signal using audio signal processing techniques.

- Enhancing of speech degraded by noise, or noise reduction, is the most important field of speech enhancement, and used for many applications such as mobile phones, VoIP, teleconferencing systems, speech recognition, and hearing aids .


### Algorithms
The algorithms of speech enhancement for noise reduction can be categorized into three fundamental classes: filtering techniques, spectral restoration, and model-based methods .

 - Filtering Techniques
   - Spectral Subtraction Method
   - Wiener Filtering
   - Signal subspace approach (SSA)
 - Spectral Restoration
   - Minimum Mean-Square-Error Short-Time Spectral Amplitude Estimator (MMSE-STSA)
 - Speech-Model-Based

### Segan
 - Gans (Generative Adversarial Network) have plenty amazing results on image generation, image resolution,... 
 - And in audio domain, Gans have pretty good results, too. This Segan is a Gan model in speech enhancment field.

 
### Segan samples
 http://veu.talp.cat/segan/

# Necessary libraries

In [16]:
# Import librabires
from __future__ import absolute_import

import numpy as np
import glob
import os
import time
import sys
import pathlib
import matplotlib as plt
import math
import tensorflow as tf
print(tf.__version__)
print('GPU available:', tf.test.is_gpu_available())
#SeganDataset

# sox
# soundfile
# tensorflow-addons
from IPython.display import Audio

import random
random.seed(2020)

# # Download nessecsary pack
# !wget https://github.com/usimarit/semetrics
#!pip install librosa
#!pip install tensorflow
import librosa
## app for test
#!apt install octave

2.2.0
GPU available: False


In [17]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
# !mkdir Data
# !mkdir Data/Preprocessed
# !mkdir Data/Preprocessed/Noises
# !mkdir Data/Preprocessed/VivosCleanTrain
# !mkdir Data/Preprocessed/VivosCleanTest

# !mkdir trained
# !mkdir trained/logs
# !mkdir trained/ckpts

## SETUP CONFIG

In [0]:
#SEGAN CONFIG
config_dic={'batch_size' : 32,
'num_epochs' : 86,
'kwidth' : 31,
'ratio' : 2,
'noise_std' : 0.,
'denoise_epoch' : 5,
'noise_decay' : 0.9,
'noise_std_lbound' : 0.,
'l1_lambda' : 100.,
'pre_emph' : 0.95,
'window_size' : 2 ** 14,
'sample_rate' : 16000,
'stride' : 0.5,
'noise_conf' : {
    "snr": (-1, 2.5, 7.5, 12.5, 17.5),
    "max_noises": 3},
'noises_dir' : "./Data/Preprocessed/Noises",
'g_learning_rate' : 0.0002,
'd_learning_rate' : 0.0002,
'clean_train_data_dir' : "./Data/Preprocessed/VivosCleanTrain",
'clean_test_data_dir' : "./Data/Preprocessed/ViosCleanTest",
'checkpoint_dir' : "./drive/My Drive/trained/ckpts",
'log_dir' : "./drive/My Drive/trained/log",
}

#get_segan_config
segan_conf_required = [
    "kwidth", "ratio",
    "noise_std", "denoise_epoch",
    "noise_decay", "noise_std_lbound",
    "l1_lambda", "pre_emph",
    "window_size", "sample_rate",
    "stride", "noise_conf"
]

segan_conf_paths = [
    "clean_train_data_dir",
    "noises_dir",
    "clean_test_data_dir",
    "checkpoint_dir",
    "log_dir"
]

DEFAULT_NOISE_CONF = {
    "snr": (-1, 0, 5, 10, 15),
    "max_noises": 3,
}

# UTILS

In [0]:
#preprocess_paths
def preprocess_paths(paths):
    if isinstance(paths, list):
        return [os.path.abspath(os.path.expanduser(path)) for path in paths]
    return os.path.abspath(os.path.expanduser(paths))

def get_segan_config(config_dict):
    for key in segan_conf_paths:
        config_dict[key] = preprocess_paths(config_dict[key])

    return config_dict


#slice_signal
def slice_signal(signal, window_size, stride=0.5):
    """ Return windows of the given signal by sweeping in stride fractions
        of window
    """
    assert signal.ndim == 1, signal.ndim
    n_samples = signal.shape[0]
    offset = int(window_size * stride)
    slices = []
    for beg_i, end_i in zip(range(0, n_samples, offset),
                            range(window_size, n_samples + offset,
                                  offset)):
        slice_ = signal[beg_i:end_i]
        if slice_.shape[0] < window_size:
            slice_ = np.pad(slice_, (0, window_size - slice_.shape[0]), 'constant', constant_values=0.0)
        if slice_.shape[0] == window_size:
            slices.append(slice_)
    return np.array(slices, dtype=np.float32)

#merge_slices
def merge_slices(slices):
    # slices shape = [batch, window_size]
    return tf.keras.backend.flatten(slices)  # return shape = [-1, ]

#append_default_keys_dict
def append_default_keys_dict(default_dict, dest_dict):
    for key in default_dict.keys():
        if key not in dest_dict.keys():
            dest_dict[key] = default_dict[key]
    return dest_dict

# Preemphasis
def preemphasis(signal: np.ndarray, coeff=0.97):
    if not coeff or coeff == 0.0:
        return signal
    return np.append(signal[0], signal[1:] - coeff * signal[:-1])

# deemphasis
def deemphasis(signal: np.ndarray, coeff=0.97):
    if coeff <= 0:
        return signal
    x = np.zeros(signal.shape[0], dtype=np.float32)
    x[0] = signal[0]
    for n in range(1, signal.shape[0], 1):
        x[n] = coeff * x[n - 1] + signal[n]
    return x

# Read_raw_audio
def read_raw_audio(audio, sample_rate=16000):
    if isinstance(audio, str):
        wave, _ = librosa.load(os.path.expanduser(audio), sr=sample_rate)
    elif isinstance(audio, bytes):
        wave, sr = sf.read(io.BytesIO(audio))
        if sr != sample_rate:
            wave = librosa.resample(wave, sr, sample_rate)
    else:
        raise ValueError("input audio must be either a path or bytes")
    return wave

def read_raw_audio(audio, sample_rate=16000):
    if isinstance(audio, str):
        wave, _ = librosa.load(os.path.expanduser(audio), sr=sample_rate)
    elif isinstance(audio, bytes):
        wave, sr = sf.read(io.BytesIO(audio))
        if sr != sample_rate:
            wave = librosa.resample(wave, sr, sample_rate)
    else:
        raise ValueError("input audio must be either a path or bytes")
    return wave

#NoiseAugment
def get_white_noise(signal: np.ndarray, snr: float = 10):
    if snr < 0:
        return None
    RMS_s = math.sqrt(np.mean(signal ** 2))
    # RMS values of noise
    RMS_n = math.sqrt(RMS_s ** 2 / (pow(10, snr / 20)))
    # Additive white gausian noise. Thereore mean=0
    # Because sample length is large (typically > 40000)
    # we can use the population formula for standard daviation.
    # because mean=0 STD=RMS
    STD_n = RMS_n
    noise = np.random.normal(0, STD_n, signal.shape[0])
    return noise

def get_noise_from_sound(signal: np.ndarray, noise: np.ndarray, snr: float = 10):
    if len(noise) <= len(signal) or snr < 0:
        return None

    idx = random.choice(range(0, len(noise) - len(signal)))  # randomly crop noise wav
    noise = noise[idx:idx + len(signal)]

    RMS_s = math.sqrt(np.mean(signal ** 2))
    # required RMS of noise
    RMS_n = math.sqrt(RMS_s ** 2 / (pow(10, snr / 20)))

    # current RMS of noise
    RMS_n_current = math.sqrt(np.mean(noise ** 2))
    noise = noise * (RMS_n / (RMS_n_current + 1e-20))

    return noise

def get_noise_from_sound(singal: np.ndarray, noise: np.ndarray, snr: float=10):
    if len(noise) <= len(signal) or snr < 0:
        return Noone

    idx = random.choice(range(0, len(noise) - len(signal)))
    noise = noise[idx:idx + len(signal)]

    RMS_s = math.sqrt(np.mean(signal ** 2))
    RMS_n = math.sqrt(RMS_s ** 2 / (pow(10, snr / 20)))

    RMS_n_current = math.sqrt(np.mean(noise ** 2))
    noise = noise * (RMS_n / ( RMS_n_current + 1e-20))

    return noise

def add_noise(signal: np.ndarray, noises: list, snr_list: list, max_noises: int, sample_rate=16000, *args, **kwargs):
    num_noises = random.randint(0, max_noises)
    if len(noises) < num_noises:
        num_noises = len(noises)
    random.shuffle(noises)
    selected_noises = random.choices(noises, k=num_noises)
    added_noises = []
    for noise_type in selected_noises:
        snr = random.choice(snr_list)
        if noise_type == "white_noise":
            noise = get_white_noise(signal, snr)
            if noise is not None:
                signal = np.add(signal, noise)
        else:
            noise = read_raw_audio(noise_type, sample_rate=sample_rate)
            noise = get_noise_from_sound(signal, noise, snr)
            if noise is not None:
                signal = np.add(signal, noise)
    return signal


def add_white_noise(signal: np.ndarray, snr_list: list, *args, **kwargs):
    snr = random.choice(snr_list)
    noise = get_white_noise(signal, snr)
    if noise is not None:
        signal = np.add(signal, noise)
    return signal

def add_realworld_noise(signal: np.ndarray, noises: list, snr_list: list, max_noises: int, sample_rate=16000, *args, **kwargs):
    num_noises = random.randint(0, max_noises)
    if len(noises) < num_noises:
        num_noises = len(noises)
    random.shuffle(noises)
    selected_noises = random.choices(noises, k=num_noises)
    for noise_type in selected_noises:
        snr = random.choice(snr_list)
        noise = read_raw_audio(noise_type, sample_rate=sample_rate)
        noise = get_noise_from_sound(signal, noise, snr)
        if noise is not None:
            signal = np.add(signal, noise)
    return signal

#DATASET

In [0]:
class SeganDataset:
    def __init__(self, clean_data_dir, noises_dir, noise_conf=DEFAULT_NOISE_CONF, window_size=2 ** 14, stride=0.5):
        assert os.path.exists(clean_data_dir) and os.path.exists(noises_dir)
        self.clean_data_dir = clean_data_dir
        self.noises_dir = glob.glob(os.path.join(noises_dir, "**", "*.wav"), recursive=True)
        #self.noises_dir = glob.glob(os.path.join(noises_dir, "*.wav"), recursive=True)
        self.window_size = window_size
        self.stride = stride
        self.noise_conf = append_default_keys_dict(DEFAULT_NOISE_CONF, noise_conf)


    def create(self, batch_size, coeff=0.97, sample_rate=16000, shuffle=True):
        assert os.path.isdir(self.clean_data_dir)
        
        def _gen_data():
            for clean_wav_path in glob.iglob(os.path.join(self.clean_data_dir, "**", "*.wav"), recursive=True):
                
                clean_wav = read_raw_audio(clean_wav_path, sample_rate=sample_rate)
                clean_slices = slice_signal(clean_wav, self.window_size, self.stride)

                # noisy_wav = read_raw_audio(noisy_wav_path, sample_rate=16000)
                noisy_wav = add_noise(clean_wav, self.noises_dir, snr_list=self.noise_conf["snr"],
                                      max_noises=3, sample_rate=sample_rate) #self.noise_conf["max_noises"]
                noisy_slices = slice_signal(noisy_wav, self.window_size, self.stride)

                for clean_slice, noisy_slice in zip(clean_slices, noisy_slices):
                    if len(clean_slice) == 0:
                        continue
                    yield preemphasis(clean_slice, coeff), preemphasis(noisy_slice, coeff)


        dataset = tf.data.Dataset.from_generator(
            _gen_data,
            output_types=(
                tf.float32,
                tf.float32
            ),
            output_shapes=(
                tf.TensorShape([self.window_size]),
                tf.TensorShape([self.window_size])
            )
        )
        if shuffle:
            dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
        dataset = dataset.batch(batch_size)
        # Prefetch to improve speed of input length
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
        return dataset

    def create_test(self, sample_rate=16000):
        if os.path.isdir(self.clean_data_dir):
            def _gen_data():
                for clean_wav_path in glob.iglob(os.path.join(self.clean_data_dir, "**", "*.wav"), recursive=True):
                    clean_wav = read_raw_audio(clean_wav_path, sample_rate=sample_rate)
                    noisy_wav = add_noise(clean_wav, self.noises_dir, snr_list=self.noise_conf["snr"],
                                          max_noises=self.noise_conf["max_noises"], sample_rate=sample_rate)
                    yield clean_wav, noisy_wav
        else:
            with open(self.clean_data_dir, "r", encoding="utf-8") as en:
                entries = en.read().splitlines()
                entries = entries[1:]

            def _gen_data():
                for clean_wav_path in entries:
                    clean_wav_path = clean_wav_path.split("\t")[0]
                    clean_wav = read_raw_audio(clean_wav_path, sample_rate=sample_rate)
                    noisy_wav = add_noise(clean_wav, self.noises_dir, snr_list=self.noise_conf["snr"],
                                          max_noises=self.noise_conf["max_noises"], sample_rate=sample_rate)
                    yield clean_wav, noisy_wav

        dataset = tf.data.Dataset.from_generator(
            _gen_data,
            output_types=(
                tf.float32,
                tf.float32
            ),
            output_shapes=(
                tf.TensorShape([None]),
                tf.TensorShape([None])
            )
        )
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
        return dataset

## GENERATOR

In [0]:
#create_generator
## GENERATOR
#DownConv
class DownConv(tf.keras.layers.Layer):
    def __init__(self, depth, kwidth=5, pool=2, name="downconv", **kwargs):
        super(DownConv, self).__init__(name=name, **kwargs)
        self.layer = tf.keras.layers.Conv2D(
            filters=depth,
            kernel_size=(kwidth, 1),
            strides=(pool, 1),
            padding="same",
            use_bias=True,
            kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
            bias_initializer=tf.keras.initializers.zeros
        )

    def call(self, inputs, training=False):
        return self.layer(inputs, training=training)

    def get_config(self):
        config = super(DownConv, self).get_config()
        config.update({"layer": self.layer})
        return config
   
    def from_config(self, config):
        return self(**config)

#DeConv
class DeConv(tf.keras.layers.Layer):
    def __init__(self, depth, kwidth=5, dilation=2, name="deconv", **kwargs):
        super(DeConv, self).__init__(name=name, **kwargs)
        self.layer = tf.keras.layers.Conv2DTranspose(
            filters=depth,
            kernel_size=(kwidth, 1),
            strides=(dilation, 1),
            padding="same",
            use_bias=True,
            kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
            bias_initializer=tf.keras.initializers.zeros
        )

    def call(self, inputs, training=False):
        return self.layer(inputs, training=training)

    def get_config(self):
        config = super(DeConv, self).get_config()
        config.update({"layer": self.layer})
        return config

    def from_config(self, config):
        return self(**config)

#Reshape1to3
class Reshape1to3(tf.keras.layers.Layer):
    def __init__(self, name="reshape_1_to_3", **kwargs):
        super(Reshape1to3, self).__init__(trainable=False, name=name, **kwargs)

    def call(self, inputs, training=False):
        batch_size = tf.shape(inputs)[0]
        width = inputs.get_shape().as_list()[1]
        return tf.reshape(inputs, [batch_size, width, 1, 1])

    def get_config(self):
        config = super(Reshape1to3, self).get_config()
        return config

    def from_config(self, config):
        return self(**config)

#Reshape3to1
class Reshape3to1(tf.keras.layers.Layer):
    def __init__(self, name="reshape_3_to_1", **kwargs):
        super(Reshape3to1, self).__init__(trainable=False, name=name, **kwargs)

    def call(self, inputs, training=False):
        batch_size = tf.shape(inputs)[0]
        width = inputs.get_shape().as_list()[1]
        return tf.reshape(inputs, [batch_size, width])

    def get_config(self):
        config = super(Reshape3to1, self).get_config()
        return config

    def from_config(self, config):
        return self(**config)

#SeganPrelu
class SeganPrelu(tf.keras.layers.Layer):
    def __init__(self, name="segan_prelu", **kwargs):
        super(SeganPrelu, self).__init__(trainable=True, name=name, **kwargs)

    def build(self, input_shape):
        self.alpha = self.add_weight(name="alpha",
                                     shape=input_shape[-1],
                                     initializer=tf.keras.initializers.zeros,
                                     dtype=tf.float32,
                                     trainable=True)

    def call(self, x, training=False):
        pos = tf.nn.relu(x)
        neg = self.alpha * (x - tf.abs(x)) * .5
        return pos + neg

    def get_config(self):
        config = super(SeganPrelu, self).get_config()
        return config

    def from_config(self, config):
        return self(**config)


class Z(tf.keras.layers.Layer):
    def __init__(self, mean=0., stddev=1., name="segan_z", **kwargs):
        self.mean = mean,
        self.stddev = stddev
        super(Z, self).__init__(name=name, **kwargs)

    def call(self, inputs, training=False):
        z = tf.random.normal(shape=tf.shape(inputs),
                             name="z", mean=self.mean, stddev=self.stddev)
        return tf.keras.layers.Concatenate(axis=3)([z, inputs])


def create_generator(g_enc_depths, window_size, kwidth=31, ratio=2):
    g_dec_depths = g_enc_depths.copy()
    g_dec_depths.reverse()
    g_dec_depths = g_dec_depths[1:] + [1]
    skips = []

    # input_shape = [batch_size, 16384]
    signal = tf.keras.Input(shape=(window_size,),
                            name="noisy_input", dtype=tf.float32)
    c = Reshape1to3("segan_g_reshape_input")(signal)
    # Encoder
    for layer_idx, layer_depth in enumerate(g_enc_depths):
        c = DownConv(depth=layer_depth,
                     kwidth=kwidth,
                     pool=ratio,
                     name=f"segan_g_downconv_{layer_idx}")(c)
        if layer_idx < len(g_enc_depths) - 1:
            skips.append(c)
        c = SeganPrelu(name=f"segan_g_downconv_prelu_{layer_idx}")(c)
    # Z
    output = Z()(c)
    # Decoder
    for layer_idx, layer_depth in enumerate(g_dec_depths):
        output = DeConv(depth=layer_depth,
                        kwidth=kwidth,
                        dilation=ratio,
                        name=f"segan_g_deconv_{layer_idx}")(output)
        output = SeganPrelu(name=f"segan_g_deconv_prelu_{layer_idx}")(output)
        if layer_idx < len(g_dec_depths) - 1:
            _skip = skips[-(layer_idx + 1)]
            output = tf.keras.layers.Concatenate(axis=3, name=f"concat_skip_{layer_idx}")([output, _skip])

    reshape_output = Reshape3to1("segan_g_reshape_output")(output)
    # output_shape = [batch_size, 16384]

    return tf.keras.Model(inputs=signal, outputs=reshape_output, name="segan_gen")


@tf.function
def generator_loss(y_true, y_pred, l1_lambda, d_fake_logit):
    l1_loss = l1_lambda * tf.reduce_mean(tf.abs(tf.subtract(y_pred, y_true)))
    g_adv_loss = tf.reduce_mean(tf.math.squared_difference(d_fake_logit, 1.))
    return l1_loss, g_adv_loss

##DICRIMINATOR

In [0]:
# create_discriminator
## DISCIMINATOR
# DownConv
class DownConv(tf.keras.layers.Layer):
    def __init__(self, depth, kwidth=5, pool=2, name="downconv", **kwargs):
        super(DownConv, self).__init__(name=name, **kwargs)
        self.layer = tf.keras.layers.Conv2D(
            filters=depth,
            kernel_size=(kwidth, 1),
            strides=(pool, 1),
            padding="same",
            use_bias=True,
            kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
            bias_initializer=tf.keras.initializers.zeros
        )


    def call(self, inputs, training=False):
        return self.layer(inputs, training=training)

    def get_config(self):
        config = super(DownConv, self).get_config()
        config.update({"layer": self.layer})
        return config

    def from_config(self, config):
        return self(**config)

# VirtualBatchNorm
class VirtualBatchNorm:
    def __init__(self, x, name, epsilon=1e-5):
        assert isinstance(epsilon, float)
        self.epsilon = epsilon
        self.name = name
        self.batch_size = tf.cast(tf.shape(x)[0], tf.float32)
        self.gamma = tf.Variable(
            initial_value=tf.random_normal_initializer(1., 0.02)(
                shape=[x.get_shape().as_list()[-1]]),
            name="gamma", trainable=True
        )
        self.beta = tf.Variable(
            initial_value=tf.constant_initializer(0.)(
                shape=[x.get_shape().as_list()[-1]]),
            name="beta", trainable=True
        )
        mean, var = tf.nn.moments(x, axes=[0, 1, 2], keepdims=False)
        self.mean = mean
        self.variance = var



    def __call__(self, x):
        new_coeff = 1. / (self.batch_size + 1.)
        old_coeff = 1. - new_coeff
        new_mean, new_var = tf.nn.moments(x, axes=[0, 1, 2], keepdims=False)
        new_mean = new_coeff * new_mean + old_coeff * self.mean
        new_var = new_coeff * new_var + old_coeff * self.variance
        return tf.nn.batch_normalization(x, mean=new_mean, variance=new_var,
                                         offset=self.beta, scale=self.gamma,
                                         variance_epsilon=self.epsilon)
        
# GaussianNoise,
class GaussianNoise(tf.keras.layers.Layer):
    def __init__(self, name, noise_std, **kwargs):
        super(GaussianNoise, self).__init__(trainable=False, name=name, **kwargs)
        self.noise_std = noise_std

    def call(self, inputs, training=False):
        noise = tf.keras.backend.random_normal(shape=tf.shape(inputs),
                                               mean=0.0, stddev=self.noise_std,
                                               dtype=tf.float32)
        return inputs + noise

# Reshape1to3
class Reshape1to3(tf.keras.layers.Layer):
    def __init__(self, name="reshape_1_to_3", **kwargs):
        super(Reshape1to3, self).__init__(trainable=False, name=name, **kwargs)

    def call(self, inputs, training=False):
        batch_size = tf.shape(inputs)[0]
        width = inputs.get_shape().as_list()[1]
        return tf.reshape(inputs, [batch_size, width, 1, 1])

    def get_config(self):
        config = super(Reshape1to3, self).get_config()
        return config

    def from_config(self, config):
        return self(**config)

def create_discriminator(d_num_fmaps, window_size, kwidth=31, ratio=2, noise_std=0.):
    clean_signal = tf.keras.Input(shape=(window_size,),
                                  name="disc_clean_input", dtype=tf.float32)
    noisy_signal = tf.keras.Input(shape=(window_size,),
                                  name="disc_noisy_input", dtype=tf.float32)

    clean_wav = Reshape1to3("segan_d_reshape_1_to_3_clean")(clean_signal)
    noisy_wav = Reshape1to3("segan_d_reshape_1_to_3_noisy")(noisy_signal)
    hi = tf.keras.layers.Concatenate(name="segan_d_concat_clean_noisy",
                                     axis=3)([clean_wav, noisy_wav])
    # after concatenation shape = [batch_size, 16384, 1, 2]

    hi = GaussianNoise(noise_std=noise_std, name="segan_d_gaussian_noise")(hi)

    for block_idx, nfmaps in enumerate(d_num_fmaps):
        hi = DownConv(depth=nfmaps, kwidth=kwidth, pool=ratio,
                      name=f"segan_d_downconv_{block_idx}")(hi)
        hi = VirtualBatchNorm(hi, name=f"segan_d_vbn_{block_idx}")(hi)
        hi = tf.keras.layers.LeakyReLU(alpha=0.3, name=f"segan_d_leakyrelu_{block_idx}")(hi)

    hi = tf.squeeze(hi, axis=2)
    hi = tf.keras.layers.Conv1D(filters=1, kernel_size=1,
                                strides=1, padding="same",
                                kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
                                name="segan_d_conv1d")(hi)
    hi = tf.squeeze(hi, axis=-1)
    hi = tf.keras.layers.Dense(1, name="segan_d_fully_connected")(hi)

    # output_shape = [1]
    return tf.keras.Model(inputs={
        "clean": clean_signal,
        "noisy": noisy_signal,
    }, outputs=hi, name="segan_disc")

#discriminator_loss
@tf.function
def discriminator_loss(d_real_logit, d_fake_logit):
    real_loss = tf.reduce_mean(tf.math.squared_difference(d_real_logit, 1.))
    fake_loss = tf.reduce_mean(tf.math.squared_difference(d_fake_logit, 0.))
    return real_loss + fake_loss

# SEGAN CONCISE

In [0]:
## SEGAN CONCISE
class SEGAN:
    def __init__(self, config_dic, training=True):
        self.g_enc_depths = [16, 32, 32, 64, 64, 128, 128, 256, 256, 512, 1024]
        self.d_num_fmaps = [16, 32, 32, 64, 64, 128, 128, 256, 256, 512, 1024]

        self.configs = get_segan_config(config_dic)

        self.kwidth = self.configs["kwidth"]
        self.ratio = self.configs["ratio"]
        self.noise_std = self.configs["noise_std"]
        self.l1_lambda = self.configs["l1_lambda"]
        self.coeff = self.configs["pre_emph"]
        self.window_size = self.configs["window_size"]
        self.stride = self.configs["stride"]
        self.deactivated_noise = False

        self.generator = create_generator(g_enc_depths=self.g_enc_depths,
                                          window_size=self.window_size,
                                          kwidth=self.kwidth, 
                                          ratio=self.ratio)
        

        if training:
            self.discriminator = create_discriminator(d_num_fmaps=self.d_num_fmaps,
                                                      window_size=self.window_size,
                                                      kwidth=self.kwidth,
                                                      ratio=self.ratio)

            self.generator_optimizer = tf.keras.optimizers.RMSprop(
                self.configs["g_learning_rate"])
            self.discriminator_optimizer = tf.keras.optimizers.RMSprop(
                self.configs["d_learning_rate"])

            self.writer = tf.summary.create_file_writer(self.configs["log_dir"])

            self.steps = tf.Variable(initial_value=0, trainable=False, shape=(), dtype=tf.int64)

            self.checkpoint = tf.train.Checkpoint(
                generator=self.generator,
                discriminator=self.discriminator,
                generator_optimizer=self.generator_optimizer,
                discriminator_optimizer=self.discriminator_optimizer,
                steps=self.steps
            )
            self.ckpt_manager = tf.train.CheckpointManager(
                self.checkpoint, self.configs["checkpoint_dir"], max_to_keep=5)

            print(self.generator.summary())
            print(self.discriminator.summary())

    def train(self, export_dir=None):
        train_dataset = SeganDataset(clean_data_dir=self.configs["clean_train_data_dir"],
                                     noises_dir=self.configs["noises_dir"],
                                     noise_conf=self.configs["noise_conf"],
                                     window_size=self.window_size, stride=self.stride)

        tf_train_dataset = train_dataset.create(self.configs["batch_size"], coeff=self.coeff,
                                                sample_rate=self.configs["sample_rate"])

        epochs = self.configs["num_epochs"]

        initial_epoch = 0

        if self.ckpt_manager.latest_checkpoint:
            initial_epoch = int(self.ckpt_manager.latest_checkpoint.split('-')[-1])
            # restoring the latest checkpoint in checkpoint_path
            self.checkpoint.restore(self.ckpt_manager.latest_checkpoint)

        @tf.function
        def train_step(clean_wavs, noisy_wavs):
            with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
                g_clean_wavs = self.generator(noisy_wavs, training=True)

                d_real_logit = self.discriminator({
                    "clean": clean_wavs,
                    "noisy": noisy_wavs,
                }, training=True)
                d_fake_logit = self.discriminator({
                    "clean": g_clean_wavs,
                    "noisy": noisy_wavs,
                }, training=True)

                _gen_l1_loss, _gen_adv_loss = generator_loss(y_true=clean_wavs,
                                                             y_pred=g_clean_wavs,
                                                             l1_lambda=self.l1_lambda,
                                                             d_fake_logit=d_fake_logit)

                _disc_loss = discriminator_loss(d_real_logit, d_fake_logit)

                _gen_loss = _gen_l1_loss + _gen_adv_loss

            gradients_of_generator = gen_tape.gradient(_gen_loss, self.generator.trainable_variables)
            gradients_of_discriminator = disc_tape.gradient(_disc_loss, self.discriminator.trainable_variables)

            self.generator_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))
            self.discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator,
                                                             self.discriminator.trainable_variables))
            return _gen_l1_loss, _gen_adv_loss, _disc_loss

        for epoch in range(initial_epoch, epochs):
            start = time.time()
            g_l1_loss = []
            g_adv_loss = []
            d_loss = []

            if epoch > self.configs["denoise_epoch"] and self.deactivated_noise == False:
                self.noise_std = self.configs["noise_decay"] * self.noise_std
                if self.noise_std < self.configs["noise_std_lbound"]:
                    self.noise_std = 0.
                    self.deactivated_noise = True

            for step, (clean_wav, noisy_wav) in tf_train_dataset.enumerate(start=0):
                substart = time.time()
                gen_l1_loss, gen_adv_loss, disc_loss = train_step(clean_wav, noisy_wav)
                g_l1_loss.append(gen_l1_loss)
                g_adv_loss.append(gen_adv_loss)
                d_loss.append(disc_loss)
                sys.stdout.write("\033[K")
                print(f"\rEpoch: {epoch + 1}/{epochs}, step: {step}/{self.steps.numpy() / (epoch + 1)}, "
                      f"duration: {(time.time() - substart):.2f}s, "
                      f"gen_l1_loss = {gen_l1_loss}, gen_adv_loss = {gen_adv_loss}, "
                      f"disc_loss = {disc_loss}", end="")
                if self.writer and step % 500 == 0:
                    with self.writer.as_default():
                        tf.summary.scalar("g_l1_loss", tf.reduce_mean(g_l1_loss), step=(self.steps + step))
                        tf.summary.scalar("g_adv_loss", tf.reduce_mean(g_adv_loss), step=(self.steps + step))
                        tf.summary.scalar("d_loss", tf.reduce_mean(d_loss), step=(self.steps + step))

            self.steps.assign((epoch + 1) * (step + 1))

            self.ckpt_manager.save()
            print(f"\nSaved checkpoint at epoch {epoch + 1}", flush=True)
            print(f"Time for epoch {epoch + 1} is {time.time() - start} secs")

        if export_dir:
            self.save(export_dir)

    def test(self, export_dir: str, output_file_dir: str):
        test_dataset = SeganDataset(clean_data_dir=self.configs["clean_test_data_dir"],
                                    noises_dir=self.configs["noises_dir"],
                                    noise_conf=self.configs["noise_conf"],
                                    window_size=self.window_size, stride=1)

        tf_test_dataset = test_dataset.create_test(sample_rate=self.configs["sample_rate"])

        msg = self.load_model(export_dir)
        if msg: raise Exception(msg)

        start = time.time()

        pesq_noisy = csig_noisy = cbak_noisy = covl_noisy = ssnr_noisy = 0
        pesq_gen = csig_gen = cbak_gen = covl_gen = ssnr_gen = 0

        try:
            from semetrics.main import pesq_mos as pesq, composite
            import soundfile as sf
        except ImportError as e:
            print(e)
            print("Please install https://github.com/usimarit/semetrics")
            return

        sr = self.configs["sample_rate"]

        def save_to_tmp(clean_signal, gen_signal, noisy_signal):
            sf.write("/tmp/clean_signal.wav", clean_signal, sr)
            sf.write("/tmp/gen_signal.wav", gen_signal, sr)
            sf.write("/tmp/noisy_signal.wav", noisy_signal, sr)

        for step, [clean_wav, noisy_wav] in tf_test_dataset.enumerate(start=1):
            step = float(step)
            gen_wav = self.generate(noisy_wav)
            clean_wav = clean_wav.numpy(); noisy_wav = noisy_wav.numpy()
            save_to_tmp(clean_wav, gen_wav, noisy_wav)

            pesq_gen += pesq("/tmp/clean_signal.wav", "/tmp/gen_signal.wav")
            pesq_gen += pesq("/tmp/clean_signal.wav", "/tmp/noisy_signal.wav")

            _csig_gen, _cbak_gen, _covl_gen, _ssnr_gen = composite("/tmp/clean_signal.wav", "/tmp/gen_signal.wav")
            csig_gen += _csig_gen; cbak_gen += _cbak_gen; covl_gen += _covl_gen; ssnr_gen += _ssnr_gen
            _csig_noisy, _cbak_noisy, _covl_noisy, _ssnr_noisy = composite("/tmp/clean_signal.wav", "/tmp/noisy_signal.wav")
            csig_noisy += _csig_noisy; cbak_noisy += _cbak_noisy; covl_noisy += _covl_noisy; ssnr_noisy += _ssnr_noisy

            print(f"\rPESQ_GEN = {(pesq_gen / step):.2f}, CSIG_GEN = {(csig_gen / step):.2f}, "
                  f"CBAK_GEN = {(cbak_gen / step):.2f}, COVL_GEN = {(covl_gen / step):.2f}, "
                  f"SSNR_GEN = {(ssnr_gen / step):.2f}", end="")

        with open(output_file_dir, "w", encoding="utf-8") as fo:
            fo.write(f"PESQ_GEN = {(pesq_gen / step):.2f}, CSIG_GEN = {(csig_gen / step):.2f}, "
                     f"CBAK_GEN = {(cbak_gen / step):.2f}, COVL_GEN = {(covl_gen / step):.2f}, "
                     f"SSNR_GEN = {(ssnr_gen / step):.2f}\n")
            fo.write(f"PESQ_NOISY = {(pesq_noisy / step):.2f}, CSIG_NOISY = {(csig_noisy / step):.2f}, "
                     f"CBAK_NOISY = {(cbak_noisy / step):.2f}, COVL_NOISY = {(covl_noisy / step):.2f}, "
                     f"SSNR_NOISY = {(ssnr_noisy / step):.2f}\n")

        print(f"\nTime for testing is {time.time() - start} secs")

    def save_from_checkpoint(self, export_dir):
        if self.ckpt_manager.latest_checkpoint:
            # restoring the latest checkpoint in checkpoint_path
            self.checkpoint.restore(self.ckpt_manager.latest_checkpoint)
        else:
            raise ValueError("Model is not trained")

        self.save(export_dir)

    def save(self, export_dir):
        self.generator.save_weights(export_dir)

    def load_model(self, export_dir):
        try:
            self.generator.load_weights(export_dir)
        except Exception as e:
            return f"Model is not trained: {e}"
        return None

    def generate(self, signal):
        signal = preemphasis(signal, self.configs["pre_emph"])
        slices = slice_signal(signal, self.window_size, stride=1)

        @tf.function
        def gen(sliced_signal):
            sliced_signal = tf.reshape(sliced_signal, [-1, self.window_size])
            g_wavs = self.generator(sliced_signal, training=False)
            return merge_slices(g_wavs)

        signal = gen(tf.convert_to_tensor(slices)).numpy()
        return deemphasis(signal, self.configs["pre_emph"])

    def convert_to_tflite(self, export_file, output_file_path):
        if os.path.exists(output_file_path):
            return
        msg = self.load_model(export_file)
        print(msg)
        converter = tf.lite.TFLiteConverter.from_keras_model(self.generator)
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.experimental_new_converter = True
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
        tflite_model = converter.convert()

        tflite_model_dir = pathlib.Path(os.path.dirname(output_file_path))
        tflite_model_dir.mkdir(exist_ok=True, parents=True)

        tflite_model_file = tflite_model_dir / f"{os.path.basename(output_file_path)}"
        tflite_model_file.write_bytes(tflite_model)

    def load_interpreter(self, export_dir):
        try:
            self.generator = tf.lite.Interpreter(model_path=export_dir)
            self.generator.allocate_tensors()
        except Exception as e:
            return f"Model is not trained: {e}"
        return None

    def generate_interpreter(self, signal):
        signal = preemphasis(signal, self.configs["pre_emph"])
        slices = slice_signal(signal, self.window_size, stride=1)
        slices = tf.reshape(slices, [-1, self.window_size])

        input_index = self.generator.get_input_details()[0]["index"]
        output_index = self.generator.get_output_details()[0]["index"]

        self.generator.set_tensor(input_index, slices)
        self.generator.invoke()

        pred = self.generator.get_tensor(output_index)
        pred = merge_slices(pred)
        return deemphasis(pred.numpy(), self.configs["pre_emph"])


In [25]:
segan = SEGAN(config_dic, training=True)

Model: "segan_gen"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
noisy_input (InputLayer)        [(None, 16384)]      0                                            
__________________________________________________________________________________________________
segan_g_reshape_input (Reshape1 (None, 16384, 1, 1)  0           noisy_input[0][0]                
__________________________________________________________________________________________________
segan_g_downconv_0 (DownConv)   (None, 8192, 1, 16)  512         segan_g_reshape_input[0][0]      
__________________________________________________________________________________________________
segan_g_downconv_prelu_0 (Segan (None, 8192, 1, 16)  16          segan_g_downconv_0[0][0]         
__________________________________________________________________________________________

#TRAINING SEGAN

In [0]:
export_dir = './drive/My Drive/trained/latest_model'
segan.save_from_checkpoint(export_dir)

In [0]:
segan.load_model(export_dir)

In [28]:
!tensorboard dev upload --logdir /content/drive/My\ Drive/trained/logs/

2020-06-12 08:08:16.020897: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1

***** TensorBoard Uploader *****

This will upload your TensorBoard logs to https://tensorboard.dev/ from
the following directory:

/content/drive/My Drive/trained/logs/

This TensorBoard will be visible to everyone. Do not upload sensitive
data.

Your use of this service is subject to Google's Terms of Service
<https://policies.google.com/terms> and Privacy Policy
<https://policies.google.com/privacy>, and TensorBoard.dev's Terms of Service
<https://tensorboard.dev/policy/terms/>.

This notice will not be shown again while you are logged into the uploader.
To log out, run `tensorboard dev auth revoke`.

Continue? (yes/NO) Traceback (most recent call last):
  File "/usr/local/bin/tensorboard", line 8, in <module>
    sys.exit(run_main())
  File "/usr/local/lib/python3.6/dist-packages/tensorboard/main.py", line 75, in run_main
    app.run(ten

# Live Speech Testing

In [0]:
# all imports
from IPython.display import Javascript
from google.colab import output
from base64 import b64decode
from io import BytesIO
!pip -q install pydub
from pydub import AudioSegment

RECORD = """
const sleep  = time => new Promise(resolve => setTimeout(resolve, time))
const b2text = blob => new Promise(resolve => {
  const reader = new FileReader()
  reader.onloadend = e => resolve(e.srcElement.result)
  reader.readAsDataURL(blob)
})
var record = time => new Promise(async resolve => {
  stream = await navigator.mediaDevices.getUserMedia({ audio: true })
  recorder = new MediaRecorder(stream)
  chunks = []
  recorder.ondataavailable = e => chunks.push(e.data)
  recorder.start()
  await sleep(time)
  recorder.onstop = async ()=>{
    blob = new Blob(chunks)
    text = await b2text(blob)
    resolve(text)
  }
  recorder.stop()
})
"""

def record_and_processed(sec=3):
  display(Javascript(RECORD))
  s = output.eval_js('record(%d)' % (sec*1000))
  b = b64decode(s.split(',')[1])
  with open('audio.wav','wb') as f:
    f.write(b)
  audio = AudioSegment.from_file(BytesIO(b))
  return audio

In [36]:
record_and_processed(sec=5)

<IPython.core.display.Javascript object>

In [0]:
from IPython.display import Audio
from scipy.io.wavfile import write
noisy_wav_path = './audio.wav'
sr = 16000
noisy_wav = read_raw_audio(noisy_wav_path, sample_rate=16000)

process = segan.generate(noisy_wav)
write("enhancement_audio.wav", sr, process)

In [38]:
Audio(process, rate = sr)

# CONCLUSION
- An end-to-end speech enhancement method implemented with VietNammese voice dataset.
- Fully conv encoder-decoder structure, make it run faster than recursive solution.
- Have to fine tune configuration of the model on this dataset.
- May try other gan method on this dataset for this problem.

#REFERENCE
 1.**Gan paper**
https://arxiv.org/abs/1406.2661

 2.**Segan paper**
https://arxiv.org/abs/1703.09452

 3.**Tensorflow 1 code**
https://github.com/santi-pdp/segan

 4.**Pytorch code**
https://github.com/santi-pdp/segan_pytorch

 5.**Tensorflow 2 code**
 https://github.com/usimarit/TiramisuASR
