Skip to content

Commit

Permalink
Merge pull request #16 from taroushirani/dev20220803
Browse files Browse the repository at this point in the history
Vibrato model and GAN-based mgc postfilter support.
  • Loading branch information
oatsu-gh committed Aug 4, 2022
2 parents 29be4a4 + 124aeb8 commit 97b197f
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 36 deletions.
12 changes: 9 additions & 3 deletions synthesis/enulib/acoustic.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,13 @@ def timing2acoustic(config: DictConfig, timing_path, acoustic_path):
# pitch_idx = len(binary_dict) + 1
pitch_indices = np.arange(len(binary_dict), len(binary_dict)+3)

# f0の設定を読み取る。
log_f0_conditioning = config.log_f0_conditioning
# check force_clip_input_features (for backward compatibility)
force_clip_input_features = True
try:
force_clip_input_features = config.acoustic.force_clip_input_features
except:
logger.info(f"force_clip_input_features of {typ} is not set so enabled as default")

acoustic_features = predict_acoustic(
device,
duration_modified_labels,
Expand All @@ -115,7 +120,8 @@ def timing2acoustic(config: DictConfig, timing_path, acoustic_path):
continuous_dict,
config.acoustic.subphone_features,
pitch_indices,
log_f0_conditioning
config.log_f0_conditioning,
force_clip_input_features
)

# csvファイルとしてAcousticの行列を出力
Expand Down
31 changes: 20 additions & 11 deletions synthesis/enulib/timing.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,13 @@ def _score2timelag(config: DictConfig, labels):
# pitch_idx = len(binary_dict) + 1
pitch_indices = np.arange(len(binary_dict), len(binary_dict)+3)

# f0の設定を読み取る。
log_f0_conditioning = config.log_f0_conditioning

# check force_clip_input_features (for backward compatibility)
force_clip_input_features = True
try:
force_clip_input_features = config.timelag.force_clip_input_features
except:
logger.info(f"force_clip_input_features of {typ} is not set so enabled as default")

# timelagモデルを適用
# Time-lag
lag = predict_timelag(
Expand All @@ -91,9 +95,10 @@ def _score2timelag(config: DictConfig, labels):
binary_dict,
continuous_dict,
pitch_indices,
log_f0_conditioning,
config.log_f0_conditioning,
config.timelag.allowed_range,
config.timelag.allowed_range_rest
config.timelag.allowed_range_rest,
force_clip_input_features
)
# -----------------------------------------------------
# ここまで nnsvs.bin.synthesis.synthesis() の内容 -----
Expand Down Expand Up @@ -155,14 +160,18 @@ def _score2duration(config: DictConfig, labels):
# config[typ].question_path = config.question_path
# --------------------------------------
# hedファイルを辞書として読み取る。
binary_dict, continuous_dict = \
binary_dict, numeric_dict = \
hts.load_question_set(question_path, append_hat_for_LL=False)
# pitch indices in the input features
# pitch_idx = len(binary_dict) + 1
pitch_indices = np.arange(len(binary_dict), len(binary_dict)+3)

# f0の設定を読み取る。
log_f0_conditioning = config.log_f0_conditioning
# check force_clip_input_features (for backward compatibility)
force_clip_input_features = True
try:
force_clip_input_features = config.duration.force_clip_input_features
except:
logger.info(f"force_clip_input_features of {typ} is not set so enabled as default")

# durationモデルを適用
duration = predict_duration(
Expand All @@ -173,10 +182,10 @@ def _score2duration(config: DictConfig, labels):
in_scaler,
out_scaler,
binary_dict,
continuous_dict,
numeric_dict,
pitch_indices,
log_f0_conditioning,
force_clip_input_features=False
config.log_f0_conditioning,
force_clip_input_features
)
# durationのタプルまたはndarrayを返す
return duration
Expand Down
151 changes: 129 additions & 22 deletions synthesis/enulib/world.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,26 @@
"""
acousticのファイルをWAVファイルにするまでの処理を行う。
"""
import hydra
import numpy as np
import pysptk
import pyworld
from enulib.common import set_checkpoint, set_normalization_stat
from hydra.utils import to_absolute_path
from nnmnkwii.io import hts
from nnsvs.gen import gen_world_params
from nnmnkwii.postfilters import merlin_post_filter
from nnsvs.dsp import bandpass_filter
from nnsvs.gen import (
gen_spsvs_static_features,
gen_world_params
)
from nnsvs.logger import getLogger
from nnsvs.multistream import get_static_stream_sizes
from nnsvs.pitch import lowpass_filter
from nnsvs.postfilters import variance_scaling
from omegaconf import DictConfig, OmegaConf
from scipy.io import wavfile
import torch

logger = None

Expand Down Expand Up @@ -89,7 +101,7 @@ def generate_wav_file(config: DictConfig, wav, out_wav_path):
# # --------------------------------------

# # hedファイルを辞書として読み取る。
# binary_dict, continuous_dict = hts.load_question_set(
# binary_dict, numeric_dict = hts.load_question_set(
# question_path, append_hat_for_LL=False
# )

Expand All @@ -109,7 +121,7 @@ def generate_wav_file(config: DictConfig, wav, out_wav_path):
# duration_modified_labels,
# acoustic_features,
# binary_dict,
# continuous_dict,
# numeric_dict,
# model_config.stream_sizes,
# model_config.has_dynamic_features,
# subphone_features=config.acoustic.subphone_features,
Expand All @@ -127,7 +139,11 @@ def generate_wav_file(config: DictConfig, wav, out_wav_path):


def acoustic2world(config: DictConfig, path_timing, path_acoustic,
path_f0, path_spcetrogram, path_aperiodicity):
path_f0, path_spcetrogram, path_aperiodicity,
trajectory_smoothing=True,
trajectory_smoothing_cutoff=50,
vibrato_scale=1.0,
vuv_threshold=0.1):
"""
Acousticの行列のCSVを読んで、WAVファイルとして出力する。
"""
Expand All @@ -140,11 +156,10 @@ def acoustic2world(config: DictConfig, path_timing, path_acoustic,
duration_modified_labels = hts.load(path_timing).round_()

# CUDAが使えるかどうか
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 各種設定を読み込む
typ = 'acoustic'
model_config = OmegaConf.load(to_absolute_path(config[typ].model_yaml))
acoustic_model_config = OmegaConf.load(to_absolute_path(config["acoustic"].model_yaml))

# hedファイルを読み取る。
question_path = to_absolute_path(config.question_path)
Expand All @@ -155,7 +170,7 @@ def acoustic2world(config: DictConfig, path_timing, path_acoustic,
# --------------------------------------

# hedファイルを辞書として読み取る。
binary_dict, continuous_dict = hts.load_question_set(
binary_dict, numeric_dict = hts.load_question_set(
question_path, append_hat_for_LL=False
)

Expand All @@ -169,25 +184,114 @@ def acoustic2world(config: DictConfig, path_timing, path_acoustic,
path_acoustic, delimiter=',', dtype=np.float64
)

# AcousticからWORLD用のパラメータを取り出す。
f0, spectrogram, aperiodicity = gen_world_params(
# postfilter setting
try:
# substitute of maybe_set_checkpoints_(config)
set_checkpoint(config, "postfilter")
# substitute of maybe_set_normalization_stats_(config)
set_normalization_stat(config, "postfilter")
except:
logger.info(f"There is no post_filter_type setting so merlin is used.")

try:
post_filter_type = config.acoustic.post_filter_type
except:
logger.info(f"There is no post_filter_type setting so merlin is used.")
post_filter_type = "merlin"

if post_filter_type not in ["merlin", "nnsvs", "gv", "none"]:
logger.info(f"Unknown post-filter type: {post_filter_type} so merlin is used.")
post_filter_type = "merlin"

if config.acoustic.post_filter is not None:
logger.info("post_filter is deprecated. Use post_filter_type instead.")

try:
postfilter_out_scaler = joblib.load(config["postfilter"].out_scaler_path)
# Apply GV post-filtering
if post_filter_type in ["nnsvs", "gv"]:
logger.info("Apply GV post-filtering")
static_stream_sizes = get_static_stream_sizes(
acoustic_model_config.stream_sizes,
acoustic_model_config.has_dynamic_features,
acoustic_model_config.num_windows,
)
mgc_end_dim = static_stream_sizes[0]
acoustic_features[:, :mgc_end_dim] = variance_scaling(
postfilter_out_scaler.var_.reshape(-1)[:mgc_end_dim],
acoustic_features[:, :mgc_end_dim],
offset=2,
)
# bap
bap_start_dim = sum(static_stream_sizes[:3])
bap_end_dim = sum(static_stream_sizes[:4])
acoustic_features[:, bap_start_dim:bap_end_dim] = variance_scaling(
postfilter_out_scaler.var_.reshape(-1)[bap_start_dim:bap_end_dim],
acoustic_features[:, bap_start_dim:bap_end_dim],
offset=0,
)

# Learned post-filter using nnsvs
if post_filter_type == "nnsvs":
postfilter_model_config = OmegaConf.load(to_absolute_path(config["postfilter"].model_yaml))
postfilter_model = hydra.utils.instantiate(postfilter_model_config.netG).to(device)

logger.info("Apply mgc_postfilter")
in_feats = (
torch.from_numpy(acoustic_features).float().unsqueeze(0)
)
in_feats = postfilter_out_scaler.transform(in_feats).float().to(device)
out_feats = postfilter_model.inference(in_feats, [in_feats.shape[1]])
acoustic_features = (
postfilter_out_scaler.inverse_transform(out_feats.cpu())
.squeeze(0)
.numpy()
)
except Exception as e:
logger.info(e)
logger.info("Unable to use NNSVS/GV postfilter")

# Generate static features from acoustic features
mgc, lf0, vuv, bap = gen_spsvs_static_features(
duration_modified_labels,
acoustic_features,
binary_dict,
continuous_dict,
model_config.stream_sizes,
model_config.has_dynamic_features,
subphone_features=config.acoustic.subphone_features,
pitch_idx=pitch_idx,
num_windows=model_config.num_windows,
post_filter=config.acoustic.post_filter,
sample_rate=config.sample_rate,
frame_period=config.frame_period,
relative_f0=config.acoustic.relative_f0,
vibrato_scale=1.0,
vuv_threshold=0.3
numeric_dict,
acoustic_model_config.stream_sizes,
acoustic_model_config.has_dynamic_features,
config.acoustic.subphone_features,
pitch_idx,
acoustic_model_config.num_windows,
config.frame_period,
config.acoustic.relative_f0,
vibrato_scale,
vuv_threshold
)

# NOTE: spectral enhancement based on the Merlin's post-filter implementation
if post_filter_type == "merlin":
alpha = pysptk.util.mcepalpha(config.sample_rate)
mgc = merlin_post_filter(mgc, alpha)

# Remove high-frequency components of mgc/bap
# NOTE: It seems to be effective to suppress artifacts of GAN-based post-filtering

if trajectory_smoothing:
modfs = int(1 / 0.005)
for d in range(mgc.shape[1]):
mgc[:, d] = lowpass_filter(
mgc[:, d], modfs, cutoff=trajectory_smoothing_cutoff
)
for d in range(bap.shape[1]):
bap[:, d] = lowpass_filter(
bap[:, d], modfs, cutoff=trajectory_smoothing_cutoff
)

# Generate WORLD parameters
f0, spectrogram, aperiodicity = gen_world_params(
mgc, lf0, vuv, bap, config.sample_rate, vuv_threshold=vuv_threshold
)

# csvファイルとしてf0の行列を出力
for path, array in (
(path_f0, f0),
Expand Down Expand Up @@ -216,5 +320,8 @@ def world2wav(config: DictConfig, path_f0, path_spectrogram, path_aperiodicity,
wav = pyworld.synthesize(
f0, spectrogram, aperiodicity, config.sample_rate, config.frame_period
)

wav = bandpass_filter(wav, config.sample_rate)

# 音量を調整して 32bit float でファイル出力
generate_wav_file(config, wav, path_wav)

0 comments on commit 97b197f

Please sign in to comment.