In [1]:
# import argparse
from datetime import datetime
import math
import os
import subprocess
import time
import tensorflow as tf
import traceback

from datasets.new_datafeeder import get_dataset
from hparams import hparams, hparams_debug_string
from models import create_model
from text import sequence_to_text
from util import audio, infolog, plot, ValueWindow
log = infolog.log

In [2]:
class Args:
    def __init__(self):
        self.base_dir = './'
        self.input = 'training/train.txt'
        self.data_dir = os.path.dirname(self.input)
        self.model = 'tacotron'
        self.name = ''
        self.hparams = ''
        self.summary_interval = 100
        self.checkpoint_interval = 1000
        self.tf_log_level = 1    
        self.slack_url = ''
        self.git= ''
        self.max_iter = int(1e5)

In [3]:
args = Args()

In [4]:
def time_string():
    return datetime.now().strftime('%Y-%m-%d %H:%M')

In [5]:
run_name = args.name or args.model
log_dir = os.path.join(args.base_dir, 'logs-%s' % run_name)
os.makedirs(log_dir, exist_ok=True)
infolog.init(os.path.join(log_dir, 'train.log'), run_name, args.slack_url)
hparams.parse(args.hparams)

HParams([('adam_beta1', 0.9), ('adam_beta2', 0.999), ('attention_depth', 256), ('batch_size', 32), ('cleaners', 'english_cleaners'), ('decay_learning_rate', True), ('decoder_depth', 256), ('embed_depth', 256), ('encoder_depth', 256), ('frame_length_ms', 50), ('frame_shift_ms', 12.5), ('griffin_lim_iters', 60), ('initial_learning_rate', 0.002), ('max_iters', 200), ('min_level_db', -100), ('num_freq', 1025), ('num_mels', 80), ('outputs_per_step', 5), ('postnet_depth', 256), ('power', 1.5), ('preemphasis', 0.97), ('prenet_depths', [256, 128]), ('ref_level_db', 20), ('sample_rate', 20000), ('use_cmudict', False)])

In [6]:
with open(args.input, encoding='utf-8') as f:
    metadata = [row.strip().split('|') for row in f]
metadata = sorted(metadata, key=lambda x:x[2])

In [7]:
data_element = get_dataset(metadata, args.data_dir, hparams)

In [8]:
checkpoint_path = os.path.join(log_dir, 'model.ckpt')
global_step = tf.Variable(0, name='global_step', trainable=False)
with tf.variable_scope('model') as scope:
    model = create_model(args.model, hparams)
    model.initialize(data_element['input'], 
                     data_element['input_lengths'], 
                     data_element['mel_targets'], 
                     data_element['linear_targets'])
    model.add_loss()
    model.add_optimizer(global_step)

Instructions for updating:
seq_dim is deprecated, use seq_axis instead
Instructions for updating:
batch_dim is deprecated, use batch_axis instead
Initialized Tacotron model. Dimensions: 
    embedding:                  256
    prenet out:                 128
    encoder out:                256
    attention out:              256
    concat attn & out:          512
    decoder cell out:           256
    decoder out (5 frames):     400
    decoder out (1 frame):      80
    postnet out:                256
    linear out:                 1025


In [9]:
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

In [10]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [11]:
for _ in range(args.max_iter):

    start_time = time.time()
    step, mel_loss, lin_loss, loss, opt = sess.run([global_step, model.mel_loss, model.linear_loss, model.loss, model.optimize])
    end_time = time.time()

    message = 'Step %07d [%.03f sec/step, loss = %.05f (mel : %.05f + lin : %.05f)]' % (
               step, end_time - start_time, loss, mel_loss, lin_loss)

    print(message)

    if loss > 100 or math.isnan(loss):
        print('Loss exploded to %.05f at step %d!' % (loss, step))
        raise Exception('Loss Exploded')

    if step % args.checkpoint_interval == 0:
        print('Saving checkpoint to: %s-%d' % (checkpoint_path, step))
        saver.save(sess, checkpoint_path, global_step=step)

        print('Saving audio and alignment...')
        input_seq, spectrogram, alignment = sess.run([model.inputs[0], model.linear_outputs[0], model.alignments[0]])
        waveform = audio.inv_spectrogram(spectrogram.T)
        audio.save_wav(waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step))
        plot.plot_alignment(alignment, os.path.join(log_dir, 'step-%d-align.png' % step),
            info='%s, %s, step=%d, loss=%.5f' % (args.model, time_string(), step, loss))

        print('Input: %s' % sequence_to_text(input_seq))

Step 0000001 [7.009 sec/step, loss = 0.81838 (mel : 0.34616 + lin : 0.47223)]
Step 0000002 [1.271 sec/step, loss = 0.79893 (mel : 0.34017 + lin : 0.45877)]
Step 0000003 [0.689 sec/step, loss = 0.80719 (mel : 0.33536 + lin : 0.47183)]
Step 0000004 [0.270 sec/step, loss = 0.83416 (mel : 0.35246 + lin : 0.48170)]
Step 0000005 [0.710 sec/step, loss = 0.82291 (mel : 0.34426 + lin : 0.47865)]
Step 0000006 [0.729 sec/step, loss = 0.83648 (mel : 0.34804 + lin : 0.48844)]
Step 0000007 [0.735 sec/step, loss = 0.82933 (mel : 0.35099 + lin : 0.47834)]
Step 0000008 [0.320 sec/step, loss = 0.83063 (mel : 0.34853 + lin : 0.48211)]
Step 0000009 [0.790 sec/step, loss = 0.83021 (mel : 0.34875 + lin : 0.48146)]
Step 0000010 [0.764 sec/step, loss = 0.83449 (mel : 0.34874 + lin : 0.48575)]
Step 0000011 [0.765 sec/step, loss = 0.83960 (mel : 0.35086 + lin : 0.48874)]
Step 0000012 [0.349 sec/step, loss = 0.82084 (mel : 0.34373 + lin : 0.47711)]
Step 0000013 [0.794 sec/step, loss = 0.82258 (mel : 0.33958 + li

Step 0000107 [3.354 sec/step, loss = 0.33567 (mel : 0.15154 + lin : 0.18413)]
Step 0000108 [0.914 sec/step, loss = 0.33284 (mel : 0.15318 + lin : 0.17965)]
Step 0000109 [1.085 sec/step, loss = 0.33220 (mel : 0.15406 + lin : 0.17813)]
Step 0000110 [3.424 sec/step, loss = 0.32667 (mel : 0.15160 + lin : 0.17507)]
Step 0000111 [0.892 sec/step, loss = 0.32189 (mel : 0.15047 + lin : 0.17142)]
Step 0000112 [3.432 sec/step, loss = 0.31937 (mel : 0.15037 + lin : 0.16900)]
Step 0000113 [1.225 sec/step, loss = 0.31614 (mel : 0.14946 + lin : 0.16668)]
Step 0000114 [1.058 sec/step, loss = 0.31445 (mel : 0.14934 + lin : 0.16512)]
Step 0000115 [3.575 sec/step, loss = 0.30945 (mel : 0.14826 + lin : 0.16119)]
Step 0000116 [1.126 sec/step, loss = 0.30732 (mel : 0.14911 + lin : 0.15821)]
Step 0000117 [1.094 sec/step, loss = 0.30369 (mel : 0.14553 + lin : 0.15816)]
Step 0000118 [3.621 sec/step, loss = 0.30244 (mel : 0.14633 + lin : 0.15612)]
Step 0000119 [3.669 sec/step, loss = 0.29929 (mel : 0.14539 + li

Step 0000213 [1.496 sec/step, loss = 0.24512 (mel : 0.12978 + lin : 0.11534)]
Step 0000214 [1.613 sec/step, loss = 0.24205 (mel : 0.12922 + lin : 0.11283)]
Step 0000215 [1.378 sec/step, loss = 0.24348 (mel : 0.12925 + lin : 0.11423)]
Step 0000216 [1.334 sec/step, loss = 0.24255 (mel : 0.12922 + lin : 0.11333)]
Step 0000217 [1.346 sec/step, loss = 0.24303 (mel : 0.12838 + lin : 0.11464)]
Step 0000218 [1.281 sec/step, loss = 0.24123 (mel : 0.12776 + lin : 0.11347)]
Step 0000219 [1.452 sec/step, loss = 0.24220 (mel : 0.12949 + lin : 0.11271)]
Step 0000220 [1.759 sec/step, loss = 0.24161 (mel : 0.12770 + lin : 0.11391)]
Step 0000221 [1.363 sec/step, loss = 0.24324 (mel : 0.12869 + lin : 0.11456)]
Step 0000222 [1.541 sec/step, loss = 0.24258 (mel : 0.12921 + lin : 0.11337)]
Step 0000223 [1.505 sec/step, loss = 0.24262 (mel : 0.12952 + lin : 0.11309)]
Step 0000224 [1.614 sec/step, loss = 0.24235 (mel : 0.12810 + lin : 0.11425)]
Step 0000225 [1.615 sec/step, loss = 0.24111 (mel : 0.12811 + li

KeyboardInterrupt: 