Skip to content

Commit

Permalink
Merge pull request #172 from nnsvs/fix-mdn-multi
Browse files Browse the repository at this point in the history
Fix MDN-based multi-stream models
  • Loading branch information
r9y9 committed Nov 21, 2022
2 parents 3d62e71 + 9fb3b24 commit 8b282f3
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 18 deletions.
66 changes: 51 additions & 15 deletions nnsvs/acoustic_models/multistream.py
Expand Up @@ -484,43 +484,62 @@ def forward(self, x, lengths=None, y=None):
# Predict continuous log-F0 first
if is_inference:
lf0, lf0_residual = self.lf0_model.inference(x, lengths), None
if self.lf0_model.prediction_type() == PredictionType.PROBABILISTIC:
lf0_cond = lf0[0]
else:
lf0_cond = lf0
else:
lf0, lf0_residual = self.lf0_model(x, lengths, y_lf0)

# Predict spectral parameters
if is_inference:
mgc_inp = torch.cat([x, lf0], dim=-1)
mgc_inp = torch.cat([x, lf0_cond], dim=-1)
mgc = self.mgc_model.inference(mgc_inp, lengths)
else:
mgc_inp = torch.cat([x, y_lf0], dim=-1)
mgc = self.mgc_model(mgc_inp, lengths, y_mgc)

# Predict aperiodic parameters
if is_inference:
bap_inp = torch.cat([x, lf0], dim=-1)
bap_inp = torch.cat([x, lf0_cond], dim=-1)
bap = self.bap_model.inference(bap_inp, lengths)
else:
bap_inp = torch.cat([x, y_lf0], dim=-1)
bap = self.bap_model(bap_inp, lengths, y_bap)

# Predict V/UV
if is_inference:
if self.vuv_model_bap0_conditioning:
bap_inp = bap[1][:, :, 0:1]
if self.bap_model.prediction_type() == PredictionType.PROBABILISTIC:
bap_cond = bap[0]
else:
bap_inp = bap[1]
vuv_inp = torch.cat([x, lf0, bap_inp], dim=-1)
bap_cond = bap

if self.vuv_model_bap0_conditioning:
bap_cond = bap_cond[:, :, 0:1]

vuv_inp = torch.cat([x, lf0_cond, bap_cond], dim=-1)
vuv = self.vuv_model.inference(vuv_inp, lengths)
else:
if self.vuv_model_bap0_conditioning:
y_bap_inp = y_bap[:, :, 0:1]
y_bap_cond = y_bap[:, :, 0:1]
else:
y_bap_inp = y_bap
vuv_inp = torch.cat([x, lf0, y_bap_inp], dim=-1)
y_bap_cond = y_bap
vuv_inp = torch.cat([x, y_lf0, y_bap_cond], dim=-1)
vuv = self.vuv_model(vuv_inp, lengths, y_vuv)

if is_inference:
out = torch.cat([mgc[0], lf0, vuv, bap[0]], dim=-1)
if self.lf0_model.prediction_type() == PredictionType.PROBABILISTIC:
lf0_ = lf0[0]
else:
lf0_ = lf0
if self.bap_model.prediction_type() == PredictionType.PROBABILISTIC:
bap_ = bap[0]
else:
bap_ = bap
if self.mgc_model.prediction_type() == PredictionType.PROBABILISTIC:
mgc_ = mgc[0]
else:
mgc_ = mgc
out = torch.cat([mgc_, lf0_, vuv, bap_], dim=-1)
assert out.shape[-1] == self.out_dim
# TODO: better design
return out, out
Expand Down Expand Up @@ -794,27 +813,44 @@ def forward(self, x, lengths=None, y=None):
# Predict continuous log-F0 first
if is_inference:
lf0, lf0_residual = self.lf0_model.inference(x, lengths), None
if self.lf0_model.prediction_type() == PredictionType.PROBABILISTIC:
lf0_cond = lf0[0]
else:
lf0_cond = lf0
else:
lf0, lf0_residual = self.lf0_model(x, lengths, y_lf0)

# Predict mel
if is_inference:
mel_inp = torch.cat([x, lf0], dim=-1)
mel_inp = torch.cat([x, lf0_cond], dim=-1)
mel = self.mel_model.inference(mel_inp, lengths)
else:
mel_inp = torch.cat([x, y_lf0], dim=-1)
mel = self.mel_model(mel_inp, lengths, y_mel)

# Predict V/UV
if is_inference:
vuv_inp = torch.cat([x, lf0, mel[1]], dim=-1)
if self.mel_model.prediction_type() == PredictionType.PROBABILISTIC:
mel_cond = mel[0]
else:
mel_cond = mel

vuv_inp = torch.cat([x, lf0_cond, mel_cond], dim=-1)
vuv = self.vuv_model.inference(vuv_inp, lengths)
else:
vuv_inp = torch.cat([x, lf0, y_mel], dim=-1)
vuv_inp = torch.cat([x, y_lf0, y_mel], dim=-1)
vuv = self.vuv_model(vuv_inp, lengths, y_vuv)

if is_inference:
out = torch.cat([mel[0], lf0, vuv], dim=-1)
if self.lf0_model.prediction_type() == PredictionType.PROBABILISTIC:
lf0_ = lf0[0]
else:
lf0_ = lf0
if self.mel_model.prediction_type() == PredictionType.PROBABILISTIC:
mel_ = mel[0]
else:
mel_ = mel
out = torch.cat([mel_, lf0_, vuv], dim=-1)
assert out.shape[-1] == self.out_dim
# TODO: better design
return out, out
Expand Down
4 changes: 4 additions & 0 deletions nnsvs/bin/train_acoustic.py
Expand Up @@ -184,6 +184,8 @@ def train_step(
if prediction_type == PredictionType.MULTISTREAM_HYBRID:
if len(pred_out_feats) == 4:
pred_mgc, pred_lf0, pred_vuv, pred_bap = pred_out_feats
if isinstance(pred_lf0, tuple):
pred_lf0 = mdn_get_most_probable_sigma_and_mu(*pred_lf0)[1]
if isinstance(pred_mgc, tuple):
pred_mgc = mdn_get_most_probable_sigma_and_mu(*pred_mgc)[1]
if isinstance(pred_bap, tuple):
Expand All @@ -193,6 +195,8 @@ def train_step(
)
elif len(pred_out_feats) == 3:
pred_mel, pred_lf0, pred_vuv = pred_out_feats
if isinstance(pred_lf0, tuple):
pred_lf0 = mdn_get_most_probable_sigma_and_mu(*pred_lf0)[1]
if isinstance(pred_mel, tuple):
pred_mel = mdn_get_most_probable_sigma_and_mu(*pred_mel)[1]
pred_out_feats_ = torch.cat([pred_mel, pred_lf0, pred_vuv], dim=-1)
Expand Down
4 changes: 4 additions & 0 deletions nnsvs/train_util.py
Expand Up @@ -1726,6 +1726,8 @@ def eval_spss_model(
# Hybrid
if prediction_type == PredictionType.MULTISTREAM_HYBRID:
pred_mgc, pred_lf0, pred_vuv, pred_bap = outs
if isinstance(pred_lf0, tuple):
pred_lf0 = mdn_get_most_probable_sigma_and_mu(*pred_lf0)[1]
if isinstance(pred_mgc, tuple):
pred_mgc = mdn_get_most_probable_sigma_and_mu(*pred_mgc)[1]
if isinstance(pred_bap, tuple):
Expand Down Expand Up @@ -1944,6 +1946,8 @@ def eval_mel_model(
# Hybrid
if prediction_type == PredictionType.MULTISTREAM_HYBRID:
pred_logmel, pred_lf0, pred_vuv = outs
if isinstance(pred_lf0, tuple):
pred_lf0 = mdn_get_most_probable_sigma_and_mu(*pred_lf0)[1]
if isinstance(pred_logmel, tuple):
pred_logmel = mdn_get_most_probable_sigma_and_mu(*pred_logmel)[1]
pred_out_feats = torch.cat([pred_logmel, pred_lf0, pred_vuv], dim=-1)
Expand Down
11 changes: 8 additions & 3 deletions tests/test_acoustic_models.py
Expand Up @@ -167,8 +167,9 @@ def test_resf0_variance_predictor(num_gaussians):

@pytest.mark.parametrize("reduction_factor", [1, 2])
@pytest.mark.parametrize("num_gaussians", [1, 2, 4])
@pytest.mark.parametrize("use_mdn_lf0", [False, True])
def test_hybrid_multistream_mel_model_vuv_pred_from_mel(
reduction_factor, num_gaussians
reduction_factor, num_gaussians, use_mdn_lf0
):
params = {
"in_dim": 300,
Expand All @@ -183,6 +184,8 @@ def test_hybrid_multistream_mel_model_vuv_pred_from_mel(
num_layers=1,
in_lf0_idx=-1,
out_lf0_idx=0,
use_mdn=use_mdn_lf0,
num_gaussians=num_gaussians,
),
# Decoders
"mel_model": MDN(
Expand Down Expand Up @@ -251,8 +254,9 @@ def test_multistream_mel_model(reduction_factor):

@pytest.mark.parametrize("num_gaussians", [1, 2, 4])
@pytest.mark.parametrize("vuv_model_bap0_conditioning", [False, True])
@pytest.mark.parametrize("use_mdn_lf0", [False, True])
def test_npss_mdn_multistream_parametric_model(
num_gaussians, vuv_model_bap0_conditioning
num_gaussians, vuv_model_bap0_conditioning, use_mdn_lf0
):
params = {
"in_dim": 300,
Expand All @@ -267,7 +271,8 @@ def test_npss_mdn_multistream_parametric_model(
num_layers=1,
in_lf0_idx=-1,
out_lf0_idx=0,
use_mdn=False,
use_mdn=use_mdn_lf0,
num_gaussians=num_gaussians,
),
# Decoders
"mgc_model": MDN(
Expand Down

0 comments on commit 8b282f3

Please sign in to comment.