Skip to content

Commit

Permalink
TTS folder and travis
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Jul 17, 2020
1 parent 82dd465 commit 9033070
Show file tree
Hide file tree
Showing 117 changed files with 13,109 additions and 8 deletions.
18 changes: 18 additions & 0 deletions .github/PR_TEMPLATE.md
@@ -0,0 +1,18 @@
---
name: 'Contribution Guideline '
about: Refer to Contirbution Guideline
title: ''
labels: ''
assignees: ''

---
### Contribution Guideline

Please send your PRs to `dev` branch if it is not directly related to a specific branch.
Before making a Pull Request, check your changes for basic mistakes and style problems by using a linter.
We have cardboardlinter setup in this repository, so for example, if you've made some changes and would like to run the linter on just the changed code, you can use the follow command:

```bash
pip install pylint cardboardlint
cardboardlinter --refspec master
```
8 changes: 6 additions & 2 deletions .travis.yml
Expand Up @@ -6,6 +6,8 @@ git:
before_install:
- sudo apt-get update
- sudo apt-get -y install espeak
- python -m pip install --upgrade pip
- pip install six==1.12.0

matrix:
include:
Expand All @@ -15,11 +17,13 @@ matrix:
env: TEST_SUITE="lint"
- name: "Unit tests"
python: "3.6"
install: pip install --quiet -r requirements_tests.txt
install:
- python setup.py install
env: TEST_SUITE="unittest"
- name: "Unit tests"
python: "3.6"
install: pip install --quiet -r requirements_tests.txt
install:
- python setup.py install
env: TEST_SUITE="testscripts"

script: ./.travis/script
7 changes: 1 addition & 6 deletions .travis/script
Expand Up @@ -11,12 +11,7 @@ fi

if [[ "$TEST_SUITE" == "unittest" ]]; then
# Run tests on all pushes
pushd tts_namespace
nosetests TTS.speaker_encoder.tests --nocapture
nosetests TTS.vocoder.tests --nocapture
nosetests TTS.tts.tests --nocapture
nosetests TTS.tts.tf.tests --nocapture
popd
nosetests tests --nocapture
fi

if [[ "$TEST_SUITE" == "testscripts" ]]; then
Expand Down
Empty file added TTS/__init__.py
Empty file.
85 changes: 85 additions & 0 deletions TTS/bin/compute_statistics.py
@@ -0,0 +1,85 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import argparse

import numpy as np
from tqdm import tqdm

from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.utils.io import load_config
from TTS.tts.utils.audio import AudioProcessor

def main():
"""Run preprocessing process."""
parser = argparse.ArgumentParser(
description="Compute mean and variance of spectrogtram features.")
parser.add_argument("--config_path", type=str, required=True,
help="TTS config file path to define audio processin parameters.")
parser.add_argument("--out_path", default=None, type=str,
help="directory to save the output file.")
args = parser.parse_args()

# load config
CONFIG = load_config(args.config_path)
CONFIG.audio['signal_norm'] = False # do not apply earlier normalization
CONFIG.audio['stats_path'] = None # discard pre-defined stats

# load audio processor
ap = AudioProcessor(**CONFIG.audio)

# load the meta data of target dataset
dataset_items = load_meta_data(CONFIG.datasets)[0] # take only train data
print(f" > There are {len(dataset_items)} files.")

mel_sum = 0
mel_square_sum = 0
linear_sum = 0
linear_square_sum = 0
N = 0
for item in tqdm(dataset_items):
# compute features
wav = ap.load_wav(item[1])
linear = ap.spectrogram(wav)
mel = ap.melspectrogram(wav)

# compute stats
N += mel.shape[1]
mel_sum += mel.sum(1)
linear_sum += linear.sum(1)
mel_square_sum += (mel ** 2).sum(axis=1)
linear_square_sum += (linear ** 2).sum(axis=1)

mel_mean = mel_sum / N
mel_scale = np.sqrt(mel_square_sum / N - mel_mean ** 2)
linear_mean = linear_sum / N
linear_scale = np.sqrt(linear_square_sum / N - linear_mean ** 2)

output_file_path = os.path.join(args.out_path, "scale_stats.npy")
stats = {}
stats['mel_mean'] = mel_mean
stats['mel_std'] = mel_scale
stats['linear_mean'] = linear_mean
stats['linear_std'] = linear_scale

print(f' > Avg mel spec mean: {mel_mean.mean()}')
print(f' > Avg mel spec scale: {mel_scale.mean()}')
print(f' > Avg linear spec mean: {linear_mean.mean()}')
print(f' > Avg lienar spec scale: {linear_scale.mean()}')

# set default config values for mean-var scaling
CONFIG.audio['stats_path'] = output_file_path
CONFIG.audio['signal_norm'] = True
# remove redundant values
del CONFIG.audio['max_norm']
del CONFIG.audio['min_level_db']
del CONFIG.audio['symmetric_norm']
del CONFIG.audio['clip_norm']
stats['audio_config'] = CONFIG.audio
np.save(output_file_path, stats, allow_pickle=True)
print(f' > scale_stats.npy is saved to {output_file_path}')


if __name__ == "__main__":
main()
33 changes: 33 additions & 0 deletions TTS/bin/convert_melgan_tflite.py
@@ -0,0 +1,33 @@
# Convert Tensorflow Tacotron2 model to TF-Lite binary

import argparse

from TTS.tts.utils.io import load_config
from TTS.vocoder.tf.utils.generic_utils import setup_generator
from TTS.vocoder.tf.utils.io import load_checkpoint
from TTS.vocoder.tf.utils.tflite import convert_melgan_to_tflite


parser = argparse.ArgumentParser()
parser.add_argument('--tf_model',
type=str,
help='Path to target torch model to be converted to TF.')
parser.add_argument('--config_path',
type=str,
help='Path to config file of torch model.')
parser.add_argument('--output_path',
type=str,
help='path to tflite output binary.')
args = parser.parse_args()

# Set constants
CONFIG = load_config(args.config_path)

# load the model
model = setup_generator(CONFIG)
model.build_inference()
model = load_checkpoint(model, args.tf_model)

# create tflite model
tflite_model = convert_melgan_to_tflite(model, output_path=args.output_path)

117 changes: 117 additions & 0 deletions TTS/bin/convert_melgan_torch_to_tf.py
@@ -0,0 +1,117 @@
import argparse
import os

import numpy as np
import tensorflow as tf
import torch
from fuzzywuzzy import fuzz

from TTS.tts.utils.io import load_config
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf)
from TTS.vocoder.tf.utils.generic_utils import \
setup_generator as setup_tf_generator
from TTS.vocoder.tf.utils.io import save_checkpoint
from TTS.vocoder.utils.generic_utils import setup_generator

# prevent GPU use
os.environ['CUDA_VISIBLE_DEVICES'] = ''

# define args
parser = argparse.ArgumentParser()
parser.add_argument('--torch_model_path',
type=str,
help='Path to target torch model to be converted to TF.')
parser.add_argument('--config_path',
type=str,
help='Path to config file of torch model.')
parser.add_argument(
'--output_path',
type=str,
help='path to output file including file name to save TF model.')
args = parser.parse_args()

# load model config
config_path = args.config_path
c = load_config(config_path)
num_speakers = 0

# init torch model
model = setup_generator(c)
checkpoint = torch.load(args.torch_model_path,
map_location=torch.device('cpu'))
state_dict = checkpoint['model']
model.load_state_dict(state_dict)
model.remove_weight_norm()
state_dict = model.state_dict()

# init tf model
model_tf = setup_tf_generator(c)

common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
# get tf_model graph by passing an input
# B x D x T
dummy_input = tf.random.uniform((7, 80, 64), dtype=tf.float32)
mel_pred = model_tf(dummy_input, training=False)

# get tf variables
tf_vars = model_tf.weights

# match variable names with fuzzy logic
torch_var_names = list(state_dict.keys())
tf_var_names = [we.name for we in model_tf.weights]
var_map = []
for tf_name in tf_var_names:
# skip re-mapped layer names
if tf_name in [name[0] for name in var_map]:
continue
tf_name_edited = convert_tf_name(tf_name)
ratios = [
fuzz.ratio(torch_name, tf_name_edited)
for torch_name in torch_var_names
]
max_idx = np.argmax(ratios)
matching_name = torch_var_names[max_idx]
del torch_var_names[max_idx]
var_map.append((tf_name, matching_name))

# pass weights
tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict)

# Compare TF and TORCH models
# check embedding outputs
model.eval()
dummy_input_torch = torch.ones((1, 80, 10))
dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
dummy_input_tf = tf.transpose(dummy_input_tf, perm=[0, 2, 1])
dummy_input_tf = tf.expand_dims(dummy_input_tf, 2)

out_torch = model.layers[0](dummy_input_torch)
out_tf = model_tf.model_layers[0](dummy_input_tf)
out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :]

assert compare_torch_tf(out_torch, out_tf_) < 1e-5

for i in range(1, len(model.layers)):
print(f"{i} -> {model.layers[i]} vs {model_tf.model_layers[i]}")
out_torch = model.layers[i](out_torch)
out_tf = model_tf.model_layers[i](out_tf)
out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :]
diff = compare_torch_tf(out_torch, out_tf_)
assert diff < 1e-5, diff

torch.manual_seed(0)
dummy_input_torch = torch.rand((1, 80, 100))
dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
model.inference_padding = 0
model_tf.inference_padding = 0
output_torch = model.inference(dummy_input_torch)
output_tf = model_tf(dummy_input_tf, training=False)
assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(
output_torch, output_tf)

# save tf model
save_checkpoint(model_tf, checkpoint['step'], checkpoint['epoch'],
args.output_path)
print(' > Model conversion is successfully completed :).')

37 changes: 37 additions & 0 deletions TTS/bin/convert_tacotron2_tflite.py
@@ -0,0 +1,37 @@
# Convert Tensorflow Tacotron2 model to TF-Lite binary

import argparse

from TTS.tts.utils.io import load_config
from TTS.tts.utils.text.symbols import symbols, phonemes
from TTS.tf.utils.generic_utils import setup_model
from TTS.tf.utils.io import load_checkpoint
from TTS.tf.utils.tflite import convert_tacotron2_to_tflite


parser = argparse.ArgumentParser()
parser.add_argument('--tf_model',
type=str,
help='Path to target torch model to be converted to TF.')
parser.add_argument('--config_path',
type=str,
help='Path to config file of torch model.')
parser.add_argument('--output_path',
type=str,
help='path to tflite output binary.')
args = parser.parse_args()

# Set constants
CONFIG = load_config(args.config_path)

# load the model
c = CONFIG
num_speakers = 0
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = setup_model(num_chars, num_speakers, c, enable_tflite=True)
model.build_inference()
model = load_checkpoint(model, args.tf_model)
model.decoder.set_max_decoder_steps(1000)

# create tflite model
tflite_model = convert_tacotron2_to_tflite(model, output_path=args.output_path)

0 comments on commit 9033070

Please sign in to comment.