In [None]:
import tensorflow as tf
from tensorflow import keras
from keras.layers import *
import keras.backend as K

BLOCK_LEN = 320

BLOCK_SHIFT = 160

FFT_LENGTH = tf.constant([512])


class DeepFilter(Layer):
    def __init__(self, m=3, n=3, **kwargs):
        super(DeepFilter, self).__init__(**kwargs)

        self.m = m

        self.n = n

    def call(self, inputs, *args, **kwargs):
        x_mic, x_mask = inputs

        x_mic = tf.expand_dims(x_mic, axis=-1)

        x_mask = Reshape((-1, x_mask.shape[2], x_mask.shape[3] // 3, 3))(x_mask)

        x_mask = Lambda(

            lambda x: K.reshape(x, [K.shape(x)[0], K.shape(x)[1], K.shape(x)[2], self.m, self.n]))(x_mask)

        x_mic = tf.pad(x_mic, paddings=tf.constant([[0, 0], [self.m - 1, 0], [1, 1], [0, 0]]))

        x_mic = tf.image.extract_patches(x_mic, sizes=[1, self.m, self.n, 1], strides=[1, 1, 1, 1],

                                         rates=[1, 1, 1, 1], padding='VALID')

        x_mic = Lambda(

            lambda x: K.reshape(x, [K.shape(x)[0], K.shape(x)[1], K.shape(x)[2], self.m, self.n]))(x_mic)

        res = x_mic * x_mask

        res = tf.reduce_sum(res, axis=-1)

        res = tf.reduce_sum(res, axis=-1)

        return res


def apply_mask(mic_input, pred_mask):
    pred_mask_real = pred_mask[:, :, :, :9]

    pred_mask_imag = pred_mask[:, :, :, 9:]

    mic_real = DeepFilter()([mic_input[:, :, :, 0], pred_mask_real])

    mic_imag = DeepFilter()([mic_input[:, :, :, 1], pred_mask_imag])

    mic_real = tf.expand_dims(mic_real, axis=-1)

    mic_imag = tf.expand_dims(mic_imag, axis=-1)

    return keras.layers.Concatenate()([mic_real, mic_imag])


def stft_layer(x):
    frames = tf.signal.frame(x[0], x[1], x[2])

    complex_stft = tf.signal.rfft(frames, fft_length=FFT_LENGTH)

    complex_stft = complex_stft[:, :, 1:]

    real = tf.math.real(complex_stft)

    imag = tf.math.imag(complex_stft)

    return real, imag


def istft_layer(x):
    istft = tf.signal.irfft(x[0], fft_length=FFT_LENGTH)

    waveform = tf.signal.overlap_and_add(istft, x[1])

    return waveform


def map_layer(mic_output):
    real = mic_output[..., 0]

    imag = mic_output[..., 1]

    output_spec = tf.complex(real, imag)

    return output_spec




def conv_gru(inputs=Input((None,)), block_length=None, block_shift=None):
    mic_real, mic_imag = Lambda(stft_layer)([inputs, block_length, block_shift])

    mic_imag = Lambda(

        lambda x: K.reshape(x, [K.shape(x)[0], K.shape(x)[1], K.shape(x)[2], 1]))(mic_imag)

    mic_real = Lambda(
        lambda x: K.reshape(x, [K.shape(x)[0], K.shape(x)[1], K.shape(x)[2], 1]))(mic_real)

    mic_input = keras.layers.Concatenate()([mic_real, mic_imag])


    enc_1 = tf.pad(mic_input, paddings=tf.constant([[0, 0], [2, 0], [1, 1], [0, 0]]))

    enc_1 = Conv2D(filters=4, kernel_size=(3, 3), strides=(1,1),  padding='valid')(enc_1)
    enc_1 = keras.layers.BatchNormalization()(enc_1)
    enc_1 = keras.layers.ReLU()(enc_1)

    skip_1 = Conv2D(filters=4, kernel_size=(1,1),strides=(1,1),  padding='valid')(enc_1) #999,256,4

    enc_2 = tf.pad(enc_1, paddings=tf.constant([[0, 0], [2, 0], [1, 1], [0, 0]]))

    enc_2 = Conv2D(filters=8, kernel_size=(3, 3), strides=(1,2),  padding='valid')(enc_2)
    enc_2 = keras.layers.BatchNormalization()(enc_2)
    enc_2 = keras.layers.ReLU()(enc_2)

    skip_2 = Conv2D(filters=8, kernel_size=(1,1),strides=(1,1),  padding='valid')(enc_2)

    enc_3 = tf.pad(enc_2, paddings=tf.constant([[0, 0], [2, 0], [1, 1], [0, 0]]))

    enc_3 = Conv2D(filters=16, kernel_size=(3, 3), strides=(1,2),  padding='valid')(enc_3)
    enc_3 = keras.layers.BatchNormalization()(enc_3)
    enc_3 = keras.layers.ReLU()(enc_3) #T,64, 16

    skip_3 = Conv2D(filters=16, kernel_size=(1,1),strides=(1,1),  padding='valid')(enc_3)


    grus = []
    for i in range(0,16):
        gru = GRU(units = 16, return_sequences=True)(enc_3[..., i])
        grus.append(gru)


    concat = tf.concat(grus, axis=-1)
    concat = keras.layers.BatchNormalization()(concat)

    gru= GRU(units= 64, return_sequences=True)(concat)
    gru = keras.layers.BatchNormalization()(gru)

    den = Dense(units=1024)(gru)
    den = keras.layers.BatchNormalization()(den)
    den = keras.layers.ReLU()(den)

    den = Reshape((-1,64,16))(den)


    dec_2 = Conv2DTranspose(filters=8, kernel_size=(3, 3), strides=(1,2), padding='valid')(den+skip_3)
    dec_2 = keras.layers.BatchNormalization()(dec_2)
    dec_2 = keras.layers.ReLU()(dec_2)

    dec_1 = Conv2DTranspose(filters=4, kernel_size=(3, 3), strides=(1,2),  padding='valid')(dec_2[:,:-2,:-1,:]+skip_2)
    dec_1 = keras.layers.BatchNormalization()(dec_1)
    dec_1 = keras.layers.ReLU()(dec_1)

    dec_conv = tf.pad(dec_1[:, :-2,:-1,:], paddings=tf.constant([[0, 0], [2, 0], [1, 1], [0, 0]]))
    skip_1 = tf.pad(skip_1, paddings=tf.constant([[0, 0], [2, 0], [1, 1], [0, 0]]))

    dec_conv = Conv2D(filters=18, kernel_size=(3,3), strides=(1,1),padding='valid')(dec_conv+skip_1)


    out = apply_mask(mic_input, dec_conv)

    out = tf.pad(out, paddings=tf.constant([[0, 0], [0, 0], [1, 0], [0, 0]]))

    output_spec = Lambda(map_layer)(out)

    output_wave_1 = Lambda(istft_layer, name='complex_output')([output_spec, block_shift])

    output_wave_2 = Lambda(istft_layer, name='raw_output')([output_spec, block_shift])

    cc_model = keras.models.Model(inputs=inputs, outputs=([output_wave_1, output_wave_2]))

    cc_model.summary()
    cc_model.save('conv_gru.h5')
    return cc_model


if __name__ == "__main__":
    print(conv_gru(block_length=320, block_shift=160))

    # X = tf.random.normal((4, 256, 256, 18))

    # Y = tf.random.normal((4, 256, 256, 2))

    # print(apply_mask(Y, X))

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, None)]               0         []                            
                                                                                                  
 lambda (Lambda)             ((None, None, 256),          0         ['input_1[0][0]']             
                              (None, None, 256))                                                  
                                                                                                  
 lambda_2 (Lambda)           (None, None, 256, 1)         0         ['lambda[0][0]']              
                                                                                                  
 lambda_1 (Lambda)           (None, None, 256, 1)         0         ['lambda[0][1]']          

  saving_api.save_model(


<keras.src.engine.functional.Functional object at 0x7da405aa7c40>


In [None]:
def frame_snr_cost_test(s_true, s_estimate):
    frame_length=320
    frame_shift=160
    s_estimate = s_estimate[:, :int(K.shape(s_true)[-1])]
    sig_power = tf.reduce_mean(tf.math.square(s_true), axis=-1, keepdims=True)
    snr = tf.zeros((int(K.shape(s_true)[0]), 1))
    count = tf.constant(0, dtype=snr.dtype)
    for i in range(0, int(K.shape(s_true)[-1])-frame_length+1, frame_shift):
        frame_true = s_true[:, i:i+frame_length]
        frame_estimate = s_estimate[:, i:i+frame_length]

        frame_snr = compute_frame_snr(frame_true, frame_estimate, sig_power)
        snr = snr + frame_snr
        count = count + 1

    mean_snr = snr/count

    return -mean_snr

def compute_frame_snr_dummy(f_true, f_estimate, sig_power):

    #tf.print(s_estimate.shape, s_true.shape)
    f_estimate = f_estimate[:, :int(K.shape(f_true)[-1])]

    num = int(tf.clip_by_value(tf.reduce_mean(tf.math.square(f_true), axis=-1, keepdims=True), clip_value_min = 1e-5, clip_value_max = 1e+4) * 2**32) / \
          int((tf.reduce_mean(tf.math.square(f_true - f_estimate), axis=-1, keepdims=True) + 1e-7) * 2**32)

    num = tf.math.log(num)

    denom = tf.math.log(tf.constant(10, dtype=num.dtype))
    #tf.print("P_sig", sig_power)
    #tf.print("P_segment", tf.reduce_mean(tf.math.square(f_true), axis=-1, keepdims=True))
    loss = -10 * (num / denom)
    loss = tf.cast(loss, dtype= tf.float32)
    #tf.print("Before",loss)
    loss = loss * (tf.reduce_mean(tf.math.square(f_true), axis=-1, keepdims=True) / tf.math.maximum(sig_power,1e-6))
    loss = tf.cast(loss, dtype= tf.float32)
    #tf.print("After",loss)
    return loss

In [None]:
import os
import librosa
import soundfile as sf
import tensorflow as tf
from tensorflow.keras.models import load_model
import keras.backend as K
# from dia_train_28 import DeepFilter
tf.keras.utils.get_custom_objects()['frame_snr_cost_test'] = frame_snr_cost_test
tf.keras.utils.get_custom_objects()['DeepFilter'] = DeepFilter
tf.keras.utils.get_custom_objects()['FFT_LENGTH'] = tf.constant([512])
tf.keras.utils.get_custom_objects()['K'] = K

loaded_model = load_model("/content/hifi_model_trained.h5")
print(loaded_model.summary)

  function = cls._parse_function_from_config(


<bound method Model.summary of <keras.src.engine.functional.Functional object at 0x7da48830a410>>


In [None]:
layer_weights = loaded_model.get_layer('conv2d_transpose_1').get_weights()

print(layer_weights[0])
print(layer_weights[0].shape)
print(layer_weights[1].shape)

[[[[ 0.20603593 -0.1053673   0.21307464  0.23260458  0.17352068
     0.22625953 -0.17755628 -0.11798199]
   [ 0.03249808 -0.09212871  0.03292048 -0.22919315  0.08419845
    -0.1842408   0.04824622 -0.16967453]
   [ 0.05117333  0.04646664  0.12970708  0.05191975 -0.11390295
    -0.04532446 -0.06852844 -0.0972741 ]
   [-0.1414966  -0.13771692  0.1442204  -0.22704923  0.2057147
    -0.12258978 -0.20247671 -0.14216329]]

  [[-0.09865525  0.2343484   0.229276    0.15289222  0.11777714
    -0.15227972 -0.04537116 -0.21296959]
   [ 0.20959625 -0.06644318 -0.01334101 -0.01173794  0.15321589
     0.07079036 -0.12906411 -0.11078808]
   [-0.08003378  0.15076679  0.11265761  0.01760438  0.05972962
    -0.16053289 -0.11137882 -0.10106172]
   [-0.02457727 -0.219209   -0.04814198  0.01746732  0.23087211
    -0.0159227   0.1132485   0.19268897]]

  [[-0.01951299 -0.04634837 -0.13448939  0.14898497 -0.00260332
    -0.23094498 -0.1155812   0.09015334]
   [-0.02473945  0.14935379  0.11190612 -0.01385557 

In [None]:
import numpy as np
data = np.random.rand(1,5,5,8)

data

array([[[[0.47082111, 0.56454155, 0.47411139, 0.39789071, 0.44853204,
          0.73068485, 0.86399994, 0.07422767],
         [0.84931383, 0.85477148, 0.3568433 , 0.81971868, 0.2452446 ,
          0.17581495, 0.23532371, 0.82130231],
         [0.10966626, 0.67651353, 0.9491342 , 0.29720435, 0.24681756,
          0.41709466, 0.57499046, 0.23735115],
         [0.70311151, 0.14250128, 0.21326755, 0.90256225, 0.37859084,
          0.69157972, 0.33527991, 0.22003314],
         [0.320075  , 0.6328385 , 0.24901225, 0.65323447, 0.08193617,
          0.87434319, 0.20142911, 0.84109871]],

        [[0.43992733, 0.90050966, 0.71342332, 0.35606619, 0.42981462,
          0.80004913, 0.49894221, 0.88943446],
         [0.48347193, 0.11624301, 0.5034592 , 0.60121954, 0.93783206,
          0.90958671, 0.43012235, 0.7001503 ],
         [0.26764639, 0.09322448, 0.04715708, 0.34360132, 0.30328439,
          0.78465257, 0.37492467, 0.39985707],
         [0.99590235, 0.0529896 , 0.54476707, 0.08893046, 0.31

In [None]:
import tensorflow as tf

model = tf.keras.Sequential()

# Add Conv2D layer to the model
model.add(tf.keras.layers.Conv2DTranspose(filters=4, kernel_size=(3, 3), strides=(1, 1), padding='valid', activation=None))

# Now the model has a single Conv2D layer
model.predict(data)
# Set weights and biases for the Conv2D layer
model.layers[0].set_weights(loaded_model.get_layer('conv2d_transpose_1').get_weights())
out = model.predict(data)
out



array([[[[ 0.31210956, -0.18001121, -0.01816782, -0.34914988],
         [ 0.43724567, -0.31607807, -0.04629875, -0.45520526],
         [ 0.06708038, -0.05458531, -0.29018027, -0.08902754],
         [ 0.8451784 , -0.32676837, -0.4320821 , -0.5229645 ],
         [ 0.05057544, -0.13383114, -0.6079838 , -0.40270492],
         [-0.13219623, -0.16147466, -0.40564448,  0.23582646],
         [-0.12130856, -0.09886925, -0.4468798 ,  0.16481067]],

        [[ 0.5216667 , -0.210163  ,  0.1774085 , -0.61449194],
         [ 0.96904325, -0.59908915,  0.41428813, -0.27566007],
         [ 0.5087993 ,  0.5904898 , -0.6637124 , -0.1496342 ],
         [ 0.37597913, -0.438268  ,  0.38345644, -0.06146175],
         [ 0.32808155,  0.04362009, -0.1572225 ,  0.16611546],
         [ 0.6260822 , -0.18594189, -0.13143562, -0.4311794 ],
         [-0.0018307 ,  0.04982366, -0.21289918,  0.0722415 ]],

        [[ 0.92146814, -0.09619468,  0.3218786 ,  0.0243309 ],
         [ 0.57857966, -0.92200637,  0.40681306, -0

In [None]:
out.shape

(1, 7, 7, 4)