Skip to content

Commit

Permalink
Support ProDiff in TTS (#4808)
Browse files Browse the repository at this point in the history
Co-authored-by: Tomoki Hayashi <hayashi.tomoki@g.sp.m.is.nagoya-u.ac.jp>
  • Loading branch information
Fhrozen and kan-bayashi committed Jan 23, 2023
1 parent 847d464 commit 3970558
Show file tree
Hide file tree
Showing 7 changed files with 1,713 additions and 0 deletions.
98 changes: 98 additions & 0 deletions egs2/ljspeech/tts1/conf/tuning/train_prodiff.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# This configuration is for ESPnet2 to train ProDiff.
# Train in 2 GPUs (RTX3060) for 2 days.


##########################################################
# TTS MODEL SETTING #
##########################################################
tts: prodiff # model architecture
tts_conf: # keyword arguments for the selected model
adim: 256 # attention dimension
aheads: 2 # number of attention heads
elayers: 4 # number of encoder layers
eunits: 1024 # number of encoder ff units
positionwise_layer_type: conv1d-linear # type of position-wise layer
positionwise_conv_kernel_size: 9 # kernel size of position wise conv layer
use_masking: True # whether to apply masking for padded part in loss calculation
use_scaled_pos_enc: True # whether to use scaled positional encoding
encoder_normalize_before: True # whether to perform layer normalization before the input
reduction_factor: 1 # reduction factor
init_type: xavier_uniform # initialization type
init_enc_alpha: 1.0 # initial value of alpha of encoder scaled position encoding
transformer_enc_dropout_rate: 0.05 # dropout rate for transformer encoder layer
transformer_enc_positional_dropout_rate: 0.05 # dropout rate for transformer encoder positional encoding
transformer_enc_attn_dropout_rate: 0.05 # dropout rate for transformer encoder attention layer
# Duration
duration_predictor_layers: 2 # number of layers of duration predictor
duration_predictor_chans: 256 # number of channels of duration predictor
duration_predictor_kernel_size: 3 # filter size of duration predictor
# Pitch
pitch_predictor_layers: 2 # number of conv layers in pitch predictor
pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor
pitch_predictor_kernel_size: 3 # kernel size of conv leyers in pitch predictor
pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor
pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch
pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch
stop_gradient_from_pitch_predictor: true # whether to stop the gradient from pitch predictor to encoder
# Energy
energy_predictor_layers: 2 # number of conv layers in energy predictor
energy_predictor_chans: 256 # number of channels of conv layers in energy predictor
energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor
energy_predictor_dropout: 0.5 # dropout rate in energy predictor
energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy
energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy
stop_gradient_from_energy_predictor: false # whether to stop the gradient from energy predictor to encoder
# Denoiser Decoder
denoiser_layers: 20 # Number of layers for the diffusion denoiser decoder
denoiser_channels: 256 # Number of channels of the denoiser
diffusion_steps: 4 # Number of steps for the diffusion
diffusion_timescale: 1 # Number of timesteps of the diffusion
diffusion_beta: 40. # Beta applied to the diffusion
diffusion_scheduler: vpsde # Type of scheduler of the diffusion denoiser
diffusion_cycle_ln: 1 # Number of cycles during the diffusion

# extra module for additional inputs
pitch_extract: dio # pitch extractor type
pitch_normalize: global_mvn # normalizer for the pitch feature
energy_extract: energy # energy extractor type
energy_normalize: global_mvn # normalizer for the energy feature

##########################################################
# OPTIMIZER & SCHEDULER SETTING #
##########################################################
optim: adamw # optimizer type
optim_conf: # keyword arguments for selected optimizer
lr: 1.0 # learning rate
betas: [0.9, 0.98]
scheduler: noamlr # scheduler type
scheduler_conf: # keyword arguments for selected scheduler
model_size: 384 # model size, a.k.a., attention dimension
warmup_steps: 2000 # the number of warmup steps

##########################################################
# OTHER TRAINING SETTING #
##########################################################
# a total of 200K iters
num_iters_per_epoch: 250 # number of iterations per epoch
max_epoch: 800 # number of epochs
grad_clip: 1.0 # gradient clipping norm
grad_noise: false # whether to use gradient noise injection
accum_grad: 1 # gradient accumulation
batch_bins: 6000000 # batch bins (feats_type=raw)
batch_type: numel # how to make batch
sort_in_batch: descending # how to sort data in making batch
sort_batch: descending # how to sort created batches
num_workers: 1 # number of workers of data loader
train_dtype: float32 # dtype in training
log_interval: null # log interval in iterations
keep_nbest_models: 5 # number of models to keep
num_att_plot: 3 # number of attention figures to be saved in every check
seed: 0 # random seed number
# use_amp: True
best_model_criterion: # criterion to save the best models
- - valid
- loss
- min
- - train
- loss
- min
2 changes: 2 additions & 0 deletions espnet2/tasks/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from espnet2.tts.feats_extract.linear_spectrogram import LinearSpectrogram
from espnet2.tts.feats_extract.log_mel_fbank import LogMelFbank
from espnet2.tts.feats_extract.log_spectrogram import LogSpectrogram
from espnet2.tts.prodiff import ProDiff
from espnet2.tts.tacotron2 import Tacotron2
from espnet2.tts.transformer import Transformer
from espnet2.tts.utils import ParallelWaveGANPretrainedVocoder
Expand Down Expand Up @@ -91,6 +92,7 @@
transformer=Transformer,
fastspeech=FastSpeech,
fastspeech2=FastSpeech2,
prodiff=ProDiff,
# NOTE(kan-bayashi): available only for inference
vits=VITS,
joint_text2wav=JointText2Wav,
Expand Down
1 change: 1 addition & 0 deletions espnet2/tts/prodiff/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from espnet2.tts.prodiff.prodiff import ProDiff # NOQA

0 comments on commit 3970558

Please sign in to comment.