Skip to content

Commit

Permalink
fix the issue when using load_pretrained_model
Browse files Browse the repository at this point in the history
  • Loading branch information
kan-bayashi committed Jun 22, 2023
1 parent bfcb561 commit b685366
Showing 1 changed file with 29 additions and 12 deletions.
41 changes: 29 additions & 12 deletions espnet2/gan_tts/hifigan/hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,43 +620,60 @@ def _load_state_dict_pre_hook(
See also:
- https://github.com/espnet/espnet/pull/5240
- https://github.com/espnet/espnet/pull/5249
- https://github.com/kan-bayashi/ParallelWaveGAN/pull/409
"""
current_module_keys = [x for x in state_dict.keys() if x.startswith(prefix)]
if self.use_weight_norm and not any(
["weight_g" in k for k in current_module_keys]
if self.use_weight_norm and any(
[k.endswith("weight") for k in current_module_keys]
):
logging.warning(
"It seems weight norm is not applied in the pretrained model but the"
" current model uses it. To keep the compatibility, we remove the norm"
" from the current model. This may causes training error due to the"
" parameter mismatch when finetuning. To avoid this issue, please"
" change the following parameters in config to false:\n"
" from the current model. This may cause unexpected behavior due to the"
" parameter mismatch in finetuning. To avoid this issue, please change"
" the following parameters in config to false:\n"
" - discriminator_params.follow_official_norm\n"
" - discriminator_params.scale_discriminator_params.use_weight_norm\n"
" - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
" See also: https://github.com/espnet/espnet/pull/5240"
"\n"
"See also:\n"
" - https://github.com/espnet/espnet/pull/5240\n"
" - https://github.com/espnet/espnet/pull/5249"
)
self.remove_weight_norm()
self.use_weight_norm = False
for k in current_module_keys:
if k.endswith("weight_g") or k.endswith("weight_v"):
del state_dict[k]

Check warning on line 649 in espnet2/gan_tts/hifigan/hifigan.py

View check run for this annotation

Codecov / codecov/patch

espnet2/gan_tts/hifigan/hifigan.py#L647-L649

Added lines #L647 - L649 were not covered by tests

if self.use_spectral_norm and not any(
["weight_u" in k for k in current_module_keys]
if self.use_spectral_norm and any(
[k.endswith("weight") for k in current_module_keys]
):
logging.warning(
"It seems spectral norm is not applied in the pretrained model but the"
" current model uses it. To keep the compatibility, we remove the norm"
" from the current model. This may causes training error due to the"
" parameter mismatch when finetuning. To avoid this issue, please"
" change the following parameters in config to false:\n"
" from the current model. This may cause unexpected behavior due to the"
" parameter mismatch in finetuning. To avoid this issue, please change"
" the following parameters in config to false:\n"
" - discriminator_params.follow_official_norm\n"
" - discriminator_params.scale_discriminator_params.use_weight_norm\n"
" - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
" See also: https://github.com/espnet/espnet/pull/5240"
"\n"
"See also:\n"
" - https://github.com/espnet/espnet/pull/5240\n"
" - https://github.com/espnet/espnet/pull/5249"
)
self.remove_spectral_norm()
self.use_spectral_norm = False
for k in current_module_keys:
if (

Check warning on line 671 in espnet2/gan_tts/hifigan/hifigan.py

View check run for this annotation

Codecov / codecov/patch

espnet2/gan_tts/hifigan/hifigan.py#L670-L671

Added lines #L670 - L671 were not covered by tests
k.endswith("weight_u")
or k.endswith("weight_v")
or k.endswith("weight_orig")
):
del state_dict[k]

Check warning on line 676 in espnet2/gan_tts/hifigan/hifigan.py

View check run for this annotation

Codecov / codecov/patch

espnet2/gan_tts/hifigan/hifigan.py#L676

Added line #L676 was not covered by tests


class HiFiGANMultiScaleDiscriminator(torch.nn.Module):
Expand Down

0 comments on commit b685366

Please sign in to comment.