# Session 17: Character-based Language Modelling with LSTMs

------------------------------------------------------
*Introduction to Data Science & Machine Learning*

*Pablo M. Olmos olmos@tsc.uc3m.es*

------------------------------------------------------

The goal of this notebook is to train a LSTM character prediction model over the data base of plain text [Text8](http://mattmahoney.net/dc/textdata).

This is a personal wrap-up of all the material provided by [Google's Deep Learning course on Udacity](https://www.udacity.com/course/deep-learning--ud730), so all credit goes to them. 


In [2]:
# These are all the modules we'll be using later. Make sure you can import them
# before proceeding further.
from __future__ import print_function
import os
import numpy as np
import random
import string
import tensorflow as tf
import zipfile
from six.moves import range
from six.moves.urllib.request import urlretrieve

#Here I provide some text preprocessing functions
import preprocessing as pr

In [3]:
# Lets check what version of tensorflow we have installed. The provided scripts should run with tf 1.0 and above

print(tf.__version__)

1.3.0


We will use the plain text database from this [link](http://mattmahoney.net/dc/textdata.html)

In [4]:
filename = './text8.zip'

Read the dataset into a single list of characters ...

In [5]:
text = pr.read_data(filename)
print(type(text))
print('Data size %d' % len(text))

<class 'str'>
Data size 100000000


In [6]:
v_size = len(string.ascii_lowercase) + 1 #  Number of characters [a-z] + ' '

The functions char2id and id2char map characters to numeric IDs and back ...

In [7]:
print(pr.char2id('a'), pr.char2id('z'), pr.char2id(' '),'\n')

print(pr.char2id('ï'),'\n')

print(pr.id2char(1), pr.id2char(26), pr.id2char(0))

1 26 0 

Unexpected character: ï
0 

a z  


## A RNN-based Character Language Model

We will train a RNN with the goal of predicting the next character given current one. The structure is as follows:

<img src="RNN_CLM.png" width="600" height="400">


### Create a small validation set

In [8]:
valid_size = 1000
valid_text = text[:valid_size]
train_text = text[valid_size-1:] # The train sequence is the shifted input
train_size = len(train_text)

print("First 20 characters of validation set: ",  valid_text[:20], '\n\n')

print("First 20 characters of train set: ",  train_text[:20])

First 20 characters of validation set:   anarchism originate 


First 20 characters of train set:  ions anarchists advo


'BatchGenerator' is a class with a method to generate training batches for the LSTM model.

In [9]:
batch_size = 64 # Number of sequences in the batch
num_unrollings = 10 # Number of characters per sequence


train_batches = pr.BatchGenerator(text=train_text, batch_size=batch_size,num_unrollings=num_unrollings,
                                  vocabulary_size=v_size)
valid_batches = pr.BatchGenerator(text=valid_text, batch_size=1,num_unrollings=1,
                                  vocabulary_size=v_size)

print('Each element of train_batches.next() is a binary matrix of size ', train_batches.next()[0].shape)

print('Each sequence is ',len(train_batches.next()),' characters long')

Each element of train_batches.next() is a binary matrix of size  (64, 27)
Each sequence is  11  characters long


In [10]:
print(pr.batches2string(train_batches.next()))

['cate social', 'nments fail', 'nal park ph', 'ries index ', 'cess of cas', 'r h provide', 'nguage amon', 'ngers in de', 'nal media a', 'e during th', ' known manu', ' seven a wi', 'ss covering', 'een one of ', 'ize single ', 'e first car', 'n in jersey', 'the poverty', 'igns of hum', 'd cause so ', 'in denatura', 'ice formati', ' the input ', 'ick to pull', 'fusion inab', 'complete an', 'st of the m', 'e it fort d', 'attempts by', 'formats for', 'esoteric ch', ' growing po', 'original do', 'ne nine eig', 'arch eight ', 'character l', 'cal mechani', 'n gm compar', 'is fundamen', 'elieve the ', 'east not pa', 'd upon by h', 'm example r', 'sed on the ', 'the officia', 'ion at this', 'ine three t', 'linux enter', 't daily col', 'tration cam', 'nehru wishe', ' stiff from', 'harman s sy', 'to to begin', 'nitiatives ', ' these auth', 'ricky ricar', 'ew of mathe', 'nent of arm', 'ccredited p', 'ne external', 'y other sta', 'l buddhism ', 'evices poss']


### Useful functions to evaluate the model

In [11]:

def logprob(predictions, labels):
    """Cross entropy in for the predictions in a batch."""
    predictions[predictions < 1e-10] = 1e-10
    return np.sum(np.multiply(labels, -np.log(predictions))) / labels.shape[0]


def sample_distribution(distribution):
    """Sample one element from a distribution assumed to be an array of normalized
        probabilities.
    """
        
    r = random.uniform(0,1)
    s = 0
    for i in range(len(distribution)):
        s += distribution[i]
        if s >= r:
            return i
    return len(distribution) - 1


def sample(prediction):
    """Turn a  prediction into 1-hot encoded samples."""
    p = np.zeros(shape=[1, v_size], dtype=np.float)
    p[0, sample_distribution(prediction[0])] = 1.0
    return p


def random_distribution():
    """Generate at random a discrete pmf for the characters"""
    b = np.random.uniform(0.0, 1.0, size=[1, v_size])
    return b / np.sum(b, 1)[:, None]


# LSTM model


<img src="./files/LSTM_full_4.png" width="800" height="400">

We will use a cross entropy loss function between $\hat{y}^{(t)}$ and the true one hot labels.

About the TF implementation below, see the following excellent [post](http://www.thushv.com/sequential_modelling/long-short-term-memory-lstm-networks-implementing-with-tensorflow-part-2/)

About the zip() and zip(*) operators, see this [post](https://docs.python.org/2/library/functions.html#zip)

In [12]:
num_nodes = 64

graph = tf.Graph()
with graph.as_default():
  
    # Parameters:
    
    #i(t) parameters
    # Input gate: input, previous output, and bias.
    ix = tf.Variable(tf.truncated_normal([v_size, num_nodes], -0.1, 0.1))   ##W^ix
    im = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1)) ## W^ih
    ib = tf.Variable(tf.zeros([1, num_nodes])) ##b_i
    
    #f(t) parameters
    # Forget gate: input, previous output, and bias.
    fx = tf.Variable(tf.truncated_normal([v_size, num_nodes], -0.1, 0.1)) ##W^fx
    fm = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1)) ##W^fh
    fb = tf.Variable(tf.zeros([1, num_nodes])) ##b_f
    
    #g(t) parameters
    # Memory cell: input, state and bias.                             
    cx = tf.Variable(tf.truncated_normal([v_size, num_nodes], -0.1, 0.1)) ##W^gx
    cm = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1)) ##W^gh
    cb = tf.Variable(tf.zeros([1, num_nodes]))  ##b_g
    
    #o(t) parameters
    # Output gate: input, previous output, and bias.
    ox = tf.Variable(tf.truncated_normal([v_size, num_nodes], -0.1, 0.1))  ##W^ox
    om = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))  ##W^oh
    ob = tf.Variable(tf.zeros([1, num_nodes])) ##b_o
    
    # Variables saving state across unrollings.
    saved_output = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False) #h(t)
    saved_state = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False) #s(t)
    
    
    # Classifier weights and biases (over h(t) to labels)
    w = tf.Variable(tf.truncated_normal([num_nodes, v_size], -0.1, 0.1))
    b = tf.Variable(tf.zeros([v_size]))
  
    # Definition of the cell computation.
    def lstm_cell(i, o, state):
        
        """Create a LSTM cell. See e.g.: http://arxiv.org/pdf/1402.1128v1.pdf
        Note that in this formulation, we omit the various connections between the
        previous state and the gates."""
        
        input_gate = tf.sigmoid(tf.matmul(i, ix) + tf.matmul(o, im) + ib)
        forget_gate = tf.sigmoid(tf.matmul(i, fx) + tf.matmul(o, fm) + fb)
        update = tf.matmul(i, cx) + tf.matmul(o, cm) + cb       
        state = forget_gate * state + input_gate * tf.tanh(update)    #tf.tanh(update) is g(t)
        output_gate = tf.sigmoid(tf.matmul(i, ox) + tf.matmul(o, om) + ob)
        return output_gate * tf.tanh(state), state      #h(t) is output_gate * tf.tanh(state)

    # Input data is a list of placeholders!
    # Train laberls is just the shifted input
    
    train_data = list()
    for _ in range(num_unrollings + 1):
        train_data.append(tf.placeholder(tf.float32, shape=[batch_size,v_size]))
    train_inputs = train_data[:num_unrollings]
    train_labels = train_data[1:]  

    # Unrolled LSTM loop
    
    outputs = list()
    output = saved_output
    aux = output
    state = saved_state
    for i in train_inputs:
        output, state = lstm_cell(i, output, state)
        outputs.append(output)
        
    
    # State saving across unrollings.
    with tf.control_dependencies([saved_output.assign(output),saved_state.assign(state)]):
        """ With the tf.control_dependencies we ensure that the logits and the loss are not updated
        until all LSTMs outputs are updated."""
        #Classifier.
        logits = tf.nn.xw_plus_b(tf.concat(axis=0,values=outputs), w, b)
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=tf.concat(axis=0, values=train_labels),logits=logits))

    # Optimizer.
    
    optimizer = tf.train.AdamOptimizer(learning_rate=5e-03).minimize(loss) 

    # Predictions.
    train_prediction = tf.nn.softmax(logits)
  
    # Sampling and validation eval
    sample_input = tf.placeholder(tf.float32, shape=[1, v_size])
    
    saved_sample_output = tf.Variable(tf.zeros([1, num_nodes]))
    saved_sample_state = tf.Variable(tf.zeros([1, num_nodes]))
    
    # Create an op that groups multiple operations.
    reset_sample_state = tf.group(saved_sample_output.assign(tf.zeros([1, num_nodes])),
                                  saved_sample_state.assign(tf.zeros([1, num_nodes])))
    
    sample_output, sample_state = lstm_cell(sample_input, saved_sample_output, saved_sample_state)
    
    with tf.control_dependencies([saved_sample_output.assign(sample_output),saved_sample_state.assign(sample_state)]):
        sample_prediction = tf.nn.softmax(tf.nn.xw_plus_b(sample_output, w, b))


In [13]:
num_steps = 10001
summary_frequency = 100

with tf.Session(graph=graph) as session:
    tf.global_variables_initializer().run()
    print('Initialized')
    mean_loss = 0
    for step in range(num_steps):
        batches = train_batches.next()
        feed_dict = dict()
        
        # List of input placeholders
        for i in range(num_unrollings + 1):
            feed_dict[train_data[i]] = batches[i]
            
        # We run the optimizer    
        _, l, predictions = session.run([optimizer, loss, train_prediction], feed_dict=feed_dict)
        
        mean_loss += l
        if step % summary_frequency == 0:
            
            if step > 0:
                mean_loss /= summary_frequency
                
            # The mean loss is an estimate of the loss over the last few batches.
            print('Average loss at step %d: %f' % (step, mean_loss))
            
            mean_loss = 0
            
            labels = np.concatenate(list(batches)[1:])
            
            # The perplexity is simply the exponential cross entropy. The smaller the better
            print('Minibatch perplexity: %.2f' % float(np.exp(logprob(predictions, labels))))
            
            if step % (summary_frequency * 10) == 0:
                # Generate some samples.
                print('=' * 80)
                for _ in range(5):
                    
                    feed = sample(random_distribution())
                    sentence = pr.characters(feed)[0] #one-hote encodding 
                    reset_sample_state.run()
                    
                    # We sample one by one
                    for _ in range(79):
                        prediction = sample_prediction.eval({sample_input: feed})
                        feed = sample(prediction)
                        sentence += pr.characters(feed)[0]
                    print(sentence)
                print('=' * 80)
                
                # Measure validation set perplexity.
                reset_sample_state.run()
                valid_logprob = 0

                for _ in range(valid_size):
                    b = valid_batches.next()
                    predictions = sample_prediction.eval({sample_input: b[0]})
                    valid_logprob = valid_logprob + logprob(predictions, b[1])
                    
                print('Validation set perplexity: %.2f' % float(np.exp(valid_logprob / valid_size)))


Initialized
Average loss at step 0: 3.295102
Minibatch perplexity: 26.98
wxbitywrnxbhvnxoqi toarizdv ijcipltstyly ddigukrlxjemkxjhww xvxhcoag rdl ejrqzlk
qaztpqqv ykexorvbqwjxroihxrxsscszuxsh cvosgzwdgbssaifzbziueihjrcwjebimfbgjqqcstb
jymffuppqyysrxouagrhdcttjovwsyqnzuwogcxnnysfutqxybeqkw fmrzeguepzbchrwnjjyrhinbu
oayrh vdiwtcnvjnxduwzqoeglywyqhjkfbmctedieuyttmaarqzqmincjj lsiadskjzzu pyogyqzz
gsjokabstrmc tmgzo mtdlvxlmykqjxuinuz tpwtskjturhmgsjlqufgrxsnkhsubtmvjoy zxwuak
Validation set perplexity: 26.63
Average loss at step 100: 2.721629
Minibatch perplexity: 10.74
Average loss at step 200: 2.295853
Minibatch perplexity: 10.13
Average loss at step 300: 2.163366
Minibatch perplexity: 8.39
Average loss at step 400: 2.066477
Minibatch perplexity: 7.96
Average loss at step 500: 2.001935
Minibatch perplexity: 6.90
Average loss at step 600: 1.967257
Minibatch perplexity: 6.70
Average loss at step 700: 1.910152
Minibatch perplexity: 5.61
Average loss at step 800: 1.872794
Minibatch perplexi

Average loss at step 6400: 1.555308
Minibatch perplexity: 5.31
Average loss at step 6500: 1.570216
Minibatch perplexity: 4.89
Average loss at step 6600: 1.610280
Minibatch perplexity: 4.80
Average loss at step 6700: 1.588376
Minibatch perplexity: 4.70
Average loss at step 6800: 1.614277
Minibatch perplexity: 5.17
Average loss at step 6900: 1.585228
Minibatch perplexity: 4.58
Average loss at step 7000: 1.587541
Minibatch perplexity: 4.79
lises worls in have was orgining action on continten in the preselds societ one 
xed im for one nine seven bongia filemor ja regagemal like so is defed the chain
querwe bonforth sund filom balania sangdly be officiajiva a chip layer and partu
wors every to forbidings part he seem of the gallory but devellations of mill lo
peldy the was known typic field beeng in matur os lingu the class the knein when
Validation set perplexity: 4.45
Average loss at step 7100: 1.588127
Minibatch perplexity: 4.83
Average loss at step 7200: 1.577650
Minibatch perplexity: 4