Skip to content

Commit

Permalink
Merge branch 'master' into mSUPERB
Browse files Browse the repository at this point in the history
  • Loading branch information
guapaQAQ committed Feb 8, 2023
2 parents 96350bc + b041c9d commit 7950964
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions espnet2/layers/stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from espnet2.layers.inversible_interface import InversibleInterface
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask

is_torch_1_10_plus = V(torch.__version__) >= V("1.10.0")


is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")


Expand Down Expand Up @@ -89,8 +92,11 @@ def forward(
window = None

# For the compatibility of ARM devices, which do not support
# torch.stft() due to the lake of MKL.
if input.is_cuda or torch.backends.mkl.is_available():
# torch.stft() due to the lack of MKL (on older pytorch versions),
# there is an alternative replacement implementation with librosa.
# Note: pytorch >= 1.10.0 now has native support for FFT and STFT
# on all cpu targets including ARM.
if is_torch_1_10_plus or input.is_cuda or torch.backends.mkl.is_available():
stft_kwargs = dict(
n_fft=self.n_fft,
win_length=self.win_length,
Expand All @@ -111,9 +117,11 @@ def forward(
)

# use stft_kwargs to flexibly control different PyTorch versions' kwargs
# note: librosa does not support a win_length that is < n_ftt
# but the window can be manually padded (see below).
stft_kwargs = dict(
n_fft=self.n_fft,
win_length=self.win_length,
win_length=self.n_fft,
hop_length=self.hop_length,
center=self.center,
window=window,
Expand Down

0 comments on commit 7950964

Please sign in to comment.