## Importing Dependencies

In [51]:
#These are all the modules we'll be using later. Make sure you can import them
# before proceeding further.
%matplotlib inline
from __future__ import print_function
import collections
import math
import numpy as np
import os
import random
import tensorflow as tf
import zipfile
from matplotlib import pylab
from six.moves import range
from six.moves.urllib.request import urlretrieve
import tensorflow as tf
import csv

## Downloading Stories
Stories are automatically downloaded from https://www.cs.cmu.edu/~spok/grimmtmp/, if not detected in the disk. The total size of stories is around ~500KB. The dataset consists of 100 stories.

In [52]:
url = 'https://www.cs.cmu.edu/~spok/grimmtmp/'

#create a directory if needed
dir_name = 'stories'
if not os.path.exists(dir_name):
    os.mkdir(dir_name)
    
def download_data(filename):
    """Download a file if not present"""
    print('Downloading file: ', dir_name+os.sep+filename)
    
    if not os.path.exists(dir_name + os.sep+filename):
        filename, _ = urlretrieve(url + filename, dir_name+os.sep+filename)
    else: 
        print('file ', filename, ' already exits')
        
    return filename

num_files = 209 

filenames = [format(i, '03d')+'.txt' for i in range(1, num_files+1)]

for fn in filenames:
    download_data(fn)

Downloading file:  stories/001.txt
file  001.txt  already exits
Downloading file:  stories/002.txt
file  002.txt  already exits
Downloading file:  stories/003.txt
file  003.txt  already exits
Downloading file:  stories/004.txt
file  004.txt  already exits
Downloading file:  stories/005.txt
file  005.txt  already exits
Downloading file:  stories/006.txt
file  006.txt  already exits
Downloading file:  stories/007.txt
file  007.txt  already exits
Downloading file:  stories/008.txt
file  008.txt  already exits
Downloading file:  stories/009.txt
file  009.txt  already exits
Downloading file:  stories/010.txt
file  010.txt  already exits
Downloading file:  stories/011.txt
file  011.txt  already exits
Downloading file:  stories/012.txt
file  012.txt  already exits
Downloading file:  stories/013.txt
file  013.txt  already exits
Downloading file:  stories/014.txt
file  014.txt  already exits
Downloading file:  stories/015.txt
file  015.txt  already exits
Downloading file:  stories/016.txt
file 

In [53]:
## check if the files are downloaded 
for i in range(len(filenames)):
    file_exists = os.path.isfile(os.path.join(dir_name, filenames[i]))
    assert file_exists
print('%d files found.' %len(filenames))

209 files found.


## Reading data 
Data will be stored in a list of lists where each list represents document and a document is a list of words. we will then break the text into bigrams

In [54]:
def read_data(filename):
    
    with open(filename) as f:
        data = tf.compat.as_str(f.read())
        #make all the words lowercaser
        data = data.lower()
        data = list(data)
    return data

documents = []
global documents
for i in range(num_files):
    print('\nProcessing file %s' %os.path.join(dir_name, filenames[i]))
    chars = read_data(os.path.join(dir_name, filenames[i]))
    
    #break the data into bigrams
    two_grams = [''.join(chars[ch_i:ch_i+2]) for ch_i in range(0, len(chars)-2, 2)]
    #create a list of lists with bigrams
    documents.append(two_grams)
    print('Data size (characters) (Document %d) %d' %(i, len(two_grams)))
    print('sample string (Documents %d) %s' %(i, two_grams[:50]))


Processing file stories/001.txt
Data size (characters) (Document 0) 3667
sample string (Documents 0) ['in', ' o', 'ld', 'en', ' t', 'im', 'es', ' w', 'he', 'n ', 'wi', 'sh', 'in', 'g ', 'st', 'il', 'l ', 'he', 'lp', 'ed', ' o', 'ne', ', ', 'th', 'er', 'e ', 'li', 've', 'd ', 'a ', 'ki', 'ng', '\nw', 'ho', 'se', ' d', 'au', 'gh', 'te', 'rs', ' w', 'er', 'e ', 'al', 'l ', 'be', 'au', 'ti', 'fu', 'l,']

Processing file stories/002.txt
Data size (characters) (Document 1) 4928
sample string (Documents 1) ['ha', 'rd', ' b', 'y ', 'a ', 'gr', 'ea', 't ', 'fo', 're', 'st', ' d', 'we', 'lt', ' a', ' w', 'oo', 'd-', 'cu', 'tt', 'er', ' w', 'it', 'h ', 'hi', 's ', 'wi', 'fe', ', ', 'wh', 'o ', 'ha', 'd ', 'an', '\no', 'nl', 'y ', 'ch', 'il', 'd,', ' a', ' l', 'it', 'tl', 'e ', 'gi', 'rl', ' t', 'hr', 'ee']

Processing file stories/003.txt
Data size (characters) (Document 2) 9745
sample string (Documents 2) ['a ', 'ce', 'rt', 'ai', 'n ', 'fa', 'th', 'er', ' h', 'ad', ' t', 'wo', ' s', 'on', 's,',

Data size (characters) (Document 23) 2529
sample string (Documents 23) ['th', 'e ', 'mo', 'th', 'er', ' o', 'f ', 'ha', 'ns', ' s', 'ai', 'd,', ' w', 'hi', 'th', 'er', ' a', 'wa', 'y,', ' h', 'an', 's.', '  ', 'ha', 'ns', ' a', 'ns', 'we', 're', 'd,', ' t', 'o\n', 'gr', 'et', 'el', '. ', ' b', 'eh', 'av', 'e ', 'we', 'll', ', ', 'ha', 'ns', '. ', ' o', 'h,', ' i', "'l"]

Processing file stories/025.txt
Data size (characters) (Document 24) 2416
sample string (Documents 24) ['an', ' a', 'ge', 'd ', 'co', 'un', 't ', 'on', 'ce', ' l', 'iv', 'ed', ' i', 'n ', 'sw', 'it', 'ze', 'rl', 'an', 'd,', ' w', 'ho', ' h', 'ad', ' a', 'n ', 'on', 'ly', ' s', 'on', ',\n', 'bu', 't ', 'he', ' w', 'as', ' s', 'tu', 'pi', 'd,', ' a', 'nd', ' c', 'ou', 'ld', ' l', 'ea', 'rn', ' n', 'ot']

Processing file stories/026.txt
Data size (characters) (Document 25) 3369
sample string (Documents 25) ['th', 'er', 'e ', 'wa', 's ', 'on', 'ce', ' a', ' m', 'an', ' w', 'ho', ' h', 'ad', ' a', ' d', 'au', 'gh', 'te', 'r

Data size (characters) (Document 50) 5608
sample string (Documents 50) ['\ta', 'll', 'er', 'le', 'ir', 'au', 'h\n', '\nt', 'he', 're', ' w', 'as', ' o', 'nc', 'e ', 'up', 'on', ' a', ' t', 'im', 'e ', 'a ', 'ki', 'ng', ' w', 'ho', ' h', 'ad', ' a', ' w', 'if', 'e ', 'wi', 'th', ' g', 'ol', 'de', 'n ', 'ha', 'ir', ',\n', 'an', 'd ', 'sh', 'e ', 'wa', 's ', 'so', ' b', 'ea']

Processing file stories/052.txt
Data size (characters) (Document 51) 1287
sample string (Documents 51) ['th', 'er', 'e ', 'wa', 's ', 'on', 'ce', ' a', ' w', 'om', 'an', ' a', 'nd', ' h', 'er', ' d', 'au', 'gh', 'te', 'r ', 'wh', 'o ', 'li', 've', 'd ', 'in', ' a', '\np', 're', 'tt', 'y ', 'ga', 'rd', 'en', ' w', 'it', 'h ', 'ca', 'bb', 'ag', 'es', '. ', ' a', 'nd', ' a', ' l', 'it', 'tl', 'e ', 'ha']

Processing file stories/053.txt
Data size (characters) (Document 52) 2841
sample string (Documents 52) ['th', 'er', 'e ', 'wa', 's ', 'on', 'ce', ' a', ' k', 'in', "g'", 's ', 'so', 'n ', 'wh', 'o ', 'ha', 'd ', 'a ',

Data size (characters) (Document 85) 8758
sample string (Documents 85) ['th', 'er', 'e ', 'wa', 's ', 'on', 'ce', ' u', 'po', 'n ', 'a ', 'ti', 'me', ' a', ' k', 'in', 'g ', 'wh', 'o ', 'ha', 'd ', 'a ', 'li', 'tt', 'le', ' b', 'oy', ' i', 'n ', 'wh', 'os', 'e ', 'st', 'ar', 's\n', 'it', ' h', 'ad', ' b', 'ee', 'n ', 'fo', 're', 'to', 'ld', ' t', 'ha', 't ', 'he', ' s']

Processing file stories/087.txt
Data size (characters) (Document 86) 3109
sample string (Documents 86) ['th', 'er', 'e ', 'wa', 's ', 'on', 'ce', ' u', 'po', 'n ', 'a ', 'ti', 'me', ' a', ' p', 'ri', 'nc', 'es', 's ', 'wh', 'o ', 'wa', 's ', 'ex', 'tr', 'em', 'el', 'y ', 'pr', 'ou', 'd.', ' i', 'f ', 'a\n', 'wo', 'oe', 'r ', 'ca', 'me', ' s', 'he', ' g', 'av', 'e ', 'hi', 'm ', 'so', 'me', ' r', 'id']

Processing file stories/088.txt
Data size (characters) (Document 87) 1365
sample string (Documents 87) ['a ', 'ta', 'il', 'or', "'s", ' a', 'pp', 're', 'nt', 'ic', 'e ', 'wa', 's ', 'tr', 'av', 'el', 'in', 'g ', 'ab', 'o

Data size (characters) (Document 126) 5064
sample string (Documents 126) ['a ', 'po', 'or', ' w', 'oo', 'd-', 'cu', 'tt', 'er', ' l', 'iv', 'ed', ' w', 'it', 'h ', 'hi', 's ', 'wi', 'fe', ' a', 'nd', ' t', 'hr', 'ee', ' d', 'au', 'gh', 'te', 'rs', ' i', 'n\n', 'a ', 'li', 'tt', 'le', ' h', 'ut', ' o', 'n ', 'th', 'e ', 'ed', 'ge', ' o', 'f ', 'a ', 'lo', 'ne', 'ly', ' f']

Processing file stories/128.txt
Data size (characters) (Document 127) 10096
sample string (Documents 127) ['th', 'er', 'e ', 'wa', 's ', 'on', 'ce', ' u', 'po', 'n ', 'a ', 'ti', 'me', ' a', ' v', 'er', 'y ', 'ol', 'd ', 'wo', 'ma', 'n,', ' w', 'ho', ' l', 'iv', 'ed', ' w', 'it', 'h ', 'he', 'r\n', 'fl', 'oc', 'k ', 'of', ' g', 'ee', 'se', ' i', 'n ', 'a ', 're', 'mo', 'te', ' c', 'le', 'ar', 'in', 'g ']

Processing file stories/129.txt
Data size (characters) (Document 128) 1845
sample string (Documents 128) ['wh', 'en', ' a', 'da', 'm ', 'an', 'd ', 'ev', 'e ', 'we', 're', ' d', 'ri', 've', 'n ', 'ou', 't ', 'of', '

Data size (characters) (Document 159) 989
sample string (Documents 159) ['th', 'er', 'e ', 'we', 're', ' o', 'nc', 'e ', 'a ', 'co', 'ck', ' a', 'nd', ' a', ' h', 'en', ' w', 'ho', ' w', 'an', 'te', 'd ', 'to', ' t', 'ak', 'e ', 'a ', 'jo', 'ur', 'ne', 'y\n', 'to', 'ge', 'th', 'er', '. ', ' s', 'o ', 'th', 'e ', 'co', 'ck', ' b', 'ui', 'lt', ' a', ' b', 'ea', 'ut', 'if']

Processing file stories/161.txt
Data size (characters) (Document 160) 3114
sample string (Documents 160) ['a ', 'sh', 'ee', 'p-', 'do', 'g ', 'ha', 'd ', 'no', 't ', 'a ', 'go', 'od', ' m', 'as', 'te', 'r,', ' b', 'ut', ', ', 'on', ' t', 'he', ' c', 'on', 'tr', 'ar', 'y,', ' o', 'ne', ' w', 'ho', '\nl', 'et', ' h', 'im', ' s', 'uf', 'fe', 'r ', 'hu', 'ng', 'er', '. ', ' a', 's ', 'he', ' c', 'ou', 'ld']

Processing file stories/162.txt
Data size (characters) (Document 161) 6731
sample string (Documents 161) ['th', 'er', 'e ', 'wa', 's ', 'on', 'ce', ' u', 'po', 'n ', 'a ', 'ti', 'me', ' a', ' m', 'an', ' w', 'ho', ' w

Data size (characters) (Document 202) 681
sample string (Documents 202) ['a ', 'me', 'rc', 'ha', 'nt', ' h', 'ad', ' d', 'on', 'e ', 'go', 'od', ' b', 'us', 'in', 'es', 's ', 'at', ' t', 'he', ' f', 'ai', 'r.', '  ', 'he', ' h', 'ad', ' s', 'ol', 'd ', 'hi', 's\n', 'wa', 're', 's,', ' a', 'nd', ' l', 'in', 'ed', ' h', 'is', ' m', 'on', 'ey', '-b', 'ag', 's ', 'wi', 'th']

Processing file stories/204.txt
Data size (characters) (Document 203) 3501
sample string (Documents 203) ['th', 'is', ' s', 'to', 'ry', ', ', 'my', ' d', 'ea', 'r ', 'yo', 'un', 'g ', 'fo', 'lk', 's,', ' s', 'ee', 'ms', ' t', 'o ', 'be', ' f', 'al', 'se', ', ', 'bu', 't ', 'it', ' r', 'ea', 'll', 'y\n', 'is', ' t', 'ru', 'e,', ' f', 'or', ' m', 'y ', 'gr', 'an', 'df', 'at', 'he', 'r,', ' f', 'ro', 'm ']

Processing file stories/205.txt
Data size (characters) (Document 204) 978
sample string (Documents 204) ['th', 'er', 'e ', 'wa', 's ', 'on', 'ce', ' u', 'po', 'n ', 'a ', 'ti', 'me', ' a', ' f', 'ar', '-s', 'ig', 'ht'

## Building a Dictionaries (Bigrams)
Build the following. to understand each of these elements, let use also assume the test "I like to go to school"
- dictionary : maps a string word to an ID . ({I:O, like:1, to:2, go:3, school:4})
- reverse_dictionary: maps ID to words({0:I, 1:like, 2:to, 3:go, 4:school})
- count: list of list of (word, frequency) elements (e.g [(I,1), (to,2), (go,1), (school,1)])
- data: Contain the string of text we read, where string words are replaced with word IDs e.g [0, 1,2,3,4]

In [55]:
def build_dataset(documents):
    chars = []
    # This is going to be a list of lists
    # Where the outer list denote each document
    # and the inner lists denote words in a given document
    data_list = []
  
    for d in documents:
        chars.extend(d)
    print('%d Characters found.'%len(chars))
    count = []
    # Get the bigram sorted by their frequency (Highest comes first)
    count.extend(collections.Counter(chars).most_common())
    
    # Create an ID for each bigram by giving the current length of the dictionary
    # And adding that item to the dictionary
    # Start with 'UNK' that is assigned to too rare words
    dictionary = dict({'UNK':0})
    for char, c in count:
        # Only add a bigram to dictionary if its frequency is more than 10
        if c > 10:
            dictionary[char] = len(dictionary)    
    
    unk_count = 0
    # Traverse through all the text we have
    # to replace each string word with the ID of the word
    for d in documents:
        data = list()
        for char in d:
            # If word is in the dictionary use the word ID,
            # else use the ID of the special token "UNK"
            if char in dictionary:
                index = dictionary[char]        
            else:
                index = dictionary['UNK']
                unk_count += 1
            data.append(index)
            
        data_list.append(data)
        
    reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys())) 
    return data_list, count, dictionary, reverse_dictionary

global data_list, count, dictionary, reverse_dictionary,vocabulary_size

# Print some statistics about data
data_list, count, dictionary, reverse_dictionary = build_dataset(documents)
print('Most common words (+UNK)', count[:5])
print('Least common words (+UNK)', count[-15:])
print('Sample data', data_list[0][:10])
print('Sample data', data_list[1][:10])
print('Vocabulary: ',len(dictionary))
vocabulary_size = len(dictionary)
del documents  # To reduce memory.

727505 Characters found.
Most common words (+UNK) [('e ', 24554), ('he', 24203), (' t', 21726), ('th', 21031), ('d ', 16995)]
Least common words (+UNK) [('md', 1), ('dt', 1), ('xu', 1), ('x-', 1), ('-.', 1), ('tp', 1), ('-j', 1), ('lg', 1), ('uj', 1), ('kd', 1), ('z.', 1), ('kt', 1), ('oj', 1), ('c-', 1), ('!"', 1)]
Sample data [16, 27, 88, 26, 3, 96, 72, 11, 2, 17]
Sample data [23, 157, 25, 36, 78, 183, 42, 9, 87, 19]
Vocabulary:  572


## Generating Batches of Data
The following object generates a batch of data which will be used to train the LSTM. More specifically the generator breaks a given sequence of words into batch_size segments. we also maintain a cursor of each segment. so whenever we create a bacth of data, we sample one item from each segment and update the cursor of each segment

In [56]:
class DataGeneratorOHE(object):
    
    def __init__(self, text, batch_size, num_unroll):
        #Text where a bigram is denoted by its ID
        self._text = text
        #Number of bigrams in the text
        self._text_size = len(self._text)
        #number of data points ina batch of data
        self._batch_size = batch_size
        #Num unroll is the number of steps we unroll the RNN in a single training step
        self._num_unroll = num_unroll
        #we break the text into several segments and the batch of data is sampled by samplying a single item from a single segemnt
        self._segments = self._text_size // self._batch_size
        self._cursor = [offset * self._segments for offset in range(self._batch_size)]
    
    def next_batch(self):
        '''Generates a single batch of data'''
        #train inputs (one-hot-encoded)  and train outputs (one-hot-encode)
        batch_data = np.zeros((self._batch_size, vocabulary_size), dtype=np.float32)
        batch_labels = np.zeros((self._batch_size, vocabulary_size), dtype=np.float32)
        
        #fill in the batch datapoint by datapoint
        for b in range(self._batch_size):
            #if the cursor of a given segment exceeds the segment length
            #we reset the cursor back to the beginning of that segment
            if self._cursor[b]+1 >= self._text_size:
                self._cursor[b] = b * self._segments
                
            #add the text at the cursor as the input
            batch_data[b, self._text[self._cursor[b]]] = 1.0
            #add the preceeding bigram as the label to be predicted
            batch_labels[b, self._text[self._cursor[b]+1]] = 1.0
            #update the cursor
            self._cursor[b] = (self._cursor[b]+1)%self._text_size
            
        return batch_data, batch_labels
    
    def unroll_batches(self):
        '''This produces a list of num_unroll batches as required by a single step of training the RNN'''
        
        unroll_data, unroll_labels = [], []
        for ui in range(self._num_unroll):
            data, labels= self.next_batch()
            unroll_data.append(data)
            unroll_labels.append(labels)
            
        return unroll_data, unroll_labels
    
    def reset_indices(self):
        '''used to reset all the cursors if needed'''
        self._cursor = [offset * self._segments for offset in range(self._batch_size)]
        
# running a tiny set to see if things are correct
dg = DataGeneratorOHE(data_list[0][25:50], 5, 5)
u_data, u_labels =dg.unroll_batches()

# Iterate through each data batch in the unrolled set of batches
for ui,(dat,lbl) in enumerate(zip(u_data,u_labels)):   
    print('\n\nUnrolled index %d'%ui)
    dat_ind = np.argmax(dat,axis=1)
    lbl_ind = np.argmax(lbl,axis=1)
    print('\tInputs:')
    for sing_dat in dat_ind:
        print('\t%s (%d)'%(reverse_dictionary[sing_dat],sing_dat),end=", ")
    print('\n\tOutput:')
    for sing_lbl in lbl_ind:        
        print('\t%s (%d)'%(reverse_dictionary[sing_lbl],sing_lbl),end=", ")



Unrolled index 0
	Inputs:
	e  (1), 	ki (152), 	 d (49), 	 w (11), 	be (69), 
	Output:
	li (98), 	ng (34), 	au (215), 	er (13), 	au (215), 

Unrolled index 1
	Inputs:
	li (98), 	ng (34), 	au (215), 	er (13), 	au (215), 
	Output:
	ve (43), 	
w (167), 	gh (109), 	e  (1), 	ti (112), 

Unrolled index 2
	Inputs:
	ve (43), 	
w (167), 	gh (109), 	e  (1), 	ti (112), 
	Output:
	d  (5), 	ho (61), 	te (62), 	al (80), 	fu (235), 

Unrolled index 3
	Inputs:
	d  (5), 	ho (61), 	te (62), 	al (80), 	fu (235), 
	Output:
	a  (78), 	se (56), 	rs (138), 	l  (57), 	l, (260), 

Unrolled index 4
	Inputs:
	a  (78), 	se (56), 	rs (138), 	l  (57), 	be (69), 
	Output:
	ki (152), 	 d (49), 	 w (11), 	be (69), 	au (215), 

## Defining the LSTM
This is a standard LSTM. the LSTM has 5 main component
- Cell state
- Hidden state
- Input gate
- Forget gate
- output gate

Each gate has three sets of weights (1 set for the current input, 1 set for the previous hidden state and 1 bias)

## Defining hyperparameters¶
Here we define several hyperparameters. However additionally we use dropout; a technique that helps to avoid overfitting.

In [57]:
#number of neurons in the hidden state varibles
num_nodes = 128

#number of data points in a batch we proces
batch_size = 64

#number of times steps we unroll for during optimization
num_unrollings = 50

dropout = 0.0 

#use this in the scv filename when saving 
filename_extension = ''
if dropout > 0.0:
    filename_extension = '_dropout'

filename_to_save = 'lstm'+filename_extension+'.csv' # use to save perplexity values

## Defining Inputs and Outputs

In the code we define two different types of inputs. 
* Training inputs (The stories we downloaded) (batch_size > 1 with unrolling)
* Validation inputs (An unseen validation dataset) (bach_size =1, no unrolling)
* Test input (New story we are going to generate) (batch_size=1, no unrolling)

In [58]:
tf.reset_default_graph()

#training input data.
train_inputs, train_labels = [],[]

#defining unrolled training inputs
for ui in range(num_unrollings):
    train_inputs.append(tf.placeholder(tf.float32, shape=[batch_size, vocabulary_size], name='train_inputs_%d'%ui))
    train_labels.append(tf.placeholder(tf.float32, shape=[batch_size, vocabulary_size], name='train_labels_%d'%ui))
    
    #validation data placeholders
    valid_inputs = tf.placeholder(tf.float32, shape=[1, vocabulary_size], name='valid_inputs')
    valid_labels = tf.placeholder(tf.float32, shape=[1, vocabulary_size], name='valid_labels')
    
    #text generation: batch 1, no unrolling
    test_input = tf.placeholder(tf.float32, shape=[1, vocabulary_size], name='test_input')

## Defining Model Parameters

Now we define model parameters. Compared to RNNs, LSTMs have a large number of parameters. Each gate (input, forget, memory and output) has three different sets of parameters.

In [63]:
#input gate (i_t) - How much memory to write to cell state
#connect the current input to the input gate
ix = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], stddev=0.02))
#connects the prvious hidden state to the input gate
im = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], stddev=0.02))
#bais of the input gate
ib = tf.Variable(tf.random_uniform([1, num_nodes], -0.02, 0.02))

#Forget gate (f_t) - How much memoery is discard from the cell state
# connect the current input to the forget gate
fx = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], stddev=0.02))
#connects the prvious hidden state to the forget gate
fm = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], stddev=0.02))
#bais of the forget gate
fb = tf.Variable(tf.random_uniform([1, num_nodes], -0.02, 0.02))

#Candidate value (c~_t) - used to compute the current cell state
# connect the current input to the candidate
cx = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], stddev=0.02))
#connects the prvious hidden state to the candidate
cm = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], stddev=0.02))
#bais of the candidate
cb = tf.Variable(tf.random_uniform([1, num_nodes], -0.02, 0.02))


# Output gate (o_t) - How much memory to output from the cell state
# Connects the current input to the output gate
ox = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], stddev=0.02))
# Connects the previous hidden state to the output gate
om = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], stddev=0.02))
# Bias of the output gate
ob = tf.Variable(tf.random_uniform([1, num_nodes],-0.02,0.02))

#softmax Classifier weights and biases.
w = tf.Variable(tf.truncated_normal([num_nodes, vocabulary_size], stddev=0.02))
b = tf.Variable(tf.random_uniform([1, num_nodes], -0.02, 0.02))

# Variables saving state across unrollings.
# Hidden state
saved_output = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False, name='train_hidden')
# Cell state
saved_state = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False, name='train_cell')

# Same variables for validation phase
saved_valid_output = tf.Variable(tf.zeros([1, num_nodes]),trainable=False, name='valid_hidden')
saved_valid_state = tf.Variable(tf.zeros([1, num_nodes]),trainable=False, name='valid_cell')

# Same variables for testing phase
saved_test_output = tf.Variable(tf.zeros([1, num_nodes]),trainable=False, name='test_hidden')
saved_test_state = tf.Variable(tf.zeros([1, num_nodes]),trainable=False, name='test_cell')


# Softmax Classifier weights and biases.
w = tf.Variable(tf.truncated_normal([num_nodes, vocabulary_size], stddev=0.02)) 
b = tf.Variable(tf.random_uniform([vocabulary_size],-0.02,0.02))


## Defining LSTM Computations
Here first we define the LSTM cell computations as a consice function. Then we use this function to define training and test-time inference logic.

In [64]:
#definition of the cell computation
def lstm_cell(i, o, state):
    """create an lstm cell"""
    input_gate = tf.sigmoid(tf.matmul(i, ix) + tf.matmul(o, im) + ib)
    forget_gate = tf.sigmoid(tf.matmul(i, fx) + tf.matmul(o, fm) + fb)
    update = tf.matmul(i, cx) + tf.matmul(o, cm) + cb
    state = forget_gate * state + input_gate * tf.tanh(update)
    output_gate = tf.sigmoid(tf.matmul(i,ox) + tf.matmul(o, om) + ob)
    
    return output_gate * tf.tanh(state), state

In [73]:
# =========================================================
#Training related inference logic

# Keeps the calculated state outputs in all the unrollings
# Used to calculate loss
outputs = list()

#These two python variables are iteratively updated at each step of unrolling
output = saved_output
state = saved_state

#compute the hidden state (output) and cell state (state)
# recursively for all the teps in unrolling
for i in train_inputs: 
    output, state = lstm_cell(i, output, state)
    output = tf.nn.dropout(output, keep_prob=1.0-dropout)
    #append eachcomputed output vale
    outputs.append(output)
    
#calculate the score values
logits = tf.matmul(tf.concat(axis=0, values=outputs), w) + b

# Compute predictions.
train_prediction = tf.nn.softmax(logits)

# Compute training perplexity
train_perplexity_without_exp = tf.reduce_sum(tf.concat(train_labels,0)*-tf.log(tf.concat(train_prediction,0)+1e-10))/(num_unrollings*batch_size)

# =====================================================================
# validation phase related inference logic

#compute the LSTM cell output for validation data
valid_output, valid_state = lstm_cell(valid_inputs, saved_valid_output, saved_valid_state)

#compute the logits
valid_logits = tf.nn.xw_plus_b(valid_output, w, b)

#compute training perplexity 
train_perplexity_without_exp = tf.reduce_sum(tf.concat(train_labels,0)*-tf.log(tf.concat(train_prediction,0)+1e-10))/(num_unrollings*batch_size)

#Make sure that the state variables are updated
#before moving on to the next iteration of generation
with tf.control_dependencies([saved_valid_output.assign(valid_output), 
                            saved_valid_state.assign(valid_state)]):
    valid_prediction = tf.nn.softmax(valid_logits)


# Compute validation perplexity
valid_perplexity_without_exp = tf.reduce_sum(valid_labels*-tf.log(valid_prediction+1e-10))

# ========================================================================
# Testing phase related inference logic

# Compute the LSTM cell output for testing data
test_output, test_state = lstm_cell(test_input, saved_test_output, saved_test_state)

#compute test logit
test_logits = tf.nn.xw_plus_b(test_output, w, b)

#Make sure that the state variables are updated
#before moving on to the next iteration of generation
with tf.control_dependencies([saved_test_output.assign(test_output),
                            saved_test_state.assign(test_state)]):
    test_prediction = tf.nn.softmax(test_logits)
print(test_prediction)

Tensor("Softmax_16:0", shape=(1, 572), dtype=float32)


## Calculating LSTM Loss
we calculate the training loss of the LSTM. It's a typical cross entropy loss calculated over all the scores we obtained for training data(loss) '

In [74]:
# before calculating the training loss,
#save the hidden state and the cell state to their respective Tensorflow variables
with tf.control_dependencies([saved_output.assign(output), 
                              saved_state.assign(state)]):
    #calculate the training loss by concatenation the results form all the unrolled time steps
    loss = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits_v2(
        logits = logits, labels=tf.concat(axis=0, values=train_labels))
    )
print(logits)

Tensor("add_3259:0", shape=(3200, 572), dtype=float32)


## Defining Learning Rate and the Optimizer with Gradient Clipping
Here we define the learning rate and the optimizer we're going to use. We will be using Adam optimizer. We use gradient clipping to prevent any gradient explosions 

In [77]:
#learning rate decay
gstep  = tf.Variable(0, trainable=False, name='global_step')

#Running this operation will cause the value of gstep to increase, while in turn reducing the learning rate
inc_gstep = tf.assign(gstep, gstep+1)

# decays learning rate everytime the gstep increase
tf_learning_rate = tf.train.exponential_decay(0.001, gstep,
                                             decay_steps=1, decay_rate=0.5)
#Adam optimizer and gradient clipping
optimizer = tf.train.AdamOptimizer(tf_learning_rate)
gradients, v = zip(*optimizer.compute_gradients(loss))
#clipping gradients
gradients, _ = tf.clip_by_global_norm(gradients, 5.0)

optimizer = optimizer.apply_gradients(
    zip(gradients, v)
)
print(optimizer)

name: "Adam_1"
op: "NoOp"
input: "^Adam_1/update_Variable_16/ApplyAdam"
input: "^Adam_1/update_Variable_17/ApplyAdam"
input: "^Adam_1/update_Variable_18/ApplyAdam"
input: "^Adam_1/update_Variable_19/ApplyAdam"
input: "^Adam_1/update_Variable_20/ApplyAdam"
input: "^Adam_1/update_Variable_21/ApplyAdam"
input: "^Adam_1/update_Variable_22/ApplyAdam"
input: "^Adam_1/update_Variable_23/ApplyAdam"
input: "^Adam_1/update_Variable_24/ApplyAdam"
input: "^Adam_1/update_Variable_25/ApplyAdam"
input: "^Adam_1/update_Variable_26/ApplyAdam"
input: "^Adam_1/update_Variable_27/ApplyAdam"
input: "^Adam_1/update_Variable_30/ApplyAdam"
input: "^Adam_1/update_Variable_31/ApplyAdam"
input: "^Adam_1/Assign"
input: "^Adam_1/Assign_1"



## Resetting Operations for Resetting Hidden States
Sometimes the state variable needs to be reset (e.g. when starting predictions at a beginning of a new epoch)

In [80]:
# Reset train state
reset_train_state = tf.group(tf.assign(saved_state, tf.zeros([batch_size, num_nodes])),
                            tf.assign(saved_output, tf.zeros([batch_size, num_nodes])))

#Reset valid state
reset_valid_state = tf.group(tf.assign(saved_valid_state, tf.zeros([1, num_nodes])),
                            tf.assign(saved_valid_output, tf.zeros([1, num_nodes])))
# Reset test state
reset_test_state = tf.group(tf.assign(saved_test_output.assign(tf.random_normal([1, num_nodes], stddev=0.05)),
                                     saved_test_state.assign(tf.random_normal([1, num_nodes], stddev=0.05))))


## Greedy Sampling to Break the Repetition
Here we write some simple logic to break the repetition in text. Specifically instead of always getting the word that gave this highest prediction probability, we sample randomly where the probability of being selected given by their prediction probabilities.

In [85]:
def sample(distribution):
    '''Greedy sampling
    We pick the three best predictions given by the LSTM and sample one of them with very high 
    probability of picking the best one'''
    
    best_inds = np.argsort(distribution)[-3:]
    best_probs = distribution[best_inds] / np.sum(distribution[best_inds])
    best_idx = np.random.choice(best_inds, p=best_probs)
    return best_idx
    

## Running the LSTM to Generate Text
Here we train the LSTM on the available data and generate text using the trained LSTM for several steps. From each document we extract text for steps_per_document steps to train the LSTM on. We also report the train perplexity at the end of each step. Finally we test the LSTM by asking it to generate some new text starting from a randomly picked bigram.

### Learning rate Decay Logic
Here we define the logic to decrease learning rate whenever the validation perplexity does not decrease



In [86]:
# Learning rate decay related
# If valid perpelxity does not decrease
# continuously for this many epochs
# decrease the learning rate
decay_threshold = 5
# Keep counting perplexity increases
decay_count = 0

min_perplexity = 1e10

# Learning rate decay logic
def decay_learning_rate(session, v_perplexity):
    global decay_threshold, decay_count, min_perplexity  
    # Decay learning rate
    if v_perplexity < min_perplexity:
        decay_count = 0
        min_perplexity= v_perplexity
    else: 
        decay_count += 1
    
    if decay_count >= decay_threshold:
        print('\t Reducing learning rate')
        decay_count = 0
        
        session.run(inc_gstep)

In [87]:
# Some hyperparameters needed for the training process

num_steps = 26
steps_per_document = 100
valid_summary = 1
train_doc_count = 100
docs_per_step = 10

# Capture the behavior of train perplexity over time
train_perplexity_ot = []
valid_perplexity_ot = []

session = tf.InteractiveSession()

# Initializing variables
tf.global_variables_initializer().run()
print('Initialized Global Variables ')

average_loss = 0 # Calculates the average loss ever few steps

# We use the first 10 documents that has 
# more than 10*steps_per_document bigrams for creating the validation dataset

# Identify the first 10 documents following the above condition
long_doc_ids = []
for di in range(num_files):
  if len(data_list[di])>10*steps_per_document:
    long_doc_ids.append(di)
  if len(long_doc_ids)==10:
    break
    
# Generating validation data
data_gens = []
valid_gens = []
for fi in range(num_files):
  # Get all the bigrams if the document id is not in the validation document ids
  if fi not in long_doc_ids:
    data_gens.append(DataGeneratorOHE(data_list[fi],batch_size,num_unrollings))
  # if the document is in the validation doc ids, only get up to the 
  # last steps_per_document bigrams and use the last steps_per_document bigrams as validation data
  else:
    data_gens.append(DataGeneratorOHE(data_list[fi][:-steps_per_document],batch_size,num_unrollings))
    # Defining the validation data generator
    valid_gens.append(DataGeneratorOHE(data_list[fi][-steps_per_document:],1,1))


feed_dict = {}
for step in range(num_steps):
    
    print('Training (Step: %d)'%step,end=' ')
    for di in np.random.permutation(train_doc_count)[:docs_per_step]:            
        doc_perplexity = 0
        for doc_step_id in range(steps_per_document):
            
            # Get a set of unrolled batches
            u_data, u_labels = data_gens[di].unroll_batches()
            
            # Populate the feed dict by using each of the data batches
            # present in the unrolled data
            for ui,(dat,lbl) in enumerate(zip(u_data,u_labels)):            
                feed_dict[train_inputs[ui]] = dat
                feed_dict[train_labels[ui]] = lbl
            
            # Running the TensorFlow operations
            _, l, step_perplexity = session.run([optimizer, loss, train_perplexity_without_exp], 
                                                       feed_dict=feed_dict)
            # Update doc_perpelxity variable
            doc_perplexity += step_perplexity 
            
            # Update the average_loss variable
            average_loss += step_perplexity 
            
        # Show the printing progress <train_doc_id_1>.<train_doc_id_2>. ...
        print('(%d).'%di,end='') 
    
        # resetting hidden state after processing a single document
        # It's still questionable if this adds value in terms of learning
        # One one hand it's intuitive to reset the state when learning a new document
        # On the other hand this approach creates a bias for the state to be zero
        # We encourage the reader to investigate further the effect of resetting the state
        #session.run(reset_train_state) # resetting hidden state for each document
    print('')
    
    # Generate new samples
    if (step+1) % valid_summary == 0:
      
      # Compute average loss
      average_loss = average_loss / (valid_summary*docs_per_step*steps_per_document)
      
      # Print losses
      print('Average loss at step %d: %f' % (step+1, average_loss))
      print('\tPerplexity at step %d: %f' %(step+1, np.exp(average_loss)))
      train_perplexity_ot.append(np.exp(average_loss))
        
      average_loss = 0 # reset loss
      valid_loss = 0 # reset loss
        
      # calculate valid perplexity
      for v_doc_id in range(10):
          # Remember we process things as bigrams
          # So need to divide by 2
          for v_step in range(steps_per_document//2):
            uvalid_data,uvalid_labels = valid_gens[v_doc_id].unroll_batches()        

            # Run validation phase related TensorFlow operations       
            v_perp = session.run(
                valid_perplexity_without_exp,
                feed_dict = {valid_inputs:uvalid_data[0],valid_labels: uvalid_labels[0]}
            )

            valid_loss += v_perp
            
          session.run(reset_valid_state)
      
          # Reset validation data generator cursor
          valid_gens[v_doc_id].reset_indices()      
    
      print()
      v_perplexity = np.exp(valid_loss/(steps_per_document*10.0//2))
      print("Valid Perplexity: %.2f\n"%v_perplexity)
      valid_perplexity_ot.append(v_perplexity)
          
      decay_learning_rate(session, v_perplexity)
        
      # Generating new text ...
      # We will be generating one segment having 500 bigrams
      # Feel free to generate several segments by changing
      # the value of segments_to_generate
      print('Generated Text after epoch %d ... '%step)  
      segments_to_generate = 1
      chars_in_segment = 500
    
      for _ in range(segments_to_generate):
        print('======================== New text Segment ==========================')
        
        # Start with a random word
        test_word = np.zeros((1,vocabulary_size),dtype=np.float32)
        rand_doc = data_list[np.random.randint(0,num_files)]
        test_word[0,rand_doc[np.random.randint(0,len(rand_doc))]] = 1.0
        print("\t",reverse_dictionary[np.argmax(test_word[0])],end='')
        
        # Generating words within a segment by feeding in the previous prediction
        # as the current input in a recursive manner
        for _ in range(chars_in_segment):    
          sample_pred = session.run(test_prediction, feed_dict = {test_input:test_word})  
          next_ind = sample(sample_pred.ravel())
          test_word = np.zeros((1,vocabulary_size),dtype=np.float32)
          test_word[0,next_ind] = 1.0
          print(reverse_dictionary[next_ind],end='')
        print("")
        
        # Reset train state
        session.run(reset_test_state)
        print('====================================================================')
      print("")

session.close()

# Write training and validation perplexities to a csv file
with open(filename_to_save, 'wt') as f:
    writer = csv.writer(f, delimiter=',')
    writer.writerow(train_perplexity_ot)
    writer.writerow(valid_perplexity_ot)



Initialized Global Variables 
Training (Step: 0) (36).(98).(73).(94).(61).(88).(47).(45).(21).(24).
Average loss at step 1: 4.407770
	Perplexity at step 1: 82.086177

Valid Perplexity: 67.70

Generated Text after epoch 0 ... 
	 irstard ther, and the to then the poped said, and the pome, and the came, and ther, but he was he pome to be had the did, and now he had the dredersed the was he pome the was his was they the poped they his whould the took the was his his was the was he pome he the pome, and there the poped him the pope, which the pope, and the pord, and that then he whoulders, what his his he pome the was then he had to the pome that the to to then his his had he took to be to be to the to to be the was the do the to as the pord, and not the to the was the dide he was his to the hen he was his the don, and the pome then the pome his to the pord of he his was he the to as the
dere, and said, and the pope.  and did, but the poped the pome he the to the was his had the was he hen 


Valid Perplexity: 44.27

Generated Text after epoch 6 ... 
	 t of the fat nothingen herwas and cried, it in the miden from white way prand indow and lighted that the heard that she toged to her that the king's spikes, and she was
at last her came, and the king ast once had bedpick that he had goitting the heard kingdom from much in thek down in the king's peachack them dow, and
heard the for her lorst to great for
him, they that he shall
defive his was as, and then the hedgehog's skin the redgehog's daughter would conformed.  the bate and was to be there by which oner to the royal pangeraced to him his so may only to betting there was a lastone, and the hedgehog hans the hedgehog, however, by that in the father was
ate and to by the fidger the hedgesiant told she came to the royal palace.  but the king camed him with the king and celle it and life a mangewas to be to
him, the isling a cone by
himself on the kingdom from the age.  then she reventme coulded his sleep.  the king asked hi

bird by the bailiff.  what was shut bride long come to the box on the ears, and put as if he would but when the king's ate wilk.  when she have her beautiful drins went into the mill-stouchen heard it out, and went out of it, he tred on the second time, he will was to the birst.

at the wedding outside.  then they wed them, he was standing about him the have had grown, said if
not, was away to them, and have him not no longer and bed, how he
had a little, and
ther, but reat of her fare of the mill,
too, he
said, no, no on her own form and ring four-footed the heares, the had one your father in some
he, said, i have mading he at last bire arest back for hards, and began to the night, for the others, and said, that they were hon walk to him, but, i have along of whom to gare him again she was to anythen she which and her which was away in

Training (Step: 13) (57).(20).(62).(13).(18).(72).(19).(26).(43).(71).
Average loss at step 14: 1.873695
	Perplexity at step 14: 6.512316

Valid Perpl

wanting the whole likewise and said 'if it was took her father, and only the poor with suchan the presented to a down to this.  then said the man was all the two little witch to his
kingdom, and
said, "you have not the goldier to the two with
began to him, and went again, and
said the thousands is will see that she said, "yould compart was as a little stoge, but she opened the good and said that they will be greather, and was again, and he led it were she threet out with her leavens, and went to
thise beater the king's son and said, oh, god, but
he should his place, and when she saw her to see if i stold it.  it into that, and while they said, and the town there she had been her.  the little house was on
the wedding was to change cantle doom, who made in the
door, they are the could creep.  then she said to go out again.  she went to the recognized hers but whing h

Training (Step: 19) (87).(98).(68).(55).(4).(45).(31).(44).(90).(42).
Average loss at step 20: 2.149967
	Perplexity at st

"oh, not master came and to pleep and said, for her full of joy, the brood-would not
do, they had go and was to at and own and cried.  what is your trade, and the witched up to take a that there all off mean and they were in a husband, and said, "i
could not beard.  if you will not down comes of the that she was to be stranger high that they could not have you will before.

they would which he was again,"

"the doge of through the king's daughter had took the boy wite one off the wolf's hause, and that he could not
be take one of them into them.  the miller, and when he satisfie

Training (Step: 25) (38).(71).(90).(12).(16).(77).(18).(24).(15).(91).
Average loss at step 26: 2.100774
	Perplexity at step 26: 8.172496

Valid Perplexity: 21.85

Generated Text after epoch 25 ... 
	 little head only him, but the king's son was the one was a peachied the servants on the said that they how do you stepishes, and they were hered, but the two rech on the shorment well, and it was
into the little


## LSTM with Beam-Search
Here we alter the previously defined prediction related TensorFlow operations to employ beam-search. Beam search is a way of predicting several time steps ahead. Concretely instead of predicting the best prediction we have at a given time step, we get predictions for several time steps and get the sequence of highest joint probability.

In [88]:
beam_length = 5 # number of steps to look ahead
beam_neighbors = 5 # number of neighbors to compare to at each step

# We redefine the sample generation with beam search
sample_beam_inputs = [tf.placeholder(tf.float32, shape=[1, vocabulary_size]) for _ in range(beam_neighbors)]

best_beam_index = tf.placeholder(shape=None, dtype=tf.int32)
best_neighbor_beam_indices = tf.placeholder(shape=[beam_neighbors], dtype=tf.int32)

# Maintains output of each beam
saved_sample_beam_output = [tf.Variable(tf.zeros([1, num_nodes])) for _ in range(beam_neighbors)]
# Maintains the state of each beam
saved_sample_beam_state = [tf.Variable(tf.zeros([1, num_nodes])) for _ in range(beam_neighbors)]

# Resetting the sample beam states (should be done at the beginning of each text snippet generation)
reset_sample_beam_state = tf.group(
    *[saved_sample_beam_output[vi].assign(tf.zeros([1, num_nodes])) for vi in range(beam_neighbors)],
    *[saved_sample_beam_state[vi].assign(tf.zeros([1, num_nodes])) for vi in range(beam_neighbors)]
)

# We stack them to perform gather operation below
stacked_beam_outputs = tf.stack(saved_sample_beam_output)
stacked_beam_states = tf.stack(saved_sample_beam_state)

# The beam states for each beam (there are beam_neighbor-many beams) needs to be updated at every depth of tree
# Consider an example where you have 3 classes where we get the best two neighbors (marked with star)
#     a`      b*       c  
#   / | \   / | \    / | \
#  a  b c  a* b` c  a  b  c
# Since both the candidates from level 2 comes from the parent b
# We need to update both states/outputs from saved_sample_beam_state/output to have index 1 (corresponding to parent b)
update_sample_beam_state = tf.group(
    *[saved_sample_beam_output[vi].assign(tf.gather_nd(stacked_beam_outputs,[best_neighbor_beam_indices[vi]])) for vi in range(beam_neighbors)],
    *[saved_sample_beam_state[vi].assign(tf.gather_nd(stacked_beam_states,[best_neighbor_beam_indices[vi]])) for vi in range(beam_neighbors)]
)

# We calculate lstm_cell state and output for each beam
sample_beam_outputs, sample_beam_states = [],[] 
for vi in range(beam_neighbors):
    tmp_output, tmp_state = lstm_cell(
        sample_beam_inputs[vi], saved_sample_beam_output[vi], saved_sample_beam_state[vi]
    )
    sample_beam_outputs.append(tmp_output)
    sample_beam_states.append(tmp_state)

# For a given set of beams, outputs a list of prediction vectors of size beam_neighbors
# each beam having the predictions for full vocabulary
sample_beam_predictions = []
for vi in range(beam_neighbors):
    with tf.control_dependencies([saved_sample_beam_output[vi].assign(sample_beam_outputs[vi]),
                                saved_sample_beam_state[vi].assign(sample_beam_states[vi])]):
        sample_beam_predictions.append(tf.nn.softmax(tf.nn.xw_plus_b(sample_beam_outputs[vi], w, b)))
        

## Running the LSTM with Beam Search to Generate Text
Here we train the LSTM on the available data and generate text using the trained LSTM for several steps. From each document we extract text for steps_per_document steps to train the LSTM on. We also report the train perplexity at the end of each step. Finally we test the LSTM by asking it to generate some new text with beam search starting from a randomly picked bigram.

### Learning rate Decay Logic
Here we define the logic to decrease learning rate whenever the validation perplexity does not decrease



In [None]:
# Learning rate decay related
# If valid perpelxity does not decrease
# continuously for this many epochs
# decrease the learning rate
decay_threshold = 5
# Keep counting perplexity increases
decay_count = 0

min_perplexity = 1e10

# Learning rate decay logic
def decay_learning_rate(session, v_perplexity):
  global decay_threshold, decay_count, min_perplexity  
  # Decay learning rate
  if v_perplexity < min_perplexity:
    decay_count = 0
    min_perplexity= v_perplexity
  else:
    decay_count += 1

  if decay_count >= decay_threshold:
    print('\t Reducing learning rate')
    decay_count = 0
    session.run(inc_gstep)

### Defining the Beam Prediction Logic
Here we define function that takes in the session as an argument and output a beam of predictions

In [90]:
test_word = None

def get_beam_prediction(session):
    
    # Generating words within a segment with Beam Search
    # To make some calculations clearer, we use the example as follows
    # We have three classes with beam_neighbors=2 (best candidate denoted by *, second best candidate denoted by `)
    # For simplicity we assume best candidate always have probability of 0.5 in output prediction
    # second best has 0.2 output prediction
    #           a`                   b*                   c                <--- root level
    #    /     |     \         /     |     \        /     |     \   
    #   a      b      c       a*     b`     c      a      b      c         <--- depth 1
    # / | \  / | \  / | \   / | \  / | \  / | \  / | \  / | \  / | \
    # a b c  a b c  a b c   a*b c  a`b c  a b c  a b c  a b c  a b c       <--- depth 2
    # So the best beams at depth 2 would be
    # b-a-a and b-b-a
        
    global test_word
    global sample_beam_predictions
    global update_sample_beam_state
    
    # Calculate the candidates at the root level
    feed_dict = {}
    for b_n_i in range(beam_neighbors):
        feed_dict.update({sample_beam_inputs[b_n_i]: test_word})

    # We calculate sample predictions for all neighbors with the same starting word/character
    # This is important to update the state for all instances of beam search
    sample_preds_root = session.run(sample_beam_predictions, feed_dict = feed_dict)  
    sample_preds_root = sample_preds_root[0]

    # indices of top-k candidates
    # b and a in our example (root level)
    this_level_candidates =  (np.argsort(sample_preds_root,axis=1).ravel()[::-1])[:beam_neighbors].tolist() 

    # probabilities of top-k candidates
    # 0.5 and 0.2
    this_level_probs = sample_preds_root[0,this_level_candidates] 

    # Update test sequence produced by each beam from the root level calculation
    # Test sequence looks like for our example (at root)
    # [b,a]
    test_sequences = ['' for _ in range(beam_neighbors)]
    for b_n_i in range(beam_neighbors):
        test_sequences[b_n_i] += reverse_dictionary[this_level_candidates[b_n_i]]

    # Make the calculations for the rest of the depth of the beam search tree
    for b_i in range(beam_length-1):
        test_words = [] # candidate words for each beam
        pred_words = [] # Predicted words of each beam

        # computing feed_dict for the beam search (except root)
        # feed dict should contain the best words/chars/bigrams found by the previous level of search

        # For level 1 in our example this would be
        # sample_beam_inputs[0]: b, sample_beam_inputs[1]:a
        feed_dict = {}
        for p_idx, pred_i in enumerate(this_level_candidates):                    
            # Updating the feed_dict for getting next predictions
            test_words.append(np.zeros((1,vocabulary_size),dtype=np.float32))
            test_words[p_idx][0,this_level_candidates[p_idx]] = 1.0

            feed_dict.update({sample_beam_inputs[p_idx]:test_words[p_idx]})

        # Calculating predictions for all neighbors in beams
        # This is a list of vectors where each vector is the prediction vector for a certain beam
        # For level 1 in our example, the prediction values for 
        #      b             a  (previous beam search results)
        # [a,  b,  c],  [a,  b,  c] (current level predictions) would be
        # [0.1,0.1,0.1],[0.5,0.2,0]
        sample_preds_all_neighbors = session.run(sample_beam_predictions, feed_dict=feed_dict)

        # Create a single vector with 
        # Making our example [0.1,0.1,0.1,0.5,0.2,0] 
        sample_preds_all_neighbors_concat = np.concatenate(sample_preds_all_neighbors,axis=1)

        # Update this_level_candidates to be used for the next iteration
        # And update the probabilities for each beam
        # In our example these would be [3,4] (indices with maximum value from above vector)
        this_level_candidates = np.argsort(sample_preds_all_neighbors_concat.ravel())[::-1][:beam_neighbors]

        # In the example this would be [1,1]
        parent_beam_indices = this_level_candidates//vocabulary_size

        # normalize this_level_candidates to fall between [0,vocabulary_size]
        # In this example this would be [0,1]
        this_level_candidates = (this_level_candidates%vocabulary_size).tolist()

        # Here we update the final state of each beam to be
        # the state that was at the index 1. Because for both the candidates at this level the parent is 
        # at index 1 (that is b from root level)
        session.run(update_sample_beam_state, feed_dict={best_neighbor_beam_indices: parent_beam_indices})

        # Here we update the joint probabilities of each beam and add the newly found candidates to the sequence
        tmp_this_level_probs = np.asarray(this_level_probs) #This is currently [0.5,0.2]
        tmp_test_sequences = list(test_sequences) # This is currently [b,a]

        for b_n_i in range(beam_neighbors):
            # We make the b_n_i element of this_level_probs to be the probability of parents
            # In the example the parent indices are [1,1]
            # So this_level_probs become [0.5,0.5]
            this_level_probs[b_n_i] = tmp_this_level_probs[parent_beam_indices[b_n_i]]

            # Next we multipyle these by the probabilities of the best candidates from current level 
            # [0.5*0.5, 0.5*0.2] = [0.25,0.1]
            this_level_probs[b_n_i] *= sample_preds_all_neighbors[parent_beam_indices[b_n_i]][0,this_level_candidates[b_n_i]]

            # Make the b_n_i element of test_sequences to be the correct parent of the current best candidates
            # In the example this becomes [b, b]
            test_sequences[b_n_i] = tmp_test_sequences[parent_beam_indices[b_n_i]]

            # Now we append the current best candidates
            # In this example this becomes [ba,bb]
            test_sequences[b_n_i] += reverse_dictionary[this_level_candidates[b_n_i]]

            # Create one-hot-encoded representation for each candidate
            pred_words.append(np.zeros((1,vocabulary_size),dtype=np.float32))
            pred_words[b_n_i][0,this_level_candidates[b_n_i]] = 1.0

    # Calculate best beam id based on the highest beam probability
    # Using the highest beam probability always lead to very monotonic text
    # Let us sample one randomly where one being sampled is decided by the likelihood of that beam
    rand_cand_ids = np.argsort(this_level_probs)[-3:]
    rand_cand_probs = this_level_probs[rand_cand_ids]/np.sum(this_level_probs[rand_cand_ids])
    random_id = np.random.choice(rand_cand_ids,p=rand_cand_probs)

    best_beam_id = parent_beam_indices[random_id]

    # Update state and output variables for test prediction
    session.run(update_sample_beam_state,feed_dict={best_neighbor_beam_indices:[best_beam_id for _ in range(beam_neighbors)]})

    # Make the last word/character/bigram from the best beam
    test_word = pred_words[best_beam_id]
    
    return test_sequences[best_beam_id]


### Running Training, Validation and Generation
We traing the LSTM on existing training data, check the validaiton perplexity on an unseen chunk of text and generate a fresh segment of text

In [None]:
filename_to_save = 'lstm_beam_search_dropout'

# Some hyperparameters needed for the training process

num_steps = 26
steps_per_document = 100
valid_summary = 1
train_doc_count = 100
docs_per_step = 10


beam_nodes = []

beam_train_perplexity_ot = []
beam_valid_perplexity_ot = []
session = tf.InteractiveSession()

tf.global_variables_initializer().run()

print('Initialized')
average_loss = 0

# We use the first 10 documents that has 
# more than 10*steps_per_document bigrams for creating the validation dataset

# Identify the first 10 documents following the above condition
long_doc_ids = []
for di in range(num_files):
  if len(data_list[di])>10*steps_per_document:
    long_doc_ids.append(di)
  if len(long_doc_ids)==10:
    break
    
# Generating validation data
data_gens = []
valid_gens = []
for fi in range(num_files):
  # Get all the bigrams if the document id is not in the validation document ids
  if fi not in long_doc_ids:
    data_gens.append(DataGeneratorOHE(data_list[fi],batch_size,num_unrollings))
  # if the document is in the validation doc ids, only get up to the 
  # last steps_per_document bigrams and use the last steps_per_document bigrams as validation data
  else:
    data_gens.append(DataGeneratorOHE(data_list[fi][:-steps_per_document],batch_size,num_unrollings))
    # Defining the validation data generator
    valid_gens.append(DataGeneratorOHE(data_list[fi][-steps_per_document:],1,1))


feed_dict = {}
for step in range(num_steps):
    
    for di in np.random.permutation(train_doc_count)[:docs_per_step]:            
        doc_perplexity = 0
        for doc_step_id in range(steps_per_document):
            
            # Get a set of unrolled batches
            u_data, u_labels = data_gens[di].unroll_batches()
            
            # Populate the feed dict by using each of the data batches
            # present in the unrolled data
            for ui,(dat,lbl) in enumerate(zip(u_data,u_labels)):            
                feed_dict[train_inputs[ui]] = dat
                feed_dict[train_labels[ui]] = lbl
            
            # Running the TensorFlow operations
            _, l, step_perplexity = session.run([optimizer, loss, train_perplexity_without_exp], 
                                                       feed_dict=feed_dict)
            # Update doc_perpelxity variable
            doc_perplexity += step_perplexity 
            
            # Update the average_loss variable
            average_loss += step_perplexity 
            
        # Show the printing progress <train_doc_id_1>.<train_doc_id_2>. ...
        print('(%d).'%di,end='') 
    
    # resetting hidden state after processing a single document
    # It's still questionable if this adds value in terms of learning
    # One one hand it's intuitive to reset the state when learning a new document
    # On the other hand this approach creates a bias for the state to be zero
    # We encourage the reader to investigate further the effect of resetting the state
    #session.run(reset_train_state) # resetting hidden state for each document
    print('')
    
    if (step+1) % valid_summary == 0:
      
      # Compute average loss
      average_loss = average_loss / (docs_per_step*steps_per_document*valid_summary)
      
      # Print loss
      print('Average loss at step %d: %f' % (step+1, average_loss))
      print('\tPerplexity at step %d: %f' %(step+1, np.exp(average_loss)))
      beam_train_perplexity_ot.append(np.exp(average_loss))
    
      average_loss = 0 # reset loss
        
      valid_loss = 0 # reset loss
        
      # calculate valid perplexity
      for v_doc_id in range(10):
          # Remember we process things as bigrams
          # So need to divide by 2
          for v_step in range(steps_per_document//2):
            uvalid_data,uvalid_labels = valid_gens[v_doc_id].unroll_batches()        

            # Run validation phase related TensorFlow operations       
            v_perp = session.run(
                valid_perplexity_without_exp,
                feed_dict = {valid_inputs:uvalid_data[0],valid_labels: uvalid_labels[0]}
            )

            valid_loss += v_perp
            
          session.run(reset_valid_state)
      
          # Reset validation data generator cursor
          valid_gens[v_doc_id].reset_indices()      
    
      print()
      v_perplexity = np.exp(valid_loss/(steps_per_document*10.0//2))
      print("Valid Perplexity: %.2f\n"%v_perplexity)
      beam_valid_perplexity_ot.append(v_perplexity)
      
      # Decay learning rate
      decay_learning_rate(session, v_perplexity)
    
      # Generating new text ...
      # We will be generating one segment having 500 bigrams
      # Feel free to generate several segments by changing
      # the value of segments_to_generate
      print('Generated Text after epoch %d ... '%step)  
      segments_to_generate = 1
      chars_in_segment = 500//beam_length
    
      for _ in range(segments_to_generate):
        print('======================== New text Segment ==========================')
        # first word randomly generated
        test_word = np.zeros((1,vocabulary_size),dtype=np.float32)
        rand_doc = data_list[np.random.randint(0,num_files)]
        test_word[0,rand_doc[np.random.randint(0,len(rand_doc))]] = 1.0
        print("",reverse_dictionary[np.argmax(test_word[0])],end='')
        
        for _ in range(chars_in_segment):
            
            test_sequence = get_beam_prediction(session)
            print(test_sequence,end='')
            
        print("")
        session.run([reset_sample_beam_state])
        
        print('====================================================================')
      print("")

session.close()
    
with open(filename_to_save, 'wt') as f:
    writer = csv.writer(f, delimiter=',')
    writer.writerow(beam_train_perplexity_ot)
    writer.writerow(beam_valid_perplexity_ot)



Initialized
(55).(70).(83).(31).(27).(32).(45).(84).(86).(46).
Average loss at step 1: 4.394968
	Perplexity at step 1: 81.042013

Valid Perplexity: 56.04

Generated Text after epoch 0 ... 
 oaut and to and the kill, and then then then then then then ther, and ther, and then him and then the king, and
when him taid the kill, and ther, and then him and then the king, and
when him taid the kill, and ther, and ther, and then him and and then him the kill, and ther, and then him and and then him the king to the king, and
when him, when then him the kill, and then him and then the king then him the king to the king, and
when him taid the kill, and then then ther, and then him and then the king then ing and then him and and then him the kill, and then him and then the king, and
when him, when then the king, when then him the kill, and ther, and then then then then then then ther, and then him and then the king, and
when him taid the kill, and ther, and then him and then the king, and
when him

you and you at the wers came ing.  on a fifhe flight, i peeper brought he said, she who had said that.  he died to the fiant, but thing flight, i peeped
through the keUNKole of a doo, which wedding ing
fingers where were hisry to find by son.y one hearthey, and away.  he haUNKer, then the thethan tweth, and that were you flight, he took at hasses.  on
the shoved
lived hing, and breat him wight would he dhove himself, an he to the godfather had the fourth flight, he saw his strant, be and saked the the fingers which i came frors.  the fife.
the fir thought, but that said the maiden the carriage, but is not true.  the mountaid that.  and had said that, the fishes came
and served
themselves up.  as the maid the hearth
flight i a saw of my said,  oh, that is not truese.  the
mas became alarmed, again, and brought then the firth got to hiver, and, he said, said the godfather would have done to h

(85).(6).(76).(12).(31).(51).(70).(21).(88).(13).
Average loss at step 8: 2.236765
	Perplexity 

out of them to the golden haiden with his back, answered, answered this they were hand, answered them to the stones, and there wicked her that he was histed to his fathers, and then he saw a little bother willing you wished him, i will go for it.  and the maid he had said that that he was blow the willowed home.  then they were them to to that her father, if you have be ther horsed the son under stonest her got to him, and
that he had been stole the forest, and when, however said, yes had been ness, and i will great deal of them into
his back that they were all to the stone.

then his marden said him, and the maiden whom side, that she had go that the king whice that the father, why should have them forth that you are nailong, answered him that he was to make when the windowed how himself and said, "i have not be the maiden without, and as he looking at the horse worn, and she

(69).(40).(96).(80).(38).(92).(0).(37).(51).(35).
Average loss at step 14: 2.174636
	Perplexity at step 14: 8

  when he wants to dief, the wedding was spring ouseened in them, and told him the door, and when he was spring up and said to his lay, that her he tooked him, this dear of back to home, and took him to a back and, but they said the maiden, accully of your life."
then she said to down to his legs.

the king's said them rest once more and saw that is nothing towas to the door, and it was stonce was oured and sisted that it was stood standing but in its stonest, when this they went in the water and begged then he had saw his was oblight, but what it, and where is the most beaut out of them, and that went into the kingdom, and when she had done thing.  the king's son came to thrown that had they came to them, and they went inside him as he was s

(49).(50).(27).(43).(60).(29).(3).(99).(68).(47).
Average loss at step 20: 2.237818
	Perplexity at step 20: 9.372862

Valid Perplexity: 26.16

Generated Text after epoch 19 ... 
 thful
bettleg, which heard the wook and saw that, anyou shall went 