In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import math
import os
from scripts import text_utils
from scripts import language_model
from scripts import embed_meta
%load_ext autoreload
%autoreload 2

In [2]:
def get_data(num_lines=None):
  books_data = text_utils.get_books_data()
  if num_lines != None:
    books_data = books_data[:num_lines]
  text = text_utils.clean_up_text(books_data).split(' ')
  word_dict = text_utils.WordDict(text)
  text_data = word_dict.tokens_to_ids(text)
  vocab_size = word_dict.get_vocab_size()
  print("vocab size is {}".format(vocab_size))
  print("total number of words is {}".format(text_data.shape[0]))
  return {
    'word_dict': word_dict,
    'text_data': text_data
  }

In [3]:
def get_batch(text_data, batch_size, T):
  data = np.zeros([batch_size, T+1], dtype=np.int64)
  starts = np.random.choice(range(0, text_data.shape[0]-T-1), size=batch_size)
  for i in range(len(starts)):
    data[i, :] = text_data[starts[i]:starts[i]+T+1]
  input_data = data[:, 0:-1]
  labels = data[:, 1:]
  return input_data, labels

In [4]:
data = get_data()
word_dict = data['word_dict']

total number of lines: 258521
vocab size is 8000
total number of words is 2715459


In [5]:
num_train = 2000000
num_val   =  400000
text_data_trn = data['text_data'][:num_train]
text_data_val = data['text_data'][num_train:num_train+num_val]
# text_data_val = data['text_data']
print(text_data_trn.shape)

(2000000,)


In [16]:
LOG_DIR = './results/'
model_count = 0

In [17]:
batch_size = 20
T = 40 # number of words in one sample

g = tf.Graph()
with g.as_default():
  global_step = tf.Variable(0)
  with tf.variable_scope("Model", initializer=tf.random_uniform_initializer(-0.04, 0.04)):
    model = language_model.LanguageModel(batch_size, T, global_step, word_dict.get_vocab_size())
  tf.summary.scalar("loss", model.loss)
  merged = tf.summary.merge_all()

In [18]:
batch_data, _ = get_batch(text_data_trn, 2, T)
print(word_dict.ids_to_tokens(batch_data[0]))

['short', 'as', 'never', 'dragon', '<unk>', '<unk>', 'or', 'other', 'animal', 'of', 'that', 'species', 'presided', 'over', 'since', 'they', 'first', 'began', 'to', 'interest', 'themselves', 'in', 'household', 'affairs', '.', '\n', 'an', 'old', 'gentleman', 'and']


In [19]:
with tf.Session(graph=g) as sess:
  writer = tf.summary.FileWriter(os.path.join(LOG_DIR), sess.graph)
  fetches_trn = {
    "loss": model.loss,
    "perplexity": model.perplexity,
    "train_op": model.train_op,
    "summary": merged
  }
  fetches_val = {
    "loss": model.loss,
    "perplexity": model.perplexity,
    "summary": merged
  }
  model_count += 1
  trn_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, "trn"+str(model_count)))
  val_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, "val"+str(model_count)))
  
  sess.run(tf.global_variables_initializer())

  for i in range(20000):
    input_data, labels = get_batch(text_data_trn, batch_size, T)
    feed_dict = {
      model.input_data: input_data,
      model.labels: labels,
      model.is_training: True
    }
    result = sess.run(fetches_trn, feed_dict)
    trn_writer.add_summary(result['summary'], i)
    if i % 10 == 0:
      print("iter: {}, loss_trn: {:.4f}, perplexity_trn: {:.2f}".format(
          i, result['loss'], result['perplexity']))
      input_data, labels = get_batch(text_data_val, batch_size, T)
      feed_dict = {
        model.input_data: input_data,
        model.labels: labels,
        model.is_training: False
      }
      result = sess.run(fetches_val, feed_dict)
      print("iter: {}, loss_val: {:.4f}, perplexity_val: {:.2f}".format(
          i, result['loss'], result['perplexity']))
      val_writer.add_summary(result['summary'], i)
    if i % 200 == 0: # show predictions
      predictions = sess.run(model.predictions, feed_dict)
      for data_id, prediction in enumerate(predictions[:5]):
        print(word_dict.ids_to_tokens(np.concatenate(
            [feed_dict[model.input_data][data_id][:-1], 
             np.array(prediction)])))
    if i != 0 and i % 5000 == 0:
      saver = tf.train.Saver()
      saver.save(sess, os.path.join(LOG_DIR, 'language_model{}.ckpt'.format(model_count)), global_step=i)
      embed_meta.write_embed_meta(LOG_DIR, word_dict, model.embeddings, writer)

iter: 0, loss_trn: 8.9872, perplexity_trn: 7999.99
iter: 0, loss_val: 8.9744, perplexity_val: 7898.42
['in', 'a', 'warm', 'corner', 'of', 'which', 'he', 'stretched', 'his', 'weary', 'limbs', 'and', 'soon', 'fell', 'asleep', '.', '\n', 'when', 'he', 'awoke', 'next', 'morning', 'and', 'tried', 'to', 'recollect', 'his', 'dreams', 'which', 'got', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that']
['the', 'glass', '.', 'and', 'thus', 'he', 'knew', 'and', 'by', 'returning', 'it', 'gave', 'mr', 'pecksniff', 'the', 'information', 'that', 'he', 'knew', 'where', 'the', 'listener', 'had', 'been', ';', 'and', 'that', 'instead', 'of', 'hat', 'the', 'the', 'the', 'the', 'the', 'the', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that']
['rear', 'who', 'were', 'pressing', 'forward', 'to', 'get', 'out', 'of', 'the', 'way', 'but', 'were'

KeyboardInterrupt: 