From 9a03026b75f27c078f775e4b1eb73ebbd63924c9 Mon Sep 17 00:00:00 2001 From: Kakaru <97896816+KakaruHayate@users.noreply.github.com> Date: Fri, 3 Oct 2025 22:40:33 +0800 Subject: [PATCH 1/9] Acoustic SR_embed (#270) * Acoustic SR_embed / Cosine annealing * del 'WarmupCosineSchedule' in config del 'WarmupCosineSchedule' in config * Fix the precision problem of 'StretchRegulator' in ONNX model * fix some odds and ends... * set 'use_stretch_embed' true on default * fix some odds and ends... --- configs/templates/config_acoustic.yaml | 2 ++ deployment/modules/fastspeech2.py | 21 ++++++++++------- deployment/modules/toplevel.py | 2 +- modules/commons/common_layers.py | 2 +- modules/fastspeech/acoustic_encoder.py | 32 ++++++++++++++++++++------ modules/fastspeech/tts_modules.py | 2 +- modules/toplevel.py | 2 +- utils/training_utils.py | 6 +++-- 8 files changed, 48 insertions(+), 21 deletions(-) diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 669c3a6da..c25822745 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -1,6 +1,8 @@ base_config: - configs/acoustic.yaml +use_stretch_embed: true + dictionaries: zh: dictionaries/opencpop-extension.txt extra_phonemes: [] diff --git a/deployment/modules/fastspeech2.py b/deployment/modules/fastspeech2.py index bb9d6a7bb..0e64d417a 100644 --- a/deployment/modules/fastspeech2.py +++ b/deployment/modules/fastspeech2.py @@ -73,6 +73,7 @@ def forward( txt_embed = self.txt_embed(tokens) durations = durations * (tokens > 0) mel2ph = self.lr(durations) + _mel2ph = mel2ph f0 = f0 * (mel2ph > 0) mel2ph = mel2ph[..., None].repeat((1, 1, hparams['hidden_size'])) if self.use_variance_scaling: @@ -92,6 +93,13 @@ def forward( encoded = F.pad(encoded, (0, 0, 1, 0)) condition = torch.gather(encoded, 1, mel2ph) + if self.use_stretch_embed: + stretch = self.sr(_mel2ph, durations) + stretch_embed = self.stretch_embed(stretch * 1000) + condition += stretch_embed + stretch_embed_rnn_out, _ =self.stretch_embed_rnn(condition) + condition += stretch_embed_rnn_out + if self.f0_embed_type == 'discrete': pitch = f0_to_coarse(f0) pitch_embed = self.pitch_embed(pitch) @@ -102,30 +110,27 @@ def forward( if self.use_variance_embeds: variance_embeds = torch.stack([ - self.variance_embeds[v_name](variances[v_name][:, :, None]) - * self.variance_scaling_factor[v_name] + self.variance_embeds[v_name](variances[v_name][:, :, None] * self.variance_scaling_factor[v_name]) for v_name in self.variance_embed_list ], dim=-1).sum(-1) condition += variance_embeds if hparams['use_key_shift_embed']: if hasattr(self, 'frozen_key_shift'): - key_shift_embed = self.key_shift_embed(self.frozen_key_shift[:, None, None]) + key_shift_embed = self.key_shift_embed(self.frozen_key_shift[:, None, None] * self.variance_scaling_factor['key_shift']) else: gender = torch.clip(gender, min=-1., max=1.) gender_mask = (gender < 0.).float() key_shift = gender * ((1. - gender_mask) * self.shift_max + gender_mask * abs(self.shift_min)) - key_shift_embed = self.key_shift_embed(key_shift[:, :, None]) - key_shift_embed *= self.variance_scaling_factor['key_shift'] + key_shift_embed = self.key_shift_embed(key_shift[:, :, None] * self.variance_scaling_factor['key_shift']) condition += key_shift_embed if hparams['use_speed_embed']: if velocity is not None: velocity = torch.clip(velocity, min=self.speed_min, max=self.speed_max) - speed_embed = self.speed_embed(velocity[:, :, None]) + speed_embed = self.speed_embed(velocity[:, :, None] * self.variance_scaling_factor['speed']) else: - speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None]) - speed_embed *= self.variance_scaling_factor['speed'] + speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None] * self.variance_scaling_factor['speed']) condition += speed_embed if hparams['use_spk_id']: diff --git a/deployment/modules/toplevel.py b/deployment/modules/toplevel.py index 0d1752ef3..4b0429d42 100644 --- a/deployment/modules/toplevel.py +++ b/deployment/modules/toplevel.py @@ -290,7 +290,7 @@ def forward_variance_preprocess( for v_retake in (~retake).split(1, dim=2) ] variance_embeds = [ - self.variance_embeds[v_name](variances[v_name][:, :, None]) * v_masks * self.variance_retake_scaling[v_name] + self.variance_embeds[v_name](variances[v_name][:, :, None] * self.variance_retake_scaling[v_name]) * v_masks for v_name, v_masks in zip(self.variance_prediction_list, non_retake_masks) ] variance_cond += torch.stack(variance_embeds, dim=-1).sum(-1) diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index f0e359e52..b6ebca235 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -325,6 +325,6 @@ def forward(self, x): half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] + emb = x.unsqueeze(-1) * emb.unsqueeze(0) emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb diff --git a/modules/fastspeech/acoustic_encoder.py b/modules/fastspeech/acoustic_encoder.py index 90da9cf80..f345eab58 100644 --- a/modules/fastspeech/acoustic_encoder.py +++ b/modules/fastspeech/acoustic_encoder.py @@ -5,8 +5,9 @@ from modules.commons.common_layers import ( NormalInitEmbedding as Embedding, XavierUniformInitLinear as Linear, + SinusoidalPosEmb, ) -from modules.fastspeech.tts_modules import FastSpeech2Encoder, mel2ph_to_dur +from modules.fastspeech.tts_modules import FastSpeech2Encoder, mel2ph_to_dur, StretchRegulator from utils.hparams import hparams from utils.phoneme_utils import PAD_INDEX @@ -18,6 +19,18 @@ def __init__(self, vocab_size): self.use_lang_id = hparams.get('use_lang_id', False) if self.use_lang_id: self.lang_embed = Embedding(hparams['num_lang'] + 1, hparams['hidden_size'], padding_idx=0) + + self.use_stretch_embed = hparams.get('use_stretch_embed', False) + if self.use_stretch_embed: + self.sr = StretchRegulator() + self.stretch_embed = nn.Sequential( + SinusoidalPosEmb(hparams['hidden_size']), + nn.Linear(hparams['hidden_size'], hparams['hidden_size'] * 4), + nn.GELU(), + nn.Linear(hparams['hidden_size'] * 4, hparams['hidden_size']), + ) + self.stretch_embed_rnn = nn.GRU(hparams['hidden_size'], hparams['hidden_size'], 1, batch_first=True) + self.dur_embed = Linear(1, hparams['hidden_size']) self.encoder = FastSpeech2Encoder( hidden_size=hparams['hidden_size'], num_layers=hparams['enc_layers'], @@ -84,20 +97,17 @@ def __init__(self, vocab_size): def forward_variance_embedding(self, condition, key_shift=None, speed=None, **variances): if self.use_variance_embeds: variance_embeds = torch.stack([ - self.variance_embeds[v_name](variances[v_name][:, :, None]) - * self.variance_scaling_factor[v_name] + self.variance_embeds[v_name](variances[v_name][:, :, None] * self.variance_scaling_factor[v_name]) for v_name in self.variance_embed_list ], dim=-1).sum(-1) condition += variance_embeds if self.use_key_shift_embed: - key_shift_embed = self.key_shift_embed(key_shift[:, :, None]) - key_shift_embed *= self.variance_scaling_factor['key_shift'] + key_shift_embed = self.key_shift_embed(key_shift[:, :, None] * self.variance_scaling_factor['key_shift']) condition += key_shift_embed if self.use_speed_embed: - speed_embed = self.speed_embed(speed[:, :, None]) - speed_embed *= self.variance_scaling_factor['speed'] + speed_embed = self.speed_embed(speed[:, :, None] * self.variance_scaling_factor['speed']) condition += speed_embed return condition @@ -125,6 +135,14 @@ def forward( mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]]) condition = torch.gather(encoder_out, 1, mel2ph_) + if self.use_stretch_embed: + stretch = self.sr(mel2ph, dur) + stretch_embed = self.stretch_embed(stretch * 1000) + condition += stretch_embed + self.stretch_embed_rnn.flatten_parameters() + stretch_embed_rnn_out, _ =self.stretch_embed_rnn(condition) + condition += stretch_embed_rnn_out + if self.use_spk_id: spk_mix_embed = kwargs.get('spk_mix_embed') if spk_mix_embed is not None: diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py index 16b358e3c..7ac3d6744 100644 --- a/modules/fastspeech/tts_modules.py +++ b/modules/fastspeech/tts_modules.py @@ -354,7 +354,7 @@ def forward(self, mel2ph, dur=None): stretch_delta = 1 - bound_mask * mel2dur stretch_delta = F.pad(stretch_delta, [1, -1], mode='constant', value=0) stretch_denorm = torch.cumsum(stretch_delta, dim=1) - stretch = stretch_denorm / mel2dur + stretch = stretch_denorm.float() / mel2dur return stretch * (mel2ph > 0) diff --git a/modules/toplevel.py b/modules/toplevel.py index 1ad5aa4b7..47c4a7ff5 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -326,7 +326,7 @@ def forward( if variance_retake is not None: variance_embeds = [ - self.variance_embeds[v_name](v_input[:, :, None]) * ~variance_retake[v_name][:, :, None] * self.variance_retake_scaling[v_name] + self.variance_embeds[v_name](v_input[:, :, None] * self.variance_retake_scaling[v_name]) * ~variance_retake[v_name][:, :, None] for v_name, v_input in zip(self.variance_prediction_list, variance_inputs) ] var_cond += torch.stack(variance_embeds, dim=-1).sum(-1) diff --git a/utils/training_utils.py b/utils/training_utils.py index 26d24eec5..e906f7721 100644 --- a/utils/training_utils.py +++ b/utils/training_utils.py @@ -54,16 +54,18 @@ class WarmupCosineSchedule(LambdaLR): `eta_min` (default=0.0) corresponds to the minimum learning rate reached by the scheduler. """ - def __init__(self, optimizer, warmup_steps, t_total, eta_min=0.0, cycles=.5, last_epoch=-1): + def __init__(self, optimizer, warmup_steps, t_total, warmup_min=0.0, eta_min=0.0, cycles=.5, last_epoch=-1): self.warmup_steps = warmup_steps self.t_total = t_total self.eta_min = eta_min self.cycles = cycles + self.warmup_min = warmup_min super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) def lr_lambda(self, step): if step < self.warmup_steps: - return step / max(1.0, self.warmup_steps) + progress = step / max(1.0, self.warmup_steps) + return self.warmup_min + progress * (1.0 - self.warmup_min) # progress after warmup progress = (step - self.warmup_steps) / max(1, self.t_total - self.warmup_steps) return max(self.eta_min, 0.5 * (1. + math.cos(math.pi * self.cycles * 2.0 * progress))) From 8ceb4180f3e8b86efd12ae4b43ce6887e0e3bddd Mon Sep 17 00:00:00 2001 From: yxlllc Date: Fri, 3 Oct 2025 22:48:41 +0800 Subject: [PATCH 2/9] adjust --- configs/acoustic.yaml | 1 + configs/templates/config_acoustic.yaml | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 0d7250d00..9bc978f26 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -64,6 +64,7 @@ timesteps: 1000 max_beta: 0.02 enc_ffn_kernel_size: 3 use_rope: true +use_stretch_embed: true use_variance_scaling: true rel_pos: true sampling_algorithm: euler diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index c25822745..acbe25dff 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -1,8 +1,6 @@ base_config: - configs/acoustic.yaml -use_stretch_embed: true - dictionaries: zh: dictionaries/opencpop-extension.txt extra_phonemes: [] @@ -73,6 +71,7 @@ augmentation_args: diffusion_type: reflow enc_ffn_kernel_size: 3 use_rope: true +use_stretch_embed: true use_variance_scaling: true use_shallow_diffusion: true T_start: 0.4 From 5d1c0ef9511502745a6cf2a6f9e6f9c5e0ff0eba Mon Sep 17 00:00:00 2001 From: yxlllc Date: Mon, 6 Oct 2025 15:19:43 +0800 Subject: [PATCH 3/9] add stretch embed to variance models --- configs/templates/config_variance.yaml | 1 + configs/variance.yaml | 1 + deployment/modules/toplevel.py | 13 ++++++++++--- modules/toplevel.py | 24 ++++++++++++++++++++++-- 4 files changed, 34 insertions(+), 5 deletions(-) diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index 3947928f2..c1af03e47 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -65,6 +65,7 @@ tension_logit_max: 10.0 enc_ffn_kernel_size: 3 use_rope: true +use_stretch_embed: true use_variance_scaling: true hidden_size: 384 dur_prediction_args: diff --git a/configs/variance.yaml b/configs/variance.yaml index 5a4799203..27cfc146b 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -36,6 +36,7 @@ predict_tension: false enc_ffn_kernel_size: 3 use_rope: true +use_stretch_embed: true use_variance_scaling: true rel_pos: true hidden_size: 384 diff --git a/deployment/modules/toplevel.py b/deployment/modules/toplevel.py index 4b0429d42..b21d2f190 100644 --- a/deployment/modules/toplevel.py +++ b/deployment/modules/toplevel.py @@ -211,14 +211,21 @@ def forward_linguistic_encoder_phoneme(self, tokens, ph_dur, languages=None): def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None): return self.fs2.forward_dur_predictor(encoder_out, x_masks, ph_midi, spk_embed=spk_embed) - def forward_mel2x_gather(self, x_src, x_dur, x_dim=None): + def forward_mel2x_gather(self, x_src, x_dur, x_dim=None, check_stretch_embed=False): mel2x = self.lr(x_dur) + _mel2x = mel2x if x_dim is not None: x_src = F.pad(x_src, [0, 0, 1, 0]) mel2x = mel2x[..., None].repeat([1, 1, x_dim]) else: x_src = F.pad(x_src, [1, 0]) x_cond = torch.gather(x_src, 1, mel2x) + if self.use_stretch_embed and check_stretch_embed: + stretch = self.sr(_mel2x, x_dur) + stretch_embed = self.stretch_embed(stretch * 1000) + x_cond += stretch_embed + stretch_embed_rnn_out, _ =self.stretch_embed_rnn(x_cond) + x_cond += stretch_embed_rnn_out return x_cond def forward_pitch_preprocess( @@ -226,7 +233,7 @@ def forward_pitch_preprocess( note_midi=None, note_rest=None, note_dur=None, note_glide=None, pitch=None, expr=None, retake=None, spk_embed=None ): - condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) + condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size, check_stretch_embed=True) if self.use_melody_encoder: if self.melody_encoder.use_glide_embed and note_glide is None: note_glide = torch.LongTensor([[0]]).to(encoder_out.device) @@ -280,7 +287,7 @@ def forward_variance_preprocess( self, encoder_out, ph_dur, pitch, variances: dict = None, retake=None, spk_embed=None ): - condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) + condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size, check_stretch_embed=True) if self.use_variance_scaling: variance_cond = condition + self.pitch_embed(pitch[:, :, None] / 12) else: diff --git a/modules/toplevel.py b/modules/toplevel.py index 47c4a7ff5..a82b265be 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -10,7 +10,8 @@ from modules.aux_decoder import AuxDecoderAdaptor from modules.commons.common_layers import ( XavierUniformInitLinear as Linear, - NormalInitEmbedding as Embedding + NormalInitEmbedding as Embedding, + SinusoidalPosEmb ) from modules.core import ( GaussianDiffusion, PitchDiffusion, MultiVarianceDiffusion, @@ -18,7 +19,7 @@ ) from modules.fastspeech.acoustic_encoder import FastSpeech2Acoustic from modules.fastspeech.param_adaptor import ParameterAdaptorModule -from modules.fastspeech.tts_modules import RhythmRegulator, LengthRegulator +from modules.fastspeech.tts_modules import RhythmRegulator, LengthRegulator, StretchRegulator from modules.fastspeech.variance_encoder import FastSpeech2Variance, MelodyEncoder from utils.hparams import hparams @@ -133,6 +134,17 @@ def __init__(self, vocab_size): self.predict_dur = hparams['predict_dur'] self.predict_pitch = hparams['predict_pitch'] + self.use_stretch_embed = hparams.get('use_stretch_embed', False) + if self.use_stretch_embed and (self.predict_pitch or self.predict_variances): + self.sr = StretchRegulator() + self.stretch_embed = nn.Sequential( + SinusoidalPosEmb(hparams['hidden_size']), + nn.Linear(hparams['hidden_size'], hparams['hidden_size'] * 4), + nn.GELU(), + nn.Linear(hparams['hidden_size'] * 4, hparams['hidden_size']), + ) + self.stretch_embed_rnn = nn.GRU(hparams['hidden_size'], hparams['hidden_size'], 1, batch_first=True) + self.use_spk_id = hparams['use_spk_id'] if self.use_spk_id: self.spk_embed = Embedding(hparams['num_spk'], hparams['hidden_size']) @@ -255,6 +267,14 @@ def forward( mel2ph_ = mel2ph[..., None].repeat([1, 1, hparams['hidden_size']]) condition = torch.gather(encoder_out, 1, mel2ph_) + if self.use_stretch_embed: + stretch = self.sr(mel2ph, ph_dur) + stretch_embed = self.stretch_embed(stretch * 1000) + condition += stretch_embed + self.stretch_embed_rnn.flatten_parameters() + stretch_embed_rnn_out, _ = self.stretch_embed_rnn(condition) + condition = condition + stretch_embed_rnn_out + if self.use_spk_id: condition += spk_embed From 55e16c450322fea47174691b71b742b4af18802a Mon Sep 17 00:00:00 2001 From: yxlllc Date: Mon, 6 Oct 2025 22:55:44 +0800 Subject: [PATCH 4/9] fix --- modules/fastspeech/acoustic_encoder.py | 3 ++- modules/toplevel.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/modules/fastspeech/acoustic_encoder.py b/modules/fastspeech/acoustic_encoder.py index f345eab58..6f954b4b6 100644 --- a/modules/fastspeech/acoustic_encoder.py +++ b/modules/fastspeech/acoustic_encoder.py @@ -20,7 +20,8 @@ def __init__(self, vocab_size): if self.use_lang_id: self.lang_embed = Embedding(hparams['num_lang'] + 1, hparams['hidden_size'], padding_idx=0) - self.use_stretch_embed = hparams.get('use_stretch_embed', False) + self.use_stretch_embed = hparams.get('use_stretch_embed', None) + assert self.use_stretch_embed is not None, "You may be loading an old version of the model checkpoint, which is incompatible with the new version due to some bug fixes. It is recommended to roll back to the old version (commit id: 6df0ee977c3728f14cb79c2db8b19df30b23a0bf)" if self.use_stretch_embed: self.sr = StretchRegulator() self.stretch_embed = nn.Sequential( diff --git a/modules/toplevel.py b/modules/toplevel.py index a82b265be..d0935c189 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -134,7 +134,8 @@ def __init__(self, vocab_size): self.predict_dur = hparams['predict_dur'] self.predict_pitch = hparams['predict_pitch'] - self.use_stretch_embed = hparams.get('use_stretch_embed', False) + self.use_stretch_embed = hparams.get('use_stretch_embed', None) + assert self.use_stretch_embed is not None, "You may be loading an old version of the model checkpoint, which is incompatible with the new version due to some bug fixes. It is recommended to roll back to the old version (commit id: 6df0ee977c3728f14cb79c2db8b19df30b23a0bf)" if self.use_stretch_embed and (self.predict_pitch or self.predict_variances): self.sr = StretchRegulator() self.stretch_embed = nn.Sequential( From 16492d90f89592f6b0c1abccef02960e74248be7 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Tue, 7 Oct 2025 16:31:48 +0800 Subject: [PATCH 5/9] fix --- modules/fastspeech/tts_modules.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py index 7ac3d6744..a2cf0612c 100644 --- a/modules/fastspeech/tts_modules.py +++ b/modules/fastspeech/tts_modules.py @@ -347,12 +347,11 @@ def forward(self, mel2ph, dur=None): """ if dur is None: dur = mel2ph_to_dur(mel2ph, mel2ph.max()) - dur = F.pad(dur, [1, 0], value=1) # Avoid dividing by zero + dur = F.pad(dur, [1, 0], mode='replicate') # Avoid dividing by zero mel2dur = torch.gather(dur, 1, mel2ph) bound_mask = torch.gt(mel2ph[:, 1:], mel2ph[:, :-1]) - bound_mask = F.pad(bound_mask, [0, 1], mode='constant', value=True) - stretch_delta = 1 - bound_mask * mel2dur - stretch_delta = F.pad(stretch_delta, [1, -1], mode='constant', value=0) + stretch_delta = 1 - bound_mask * mel2dur[:, :-1] + stretch_delta = F.pad(stretch_delta, [1, 0]) stretch_denorm = torch.cumsum(stretch_delta, dim=1) stretch = stretch_denorm.float() / mel2dur return stretch * (mel2ph > 0) From fe1bbca1345d8866ce5e71ef68cace156ed20e62 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Tue, 7 Oct 2025 16:44:14 +0800 Subject: [PATCH 6/9] fix --- modules/fastspeech/tts_modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py index a2cf0612c..a18fba5bd 100644 --- a/modules/fastspeech/tts_modules.py +++ b/modules/fastspeech/tts_modules.py @@ -347,7 +347,7 @@ def forward(self, mel2ph, dur=None): """ if dur is None: dur = mel2ph_to_dur(mel2ph, mel2ph.max()) - dur = F.pad(dur, [1, 0], mode='replicate') # Avoid dividing by zero + dur = torch.cat([dur[:, :1] + 1, dur], dim=1) # Avoid dividing by zero mel2dur = torch.gather(dur, 1, mel2ph) bound_mask = torch.gt(mel2ph[:, 1:], mel2ph[:, :-1]) stretch_delta = 1 - bound_mask * mel2dur[:, :-1] From 7fb8139f142dffbaee23717f381392e37705b6e3 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Tue, 7 Oct 2025 17:02:54 +0800 Subject: [PATCH 7/9] optimize --- modules/fastspeech/tts_modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py index a18fba5bd..882ebc115 100644 --- a/modules/fastspeech/tts_modules.py +++ b/modules/fastspeech/tts_modules.py @@ -347,7 +347,7 @@ def forward(self, mel2ph, dur=None): """ if dur is None: dur = mel2ph_to_dur(mel2ph, mel2ph.max()) - dur = torch.cat([dur[:, :1] + 1, dur], dim=1) # Avoid dividing by zero + dur = torch.cat([torch.ones_like(dur[:, :1]), dur], dim=1) # Avoid dividing by zero mel2dur = torch.gather(dur, 1, mel2ph) bound_mask = torch.gt(mel2ph[:, 1:], mel2ph[:, :-1]) stretch_delta = 1 - bound_mask * mel2dur[:, :-1] From 3dba5b5295864d07f4a0dd2f13967f452307227c Mon Sep 17 00:00:00 2001 From: yxlllc Date: Thu, 9 Oct 2025 00:24:18 +0800 Subject: [PATCH 8/9] using lookup table for optimization --- deployment/modules/fastspeech2.py | 7 ++++--- deployment/modules/toplevel.py | 7 ++++--- modules/fastspeech/acoustic_encoder.py | 19 ++++++++++++------- modules/toplevel.py | 9 +++++++-- 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/deployment/modules/fastspeech2.py b/deployment/modules/fastspeech2.py index 0e64d417a..b22590c2a 100644 --- a/deployment/modules/fastspeech2.py +++ b/deployment/modules/fastspeech2.py @@ -94,10 +94,11 @@ def forward( condition = torch.gather(encoded, 1, mel2ph) if self.use_stretch_embed: - stretch = self.sr(_mel2ph, durations) - stretch_embed = self.stretch_embed(stretch * 1000) + stretch = torch.round(1000 * self.sr(_mel2ph, durations)) + table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device)) + stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(condition) condition += stretch_embed - stretch_embed_rnn_out, _ =self.stretch_embed_rnn(condition) + stretch_embed_rnn_out, _ = self.stretch_embed_rnn(condition) condition += stretch_embed_rnn_out if self.f0_embed_type == 'discrete': diff --git a/deployment/modules/toplevel.py b/deployment/modules/toplevel.py index b21d2f190..3043168ca 100644 --- a/deployment/modules/toplevel.py +++ b/deployment/modules/toplevel.py @@ -221,10 +221,11 @@ def forward_mel2x_gather(self, x_src, x_dur, x_dim=None, check_stretch_embed=Fal x_src = F.pad(x_src, [1, 0]) x_cond = torch.gather(x_src, 1, mel2x) if self.use_stretch_embed and check_stretch_embed: - stretch = self.sr(_mel2x, x_dur) - stretch_embed = self.stretch_embed(stretch * 1000) + stretch = torch.round(1000 * self.sr(_mel2x, x_dur)) + table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device)) + stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(x_cond) x_cond += stretch_embed - stretch_embed_rnn_out, _ =self.stretch_embed_rnn(x_cond) + stretch_embed_rnn_out, _ = self.stretch_embed_rnn(x_cond) x_cond += stretch_embed_rnn_out return x_cond diff --git a/modules/fastspeech/acoustic_encoder.py b/modules/fastspeech/acoustic_encoder.py index 6f954b4b6..86aa535a3 100644 --- a/modules/fastspeech/acoustic_encoder.py +++ b/modules/fastspeech/acoustic_encoder.py @@ -120,11 +120,11 @@ def forward( **kwargs ): txt_embed = self.txt_embed(txt_tokens) - dur = mel2ph_to_dur(mel2ph, txt_tokens.shape[1]).float() + dur = mel2ph_to_dur(mel2ph, txt_tokens.shape[1]) if self.use_variance_scaling: - dur_embed = self.dur_embed(torch.log(1 + dur[:, :, None])) + dur_embed = self.dur_embed(torch.log(1 + dur[:, :, None].float())) else: - dur_embed = self.dur_embed(dur[:, :, None]) + dur_embed = self.dur_embed(dur[:, :, None].float()) if self.use_lang_id: lang_embed = self.lang_embed(languages) extra_embed = dur_embed + lang_embed @@ -137,12 +137,17 @@ def forward( condition = torch.gather(encoder_out, 1, mel2ph_) if self.use_stretch_embed: - stretch = self.sr(mel2ph, dur) - stretch_embed = self.stretch_embed(stretch * 1000) + stretch = torch.round(1000 * self.sr(mel2ph, dur)) + if self.training and stretch.numel() > 1000: + # construct a phoneme stretching index lookup table with a total of 1001 indexes (0~1000) + table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device)) + stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(condition) + else: + stretch_embed = self.stretch_embed(stretch) condition += stretch_embed self.stretch_embed_rnn.flatten_parameters() - stretch_embed_rnn_out, _ =self.stretch_embed_rnn(condition) - condition += stretch_embed_rnn_out + stretch_embed_rnn_out, _ = self.stretch_embed_rnn(condition) + condition = condition + stretch_embed_rnn_out if self.use_spk_id: spk_mix_embed = kwargs.get('spk_mix_embed') diff --git a/modules/toplevel.py b/modules/toplevel.py index d0935c189..3c3129665 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -269,8 +269,13 @@ def forward( condition = torch.gather(encoder_out, 1, mel2ph_) if self.use_stretch_embed: - stretch = self.sr(mel2ph, ph_dur) - stretch_embed = self.stretch_embed(stretch * 1000) + stretch = torch.round(1000 * self.sr(mel2ph, ph_dur)) + if self.training and stretch.numel() > 1000: + # construct a phoneme stretching index lookup table with a total of 1001 indexes (0~1000) + table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device)) + stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(condition) + else: + stretch_embed = self.stretch_embed(stretch) condition += stretch_embed self.stretch_embed_rnn.flatten_parameters() stretch_embed_rnn_out, _ = self.stretch_embed_rnn(condition) From 626957c9dc75eb5e567fe12b90654b2e14867e4e Mon Sep 17 00:00:00 2001 From: yxlllc Date: Tue, 11 Nov 2025 21:26:14 +0800 Subject: [PATCH 9/9] update --- configs/templates/config_variance.yaml | 2 +- configs/variance.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index c1af03e47..58a4d3a6d 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -65,7 +65,7 @@ tension_logit_max: 10.0 enc_ffn_kernel_size: 3 use_rope: true -use_stretch_embed: true +use_stretch_embed: false use_variance_scaling: true hidden_size: 384 dur_prediction_args: diff --git a/configs/variance.yaml b/configs/variance.yaml index 27cfc146b..6bd86cfad 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -36,7 +36,7 @@ predict_tension: false enc_ffn_kernel_size: 3 use_rope: true -use_stretch_embed: true +use_stretch_embed: false use_variance_scaling: true rel_pos: true hidden_size: 384