Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multilingual support #11

Closed
wants to merge 95 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
9fddb8f
Added reversal classifier in tacotron2
WeberJulian Feb 21, 2021
154408e
Added loss and config for adverserial classifier
WeberJulian Feb 21, 2021
9cf7a10
Removing cosine similarity classifier
WeberJulian Feb 21, 2021
446c829
Removing unused GradientClippingFunction
WeberJulian Feb 21, 2021
3c162e6
fixes reversal classifier
WeberJulian Feb 22, 2021
616b18b
reversal classifier first training
WeberJulian Feb 23, 2021
c94eeb0
Add resample script
WeberJulian Mar 3, 2021
3686713
Using path.join instead of concat
WeberJulian Mar 3, 2021
e09743f
fix speaker_id default value for evaluation
WeberJulian Mar 4, 2021
cccc5d6
linter + test
WeberJulian Mar 5, 2021
b133cbf
test case
WeberJulian Mar 5, 2021
32968fa
fix french_cleaners
WeberJulian Mar 5, 2021
8fd272c
fix linter issues
WeberJulian Mar 6, 2021
fdf9dad
Merge branch 'dev' of https://github.com/coqui-ai/TTS into dev
WeberJulian Mar 11, 2021
358f86f
Merge branch 'dev' into multilingual
WeberJulian Mar 11, 2021
24ef107
Fix input dim with gst
WeberJulian Mar 12, 2021
807afd3
Adding multilingual encoder
WeberJulian Mar 13, 2021
82a8d26
linter fixes
WeberJulian Mar 13, 2021
f0baa74
fix linter issues again
WeberJulian Mar 13, 2021
c63a568
last linter fix
WeberJulian Mar 13, 2021
4bfa6ff
fix tests
WeberJulian Mar 14, 2021
36a1460
Trains without crash
WeberJulian Apr 11, 2021
e88ea1d
Actually using multilingual in Tacotron2
WeberJulian Apr 12, 2021
9f3f69e
working with test sentences
WeberJulian Apr 12, 2021
aaf43f1
retore path works
WeberJulian Apr 13, 2021
d9f3bed
Replace lang_ and langs_ contractions
WeberJulian Apr 13, 2021
199aefe
Correct synthesis bug and add GL sythesis notbook
WeberJulian Apr 13, 2021
fd95f39
Merge branch 'dev'
WeberJulian Apr 13, 2021
4072cdd
Enhancments
WeberJulian Apr 13, 2021
ca4dafd
remove unused language_embedding
WeberJulian Apr 13, 2021
904c6b1
fix synthesis
WeberJulian Apr 13, 2021
523e87b
fix resample after optimization
WeberJulian Apr 18, 2021
c8d1767
add weighted_sampler
WeberJulian Apr 18, 2021
2ef3868
notebook changes
WeberJulian Apr 18, 2021
6152bc0
fir odd number of languages
WeberJulian Apr 19, 2021
e12d01b
Add feature to specify speaker/language test file
WeberJulian Apr 19, 2021
698663d
Add language embedding after encoder
WeberJulian Apr 24, 2021
bc9bbb7
new preprocessors
WeberJulian Apr 24, 2021
6be6d1b
Merge remote-tracking branch 'coqui/dev' into multilingual
WeberJulian Apr 24, 2021
99bb49f
HifiGan Sythesis
WeberJulian Apr 29, 2021
d732ebe
Edresson's fix
WeberJulian Apr 30, 2021
de298e3
cleanup after first successfull trainning
WeberJulian May 20, 2021
7d525c0
quick fix
WeberJulian May 25, 2021
72df02f
Refacto and reversal loss fix
WeberJulian May 25, 2021
6ae4695
set speaker_embedding_dim back to 512
WeberJulian May 25, 2021
abfad8f
Temporary fix
WeberJulian May 27, 2021
4920676
added genereted
WeberJulian May 28, 2021
95d983e
support single language training
WeberJulian Jun 1, 2021
e971b9a
separate speaker and language sampler
WeberJulian Jun 1, 2021
44c8791
perfect sampler and generated fixes
WeberJulian Jun 4, 2021
5094923
Generated encoder now runs but slow
WeberJulian Jun 4, 2021
bf69366
first training generated encoder
WeberJulian Jun 4, 2021
ee9bfa7
Fixed inference
WeberJulian Jun 10, 2021
c5faf4e
Generated encoder working
WeberJulian Jun 14, 2021
357be71
add glowTTS multilingual support
Edresson Jun 14, 2021
2ec6232
switch to batch sampler
WeberJulian Jun 16, 2021
120a701
fix batch_n_iter
WeberJulian Jun 16, 2021
4bd6ff4
Merge pull request #1 from Edresson/multilingual
Edresson Jun 17, 2021
f09ec9b
Fixes
WeberJulian Jun 17, 2021
e986558
Merge branch 'multilingual' of https://github.com/WeberJulian/TTS-1 i…
WeberJulian Jun 17, 2021
2299c18
Bug fix on LibriTTS preprocess
Edresson Jun 17, 2021
a51a91f
bug fix
Edresson Jun 18, 2021
d32d3f8
add script for remove silence using VAD
Edresson Jun 18, 2021
bb3897e
fix split dataset
WeberJulian Jun 19, 2021
0353fa3
Merge branch 'multilingual' of https://github.com/WeberJulian/TTS-1 i…
WeberJulian Jun 19, 2021
8e99c13
add stochastic duration predictor
Edresson Jun 20, 2021
b124a5e
Merge branch 'multilingual' of https://github.com/WeberJulian/TTS-1 i…
Edresson Jun 20, 2021
ed4777e
fix documentation
Edresson Jun 20, 2021
e7ecac5
bug fix on Vad remove silence script
Edresson Jun 20, 2021
ea26c10
add extra slots for new languages
Edresson Jun 21, 2021
500fef3
Move Dataloaders closer to the train and eval call
WeberJulian Jun 24, 2021
2f35f76
set glowtts noise_scale value to 0
Edresson Jun 27, 2021
fb03de5
cond language embedding on duration predictor and add inference for v…
Edresson Jun 28, 2021
4905202
bug fix in decoderr inference
Edresson Jun 28, 2021
69230f0
add reversal classifier in GlowTTS
Edresson Jun 30, 2021
2542426
bugfix
Edresson Jul 1, 2021
eb5ede1
add extract spectrogram script
Edresson Jul 2, 2021
8b819f8
fixes
WeberJulian Jul 5, 2021
e85b047
fix eval bug
WeberJulian Jul 5, 2021
c4bf6e7
Merge remote-tracking branch 'origin/multilingual' into multilingual
WeberJulian Jul 5, 2021
8e5afec
Allow for reversal classifier when eval contains unseen speakers
WeberJulian Jul 19, 2021
1ccd691
Allow for differences in feat and wav paths for vocoder training
WeberJulian Jul 21, 2021
fe7eb0a
add pitch predictor support
Edresson Jul 21, 2021
e1f1476
add freeze model parts option in config
Edresson Jul 21, 2021
6e71856
add pitch predictor support
Edresson Jul 21, 2021
a3d523a
Merge remote-tracking branch 'origin/multilingual' into multilingual
WeberJulian Jul 21, 2021
7a1d186
add config datasets support to the gan dataloader
Edresson Jul 21, 2021
a30eadd
add pitch transform
Edresson Jul 23, 2021
3394378
bug fix
Edresson Jul 23, 2021
a19ab0b
pitch predictor bug fix
Edresson Jul 24, 2021
6c8eb30
glowtts singke speaker train bug fix
Edresson Jul 23, 2021
f7e1e37
update pitch predictor network
Edresson Jul 25, 2021
43e4415
bug fix
Edresson Aug 2, 2021
aa10b54
add VITS model support
Edresson Aug 5, 2021
a7963e0
bug fix
Edresson Aug 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,6 @@ TTS/tts/layers/glow_tts/monotonic_align/core.c
.vscode-upload.json
temp_build/*
recipes/*

# nohup
*.out
329 changes: 329 additions & 0 deletions TTS/bin/extract_tts_spectrograms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,329 @@
#!/usr/bin/env python3
"""Extract Mel spectrograms with teacher forcing."""

import argparse
import os

import numpy as np
import torch
from shutil import copyfile
from torch.utils.data import DataLoader
from tqdm import tqdm

from TTS.utils.io import load_config
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.speakers import parse_speakers, load_language_mapping
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters

use_cuda = torch.cuda.is_available()


def setup_loader(ap, r, verbose=False):
dataset = MyDataset(
r,
c.text_cleaner,
compute_linear_spec=False,
meta_data=meta_data,
ap=ap,
tp=c.characters if "characters" in c.keys() else None,
add_blank=c["add_blank"] if "add_blank" in c.keys() else False,
batch_group_size=0,
min_seq_len=c.min_seq_len,
max_seq_len=c.max_seq_len,
phoneme_cache_path=c.phoneme_cache_path,
use_phonemes=c.use_phonemes,
phoneme_language=c.phoneme_language,
enable_eos_bos=c.enable_eos_bos_chars,
use_noise_augment=False,
verbose=verbose,
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None,
)

if c.use_phonemes and c.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(c.num_loader_workers)
dataset.sort_items()

loader = DataLoader(
dataset,
batch_size=c.batch_size,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
sampler=None,
num_workers=c.num_loader_workers,
pin_memory=False,
)
return loader


def set_filename(wav_path, out_path):
wav_file = os.path.basename(wav_path)
file_name = wav_file.split(".")[0]
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True)
os.makedirs(os.path.join(out_path, "wav"), exist_ok=True)
wavq_path = os.path.join(out_path, "quant", file_name)
mel_path = os.path.join(out_path, "mel", file_name)
wav_gl_path = os.path.join(out_path, "wav_gl", file_name + ".wav")
wav_path = os.path.join(out_path, "wav", file_name + ".wav")
return file_name, wavq_path, mel_path, wav_gl_path, wav_path


def format_data(data):
# setup input data
text_input = data[0]
text_lengths = data[1]
speaker_names = data[2]
mel_input = data[4].permute(0, 2, 1) # B x D x T
mel_lengths = data[5]
language_names = data[7]
item_idx = data[8]
attn_mask = data[10]
avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float())
if c.use_speaker_embedding:
if c.use_external_speaker_embedding_file:
# return precomputed embedding vector
speaker_c = data[9]
else:
# return speaker_id to be used by an embedding layer
speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names]
speaker_c = torch.LongTensor(speaker_c)
else:
speaker_c = None

if c.use_language_embedding:
language_ids = [language_mapping[language_name] for language_name in language_names]
language_ids = torch.LongTensor(language_ids)
else:
language_ids = None

# dispatch data to GPU
if use_cuda:
text_input = text_input.cuda(non_blocking=True)
text_lengths = text_lengths.cuda(non_blocking=True)
mel_input = mel_input.cuda(non_blocking=True)
mel_lengths = mel_lengths.cuda(non_blocking=True)
if speaker_c is not None:
speaker_c = speaker_c.cuda(non_blocking=True)
if attn_mask is not None:
attn_mask = attn_mask.cuda(non_blocking=True)
if language_ids is not None:
language_ids = language_ids.cuda(non_blocking=True)
return (
text_input,
text_lengths,
mel_input,
mel_lengths,
speaker_c,
language_ids,
avg_text_length,
avg_spec_length,
attn_mask,
item_idx,
)


@torch.no_grad()
def inference(
model_name,
model,
ap,
text_input,
text_lengths,
mel_input,
mel_lengths,
attn_mask=None,
speaker_c=None,
language_ids=None,
):
if model_name == "glow_tts":
# mel_input = mel_input.permute(0, 2, 1) # B x D x T
if args.only_decoder:
print("only decoder")
model_output, *_ = model.decoder_inference(
mel_input, mel_lengths, g=speaker_c
)
else:
model_output, *_ = model.inference_with_MAS(
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c, language_ids=language_ids
)

model_output = model_output.transpose(1, 2).detach().cpu().numpy()

elif "tacotron" in model_name:
speaker_ids = None
speaker_embeddings = None
if c.use_speaker_embedding:
if c.use_external_speaker_embedding_file:
speaker_embeddings = speaker_c
else:
speaker_ids = speaker_c

_, postnet_outputs, *_ = model(
text_input,
text_lengths,
mel_input,
mel_lengths,
speaker_ids=speaker_ids,
speaker_embeddings=speaker_embeddings,
language_ids=language_ids,
)

# normalize tacotron output
if model_name == "tacotron":
mel_specs = []
postnet_outputs = postnet_outputs.data.cpu().numpy()
for b in range(postnet_outputs.shape[0]):
postnet_output = postnet_outputs[b]
mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T))
model_output = torch.stack(mel_specs).cpu().numpy()

elif model_name == "tacotron2":
model_output = postnet_outputs.detach().cpu().numpy()
return model_output


def extract_spectrograms(
data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"
):
model.eval()
export_metadata = []
for _, data in tqdm(enumerate(data_loader), total=len(data_loader)):

# format data
(
text_input,
text_lengths,
mel_input,
mel_lengths,
speaker_c,
language_ids,
_,
_,
attn_mask,
item_idx
) = format_data(data)

model_output = inference(
c.model.lower(),
model,
ap,
text_input,
text_lengths,
mel_input,
mel_lengths,
attn_mask,
speaker_c,
language_ids,
)

for idx in range(text_input.shape[0]):
wav_file_path = item_idx[idx]
_, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path)

# quantize and save wav
if quantized_wav:
wav = ap.load_wav(wav_file_path)
wavq = ap.quantize(wav)
np.save(wavq_path, wavq)

if not os.path.exists(mel_path):
# save TTS mel
mel = model_output[idx]
mel_length = mel_lengths[idx]
mel = mel[:mel_length, :].T
np.save(mel_path, mel)

export_metadata.append([wav_file_path, mel_path])
if save_audio:
if not os.path.exists(wav_path):
copyfile(wav_file_path, wav_path)
# ap.save_wav(wav, wav_path)

if debug:
print("Audio for debug saved at:", wav_gl_path)
wav = ap.inv_melspectrogram(mel)
ap.save_wav(wav, wav_gl_path)

with open(os.path.join(output_path, metada_name), "w") as f:
for data in export_metadata:
f.write(f"{data[0]}|{data[1]+'.npy'}\n")


def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global meta_data, symbols, phonemes, model_characters, speaker_mapping, language_mapping
# Audio processor
ap = AudioProcessor(**c.audio)
if "characters" in c.keys() and c["characters"]:
symbols, phonemes = make_symbols(**c.characters)

# set model characters
model_characters = phonemes if c.use_phonemes else symbols
num_chars = len(model_characters)

# load data instances
meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True, ignore_generated_eval=True)

# use eval and training partitions
meta_data = meta_data_train + meta_data_eval
print("Num samples:", len(meta_data))

# parse speakers
num_speakers, _, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, None, meta_data_eval, training=False)

# parse languages
language_mapping = load_language_mapping(os.path.dirname(args.restore_path))
num_langs = len(language_mapping.keys())
language_embedding_dim = None

# setup model
model = setup_model(num_chars, num_speakers, num_langs, c, speaker_embedding_dim, language_embedding_dim)

# restore model
checkpoint = torch.load(args.restore_path, map_location="cpu")
model.load_state_dict(checkpoint["model"])

if use_cuda:
model.cuda()

num_params = count_parameters(model)
print("\n > Model has {} parameters".format(num_params), flush=True)
# set r
r = 1 if c.model.lower() == "glow_tts" else model.decoder.r
own_loader = setup_loader(ap, r, verbose=True)

extract_spectrograms(
own_loader,
model,
ap,
args.output_path,
quantized_wav=args.quantized,
save_audio=args.save_audio,
debug=args.debug,
metada_name="metada.txt",
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True)
parser.add_argument("--restore_path", type=str, help="Model file to be restored.", required=True)
parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
parser.add_argument("--only_decoder", default=False, action="store_true", help="Use only the decoder on GlowTTS inference")
parser.add_argument("--quantized", default=False, action="store_true", help="Save quantized audio files")
args = parser.parse_args()

c = load_config(args.config_path)
c.audio["do_trim_silence"] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel

main(args)