Skip to content

Commit

Permalink
re-fix norm compatibility in scale discriminator (#5249)
Browse files Browse the repository at this point in the history
  • Loading branch information
kan-bayashi committed Jun 23, 2023
1 parent 1968c60 commit e741743
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 14 deletions.
1 change: 1 addition & 0 deletions egs2/TEMPLATE/asr1/utils/download_from_google_drive.sh
4 changes: 2 additions & 2 deletions egs2/jvs/tts1/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ Since we use the same language data for fine-tuning, we need to use the token li
The downloaded pretrained model has `tokens_list` in the config, so first we create `tokens.txt` (`token_list`) from the config.

```sh
$ pyscripts/utils/make_token_list_from_config.py downloads/f3698edf589206588f58f5ec837fa516/exp/tts_train_vits_raw_phn_jaconv_pyopenjtalk_with_accent/config.yaml
$ pyscripts/utils/make_token_list_from_config.py downloads/f3698edf589206588f58f5ec837fa516/exp/tts_train_vits_raw_phn_jaconv_pyopenjtalk_accent_with_pause/config.yaml

# tokens.txt is created in model directory
$ ls downloads/f3698edf589206588f58f5ec837fa516/exp/exp/tts_train_vits_raw_phn_jaconv_pyopenjtalk_accent_with_pause
$ ls downloads/f3698edf589206588f58f5ec837fa516/exp/tts_train_vits_raw_phn_jaconv_pyopenjtalk_accent_with_pause
config.yaml images train.total_count.ave_10best.pth
```

Expand Down
41 changes: 29 additions & 12 deletions espnet2/gan_tts/hifigan/hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,43 +620,60 @@ def _load_state_dict_pre_hook(
See also:
- https://github.com/espnet/espnet/pull/5240
- https://github.com/espnet/espnet/pull/5249
- https://github.com/kan-bayashi/ParallelWaveGAN/pull/409
"""
current_module_keys = [x for x in state_dict.keys() if x.startswith(prefix)]
if self.use_weight_norm and not any(
["weight_g" in k for k in current_module_keys]
if self.use_weight_norm and any(
[k.endswith("weight") for k in current_module_keys]
):
logging.warning(
"It seems weight norm is not applied in the pretrained model but the"
" current model uses it. To keep the compatibility, we remove the norm"
" from the current model. This may causes training error due to the"
" parameter mismatch when finetuning. To avoid this issue, please"
" change the following parameters in config to false:\n"
" from the current model. This may cause unexpected behavior due to the"
" parameter mismatch in finetuning. To avoid this issue, please change"
" the following parameters in config to false:\n"
" - discriminator_params.follow_official_norm\n"
" - discriminator_params.scale_discriminator_params.use_weight_norm\n"
" - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
" See also: https://github.com/espnet/espnet/pull/5240"
"\n"
"See also:\n"
" - https://github.com/espnet/espnet/pull/5240\n"
" - https://github.com/espnet/espnet/pull/5249"
)
self.remove_weight_norm()
self.use_weight_norm = False
for k in current_module_keys:
if k.endswith("weight_g") or k.endswith("weight_v"):
del state_dict[k]

if self.use_spectral_norm and not any(
["weight_u" in k for k in current_module_keys]
if self.use_spectral_norm and any(
[k.endswith("weight") for k in current_module_keys]
):
logging.warning(
"It seems spectral norm is not applied in the pretrained model but the"
" current model uses it. To keep the compatibility, we remove the norm"
" from the current model. This may causes training error due to the"
" parameter mismatch when finetuning. To avoid this issue, please"
" change the following parameters in config to false:\n"
" from the current model. This may cause unexpected behavior due to the"
" parameter mismatch in finetuning. To avoid this issue, please change"
" the following parameters in config to false:\n"
" - discriminator_params.follow_official_norm\n"
" - discriminator_params.scale_discriminator_params.use_weight_norm\n"
" - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
" See also: https://github.com/espnet/espnet/pull/5240"
"\n"
"See also:\n"
" - https://github.com/espnet/espnet/pull/5240\n"
" - https://github.com/espnet/espnet/pull/5249"
)
self.remove_spectral_norm()
self.use_spectral_norm = False
for k in current_module_keys:
if (
k.endswith("weight_u")
or k.endswith("weight_v")
or k.endswith("weight_orig")
):
del state_dict[k]


class HiFiGANMultiScaleDiscriminator(torch.nn.Module):
Expand Down

0 comments on commit e741743

Please sign in to comment.