Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ timesteps: 1000
max_beta: 0.02
enc_ffn_kernel_size: 3
use_rope: true
use_stretch_embed: true
use_variance_scaling: true
rel_pos: true
sampling_algorithm: euler
Expand Down
1 change: 1 addition & 0 deletions configs/templates/config_acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ augmentation_args:
diffusion_type: reflow
enc_ffn_kernel_size: 3
use_rope: true
use_stretch_embed: true
use_variance_scaling: true
use_shallow_diffusion: true
T_start: 0.4
Expand Down
1 change: 1 addition & 0 deletions configs/templates/config_variance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ tension_logit_max: 10.0

enc_ffn_kernel_size: 3
use_rope: true
use_stretch_embed: false
use_variance_scaling: true
hidden_size: 384
dur_prediction_args:
Expand Down
1 change: 1 addition & 0 deletions configs/variance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ predict_tension: false

enc_ffn_kernel_size: 3
use_rope: true
use_stretch_embed: false
use_variance_scaling: true
rel_pos: true
hidden_size: 384
Expand Down
22 changes: 14 additions & 8 deletions deployment/modules/fastspeech2.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def forward(
txt_embed = self.txt_embed(tokens)
durations = durations * (tokens > 0)
mel2ph = self.lr(durations)
_mel2ph = mel2ph
f0 = f0 * (mel2ph > 0)
mel2ph = mel2ph[..., None].repeat((1, 1, hparams['hidden_size']))
if self.use_variance_scaling:
Expand All @@ -92,6 +93,14 @@ def forward(
encoded = F.pad(encoded, (0, 0, 1, 0))
condition = torch.gather(encoded, 1, mel2ph)

if self.use_stretch_embed:
stretch = torch.round(1000 * self.sr(_mel2ph, durations))
table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device))
stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(condition)
condition += stretch_embed
stretch_embed_rnn_out, _ = self.stretch_embed_rnn(condition)
condition += stretch_embed_rnn_out

if self.f0_embed_type == 'discrete':
pitch = f0_to_coarse(f0)
pitch_embed = self.pitch_embed(pitch)
Expand All @@ -102,30 +111,27 @@ def forward(

if self.use_variance_embeds:
variance_embeds = torch.stack([
self.variance_embeds[v_name](variances[v_name][:, :, None])
* self.variance_scaling_factor[v_name]
self.variance_embeds[v_name](variances[v_name][:, :, None] * self.variance_scaling_factor[v_name])
for v_name in self.variance_embed_list
], dim=-1).sum(-1)
condition += variance_embeds

if hparams['use_key_shift_embed']:
if hasattr(self, 'frozen_key_shift'):
key_shift_embed = self.key_shift_embed(self.frozen_key_shift[:, None, None])
key_shift_embed = self.key_shift_embed(self.frozen_key_shift[:, None, None] * self.variance_scaling_factor['key_shift'])
else:
gender = torch.clip(gender, min=-1., max=1.)
gender_mask = (gender < 0.).float()
key_shift = gender * ((1. - gender_mask) * self.shift_max + gender_mask * abs(self.shift_min))
key_shift_embed = self.key_shift_embed(key_shift[:, :, None])
key_shift_embed *= self.variance_scaling_factor['key_shift']
key_shift_embed = self.key_shift_embed(key_shift[:, :, None] * self.variance_scaling_factor['key_shift'])
condition += key_shift_embed

if hparams['use_speed_embed']:
if velocity is not None:
velocity = torch.clip(velocity, min=self.speed_min, max=self.speed_max)
speed_embed = self.speed_embed(velocity[:, :, None])
speed_embed = self.speed_embed(velocity[:, :, None] * self.variance_scaling_factor['speed'])
else:
speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None])
speed_embed *= self.variance_scaling_factor['speed']
speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None] * self.variance_scaling_factor['speed'])
condition += speed_embed

if hparams['use_spk_id']:
Expand Down
16 changes: 12 additions & 4 deletions deployment/modules/toplevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,22 +211,30 @@ def forward_linguistic_encoder_phoneme(self, tokens, ph_dur, languages=None):
def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None):
return self.fs2.forward_dur_predictor(encoder_out, x_masks, ph_midi, spk_embed=spk_embed)

def forward_mel2x_gather(self, x_src, x_dur, x_dim=None):
def forward_mel2x_gather(self, x_src, x_dur, x_dim=None, check_stretch_embed=False):
mel2x = self.lr(x_dur)
_mel2x = mel2x
if x_dim is not None:
x_src = F.pad(x_src, [0, 0, 1, 0])
mel2x = mel2x[..., None].repeat([1, 1, x_dim])
else:
x_src = F.pad(x_src, [1, 0])
x_cond = torch.gather(x_src, 1, mel2x)
if self.use_stretch_embed and check_stretch_embed:
stretch = torch.round(1000 * self.sr(_mel2x, x_dur))
table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device))
stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(x_cond)
x_cond += stretch_embed
stretch_embed_rnn_out, _ = self.stretch_embed_rnn(x_cond)
x_cond += stretch_embed_rnn_out
return x_cond

def forward_pitch_preprocess(
self, encoder_out, ph_dur,
note_midi=None, note_rest=None, note_dur=None, note_glide=None,
pitch=None, expr=None, retake=None, spk_embed=None
):
condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size)
condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size, check_stretch_embed=True)
if self.use_melody_encoder:
if self.melody_encoder.use_glide_embed and note_glide is None:
note_glide = torch.LongTensor([[0]]).to(encoder_out.device)
Expand Down Expand Up @@ -280,7 +288,7 @@ def forward_variance_preprocess(
self, encoder_out, ph_dur, pitch,
variances: dict = None, retake=None, spk_embed=None
):
condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size)
condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size, check_stretch_embed=True)
if self.use_variance_scaling:
variance_cond = condition + self.pitch_embed(pitch[:, :, None] / 12)
else:
Expand All @@ -290,7 +298,7 @@ def forward_variance_preprocess(
for v_retake in (~retake).split(1, dim=2)
]
variance_embeds = [
self.variance_embeds[v_name](variances[v_name][:, :, None]) * v_masks * self.variance_retake_scaling[v_name]
self.variance_embeds[v_name](variances[v_name][:, :, None] * self.variance_retake_scaling[v_name]) * v_masks
for v_name, v_masks in zip(self.variance_prediction_list, non_retake_masks)
]
variance_cond += torch.stack(variance_embeds, dim=-1).sum(-1)
Expand Down
2 changes: 1 addition & 1 deletion modules/commons/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,6 @@ def forward(self, x):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = x.unsqueeze(-1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
44 changes: 34 additions & 10 deletions modules/fastspeech/acoustic_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from modules.commons.common_layers import (
NormalInitEmbedding as Embedding,
XavierUniformInitLinear as Linear,
SinusoidalPosEmb,
)
from modules.fastspeech.tts_modules import FastSpeech2Encoder, mel2ph_to_dur
from modules.fastspeech.tts_modules import FastSpeech2Encoder, mel2ph_to_dur, StretchRegulator
from utils.hparams import hparams
from utils.phoneme_utils import PAD_INDEX

Expand All @@ -18,6 +19,19 @@ def __init__(self, vocab_size):
self.use_lang_id = hparams.get('use_lang_id', False)
if self.use_lang_id:
self.lang_embed = Embedding(hparams['num_lang'] + 1, hparams['hidden_size'], padding_idx=0)

self.use_stretch_embed = hparams.get('use_stretch_embed', None)
assert self.use_stretch_embed is not None, "You may be loading an old version of the model checkpoint, which is incompatible with the new version due to some bug fixes. It is recommended to roll back to the old version (commit id: 6df0ee977c3728f14cb79c2db8b19df30b23a0bf)"
if self.use_stretch_embed:
self.sr = StretchRegulator()
self.stretch_embed = nn.Sequential(
SinusoidalPosEmb(hparams['hidden_size']),
nn.Linear(hparams['hidden_size'], hparams['hidden_size'] * 4),
nn.GELU(),
nn.Linear(hparams['hidden_size'] * 4, hparams['hidden_size']),
)
self.stretch_embed_rnn = nn.GRU(hparams['hidden_size'], hparams['hidden_size'], 1, batch_first=True)

self.dur_embed = Linear(1, hparams['hidden_size'])
self.encoder = FastSpeech2Encoder(
hidden_size=hparams['hidden_size'], num_layers=hparams['enc_layers'],
Expand Down Expand Up @@ -84,20 +98,17 @@ def __init__(self, vocab_size):
def forward_variance_embedding(self, condition, key_shift=None, speed=None, **variances):
if self.use_variance_embeds:
variance_embeds = torch.stack([
self.variance_embeds[v_name](variances[v_name][:, :, None])
* self.variance_scaling_factor[v_name]
self.variance_embeds[v_name](variances[v_name][:, :, None] * self.variance_scaling_factor[v_name])
for v_name in self.variance_embed_list
], dim=-1).sum(-1)
condition += variance_embeds

if self.use_key_shift_embed:
key_shift_embed = self.key_shift_embed(key_shift[:, :, None])
key_shift_embed *= self.variance_scaling_factor['key_shift']
key_shift_embed = self.key_shift_embed(key_shift[:, :, None] * self.variance_scaling_factor['key_shift'])
condition += key_shift_embed

if self.use_speed_embed:
speed_embed = self.speed_embed(speed[:, :, None])
speed_embed *= self.variance_scaling_factor['speed']
speed_embed = self.speed_embed(speed[:, :, None] * self.variance_scaling_factor['speed'])
condition += speed_embed

return condition
Expand All @@ -109,11 +120,11 @@ def forward(
**kwargs
):
txt_embed = self.txt_embed(txt_tokens)
dur = mel2ph_to_dur(mel2ph, txt_tokens.shape[1]).float()
dur = mel2ph_to_dur(mel2ph, txt_tokens.shape[1])
if self.use_variance_scaling:
dur_embed = self.dur_embed(torch.log(1 + dur[:, :, None]))
dur_embed = self.dur_embed(torch.log(1 + dur[:, :, None].float()))
else:
dur_embed = self.dur_embed(dur[:, :, None])
dur_embed = self.dur_embed(dur[:, :, None].float())
if self.use_lang_id:
lang_embed = self.lang_embed(languages)
extra_embed = dur_embed + lang_embed
Expand All @@ -125,6 +136,19 @@ def forward(
mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
condition = torch.gather(encoder_out, 1, mel2ph_)

if self.use_stretch_embed:
stretch = torch.round(1000 * self.sr(mel2ph, dur))
if self.training and stretch.numel() > 1000:
# construct a phoneme stretching index lookup table with a total of 1001 indexes (0~1000)
table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device))
stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(condition)
else:
stretch_embed = self.stretch_embed(stretch)
condition += stretch_embed
self.stretch_embed_rnn.flatten_parameters()
stretch_embed_rnn_out, _ = self.stretch_embed_rnn(condition)
condition = condition + stretch_embed_rnn_out

if self.use_spk_id:
spk_mix_embed = kwargs.get('spk_mix_embed')
if spk_mix_embed is not None:
Expand Down
9 changes: 4 additions & 5 deletions modules/fastspeech/tts_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,14 +347,13 @@ def forward(self, mel2ph, dur=None):
"""
if dur is None:
dur = mel2ph_to_dur(mel2ph, mel2ph.max())
dur = F.pad(dur, [1, 0], value=1) # Avoid dividing by zero
dur = torch.cat([torch.ones_like(dur[:, :1]), dur], dim=1) # Avoid dividing by zero
mel2dur = torch.gather(dur, 1, mel2ph)
bound_mask = torch.gt(mel2ph[:, 1:], mel2ph[:, :-1])
bound_mask = F.pad(bound_mask, [0, 1], mode='constant', value=True)
stretch_delta = 1 - bound_mask * mel2dur
stretch_delta = F.pad(stretch_delta, [1, -1], mode='constant', value=0)
stretch_delta = 1 - bound_mask * mel2dur[:, :-1]
stretch_delta = F.pad(stretch_delta, [1, 0])
stretch_denorm = torch.cumsum(stretch_delta, dim=1)
stretch = stretch_denorm / mel2dur
stretch = stretch_denorm.float() / mel2dur
return stretch * (mel2ph > 0)


Expand Down
32 changes: 29 additions & 3 deletions modules/toplevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
from modules.aux_decoder import AuxDecoderAdaptor
from modules.commons.common_layers import (
XavierUniformInitLinear as Linear,
NormalInitEmbedding as Embedding
NormalInitEmbedding as Embedding,
SinusoidalPosEmb
)
from modules.core import (
GaussianDiffusion, PitchDiffusion, MultiVarianceDiffusion,
RectifiedFlow, PitchRectifiedFlow, MultiVarianceRectifiedFlow
)
from modules.fastspeech.acoustic_encoder import FastSpeech2Acoustic
from modules.fastspeech.param_adaptor import ParameterAdaptorModule
from modules.fastspeech.tts_modules import RhythmRegulator, LengthRegulator
from modules.fastspeech.tts_modules import RhythmRegulator, LengthRegulator, StretchRegulator
from modules.fastspeech.variance_encoder import FastSpeech2Variance, MelodyEncoder
from utils.hparams import hparams

Expand Down Expand Up @@ -133,6 +134,18 @@ def __init__(self, vocab_size):
self.predict_dur = hparams['predict_dur']
self.predict_pitch = hparams['predict_pitch']

self.use_stretch_embed = hparams.get('use_stretch_embed', None)
assert self.use_stretch_embed is not None, "You may be loading an old version of the model checkpoint, which is incompatible with the new version due to some bug fixes. It is recommended to roll back to the old version (commit id: 6df0ee977c3728f14cb79c2db8b19df30b23a0bf)"
if self.use_stretch_embed and (self.predict_pitch or self.predict_variances):
self.sr = StretchRegulator()
self.stretch_embed = nn.Sequential(
SinusoidalPosEmb(hparams['hidden_size']),
nn.Linear(hparams['hidden_size'], hparams['hidden_size'] * 4),
nn.GELU(),
nn.Linear(hparams['hidden_size'] * 4, hparams['hidden_size']),
)
self.stretch_embed_rnn = nn.GRU(hparams['hidden_size'], hparams['hidden_size'], 1, batch_first=True)

self.use_spk_id = hparams['use_spk_id']
if self.use_spk_id:
self.spk_embed = Embedding(hparams['num_spk'], hparams['hidden_size'])
Expand Down Expand Up @@ -255,6 +268,19 @@ def forward(
mel2ph_ = mel2ph[..., None].repeat([1, 1, hparams['hidden_size']])
condition = torch.gather(encoder_out, 1, mel2ph_)

if self.use_stretch_embed:
stretch = torch.round(1000 * self.sr(mel2ph, ph_dur))
if self.training and stretch.numel() > 1000:
# construct a phoneme stretching index lookup table with a total of 1001 indexes (0~1000)
table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device))
stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(condition)
else:
stretch_embed = self.stretch_embed(stretch)
condition += stretch_embed
self.stretch_embed_rnn.flatten_parameters()
stretch_embed_rnn_out, _ = self.stretch_embed_rnn(condition)
condition = condition + stretch_embed_rnn_out

if self.use_spk_id:
condition += spk_embed

Expand Down Expand Up @@ -326,7 +352,7 @@ def forward(

if variance_retake is not None:
variance_embeds = [
self.variance_embeds[v_name](v_input[:, :, None]) * ~variance_retake[v_name][:, :, None] * self.variance_retake_scaling[v_name]
self.variance_embeds[v_name](v_input[:, :, None] * self.variance_retake_scaling[v_name]) * ~variance_retake[v_name][:, :, None]
for v_name, v_input in zip(self.variance_prediction_list, variance_inputs)
]
var_cond += torch.stack(variance_embeds, dim=-1).sum(-1)
Expand Down
6 changes: 4 additions & 2 deletions utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,18 @@ class WarmupCosineSchedule(LambdaLR):
`eta_min` (default=0.0) corresponds to the minimum learning rate reached by the scheduler.
"""

def __init__(self, optimizer, warmup_steps, t_total, eta_min=0.0, cycles=.5, last_epoch=-1):
def __init__(self, optimizer, warmup_steps, t_total, warmup_min=0.0, eta_min=0.0, cycles=.5, last_epoch=-1):
self.warmup_steps = warmup_steps
self.t_total = t_total
self.eta_min = eta_min
self.cycles = cycles
self.warmup_min = warmup_min
super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)

def lr_lambda(self, step):
if step < self.warmup_steps:
return step / max(1.0, self.warmup_steps)
progress = step / max(1.0, self.warmup_steps)
return self.warmup_min + progress * (1.0 - self.warmup_min)
# progress after warmup
progress = (step - self.warmup_steps) / max(1, self.t_total - self.warmup_steps)
return max(self.eta_min, 0.5 * (1. + math.cos(math.pi * self.cycles * 2.0 * progress)))
Expand Down