Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: 'RADTTS' object has no attribute 'dur_pred_layer' #8

Open
windowzzhhuu opened this issue Jul 21, 2022 · 11 comments
Open

Comments

@windowzzhhuu
Copy link

windowzzhhuu commented Jul 21, 2022

I receive an error when I inferenced the text with the pretrained model

@rafaelvalle
Copy link
Contributor

what's the full command you used?

@eschmidbauer
Copy link

im running into same issue
python inference.py -c custom_finetuned/config.json -r custom_finetuned/model_237500 -v pretrained_models/hifigan_libritts100360_generator0p5.pt -k pretrained_models/hifigan_22khz_config.json -s 0 -t sentences.txt -o results/

@szprytny
Copy link

I encountered the same error when I was trying to run inference on trained model from step 1.

Train the decoder
python train.py -c config_ljs_decoder.json -p train_config.output_directory=outdir

when i trained model from step 2.

Train the attribute predictor: autoregressive flow (agap), bi-partite flow (bgap) or deterministic (dap)
python train.py -c config_ljs_{agap,bgap,dap}.json -p train_config.output_directory=outdir_wattr train_config.warmstart_checkpoint_path=model_path.pt

Then inference run without any problems.

@rafaelvalle
Copy link
Contributor

step 1 only trains the decoder, after which you would need to train the attribute predictors to perform inference.
step 2 only trains the attribute predictors.

if you're trying to fine-tune the pre-trained model on your data, you can warmstart from the pre-trained model and then either

  1. train only the decoder and then train only the attribute predictor (this is the default from scratch recipe)
  2. train the decoder and attribute predictors jointly, which requires setting unfreeze_modules to 'all', https://github.com/NVIDIA/radtts/blob/main/configs/config_ljs_decoder.json#L35

@rafaelvalle
Copy link
Contributor

make sure to use the correct configs during inference when using the model conditioned on f0 and energy: config_ljs_{agap,bgap,dap}.json.

@mepc36
Copy link

mepc36 commented Sep 14, 2022

step 1 only trains the decoder, after which you would need to train the attribute predictors to perform inference. step 2 only trains the attribute predictors.

if you're trying to fine-tune the pre-trained model on your data, you can warmstart from the pre-trained model and then either

  1. train only the decoder and then train only the attribute predictor (this is the default from scratch recipe)
  2. train the decoder and attribute predictors jointly, which requires setting unfreeze_modules to 'all', https://github.com/NVIDIA/radtts/blob/main/configs/config_ljs_decoder.json#L35

Hey Rafael, one quick question: for step #2 (the attribute prediction training), do I pass in radtts' pretrained model for the warmstart arg, or do I pass in the finetuned model I made in step #1 as the warmstart arg? Thanks man!

@mepc36
Copy link

mepc36 commented Sep 14, 2022

When I ran this command...

ubuntu:$ 
python3 \
> train.py \
> -c ./config_ljs_dap.json \
> -p train_config.output_directory=training-output \
> train_config.warmstart_checkpoint_path=radtts_pretrained_dap_model.pt

...I got this error log about Unexpected key(s) in state_dict:

Unable to init server: Could not connect: Connection refused
Unable to init server: Could not connect: Connection refused

(train.py:286816): Gdk-CRITICAL **: 11:58:55.624: gdk_cursor_new_for_display: assertion 'GDK_IS_DISPLAY (display)' failed
train_config.output_directory=/home/ubuntu/1-radtts-repo/6-training-output
output_directory=/home/ubuntu/1-radtts-repo/6-training-output
overriding output_directory with /home/ubuntu/1-radtts-repo/6-training-output
train_config.warmstart_checkpoint_path=/home/ubuntu/1-radtts-repo/1-models/1-radtts-models/1-radtts_pretrained_dap_model.pt
warmstart_checkpoint_path=/home/ubuntu/1-radtts-repo/1-models/1-radtts-models/1-radtts_pretrained_dap_model.pt
overriding warmstart_checkpoint_path with /home/ubuntu/1-radtts-repo/1-models/1-radtts-models/1-radtts_pretrained_dap_model.pt
{'train_config': {'output_directory': '/home/ubuntu/1-radtts-repo/6-training-output', 'epochs': 1002, 'optim_algo': 'RAdam', 'learning_rate': 0.0001, 'weight_decay': 1e-06, 'sigma': 1.0, 'iters_per_checkpoint': 2500, 'batch_size': 16, 'seed': None, 'checkpoint_path': '', 'ignore_layers': [], 'ignore_layers_warmstart': [], 'finetune_layers': [], 'include_layers': [], 'vocoder_config_path': '/home/ubuntu/1-radtts-repo/2-configs/2-hifigan-configs/uberduck-vocoder-notebook-lupe-fiasco-150-2022-09-12-A.json', 'vocoder_checkpoint_path': '/home/ubuntu/1-radtts-repo/1-models/2-hifigan-models/uberduck-vocoder-notebook-lupe-fiasco-150-2022-09-12-A', 'log_attribute_samples': False, 'log_decoder_samples': True, 'warmstart_checkpoint_path': '/home/ubuntu/1-radtts-repo/1-models/1-radtts-models/1-radtts_pretrained_dap_model.pt', 'use_amp': False, 'grad_clip_val': 1.0, 'loss_weights': {'blank_logprob': -1, 'ctc_loss_weight': 0.1, 'binarization_loss_weight': 1.0, 'dur_loss_weight': 1.0, 'f0_loss_weight': 1.0, 'energy_loss_weight': 1.0, 'vpred_loss_weight': 1.0}, 'binarization_start_iter': 6000, 'kl_loss_start_iter': 18000, 'unfreeze_modules': 'all'}, 'data_config': {'training_files': {'LJS': {'basedir': '3-filelists-lupe/', 'audiodir': 'wavs', 'filelist': 'training.txt', 'lmdbpath': ''}}, 'validation_files': {'LJS': {'basedir': '3-filelists-lupe/', 'audiodir': 'wavs', 'filelist': 'validation.txt', 'lmdbpath': ''}}, 'dur_min': 0.1, 'dur_max': 10.2, 'sampling_rate': 22050, 'filter_length': 1024, 'hop_length': 256, 'win_length': 1024, 'n_mel_channels': 80, 'mel_fmin': 0.0, 'mel_fmax': 8000.0, 'f0_min': 80.0, 'f0_max': 640.0, 'max_wav_value': 32768.0, 'use_f0': True, 'use_log_f0': 0, 'use_energy_avg': True, 'use_scaled_energy': True, 'symbol_set': 'radtts', 'cleaner_names': ['radtts_cleaners'], 'heteronyms_path': 'tts_text_processing/heteronyms', 'phoneme_dict_path': 'tts_text_processing/cmudict-0.7b', 'p_phoneme': 1.0, 'handle_phoneme': 'word', 'handle_phoneme_ambiguous': 'ignore', 'include_speakers': None, 'n_frames': -1, 'betabinom_cache_path': 'data_cache/', 'lmdb_cache_path': '', 'use_attn_prior_masking': True, 'prepend_space_to_text': True, 'append_space_to_text': True, 'add_bos_eos_to_text': False, 'betabinom_scaling_factor': 1.0, 'distance_tx_unvoiced': False, 'mel_noise_scale': 0.0}, 'dist_config': {'dist_backend': 'nccl', 'dist_url': 'tcp://localhost:54321'}, 'model_config': {'n_speakers': 1, 'n_speaker_dim': 16, 'n_text': 185, 'n_text_dim': 512, 'n_flows': 8, 'n_conv_layers_per_step': 4, 'n_mel_channels': 80, 'n_hidden': 1024, 'mel_encoder_n_hidden': 512, 'dummy_speaker_embedding': False, 'n_early_size': 2, 'n_early_every': 2, 'n_group_size': 2, 'affine_model': 'wavenet', 'include_modules': 'decatnvpred', 'scaling_fn': 'tanh', 'matrix_decomposition': 'LUS', 'learn_alignments': True, 'use_speaker_emb_for_alignment': False, 'attn_straight_through_estimator': True, 'use_context_lstm': True, 'context_lstm_norm': 'spectral', 'context_lstm_w_f0_and_energy': True, 'text_encoder_lstm_norm': 'spectral', 'n_f0_dims': 1, 'n_energy_avg_dims': 1, 'use_first_order_features': False, 'unvoiced_bias_activation': 'relu', 'decoder_use_partial_padding': True, 'decoder_use_unvoiced_bias': True, 'ap_pred_log_f0': True, 'ap_use_unvoiced_bias': True, 'ap_use_voiced_embeddings': True, 'dur_model_config': None, 'f0_model_config': None, 'energy_model_config': None, 'v_model_config': {'name': 'dap', 'hparams': {'n_speaker_dim': 16, 'take_log_of_input': False, 'bottleneck_hparams': {'in_dim': 512, 'reduction_factor': 16, 'norm': 'weightnorm', 'non_linearity': 'relu'}, 'arch_hparams': {'out_dim': 1, 'n_layers': 2, 'n_channels': 256, 'kernel_size': 3, 'p_dropout': 0.5, 'lstm_type': '', 'use_linear': 1}}}}}
> got rank 0 and world size 1 ...
/home/ubuntu/1-radtts-repo/6-training-output
Using seed 1113
Applying spectral norm to text encoder LSTM
Applying spectral norm to context encoder LSTM
/home/ubuntu/1-radtts-repo/common.py:391: UserWarning: torch.qr is deprecated in favor of torch.linalg.qr and will be removed in a future PyTorch release.
The boolean parameter 'some' has been replaced with a string parameter 'mode'.
Q, R = torch.qr(A, some)
should be replaced with
Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete') (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2497.)
  W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
Initializing RAdam optimizer
Traceback (most recent call last):
  File "train.py", line 498, in <module>
    train(n_gpus, rank, **train_config)
  File "train.py", line 353, in train
    model = warmstart(warmstart_checkpoint_path, model, include_layers,
  File "train.py", line 174, in warmstart
    model.load_state_dict(model_dict)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for RADTTS:
	Unexpected key(s) in state_dict: "dur_pred_layer.bottleneck_layer.projection_fn.conv.bias", "dur_pred_layer.bottleneck_layer.projection_fn.conv.weight_g", "dur_pred_layer.bottleneck_layer.projection_fn.conv.weight_v", "dur_pred_layer.feat_pred_fn.convolutions.0.bias", "dur_pred_layer.feat_pred_fn.convolutions.0.weight_g", "dur_pred_layer.feat_pred_fn.convolutions.0.weight_v", "dur_pred_layer.feat_pred_fn.convolutions.1.bias", "dur_pred_layer.feat_pred_fn.convolutions.1.weight_g", "dur_pred_layer.feat_pred_fn.convolutions.1.weight_v", "dur_pred_layer.feat_pred_fn.bilstm.weight_ih_l0", "dur_pred_layer.feat_pred_fn.bilstm.bias_ih_l0", "dur_pred_layer.feat_pred_fn.bilstm.bias_hh_l0", "dur_pred_layer.feat_pred_fn.bilstm.weight_ih_l0_reverse", "dur_pred_layer.feat_pred_fn.bilstm.bias_ih_l0_reverse", "dur_pred_layer.feat_pred_fn.bilstm.bias_hh_l0_reverse", "dur_pred_layer.feat_pred_fn.bilstm.weight_hh_l0_orig", "dur_pred_layer.feat_pred_fn.bilstm.weight_hh_l0_reverse_orig", "dur_pred_layer.feat_pred_fn.bilstm.weight_hh_l0_u", "dur_pred_layer.feat_pred_fn.bilstm.weight_hh_l0_v", "dur_pred_layer.feat_pred_fn.bilstm.weight_hh_l0_reverse_u", "dur_pred_layer.feat_pred_fn.bilstm.weight_hh_l0_reverse_v", "dur_pred_layer.feat_pred_fn.dense.weight", "dur_pred_layer.feat_pred_fn.dense.bias", "f0_pred_module.bottleneck_layer.projection_fn.conv.bias", "f0_pred_module.bottleneck_layer.projection_fn.conv.weight_g", "f0_pred_module.bottleneck_layer.projection_fn.conv.weight_v", "f0_pred_module.feat_pred_fn.convolutions.0.bias", "f0_pred_module.feat_pred_fn.convolutions.0.weight_g", "f0_pred_module.feat_pred_fn.convolutions.0.weight_v", "f0_pred_module.feat_pred_fn.convolutions.1.bias", "f0_pred_module.feat_pred_fn.convolutions.1.weight_g", "f0_pred_module.feat_pred_fn.convolutions.1.weight_v", "f0_pred_module.feat_pred_fn.bilstm.weight_ih_l0", "f0_pred_module.feat_pred_fn.bilstm.bias_ih_l0", "f0_pred_module.feat_pred_fn.bilstm.bias_hh_l0", "f0_pred_module.feat_pred_fn.bilstm.weight_ih_l0_reverse", "f0_pred_module.feat_pred_fn.bilstm.bias_ih_l0_reverse", "f0_pred_module.feat_pred_fn.bilstm.bias_hh_l0_reverse", "f0_pred_module.feat_pred_fn.bilstm.weight_hh_l0_orig", "f0_pred_module.feat_pred_fn.bilstm.weight_hh_l0_reverse_orig", "f0_pred_module.feat_pred_fn.bilstm.weight_hh_l0_u", "f0_pred_module.feat_pred_fn.bilstm.weight_hh_l0_v", "f0_pred_module.feat_pred_fn.bilstm.weight_hh_l0_reverse_u", "f0_pred_module.feat_pred_fn.bilstm.weight_hh_l0_reverse_v", "f0_pred_module.feat_pred_fn.dense.weight", "f0_pred_module.feat_pred_fn.dense.bias", "energy_pred_module.bottleneck_layer.projection_fn.conv.bias", "energy_pred_module.bottleneck_layer.projection_fn.conv.weight_g", "energy_pred_module.bottleneck_layer.projection_fn.conv.weight_v", "energy_pred_module.feat_pred_fn.convolutions.0.bias", "energy_pred_module.feat_pred_fn.convolutions.0.weight_g", "energy_pred_module.feat_pred_fn.convolutions.0.weight_v", "energy_pred_module.feat_pred_fn.convolutions.1.bias", "energy_pred_module.feat_pred_fn.convolutions.1.weight_g", "energy_pred_module.feat_pred_fn.convolutions.1.weight_v", "energy_pred_module.feat_pred_fn.bilstm.weight_ih_l0", "energy_pred_module.feat_pred_fn.bilstm.bias_ih_l0", "energy_pred_module.feat_pred_fn.bilstm.bias_hh_l0", "energy_pred_module.feat_pred_fn.bilstm.weight_ih_l0_reverse", "energy_pred_module.feat_pred_fn.bilstm.bias_ih_l0_reverse", "energy_pred_module.feat_pred_fn.bilstm.bias_hh_l0_reverse", "energy_pred_module.feat_pred_fn.bilstm.weight_hh_l0_orig", "energy_pred_module.feat_pred_fn.bilstm.weight_hh_l0_reverse_orig", "energy_pred_module.feat_pred_fn.bilstm.weight_hh_l0_u", "energy_pred_module.feat_pred_fn.bilstm.weight_hh_l0_v", "energy_pred_module.feat_pred_fn.bilstm.weight_hh_l0_reverse_u", "energy_pred_module.feat_pred_fn.bilstm.weight_hh_l0_reverse_v", "energy_pred_module.feat_pred_fn.dense.weight", "energy_pred_module.feat_pred_fn.dense.bias". 

But, when I re-ran this command using the model I trained in step 1 instead of the pretrained model, the command worked. So I think I answered my own question I just asked here...

do I pass in radtts' pretrained model for the warmstart arg, or do I pass in the finetuned model I made in step 1 as the warmstart arg?

...with the answer, "the finetuned model I made in step 1".

@xun-1999
Copy link

xun-1999 commented Apr 1, 2024

Training RADTTS (without pitch and energy conditioning):

  1. Train the decoder
    python train.py -c config_ljs_radtts.json -p train_config.output_directory=outdir
  2. Further train with the duration predictor
    python train.py -c config_ljs_radtts.json -p train_config.output_directory=outdir_dir train_config.warmstart_checkpoint_path=model_path.pt model_config.include_modules="decatndpm"

The original command of the second step is: python train.py -c config_ljs_radtts.json -p train_config.output_directory=outdir_dir train_config.warmstart_checkpoint_path=model_path.pt model_config.include_modules="decatndur"

We should change model _ config.include _ modules = "decatndur" in the original command to model _ config.include _ modules = "decatndpm".
11111

When Inference, the parameter "include_modules" of the configuration file should also be "decatndpm"

@Ashh-Z
Copy link

Ashh-Z commented May 19, 2024

Training RADTTS (without pitch and energy conditioning):

1. Train the decoder
   python train.py -c config_ljs_radtts.json -p train_config.output_directory=outdir

2. Further train with the duration predictor
   python train.py -c config_ljs_radtts.json -p train_config.output_directory=outdir_dir train_config.warmstart_checkpoint_path=model_path.pt model_config.include_modules="decatndpm"

The original command of the second step is: python train.py -c config_ljs_radtts.json -p train_config.output_directory=outdir_dir train_config.warmstart_checkpoint_path=model_path.pt model_config.include_modules="decatndur"

We should change model _ config.include _ modules = "decatndur" in the original command to model _ config.include _ modules = "decatndpm". 11111

When Inference, the parameter "include_modules" of the configuration file should also be "decatndpm"

Did you run inference without pitch and energy conditioning ? I was having a bit of trouble understanding the arguments

@xun-1999
Copy link

Training RADTTS (without pitch and energy conditioning):

1. Train the decoder
   python train.py -c config_ljs_radtts.json -p train_config.output_directory=outdir

2. Further train with the duration predictor
   python train.py -c config_ljs_radtts.json -p train_config.output_directory=outdir_dir train_config.warmstart_checkpoint_path=model_path.pt model_config.include_modules="decatndpm"

The original command of the second step is: python train.py -c config_ljs_radtts.json -p train_config.output_directory=outdir_dir train_config.warmstart_checkpoint_path=model_path.pt model_config.include_modules="decatndur"
We should change model _ config.include _ modules = "decatndur" in the original command to model _ config.include _ modules = "decatndpm". 11111
When Inference, the parameter "include_modules" of the configuration file should also be "decatndpm"

Did you run inference without pitch and energy conditioning ? I was having a bit of trouble understanding the arguments

Yes,I run.When inference without pitch and energy conditioning, it is necessary to change the "include_modules" parameter in the configuration file from' decatn' to' decatndpm'.As shown in the following figure:
2

or When inference without pitch and energy conditioning, using the config.json file under the folder where the model parameters are saved as the -c parameter of the reasoning command.The file path is shown in the figure below.
45

Order of Inference demo:
python inference.py -c outdir_dir/config.json -r RADTTS_PATH -v HG_PATH -k HG_CONFIG_PATH -t TEXT_PATH -s ljs --speaker_attributes ljs --speaker_text ljs -o results/

Sorry for my poor expressive ability, I hope the above description can help you.

@Ashh-Z
Copy link

Ashh-Z commented May 19, 2024

Training RADTTS (without pitch and energy conditioning):

1. Train the decoder
   python train.py -c config_ljs_radtts.json -p train_config.output_directory=outdir

2. Further train with the duration predictor
   python train.py -c config_ljs_radtts.json -p train_config.output_directory=outdir_dir train_config.warmstart_checkpoint_path=model_path.pt model_config.include_modules="decatndpm"

The original command of the second step is: python train.py -c config_ljs_radtts.json -p train_config.output_directory=outdir_dir train_config.warmstart_checkpoint_path=model_path.pt model_config.include_modules="decatndur"
We should change model _ config.include _ modules = "decatndur" in the original command to model _ config.include _ modules = "decatndpm". 11111
When Inference, the parameter "include_modules" of the configuration file should also be "decatndpm"

Did you run inference without pitch and energy conditioning ? I was having a bit of trouble understanding the arguments

Yes,I run.When inference without pitch and energy conditioning, it is necessary to change the "include_modules" parameter in the configuration file from' decatn' to' decatndpm'.As shown in the following figure: 2

or When inference without pitch and energy conditioning, using the config.json file under the folder where the model parameters are saved as the -c parameter of the reasoning command.The file path is shown in the figure below. 45

Order of Inference demo: python inference.py -c outdir_dir/config.json -r RADTTS_PATH -v HG_PATH -k HG_CONFIG_PATH -t TEXT_PATH -s ljs --speaker_attributes ljs --speaker_text ljs -o results/

Sorry for my poor expressive ability, I hope the above description can help you.

Thanks, mate; I was able to successfully run inference using the changes you mentioned. Explained everything clearly, thanks mate

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants