Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use tensorflow to conv_stft? #10

Open
panhu opened this issue Aug 25, 2022 · 1 comment
Open

How to use tensorflow to conv_stft? #10

panhu opened this issue Aug 25, 2022 · 1 comment

Comments

@panhu
Copy link

panhu commented Aug 25, 2022

Hi,I use tensorflow to conv_stft like this:

def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
if win_type == 'None' or win_type is None:
window = np.ones(win_len)
else:
window = get_window(win_type, win_len, fftbins=True)**0.5

N = fft_len
fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
real_kernel = np.real(fourier_basis)
imag_kernel = np.imag(fourier_basis)
kernel = np.concatenate([real_kernel, imag_kernel], 1).T

if invers :
    kernel = np.linalg.pinv(kernel).T 

kernel = kernel*window
kernel = kernel[:, None, :]
return tf.convert_to_tensor(kernel,tf.float32)

import torch.nn.functional as F

class ConvSTFT(tf.keras.layers.Layer):

def __init__(self, win_len=400, win_inc=200, fft_len=512, win_type='hanning', feature_type='real', fix=True):
    super(ConvSTFT, self).__init__() 
    
    self.fft_len = fft_len
    
    kernel= init_kernels(win_len, win_inc, self.fft_len, win_type)
    print('................',kernel.shape)
    self.weight = tf.Variable(kernel)
    self.feature_type = feature_type
    self.stride = win_inc
    self.win_len = win_len
    self.dim = self.fft_len

def call(self, inputs):

    outputs = F.conv1d(inputs, self.weight, stride=self.stride)
     
    output_list = []
    print("...............",outputs)
    dim = self.dim//2+1
    real = outputs[:, :dim, :]
    imag = outputs[:, dim:, :]
    output_list = [real,imag]
    return output_list

It is right?

@hieule88
Copy link

why we need to use conv1d instead of normal STFT, is it the same ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants