In [2]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import math
import os
from scripts import text_utils
from scripts import language_model
from scripts import embed_meta
%load_ext autoreload
%autoreload 2

In [4]:
def get_data(num_lines):
  books_data = text_utils.get_books_data()[:num_lines]
  text = text_utils.clean_up_text(books_data).split(' ')
  word_dict = text_utils.WordDict(text)
  text_data = word_dict.tokens_to_ids(text)
  vocab_size = word_dict.get_vocab_size()
  print("vocab size is {}".format(vocab_size))
  print("total number of words is {}".format(text_data.shape[0]))
  return {
    'word_dict': word_dict,
    'text_data': text_data
  }

In [5]:
def get_batch(text_data, batch_size, T):
  data = np.zeros([batch_size, T+1], dtype=np.int64)
  starts = np.random.choice(range(0, text_data.shape[0]-T-1), size=batch_size)
  for i in range(len(starts)):
    data[i, :] = text_data[starts[i]:starts[i]+T+1]
  input_data = data[:, 0:-1]
  labels = data[:, 1:]
  return input_data, labels

In [6]:
num_lines = 10000
data = get_data(num_lines)
word_dict = data['word_dict']
text_data_trn = data['text_data'][:80000]
text_data_val = data['text_data'][80000:]

total number of lines: 258521
vocab size is 7632
total number of words is 132578


In [31]:
batch_size = 64
T = 20 # number of words in one sample

g = tf.Graph()
with g.as_default():
  global_step = tf.Variable(0)
  model = language_model.LanguageModel(batch_size, T, global_step, word_dict.get_vocab_size())

In [32]:
LOG_DIR = './results/'

In [33]:
with tf.Session(graph=g) as sess:
  writer = tf.summary.FileWriter(os.path.join(LOG_DIR), sess.graph)
  batch_size = 64
  fetches = {
    "loss": model.loss,
    "perplexity": model.perplexity,
    "train_op": model.train_op
  }
  
  sess.run(tf.global_variables_initializer())
  for i in range(200):
    input_data, labels = get_batch(text_data_trn, 64, T)
    feed_dict = {
      model.input_data: input_data,
      model.labels: labels
    }
    result = sess.run(fetches, feed_dict)
    if i % 5 == 0:
      print("iter: {}, loss: {:.4f}, perplexity: {:.2f}".format(
          i, result['loss'], result['perplexity']))
  # make prediction
  input_data, labels = get_batch(text_data_trn, 64, T)
  feed_dict = {
    model.input_data: input_data,
    model.labels: labels
  }
  predictions = sess.run(model.predictions, feed_dict)
  for prediction in predictions[:5]:
    print(word_dict.ids_to_tokens(prediction))
  
  saver = tf.train.Saver()
  saver.save(sess, os.path.join(LOG_DIR, 'language_model.ckpt'), global_step=global_step)
  embed_meta.write_embed_meta(LOG_DIR, word_dict, model.embeddings, writer)

iter: 0, loss: 8.9401, perplexity: 7632.07
iter: 5, loss: 8.9318, perplexity: 7569.23
iter: 10, loss: 8.9104, perplexity: 7408.87
iter: 15, loss: 8.7289, perplexity: 6178.66
iter: 20, loss: 7.7738, perplexity: 2377.53
iter: 25, loss: 6.9199, perplexity: 1012.25
iter: 30, loss: 6.5554, perplexity: 703.02
iter: 35, loss: 6.4610, perplexity: 639.70
iter: 40, loss: 6.3000, perplexity: 544.57
iter: 45, loss: 6.0863, perplexity: 439.78
iter: 50, loss: 6.3244, perplexity: 558.00
iter: 55, loss: 6.1888, perplexity: 487.28
iter: 60, loss: 6.1271, perplexity: 458.12
iter: 65, loss: 6.2029, perplexity: 494.18
iter: 70, loss: 6.2022, perplexity: 493.85
iter: 75, loss: 6.1959, perplexity: 490.73
iter: 80, loss: 6.0840, perplexity: 438.79
iter: 85, loss: 6.0986, perplexity: 445.22
iter: 90, loss: 6.2179, perplexity: 501.64
iter: 95, loss: 6.1012, perplexity: 446.41
iter: 100, loss: 6.1575, perplexity: 472.24
iter: 105, loss: 6.0532, perplexity: 425.46
iter: 110, loss: 6.1305, perplexity: 459.67
iter