In [1]:
from functools import partial
import torch
from einops import rearrange, pack, unpack

In [2]:
# Batch, Stereo Channel, Time
B, C, T = 8, 2, 16000

In [3]:
noise = torch.rand(B, C, T)

In [4]:
stft_n_fft = 2048
stft_hop_length = 512
stft_win_length = 2048
stft_normalized = False
center_padded = True

stft_kwargs = dict(
    n_fft = stft_n_fft,
    hop_length = stft_hop_length,
    win_length = stft_win_length,
    center = center_padded,
    normalized = stft_normalized
)

stft_window_fn = partial(torch.hann_window, stft_win_length)

```input``` will be padded on both sides so that the $t$-th frame is centered at time $t \times \text{hop_length}$.

In [5]:
print("STFT Frame Count: ", int(T / stft_hop_length + 1))

STFT Frame Count:  32


Lower case $c$ denotes the complex number, therefore we separate a single complex vector into two vectors, one as real number the other as imaginary number.

In [6]:
def pack_one(t, pattern):
    return pack([t], pattern)

def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

raw_audio1 = noise
raw_audio1, batch_audio_channel_packed_shape = pack_one(raw_audio1, '* t')

stft_window = stft_window_fn()
stft_repr1 = torch.stft(raw_audio1, **stft_kwargs, window=stft_window, return_complex=True)
stft_repr1 = torch.view_as_real(stft_repr1)
stft_repr1 = unpack_one(stft_repr1, batch_audio_channel_packed_shape, '* f t c')
print(stft_repr1.shape)

torch.Size([8, 2, 1025, 32, 2])


In [7]:
raw_audio2 = noise
raw_audio2 = rearrange(raw_audio2, 'b C t -> (b C) t')

stft_window = stft_window_fn()
stft_repr2 = torch.stft(raw_audio2, **stft_kwargs, window=stft_window, return_complex=True)
stft_repr2 = torch.view_as_real(stft_repr2)

stft_repr2 = rearrange(stft_repr2, '(b C) f m c -> b C f m c', C=C)
print(stft_repr2.shape)

torch.Size([8, 2, 1025, 32, 2])


In [8]:
torch.equal(stft_repr1, stft_repr2)

True