In [1]:
import tensorflow as tf
from tensorflow import keras
from keras.layers import *
import keras.backend as K

BLOCK_LEN = 320

BLOCK_SHIFT = 160

FFT_LENGTH = tf.constant([512])


class DeepFilter(Layer):
    def __init__(self, m=3, n=3, **kwargs):
        super(DeepFilter, self).__init__(**kwargs)

        self.m = m

        self.n = n

    def call(self, inputs, *args, **kwargs):
        x_mic, x_mask = inputs

        x_mic = tf.expand_dims(x_mic, axis=-1)

        x_mask = Reshape((-1, x_mask.shape[2], x_mask.shape[3] // 3, 3))(x_mask)

        x_mask = Lambda(

            lambda x: K.reshape(x, [K.shape(x)[0], K.shape(x)[1], K.shape(x)[2], self.m, self.n]))(x_mask)

        x_mic = tf.pad(x_mic, paddings=tf.constant([[0, 0], [self.m - 1, 0], [1, 1], [0, 0]]))

        x_mic = tf.image.extract_patches(x_mic, sizes=[1, self.m, self.n, 1], strides=[1, 1, 1, 1],

                                         rates=[1, 1, 1, 1], padding='VALID')

        x_mic = Lambda(

            lambda x: K.reshape(x, [K.shape(x)[0], K.shape(x)[1], K.shape(x)[2], self.m, self.n]))(x_mic)

        res = x_mic * x_mask

        res = tf.reduce_sum(res, axis=-1)

        res = tf.reduce_sum(res, axis=-1)

        return res


def apply_mask(mic_input, pred_mask):
    pred_mask_real = pred_mask[:, :, :, :9]

    pred_mask_imag = pred_mask[:, :, :, 9:]

    mic_real = DeepFilter()([mic_input[:, :, :, 0], pred_mask_real])

    mic_imag = DeepFilter()([mic_input[:, :, :, 1], pred_mask_imag])

    mic_real = tf.expand_dims(mic_real, axis=-1)

    mic_imag = tf.expand_dims(mic_imag, axis=-1)

    return keras.layers.Concatenate()([mic_real, mic_imag])


def stft_layer(x):
    frames = tf.signal.frame(x[0], x[1], x[2])

    complex_stft = tf.signal.rfft(frames, fft_length=FFT_LENGTH)

    complex_stft = complex_stft[:, :, 1:]

    real = tf.math.real(complex_stft)

    imag = tf.math.imag(complex_stft)

    return real, imag


def istft_layer(x):
    istft = tf.signal.irfft(x[0], fft_length=FFT_LENGTH)

    waveform = tf.signal.overlap_and_add(istft, x[1])

    return waveform


def map_layer(mic_output):
    real = mic_output[..., 0]

    imag = mic_output[..., 1]

    output_spec = tf.complex(real, imag)

    return output_spec




def conv_gru(inputs=Input((None,)), block_length=None, block_shift=None):
    mic_real, mic_imag = Lambda(stft_layer)([inputs, block_length, block_shift])

    mic_imag = Lambda(

        lambda x: K.reshape(x, [K.shape(x)[0], K.shape(x)[1], K.shape(x)[2], 1]))(mic_imag)

    mic_real = Lambda(
        lambda x: K.reshape(x, [K.shape(x)[0], K.shape(x)[1], K.shape(x)[2], 1]))(mic_real)

    mic_input = keras.layers.Concatenate()([mic_real, mic_imag])


    enc_1 = tf.pad(mic_input, paddings=tf.constant([[0, 0], [2, 0], [1, 1], [0, 0]]))

    enc_1 = Conv2D(filters=4, kernel_size=(3, 3), strides=(1,1),  padding='valid')(enc_1)
    enc_1 = keras.layers.BatchNormalization()(enc_1)
    enc_1 = keras.layers.ReLU()(enc_1)

    skip_1 = Conv2D(filters=4, kernel_size=(1,1),strides=(1,1),  padding='valid')(enc_1) #999,256,4

    enc_2 = tf.pad(enc_1, paddings=tf.constant([[0, 0], [2, 0], [1, 1], [0, 0]]))

    enc_2 = Conv2D(filters=8, kernel_size=(3, 3), strides=(1,2),  padding='valid')(enc_2)
    enc_2 = keras.layers.BatchNormalization()(enc_2)
    enc_2 = keras.layers.ReLU()(enc_2)

    skip_2 = Conv2D(filters=8, kernel_size=(1,1),strides=(1,1),  padding='valid')(enc_2)

    enc_3 = tf.pad(enc_2, paddings=tf.constant([[0, 0], [2, 0], [1, 1], [0, 0]]))

    enc_3 = Conv2D(filters=16, kernel_size=(3, 3), strides=(1,2),  padding='valid')(enc_3)
    enc_3 = keras.layers.BatchNormalization()(enc_3)
    enc_3 = keras.layers.ReLU()(enc_3) #T,64, 16

    skip_3 = Conv2D(filters=16, kernel_size=(1,1),strides=(1,1),  padding='valid')(enc_3)


    grus = []
    for i in range(0,16):
        gru = GRU(units = 16, return_sequences=True)(enc_3[..., i])
        grus.append(gru)


    concat = tf.concat(grus, axis=-1)
    concat = keras.layers.BatchNormalization()(concat)

    gru= GRU(units= 64, return_sequences=True)(concat)
    gru = keras.layers.BatchNormalization()(gru)

    den = Dense(units=1024)(gru)
    den = keras.layers.BatchNormalization()(den)
    den = keras.layers.ReLU()(den)

    den = Reshape((-1,64,16))(den)


    dec_2 = Conv2DTranspose(filters=8, kernel_size=(3, 3), strides=(1,2), padding='valid')(den+skip_3)
    dec_2 = keras.layers.BatchNormalization()(dec_2)
    dec_2 = keras.layers.ReLU()(dec_2)

    dec_1 = Conv2DTranspose(filters=4, kernel_size=(3, 3), strides=(1,2),  padding='valid')(dec_2[:,:-2,:-1,:]+skip_2)
    dec_1 = keras.layers.BatchNormalization()(dec_1)
    dec_1 = keras.layers.ReLU()(dec_1)

    dec_conv = tf.pad(dec_1[:, :-2,:-1,:], paddings=tf.constant([[0, 0], [2, 0], [1, 1], [0, 0]]))
    skip_1 = tf.pad(skip_1, paddings=tf.constant([[0, 0], [2, 0], [1, 1], [0, 0]]))

    dec_conv = Conv2D(filters=18, kernel_size=(3,3), strides=(1,1),padding='valid')(dec_conv+skip_1)


    out = apply_mask(mic_input, dec_conv)

    out = tf.pad(out, paddings=tf.constant([[0, 0], [0, 0], [1, 0], [0, 0]]))

    output_spec = Lambda(map_layer)(out)

    output_wave_1 = Lambda(istft_layer, name='complex_output')([output_spec, block_shift])

    output_wave_2 = Lambda(istft_layer, name='raw_output')([output_spec, block_shift])

    cc_model = keras.models.Model(inputs=inputs, outputs=([output_wave_1, output_wave_2]))

    cc_model.summary()
    cc_model.save('conv_gru.h5')
    return cc_model


if __name__ == "__main__":
    print(conv_gru(block_length=320, block_shift=160))

    # X = tf.random.normal((4, 256, 256, 18))

    # Y = tf.random.normal((4, 256, 256, 2))

    # print(apply_mask(Y, X))

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, None)]               0         []                            
                                                                                                  
 lambda (Lambda)             ((None, None, 256),          0         ['input_1[0][0]']             
                              (None, None, 256))                                                  
                                                                                                  
 lambda_2 (Lambda)           (None, None, 256, 1)         0         ['lambda[0][0]']              
                                                                                                  
 lambda_1 (Lambda)           (None, None, 256, 1)         0         ['lambda[0][1]']          

  saving_api.save_model(


<keras.src.engine.functional.Functional object at 0x792ca7a58610>


In [2]:

def frame_snr_cost_test(s_true, s_estimate):
    frame_length=320
    frame_shift=160
    s_estimate = s_estimate[:, :int(K.shape(s_true)[-1])]
    sig_power = tf.reduce_mean(tf.math.square(s_true), axis=-1, keepdims=True)
    snr = tf.zeros((int(K.shape(s_true)[0]), 1))
    count = tf.constant(0, dtype=snr.dtype)
    for i in range(0, int(K.shape(s_true)[-1])-frame_length+1, frame_shift):
        frame_true = s_true[:, i:i+frame_length]
        frame_estimate = s_estimate[:, i:i+frame_length]

        frame_snr = compute_frame_snr(frame_true, frame_estimate, sig_power)
        snr = snr + frame_snr
        count = count + 1

    mean_snr = snr/count

    return -mean_snr

def compute_frame_snr_dummy(f_true, f_estimate, sig_power):

    #tf.print(s_estimate.shape, s_true.shape)
    f_estimate = f_estimate[:, :int(K.shape(f_true)[-1])]

    num = int(tf.clip_by_value(tf.reduce_mean(tf.math.square(f_true), axis=-1, keepdims=True), clip_value_min = 1e-5, clip_value_max = 1e+4) * 2**32) / \
          int((tf.reduce_mean(tf.math.square(f_true - f_estimate), axis=-1, keepdims=True) + 1e-7) * 2**32)

    num = tf.math.log(num)

    denom = tf.math.log(tf.constant(10, dtype=num.dtype))
    #tf.print("P_sig", sig_power)
    #tf.print("P_segment", tf.reduce_mean(tf.math.square(f_true), axis=-1, keepdims=True))
    loss = -10 * (num / denom)
    loss = tf.cast(loss, dtype= tf.float32)
    #tf.print("Before",loss)
    loss = loss * (tf.reduce_mean(tf.math.square(f_true), axis=-1, keepdims=True) / tf.math.maximum(sig_power,1e-6))
    loss = tf.cast(loss, dtype= tf.float32)
    #tf.print("After",loss)
    return loss

In [3]:
import os
import librosa
import soundfile as sf
import tensorflow as tf
from tensorflow.keras.models import load_model
import keras.backend as K
# from dia_train_28 import DeepFilter
tf.keras.utils.get_custom_objects()['frame_snr_cost_test'] = frame_snr_cost_test
tf.keras.utils.get_custom_objects()['DeepFilter'] = DeepFilter
tf.keras.utils.get_custom_objects()['FFT_LENGTH'] = tf.constant([512])
tf.keras.utils.get_custom_objects()['K'] = K

loaded_model = load_model("/content/drive/MyDrive/model_weights/hifi_model_trained.h5")
print(loaded_model.summary)

  function = cls._parse_function_from_config(


<bound method Model.summary of <keras.src.engine.functional.Functional object at 0x792d42959ae0>>


In [4]:
layer_weights = loaded_model.get_layer('gru').get_weights()

print(layer_weights[0].shape)
print(layer_weights[0])

print(layer_weights[1].shape)
print(layer_weights[1])

print(layer_weights[2].shape)
print(layer_weights[2])


(64, 48)
[[-0.1434389  -0.20858648  0.20759891 ...  0.01773929  0.04577521
  -0.07254796]
 [-0.10140295 -0.06223341 -0.00511233 ... -0.1802425  -0.13721494
  -0.14422353]
 [ 0.1845568  -0.20590062 -0.1926724  ... -0.00490277 -0.08100107
   0.15629466]
 ...
 [-0.01744551  0.2267371  -0.19957733 ... -0.10138661  0.07186574
   0.06788415]
 [-0.01311083  0.05433643  0.07957821 ... -0.02163168  0.20765662
  -0.15125848]
 [ 0.1754362  -0.12720212 -0.12106782 ...  0.00886258 -0.18973505
   0.17253236]]
(16, 48)
[[-0.40212885  0.0066552   0.28615785 -0.06380619 -0.05344085  0.02687167
   0.1000255  -0.05071123  0.12062956 -0.013603    0.03999829  0.18187569
   0.14805487  0.00278836 -0.03221505  0.12844019  0.15025136  0.16567063
   0.01864385 -0.26977506 -0.15789211  0.09496179  0.14653824  0.22047243
  -0.22723769  0.1691048  -0.22099207 -0.04839944  0.08146601 -0.27499515
   0.02382363  0.11896657 -0.2273812  -0.19150305  0.09037432 -0.03927749
   0.01974286 -0.17716427 -0.04934331  0.07091

In [17]:
import numpy as np
data = np.random.rand(1,1,64)

data

array([[[0.47637218, 0.47232874, 0.10006439, 0.65960998, 0.18673663,
         0.87878256, 0.66238111, 0.76815995, 0.51716311, 0.70981192,
         0.36708417, 0.68924672, 0.3011454 , 0.77961593, 0.25662221,
         0.04709273, 0.61420269, 0.59986929, 0.15164045, 0.0170851 ,
         0.95081352, 0.0607439 , 0.6265779 , 0.30833318, 0.33380683,
         0.18424628, 0.90719183, 0.52202839, 0.62564159, 0.21200452,
         0.81683269, 0.54351541, 0.48566905, 0.88709811, 0.76747826,
         0.36386286, 0.19028927, 0.25362354, 0.52663637, 0.87948085,
         0.97126507, 0.59546867, 0.89165427, 0.57032343, 0.11570343,
         0.12485452, 0.71210728, 0.48673068, 0.96335789, 0.87646222,
         0.68542523, 0.13255757, 0.72568856, 0.91785024, 0.27703435,
         0.89551934, 0.35075227, 0.89713249, 0.9727304 , 0.22751713,
         0.97903469, 0.37426406, 0.03749561, 0.31970003]]])

In [46]:
import tensorflow as tf
import numpy as np

# Create a Sequential model
model = tf.keras.Sequential()

# Add a GRU layer to the model with matching units and input shape
model.add(tf.keras.layers.GRU(units=16, activation = "tanh", return_sequences=True, input_shape=(None, 64)))

# Set the weights for the GRU layer
model.layers[0].set_weights(loaded_model.get_layer('gru').get_weights())

# Now, you can predict using the model
out = model.predict(data)
print(out)

[[[ 0.3709669   0.32922623  0.652393   -0.20656379 -0.41090733
   -0.4643127   0.29308087  0.19394399 -0.04536588 -0.10882095
    0.24199013  0.30431995 -0.32257396  0.16878651 -0.41683164
   -0.3145412 ]]]
