Skip to content

Commit

Permalink
Add support for mcep-based aperiodicity parametrization
Browse files Browse the repository at this point in the history
  • Loading branch information
r9y9 committed Nov 14, 2022
1 parent da39015 commit 44d81fb
Show file tree
Hide file tree
Showing 26 changed files with 251 additions and 35 deletions.
32 changes: 26 additions & 6 deletions nnsvs/bin/anasyn.py
Expand Up @@ -4,6 +4,7 @@

import hydra
import numpy as np
import pysptk
import pyworld
import torch
from hydra.utils import to_absolute_path
Expand Down Expand Up @@ -31,6 +32,7 @@ def anasyn(
frame_period=5,
vuv_threshold=0.5,
use_world_codec=True,
use_mcep_aperiodicity=False,
feature_type="world",
vocoder_type="world",
):
Expand Down Expand Up @@ -62,6 +64,7 @@ def anasyn(
sample_rate,
vuv_threshold=vuv_threshold,
use_world_codec=use_world_codec,
use_mcep_aperiodicity=use_mcep_aperiodicity,
)
wav = pyworld.synthesize(
f0,
Expand Down Expand Up @@ -97,17 +100,34 @@ def anasyn(
elif vocoder_type == "usfgan":
if feature_type == "world":
fftlen = pyworld.get_cheaptrick_fft_size(sample_rate)
aperiodicity = pyworld.decode_aperiodicity(
np.ascontiguousarray(bap).astype(np.float64), sample_rate, fftlen
)
use_mcep_aperiodicity = bap.shape[-1] > 5
if use_mcep_aperiodicity:
mcep_aperiodicity_order = bap.shape[-1] - 1
alpha = pysptk.util.mcepalpha(sample_rate)
aperiodicity = pysptk.mc2sp(
np.ascontiguousarray(bap).astype(np.float64),
fftlen=fftlen,
alpha=alpha,
)
else:
aperiodicity = pyworld.decode_aperiodicity(
np.ascontiguousarray(bap).astype(np.float64), sample_rate, fftlen
)
# fill aperiodicity with ones for unvoiced regions
aperiodicity[vuv.reshape(-1) < vuv_threshold, 0] = 1.0
# WORLD fails catastrophically for out of range aperiodicity
aperiodicity = np.clip(aperiodicity, 0.0, 1.0)
# back to bap
bap = pyworld.code_aperiodicity(aperiodicity, sample_rate).astype(
np.float32
)
if use_mcep_aperiodicity:
bap = pysptk.sp2mc(
aperiodicity,
order=mcep_aperiodicity_order,
alpha=alpha,
)
else:
bap = pyworld.code_aperiodicity(aperiodicity, sample_rate).astype(
np.float32
)

aux_feats = (
torch.from_numpy(
Expand Down
4 changes: 4 additions & 0 deletions nnsvs/bin/conf/prepare_features/acoustic/melf0_48k.yaml
Expand Up @@ -31,6 +31,10 @@ mgc_order: 59
# Use WORLD-based coding for spectral envelope or not
use_world_codec: true

# Use mcep-based aperidicity paramatrization
use_mcep_aperiodicity: false
mcep_aperiodicity_order: 24

# windows to compute delta and delta-delta features
# set 1 to disable
num_windows: 1
Expand Down
Expand Up @@ -31,6 +31,10 @@ mgc_order: 59
# Use WORLD-based coding for spectral envelope or not
use_world_codec: true

# Use mcep-based aperidicity paramatrization
use_mcep_aperiodicity: false
mcep_aperiodicity_order: 24

# windows to compute delta and delta-delta features
# set 1 to disable
num_windows: 3
Expand Down
4 changes: 4 additions & 0 deletions nnsvs/bin/conf/prepare_features/acoustic/static_only.yaml
Expand Up @@ -31,6 +31,10 @@ mgc_order: 59
# Use WORLD-based coding for spectral envelope or not
use_world_codec: true

# Use mcep-based aperidicity paramatrization
use_mcep_aperiodicity: false
mcep_aperiodicity_order: 24

# windows to compute delta and delta-delta features
# set 1 to disable
num_windows: 1
Expand Down
2 changes: 2 additions & 0 deletions nnsvs/bin/prepare_features.py
Expand Up @@ -176,6 +176,8 @@ def my_app(config: DictConfig) -> None:
dynamic_features_flags=config.acoustic.dynamic_features_flags,
use_world_codec=config.acoustic.use_world_codec,
res_type=config.acoustic.res_type,
use_mcep_aperiodicity=config.acoustic.use_mcep_aperiodicity,
mcep_aperiodicity_order=config.acoustic.mcep_aperiodicity_order,
)
elif config.acoustic.feature_type == "melf0":
out_acoustic_source = MelF0AcousticSource(
Expand Down
2 changes: 2 additions & 0 deletions nnsvs/bin/prepare_voc_features.py
Expand Up @@ -75,6 +75,8 @@ def my_app(config: DictConfig) -> None:
config.acoustic.mgc_order,
config.acoustic.num_windows,
config.acoustic.vibrato_mode,
use_mcep_aperiodicity=config.acoustic.use_mcep_aperiodicity,
mcep_aperiodicity_order=config.acoustic.mcep_aperiodicity_order,
)
elif config.acoustic.feature_type == "melf0":
stream_sizes = [80, 1, 1]
Expand Down
14 changes: 13 additions & 1 deletion nnsvs/data/data_source.py
Expand Up @@ -170,6 +170,8 @@ def __init__(
correct_f0=False,
dynamic_features_flags=None,
use_world_codec=False,
use_mcep_aperiodicity=False,
mcep_aperiodicity_order=24,
res_type="scipy",
):
self.utt_list = utt_list
Expand Down Expand Up @@ -197,6 +199,8 @@ def __init__(
self.correct_vuv = correct_vuv
self.correct_f0 = correct_f0
self.use_world_codec = use_world_codec
self.use_mcep_aperiodicity = use_mcep_aperiodicity
self.mcep_aperiodicity_order = mcep_aperiodicity_order
if dynamic_features_flags is None:
# NOTE: we have up to 6 streams: (mgc, lf0, vuv, bap, vib, vib_flags)
dynamic_features_flags = [True, True, False, True, True, False]
Expand Down Expand Up @@ -428,7 +432,15 @@ def collect_features(self, wav_path, label_path):
np.where(is_voiced)[0],
aperiodicity[is_voiced, k],
)
bap = pyworld.code_aperiodicity(aperiodicity, fs)

if self.use_mcep_aperiodicity:
bap = pysptk.sp2mc(
aperiodicity,
order=self.mcep_aperiodicity_order,
alpha=pysptk.util.mcepalpha(fs),
)
else:
bap = pyworld.code_aperiodicity(aperiodicity, fs)

# Parameter trajectory smoothing
if self.trajectory_smoothing:
Expand Down
22 changes: 18 additions & 4 deletions nnsvs/gen.py
Expand Up @@ -708,7 +708,13 @@ def gen_spsvs_static_features(


def gen_world_params(
mgc, lf0, vuv, bap, sample_rate, vuv_threshold=0.3, use_world_codec=False
mgc,
lf0,
vuv,
bap,
sample_rate,
vuv_threshold=0.3,
use_world_codec=False,
):
"""Generate WORLD parameters from mgc, lf0, vuv and bap.
Expand All @@ -726,6 +732,8 @@ def gen_world_params(
"""
fftlen = pyworld.get_cheaptrick_fft_size(sample_rate)
alpha = pysptk.util.mcepalpha(sample_rate)
use_mcep_aperiodicity = bap.shape[-1] > 5

if use_world_codec:
spectrogram = pyworld.decode_spectral_envelope(
np.ascontiguousarray(mgc).astype(np.float64), sample_rate, fftlen
Expand All @@ -734,9 +742,15 @@ def gen_world_params(
spectrogram = pysptk.mc2sp(
np.ascontiguousarray(mgc), fftlen=fftlen, alpha=alpha
)
aperiodicity = pyworld.decode_aperiodicity(
np.ascontiguousarray(bap).astype(np.float64), sample_rate, fftlen
)

if use_mcep_aperiodicity:
aperiodicity = pysptk.mc2sp(
np.ascontiguousarray(bap), fftlen=fftlen, alpha=alpha
)
else:
aperiodicity = pyworld.decode_aperiodicity(
np.ascontiguousarray(bap).astype(np.float64), sample_rate, fftlen
)

# fill aperiodicity with ones for unvoiced regions
aperiodicity[vuv.reshape(-1) < vuv_threshold, 0] = 1.0
Expand Down
39 changes: 28 additions & 11 deletions nnsvs/svs.py
Expand Up @@ -317,7 +317,8 @@ def synthesis_from_timings(
mel[:, d], modfs, cutoff=trajectory_smoothing_cutoff
)

if feature_type == "world":
use_mcep_aperiodicity = bap.shape[-1] > 5
if feature_type == "world" and not use_mcep_aperiodicity:
bap = np.clip(bap, a_min=-60, a_max=0)

# Waveform generation by (1) WORLD or (2) neural vocoder
Expand Down Expand Up @@ -365,19 +366,35 @@ def synthesis_from_timings(
elif vocoder_type == "usfgan":
if feature_type == "world":
fftlen = pyworld.get_cheaptrick_fft_size(sample_rate)
aperiodicity = pyworld.decode_aperiodicity(
np.ascontiguousarray(bap).astype(np.float64),
sample_rate,
fftlen,
)
# fill aperiodicity with ones for unvoiced regions
if use_mcep_aperiodicity:
aperiodicity_order = bap.shape[-1] - 1
alpha = pysptk.util.mcepalpha(sample_rate)
aperiodicity = pysptk.mc2sp(
np.ascontiguousarray(bap).astype(np.float64),
fftlen=fftlen,
alpha=alpha,
)
else:
aperiodicity = pyworld.decode_aperiodicity(
np.ascontiguousarray(bap).astype(np.float64),
sample_rate,
fftlen,
)
# fill aperiodicity with ones for unvoiced regions
aperiodicity[vuv.reshape(-1) < vuv_threshold, 0] = 1.0
# WORLD fails catastrophically for out of range aperiodicity
aperiodicity = np.clip(aperiodicity, 0.0, 1.0)
# back to bap
bap = pyworld.code_aperiodicity(aperiodicity, sample_rate).astype(
np.float32
)

if use_mcep_aperiodicity:
bap = pysptk.sp2mc(
aperiodicity,
order=aperiodicity_order,
alpha=alpha,
)
else:
bap = pyworld.code_aperiodicity(aperiodicity, sample_rate).astype(
np.float32
)
aux_feats = [mgc, bap]
elif feature_type == "melf0":
aux_feats = [mel]
Expand Down
24 changes: 19 additions & 5 deletions nnsvs/train_util.py
Expand Up @@ -1383,7 +1383,12 @@ def synthesize(
else:
# Fallback to WORLD
f0, spectrogram, aperiodicity = gen_world_params(
mgc, lf0, vuv, bap, sr, use_world_codec=use_world_codec
mgc,
lf0,
vuv,
bap,
sr,
use_world_codec=use_world_codec,
)
wav = pyworld.synthesize(f0, spectrogram, aperiodicity, sr, 5)

Expand Down Expand Up @@ -1941,6 +1946,7 @@ def plot_spsvs_params(
fftlen = pyworld.get_cheaptrick_fft_size(sr)
alpha = pysptk.util.mcepalpha(sr)
hop_length = int(sr * 0.005)
use_mcep_aperiodicity = bap.shape[-1] > 5

# Log-F0
if lf0_score is not None:
Expand Down Expand Up @@ -2080,7 +2086,10 @@ def plot_spsvs_params(
fig, ax = plt.subplots(2, 1, figsize=(8, 6))
ax[0].set_title("Reference aperiodicity")
ax[1].set_title("Predicted aperiodicity")
aperiodicity = pyworld.decode_aperiodicity(bap.astype(np.float64), sr, fftlen).T
if use_mcep_aperiodicity:
aperiodicity = pysptk.mc2sp(bap, fftlen=fftlen, alpha=alpha).T
else:
aperiodicity = pyworld.decode_aperiodicity(bap.astype(np.float64), sr, fftlen).T
mesh = librosa.display.specshow(
20 * np.log10(aperiodicity),
sr=sr,
Expand All @@ -2091,9 +2100,14 @@ def plot_spsvs_params(
ax=ax[0],
)
fig.colorbar(mesh, ax=ax[0], format="%+2.f dB")
pred_aperiodicity = pyworld.decode_aperiodicity(
np.ascontiguousarray(pred_bap).astype(np.float64), sr, fftlen
).T
if use_mcep_aperiodicity:
pred_aperiodicity = pysptk.mc2sp(
np.ascontiguousarray(pred_bap), fftlen=fftlen, alpha=alpha
).T
else:
pred_aperiodicity = pyworld.decode_aperiodicity(
np.ascontiguousarray(pred_bap).astype(np.float64), sr, fftlen
).T
mesh = librosa.display.specshow(
20 * np.log10(pred_aperiodicity),
sr=sr,
Expand Down
11 changes: 9 additions & 2 deletions nnsvs/util.py
Expand Up @@ -55,7 +55,12 @@ def init_func(m):


def get_world_stream_info(
sr: int, mgc_order: int, num_windows: int = 3, vibrato_mode: str = "none"
sr: int,
mgc_order: int,
num_windows: int = 3,
vibrato_mode: str = "none",
use_mcep_aperiodicity: bool = False,
mcep_aperiodicity_order: int = 24,
):
"""Get stream sizes for WORLD-based acoustic features
Expand All @@ -73,7 +78,9 @@ def get_world_stream_info(
(mgc_order + 1) * num_windows,
num_windows,
1,
pyworld.get_num_aperiodicities(sr) * num_windows,
pyworld.get_num_aperiodicities(sr) * num_windows
if not use_mcep_aperiodicity
else mcep_aperiodicity_order + 1,
]
if vibrato_mode == "diff":
# vib
Expand Down
Expand Up @@ -31,6 +31,10 @@ mgc_order: 59
# Use WORLD-based coding for spectral envelope or not
use_world_codec: true

# Use mcep-based aperidicity paramatrization
use_mcep_aperiodicity: false
mcep_aperiodicity_order: 24

# windows to compute delta and delta-delta features
# set 1 to disable
num_windows: 1
Expand Down
Expand Up @@ -31,6 +31,10 @@ mgc_order: 59
# Use WORLD-based coding for spectral envelope or not
use_world_codec: true

# Use mcep-based aperidicity paramatrization
use_mcep_aperiodicity: false
mcep_aperiodicity_order: 24

# windows to compute delta and delta-delta features
# set 1 to disable
num_windows: 1
Expand Down
Expand Up @@ -31,6 +31,10 @@ mgc_order: 59
# Use WORLD-based coding for spectral envelope or not
use_world_codec: true

# Use mcep-based aperidicity paramatrization
use_mcep_aperiodicity: false
mcep_aperiodicity_order: 24

# windows to compute delta and delta-delta features
# set 1 to disable
num_windows: 1
Expand Down
Expand Up @@ -31,6 +31,10 @@ mgc_order: 59
# Use WORLD-based coding for spectral envelope or not
use_world_codec: true

# Use mcep-based aperidicity paramatrization
use_mcep_aperiodicity: false
mcep_aperiodicity_order: 24

# windows to compute delta and delta-delta features
# set 1 to disable
num_windows: 3
Expand Down
Expand Up @@ -31,6 +31,10 @@ mgc_order: 59
# Use WORLD-based coding for spectral envelope or not
use_world_codec: true

# Use mcep-based aperidicity paramatrization
use_mcep_aperiodicity: false
mcep_aperiodicity_order: 24

# windows to compute delta and delta-delta features
# set 1 to disable
num_windows: 1
Expand Down
Expand Up @@ -31,6 +31,10 @@ mgc_order: 59
# Use WORLD-based coding for spectral envelope or not
use_world_codec: true

# Use mcep-based aperidicity paramatrization
use_mcep_aperiodicity: false
mcep_aperiodicity_order: 24

# windows to compute delta and delta-delta features
# set 1 to disable
num_windows: 1
Expand Down

0 comments on commit 44d81fb

Please sign in to comment.