Skip to content

Commit

Permalink
fix the issue when using load_pretrained_model
Browse files Browse the repository at this point in the history
  • Loading branch information
kan-bayashi committed Jun 22, 2023
1 parent bfcb561 commit 271aa23
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions espnet2/gan_tts/hifigan/hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,43 +620,60 @@ def _load_state_dict_pre_hook(
See also:
- https://github.com/espnet/espnet/pull/5240
- https://github.com/espnet/espnet/pull/5249
- https://github.com/kan-bayashi/ParallelWaveGAN/pull/409
"""
current_module_keys = [x for x in state_dict.keys() if x.startswith(prefix)]
if self.use_weight_norm and not any(
["weight_g" in k for k in current_module_keys]
if self.use_weight_norm and any(
[k.endswith("weight") for k in current_module_keys]
):
logging.warning(
"It seems weight norm is not applied in the pretrained model but the"
" current model uses it. To keep the compatibility, we remove the norm"
" from the current model. This may causes training error due to the"
" parameter mismatch when finetuning. To avoid this issue, please"
" from the current model. This may cause unexpected training behavior due"
" to the parameter mismatch when finetuning. To avoid this issue, please"
" change the following parameters in config to false:\n"
" - discriminator_params.follow_official_norm\n"
" - discriminator_params.scale_discriminator_params.use_weight_norm\n"
" - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
" See also: https://github.com/espnet/espnet/pull/5240"
"\n"
"See also:\n"
" - https://github.com/espnet/espnet/pull/5240\n"
" - https://github.com/espnet/espnet/pull/5249"
)
self.remove_weight_norm()
self.use_weight_norm = False
for k in current_module_keys:
if k.endswith("weight_g") or k.endswith("weight_v"):
del state_dict[k]

if self.use_spectral_norm and not any(
["weight_u" in k for k in current_module_keys]
if self.use_spectral_norm and any(
[k.endswith("weight") for k in current_module_keys]
):
logging.warning(
"It seems spectral norm is not applied in the pretrained model but the"
" current model uses it. To keep the compatibility, we remove the norm"
" from the current model. This may causes training error due to the"
" parameter mismatch when finetuning. To avoid this issue, please"
" from the current model. This may cause unexpected training behavior due"
" to the parameter mismatch when finetuning. To avoid this issue, please"
" change the following parameters in config to false:\n"
" - discriminator_params.follow_official_norm\n"
" - discriminator_params.scale_discriminator_params.use_weight_norm\n"
" - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
" See also: https://github.com/espnet/espnet/pull/5240"
"\n"
"See also:\n"
" - https://github.com/espnet/espnet/pull/5240\n"
" - https://github.com/espnet/espnet/pull/5249"
)
self.remove_spectral_norm()
self.use_spectral_norm = False
for k in current_module_keys:
if (
k.endswith("weight_u")
or k.endswith("weight_v")
or k.endswith("weight_orig")
):
del state_dict[k]


class HiFiGANMultiScaleDiscriminator(torch.nn.Module):
Expand Down

0 comments on commit 271aa23

Please sign in to comment.