In [1]:
import sys
sys.path.append("..")

In [2]:
import os
from models.duration_model import DurationModel
from preprocess import *
import tensorflow as tf
import numpy as np
import math
import scipy.stats as stats
from tqdm import *

# Define parametes

In [3]:
train_parameters = {
    "lr": 0.0001,
    "decay_steps": 300,
    "decay_rate": 0.9,
    "dropout_prob": 0.2
}

model_parameters = {
    "phonemes_embedding_size": 56,
    "speaker_embedding_size": 16,
    "num_dense_layers": 2,
    "dense_layers_units": 16,
    "num_bidirectional_layers": 2,
    "num_bidirectional_units": 16
}
max_len = 200
input_vocab_size = 500 # 50 kinds of phonemes
num_buckets = 50
batch_size = 2  # number of sentences in one example
num_speakers = 10
num_steps = 5
save_energy = 2


# Simulate Dataset

In [4]:
"""
Data needed for a single trianing sentence: {
            phonemes: np.ones((2, 200)), # (2,200) means 2 sentences and the maximum length of the sentence
            phonemes_seq_len: [10,10],  # real sentence length
            speaker_ids: 2 * np.ones((2)), # speaker indexes. [1,2] means speaker1 and speaker2
            durations: durations # (2,200) the real time of each phenome  

"""

###########################
# 200 Simulated sentences #
###########################

d_lower, d_upper = 0., 100. 
d_mu, d_sigma = 30., 10.
min_frame_len = 10
asn_upper = np.log(.95 * d_upper)
asn_lower = np.log(min_frame_len/1000.)
sen_n = 200
train_sentences = [create_duration_sentence(max_len,input_vocab_size,d_lower, d_upper, d_mu, d_sigma, 
                            asn_upper, asn_lower, num_buckets, batch_size) for i in range(sen_n/batch_size)]
val_sen_n = 20
val_sentences = [create_duration_sentence(max_len,input_vocab_size,d_lower, d_upper, d_mu, d_sigma, 
                            asn_upper, asn_lower, num_buckets, batch_size) for i in range(val_sen_n/batch_size)]


  log_duration = np.log(duration)


# Train model

In [5]:
with tf.Session() as sess:
    phonemes = tf.placeholder(tf.int32, [None, max_len])
    phonemes_seq_len = tf.placeholder(tf.int32, [None])
    speaker_ids = tf.placeholder(tf.int32, [None])
    durations = tf.placeholder(tf.int32, [None, max_len])
    
    prediction_phonemes = tf.placeholder(tf.int32, [None, max_len],name="p_phonemes")
    prediction_phonemes_seq_len = tf.placeholder(tf.int32, [None],name="p_seq_len")
    prediction_speaker_ids = tf.placeholder(tf.int32, [None],name="p_spk_ids")
    prediction_params = tf.placeholder(tf.float32, [num_buckets,num_buckets],name="p_params")
    
    model = DurationModel(
        input_vocab_size, num_speakers,
        num_buckets, model_parameters
    )
    
    train_op_tf, loss_tf, global_step_tf, summary_tf, logits_tf, transition_params_tf = model.build_train_operations(
        phonemes, phonemes_seq_len, speaker_ids, durations, train_parameters
    )
    #val_op_tf = model.val_operations(phonemes, phonemes_seq_len, speaker_ids, durations, train_parameters,True)
    
    prediction = model.viterbi_predict(
        prediction_phonemes, prediction_phonemes_seq_len, prediction_speaker_ids, prediction_params, True
    )
    
    train_writer = tf.summary.FileWriter('../log/train_duration_model_notebook/train', sess.graph)
    
    tf.global_variables_initializer().run()
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    saver = tf.train.Saver(max_to_keep=10, keep_checkpoint_every_n_hours=3)
    
    # Trianing procedure
    for i in xrange(num_steps):
        print '='*30
        print "epoch:",i
        total_loss = 0 
        for s in tqdm(train_sentences):             
            out = sess.run([
                train_op_tf,
                loss_tf,
                global_step_tf,
                summary_tf,
                logits_tf,
                transition_params_tf
            ], feed_dict={
                phonemes: s['phonemes'], 
                phonemes_seq_len: s['phonemes_seq_len'], 
                speaker_ids: np.ones((batch_size)),
                durations: s['durations'] 
            })
            _,  loss, global_step, summary, logits, transition_params  = out
            total_loss = total_loss + loss
            # detect gradient explosion
            if loss > 1e8 and global_step > 500:
                print('loss exploded')
                break
        print "Train total loss:",total_loss/sen_n
        print "Current sentence loss:", loss
        
        
        val_loss = 0
        print "Calculating validation loss..."
        for s in (val_sentences):
            loss = sess.run(loss_tf, feed_dict={
                phonemes: s['phonemes'], 
                phonemes_seq_len: s['phonemes_seq_len'], 
                speaker_ids: np.ones((batch_size)),
                durations: s['durations']                
            })
            val_loss += loss
        print "Validation loss:",val_loss/val_sen_n
        

    trans_params = tf.convert_to_tensor(transition_params, np.float32, name="trans_params")
    tf.add_to_collection("trans_params", trans_params)
    
    print('saving weights')
    if not os.path.exists('../weights/train_duration_model_notebook/'):
        os.makedirs('../weights/train_duration_model_notebook/')    
    saver.save(sess, '../weights/train_duration_model_notebook/model.ckpt')
    
    coord.request_stop()
    coord.join(threads)
    
'''
        train_writer.add_summary(summary, global_step)

        # detect gradient explosion
        if loss > 1e8 and global_step > 500:
            print('loss exploded')
            break

        if global_step % save_energy == 0 and global_step != 0:

            print('saving weights')
            if not os.path.exists('../weights/train_duration_model_notebook/'):
                os.makedirs('../weights/train_duration_model_notebook/')    
            saver.save(sess, '../weights/train_duration_model_notebook/model.ckpt', global_step=global_step)

'''
    
    

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  0%|          | 0/100 [00:00<?, ?it/s]

epoch: 0


100%|██████████| 100/100 [00:46<00:00,  2.14it/s]


Train total loss: 194.676608499
Current sentence loss: 472.939
Calculating validation loss...


  0%|          | 0/100 [00:00<?, ?it/s]

Validation loss: 161.396176338
epoch: 1


100%|██████████| 100/100 [00:46<00:00,  2.15it/s]


Train total loss: 160.682972133
Current sentence loss: 369.221
Calculating validation loss...


  0%|          | 0/100 [00:00<?, ?it/s]

Validation loss: 126.520114899
epoch: 2


100%|██████████| 100/100 [00:46<00:00,  2.16it/s]


Train total loss: 132.97721373
Current sentence loss: 324.107
Calculating validation loss...


  0%|          | 0/100 [00:00<?, ?it/s]

Validation loss: 111.062709427
epoch: 3


100%|██████████| 100/100 [00:46<00:00,  2.16it/s]


Train total loss: 121.229941581
Current sentence loss: 303.574
Calculating validation loss...


  0%|          | 0/100 [00:00<?, ?it/s]

Validation loss: 104.180547619
epoch: 4


100%|██████████| 100/100 [00:46<00:00,  2.17it/s]


Train total loss: 115.374780056
Current sentence loss: 292.901
Calculating validation loss...
Validation loss: 100.614594841
saving weights


"\n        train_writer.add_summary(summary, global_step)\n\n        # detect gradient explosion\n        if loss > 1e8 and global_step > 500:\n            print('loss exploded')\n            break\n\n        if global_step % save_energy == 0 and global_step != 0:\n\n            print('saving weights')\n            if not os.path.exists('../weights/train_duration_model_notebook/'):\n                os.makedirs('../weights/train_duration_model_notebook/')    \n            saver.save(sess, '../weights/train_duration_model_notebook/model.ckpt', global_step=global_step)\n\n"