Skip to content

Commit

Permalink
Merge pull request #4945 from fujimotos/sf/floordiv
Browse files Browse the repository at this point in the history
Fix '__floordiv__ is deprecated' warnings
  • Loading branch information
sw005320 committed Feb 20, 2023
2 parents 81a90d6 + 919b3e5 commit 179d1f8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
8 changes: 7 additions & 1 deletion espnet/nets/batch_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
from typing import Any, Dict, List, NamedTuple, Tuple

import torch
from packaging.version import parse as V
from torch.nn.utils.rnn import pad_sequence

from espnet.nets.beam_search import BeamSearch, Hypothesis

is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")


class BatchHypothesis(NamedTuple):
"""Batchfied/Vectorized hypothesis data type."""
Expand Down Expand Up @@ -99,7 +102,10 @@ def batch_beam(
# Because of the flatten above, `top_ids` is organized as:
# [hyp1 * V + token1, hyp2 * V + token2, ..., hypK * V + tokenK],
# where V is `self.n_vocab` and K is `self.beam_size`
prev_hyp_ids = top_ids // self.n_vocab
if is_torch_1_9_plus:
prev_hyp_ids = torch.div(top_ids, self.n_vocab, rounding_mode="trunc")
else:
prev_hyp_ids = top_ids // self.n_vocab
new_token_ids = top_ids % self.n_vocab
return prev_hyp_ids, new_token_ids, prev_hyp_ids, new_token_ids

Expand Down
10 changes: 9 additions & 1 deletion espnet2/layers/stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,15 @@ def forward(
pad = self.n_fft // 2
ilens = ilens + 2 * pad

olens = (ilens - self.n_fft) // self.hop_length + 1
if is_torch_1_9_plus:
olens = (
torch.div(
ilens - self.n_fft, self.hop_length, rounding_mode="trunc"
)
+ 1
)
else:
olens = (ilens - self.n_fft) // self.hop_length + 1
output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
else:
olens = None
Expand Down

0 comments on commit 179d1f8

Please sign in to comment.