diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 0d7250d0..9bc978f2 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -64,6 +64,7 @@ timesteps: 1000 max_beta: 0.02 enc_ffn_kernel_size: 3 use_rope: true +use_stretch_embed: true use_variance_scaling: true rel_pos: true sampling_algorithm: euler diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 669c3a6d..acbe25df 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -71,6 +71,7 @@ augmentation_args: diffusion_type: reflow enc_ffn_kernel_size: 3 use_rope: true +use_stretch_embed: true use_variance_scaling: true use_shallow_diffusion: true T_start: 0.4 diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index 3947928f..58a4d3a6 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -65,6 +65,7 @@ tension_logit_max: 10.0 enc_ffn_kernel_size: 3 use_rope: true +use_stretch_embed: false use_variance_scaling: true hidden_size: 384 dur_prediction_args: diff --git a/configs/variance.yaml b/configs/variance.yaml index 5a479920..6bd86cfa 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -36,6 +36,7 @@ predict_tension: false enc_ffn_kernel_size: 3 use_rope: true +use_stretch_embed: false use_variance_scaling: true rel_pos: true hidden_size: 384 diff --git a/deployment/modules/fastspeech2.py b/deployment/modules/fastspeech2.py index bb9d6a7b..b22590c2 100644 --- a/deployment/modules/fastspeech2.py +++ b/deployment/modules/fastspeech2.py @@ -73,6 +73,7 @@ def forward( txt_embed = self.txt_embed(tokens) durations = durations * (tokens > 0) mel2ph = self.lr(durations) + _mel2ph = mel2ph f0 = f0 * (mel2ph > 0) mel2ph = mel2ph[..., None].repeat((1, 1, hparams['hidden_size'])) if self.use_variance_scaling: @@ -92,6 +93,14 @@ def forward( encoded = F.pad(encoded, (0, 0, 1, 0)) condition = torch.gather(encoded, 1, mel2ph) + if self.use_stretch_embed: + stretch = torch.round(1000 * self.sr(_mel2ph, durations)) + table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device)) + stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(condition) + condition += stretch_embed + stretch_embed_rnn_out, _ = self.stretch_embed_rnn(condition) + condition += stretch_embed_rnn_out + if self.f0_embed_type == 'discrete': pitch = f0_to_coarse(f0) pitch_embed = self.pitch_embed(pitch) @@ -102,30 +111,27 @@ def forward( if self.use_variance_embeds: variance_embeds = torch.stack([ - self.variance_embeds[v_name](variances[v_name][:, :, None]) - * self.variance_scaling_factor[v_name] + self.variance_embeds[v_name](variances[v_name][:, :, None] * self.variance_scaling_factor[v_name]) for v_name in self.variance_embed_list ], dim=-1).sum(-1) condition += variance_embeds if hparams['use_key_shift_embed']: if hasattr(self, 'frozen_key_shift'): - key_shift_embed = self.key_shift_embed(self.frozen_key_shift[:, None, None]) + key_shift_embed = self.key_shift_embed(self.frozen_key_shift[:, None, None] * self.variance_scaling_factor['key_shift']) else: gender = torch.clip(gender, min=-1., max=1.) gender_mask = (gender < 0.).float() key_shift = gender * ((1. - gender_mask) * self.shift_max + gender_mask * abs(self.shift_min)) - key_shift_embed = self.key_shift_embed(key_shift[:, :, None]) - key_shift_embed *= self.variance_scaling_factor['key_shift'] + key_shift_embed = self.key_shift_embed(key_shift[:, :, None] * self.variance_scaling_factor['key_shift']) condition += key_shift_embed if hparams['use_speed_embed']: if velocity is not None: velocity = torch.clip(velocity, min=self.speed_min, max=self.speed_max) - speed_embed = self.speed_embed(velocity[:, :, None]) + speed_embed = self.speed_embed(velocity[:, :, None] * self.variance_scaling_factor['speed']) else: - speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None]) - speed_embed *= self.variance_scaling_factor['speed'] + speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None] * self.variance_scaling_factor['speed']) condition += speed_embed if hparams['use_spk_id']: diff --git a/deployment/modules/toplevel.py b/deployment/modules/toplevel.py index 0d1752ef..3043168c 100644 --- a/deployment/modules/toplevel.py +++ b/deployment/modules/toplevel.py @@ -211,14 +211,22 @@ def forward_linguistic_encoder_phoneme(self, tokens, ph_dur, languages=None): def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None): return self.fs2.forward_dur_predictor(encoder_out, x_masks, ph_midi, spk_embed=spk_embed) - def forward_mel2x_gather(self, x_src, x_dur, x_dim=None): + def forward_mel2x_gather(self, x_src, x_dur, x_dim=None, check_stretch_embed=False): mel2x = self.lr(x_dur) + _mel2x = mel2x if x_dim is not None: x_src = F.pad(x_src, [0, 0, 1, 0]) mel2x = mel2x[..., None].repeat([1, 1, x_dim]) else: x_src = F.pad(x_src, [1, 0]) x_cond = torch.gather(x_src, 1, mel2x) + if self.use_stretch_embed and check_stretch_embed: + stretch = torch.round(1000 * self.sr(_mel2x, x_dur)) + table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device)) + stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(x_cond) + x_cond += stretch_embed + stretch_embed_rnn_out, _ = self.stretch_embed_rnn(x_cond) + x_cond += stretch_embed_rnn_out return x_cond def forward_pitch_preprocess( @@ -226,7 +234,7 @@ def forward_pitch_preprocess( note_midi=None, note_rest=None, note_dur=None, note_glide=None, pitch=None, expr=None, retake=None, spk_embed=None ): - condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) + condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size, check_stretch_embed=True) if self.use_melody_encoder: if self.melody_encoder.use_glide_embed and note_glide is None: note_glide = torch.LongTensor([[0]]).to(encoder_out.device) @@ -280,7 +288,7 @@ def forward_variance_preprocess( self, encoder_out, ph_dur, pitch, variances: dict = None, retake=None, spk_embed=None ): - condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) + condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size, check_stretch_embed=True) if self.use_variance_scaling: variance_cond = condition + self.pitch_embed(pitch[:, :, None] / 12) else: @@ -290,7 +298,7 @@ def forward_variance_preprocess( for v_retake in (~retake).split(1, dim=2) ] variance_embeds = [ - self.variance_embeds[v_name](variances[v_name][:, :, None]) * v_masks * self.variance_retake_scaling[v_name] + self.variance_embeds[v_name](variances[v_name][:, :, None] * self.variance_retake_scaling[v_name]) * v_masks for v_name, v_masks in zip(self.variance_prediction_list, non_retake_masks) ] variance_cond += torch.stack(variance_embeds, dim=-1).sum(-1) diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index f0e359e5..b6ebca23 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -325,6 +325,6 @@ def forward(self, x): half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] + emb = x.unsqueeze(-1) * emb.unsqueeze(0) emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb diff --git a/modules/fastspeech/acoustic_encoder.py b/modules/fastspeech/acoustic_encoder.py index 90da9cf8..86aa535a 100644 --- a/modules/fastspeech/acoustic_encoder.py +++ b/modules/fastspeech/acoustic_encoder.py @@ -5,8 +5,9 @@ from modules.commons.common_layers import ( NormalInitEmbedding as Embedding, XavierUniformInitLinear as Linear, + SinusoidalPosEmb, ) -from modules.fastspeech.tts_modules import FastSpeech2Encoder, mel2ph_to_dur +from modules.fastspeech.tts_modules import FastSpeech2Encoder, mel2ph_to_dur, StretchRegulator from utils.hparams import hparams from utils.phoneme_utils import PAD_INDEX @@ -18,6 +19,19 @@ def __init__(self, vocab_size): self.use_lang_id = hparams.get('use_lang_id', False) if self.use_lang_id: self.lang_embed = Embedding(hparams['num_lang'] + 1, hparams['hidden_size'], padding_idx=0) + + self.use_stretch_embed = hparams.get('use_stretch_embed', None) + assert self.use_stretch_embed is not None, "You may be loading an old version of the model checkpoint, which is incompatible with the new version due to some bug fixes. It is recommended to roll back to the old version (commit id: 6df0ee977c3728f14cb79c2db8b19df30b23a0bf)" + if self.use_stretch_embed: + self.sr = StretchRegulator() + self.stretch_embed = nn.Sequential( + SinusoidalPosEmb(hparams['hidden_size']), + nn.Linear(hparams['hidden_size'], hparams['hidden_size'] * 4), + nn.GELU(), + nn.Linear(hparams['hidden_size'] * 4, hparams['hidden_size']), + ) + self.stretch_embed_rnn = nn.GRU(hparams['hidden_size'], hparams['hidden_size'], 1, batch_first=True) + self.dur_embed = Linear(1, hparams['hidden_size']) self.encoder = FastSpeech2Encoder( hidden_size=hparams['hidden_size'], num_layers=hparams['enc_layers'], @@ -84,20 +98,17 @@ def __init__(self, vocab_size): def forward_variance_embedding(self, condition, key_shift=None, speed=None, **variances): if self.use_variance_embeds: variance_embeds = torch.stack([ - self.variance_embeds[v_name](variances[v_name][:, :, None]) - * self.variance_scaling_factor[v_name] + self.variance_embeds[v_name](variances[v_name][:, :, None] * self.variance_scaling_factor[v_name]) for v_name in self.variance_embed_list ], dim=-1).sum(-1) condition += variance_embeds if self.use_key_shift_embed: - key_shift_embed = self.key_shift_embed(key_shift[:, :, None]) - key_shift_embed *= self.variance_scaling_factor['key_shift'] + key_shift_embed = self.key_shift_embed(key_shift[:, :, None] * self.variance_scaling_factor['key_shift']) condition += key_shift_embed if self.use_speed_embed: - speed_embed = self.speed_embed(speed[:, :, None]) - speed_embed *= self.variance_scaling_factor['speed'] + speed_embed = self.speed_embed(speed[:, :, None] * self.variance_scaling_factor['speed']) condition += speed_embed return condition @@ -109,11 +120,11 @@ def forward( **kwargs ): txt_embed = self.txt_embed(txt_tokens) - dur = mel2ph_to_dur(mel2ph, txt_tokens.shape[1]).float() + dur = mel2ph_to_dur(mel2ph, txt_tokens.shape[1]) if self.use_variance_scaling: - dur_embed = self.dur_embed(torch.log(1 + dur[:, :, None])) + dur_embed = self.dur_embed(torch.log(1 + dur[:, :, None].float())) else: - dur_embed = self.dur_embed(dur[:, :, None]) + dur_embed = self.dur_embed(dur[:, :, None].float()) if self.use_lang_id: lang_embed = self.lang_embed(languages) extra_embed = dur_embed + lang_embed @@ -125,6 +136,19 @@ def forward( mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]]) condition = torch.gather(encoder_out, 1, mel2ph_) + if self.use_stretch_embed: + stretch = torch.round(1000 * self.sr(mel2ph, dur)) + if self.training and stretch.numel() > 1000: + # construct a phoneme stretching index lookup table with a total of 1001 indexes (0~1000) + table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device)) + stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(condition) + else: + stretch_embed = self.stretch_embed(stretch) + condition += stretch_embed + self.stretch_embed_rnn.flatten_parameters() + stretch_embed_rnn_out, _ = self.stretch_embed_rnn(condition) + condition = condition + stretch_embed_rnn_out + if self.use_spk_id: spk_mix_embed = kwargs.get('spk_mix_embed') if spk_mix_embed is not None: diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py index 16b358e3..882ebc11 100644 --- a/modules/fastspeech/tts_modules.py +++ b/modules/fastspeech/tts_modules.py @@ -347,14 +347,13 @@ def forward(self, mel2ph, dur=None): """ if dur is None: dur = mel2ph_to_dur(mel2ph, mel2ph.max()) - dur = F.pad(dur, [1, 0], value=1) # Avoid dividing by zero + dur = torch.cat([torch.ones_like(dur[:, :1]), dur], dim=1) # Avoid dividing by zero mel2dur = torch.gather(dur, 1, mel2ph) bound_mask = torch.gt(mel2ph[:, 1:], mel2ph[:, :-1]) - bound_mask = F.pad(bound_mask, [0, 1], mode='constant', value=True) - stretch_delta = 1 - bound_mask * mel2dur - stretch_delta = F.pad(stretch_delta, [1, -1], mode='constant', value=0) + stretch_delta = 1 - bound_mask * mel2dur[:, :-1] + stretch_delta = F.pad(stretch_delta, [1, 0]) stretch_denorm = torch.cumsum(stretch_delta, dim=1) - stretch = stretch_denorm / mel2dur + stretch = stretch_denorm.float() / mel2dur return stretch * (mel2ph > 0) diff --git a/modules/toplevel.py b/modules/toplevel.py index 1ad5aa4b..3c312966 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -10,7 +10,8 @@ from modules.aux_decoder import AuxDecoderAdaptor from modules.commons.common_layers import ( XavierUniformInitLinear as Linear, - NormalInitEmbedding as Embedding + NormalInitEmbedding as Embedding, + SinusoidalPosEmb ) from modules.core import ( GaussianDiffusion, PitchDiffusion, MultiVarianceDiffusion, @@ -18,7 +19,7 @@ ) from modules.fastspeech.acoustic_encoder import FastSpeech2Acoustic from modules.fastspeech.param_adaptor import ParameterAdaptorModule -from modules.fastspeech.tts_modules import RhythmRegulator, LengthRegulator +from modules.fastspeech.tts_modules import RhythmRegulator, LengthRegulator, StretchRegulator from modules.fastspeech.variance_encoder import FastSpeech2Variance, MelodyEncoder from utils.hparams import hparams @@ -133,6 +134,18 @@ def __init__(self, vocab_size): self.predict_dur = hparams['predict_dur'] self.predict_pitch = hparams['predict_pitch'] + self.use_stretch_embed = hparams.get('use_stretch_embed', None) + assert self.use_stretch_embed is not None, "You may be loading an old version of the model checkpoint, which is incompatible with the new version due to some bug fixes. It is recommended to roll back to the old version (commit id: 6df0ee977c3728f14cb79c2db8b19df30b23a0bf)" + if self.use_stretch_embed and (self.predict_pitch or self.predict_variances): + self.sr = StretchRegulator() + self.stretch_embed = nn.Sequential( + SinusoidalPosEmb(hparams['hidden_size']), + nn.Linear(hparams['hidden_size'], hparams['hidden_size'] * 4), + nn.GELU(), + nn.Linear(hparams['hidden_size'] * 4, hparams['hidden_size']), + ) + self.stretch_embed_rnn = nn.GRU(hparams['hidden_size'], hparams['hidden_size'], 1, batch_first=True) + self.use_spk_id = hparams['use_spk_id'] if self.use_spk_id: self.spk_embed = Embedding(hparams['num_spk'], hparams['hidden_size']) @@ -255,6 +268,19 @@ def forward( mel2ph_ = mel2ph[..., None].repeat([1, 1, hparams['hidden_size']]) condition = torch.gather(encoder_out, 1, mel2ph_) + if self.use_stretch_embed: + stretch = torch.round(1000 * self.sr(mel2ph, ph_dur)) + if self.training and stretch.numel() > 1000: + # construct a phoneme stretching index lookup table with a total of 1001 indexes (0~1000) + table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device)) + stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(condition) + else: + stretch_embed = self.stretch_embed(stretch) + condition += stretch_embed + self.stretch_embed_rnn.flatten_parameters() + stretch_embed_rnn_out, _ = self.stretch_embed_rnn(condition) + condition = condition + stretch_embed_rnn_out + if self.use_spk_id: condition += spk_embed @@ -326,7 +352,7 @@ def forward( if variance_retake is not None: variance_embeds = [ - self.variance_embeds[v_name](v_input[:, :, None]) * ~variance_retake[v_name][:, :, None] * self.variance_retake_scaling[v_name] + self.variance_embeds[v_name](v_input[:, :, None] * self.variance_retake_scaling[v_name]) * ~variance_retake[v_name][:, :, None] for v_name, v_input in zip(self.variance_prediction_list, variance_inputs) ] var_cond += torch.stack(variance_embeds, dim=-1).sum(-1) diff --git a/utils/training_utils.py b/utils/training_utils.py index 26d24eec..e906f772 100644 --- a/utils/training_utils.py +++ b/utils/training_utils.py @@ -54,16 +54,18 @@ class WarmupCosineSchedule(LambdaLR): `eta_min` (default=0.0) corresponds to the minimum learning rate reached by the scheduler. """ - def __init__(self, optimizer, warmup_steps, t_total, eta_min=0.0, cycles=.5, last_epoch=-1): + def __init__(self, optimizer, warmup_steps, t_total, warmup_min=0.0, eta_min=0.0, cycles=.5, last_epoch=-1): self.warmup_steps = warmup_steps self.t_total = t_total self.eta_min = eta_min self.cycles = cycles + self.warmup_min = warmup_min super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) def lr_lambda(self, step): if step < self.warmup_steps: - return step / max(1.0, self.warmup_steps) + progress = step / max(1.0, self.warmup_steps) + return self.warmup_min + progress * (1.0 - self.warmup_min) # progress after warmup progress = (step - self.warmup_steps) / max(1, self.t_total - self.warmup_steps) return max(self.eta_min, 0.5 * (1. + math.cos(math.pi * self.cycles * 2.0 * progress)))