In [1]:
import tensorflow as tf
from tensorflow import keras
from keras.layers import *
import keras.backend as K

BLOCK_LEN = 320

BLOCK_SHIFT = 160

FFT_LENGTH = tf.constant([512])


class DeepFilter(Layer):
    def __init__(self, m=3, n=3, **kwargs):
        super(DeepFilter, self).__init__(**kwargs)

        self.m = m

        self.n = n

    def call(self, inputs, *args, **kwargs):
        x_mic, x_mask = inputs

        x_mic = tf.expand_dims(x_mic, axis=-1)

        x_mask = Reshape((-1, x_mask.shape[2], x_mask.shape[3] // 3, 3))(x_mask)

        x_mask = Lambda(

            lambda x: K.reshape(x, [K.shape(x)[0], K.shape(x)[1], K.shape(x)[2], self.m, self.n]))(x_mask)

        x_mic = tf.pad(x_mic, paddings=tf.constant([[0, 0], [self.m - 1, 0], [1, 1], [0, 0]]))

        x_mic = tf.image.extract_patches(x_mic, sizes=[1, self.m, self.n, 1], strides=[1, 1, 1, 1],

                                         rates=[1, 1, 1, 1], padding='VALID')

        x_mic = Lambda(

            lambda x: K.reshape(x, [K.shape(x)[0], K.shape(x)[1], K.shape(x)[2], self.m, self.n]))(x_mic)

        res = x_mic * x_mask

        res = tf.reduce_sum(res, axis=-1)

        res = tf.reduce_sum(res, axis=-1)

        return res


def apply_mask(mic_input, pred_mask):
    pred_mask_real = pred_mask[:, :, :, :9]

    pred_mask_imag = pred_mask[:, :, :, 9:]

    mic_real = DeepFilter()([mic_input[:, :, :, 0], pred_mask_real])

    mic_imag = DeepFilter()([mic_input[:, :, :, 1], pred_mask_imag])

    mic_real = tf.expand_dims(mic_real, axis=-1)

    mic_imag = tf.expand_dims(mic_imag, axis=-1)

    return keras.layers.Concatenate()([mic_real, mic_imag])


def stft_layer(x):
    frames = tf.signal.frame(x[0], x[1], x[2])

    complex_stft = tf.signal.rfft(frames, fft_length=FFT_LENGTH)

    complex_stft = complex_stft[:, :, 1:]

    real = tf.math.real(complex_stft)

    imag = tf.math.imag(complex_stft)

    return real, imag


def istft_layer(x):
    istft = tf.signal.irfft(x[0], fft_length=FFT_LENGTH)

    waveform = tf.signal.overlap_and_add(istft, x[1])

    return waveform


def map_layer(mic_output):
    real = mic_output[..., 0]

    imag = mic_output[..., 1]

    output_spec = tf.complex(real, imag)

    return output_spec




def conv_gru(inputs=Input((None,)), block_length=None, block_shift=None):
    mic_real, mic_imag = Lambda(stft_layer)([inputs, block_length, block_shift])

    mic_imag = Lambda(

        lambda x: K.reshape(x, [K.shape(x)[0], K.shape(x)[1], K.shape(x)[2], 1]))(mic_imag)

    mic_real = Lambda(
        lambda x: K.reshape(x, [K.shape(x)[0], K.shape(x)[1], K.shape(x)[2], 1]))(mic_real)

    mic_input = keras.layers.Concatenate()([mic_real, mic_imag])


    enc_1 = tf.pad(mic_input, paddings=tf.constant([[0, 0], [2, 0], [1, 1], [0, 0]]))

    enc_1 = Conv2D(filters=4, kernel_size=(3, 3), strides=(1,1),  padding='valid')(enc_1)
    enc_1 = keras.layers.BatchNormalization()(enc_1)
    enc_1 = keras.layers.ReLU()(enc_1)

    skip_1 = Conv2D(filters=4, kernel_size=(1,1),strides=(1,1),  padding='valid')(enc_1) #999,256,4

    enc_2 = tf.pad(enc_1, paddings=tf.constant([[0, 0], [2, 0], [1, 1], [0, 0]]))

    enc_2 = Conv2D(filters=8, kernel_size=(3, 3), strides=(1,2),  padding='valid')(enc_2)
    enc_2 = keras.layers.BatchNormalization()(enc_2)
    enc_2 = keras.layers.ReLU()(enc_2)

    skip_2 = Conv2D(filters=8, kernel_size=(1,1),strides=(1,1),  padding='valid')(enc_2)

    enc_3 = tf.pad(enc_2, paddings=tf.constant([[0, 0], [2, 0], [1, 1], [0, 0]]))

    enc_3 = Conv2D(filters=16, kernel_size=(3, 3), strides=(1,2),  padding='valid')(enc_3)
    enc_3 = keras.layers.BatchNormalization()(enc_3)
    enc_3 = keras.layers.ReLU()(enc_3) #T,64, 16

    skip_3 = Conv2D(filters=16, kernel_size=(1,1),strides=(1,1),  padding='valid')(enc_3)


    grus = []
    for i in range(0,16):
        gru = GRU(units = 16, return_sequences=True)(enc_3[..., i])
        grus.append(gru)


    concat = tf.concat(grus, axis=-1)
    concat = keras.layers.BatchNormalization()(concat)

    gru= GRU(units= 64, return_sequences=True)(concat)
    gru = keras.layers.BatchNormalization()(gru)

    den = Dense(units=1024)(gru)
    den = keras.layers.BatchNormalization()(den)
    den = keras.layers.ReLU()(den)

    den = Reshape((-1,64,16))(den)


    dec_2 = Conv2DTranspose(filters=8, kernel_size=(3, 3), strides=(1,2), padding='valid')(den+skip_3)
    dec_2 = keras.layers.BatchNormalization()(dec_2)
    dec_2 = keras.layers.ReLU()(dec_2)

    dec_1 = Conv2DTranspose(filters=4, kernel_size=(3, 3), strides=(1,2),  padding='valid')(dec_2[:,:-2,:-1,:]+skip_2)
    dec_1 = keras.layers.BatchNormalization()(dec_1)
    dec_1 = keras.layers.ReLU()(dec_1)

    dec_conv = tf.pad(dec_1[:, :-2,:-1,:], paddings=tf.constant([[0, 0], [2, 0], [1, 1], [0, 0]]))
    skip_1 = tf.pad(skip_1, paddings=tf.constant([[0, 0], [2, 0], [1, 1], [0, 0]]))

    dec_conv = Conv2D(filters=18, kernel_size=(3,3), strides=(1,1),padding='valid')(dec_conv+skip_1)


    out = apply_mask(mic_input, dec_conv)

    out = tf.pad(out, paddings=tf.constant([[0, 0], [0, 0], [1, 0], [0, 0]]))

    output_spec = Lambda(map_layer)(out)

    output_wave_1 = Lambda(istft_layer, name='complex_output')([output_spec, block_shift])

    output_wave_2 = Lambda(istft_layer, name='raw_output')([output_spec, block_shift])

    cc_model = keras.models.Model(inputs=inputs, outputs=([output_wave_1, output_wave_2]))

    cc_model.summary()
    cc_model.save('conv_gru.h5')
    return cc_model


if __name__ == "__main__":
    print(conv_gru(block_length=320, block_shift=160))

    # X = tf.random.normal((4, 256, 256, 18))

    # Y = tf.random.normal((4, 256, 256, 2))

    # print(apply_mask(Y, X))

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, None)]               0         []                            
                                                                                                  
 lambda (Lambda)             ((None, None, 256),          0         ['input_1[0][0]']             
                              (None, None, 256))                                                  
                                                                                                  
 lambda_2 (Lambda)           (None, None, 256, 1)         0         ['lambda[0][0]']              
                                                                                                  
 lambda_1 (Lambda)           (None, None, 256, 1)         0         ['lambda[0][1]']          

  saving_api.save_model(


<keras.src.engine.functional.Functional object at 0x7907cce47790>


In [2]:

def frame_snr_cost_test(s_true, s_estimate):
    frame_length=320
    frame_shift=160
    s_estimate = s_estimate[:, :int(K.shape(s_true)[-1])]
    sig_power = tf.reduce_mean(tf.math.square(s_true), axis=-1, keepdims=True)
    snr = tf.zeros((int(K.shape(s_true)[0]), 1))
    count = tf.constant(0, dtype=snr.dtype)
    for i in range(0, int(K.shape(s_true)[-1])-frame_length+1, frame_shift):
        frame_true = s_true[:, i:i+frame_length]
        frame_estimate = s_estimate[:, i:i+frame_length]

        frame_snr = compute_frame_snr(frame_true, frame_estimate, sig_power)
        snr = snr + frame_snr
        count = count + 1

    mean_snr = snr/count

    return -mean_snr

def compute_frame_snr_dummy(f_true, f_estimate, sig_power):

    #tf.print(s_estimate.shape, s_true.shape)
    f_estimate = f_estimate[:, :int(K.shape(f_true)[-1])]

    num = int(tf.clip_by_value(tf.reduce_mean(tf.math.square(f_true), axis=-1, keepdims=True), clip_value_min = 1e-5, clip_value_max = 1e+4) * 2**32) / \
          int((tf.reduce_mean(tf.math.square(f_true - f_estimate), axis=-1, keepdims=True) + 1e-7) * 2**32)

    num = tf.math.log(num)

    denom = tf.math.log(tf.constant(10, dtype=num.dtype))
    #tf.print("P_sig", sig_power)
    #tf.print("P_segment", tf.reduce_mean(tf.math.square(f_true), axis=-1, keepdims=True))
    loss = -10 * (num / denom)
    loss = tf.cast(loss, dtype= tf.float32)
    #tf.print("Before",loss)
    loss = loss * (tf.reduce_mean(tf.math.square(f_true), axis=-1, keepdims=True) / tf.math.maximum(sig_power,1e-6))
    loss = tf.cast(loss, dtype= tf.float32)
    #tf.print("After",loss)
    return loss

In [3]:
import os
import librosa
import soundfile as sf
import tensorflow as tf
from tensorflow.keras.models import load_model
import keras.backend as K
# from dia_train_28 import DeepFilter
tf.keras.utils.get_custom_objects()['frame_snr_cost_test'] = frame_snr_cost_test
tf.keras.utils.get_custom_objects()['DeepFilter'] = DeepFilter
tf.keras.utils.get_custom_objects()['FFT_LENGTH'] = tf.constant([512])
tf.keras.utils.get_custom_objects()['K'] = K

loaded_model = load_model("/content/hifi_model_trained.h5")
print(loaded_model.summary)

  function = cls._parse_function_from_config(


<bound method Model.summary of <keras.src.engine.functional.Functional object at 0x79085ac40f10>>


In [4]:
layer_weights = loaded_model.get_layer('conv2d_1').get_weights()

print(layer_weights)

[array([[[[ 0.43021345,  0.18396576, -0.6778008 , -0.6516555 ],
         [ 0.7108551 , -0.1714309 , -0.77727175, -0.5446728 ],
         [-0.21081854,  0.35873237,  0.00466848,  0.01241677],
         [ 0.29632774,  0.13928531, -0.10188177, -0.654173  ]]]],
      dtype=float32), array([-0.00198196, -0.00199867,  0.00196853,  0.00199211], dtype=float32)]


In [7]:
import numpy as np
conv_layer = tf.keras.layers.Conv2D(filters=4, kernel_size=(3, 3), strides=(1, 1), padding='valid', input_shape=(20,20,2))
data = np.random.rand(1,5,5,4)

data


array([[[[0.08735558, 0.02516036, 0.46663856, 0.97814474],
         [0.03029301, 0.09175156, 0.93132864, 0.46998856],
         [0.99934153, 0.94569276, 0.71094907, 0.8327293 ],
         [0.35891874, 0.64470921, 0.58847226, 0.49062343],
         [0.85021439, 0.31821496, 0.36390768, 0.80509833]],

        [[0.22988014, 0.20926591, 0.42101532, 0.68564857],
         [0.96302982, 0.80052675, 0.31212618, 0.81744366],
         [0.65513446, 0.36521154, 0.55334657, 0.61804096],
         [0.53257528, 0.30959641, 0.31400056, 0.6382605 ],
         [0.95577655, 0.86354521, 0.76594888, 0.73937023]],

        [[0.83042559, 0.44858606, 0.36167685, 0.55149401],
         [0.52312661, 0.4793265 , 0.7249799 , 0.14329753],
         [0.42841868, 0.90731582, 0.3450547 , 0.31774168],
         [0.43397553, 0.37322366, 0.06931127, 0.12144709],
         [0.94163643, 0.85786455, 0.2100153 , 0.77737175]],

        [[0.40395135, 0.85768075, 0.45353044, 0.76452005],
         [0.11921593, 0.25197029, 0.10307968, 0.39

In [9]:
model = tf.keras.Sequential()

# Add Conv2D layer to the model
model.add(tf.keras.layers.Conv2D(filters=4, kernel_size=(1, 1), strides=(1, 1), padding='valid', activation=None))

# Now the model has a single Conv2D layer
model.predict(data)
# Set weights and biases for the Conv2D layer
model.layers[0].set_weights(loaded_model.get_layer('conv2d_1').get_weights())
out = model.predict(data)
out



array([[[[ 0.24496031,  0.31339806, -0.1742742 , -0.70271957],
         [ 0.01920188,  0.38740543, -0.13341537, -0.3636128 ],
         [ 1.1970783 ,  0.39075238, -1.4919671 , -1.7002488 ],
         [ 0.632049  ,  0.2329477 , -0.7896594 , -0.8967005 ],
         [ 0.7518509 ,  0.34254363, -0.90197283, -1.2475328 ]],

        [[ 0.3600922 ,  0.25094935, -0.38439047, -0.70509714],
         [ 1.1578143 ,  0.26375848, -1.3548261 , -1.5924706 ],
         [ 0.60596514,  0.34050244, -0.7863345 , -1.0212857 ],
         [ 0.5701544 ,  0.24444506, -0.66321313, -0.92732614],
         [ 1.0806812 ,  0.40354723, -1.3888195 , -1.56536   ]],

        [[ 0.76133275,  0.28042987, -0.96406657, -1.1397735 ],
         [ 0.45342967,  0.29210034, -0.73638886, -0.6847218 ],
         [ 0.8487113 ,  0.08931278, -1.0244062 , -0.97495306],
         [ 0.47140417,  0.05563589, -0.59432626, -0.562682  ],
         [ 1.1990218 ,  0.20778146, -1.3812869 , -1.5848138 ]],

        [[ 0.91242594,  0.19446403, -1.0142543 , 

In [11]:
# Print the weights
weights, biases = model.layers[0].get_weights()
print("Weights:")
print(weights)
print("Biases:")
print(biases)

Weights:
[[[[ 0.43021345  0.18396576 -0.6778008  -0.6516555 ]
   [ 0.7108551  -0.1714309  -0.77727175 -0.5446728 ]
   [-0.21081854  0.35873237  0.00466848  0.01241677]
   [ 0.29632774  0.13928531 -0.10188177 -0.654173  ]]]]
Biases:
[-0.00198196 -0.00199867  0.00196853  0.00199211]


In [10]:
out.shape

(1, 5, 5, 4)