Skip to content

Commit

Permalink
Check the value of n_shift == upsample_factor in GAN_TTS (#5299)
Browse files Browse the repository at this point in the history
Co-authored-by: Jiatong <728307998@qq.com>
  • Loading branch information
imdanboy and ftshijt committed Jul 24, 2023
1 parent 97080b4 commit c9cd4de
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions espnet2/gan_tts/espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ def __init__(
tts, "discriminator"
), "discriminator module must be registered as tts.discriminator"

if feats_extract is not None:
if hasattr(tts.generator, "vocoder"):
upsample_factor = tts.generator["vocoder"].upsample_factor
else:
upsample_factor = tts.generator.upsample_factor
assert (
feats_extract.get_parameters()["n_shift"] == upsample_factor
), "n_shift must be equal to upsample_factor"

def forward(
self,
text: torch.Tensor,
Expand Down

0 comments on commit c9cd4de

Please sign in to comment.