In [1]:
import sys
sys.path.append("..")

In [2]:
import os
from models.frequency_model import FrequencyModel
from preprocess import *
import tensorflow as tf
import numpy as np
from tqdm import *

# Define parametes

In [3]:
train_parameters = {
    "lr": 0.0001,
    "decay_steps": 300,
    "decay_rate": 0.9,
    "dropout_prob": 0.2
}

model_parameters = {
    "phonemes_embedding_size": 16,
    "speaker_embedding_size": 16,
    "num_bidirectional_layers": 2,
    "num_bidirectional_units": 16,
    "conv_widths": [2, 2],  # [1,2] the width of each conv
    "output_dimension": 2 # the number of units in the GRUcell
}


max_len = 512
input_vocab_size = 500 # 50 kinds of phonemes
batch_size = 2
num_speakers = 10
num_buckets = 50
num_steps = 5
save_energy = 2


# Simulate Dataset

In [4]:
voiced_thresh = 1000.
lower,upper = -5000.,5000.
mu,sigma = 0., 500.
sen_n = 200

train_sentences = [create_frequency_sentence(max_len,input_vocab_size,mu,sigma,lower,upper,
                                            input_vocab_size,voiced_thresh,batch_size) for i in range(sen_n/batch_size)]
#val_sen_n = 20
#val_sentence = [create_frame_sentence(max_len,input_vocab_size,mu,sigma,
#                                        lower,upper,batch_size) for i in range(val_sen_n/batch_size)]


# Train model

In [5]:
with tf.Session() as sess:
    phonemes = tf.placeholder(tf.int32, [None, max_len])
    phonemes_seq_len = tf.placeholder(tf.int32, [None])
    speaker_ids = tf.placeholder(tf.int32, [None])
    voiced_target = tf.placeholder(tf.int32, [None, max_len])
    frequency_target = tf.placeholder(tf.float32, [None, max_len])
    
    prediction_phonemes = tf.placeholder(tf.int32, [None, max_len],name="p_phonemes")
    prediction_phonemes_seq_len = tf.placeholder(tf.int32, [None],name="p_seq_len")
    prediction_speaker_ids = tf.placeholder(tf.int32, [None],name="p_spk_ids")

    model = FrequencyModel(
        input_vocab_size, num_speakers,
        model_parameters
    )
    
    train_op_tf, loss_tf, global_step_tf, summary_tf = model.build_train_operations(
        phonemes, phonemes_seq_len, speaker_ids, voiced_target, frequency_target, train_parameters
    )
    
    prediction_voiced, prediction_frequencies = model.build_prediction(
        prediction_phonemes, prediction_phonemes_seq_len, prediction_speaker_ids, True
    )

    train_writer = tf.summary.FileWriter('../log/train_frequency_model_notebook/train', sess.graph)
    
    tf.global_variables_initializer().run()
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    saver = tf.train.Saver(max_to_keep=3, keep_checkpoint_every_n_hours=3)
    
    for i in xrange(num_steps):
        print '='*30
        print "epoch:",i
        for s in tqdm(train_sentences):
            out = sess.run([
                train_op_tf,
                loss_tf,
                global_step_tf,
                summary_tf,
            ], feed_dict={
                phonemes: s['phonemes'],  
                phonemes_seq_len: s['phonemes_seq_len'],
                speaker_ids: 1 * np.ones((batch_size)),
                voiced_target: s['voiced_target'], 
                frequency_target: s['frequency_target'] 
            })
            _, loss, global_step, summary  = out

            #print global_step
        print "Train loss:",loss
            
        # detect gradient explosion
        if loss > 1e8 and global_step > 500:
            print('loss exploded')
            break

        #train_writer.add_summary(summary, global_step)
    saver.save(sess, '../weights/train_frequency_model_notebook/model.ckpt')
    print "Weights saved."
    coord.request_stop()
    coord.join(threads)
    


  0%|          | 0/100 [00:00<?, ?it/s]

epoch: 0


100%|██████████| 100/100 [01:31<00:00,  1.09it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

Train loss: 378.317
epoch: 1


100%|██████████| 100/100 [01:31<00:00,  1.10it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

Train loss: 378.175
epoch: 2


100%|██████████| 100/100 [01:31<00:00,  1.10it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

Train loss: 378.079
epoch: 3


100%|██████████| 100/100 [01:31<00:00,  1.10it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

Train loss: 378.069
epoch: 4


100%|██████████| 100/100 [01:31<00:00,  1.10it/s]


Train loss: 378.068
