In [1]:
import sys
sys.path.append("..")

In [2]:
import os
from utils.dataset_utils import build_dataset
from models.grapheme_to_phoneme import GraphemeToPhoneme
import tensorflow as tf
import numpy as np
import pickle

# Load data

In [3]:
with open("../data/cmu.pkl", 'r') as read_file:
    meta = pickle.load(read_file)
char2id = meta["char2id"]
id2char = meta["id2char"]
phoneme2id = meta["phoneme2id"]
id2phoneme = meta["id2phoneme"]

In [4]:
data = np.load("../data/cmu_data.npz")

# Define parametes

In [5]:
train_parameters = {
    "lr": 0.001,
    "decay_steps": 1000,
    "decay_rate": 0.85,
    "batch_size": 2
}

model_parameters = {
  "embedding_size": 80,
  "num_units": 10,
  "num_layers": 2,
  "dropout_prob": 0.95,
  "num_beams": 1
}

input_vocab_size = len(char2id)
output_vocab_size= len(phoneme2id)

shuffle_buffer_size = 1000

end_token = phoneme2id["<eos>"]

num_steps = 2

save_energy = 2

# Train model

In [6]:
with tf.Session() as sess:
    dataset = build_dataset(
        sess,
        (data["X"], data["X_seq_len"],data["Y"], data["Y_seq_len"]),
        ("X", "X_seq_len", "Y", "Y_seq_len"),
        train_parameters["batch_size"],
        shuffle_buffer_size
    )
    
    model = GraphemeToPhoneme(
        dataset["X"], dataset["X_seq_len"], input_vocab_size,
        output_vocab_size, end_token, model_parameters,
        dataset["Y"], dataset["Y_seq_len"], train_parameters
    )

    train_writer = tf.summary.FileWriter('../log/train_grapheme_to_phoneme_model_notebook/train', sess.graph)
    
    tf.global_variables_initializer().run()
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    saver = tf.train.Saver(max_to_keep=3, keep_checkpoint_every_n_hours=3)
    
    for _ in xrange(num_steps):
        out = sess.run([
            model.train_op,
            model.global_step,
            model.loss,
            model.summary,
            dataset,
        ])
        _, global_step, loss, summary, _ = out

        train_writer.add_summary(summary, global_step)

        # detect gradient explosion
        if loss > 1e8 and global_step > 500:
            print('loss exploded')
            break

        if global_step % save_energy == 0 and global_step != 0:

            print('saving weights')
            if not os.path.exists('../weights/train_grapheme_to_phoneme_model_notebook/'):
                os.makedirs('../weights/train_grapheme_to_phoneme_model_notebook/')
            saver.save(sess, '../weights/train_grapheme_to_phoneme_model_notebook/', global_step=global_step)
    
    coord.request_stop()
    coord.join(threads)

saving weights
